├── .github └── workflows │ └── black.yml ├── .gitignore ├── README.md ├── aimless ├── __init__.py ├── augment.py ├── lightning │ ├── __init__.py │ ├── freq_mask.py │ └── waveform.py ├── loss │ ├── __init__.py │ ├── freq.py │ └── time.py ├── models │ ├── __init__.py │ ├── band_split_rnn.py │ ├── demucs_split.py │ ├── rel_tdemucs.py │ └── xumx.py └── utils │ ├── __init__.py │ ├── mwf.py │ └── utils.py ├── cfg ├── cdx_a │ ├── bandsplit_rnn.yaml │ └── hdemucs.yaml ├── demucs.yaml ├── hdemucs.yaml ├── mdx_a │ └── hdemucs.yaml ├── speech_enhance.yaml └── xumx.yaml ├── data ├── __init__.py ├── augment.py ├── dataset │ ├── __init__.py │ ├── base.py │ ├── dnr.py │ ├── fast_musdb.py │ ├── label_noise_bleed.py │ └── speech.py └── lightning │ ├── __init__.py │ ├── bleed.py │ ├── dnr.py │ ├── label_noise.csv │ ├── label_noise.py │ ├── musdb.py │ └── speech.py ├── docs └── aimless-logo-crop.svg ├── environment.yml ├── main.py ├── scripts ├── audio_utils.py ├── convert_to_dnr.py ├── dataset_split_and_mix.py ├── download_youtube.py ├── musdb_to_voiceless.py ├── urls.txt └── webapp.py └── setup.py /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: psf/black@stable 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### CUDA template 2 | *.i 3 | *.ii 4 | *.gpu 5 | *.ptx 6 | *.cubin 7 | *.fatbin 8 | 9 | ### VirtualEnv template 10 | # Virtualenv 11 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 12 | .Python 13 | [Bb]in 14 | [Ii]nclude 15 | [Ll]ib 16 | [Ll]ib64 17 | [Ll]ocal 18 | pyvenv.cfg 19 | .venv 20 | pip-selfcheck.json 21 | 22 | ### Linux template 23 | *~ 24 | 25 | # temporary files which can be created if a process still has a handle open of a deleted file 26 | .fuse_hidden* 27 | 28 | # KDE directory preferences 29 | .directory 30 | 31 | # Linux trash folder which might appear on any partition or disk 32 | .Trash-* 33 | 34 | # .nfs files are created when an open file is removed but is still being accessed 35 | .nfs* 36 | 37 | ### JupyterNotebooks template 38 | # gitignore template for Jupyter Notebooks 39 | # website: http://jupyter.org/ 40 | 41 | .ipynb_checkpoints 42 | */.ipynb_checkpoints/* 43 | 44 | # IPython 45 | profile_default/ 46 | ipython_config.py 47 | 48 | # Remove previous ipynb_checkpoints 49 | # git rm -r .ipynb_checkpoints/ 50 | 51 | ### macOS template 52 | # General 53 | .AppleDouble 54 | .DS_Store 55 | .LSOverride 56 | 57 | # Icon must end with two \r 58 | Icon 59 | 60 | # Thumbnails 61 | ._* 62 | 63 | # Files that might appear in the root of a volume 64 | .DocumentRevisions-V100 65 | .fseventsd 66 | .Spotlight-V100 67 | .TemporaryItems 68 | .Trashes 69 | .VolumeIcon.icns 70 | .com.apple.timemachine.donotpresent 71 | 72 | # Directories potentially created on remote AFP share 73 | .AppleDB 74 | .AppleDesktop 75 | Network Trash Folder 76 | Temporary Items 77 | .apdisk 78 | 79 | ### Python template 80 | # Byte-compiled / optimized / DLL files 81 | __pycache__/ 82 | *.py[cod] 83 | *$py.class 84 | 85 | # C extensions 86 | *.so 87 | 88 | # Distribution / packaging 89 | develop-eggs/ 90 | downloads/ 91 | eggs/ 92 | .eggs/ 93 | lib/ 94 | lib64/ 95 | parts/ 96 | sdist/ 97 | var/ 98 | wheels/ 99 | share/python-wheels/ 100 | *.egg-info/ 101 | .installed.cfg 102 | *.egg 103 | MANIFEST 104 | 105 | # PyInstaller 106 | # Usually these files are written by a python script from a template 107 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 108 | *.manifest 109 | *.spec 110 | 111 | # Installer logs 112 | pip-log.txt 113 | pip-delete-this-directory.txt 114 | 115 | # Unit test / coverage reports 116 | htmlcov/ 117 | .tox/ 118 | .nox/ 119 | .coverage 120 | .coverage.* 121 | .cache 122 | nosetests.xml 123 | coverage.xml 124 | *.cover 125 | *.py,cover 126 | .hypothesis/ 127 | .pytest_cache/ 128 | cover/ 129 | 130 | # Translations 131 | *.mo 132 | *.pot 133 | 134 | # Django stuff: 135 | *.log 136 | local_settings.py 137 | db.sqlite3 138 | db.sqlite3-journal 139 | 140 | # Flask stuff: 141 | instance/ 142 | .webassets-cache 143 | 144 | # Scrapy stuff: 145 | .scrapy 146 | 147 | # Sphinx documentation 148 | docs/_build/ 149 | 150 | # PyBuilder 151 | .pybuilder/ 152 | target/ 153 | 154 | # Jupyter Notebook 155 | 156 | # IPython 157 | 158 | # pyenv 159 | # For a library or package, you might want to ignore these files since the code is 160 | # intended to run in multiple environments; otherwise, check them in: 161 | # .python-version 162 | 163 | # pipenv 164 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 165 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 166 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 167 | # install all needed dependencies. 168 | #Pipfile.lock 169 | 170 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 171 | __pypackages__/ 172 | 173 | # Celery stuff 174 | celerybeat-schedule 175 | celerybeat.pid 176 | 177 | # SageMath parsed files 178 | *.sage.py 179 | 180 | # Environments 181 | .env 182 | env/ 183 | venv/ 184 | ENV/ 185 | env.bak/ 186 | venv.bak/ 187 | 188 | # Spyder project settings 189 | .spyderproject 190 | .spyproject 191 | 192 | # Rope project settings 193 | .ropeproject 194 | 195 | # mkdocs documentation 196 | /site 197 | 198 | # mypy 199 | .mypy_cache/ 200 | .dmypy.json 201 | dmypy.json 202 | 203 | # Pyre type checker 204 | .pyre/ 205 | 206 | # pytype static type analyzer 207 | .pytype/ 208 | 209 | # Cython debug symbols 210 | cython_debug/ 211 | 212 | # Custom 213 | lightning_logs/* 214 | *.pyc 215 | build/ 216 | dist/ 217 | .idea/ 218 | out/ 219 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 |
6 | 7 | AIMLESS (Artificial Intelligence and Music League for Effective Source Separation) is a special interest group in audio source separation at C4DM, consisting of PhD students from the AIM CDT program. 8 | This repository is adapted from [Danna-Sep](https://github.com/yoyololicon/music-demixing-challenge-ismir-2021-entry) and 9 | contains our training code for the [SDX23 Sound Demixing Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023). 10 | 11 | 12 | ## Quick Start 13 | 14 | The conda environment we used for training is described in `environment.yml`. 15 | The below commands should be runnable if you're using QMUL EECS servers. 16 | If you want to run it on your local machine, change the `root` param in the config to where you downloaded the MUSDB18-HQ dataset. 17 | 18 | ### Frequency-masking-based model 19 | 20 | 21 | ```commandline 22 | python main.py fit --config cfg/xumx.yaml 23 | ``` 24 | 25 | ### Raw-waveform-based model 26 | 27 | 28 | ```commandline 29 | python main.py fit --config cfg/demucs.yaml 30 | ``` 31 | 32 | ## Install the repository as a package 33 | 34 | This step is required if you want to test our submission repositories (see section [Reproduce the winning submission](#reproduce-the-winning-submission)) locally. 35 | ```sh 36 | pip install git+https://github.com/aim-qmul/sdx23-aimless 37 | ``` 38 | 39 | ## Reproduce the winning submission 40 | 41 | ### CDX Leaderboard A, submission ID 220319 42 | 43 | This section describes how to reproduce the [best perform model](https://gitlab.aicrowd.com/yoyololicon/cdx-submissions/-/issues/90) we used on CDX leaderboard A. 44 | The submission consists of one HDemucs predicting all the targets and one BandSplitRNN predicitng the music from the mixture. 45 | 46 | To train the HDemucs: 47 | ```commandline 48 | python main.py fit --config cfg/cdx_a/hdemucs.yaml --data.init_args.root /DNR_DATASET_ROOT/dnr_v2/ 49 | ``` 50 | Remember to change `/DNR_DATASET_ROOT/dnr_v2/` to your download location of [Divide and Remaster (DnR) dataset](https://zenodo.org/record/6949108). 51 | 52 | To train the BandSplitRNN: 53 | ```commandline 54 | python main.py fit --config cfg/cdx_a/bandsplit_rnn.yaml --data.init_args.root /DNR_DATASET_ROOT/dnr_v2/ 55 | ``` 56 | 57 | We trained the models with no more than 4 GPUs, depending on the resources we had at the time. 58 | 59 | After training, please go to our [submission repository](https://gitlab.aicrowd.com/yoyololicon/cdx-submissions/). 60 | Then, copy the last checkpoint of HDemucs (usually located at `lightning_logs/version_**/checkpoints/last.ckpt`) to `my_submission/lightning_logs/hdemucs-64-sdr/checkpoints/last.ckpt` in the submission repository. 61 | Similarly, copy the last checkpoint of BandSplitRNN to `my_submission/lightning_logs/bandsplitRNN-music/checkpoints/last.ckpt`. 62 | After these steps, you have reproduced our submission! 63 | 64 | The inference procedure in our submission repository is a bit complex. 65 | Briefly speaking, the HDemucs predicts the targets independently for each channels of the stereo mixture, plus, the average (the mid) and the difference (the side) of the two channels. 66 | The stereo separated sources are made from a linear combination of these mono predictions. 67 | The separated music from the BandSplitRNN is enhanced by Wiener Filtering, and the final music predictions is the average from the two models. 68 | 69 | ### MDX Leaderboard A (Label Noise), submission ID 220426 70 | 71 | This section describes how to reproduce the [best perform model](https://gitlab.aicrowd.com/yoyololicon/mdx23-submissions/-/issues/76) we used on MDX leaderboard A. 72 | 73 | Firstly, we manually inspected the [label noise dataset](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23/dataset_files)(thanks @mhrice for the hard work!) and labeled the clean songs (no label noise). 74 | The labels are recorded in `data/lightning/label_noise.csv`. 75 | Then, a HDemucs was trained only on the clean labels with the following settings: 76 | 77 | * negative SDR as the loss function 78 | * Training occurs on random chunks and random stem combinations of the clean songs 79 | * Training batches are augmented and processed using different random effects 80 | * Due to all this randomization, validation is done also on the training dataset (no separate validation set) 81 | 82 | To reproduce the training: 83 | ```commandline 84 | python main.py fit --config cfg/mdx_a/hdemucs.yaml --data.init_args.root /DATASET_ROOT/ 85 | ``` 86 | Remember to place the label noise data under `/DATASET_ROOT/train/`. 87 | 88 | Other details: 89 | * Model is trained for ~800 epochs (approx. 2 weeks on 4 RTX A50000) 90 | * During the last ~200 epochs, the learning rate is reduced to 0.001, gradient accumulation is increased to 64, and the effect randomization chance is increased by a factor of 1.666 (e.g. 30% to 50% etc.) 91 | 92 | After training, please go to our [submission repository](https://gitlab.aicrowd.com/yoyololicon/mdx23-submissions/). 93 | Then, copy the checkpoint to `my_submission/acc64_4devices_lr0001_e1213_last.ckpt` in the submission repository. 94 | After these steps, you have reproduced our submission! 95 | 96 | 97 | ## Structure 98 | 99 | * `aimless`: package root, which can be imported for submission. 100 | * `loss`: loss functions. 101 | * `freq.*`: loss functions for frequency-domain models . 102 | * `time.*`: loss functions for time-domain models. 103 | * `augment`: data augmentations that are better on GPU. 104 | * `lightning`: all lightning modules. 105 | * `waveform.WaveformSeparator`: trainer for time-domain models. 106 | * `freq_mask.MaskPredictor`: trainer for frequency-domain models. 107 | * `models`: your custom models. 108 | * `cfg`: all config files. 109 | * `data`: 110 | * `dataset`: custom pytorch datasets. 111 | * `lightning`: all lightning data modules. 112 | * `augment`: data augmentations that are better on CPU. 113 | 114 | ## Streamlit 115 | 116 | Split song in the browser with pretrained Hybrid Demucs. 117 | 118 | ``` streamlit run scripts/webapp.py ``` 119 | 120 | Then open [http://localhost:8501/](http://localhost:8501/) in your browser. 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /aimless/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/__init__.py -------------------------------------------------------------------------------- /aimless/augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torchaudio 5 | from torchaudio.transforms import TimeStretch, Spectrogram, InverseSpectrogram, Resample 6 | from torchaudio import functional as aF 7 | from torch_fftconv import fft_conv1d 8 | from pathlib import Path 9 | 10 | __all__ = ["SpeedPerturb", "RandomPitch", "RandomConvolutions", "CudaBase"] 11 | 12 | 13 | class CudaBase(nn.Module): 14 | def __init__(self, rand_size, p=0.2): 15 | super().__init__() 16 | self.p = p 17 | self.rand_size = rand_size 18 | 19 | def _transform(self, stems, index): 20 | """ 21 | Args: 22 | stems (torch.Tensor): (B, Num_channels, L) 23 | index (int): index of random transform 24 | Return: 25 | perturbed_stems (torch.Tensor): (B, Num_channels, L') 26 | """ 27 | raise NotImplementedError 28 | 29 | def forward(self, stems: torch.Tensor): 30 | """ 31 | Args: 32 | stems (torch.Tensor): (B, Num_sources, Num_channels, L) 33 | Return: 34 | perturbed_stems (torch.Tensor): (B, Num_sources, Num_channels, L') 35 | """ 36 | shape = stems.shape 37 | orig_len = shape[-1] 38 | stems = stems.view(-1, *shape[-2:]) 39 | select_mask = torch.rand(stems.shape[0], device=stems.device) < self.p 40 | if not torch.any(select_mask): 41 | return stems.view(*shape) 42 | 43 | select_idx = torch.where(select_mask)[0] 44 | perturbed_stems = torch.zeros_like(stems) 45 | perturbed_stems[~select_mask] = stems[~select_mask] 46 | selected_stems = stems[select_mask] 47 | rand_idx = torch.randint( 48 | self.rand_size, (selected_stems.shape[0],), device=stems.device 49 | ) 50 | 51 | for i in range(self.rand_size): 52 | mask = rand_idx == i 53 | if not torch.any(mask): 54 | continue 55 | masked_stems = selected_stems[mask] 56 | perturbed_audio = self._transform(masked_stems, i).to(perturbed_stems.dtype) 57 | 58 | diff = perturbed_audio.shape[-1] - orig_len 59 | 60 | put_idx = select_idx[mask] 61 | if diff >= 0: 62 | perturbed_stems[put_idx] = perturbed_audio[..., :orig_len] 63 | else: 64 | perturbed_stems[put_idx, :, : orig_len + diff] = perturbed_audio 65 | 66 | perturbed_stems = perturbed_stems.view(*shape) 67 | return perturbed_stems 68 | 69 | 70 | class SpeedPerturb(CudaBase): 71 | def __init__(self, orig_freq=44100, speeds=[90, 100, 110], **kwargs): 72 | super().__init__(len(speeds), **kwargs) 73 | self.orig_freq = orig_freq 74 | self.resamplers = nn.ModuleList() 75 | self.speeds = speeds 76 | for s in self.speeds: 77 | new_freq = self.orig_freq * s // 100 78 | self.resamplers.append(Resample(self.orig_freq, new_freq)) 79 | 80 | def _transform(self, stems, index): 81 | y = self.resamplers[index](stems.view(-1, stems.shape[-1])).view( 82 | *stems.shape[:-1], -1 83 | ) 84 | return y 85 | 86 | 87 | class RandomPitch(CudaBase): 88 | def __init__( 89 | self, semitones=[-2, -1, 0, 1, 2], n_fft=2048, hop_length=512, **kwargs 90 | ): 91 | super().__init__(len(semitones), **kwargs) 92 | self.resamplers = nn.ModuleList() 93 | 94 | semitones = torch.tensor(semitones, dtype=torch.float32) 95 | rates = 2 ** (-semitones / 12) 96 | rrates = rates.reciprocal() 97 | rrates = (rrates * 100).long() 98 | rrates[rrates % 2 == 1] += 1 99 | rates = 100 / rrates 100 | 101 | self.register_buffer("rates", rates) 102 | self.spec = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None) 103 | self.inv_spec = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length) 104 | self.stretcher = TimeStretch(hop_length, n_freq=n_fft // 2 + 1) 105 | 106 | for rr in rrates.tolist(): 107 | self.resamplers.append(Resample(rr, 100)) 108 | 109 | def _transform(self, stems, index): 110 | spec = self.spec(stems) 111 | stretched_spec = self.stretcher(spec, self.rates[index]) 112 | stretched_stems = self.inv_spec(stretched_spec) 113 | shifted_stems = self.resamplers[index]( 114 | stretched_stems.view(-1, stretched_stems.shape[-1]) 115 | ).view(*stretched_stems.shape[:-1], -1) 116 | return shifted_stems 117 | 118 | 119 | class RandomConvolutions(CudaBase): 120 | def __init__(self, target_sr: int, ir_folder: str, **kwargs): 121 | ir_folder = Path(ir_folder) 122 | ir_files = list(ir_folder.glob("**/*.wav")) 123 | impulses = [] 124 | for ir_file in ir_files: 125 | ir, sr = torchaudio.load(ir_file) 126 | if ir.shape[0] > 2: 127 | continue 128 | if sr != target_sr: 129 | ir = aF.resample(ir, sr, target_sr) 130 | if ir.shape[0] == 1: 131 | ir = ir.repeat(2, 1) 132 | impulses.append(ir) 133 | 134 | super().__init__(len(impulses), **kwargs) 135 | for i, impulse in enumerate(impulses): 136 | self.register_buffer(f"impulse_{i}", impulse) 137 | 138 | def _transform(self, stems, index): 139 | ir = self.get_buffer(f"impulse_{index}").unsqueeze(1) 140 | ir_flipped = ir.flip(-1) 141 | padded_stems = F.pad(stems, (ir.shape[-1] - 1, 0)) 142 | # TODO: dynamically use F.conv1d if impulse is short 143 | convolved_stems = fft_conv1d(padded_stems, ir_flipped, groups=2) 144 | return convolved_stems 145 | -------------------------------------------------------------------------------- /aimless/lightning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/lightning/__init__.py -------------------------------------------------------------------------------- /aimless/lightning/freq_mask.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | from typing import List, Dict 5 | from torchaudio.transforms import Spectrogram, InverseSpectrogram 6 | 7 | from ..loss.time import SDR 8 | from ..loss.freq import FLoss 9 | from ..augment import CudaBase 10 | 11 | from ..utils import MWF, MDX_SOURCES, SDX_SOURCES, SE_SOURCES 12 | 13 | 14 | class MaskPredictor(pl.LightningModule): 15 | def __init__( 16 | self, 17 | model: nn.Module, 18 | criterion: FLoss, 19 | transforms: List[CudaBase] = None, 20 | target_track: str = None, 21 | targets: Dict[str, None] = {}, 22 | n_fft: int = 4096, 23 | hop_length: int = 1024, 24 | **mwf_kwargs, 25 | ): 26 | super().__init__() 27 | 28 | self.model = model 29 | self.criterion = criterion 30 | self.sdr = SDR() 31 | self.mwf = MWF(**mwf_kwargs) 32 | self.spec = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None) 33 | self.inv_spec = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length) 34 | 35 | if transforms is None: 36 | transforms = [] 37 | 38 | self.transforms = nn.Sequential(*transforms) 39 | if target_track == "sdx": 40 | self.sources = SDX_SOURCES 41 | elif target_track == "mdx": 42 | self.sources = MDX_SOURCES 43 | elif target_track == "se": 44 | self.sources = SE_SOURCES 45 | else: 46 | raise ValueError(f"Invalid target track: {target_track}") 47 | self.register_buffer( 48 | "targets_idx", 49 | torch.tensor(sorted([self.sources.index(target) for target in targets])), 50 | ) 51 | 52 | def forward(self, x): 53 | X = self.spec(x) 54 | X_mag = X.abs() 55 | pred_mask = self.model(X_mag) 56 | Y = self.mwf(pred_mask, X) 57 | pred = self.inv_spec(Y) 58 | return pred 59 | 60 | def training_step(self, batch, batch_idx): 61 | x, y = batch 62 | if len(self.transforms) > 0: 63 | y = self.transforms(y) 64 | x = y.sum(1) 65 | y = y[:, self.targets_idx].squeeze(1) 66 | 67 | X = self.spec(x) 68 | Y = self.spec(y) 69 | X_mag = X.abs() 70 | pred_mask = self.model(X_mag) 71 | loss, values = self.criterion(pred_mask, Y, X, y, x) 72 | 73 | values["loss"] = loss 74 | self.log_dict(values, prog_bar=False, sync_dist=True) 75 | return loss 76 | 77 | def validation_step(self, batch, batch_idx): 78 | x, y = batch 79 | y = y[:, self.targets_idx].squeeze(1) 80 | 81 | X = self.spec(x) 82 | Y = self.spec(y) 83 | X_mag = X.abs() 84 | pred_mask = self.model(X_mag) 85 | loss, values = self.criterion(pred_mask, Y, X, y, x) 86 | 87 | pred = self.inv_spec(self.mwf(pred_mask, X)) 88 | 89 | batch = pred.shape[0] 90 | sdrs = ( 91 | self.sdr(pred.view(-1, *pred.shape[-2:]), y.view(-1, *y.shape[-2:])) 92 | .view(batch, -1) 93 | .mean(0) 94 | ) 95 | 96 | for i, t in enumerate(self.targets_idx): 97 | values[f"{self.sources[t]}_sdr"] = sdrs[i].item() 98 | values["avg_sdr"] = sdrs.mean().item() 99 | return loss, values 100 | 101 | def validation_epoch_end(self, outputs) -> None: 102 | avg_loss = sum(x[0] for x in outputs) / len(outputs) 103 | avg_values = {} 104 | for k in outputs[0][1].keys(): 105 | avg_values[k] = sum(x[1][k] for x in outputs) / len(outputs) 106 | 107 | self.log("val_loss", avg_loss, prog_bar=True, sync_dist=True) 108 | self.log_dict(avg_values, prog_bar=False, sync_dist=True) 109 | -------------------------------------------------------------------------------- /aimless/lightning/waveform.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | from typing import List, Dict 5 | 6 | from ..loss.time import TLoss, SDR 7 | from ..augment import CudaBase 8 | 9 | from ..utils import MDX_SOURCES, SDX_SOURCES, SE_SOURCES 10 | 11 | 12 | class WaveformSeparator(pl.LightningModule): 13 | def __init__( 14 | self, 15 | model: nn.Module, 16 | criterion: TLoss, 17 | transforms: List[CudaBase] = None, 18 | use_sdx_targets: bool = False, # will be deprecated, please use target_track 19 | target_track: str = None, 20 | targets: Dict[str, None] = {}, 21 | ): 22 | super().__init__() 23 | 24 | self.model = model 25 | self.criterion = criterion 26 | self.sdr = SDR() 27 | 28 | if transforms is None: 29 | transforms = [] 30 | 31 | if target_track is not None: 32 | if target_track == "sdx": 33 | self.sources = SDX_SOURCES 34 | elif target_track == "mdx": 35 | self.sources = MDX_SOURCES 36 | elif target_track == "se": 37 | self.sources = SE_SOURCES 38 | else: 39 | raise ValueError(f"Invalid target track: {target_track}") 40 | else: 41 | self.sources = SDX_SOURCES if use_sdx_targets else MDX_SOURCES 42 | 43 | self.transforms = nn.Sequential(*transforms) 44 | self.register_buffer( 45 | "targets_idx", 46 | torch.tensor(sorted([self.sources.index(target) for target in targets])), 47 | ) 48 | 49 | def forward(self, x): 50 | return self.model(x) 51 | 52 | def training_step(self, batch, batch_idx): 53 | x, y = batch 54 | if len(self.transforms) > 0: 55 | y = self.transforms(y) 56 | x = y.sum(1) 57 | y = y[:, self.targets_idx].squeeze(1) 58 | 59 | pred = self.model(x) 60 | if pred.ndim == 4: 61 | pred = pred.squeeze(1) 62 | loss, values = self.criterion(pred, y, x) 63 | 64 | values["loss"] = loss 65 | self.log_dict(values, prog_bar=False, sync_dist=True) 66 | return loss 67 | 68 | def validation_step(self, batch, batch_idx): 69 | x, y = batch 70 | y = y[:, self.targets_idx].squeeze(1) 71 | 72 | pred = self.model(x) 73 | loss, values = self.criterion(pred, y, x) 74 | 75 | batch = pred.shape[0] 76 | sdrs = ( 77 | self.sdr(pred.view(-1, *pred.shape[-2:]), y.view(-1, *y.shape[-2:])) 78 | .view(batch, -1) 79 | .mean(0) 80 | ) 81 | 82 | for i, t in enumerate(self.targets_idx): 83 | values[f"{self.sources[t]}_sdr"] = sdrs[i].item() 84 | values["avg_sdr"] = sdrs.mean().item() 85 | return loss, values 86 | 87 | def validation_epoch_end(self, outputs) -> None: 88 | avg_loss = sum(x[0] for x in outputs) / len(outputs) 89 | avg_values = {} 90 | for k in outputs[0][1].keys(): 91 | avg_values[k] = sum(x[1][k] for x in outputs) / len(outputs) 92 | 93 | self.log("val_loss", avg_loss, prog_bar=True, sync_dist=True) 94 | self.log_dict(avg_values, prog_bar=False, sync_dist=True) 95 | -------------------------------------------------------------------------------- /aimless/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/loss/__init__.py -------------------------------------------------------------------------------- /aimless/loss/freq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from itertools import combinations, chain 4 | from ..utils import MWF 5 | from torchaudio.transforms import InverseSpectrogram 6 | 7 | 8 | class FLoss(torch.nn.Module): 9 | r"""Base class for frequency domain loss modules. 10 | You can't use this module directly. 11 | Your loss should also subclass this class. 12 | Args: 13 | msk_hat (Tensor): mask tensor with size (batch, *, channels, bins, frames), `*` is an optional multi targets dimension. 14 | The tensor value should always be non-negative 15 | gt_spec (Tensor): target spectrogram complex tensor with size (batch, *, channels, bins, frames), `*` is an optional multi targets dimension. 16 | mix_spec (Tensor): mixture spectrogram complex tensor with size (batch, channels, bins, frames) 17 | gt (Tensor): target time signal tensor with size (batch, *, channels, samples), `*` is an optional multi targets dimension. 18 | mix (Tensor): mixture time signal tensor with size (batch, channels, samples) 19 | Returns: 20 | tuple: a length-2 tuple with the first element is the final loss tensor, 21 | and the second is a dict containing any intermediate loss value you want to monitor 22 | """ 23 | 24 | def forward(self, *args, **kwargs): 25 | return self._core_loss(*args, **kwargs) 26 | 27 | def _core_loss(self, msk_hat, gt_spec, mix_spec, gt, mix): 28 | raise NotImplementedError 29 | 30 | 31 | class CLoss(FLoss): 32 | def __init__( 33 | self, mcoeff=10, n_fft=4096, hop_length=1024, complex_mse=True, **mwf_kwargs 34 | ): 35 | super().__init__() 36 | self.mcoeff = mcoeff 37 | self.inv_spec = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length) 38 | self.complex_mse = complex_mse 39 | if len(mwf_kwargs): 40 | self.mwf = MWF(**mwf_kwargs) 41 | 42 | def _core_loss(self, msk_hat, gt_spec, mix_spec, gt, mix): 43 | if hasattr(self, "mwf"): 44 | Y = self.mwf(msk_hat, mix_spec) 45 | else: 46 | Y = msk_hat * mix_spec.unsqueeze(1) 47 | pred = self.inv_spec(Y) 48 | if self.complex_mse: 49 | loss_f = complex_mse_loss(Y, gt_spec) 50 | else: 51 | loss_f = real_mse_loss(msk_hat, gt_spec.abs(), mix_spec.abs()) 52 | loss_t = sdr_loss(pred, gt, mix) 53 | loss = loss_f + self.mcoeff * loss_t 54 | return loss, {"loss_f": loss_f.item(), "loss_t": loss_t.item()} 55 | 56 | 57 | class MDLoss(FLoss): 58 | def __init__(self, mcoeff=10, n_fft=4096, hop_length=1024): 59 | super().__init__() 60 | self.mcoeff = mcoeff 61 | self.inv_spec = InverseSpectrogram(n_fft, hop_length=hop_length) 62 | 63 | def _core_loss(self, msk_hat, gt_spec, mix_spec, gt, mix): 64 | pred_spec = msk_hat * mix_spec 65 | diff = pred_spec - gt_spec 66 | 67 | real = diff.real.reshape(-1) 68 | imag = diff.imag.reshape(-1) 69 | mse = real @ real + imag @ imag 70 | loss_f = mse / real.numel() 71 | 72 | pred = self.inv_spec(pred_spec) 73 | batch_size, n_channels, length = pred.shape 74 | 75 | # Fix Length 76 | mix = mix[..., :length].reshape(-1, length) 77 | gt = gt[..., :length].reshape(-1, length) 78 | pred = pred.view(-1, length) 79 | 80 | loss_t = _sdr_loss_core(pred, gt, mix) + 1.0 81 | loss = loss_f + self.mcoeff * loss_t 82 | return loss, {"loss_f": loss_f.item(), "loss_t": loss_t.item()} 83 | 84 | 85 | def bce_loss(msk_hat, gt_spec): 86 | assert msk_hat.shape == gt_spec.shape 87 | loss = [] 88 | gt_spec_power = gt_spec.abs() 89 | gt_spec_power *= gt_spec_power 90 | divider = gt_spec_power.sum(1) + 1e-10 91 | for c in chain( 92 | combinations(range(4), 1), combinations(range(4), 2), combinations(range(4), 3) 93 | ): 94 | m = sum([msk_hat[:, i] for i in c]) 95 | gt = sum([gt_spec_power[:, i] for i in c]) / divider 96 | loss.append(F.binary_cross_entropy(m, gt)) 97 | 98 | # All 14 Combination Losses (4C1 + 4C2 + 4C3) 99 | loss_mse = sum(loss) / len(loss) 100 | return loss_mse 101 | 102 | 103 | def complex_mse_loss(y_hat: torch.Tensor, gt_spec: torch.Tensor): 104 | assert y_hat.shape == gt_spec.shape 105 | assert gt_spec.is_complex() and y_hat.is_complex() 106 | 107 | loss = [] 108 | for c in chain( 109 | combinations(range(4), 1), combinations(range(4), 2), combinations(range(4), 3) 110 | ): 111 | m = sum([y_hat[:, i] for i in c]) 112 | gt = sum([gt_spec[:, i] for i in c]) 113 | diff = m - gt 114 | real = diff.real.reshape(-1) 115 | imag = diff.imag.reshape(-1) 116 | mse = real @ real + imag @ imag 117 | loss.append(mse / real.numel()) 118 | 119 | # All 14 Combination Losses (4C1 + 4C2 + 4C3) 120 | loss_mse = sum(loss) / len(loss) 121 | return loss_mse 122 | 123 | 124 | def real_mse_loss(msk_hat: torch.Tensor, gt_spec: torch.Tensor, mix_spec: torch.Tensor): 125 | assert msk_hat.shape == gt_spec.shape 126 | assert ( 127 | msk_hat.is_floating_point() 128 | and gt_spec.is_floating_point() 129 | and mix_spec.is_floating_point() 130 | ) 131 | assert not gt_spec.is_complex() and not mix_spec.is_complex() 132 | 133 | loss = [] 134 | for c in chain( 135 | combinations(range(4), 1), combinations(range(4), 2), combinations(range(4), 3) 136 | ): 137 | m = sum([msk_hat[:, i] for i in c]) 138 | gt = sum([gt_spec[:, i] for i in c]) 139 | loss.append(F.mse_loss(m * mix_spec, gt)) 140 | 141 | # All 14 Combination Losses (4C1 + 4C2 + 4C3) 142 | loss_mse = sum(loss) / len(loss) 143 | return loss_mse 144 | 145 | 146 | def sdr_loss(pred, gt_time, mix): 147 | # SDR-Combination Loss 148 | 149 | batch_size, _, n_channels, length = pred.shape 150 | pred, gt_time = ( 151 | pred.transpose(0, 1).contiguous(), 152 | gt_time.transpose(0, 1).contiguous(), 153 | ) 154 | 155 | # Fix Length 156 | mix = mix[..., :length].reshape(-1, length) 157 | gt_time = gt_time[..., :length].reshape(_, -1, length) 158 | pred = pred.view(_, -1, length) 159 | 160 | extend_pred = [pred.view(-1, length)] 161 | extend_gt = [gt_time.view(-1, length)] 162 | 163 | for c in chain(combinations(range(4), 2), combinations(range(4), 3)): 164 | extend_pred.append(sum([pred[i] for i in c])) 165 | extend_gt.append(sum([gt_time[i] for i in c])) 166 | 167 | extend_pred = torch.cat(extend_pred, 0) 168 | extend_gt = torch.cat(extend_gt, 0) 169 | extend_mix = mix.repeat(14, 1) 170 | 171 | loss_sdr = _sdr_loss_core(extend_pred, extend_gt, extend_mix) 172 | 173 | return 1.0 + loss_sdr 174 | 175 | 176 | def _sdr_loss_core(x_hat, x, y): 177 | assert x.shape == y.shape == x_hat.shape # (Batch, Len) 178 | 179 | ns = y - x 180 | ns_hat = y - x_hat 181 | 182 | ns_norm = ns[:, None, :] @ ns[:, :, None] 183 | ns_hat_norm = ns_hat[:, None, :] @ ns_hat[:, :, None] 184 | 185 | x_norm = x[:, None, :] @ x[:, :, None] 186 | x_hat_norm = x_hat[:, None, :] @ x_hat[:, :, None] 187 | x_cross = x[:, None, :] @ x_hat[:, :, None] 188 | 189 | x_norm, x_hat_norm, ns_norm, ns_hat_norm = ( 190 | x_norm.relu(), 191 | x_hat_norm.relu(), 192 | ns_norm.relu(), 193 | ns_hat_norm.relu(), 194 | ) 195 | 196 | alpha = x_norm / (ns_norm + x_norm + 1e-10) 197 | 198 | # Target 199 | sdr_cln = x_cross / (x_norm.sqrt() * x_hat_norm.sqrt() + 1e-10) 200 | 201 | # Noise 202 | num_noise = ns[:, None, :] @ ns_hat[:, :, None] 203 | denom_noise = ns_norm.sqrt() * ns_hat_norm.sqrt() 204 | sdr_noise = num_noise / (denom_noise + 1e-10) 205 | 206 | return torch.mean(-alpha * sdr_cln - (1 - alpha) * sdr_noise) 207 | -------------------------------------------------------------------------------- /aimless/loss/time.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from itertools import combinations, chain 4 | 5 | 6 | class TLoss(torch.nn.Module): 7 | r"""Base class for time domain loss modules. 8 | You can't use this module directly. 9 | Your loss should also subclass this class. 10 | Args: 11 | pred (Tensor): predict time signal tensor with size (batch, *, channels, samples), `*` is an optional multi targets dimension. 12 | gt (Tensor): target time signal tensor with size (batch, *, channels, samples), `*` is an optional multi targets dimension. 13 | mix (Tensor): mixture time signal tensor with size (batch, channels, samples) 14 | Returns: 15 | tuple: a length-2 tuple with the first element is the final loss tensor, 16 | and the second is a dict containing any intermediate loss value you want to monitor 17 | """ 18 | 19 | def forward(self, *args, **kwargs): 20 | return self._core_loss(*args, **kwargs) 21 | 22 | def _core_loss(self, pred, gt, mix): 23 | raise NotImplementedError 24 | 25 | 26 | class SDR(torch.nn.Module): 27 | def __init__(self) -> None: 28 | super().__init__() 29 | self.expr = "bi,bi->b" 30 | 31 | def _batch_dot(self, x, y): 32 | return torch.einsum(self.expr, x, y) 33 | 34 | def forward(self, estimates, references): 35 | if estimates.dtype != references.dtype: 36 | estimates = estimates.to(references.dtype) 37 | length = min(references.shape[-1], estimates.shape[-1]) 38 | references = references[..., :length].reshape(references.shape[0], -1) 39 | estimates = estimates[..., :length].reshape(estimates.shape[0], -1) 40 | 41 | delta = 1e-7 # avoid numerical errors 42 | num = self._batch_dot(references, references) 43 | den = ( 44 | num 45 | + self._batch_dot(estimates, estimates) 46 | - 2 * self._batch_dot(estimates, references) 47 | ) 48 | den = den.relu().add(delta).log10() 49 | num = num.add(delta).log10() 50 | return 10 * (num - den) 51 | 52 | 53 | class NegativeSDR(TLoss): 54 | def __init__(self) -> None: 55 | super().__init__() 56 | self.sdr = SDR() 57 | 58 | def _core_loss(self, pred, gt, mix): 59 | return -self.sdr(pred, gt).mean(), {} 60 | 61 | 62 | class CL1Loss(TLoss): 63 | def _core_loss(self, pred, gt, mix): 64 | gt = gt[..., : pred.shape[-1]] 65 | loss = [] 66 | for c in chain( 67 | combinations(range(4), 1), 68 | combinations(range(4), 2), 69 | combinations(range(4), 3), 70 | ): 71 | x = sum([pred[:, i] for i in c]) 72 | y = sum([gt[:, i] for i in c]) 73 | loss.append(F.l1_loss(x, y)) 74 | 75 | # All 14 Combination Losses (4C1 + 4C2 + 4C3) 76 | loss_l1 = sum(loss) / len(loss) 77 | return loss_l1, {} 78 | 79 | 80 | class L1Loss(TLoss): 81 | def _core_loss(self, pred, gt, mix): 82 | return F.l1_loss(pred, gt[..., : pred.shape[-1]]), {} 83 | -------------------------------------------------------------------------------- /aimless/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/models/__init__.py -------------------------------------------------------------------------------- /aimless/models/band_split_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List, Tuple, Dict, Optional 5 | 6 | v7_freqs = [ 7 | 100, 8 | 200, 9 | 300, 10 | 400, 11 | 500, 12 | 600, 13 | 700, 14 | 800, 15 | 900, 16 | 1000, 17 | 1250, 18 | 1500, 19 | 1750, 20 | 2000, 21 | 2250, 22 | 2500, 23 | 2750, 24 | 3000, 25 | 3250, 26 | 3500, 27 | 3750, 28 | 4000, 29 | 4500, 30 | 5000, 31 | 5500, 32 | 6000, 33 | 6500, 34 | 7000, 35 | 7500, 36 | 8000, 37 | 9000, 38 | 10000, 39 | 11000, 40 | 12000, 41 | 13000, 42 | 14000, 43 | 15000, 44 | 16000, 45 | 18000, 46 | 20000, 47 | 22000, 48 | ] 49 | 50 | 51 | class BandSplitRNN(nn.Module): 52 | def __init__( 53 | self, 54 | n_fft: int, 55 | split_freqs: List[int] = v7_freqs, 56 | hidden_size: int = 128, 57 | num_layers: int = 12, 58 | norm_groups: int = 4, 59 | ) -> None: 60 | super().__init__() 61 | 62 | # get split freq bins index from split freqs 63 | index = [0] + [int(n_fft * f / 44100) for f in split_freqs] + [n_fft // 2 + 1] 64 | chunk_size = [index[i + 1] - index[i] for i in range(len(index) - 1)] 65 | self.split_sections = tuple(chunk_size) 66 | 67 | # stage 1: band split modules 68 | self.norm1_list = nn.ModuleList( 69 | [nn.LayerNorm(chunk_size[i]) for i in range(len(chunk_size))] 70 | ) 71 | self.fc1_list = nn.ModuleList( 72 | [nn.Linear(chunk_size[i], hidden_size) for i in range(len(chunk_size))] 73 | ) 74 | 75 | # stage 2: RNN modules 76 | self.band_lstms = nn.ModuleList() 77 | self.time_lstms = nn.ModuleList() 78 | self.band_group_norms = nn.ModuleList() 79 | self.time_group_norms = nn.ModuleList() 80 | for i in range(num_layers): 81 | self.band_group_norms.append(nn.GroupNorm(norm_groups, hidden_size)) 82 | self.band_lstms.append( 83 | nn.LSTM( 84 | input_size=hidden_size, 85 | hidden_size=hidden_size, 86 | batch_first=True, 87 | num_layers=1, 88 | bidirectional=True, 89 | proj_size=hidden_size // 2, 90 | ) 91 | ) 92 | 93 | self.time_lstms.append( 94 | nn.LSTM( 95 | input_size=hidden_size, 96 | hidden_size=hidden_size, 97 | batch_first=True, 98 | num_layers=1, 99 | bidirectional=True, 100 | proj_size=hidden_size // 2, 101 | ) 102 | ) 103 | self.time_group_norms.append(nn.GroupNorm(norm_groups, hidden_size)) 104 | 105 | # stage 3: band merge modules and mask prediction modules 106 | self.norm2_list = nn.ModuleList( 107 | [nn.LayerNorm(hidden_size) for i in range(len(chunk_size))] 108 | ) 109 | self.mlps = nn.ModuleList( 110 | [ 111 | nn.Sequential( 112 | nn.Linear(hidden_size, hidden_size * 4), 113 | nn.Tanh(), 114 | nn.Linear(hidden_size * 4, chunk_size[i]), 115 | ) 116 | for i in range(len(chunk_size)) 117 | ] 118 | ) 119 | 120 | def forward(self, mag: torch.Tensor): 121 | log_mag = torch.log(mag + 1e-8) 122 | batch, channels, freq_bins, time_bins = log_mag.shape 123 | # merge channels to batch 124 | log_mag = log_mag.view(-1, freq_bins, time_bins).transpose(1, 2) 125 | 126 | # stage 1: band split modules 127 | tmp = [] 128 | for fc, norm, sub_band in zip( 129 | self.fc1_list, self.norm1_list, log_mag.split(self.split_sections, dim=2) 130 | ): 131 | tmp.append(fc(norm(sub_band))) 132 | 133 | x = torch.stack(tmp, dim=1) 134 | 135 | # stage 2: RNN modules 136 | for band_lstm, time_lstm, band_norm, time_norm in zip( 137 | self.band_lstms, 138 | self.time_lstms, 139 | self.band_group_norms, 140 | self.time_group_norms, 141 | ): 142 | x_reshape = x.reshape(-1, x.shape[-2], x.shape[-1]) 143 | x_reshape = time_norm(x_reshape.transpose(1, 2)).transpose(1, 2) 144 | x = time_lstm(x_reshape)[0].view(*x.shape) + x 145 | 146 | x_reshape = x.transpose(1, 2).reshape(-1, x.shape[-3], x.shape[-1]) 147 | x_reshape = band_norm(x_reshape.transpose(1, 2)).transpose(1, 2) 148 | x = ( 149 | band_lstm(x_reshape)[0] 150 | .view(x.shape[0], x.shape[2], x.shape[1], x.shape[3]) 151 | .transpose(1, 2) 152 | + x 153 | ) 154 | 155 | # stage 3: band merge modules and mask prediction modules 156 | tmp = [] 157 | for i, (mlp, norm) in enumerate(zip(self.mlps, self.norm2_list)): 158 | tmp.append(mlp(norm(x[:, i]))) 159 | 160 | mask = ( 161 | torch.cat(tmp, dim=-1) 162 | .transpose(1, 2) 163 | .reshape(batch, channels, freq_bins, time_bins) 164 | .sigmoid() 165 | ) 166 | return mask 167 | 168 | 169 | class BandSplitRNNMulti(nn.Module): 170 | def __init__( 171 | self, 172 | n_fft: int, 173 | n_sources: int, 174 | split_freqs: List[int] = v7_freqs, 175 | hidden_size: int = 128, 176 | num_layers: int = 12, 177 | norm_groups: int = 4, 178 | ) -> None: 179 | super().__init__() 180 | 181 | # get split freq bins index from split freqs 182 | index = [0] + [int(n_fft * f / 44100) for f in split_freqs] + [n_fft // 2 + 1] 183 | chunk_size = [index[i + 1] - index[i] for i in range(len(index) - 1)] 184 | self.split_sections = tuple(chunk_size) 185 | 186 | self.n_sources = n_sources 187 | 188 | # stage 1: band split modules 189 | self.norm1_list = nn.ModuleList( 190 | [nn.LayerNorm(chunk_size[i]) for i in range(len(chunk_size))] 191 | ) 192 | self.fc1_list = nn.ModuleList( 193 | [nn.Linear(chunk_size[i], hidden_size) for i in range(len(chunk_size))] 194 | ) 195 | 196 | # stage 2: RNN modules 197 | self.band_lstms = nn.ModuleList() 198 | self.time_lstms = nn.ModuleList() 199 | self.band_group_norms = nn.ModuleList() 200 | self.time_group_norms = nn.ModuleList() 201 | for i in range(num_layers): 202 | self.band_group_norms.append(nn.GroupNorm(norm_groups, hidden_size)) 203 | self.band_lstms.append( 204 | nn.LSTM( 205 | input_size=hidden_size, 206 | hidden_size=hidden_size, 207 | batch_first=True, 208 | num_layers=1, 209 | bidirectional=True, 210 | proj_size=hidden_size // 2, 211 | ) 212 | ) 213 | 214 | self.time_lstms.append( 215 | nn.LSTM( 216 | input_size=hidden_size, 217 | hidden_size=hidden_size, 218 | batch_first=True, 219 | num_layers=1, 220 | bidirectional=True, 221 | proj_size=hidden_size // 2, 222 | ) 223 | ) 224 | self.time_group_norms.append(nn.GroupNorm(norm_groups, hidden_size)) 225 | 226 | # stage 3: band merge modules and mask prediction modules 227 | self.norm2_list = nn.ModuleList( 228 | [nn.LayerNorm(hidden_size) for i in range(len(chunk_size))] 229 | ) 230 | self.mlps = nn.ModuleList( 231 | [ 232 | nn.Sequential( 233 | nn.Linear(hidden_size, hidden_size * 4), 234 | nn.Tanh(), 235 | nn.Linear(hidden_size * 4, chunk_size[i] * n_sources), 236 | ) 237 | for i in range(len(chunk_size)) 238 | ] 239 | ) 240 | 241 | def forward(self, mag: torch.Tensor): 242 | log_mag = torch.log(mag + 1e-8) 243 | batch, channels, freq_bins, time_bins = log_mag.shape 244 | # merge channels to batch 245 | log_mag = log_mag.view(-1, freq_bins, time_bins).transpose(1, 2) 246 | 247 | # stage 1: band split modules 248 | tmp = [] 249 | for fc, norm, sub_band in zip( 250 | self.fc1_list, self.norm1_list, log_mag.split(self.split_sections, dim=2) 251 | ): 252 | tmp.append(fc(norm(sub_band))) 253 | 254 | x = torch.stack(tmp, dim=1) 255 | 256 | # stage 2: RNN modules 257 | for band_lstm, time_lstm, band_norm, time_norm in zip( 258 | self.band_lstms, 259 | self.time_lstms, 260 | self.band_group_norms, 261 | self.time_group_norms, 262 | ): 263 | x_reshape = x.reshape(-1, x.shape[-2], x.shape[-1]) 264 | x_reshape = time_norm(x_reshape.transpose(1, 2)).transpose(1, 2) 265 | x = time_lstm(x_reshape)[0].view(*x.shape) + x 266 | 267 | x_reshape = x.transpose(1, 2).reshape(-1, x.shape[-3], x.shape[-1]) 268 | x_reshape = band_norm(x_reshape.transpose(1, 2)).transpose(1, 2) 269 | x = ( 270 | band_lstm(x_reshape)[0] 271 | .view(x.shape[0], x.shape[2], x.shape[1], x.shape[3]) 272 | .transpose(1, 2) 273 | + x 274 | ) 275 | 276 | # stage 3: band merge modules and mask prediction modules 277 | tmp = [] 278 | for i, (mlp, norm) in enumerate(zip(self.mlps, self.norm2_list)): 279 | tmp.append( 280 | mlp(norm(x[:, i])).view(x.shape[0], x.shape[2], self.n_sources, -1) 281 | ) 282 | 283 | mask = ( 284 | torch.cat(tmp, dim=-1) 285 | .reshape(batch, channels, time_bins, self.n_sources, freq_bins) 286 | .permute(0, 3, 1, 4, 2) 287 | .softmax(dim=1) 288 | ) 289 | return mask 290 | -------------------------------------------------------------------------------- /aimless/models/demucs_split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchaudio.transforms import Resample 4 | 5 | 6 | @torch.jit.script 7 | def glu(a, b): 8 | return a * b.sigmoid() 9 | 10 | 11 | @torch.jit.script 12 | def standardize(x, mu, std): 13 | return (x - mu) / std 14 | 15 | 16 | @torch.jit.script 17 | def destandardize(x, mu, std): 18 | return x * std + mu 19 | 20 | 21 | def rescale_conv(reference): 22 | @torch.no_grad() 23 | def closure(m: nn.Module): 24 | if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)): 25 | std = m.weight.std() 26 | scale = (std / reference) ** 0.5 27 | m.weight.div_(scale) 28 | if m.bias is not None: 29 | m.bias.div_(scale) 30 | 31 | return closure 32 | 33 | 34 | class DemucsSplit(nn.Module): 35 | def __init__( 36 | self, 37 | channels=64, 38 | depth=6, 39 | rescale=0.1, 40 | resample=True, 41 | kernel_size=8, 42 | stride=4, 43 | lstm_layers=2, 44 | ): 45 | super().__init__() 46 | self.kernel_size = kernel_size 47 | self.stride = stride 48 | self.depth = depth 49 | self.channels = channels 50 | 51 | if resample: 52 | self.up_sample = Resample(1, 2) 53 | self.down_sample = Resample(2, 1) 54 | 55 | self.encoder = nn.ModuleList() 56 | self.convs_1x1 = nn.ModuleList() 57 | self.dec_pre_convs = nn.ModuleList() 58 | self.decoder = nn.ModuleList() 59 | 60 | in_channels = 2 61 | for index in range(depth): 62 | self.encoder.append( 63 | nn.Sequential( 64 | nn.Conv1d(in_channels, channels, kernel_size, stride), 65 | nn.ReLU(inplace=True), 66 | nn.Conv1d(channels, channels * 2, 1), 67 | nn.GLU(dim=1), 68 | ) 69 | ) 70 | 71 | decode = [] 72 | if index > 0: 73 | out_channels = in_channels * 2 74 | else: 75 | out_channels = 8 76 | 77 | self.convs_1x1.insert( 78 | 0, nn.Conv1d(channels, channels * 4, 3, padding=1, bias=False) 79 | ) 80 | self.dec_pre_convs.insert( 81 | 0, nn.Conv1d(channels * 2, channels * 4, 3, padding=1, groups=4) 82 | ) 83 | decode = [ 84 | nn.ConvTranspose1d( 85 | channels * 2, out_channels, kernel_size, stride, groups=4 86 | ) 87 | ] 88 | if index > 0: 89 | decode.append(nn.ReLU(inplace=True)) 90 | self.decoder.insert(0, nn.Sequential(*decode)) 91 | in_channels = channels 92 | channels *= 2 93 | 94 | channels = in_channels 95 | 96 | self.lstm = nn.LSTM( 97 | input_size=channels, 98 | hidden_size=channels, 99 | num_layers=lstm_layers, 100 | dropout=0, 101 | bidirectional=True, 102 | ) 103 | self.lstm_linear = nn.Linear(channels * 2, channels * 2) 104 | 105 | self.apply(rescale_conv(reference=rescale)) 106 | 107 | def forward(self, x): 108 | length = x.size(2) 109 | 110 | mono = x.mean(1, keepdim=True) 111 | mu = mono.mean(dim=-1, keepdim=True) 112 | std = mono.std(dim=-1, keepdim=True).add_(1e-5) 113 | x = standardize(x, mu, std) 114 | 115 | if hasattr(self, "up_sample"): 116 | x = self.up_sample(x) 117 | 118 | saved = [] 119 | for encode in self.encoder: 120 | x = encode(x) 121 | saved.append(x) 122 | 123 | x = x.permute(2, 0, 1) 124 | x = self.lstm(x)[0] 125 | x = self.lstm_linear(x).permute(1, 2, 0) 126 | 127 | for decode, pre_dec, conv1x1 in zip( 128 | self.decoder, self.dec_pre_convs, self.convs_1x1 129 | ): 130 | skip = saved.pop() 131 | 132 | x = pre_dec(x) + conv1x1(skip[..., : x.shape[2]]) 133 | a, b = x.view(x.shape[0], 4, -1, x.shape[2]).chunk(2, 2) 134 | x = glu(a, b) 135 | x = decode(x.view(x.shape[0], -1, x.shape[3])) 136 | 137 | if hasattr(self, "down_sample"): 138 | x = self.down_sample(x) 139 | 140 | x = destandardize(x, mu, std) 141 | x = x.view(-1, 4, 2, x.size(-1)) 142 | return x 143 | -------------------------------------------------------------------------------- /aimless/models/rel_tdemucs.py: -------------------------------------------------------------------------------- 1 | from math import inf 2 | from typing import Optional 3 | import torch 4 | from torch import nn, Tensor 5 | import torch.nn.functional as F 6 | from torchaudio.transforms import Resample 7 | 8 | from .demucs_split import standardize, destandardize, rescale_conv 9 | 10 | 11 | class PositionalEmbedding(nn.Module): 12 | def __init__(self, demb): 13 | super(PositionalEmbedding, self).__init__() 14 | 15 | self.demb = demb 16 | 17 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 18 | self.register_buffer("inv_freq", inv_freq) 19 | 20 | def forward(self, pos_seq): 21 | sinusoid_inp = pos_seq[:, None] * self.inv_freq 22 | pos_emb = torch.view_as_real(torch.exp(1j * sinusoid_inp)).view( 23 | pos_seq.size(0), -1 24 | ) 25 | return pos_emb 26 | 27 | 28 | class RelMultiheadAttention(nn.MultiheadAttention): 29 | def __init__(self, *args, **kwargs): 30 | super().__init__( 31 | *args, 32 | bias=False, 33 | add_bias_kv=False, 34 | add_zero_attn=False, 35 | kdim=None, 36 | vdim=None, 37 | batch_first=True, 38 | **kwargs 39 | ) 40 | self.register_parameter( 41 | "u", nn.Parameter(torch.zeros(self.num_heads, self.head_dim)) 42 | ) 43 | self.register_parameter( 44 | "v", nn.Parameter(torch.zeros(self.num_heads, self.head_dim)) 45 | ) 46 | self.pos_emb = PositionalEmbedding(self.embed_dim) 47 | self.pos_emb_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) 48 | 49 | def forward(self, x, mask=None): 50 | # x: [B, T, C] 51 | B, T, _ = x.size() 52 | seq = torch.arange(-T + 1, T, device=x.device) 53 | pos_emb = self.pos_emb_proj(self.pos_emb(seq)) 54 | pos_emb = pos_emb.view(-1, self.num_heads, self.head_dim) 55 | 56 | h = x @ self.in_proj_weight.t() 57 | h = h.view(B, T, self.num_heads, self.head_dim * 3) 58 | w_head_q, w_head_k, w_head_v = h.chunk(3, dim=-1) 59 | 60 | rw_head_q = w_head_q + self.u 61 | AC = rw_head_q.transpose(1, 2) @ w_head_k.permute(0, 2, 3, 1) 62 | 63 | rr_head_q = w_head_q + self.v 64 | BD = rr_head_q.transpose(1, 2) @ pos_emb.permute(1, 2, 0) # [B, H, T, 2T-1] 65 | BD = F.pad(BD, (1, 1)).view(B, self.num_heads, 2 * T + 1, T)[ 66 | :, :, 1::2, : 67 | ] # [B, H, T, T] 68 | 69 | attn_score = (AC + BD) / self.head_dim**0.5 70 | 71 | if mask is not None: 72 | attn_score = attn_score.masked_fill(mask, -inf) 73 | 74 | with torch.cuda.amp.autocast(enabled=False): 75 | attn_prob = F.softmax(attn_score.float(), dim=-1) 76 | attn_prob = F.dropout(attn_prob, self.dropout, self.training) 77 | 78 | attn_vec = attn_prob @ w_head_v.transpose(1, 2) 79 | return self.out_proj(attn_vec.permute(0, 2, 1, 3).reshape(B, T, -1)) 80 | 81 | 82 | class RelEncoderLayer(nn.TransformerEncoderLayer): 83 | def __init__(self, d_model: int, nhead: int, *args, dropout: float = 0.1, **kwargs): 84 | super().__init__( 85 | d_model, nhead, *args, batch_first=True, dropout=dropout, **kwargs 86 | ) 87 | self.self_attn = RelMultiheadAttention(d_model, nhead, dropout=dropout) 88 | 89 | def _sa_block(self, x: Tensor, mask: Tensor = None) -> Tensor: 90 | return self.dropout1(self.self_attn(x, mask)) 91 | 92 | def forward( 93 | self, 94 | src: Tensor, 95 | src_mask: Optional[Tensor] = None, 96 | src_key_padding_mask: Optional[Tensor] = None, 97 | ) -> Tensor: 98 | x = src 99 | if self.norm_first: 100 | x = x + self._sa_block(self.norm1(x), src_mask) 101 | x = x + self._ff_block(self.norm2(x)) 102 | else: 103 | x = self.norm1(x + self._sa_block(x, src_mask)) 104 | x = self.norm2(x + self._ff_block(x)) 105 | return x 106 | 107 | 108 | class RelTDemucs(nn.Module): 109 | def __init__( 110 | self, 111 | in_channels=2, 112 | num_sources=4, 113 | channels=64, 114 | depth=6, 115 | context_size=None, 116 | rescale=0.1, 117 | resample=True, 118 | kernel_size=8, 119 | stride=4, 120 | attention_layers=8, 121 | **kwargs 122 | ): 123 | super().__init__() 124 | self.kernel_size = kernel_size 125 | self.stride = stride 126 | self.depth = depth 127 | self.channels = channels 128 | self.context_size = context_size 129 | 130 | if resample: 131 | self.up_sample = Resample(1, 2) 132 | self.down_sample = Resample(2, 1) 133 | 134 | self.encoder = nn.ModuleList() 135 | self.decoder = nn.ModuleList() 136 | 137 | current_channels = in_channels 138 | for index in range(depth): 139 | self.encoder.append( 140 | nn.Sequential( 141 | nn.Conv1d(current_channels, channels, kernel_size, stride), 142 | nn.ReLU(inplace=True), 143 | nn.Conv1d(channels, channels * 2, 1), 144 | nn.GLU(dim=1), 145 | ) 146 | ) 147 | 148 | out_channels = current_channels if index > 0 else num_sources * in_channels 149 | 150 | decode = [ 151 | nn.Conv1d(channels, channels * 2, 3, padding=1, bias=False), 152 | nn.GLU(dim=1), 153 | nn.ConvTranspose1d( 154 | channels, 155 | out_channels, 156 | kernel_size, 157 | stride, 158 | ), 159 | ] 160 | if index > 0: 161 | decode.append(nn.ReLU(inplace=True)) 162 | self.decoder.insert(0, nn.Sequential(*decode)) 163 | current_channels = channels 164 | channels *= 2 165 | 166 | channels = current_channels 167 | 168 | encoder_layer = RelEncoderLayer( 169 | d_model=channels, dim_feedforward=channels * 4, **kwargs 170 | ) 171 | self.transformer = nn.TransformerEncoder(encoder_layer, attention_layers) 172 | self.apply(rescale_conv(reference=rescale)) 173 | 174 | def forward(self, x, context_size: int = None): 175 | batch, ch, _ = x.shape 176 | mono = x.mean(1, keepdim=True) 177 | mu = mono.mean(dim=-1, keepdim=True) 178 | std = mono.std(dim=-1, keepdim=True).add_(1e-5) 179 | x = standardize(x, mu, std) 180 | 181 | if hasattr(self, "up_sample"): 182 | x = self.up_sample(x) 183 | 184 | saved = [] 185 | for encode in self.encoder: 186 | x = encode(x) 187 | saved.append(x) 188 | 189 | x = x.transpose(1, 2) 190 | mask = None 191 | if context_size is None and self.context_size is not None: 192 | context_size = self.context_size 193 | 194 | if context_size and context_size < x.size(1): 195 | mask = x.new_ones(x.size(1), x.size(1), dtype=torch.bool) 196 | mask = torch.triu(mask, diagonal=context_size) 197 | mask = mask | mask.T 198 | 199 | x = self.transformer(x, mask).transpose(1, 2) 200 | 201 | for decode in self.decoder: 202 | skip = saved.pop() 203 | x = decode(x + skip[..., : x.shape[-1]]) 204 | 205 | if hasattr(self, "down_sample"): 206 | x = self.down_sample(x) 207 | 208 | x = destandardize(x, mu, std) 209 | x = x.view(batch, -1, ch, x.shape[-1]) 210 | return x 211 | -------------------------------------------------------------------------------- /aimless/models/xumx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class X_UMX(nn.Module): 7 | __constants__ = ["max_bins"] 8 | 9 | max_bins: int 10 | 11 | def __init__( 12 | self, n_fft=4096, hidden_channels=512, max_bins=None, nb_channels=2, nb_layers=3 13 | ): 14 | super().__init__() 15 | 16 | self.nb_output_bins = n_fft // 2 + 1 17 | if max_bins: 18 | self.max_bins = max_bins 19 | else: 20 | self.max_bins = self.nb_output_bins 21 | self.hidden_channels = hidden_channels 22 | self.n_fft = n_fft 23 | self.nb_channels = nb_channels 24 | self.nb_layers = nb_layers 25 | 26 | self.input_means = nn.Parameter(torch.zeros(4 * self.max_bins)) 27 | self.input_scale = nn.Parameter(torch.ones(4 * self.max_bins)) 28 | 29 | self.output_means = nn.Parameter(torch.zeros(4 * self.nb_output_bins)) 30 | self.output_scale = nn.Parameter(torch.ones(4 * self.nb_output_bins)) 31 | 32 | self.affine1 = nn.Sequential( 33 | nn.Conv1d( 34 | nb_channels * self.max_bins * 4, 35 | hidden_channels * 4, 36 | 1, 37 | bias=False, 38 | groups=4, 39 | ), 40 | nn.BatchNorm1d(hidden_channels * 4), 41 | nn.Tanh(), 42 | ) 43 | self.bass_lstm = nn.LSTM( 44 | input_size=self.hidden_channels, 45 | hidden_size=self.hidden_channels // 2, 46 | num_layers=nb_layers, 47 | dropout=0.4, 48 | bidirectional=True, 49 | ) 50 | self.drums_lstm = nn.LSTM( 51 | input_size=self.hidden_channels, 52 | hidden_size=self.hidden_channels // 2, 53 | num_layers=nb_layers, 54 | dropout=0.4, 55 | bidirectional=True, 56 | ) 57 | self.vocals_lstm = nn.LSTM( 58 | input_size=self.hidden_channels, 59 | hidden_size=self.hidden_channels // 2, 60 | num_layers=nb_layers, 61 | dropout=0.4, 62 | bidirectional=True, 63 | ) 64 | self.other_lstm = nn.LSTM( 65 | input_size=self.hidden_channels, 66 | hidden_size=self.hidden_channels // 2, 67 | num_layers=nb_layers, 68 | dropout=0.4, 69 | bidirectional=True, 70 | ) 71 | 72 | self.affine2 = nn.Sequential( 73 | nn.Conv1d(hidden_channels * 2, hidden_channels * 4, 1, bias=False), 74 | nn.BatchNorm1d(hidden_channels * 4), 75 | nn.ReLU(inplace=True), 76 | nn.Conv1d( 77 | hidden_channels * 4, 78 | nb_channels * self.nb_output_bins * 4, 79 | 1, 80 | bias=False, 81 | groups=4, 82 | ), 83 | nn.BatchNorm1d(nb_channels * self.nb_output_bins * 4), 84 | ) 85 | 86 | def forward(self, spec: torch.Tensor): 87 | batch, channels, bins, frames = spec.shape 88 | spec = spec[..., : self.max_bins, :] 89 | 90 | x = ( 91 | spec.unsqueeze(1) + self.input_means.view(4, 1, -1, 1) 92 | ) * self.input_scale.view(4, 1, -1, 1) 93 | 94 | x = x.reshape(batch, -1, frames) 95 | cross_1 = self.affine1(x).view(batch, 4, -1, frames).mean(1) 96 | 97 | cross_1 = cross_1.permute(2, 0, 1) 98 | bass, *_ = self.bass_lstm(cross_1) 99 | drums, *_ = self.drums_lstm(cross_1) 100 | others, *_ = self.other_lstm(cross_1) 101 | vocals, *_ = self.vocals_lstm(cross_1) 102 | 103 | avg = (bass + drums + vocals + others) * 0.25 104 | cross_2 = torch.cat([cross_1, avg], 2).permute(1, 2, 0) 105 | 106 | mask = self.affine2(cross_2).view( 107 | batch, 4, channels, bins, frames 108 | ) * self.output_scale.view(4, 1, -1, 1) + self.output_means.view(4, 1, -1, 1) 109 | return mask.relu() 110 | -------------------------------------------------------------------------------- /aimless/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .mwf import MWF 2 | 3 | 4 | MDX_SOURCES = ["drums", "bass", "other", "vocals"] 5 | SDX_SOURCES = ["music", "sfx", "speech"] 6 | SE_SOURCES = ["speech", "noise"] 7 | -------------------------------------------------------------------------------- /aimless/utils/mwf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import norbert 3 | 4 | 5 | class MWF(torch.nn.Module): 6 | def __init__( 7 | self, residual_model=False, softmask=False, alpha=1.0, n_iter=1 8 | ) -> None: 9 | super().__init__() 10 | self.residual_model = residual_model 11 | self.n_iter = n_iter 12 | self.softmask = softmask 13 | self.alpha = alpha 14 | 15 | def forward(self, msk_hat, mix_spec): 16 | assert msk_hat.ndim > mix_spec.ndim 17 | if not self.softmask and not self.n_iter and not self.residual_model: 18 | return msk_hat * mix_spec.unsqueeze(1) 19 | 20 | V = msk_hat * mix_spec.abs().unsqueeze(1) 21 | if self.softmask and self.alpha != 1: 22 | V = V.pow(self.alpha) 23 | 24 | X = mix_spec.transpose(1, 3).contiguous() 25 | V = V.permute(0, 4, 3, 2, 1).contiguous() 26 | 27 | if self.residual_model or V.shape[4] == 1: 28 | V = norbert.residual_model(V, X, self.alpha if self.softmask else 1) 29 | 30 | Y = norbert.wiener( 31 | V, X.to(torch.complex128), self.n_iter, use_softmask=self.softmask 32 | ).to(X.dtype) 33 | 34 | Y = Y.permute(0, 4, 3, 2, 1).contiguous() 35 | return Y 36 | -------------------------------------------------------------------------------- /aimless/utils/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/utils/utils.py -------------------------------------------------------------------------------- /cfg/cdx_a/bandsplit_rnn.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.5.post0 2 | seed_everything: 2434 3 | trainer: 4 | logger: true 5 | enable_checkpointing: true 6 | callbacks: 7 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 8 | init_args: 9 | dirpath: null 10 | filename: null 11 | monitor: null 12 | verbose: false 13 | save_last: true 14 | save_top_k: 1 15 | save_weights_only: false 16 | mode: min 17 | auto_insert_metric_name: true 18 | every_n_train_steps: 2000 19 | train_time_interval: null 20 | every_n_epochs: null 21 | save_on_train_epoch_end: null 22 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 23 | init_args: 24 | dirpath: null 25 | filename: null 26 | monitor: null 27 | verbose: false 28 | save_last: null 29 | save_top_k: -1 30 | save_weights_only: true 31 | mode: min 32 | auto_insert_metric_name: true 33 | every_n_train_steps: 1000 34 | train_time_interval: null 35 | every_n_epochs: null 36 | save_on_train_epoch_end: null 37 | - class_path: pytorch_lightning.callbacks.ModelSummary 38 | init_args: 39 | max_depth: 2 40 | default_root_dir: null 41 | gradient_clip_val: null 42 | gradient_clip_algorithm: null 43 | num_nodes: 1 44 | num_processes: null 45 | devices: null 46 | gpus: null 47 | auto_select_gpus: false 48 | tpu_cores: null 49 | ipus: null 50 | enable_progress_bar: true 51 | overfit_batches: 0.0 52 | track_grad_norm: -1 53 | check_val_every_n_epoch: 1 54 | fast_dev_run: false 55 | accumulate_grad_batches: null 56 | max_epochs: null 57 | min_epochs: null 58 | max_steps: 99000 59 | min_steps: null 60 | max_time: null 61 | limit_train_batches: null 62 | limit_val_batches: 0 63 | limit_test_batches: null 64 | limit_predict_batches: null 65 | val_check_interval: null 66 | log_every_n_steps: 1 67 | accelerator: gpu 68 | strategy: ddp 69 | sync_batchnorm: false 70 | precision: 32 71 | enable_model_summary: true 72 | num_sanity_val_steps: 0 73 | resume_from_checkpoint: null 74 | profiler: null 75 | benchmark: null 76 | deterministic: null 77 | reload_dataloaders_every_n_epochs: 0 78 | auto_lr_find: false 79 | replace_sampler_ddp: true 80 | detect_anomaly: false 81 | auto_scale_batch_size: false 82 | plugins: null 83 | amp_backend: native 84 | amp_level: null 85 | move_metrics_to_cpu: false 86 | multiple_trainloader_mode: max_size_cycle 87 | inference_mode: true 88 | ckpt_path: null 89 | model: 90 | class_path: aimless.lightning.freq_mask.MaskPredictor 91 | init_args: 92 | model: 93 | class_path: aimless.models.band_split_rnn.BandSplitRNN 94 | init_args: 95 | n_fft: 4096 96 | split_freqs: 97 | - 100 98 | - 200 99 | - 300 100 | - 400 101 | - 500 102 | - 600 103 | - 700 104 | - 800 105 | - 900 106 | - 1000 107 | - 1250 108 | - 1500 109 | - 1750 110 | - 2000 111 | - 2250 112 | - 2500 113 | - 2750 114 | - 3000 115 | - 3250 116 | - 3500 117 | - 3750 118 | - 4000 119 | - 4500 120 | - 5000 121 | - 5500 122 | - 6000 123 | - 6500 124 | - 7000 125 | - 7500 126 | - 8000 127 | - 9000 128 | - 10000 129 | - 11000 130 | - 12000 131 | - 13000 132 | - 14000 133 | - 15000 134 | - 16000 135 | - 18000 136 | - 20000 137 | - 22000 138 | hidden_size: 128 139 | num_layers: 12 140 | norm_groups: 4 141 | criterion: 142 | class_path: aimless.loss.freq.MDLoss 143 | init_args: 144 | mcoeff: 10 145 | n_fft: 4096 146 | hop_length: 1024 147 | transforms: 148 | - class_path: aimless.augment.SpeedPerturb 149 | init_args: 150 | orig_freq: 44100 151 | speeds: 152 | - 90 153 | - 100 154 | - 110 155 | p: 0.2 156 | - class_path: aimless.augment.RandomPitch 157 | init_args: 158 | semitones: 159 | - -1 160 | - 1 161 | - 0 162 | - 1 163 | - 2 164 | n_fft: 2048 165 | hop_length: 512 166 | p: 0.2 167 | target_track: sdx 168 | targets: 169 | music: null 170 | n_fft: 4096 171 | hop_length: 1024 172 | residual_model: true 173 | softmask: false 174 | alpha: 1.0 175 | n_iter: 1 176 | data: 177 | class_path: data.lightning.DnR 178 | init_args: 179 | root: /import/c4dm-datasets-ext/sdx-2023/dnr_v2/dnr_v2/ 180 | seq_duration: 3.0 181 | samples_per_track: 144 182 | random: true 183 | include_val: true 184 | random_track_mix: true 185 | transforms: 186 | - class_path: data.augment.RandomGain 187 | init_args: 188 | low: 0.25 189 | high: 1.25 190 | p: 1.0 191 | - class_path: data.augment.RandomFlipPhase 192 | init_args: 193 | p: 0.5 194 | - class_path: data.augment.RandomSwapLR 195 | init_args: 196 | p: 0.5 197 | batch_size: 16 198 | optimizer: 199 | class_path: torch.optim.Adam 200 | init_args: 201 | lr: 0.0003 202 | betas: 203 | - 0.9 204 | - 0.999 205 | eps: 1.0e-08 206 | weight_decay: 0 207 | amsgrad: false 208 | foreach: null 209 | maximize: false 210 | capturable: false 211 | differentiable: false 212 | fused: false 213 | -------------------------------------------------------------------------------- /cfg/cdx_a/hdemucs.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.5.post0 2 | seed_everything: 2434 3 | trainer: 4 | logger: true 5 | enable_checkpointing: true 6 | callbacks: 7 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 8 | init_args: 9 | dirpath: null 10 | filename: null 11 | monitor: null 12 | verbose: false 13 | save_last: true 14 | save_top_k: 1 15 | save_weights_only: false 16 | mode: min 17 | auto_insert_metric_name: true 18 | every_n_train_steps: 2000 19 | train_time_interval: null 20 | every_n_epochs: null 21 | save_on_train_epoch_end: null 22 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 23 | init_args: 24 | dirpath: null 25 | filename: null 26 | monitor: null 27 | verbose: false 28 | save_last: null 29 | save_top_k: -1 30 | save_weights_only: true 31 | mode: min 32 | auto_insert_metric_name: true 33 | every_n_train_steps: 2000 34 | train_time_interval: null 35 | every_n_epochs: null 36 | save_on_train_epoch_end: null 37 | - class_path: pytorch_lightning.callbacks.ModelSummary 38 | init_args: 39 | max_depth: 2 40 | default_root_dir: null 41 | gradient_clip_val: null 42 | gradient_clip_algorithm: null 43 | num_nodes: 1 44 | num_processes: null 45 | devices: null 46 | gpus: null 47 | auto_select_gpus: false 48 | tpu_cores: null 49 | ipus: null 50 | enable_progress_bar: true 51 | overfit_batches: 0.0 52 | track_grad_norm: -1 53 | check_val_every_n_epoch: 1 54 | fast_dev_run: false 55 | accumulate_grad_batches: 4 56 | max_epochs: null 57 | min_epochs: null 58 | max_steps: 30000 59 | min_steps: null 60 | max_time: null 61 | limit_train_batches: null 62 | limit_val_batches: 0 63 | limit_test_batches: null 64 | limit_predict_batches: null 65 | val_check_interval: null 66 | log_every_n_steps: 1 67 | accelerator: gpu 68 | strategy: ddp 69 | sync_batchnorm: false 70 | precision: 32 71 | enable_model_summary: true 72 | num_sanity_val_steps: 0 73 | resume_from_checkpoint: null 74 | profiler: null 75 | benchmark: null 76 | deterministic: null 77 | reload_dataloaders_every_n_epochs: 0 78 | auto_lr_find: false 79 | replace_sampler_ddp: true 80 | detect_anomaly: false 81 | auto_scale_batch_size: false 82 | plugins: null 83 | amp_backend: native 84 | amp_level: null 85 | move_metrics_to_cpu: false 86 | multiple_trainloader_mode: max_size_cycle 87 | inference_mode: true 88 | ckpt_path: null 89 | model: 90 | class_path: aimless.lightning.waveform.WaveformSeparator 91 | init_args: 92 | model: 93 | class_path: torchaudio.models.HDemucs 94 | init_args: 95 | sources: 96 | - music 97 | - sfx 98 | - speech 99 | audio_channels: 1 100 | channels: 64 101 | growth: 2 102 | nfft: 4096 103 | depth: 6 104 | freq_emb: 0.2 105 | emb_scale: 10 106 | emb_smooth: true 107 | kernel_size: 8 108 | time_stride: 2 109 | stride: 4 110 | context: 1 111 | context_enc: 0 112 | norm_starts: 4 113 | norm_groups: 4 114 | dconv_depth: 2 115 | dconv_comp: 4 116 | dconv_attn: 4 117 | dconv_lstm: 4 118 | dconv_init: 0.0001 119 | criterion: 120 | class_path: aimless.loss.time.NegativeSDR 121 | transforms: 122 | - class_path: aimless.augment.SpeedPerturb 123 | init_args: 124 | orig_freq: 44100 125 | speeds: 126 | - 50 127 | - 60 128 | - 70 129 | - 80 130 | - 90 131 | - 100 132 | - 110 133 | - 120 134 | - 130 135 | - 140 136 | - 150 137 | p: 0.3 138 | - class_path: aimless.augment.RandomPitch 139 | init_args: 140 | semitones: 141 | - -7 142 | - -6 143 | - -5 144 | - -4 145 | - -3 146 | - -2 147 | - -1 148 | - 1 149 | - 0 150 | - 1 151 | - 2 152 | - 3 153 | - 4 154 | - 5 155 | - 6 156 | - 7 157 | n_fft: 2048 158 | hop_length: 512 159 | p: 0.3 160 | use_sdx_targets: false 161 | target_track: sdx 162 | targets: 163 | music: null 164 | sfx: null 165 | speech: null 166 | data: 167 | class_path: data.lightning.DnR 168 | init_args: 169 | root: /import/c4dm-datasets-ext/sdx-2023/dnr_v2/dnr_v2/ 170 | seq_duration: 6.0 171 | samples_per_track: 144 172 | random: true 173 | include_val: true 174 | random_track_mix: true 175 | transforms: 176 | - class_path: data.augment.RandomGain 177 | init_args: 178 | low: 0.25 179 | high: 1.25 180 | p: 1.0 181 | - class_path: data.augment.RandomFlipPhase 182 | init_args: 183 | p: 0.5 184 | - class_path: data.augment.RandomSwapLR 185 | init_args: 186 | p: 0.5 187 | batch_size: 3 188 | optimizer: 189 | class_path: torch.optim.Adam 190 | init_args: 191 | lr: 0.0003 192 | betas: 193 | - 0.9 194 | - 0.999 195 | eps: 1.0e-08 196 | weight_decay: 0 197 | amsgrad: false 198 | foreach: null 199 | maximize: false 200 | capturable: false 201 | differentiable: false 202 | fused: false 203 | -------------------------------------------------------------------------------- /cfg/demucs.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.5.post0 2 | seed_everything: true 3 | trainer: 4 | callbacks: 5 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 6 | init_args: 7 | save_last: true 8 | every_n_train_steps: 2000 9 | filename: "{epoch}-{step}" 10 | - class_path: pytorch_lightning.callbacks.ModelSummary 11 | init_args: 12 | max_depth: 2 13 | logger: true 14 | enable_checkpointing: true 15 | callbacks: null 16 | default_root_dir: null 17 | gradient_clip_val: null 18 | gradient_clip_algorithm: null 19 | num_nodes: 1 20 | num_processes: null 21 | devices: null 22 | gpus: null 23 | auto_select_gpus: false 24 | tpu_cores: null 25 | ipus: null 26 | enable_progress_bar: true 27 | overfit_batches: 0.0 28 | track_grad_norm: -1 29 | check_val_every_n_epoch: 1 30 | fast_dev_run: false 31 | accumulate_grad_batches: null 32 | max_epochs: null 33 | min_epochs: null 34 | max_steps: -1 35 | min_steps: null 36 | max_time: null 37 | limit_train_batches: null 38 | limit_val_batches: null 39 | limit_test_batches: null 40 | limit_predict_batches: null 41 | val_check_interval: null 42 | log_every_n_steps: 1 43 | accelerator: gpu 44 | sync_batchnorm: true 45 | precision: 16 46 | enable_model_summary: true 47 | num_sanity_val_steps: 2 48 | resume_from_checkpoint: null 49 | profiler: null 50 | benchmark: null 51 | deterministic: null 52 | reload_dataloaders_every_n_epochs: 0 53 | auto_lr_find: false 54 | replace_sampler_ddp: true 55 | detect_anomaly: false 56 | auto_scale_batch_size: false 57 | plugins: null 58 | amp_backend: native 59 | amp_level: null 60 | move_metrics_to_cpu: false 61 | multiple_trainloader_mode: max_size_cycle 62 | inference_mode: true 63 | ckpt_path: null 64 | model: 65 | class_path: aimless.lightning.waveform.WaveformSeparator 66 | init_args: 67 | model: 68 | class_path: aimless.models.demucs_split.DemucsSplit 69 | init_args: 70 | channels: 48 71 | criterion: 72 | class_path: aimless.loss.time.L1Loss 73 | targets: {vocals, drums, bass, other} 74 | data: 75 | class_path: data.lightning.MUSDB 76 | init_args: 77 | root: /import/c4dm-datasets-ext/musdb18hq/ 78 | seq_duration: 10.0 79 | samples_per_track: 150 80 | random: true 81 | random_track_mix: true 82 | transforms: 83 | - class_path: data.augment.RandomGain 84 | - class_path: data.augment.RandomFlipPhase 85 | - class_path: data.augment.RandomSwapLR 86 | - class_path: data.augment.LimitAug 87 | init_args: 88 | sample_rate: 44100 89 | batch_size: 4 90 | optimizer: 91 | class_path: torch.optim.Adam 92 | init_args: 93 | lr: 0.0003 -------------------------------------------------------------------------------- /cfg/hdemucs.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.5.post0 2 | seed_everything: true 3 | trainer: 4 | callbacks: 5 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 6 | init_args: 7 | save_last: true 8 | every_n_train_steps: 2000 9 | filename: "{epoch}-{step}" 10 | - class_path: pytorch_lightning.callbacks.ModelSummary 11 | init_args: 12 | max_depth: 2 13 | logger: true 14 | enable_checkpointing: true 15 | # callbacks: null 16 | default_root_dir: null 17 | gradient_clip_val: null 18 | gradient_clip_algorithm: null 19 | num_nodes: 1 20 | num_processes: null 21 | devices: 1 22 | gpus: null 23 | auto_select_gpus: false 24 | tpu_cores: null 25 | ipus: null 26 | enable_progress_bar: true 27 | overfit_batches: 0.0 28 | track_grad_norm: -1 29 | check_val_every_n_epoch: 1 30 | fast_dev_run: false 31 | accumulate_grad_batches: null 32 | max_epochs: null 33 | min_epochs: null 34 | max_steps: -1 35 | min_steps: null 36 | max_time: null 37 | limit_train_batches: null 38 | limit_val_batches: null 39 | limit_test_batches: null 40 | limit_predict_batches: null 41 | val_check_interval: null 42 | log_every_n_steps: 1 43 | # accelerator: cpu 44 | accelerator: gpu 45 | strategy: ddp 46 | sync_batchnorm: true 47 | precision: 32 48 | enable_model_summary: true 49 | num_sanity_val_steps: 2 50 | resume_from_checkpoint: null 51 | profiler: null 52 | benchmark: null 53 | deterministic: null 54 | reload_dataloaders_every_n_epochs: 0 55 | auto_lr_find: false 56 | replace_sampler_ddp: true 57 | detect_anomaly: false 58 | auto_scale_batch_size: false 59 | plugins: null 60 | amp_backend: native 61 | amp_level: null 62 | move_metrics_to_cpu: false 63 | multiple_trainloader_mode: max_size_cycle 64 | inference_mode: true 65 | ckpt_path: null 66 | model: 67 | class_path: aimless.lightning.waveform.WaveformSeparator 68 | init_args: 69 | model: 70 | class_path: torchaudio.models.HDemucs 71 | init_args: 72 | sources: 73 | - vocals 74 | - drums 75 | - bass 76 | - other 77 | channels: 48 78 | criterion: 79 | class_path: aimless.loss.time.L1Loss 80 | transforms: 81 | - class_path: aimless.augment.SpeedPerturb 82 | init_args: 83 | orig_freq: 44100 84 | speeds: 85 | - 90 86 | - 100 87 | - 110 88 | p: 0.2 89 | - class_path: aimless.augment.RandomPitch 90 | init_args: 91 | semitones: 92 | - -1 93 | - 1 94 | - 0 95 | - 1 96 | - 2 97 | p: 0.2 98 | targets: {vocals, drums, bass, other} 99 | data: 100 | class_path: data.lightning.LabelNoise 101 | init_args: 102 | # root: /import/c4dm-datasets-ext/musdb18hq/ 103 | # root: /Volumes/samsung_t5/moisesdb23_labelnoise_v1.0_split 104 | root: /import/c4dm-datasets-ext/cm007/moisesdb23_labelnoise_v1.0_split 105 | seq_duration: 10.0 106 | # samples_per_track: 1 107 | samples_per_track: 100 108 | transforms: 109 | - class_path: data.augment.RandomParametricEQ 110 | init_args: 111 | sample_rate: 44100 112 | p: 0.7 113 | - class_path: data.augment.RandomPedalboardDistortion 114 | init_args: 115 | sample_rate: 44100 116 | p: 0.01 117 | - class_path: data.augment.RandomPedalboardDelay 118 | init_args: 119 | sample_rate: 44100 120 | p: 0.02 121 | - class_path: data.augment.RandomPedalboardChorus 122 | init_args: 123 | sample_rate: 44100 124 | p: 0.01 125 | - class_path: data.augment.RandomPedalboardPhaser 126 | init_args: 127 | sample_rate: 44100 128 | p: 0.01 129 | - class_path: data.augment.RandomPedalboardCompressor 130 | init_args: 131 | sample_rate: 44100 132 | p: 0.5 133 | - class_path: data.augment.RandomPedalboardReverb 134 | init_args: 135 | sample_rate: 44100 136 | p: 0.2 137 | - class_path: data.augment.RandomStereoWidener 138 | init_args: 139 | sample_rate: 44100 140 | p: 0.3 141 | - class_path: data.augment.RandomPedalboardLimiter 142 | init_args: 143 | sample_rate: 44100 144 | p: 0.1 145 | - class_path: data.augment.RandomVolumeAutomation 146 | init_args: 147 | sample_rate: 44100 148 | p: 0.1 149 | - class_path: data.augment.LoudnessNormalize 150 | init_args: 151 | sample_rate: 44100 152 | target_lufs_db: -32.0 153 | p: 1.0 154 | random: true 155 | random_track_mix: false 156 | # batch_size: 1 157 | batch_size: 5 158 | optimizer: 159 | class_path: torch.optim.Adam 160 | init_args: 161 | lr: 0.0003 162 | -------------------------------------------------------------------------------- /cfg/mdx_a/hdemucs.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.5.post0 2 | seed_everything: 45 3 | trainer: 4 | logger: true 5 | enable_checkpointing: true 6 | callbacks: 7 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 8 | init_args: 9 | dirpath: null 10 | filename: null 11 | monitor: null 12 | verbose: false 13 | save_last: true 14 | save_top_k: 1 15 | save_weights_only: false 16 | mode: min 17 | auto_insert_metric_name: true 18 | every_n_train_steps: 2000 19 | train_time_interval: null 20 | every_n_epochs: null 21 | save_on_train_epoch_end: null 22 | - class_path: pytorch_lightning.callbacks.ModelSummary 23 | init_args: 24 | max_depth: 2 25 | default_root_dir: /import/c4dm-datasets-ext/cm007/mdx_checkpoints/ 26 | gradient_clip_val: null 27 | gradient_clip_algorithm: null 28 | num_nodes: 1 29 | num_processes: null 30 | # devices: 1 31 | devices: 4 32 | gpus: null 33 | auto_select_gpus: false 34 | tpu_cores: null 35 | ipus: null 36 | enable_progress_bar: true 37 | overfit_batches: 0.0 38 | track_grad_norm: -1 39 | check_val_every_n_epoch: 10 40 | fast_dev_run: false 41 | accumulate_grad_batches: 4 42 | # accumulate_grad_batches: 32 43 | # accumulate_grad_batches: 64 44 | max_epochs: -1 45 | min_epochs: null 46 | max_steps: -1 47 | min_steps: null 48 | max_time: null 49 | limit_train_batches: null 50 | limit_val_batches: null 51 | limit_test_batches: null 52 | limit_predict_batches: null 53 | val_check_interval: null 54 | log_every_n_steps: 1 55 | accelerator: gpu 56 | strategy: ddp 57 | sync_batchnorm: false 58 | precision: 32 59 | enable_model_summary: true 60 | num_sanity_val_steps: 2 61 | resume_from_checkpoint: null 62 | profiler: null 63 | benchmark: null 64 | deterministic: null 65 | reload_dataloaders_every_n_epochs: 0 66 | auto_lr_find: false 67 | replace_sampler_ddp: true 68 | detect_anomaly: false 69 | auto_scale_batch_size: false 70 | plugins: null 71 | amp_backend: native 72 | amp_level: null 73 | move_metrics_to_cpu: false 74 | multiple_trainloader_mode: max_size_cycle 75 | inference_mode: true 76 | ckpt_path: null 77 | #ckpt_path: /import/c4dm-datasets-ext/cm007/mdx_checkpoints/epoch=750-step=212000.ckpt 78 | model: 79 | class_path: aimless.lightning.waveform.WaveformSeparator 80 | init_args: 81 | model: 82 | class_path: torchaudio.models.HDemucs 83 | init_args: 84 | sources: 85 | - vocals 86 | - drums 87 | - bass 88 | - other 89 | audio_channels: 2 90 | channels: 64 91 | growth: 2 92 | nfft: 4096 93 | depth: 6 94 | freq_emb: 0.2 95 | emb_scale: 10 96 | emb_smooth: true 97 | kernel_size: 8 98 | time_stride: 2 99 | stride: 4 100 | context: 1 101 | context_enc: 0 102 | norm_starts: 4 103 | norm_groups: 4 104 | dconv_depth: 2 105 | dconv_comp: 4 106 | dconv_attn: 4 107 | dconv_lstm: 4 108 | dconv_init: 0.0001 109 | criterion: 110 | class_path: aimless.loss.time.NegativeSDR 111 | transforms: 112 | - class_path: aimless.augment.SpeedPerturb 113 | init_args: 114 | orig_freq: 44100 115 | speeds: 116 | - 90 117 | - 100 118 | - 110 119 | p: 0.3 120 | - class_path: aimless.augment.RandomPitch 121 | init_args: 122 | semitones: 123 | - -1 124 | - 1 125 | - 0 126 | - 1 127 | - 2 128 | p: 0.3 129 | targets: {vocals, drums, bass, other} 130 | data: 131 | class_path: data.lightning.LabelNoise 132 | init_args: 133 | # root: /import/c4dm-datasets-ext/musdb18hq/ 134 | # root: /Volumes/samsung_t5/moisesdb23_labelnoise_v1.0_split 135 | root: /import/c4dm-datasets-ext/cm007/moisesdb23_labelnoise_v1.0_split 136 | seq_duration: 10.0 137 | # samples_per_track: 1 138 | samples_per_track: 144 139 | transforms: 140 | - class_path: data.augment.RandomParametricEQ 141 | init_args: 142 | sample_rate: 44100 143 | p: 0.3 144 | - class_path: data.augment.RandomPedalboardDistortion 145 | init_args: 146 | sample_rate: 44100 147 | p: 0.03 148 | - class_path: data.augment.RandomPedalboardDelay 149 | init_args: 150 | sample_rate: 44100 151 | p: 0.03 152 | - class_path: data.augment.RandomPedalboardChorus 153 | init_args: 154 | sample_rate: 44100 155 | p: 0.03 156 | - class_path: data.augment.RandomPedalboardPhaser 157 | init_args: 158 | sample_rate: 44100 159 | p: 0.03 160 | - class_path: data.augment.RandomPedalboardCompressor 161 | init_args: 162 | sample_rate: 44100 163 | p: 0.3 164 | - class_path: data.augment.RandomPedalboardReverb 165 | init_args: 166 | sample_rate: 44100 167 | p: 0.3 168 | - class_path: data.augment.RandomStereoWidener 169 | init_args: 170 | sample_rate: 44100 171 | p: 0.3 172 | - class_path: data.augment.RandomPedalboardLimiter 173 | init_args: 174 | sample_rate: 44100 175 | p: 0.3 176 | - class_path: data.augment.RandomVolumeAutomation 177 | init_args: 178 | sample_rate: 44100 179 | p: 0.3 180 | - class_path: data.augment.LoudnessNormalize 181 | init_args: 182 | sample_rate: 44100 183 | target_lufs_db: -32.0 184 | p: 0.3 185 | - class_path: data.augment.RandomGain 186 | init_args: 187 | low: 0.25 188 | high: 1.25 189 | p: 0.3 190 | - class_path: data.augment.RandomFlipPhase 191 | init_args: 192 | p: 0.3 193 | - class_path: data.augment.RandomSwapLR 194 | init_args: 195 | p: 0.3 196 | random: true 197 | # include_val: true 198 | random_track_mix: true 199 | # batch_size: 1 200 | batch_size: 3 201 | optimizer: 202 | class_path: torch.optim.Adam 203 | init_args: 204 | lr: 0.0003 205 | # lr: 0.0001 206 | betas: 207 | - 0.9 208 | - 0.999 209 | eps: 1.0e-08 210 | weight_decay: 0 211 | amsgrad: false 212 | foreach: null 213 | maximize: false 214 | capturable: false 215 | differentiable: false 216 | fused: false 217 | -------------------------------------------------------------------------------- /cfg/speech_enhance.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.5.post0 2 | seed_everything: true 3 | trainer: 4 | detect_anomaly: true 5 | callbacks: 6 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 7 | init_args: 8 | save_last: true 9 | every_n_train_steps: 2000 10 | filename: "{epoch}-{step}" 11 | - class_path: pytorch_lightning.callbacks.ModelSummary 12 | init_args: 13 | max_depth: 2 14 | log_every_n_steps: 1 15 | accelerator: gpu 16 | strategy: ddp 17 | sync_batchnorm: true 18 | precision: 32 19 | ckpt_path: null 20 | model: 21 | class_path: aimless.lightning.waveform.WaveformSeparator 22 | init_args: 23 | model: 24 | class_path: torchaudio.models.HDemucs 25 | init_args: 26 | sources: 27 | - speech 28 | channels: 48 29 | audio_channels: 1 30 | criterion: 31 | class_path: aimless.loss.time.L1Loss 32 | transforms: 33 | - class_path: aimless.augment.SpeedPerturb 34 | init_args: 35 | orig_freq: 44100 36 | speeds: 37 | - 90 38 | - 100 39 | - 110 40 | p: 0.2 41 | - class_path: aimless.augment.RandomPitch 42 | init_args: 43 | semitones: 44 | - -1 45 | - 1 46 | - 0 47 | - 1 48 | - 2 49 | p: 0.2 50 | target_track: se 51 | targets: {speech} 52 | data: 53 | class_path: data.lightning.SpeechNoise 54 | init_args: 55 | speech_root: /import/c4dm-datasets/VCTK-Corpus-0.92/wav48_silence_trimmed/ 56 | noise_root: /import/c4dm-datasets-ext/musdb18hq-voiceless-mix/ 57 | seq_duration: 6.0 58 | samples_per_track: 64 59 | least_overlap_ratio: 0.5 60 | mono: true 61 | snr_sampler: 62 | class_path: torch.distributions.Uniform 63 | init_args: 64 | low: 0 65 | high: 20 66 | batch_size: 8 67 | optimizer: 68 | class_path: torch.optim.Adam 69 | init_args: 70 | lr: 0.0003 -------------------------------------------------------------------------------- /cfg/xumx.yaml: -------------------------------------------------------------------------------- 1 | # pytorch_lightning==1.8.5.post0 2 | seed_everything: true 3 | trainer: 4 | callbacks: 5 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 6 | init_args: 7 | save_last: true 8 | every_n_train_steps: 2000 9 | filename: "{epoch}-{step}" 10 | - class_path: pytorch_lightning.callbacks.ModelSummary 11 | init_args: 12 | max_depth: 2 13 | logger: true 14 | enable_checkpointing: true 15 | callbacks: null 16 | default_root_dir: null 17 | gradient_clip_val: null 18 | gradient_clip_algorithm: null 19 | num_nodes: 1 20 | num_processes: null 21 | devices: null 22 | gpus: null 23 | auto_select_gpus: false 24 | tpu_cores: null 25 | ipus: null 26 | enable_progress_bar: true 27 | overfit_batches: 0.0 28 | track_grad_norm: -1 29 | check_val_every_n_epoch: 1 30 | fast_dev_run: false 31 | accumulate_grad_batches: null 32 | max_epochs: null 33 | min_epochs: null 34 | max_steps: -1 35 | min_steps: null 36 | max_time: null 37 | limit_train_batches: null 38 | limit_val_batches: null 39 | limit_test_batches: null 40 | limit_predict_batches: null 41 | val_check_interval: null 42 | log_every_n_steps: 1 43 | accelerator: gpu 44 | sync_batchnorm: true 45 | precision: 32 46 | enable_model_summary: true 47 | num_sanity_val_steps: 2 48 | resume_from_checkpoint: null 49 | profiler: null 50 | benchmark: null 51 | deterministic: null 52 | reload_dataloaders_every_n_epochs: 0 53 | auto_lr_find: false 54 | replace_sampler_ddp: true 55 | detect_anomaly: false 56 | auto_scale_batch_size: false 57 | plugins: null 58 | amp_backend: native 59 | amp_level: null 60 | move_metrics_to_cpu: false 61 | multiple_trainloader_mode: max_size_cycle 62 | inference_mode: true 63 | ckpt_path: null 64 | model: 65 | class_path: aimless.lightning.freq_mask.MaskPredictor 66 | init_args: 67 | model: 68 | class_path: aimless.models.xumx.X_UMX 69 | init_args: 70 | n_fft: 4096 71 | hidden_channels: 512 72 | max_bins: 1487 73 | nb_channels: 2 74 | nb_layers: 3 75 | criterion: 76 | class_path: aimless.loss.freq.CLoss 77 | init_args: 78 | mcoeff: 10 79 | n_fft: 4096 80 | hop_length: 1024 81 | n_iter: 1 82 | transforms: 83 | - class_path: aimless.augment.SpeedPerturb 84 | init_args: 85 | orig_freq: 44100 86 | speeds: 87 | - 90 88 | - 100 89 | - 110 90 | p: 0.2 91 | - class_path: aimless.augment.RandomPitch 92 | init_args: 93 | semitones: 94 | - -1 95 | - 1 96 | - 0 97 | - 1 98 | - 2 99 | p: 0.2 100 | targets: {vocals, drums, bass, other} 101 | n_fft: 4096 102 | hop_length: 1024 103 | residual_model: false 104 | softmask: false 105 | alpha: 1.0 106 | n_iter: 1 107 | data: 108 | class_path: data.lightning.MUSDB 109 | init_args: 110 | root: /import/c4dm-datasets-ext/musdb18hq/ 111 | seq_duration: 6.0 112 | samples_per_track: 64 113 | random: true 114 | random_track_mix: true 115 | transforms: 116 | - class_path: data.augment.RandomGain 117 | - class_path: data.augment.RandomFlipPhase 118 | - class_path: data.augment.RandomSwapLR 119 | - class_path: data.augment.LimitAug 120 | init_args: 121 | sample_rate: 44100 122 | batch_size: 4 123 | optimizer: 124 | class_path: torch.optim.Adam 125 | init_args: 126 | lr: 0.0001 127 | weight_decay: 0.00001 128 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/data/__init__.py -------------------------------------------------------------------------------- /data/augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | import numpy as np 5 | import scipy.stats 6 | import scipy.signal 7 | import pyloudnorm as pyln 8 | from typing import List, Tuple 9 | 10 | from pedalboard import ( 11 | Pedalboard, 12 | Gain, 13 | Chorus, 14 | Reverb, 15 | Compressor, 16 | Phaser, 17 | Delay, 18 | Distortion, 19 | Limiter, 20 | ) 21 | 22 | 23 | __all__ = [ 24 | "RandomSwapLR", 25 | "RandomGain", 26 | "RandomFlipPhase", 27 | "LimitAug", 28 | "CPUBase", 29 | "RandomParametricEQ", 30 | "RandomStereoWidener", 31 | "RandomVolumeAutomation", 32 | "RandomPedalboardCompressor", 33 | "RandomPedalboardDelay", 34 | "RandomPedalboardChorus", 35 | "RandomPedalboardDistortion", 36 | "RandomPedalboardPhaser", 37 | "RandomPedalboardReverb", 38 | "RandomPedalboardLimiter", 39 | "RandomSoxReverb", 40 | "LoudnessNormalize", 41 | "RandomPan", 42 | "Mono2Stereo", 43 | ] 44 | 45 | 46 | class CPUBase(object): 47 | def __init__(self, p: float = 1.0): 48 | assert 0 <= p <= 1, "invalid probability value" 49 | self.p = p 50 | 51 | def __call__( 52 | self, x: Tuple[torch.Tensor, torch.Tensor] 53 | ) -> Tuple[torch.Tensor, torch.Tensor]: 54 | """ 55 | Args: 56 | x (Tuple[torch.Tensor, torch.Tensor]): (mixture, stems) where 57 | mixture : (Num_channels, L) 58 | stems: (Num_sources, Num_channels, L) 59 | Return: 60 | Tuple[torch.Tensor, torch.Tensor]: (mixture, stems) where 61 | mixture: (Num_channels, L) 62 | stems: (Num_sources, Num_channels, L) 63 | """ 64 | mixture, stems = x 65 | stems_fx = [] 66 | for stem_idx in np.arange(stems.shape[0]): 67 | stem = stems[stem_idx, ...] 68 | if np.random.rand() < self.p: 69 | stems_fx.append(self._transform(stem)) 70 | else: 71 | stems_fx.append(stem) 72 | stems = torch.stack(stems_fx, dim=0) 73 | mixture = stems.sum(0) 74 | return (mixture, stems) 75 | 76 | def _transform(self, stems: torch.Tensor) -> torch.Tensor: 77 | raise NotImplementedError 78 | 79 | 80 | class RandomSwapLR(CPUBase): 81 | def __init__(self, p=0.5) -> None: 82 | super().__init__(p=p) 83 | 84 | def _transform(self, stems: torch.Tensor): 85 | return torch.flip(stems, [0]) 86 | 87 | 88 | class RandomGain(CPUBase): 89 | def __init__(self, low=0.25, high=1.25, **kwargs) -> None: 90 | super().__init__(**kwargs) 91 | self.low = low 92 | self.high = high 93 | 94 | def _transform(self, stems): 95 | gain = np.random.rand() * (self.high - self.low) + self.low 96 | return stems * gain 97 | 98 | 99 | class RandomFlipPhase(RandomSwapLR): 100 | def __init__(self, p=0.5) -> None: 101 | super().__init__(p=p) 102 | 103 | def _transform(self, stems: torch.Tensor): 104 | return -stems 105 | 106 | 107 | def db2linear(x): 108 | return 10 ** (x / 20) 109 | 110 | 111 | class LimitAug(CPUBase): 112 | def __init__( 113 | self, 114 | target_lufs_mean=-10.887, 115 | target_lufs_std=1.191, 116 | target_loudnorm_lufs=-14.0, 117 | max_release_ms=200.0, 118 | min_release_ms=30.0, 119 | sample_rate=44100, 120 | **kwargs, 121 | ) -> None: 122 | """ 123 | Args: 124 | target_lufs_mean (float): mean of target LUFS. default: -10.887 (corresponding to the statistics of musdb-L) 125 | target_lufs_std (float): std of target LUFS. default: 1.191 (corresponding to the statistics of musdb-L) 126 | target_loudnorm_lufs (float): target LUFS after loudnorm. default: -14.0 127 | max_release_ms (float): max release time of limiter. default: 200.0 128 | min_release_ms (float): min release time of limiter. default: 30.0 129 | sample_rate (int): sample rate of audio. default: 44100 130 | """ 131 | super().__init__(**kwargs) 132 | self.target_lufs_sampler = torch.distributions.Normal( 133 | target_lufs_mean, target_lufs_std 134 | ) 135 | self.target_loudnorm_lufs = target_loudnorm_lufs 136 | self.sample_rate = sample_rate 137 | self.board = Pedalboard([Gain(0), Limiter(threshold_db=0.0, release_ms=100.0)]) 138 | self.limiter_release_sampler = torch.distributions.Uniform( 139 | min_release_ms, max_release_ms 140 | ) 141 | self.meter = pyln.Meter(sample_rate) 142 | 143 | def __call__( 144 | self, x: Tuple[torch.Tensor, torch.Tensor] 145 | ) -> Tuple[torch.Tensor, torch.Tensor]: 146 | if np.random.rand() > self.p: 147 | return x 148 | mixture, stems = x 149 | mixture_np = mixture.numpy().T 150 | loudness = self.meter.integrated_loudness(mixture_np) 151 | target_lufs = self.target_lufs_sampler.sample().item() 152 | self.board[1].release_ms = self.limiter_release_sampler.sample().item() 153 | 154 | if np.isinf(loudness): 155 | aug_gain = 0.0 156 | else: 157 | aug_gain = target_lufs - loudness 158 | self.board[0].gain_db = aug_gain 159 | 160 | new_mixture_np = self.board(mixture_np, self.sample_rate) 161 | after_loudness = self.meter.integrated_loudness(new_mixture_np) 162 | 163 | if not np.isinf(after_loudness): 164 | target_gain = self.target_loudnorm_lufs - after_loudness 165 | new_mixture_np *= db2linear(target_gain) 166 | 167 | new_mixture = torch.tensor(new_mixture_np.T, dtype=mixture.dtype) 168 | # apply element-wise gain to stems 169 | stems *= new_mixture.abs() / mixture.abs().add(1e-8) 170 | return (new_mixture, stems) 171 | 172 | 173 | def loguniform(low=0, high=1): 174 | return scipy.stats.loguniform.rvs(low, high) 175 | 176 | 177 | def rand(low=0, high=1): 178 | return (torch.rand(1).numpy()[0] * (high - low)) + low 179 | 180 | 181 | def randint(low=0, high=1): 182 | return torch.randint(low, high + 1, (1,)).numpy()[0] 183 | 184 | 185 | def biqaud( 186 | gain_db: float, 187 | cutoff_freq: float, 188 | q_factor: float, 189 | sample_rate: float, 190 | filter_type: str, 191 | ): 192 | """Use design parameters to generate coeffieicnets for a specific filter type. 193 | Args: 194 | gain_db (float): Shelving filter gain in dB. 195 | cutoff_freq (float): Cutoff frequency in Hz. 196 | q_factor (float): Q factor. 197 | sample_rate (float): Sample rate in Hz. 198 | filter_type (str): Filter type. 199 | One of "low_shelf", "high_shelf", or "peaking" 200 | Returns: 201 | b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2] 202 | a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2] 203 | """ 204 | 205 | A = 10 ** (gain_db / 40.0) 206 | w0 = 2.0 * np.pi * (cutoff_freq / sample_rate) 207 | alpha = np.sin(w0) / (2.0 * q_factor) 208 | 209 | cos_w0 = np.cos(w0) 210 | sqrt_A = np.sqrt(A) 211 | 212 | if filter_type == "high_shelf": 213 | b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha) 214 | b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0) 215 | b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha) 216 | a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha 217 | a1 = 2 * ((A - 1) - (A + 1) * cos_w0) 218 | a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha 219 | elif filter_type == "low_shelf": 220 | b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha) 221 | b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0) 222 | b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha) 223 | a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha 224 | a1 = -2 * ((A - 1) + (A + 1) * cos_w0) 225 | a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha 226 | elif filter_type == "peaking": 227 | b0 = 1 + alpha * A 228 | b1 = -2 * cos_w0 229 | b2 = 1 - alpha * A 230 | a0 = 1 + alpha / A 231 | a1 = -2 * cos_w0 232 | a2 = 1 - alpha / A 233 | else: 234 | pass 235 | # raise ValueError(f"Invalid filter_type: {filter_type}.") 236 | 237 | b = np.array([b0, b1, b2]) / a0 238 | a = np.array([a0, a1, a2]) / a0 239 | 240 | return b, a 241 | 242 | 243 | def parametric_eq( 244 | x: np.ndarray, 245 | sample_rate: float, 246 | low_shelf_gain_db: float = 0.0, 247 | low_shelf_cutoff_freq: float = 80.0, 248 | low_shelf_q_factor: float = 0.707, 249 | band_gains_db: List[float] = [0.0], 250 | band_cutoff_freqs: List[float] = [300.0], 251 | band_q_factors: List[float] = [0.707], 252 | high_shelf_gain_db: float = 0.0, 253 | high_shelf_cutoff_freq: float = 1000.0, 254 | high_shelf_q_factor: float = 0.707, 255 | dtype=np.float32, 256 | ): 257 | """Multiband parametric EQ. 258 | 259 | Low-shelf -> Band 1 -> ... -> Band N -> High-shelf 260 | 261 | Args: 262 | 263 | """ 264 | assert ( 265 | len(band_gains_db) == len(band_cutoff_freqs) == len(band_q_factors) 266 | ) # must define for all bands 267 | 268 | # -------- apply low-shelf filter -------- 269 | b, a = biqaud( 270 | low_shelf_gain_db, 271 | low_shelf_cutoff_freq, 272 | low_shelf_q_factor, 273 | sample_rate, 274 | "low_shelf", 275 | ) 276 | x = scipy.signal.lfilter(b, a, x) 277 | 278 | # -------- apply peaking filters -------- 279 | for gain_db, cutoff_freq, q_factor in zip( 280 | band_gains_db, band_cutoff_freqs, band_q_factors 281 | ): 282 | b, a = biqaud( 283 | gain_db, 284 | cutoff_freq, 285 | q_factor, 286 | sample_rate, 287 | "peaking", 288 | ) 289 | x = scipy.signal.lfilter(b, a, x) 290 | 291 | # -------- apply high-shelf filter -------- 292 | b, a = biqaud( 293 | high_shelf_gain_db, 294 | high_shelf_cutoff_freq, 295 | high_shelf_q_factor, 296 | sample_rate, 297 | "high_shelf", 298 | ) 299 | sos5 = np.concatenate((b, a)) 300 | x = scipy.signal.lfilter(b, a, x) 301 | 302 | return x.astype(dtype) 303 | 304 | 305 | class RandomParametricEQ(CPUBase): 306 | def __init__( 307 | self, 308 | sample_rate: float = 44100.0, 309 | num_bands: int = 3, 310 | min_gain_db: float = -6.0, 311 | max_gain_db: float = +6.0, 312 | min_cutoff_freq: float = 1000.0, 313 | max_cutoff_freq: float = 10000.0, 314 | min_q_factor: float = 0.1, 315 | max_q_factor: float = 4.0, 316 | **kwargs, 317 | ): 318 | super().__init__(**kwargs) 319 | self.sample_rate = sample_rate 320 | self.num_bands = num_bands 321 | self.min_gain_db = min_gain_db 322 | self.max_gain_db = max_gain_db 323 | self.min_cutoff_freq = min_cutoff_freq 324 | self.max_cutoff_freq = max_cutoff_freq 325 | self.min_q_factor = min_q_factor 326 | self.max_q_factor = max_q_factor 327 | 328 | def _transform(self, x: torch.Tensor): 329 | """ 330 | Args: 331 | x: (torch.Tensor): Array of audio samples with shape (chs, seq_leq). 332 | The filter will be applied the final dimension, and by default the same 333 | filter will be applied to all channels. 334 | """ 335 | low_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db) 336 | low_shelf_cutoff_freq = loguniform(20.0, 200.0) 337 | low_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor) 338 | 339 | high_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db) 340 | high_shelf_cutoff_freq = loguniform(8000.0, 16000.0) 341 | high_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor) 342 | 343 | band_gain_dbs = [] 344 | band_cutoff_freqs = [] 345 | band_q_factors = [] 346 | for _ in range(self.num_bands): 347 | band_gain_dbs.append(rand(self.min_gain_db, self.max_gain_db)) 348 | band_cutoff_freqs.append( 349 | loguniform(self.min_cutoff_freq, self.max_cutoff_freq) 350 | ) 351 | band_q_factors.append(rand(self.min_q_factor, self.max_q_factor)) 352 | 353 | y = parametric_eq( 354 | x.numpy(), 355 | self.sample_rate, 356 | low_shelf_gain_db=low_shelf_gain_db, 357 | low_shelf_cutoff_freq=low_shelf_cutoff_freq, 358 | low_shelf_q_factor=low_shelf_q_factor, 359 | band_gains_db=band_gain_dbs, 360 | band_cutoff_freqs=band_cutoff_freqs, 361 | band_q_factors=band_q_factors, 362 | high_shelf_gain_db=high_shelf_gain_db, 363 | high_shelf_cutoff_freq=high_shelf_cutoff_freq, 364 | high_shelf_q_factor=high_shelf_q_factor, 365 | ) 366 | 367 | return torch.from_numpy(y) 368 | 369 | 370 | def stereo_widener(x: torch.Tensor, width: torch.Tensor): 371 | sqrt2 = np.sqrt(2) 372 | 373 | left = x[0, ...] 374 | right = x[1, ...] 375 | 376 | mid = (left + right) / sqrt2 377 | side = (left - right) / sqrt2 378 | 379 | # amplify mid and side signal seperately: 380 | mid *= 2 * (1 - width) 381 | side *= 2 * width 382 | 383 | left = (mid + side) / sqrt2 384 | right = (mid - side) / sqrt2 385 | 386 | x = torch.stack((left, right), dim=0) 387 | 388 | return x 389 | 390 | 391 | class RandomStereoWidener(CPUBase): 392 | def __init__( 393 | self, 394 | sample_rate: float = 44100.0, 395 | min_width: float = 0.0, 396 | max_width: float = 1.0, 397 | **kwargs, 398 | ) -> None: 399 | super().__init__(**kwargs) 400 | self.sample_rate = sample_rate 401 | self.min_width = min_width 402 | self.max_width = max_width 403 | 404 | def _transform(self, x: torch.Tensor): 405 | width = rand(self.min_width, self.max_width) 406 | return stereo_widener(x, width) 407 | 408 | 409 | class RandomVolumeAutomation(CPUBase): 410 | def __init__( 411 | self, 412 | sample_rate: float = 44100.0, 413 | min_segment_seconds: float = 3.0, 414 | min_gain_db: float = -6.0, 415 | max_gain_db: float = 6.0, 416 | **kwargs, 417 | ) -> None: 418 | super().__init__(**kwargs) 419 | self.sample_rate = sample_rate 420 | self.min_segment_seconds = min_segment_seconds 421 | self.min_gain_db = min_gain_db 422 | self.max_gain_db = max_gain_db 423 | 424 | def _transform(self, x: torch.Tensor): 425 | gain_db = torch.zeros(x.shape[-1]).type_as(x) 426 | 427 | seconds = x.shape[-1] / self.sample_rate 428 | max_num_segments = max(1, int(seconds // self.min_segment_seconds)) 429 | 430 | num_segments = randint(1, max_num_segments) 431 | segment_lengths = ( 432 | x.shape[-1] 433 | * np.random.dirichlet( 434 | [rand(0, 10) for _ in range(num_segments)], 1 435 | ) # TODO(cm): this can crash training 436 | ).astype("int")[0] 437 | 438 | segment_lengths = np.maximum(segment_lengths, 1) 439 | 440 | # check the sum is equal to the length of the signal 441 | diff = segment_lengths.sum() - x.shape[-1] 442 | if diff < 0: 443 | segment_lengths[-1] -= diff 444 | elif diff > 0: 445 | for idx in range(num_segments): 446 | if segment_lengths[idx] > diff + 1: 447 | segment_lengths[idx] -= diff 448 | break 449 | 450 | samples_filled = 0 451 | start_gain_db = 0 452 | for idx in range(num_segments): 453 | segment_samples = segment_lengths[idx] 454 | if idx != 0: 455 | start_gain_db = end_gain_db 456 | 457 | # sample random end gain 458 | end_gain_db = rand(self.min_gain_db, self.max_gain_db) 459 | fade = torch.linspace(start_gain_db, end_gain_db, steps=segment_samples) 460 | gain_db[samples_filled : samples_filled + segment_samples] = fade 461 | samples_filled = samples_filled + segment_samples 462 | 463 | # print(gain_db) 464 | x *= 10 ** (gain_db / 20.0) 465 | return x 466 | 467 | 468 | class RandomPedalboardCompressor(CPUBase): 469 | def __init__( 470 | self, 471 | sample_rate: float = 44100.0, 472 | min_threshold_db: float = -42.0, 473 | max_threshold_db: float = -6.0, 474 | min_ratio: float = 1.5, 475 | max_ratio: float = 4.0, 476 | min_attack_ms: float = 1.0, 477 | max_attack_ms: float = 50.0, 478 | min_release_ms: float = 10.0, 479 | max_release_ms: float = 250.0, 480 | **kwargs, 481 | ) -> None: 482 | super().__init__(**kwargs) 483 | self.sample_rate = sample_rate 484 | self.min_threshold_db = min_threshold_db 485 | self.max_threshold_db = max_threshold_db 486 | self.min_ratio = min_ratio 487 | self.max_ratio = max_ratio 488 | self.min_attack_ms = min_attack_ms 489 | self.max_attack_ms = max_attack_ms 490 | self.min_release_ms = min_release_ms 491 | self.max_release_ms = max_release_ms 492 | 493 | def _transform(self, x: torch.Tensor): 494 | board = Pedalboard() 495 | threshold_db = rand(self.min_threshold_db, self.max_threshold_db) 496 | ratio = rand(self.min_ratio, self.max_ratio) 497 | attack_ms = rand(self.min_attack_ms, self.max_attack_ms) 498 | release_ms = rand(self.min_release_ms, self.max_release_ms) 499 | 500 | board.append( 501 | Compressor( 502 | threshold_db=threshold_db, 503 | ratio=ratio, 504 | attack_ms=attack_ms, 505 | release_ms=release_ms, 506 | ) 507 | ) 508 | 509 | # process audio using the pedalboard 510 | return torch.from_numpy(board(x.numpy(), self.sample_rate)) 511 | 512 | 513 | class RandomPedalboardDelay(CPUBase): 514 | def __init__( 515 | self, 516 | sample_rate: float = 44100.0, 517 | min_delay_seconds: float = 0.1, 518 | max_delay_sconds: float = 1.0, 519 | min_feedback: float = 0.05, 520 | max_feedback: float = 0.6, 521 | min_mix: float = 0.0, 522 | max_mix: float = 0.7, 523 | **kwargs, 524 | ) -> None: 525 | super().__init__(**kwargs) 526 | self.sample_rate = sample_rate 527 | self.min_delay_seconds = min_delay_seconds 528 | self.max_delay_seconds = max_delay_sconds 529 | self.min_feedback = min_feedback 530 | self.max_feedback = max_feedback 531 | self.min_mix = min_mix 532 | self.max_mix = max_mix 533 | 534 | def _transform(self, x: torch.Tensor): 535 | board = Pedalboard() 536 | delay_seconds = loguniform(self.min_delay_seconds, self.max_delay_seconds) 537 | feedback = rand(self.min_feedback, self.max_feedback) 538 | mix = rand(self.min_mix, self.max_mix) 539 | board.append(Delay(delay_seconds=delay_seconds, feedback=feedback, mix=mix)) 540 | return torch.from_numpy(board(x.numpy(), self.sample_rate)) 541 | 542 | 543 | class RandomPedalboardChorus(CPUBase): 544 | def __init__( 545 | self, 546 | sample_rate: float = 44100.0, 547 | min_rate_hz: float = 0.25, 548 | max_rate_hz: float = 4.0, 549 | min_depth: float = 0.0, 550 | max_depth: float = 0.6, 551 | min_centre_delay_ms: float = 5.0, 552 | max_centre_delay_ms: float = 10.0, 553 | min_feedback: float = 0.1, 554 | max_feedback: float = 0.6, 555 | min_mix: float = 0.1, 556 | max_mix: float = 0.7, 557 | **kwargs, 558 | ) -> None: 559 | super().__init__(**kwargs) 560 | self.sample_rate = sample_rate 561 | self.min_rate_hz = min_rate_hz 562 | self.max_rate_hz = max_rate_hz 563 | self.min_depth = min_depth 564 | self.max_depth = max_depth 565 | self.min_centre_delay_ms = min_centre_delay_ms 566 | self.max_centre_delay_ms = max_centre_delay_ms 567 | self.min_feedback = min_feedback 568 | self.max_feedback = max_feedback 569 | self.min_mix = min_mix 570 | self.max_mix = max_mix 571 | 572 | def _transform(self, x: torch.Tensor): 573 | board = Pedalboard() 574 | rate_hz = rand(self.min_rate_hz, self.max_rate_hz) 575 | depth = rand(self.min_depth, self.max_depth) 576 | centre_delay_ms = rand(self.min_centre_delay_ms, self.max_centre_delay_ms) 577 | feedback = rand(self.min_feedback, self.max_feedback) 578 | mix = rand(self.min_mix, self.max_mix) 579 | board.append( 580 | Chorus( 581 | rate_hz=rate_hz, 582 | depth=depth, 583 | centre_delay_ms=centre_delay_ms, 584 | feedback=feedback, 585 | mix=mix, 586 | ) 587 | ) 588 | # process audio using the pedalboard 589 | return torch.from_numpy(board(x.numpy(), self.sample_rate)) 590 | 591 | 592 | class RandomPedalboardPhaser(CPUBase): 593 | def __init__( 594 | self, 595 | sample_rate: float = 44100.0, 596 | min_rate_hz: float = 0.25, 597 | max_rate_hz: float = 5.0, 598 | min_depth: float = 0.1, 599 | max_depth: float = 0.6, 600 | min_centre_frequency_hz: float = 200.0, 601 | max_centre_frequency_hz: float = 600.0, 602 | min_feedback: float = 0.1, 603 | max_feedback: float = 0.6, 604 | min_mix: float = 0.1, 605 | max_mix: float = 0.7, 606 | **kwargs, 607 | ) -> None: 608 | super().__init__(**kwargs) 609 | self.sample_rate = sample_rate 610 | self.min_rate_hz = min_rate_hz 611 | self.max_rate_hz = max_rate_hz 612 | self.min_depth = min_depth 613 | self.max_depth = max_depth 614 | self.min_centre_frequency_hz = min_centre_frequency_hz 615 | self.max_centre_frequency_hz = max_centre_frequency_hz 616 | self.min_feedback = min_feedback 617 | self.max_feedback = max_feedback 618 | self.min_mix = min_mix 619 | self.max_mix = max_mix 620 | 621 | def _transform(self, x: torch.Tensor): 622 | board = Pedalboard() 623 | rate_hz = rand(self.min_rate_hz, self.max_rate_hz) 624 | depth = rand(self.min_depth, self.max_depth) 625 | centre_frequency_hz = rand( 626 | self.min_centre_frequency_hz, self.min_centre_frequency_hz 627 | ) 628 | feedback = rand(self.min_feedback, self.max_feedback) 629 | mix = rand(self.min_mix, self.max_mix) 630 | board.append( 631 | Phaser( 632 | rate_hz=rate_hz, 633 | depth=depth, 634 | centre_frequency_hz=centre_frequency_hz, 635 | feedback=feedback, 636 | mix=mix, 637 | ) 638 | ) 639 | # process audio using the pedalboard 640 | return torch.from_numpy(board(x.numpy(), self.sample_rate)) 641 | 642 | 643 | class RandomPedalboardLimiter(CPUBase): 644 | def __init__( 645 | self, 646 | sample_rate: float = 44100.0, 647 | min_threshold_db: float = -32.0, 648 | max_threshold_db: float = -6.0, 649 | min_release_ms: float = 10.0, 650 | max_release_ms: float = 300.0, 651 | **kwargs, 652 | ) -> None: 653 | super().__init__(**kwargs) 654 | self.sample_rate = sample_rate 655 | self.min_threshold_db = min_threshold_db 656 | self.max_threshold_db = max_threshold_db 657 | self.min_release_ms = min_release_ms 658 | self.max_release_ms = max_release_ms 659 | 660 | def _transform(self, x: torch.Tensor): 661 | board = Pedalboard() 662 | threshold_db = rand(self.min_threshold_db, self.max_threshold_db) 663 | release_ms = rand(self.min_release_ms, self.max_release_ms) 664 | board.append( 665 | Limiter( 666 | threshold_db=threshold_db, 667 | release_ms=release_ms, 668 | ) 669 | ) 670 | return torch.from_numpy(board(x.numpy(), self.sample_rate)) 671 | 672 | 673 | class RandomPedalboardDistortion(CPUBase): 674 | def __init__( 675 | self, 676 | sample_rate: float = 44100.0, 677 | min_drive_db: float = -20.0, 678 | max_drive_db: float = 12.0, 679 | **kwargs, 680 | ): 681 | super().__init__(**kwargs) 682 | self.sample_rate = sample_rate 683 | self.min_drive_db = min_drive_db 684 | self.max_drive_db = max_drive_db 685 | 686 | def _transform(self, x: torch.Tensor): 687 | board = Pedalboard() 688 | drive_db = rand(self.min_drive_db, self.max_drive_db) 689 | board.append(Distortion(drive_db=drive_db)) 690 | return torch.from_numpy(board(x.numpy(), self.sample_rate)) 691 | 692 | 693 | class RandomSoxReverb(CPUBase): 694 | def __init__( 695 | self, 696 | sample_rate: float = 44100.0, 697 | min_reverberance: float = 10.0, 698 | max_reverberance: float = 100.0, 699 | min_high_freq_damping: float = 0.0, 700 | max_high_freq_damping: float = 100.0, 701 | min_wet_dry: float = 0.0, 702 | max_wet_dry: float = 1.0, 703 | min_room_scale: float = 5.0, 704 | max_room_scale: float = 100.0, 705 | min_stereo_depth: float = 20.0, 706 | max_stereo_depth: float = 100.0, 707 | min_pre_delay: float = 0.0, 708 | max_pre_delay: float = 100.0, 709 | **kwargs, 710 | ) -> None: 711 | super().__init__(**kwargs) 712 | self.sample_rate = sample_rate 713 | self.min_reverberance = min_reverberance 714 | self.max_reverberance = max_reverberance 715 | self.min_high_freq_damping = min_high_freq_damping 716 | self.max_high_freq_damping = max_high_freq_damping 717 | self.min_wet_dry = min_wet_dry 718 | self.max_wet_dry = max_wet_dry 719 | self.min_room_scale = min_room_scale 720 | self.max_room_scale = max_room_scale 721 | self.min_stereo_depth = min_stereo_depth 722 | self.max_stereo_depth = max_stereo_depth 723 | self.min_pre_delay = min_pre_delay 724 | self.max_pre_delay = max_pre_delay 725 | 726 | def _transform(self, x: torch.Tensor): 727 | reverberance = rand(self.min_reverberance, self.max_reverberance) 728 | high_freq_damping = rand(self.min_high_freq_damping, self.max_high_freq_damping) 729 | room_scale = rand(self.min_room_scale, self.max_room_scale) 730 | stereo_depth = rand(self.min_stereo_depth, self.max_stereo_depth) 731 | wet_dry = rand(self.min_wet_dry, self.max_wet_dry) 732 | pre_delay = rand(self.min_pre_delay, self.max_pre_delay) 733 | 734 | effects = [ 735 | [ 736 | "reverb", 737 | f"{reverberance}", 738 | f"{high_freq_damping}", 739 | f"{room_scale}", 740 | f"{stereo_depth}", 741 | f"{pre_delay}", 742 | "--wet-only", 743 | ] 744 | ] 745 | y, _ = torchaudio.sox_effects.apply_effects_tensor( 746 | x, self.sample_rate, effects, channels_first=True 747 | ) 748 | 749 | # manual wet/dry mix 750 | return (x * (1 - wet_dry)) + (y * wet_dry) 751 | 752 | 753 | class RandomPedalboardReverb(CPUBase): 754 | def __init__( 755 | self, 756 | sample_rate: float = 44100.0, 757 | min_room_size: float = 0.0, 758 | max_room_size: float = 1.0, 759 | min_damping: float = 0.0, 760 | max_damping: float = 1.0, 761 | min_wet_dry: float = 0.0, 762 | max_wet_dry: float = 0.7, 763 | min_width: float = 0.0, 764 | max_width: float = 1.0, 765 | **kwargs, 766 | ) -> None: 767 | super().__init__(**kwargs) 768 | self.sample_rate = sample_rate 769 | self.min_room_size = min_room_size 770 | self.max_room_size = max_room_size 771 | self.min_damping = min_damping 772 | self.max_damping = max_damping 773 | self.min_wet_dry = min_wet_dry 774 | self.max_wet_dry = max_wet_dry 775 | self.min_width = min_width 776 | self.max_width = max_width 777 | 778 | def _transform(self, x: torch.Tensor): 779 | board = Pedalboard() 780 | room_size = rand(self.min_room_size, self.max_room_size) 781 | damping = rand(self.min_damping, self.max_damping) 782 | wet_dry = rand(self.min_wet_dry, self.max_wet_dry) 783 | width = rand(self.min_width, self.max_width) 784 | 785 | board.append( 786 | Reverb( 787 | room_size=room_size, 788 | damping=damping, 789 | wet_level=wet_dry, 790 | dry_level=(1 - wet_dry), 791 | width=width, 792 | ) 793 | ) 794 | 795 | return torch.from_numpy(board(x.numpy(), self.sample_rate)) 796 | 797 | 798 | class LoudnessNormalize(CPUBase): 799 | def __init__( 800 | self, 801 | sample_rate: float = 44100.0, 802 | target_lufs_db: float = -32.0, 803 | **kwargs, 804 | ) -> None: 805 | super().__init__(**kwargs) 806 | self.sample_rate = sample_rate 807 | self.meter = pyln.Meter(sample_rate) 808 | self.target_lufs_db = target_lufs_db 809 | 810 | def _transform(self, x: torch.Tensor): 811 | x_lufs_db = self.meter.integrated_loudness(x.permute(1, 0).numpy()) 812 | delta_lufs_db = torch.tensor([self.target_lufs_db - x_lufs_db]).float() 813 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0) 814 | return gain_lin * x 815 | 816 | 817 | class Mono2Stereo(CPUBase): 818 | def __init__(self) -> None: 819 | super().__init__(p=1.0) 820 | 821 | def _transform(self, x: torch.Tensor): 822 | assert x.ndim == 2 and x.shape[0] == 1, "x must be mono" 823 | return torch.cat([x, x], dim=0) 824 | 825 | 826 | class RandomPan(CPUBase): 827 | def __init__( 828 | self, 829 | min_pan: float = -1.0, 830 | max_pan: float = 1.0, 831 | **kwargs, 832 | ) -> None: 833 | super().__init__(**kwargs) 834 | self.min_pan = min_pan 835 | self.max_pan = max_pan 836 | 837 | def _transform(self, x: torch.Tensor): 838 | """Constant power panning""" 839 | assert x.ndim == 2 and x.shape[0] == 2, "x must be stereo" 840 | theta = rand(self.min_pan, self.max_pan) * np.pi / 4 841 | x = x * 0.707 # normalize to prevent clipping 842 | left_x, right_x = x[0], x[1] 843 | cos_theta = np.cos(theta) 844 | sin_theta = np.sin(theta) 845 | left_x = left_x * (cos_theta - sin_theta) 846 | right_x = right_x * (cos_theta + sin_theta) 847 | return torch.stack([left_x, right_x], dim=0) 848 | -------------------------------------------------------------------------------- /data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDataset 2 | from .fast_musdb import FastMUSDB 3 | from .dnr import DnR 4 | from .label_noise_bleed import LabelNoiseBleed 5 | -------------------------------------------------------------------------------- /data/dataset/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import random 4 | import numpy as np 5 | import torchaudio 6 | from typing import Optional, Callable, List, Tuple 7 | from pathlib import Path 8 | 9 | __all__ = ["BaseDataset"] 10 | 11 | 12 | class BaseDataset(Dataset): 13 | sr: int = 44100 14 | 15 | def __init__( 16 | self, 17 | tracks: List[Path], 18 | track_lengths: List[int], 19 | sources: List[str], 20 | mix_name: str = "mix", 21 | seq_duration: float = 6.0, 22 | samples_per_track: int = 64, 23 | random: bool = False, 24 | random_track_mix: bool = False, 25 | transform: Optional[Callable] = None, 26 | ): 27 | super().__init__() 28 | self.tracks = tracks 29 | self.track_lengths = track_lengths 30 | self.sources = sources 31 | self.mix_name = mix_name 32 | self.seq_duration = seq_duration 33 | self.samples_per_track = samples_per_track 34 | self.segment = int(self.seq_duration * self.sr) 35 | self.random = random 36 | self.random_track_mix = random_track_mix 37 | self.transform = transform 38 | 39 | if self.seq_duration <= 0: 40 | self._size = len(self.tracks) 41 | elif self.random: 42 | self._size = len(self.tracks) * self.samples_per_track 43 | else: 44 | chunks = [l // self.segment for l in self.track_lengths] 45 | cum_chunks = np.cumsum(chunks) 46 | self.cum_chunks = cum_chunks 47 | self._size = cum_chunks[-1] 48 | 49 | def load_tracks(self) -> Tuple[List[Path], List[int]]: 50 | # Implement in child class 51 | # Return list of tracks and list of track lengths 52 | raise NotImplementedError 53 | 54 | def __len__(self): 55 | return self._size 56 | 57 | def _get_random_track_idx(self): 58 | return random.randrange(len(self.tracks)) 59 | 60 | def _get_random_start(self, length): 61 | return random.randrange(length - self.segment + 1) 62 | 63 | def _get_track_from_chunk(self, index): 64 | track_idx = np.digitize(index, self.cum_chunks) 65 | if track_idx > 0: 66 | chunk_start = (index - self.cum_chunks[track_idx - 1]) * self.segment 67 | else: 68 | chunk_start = index * self.segment 69 | return self.tracks[track_idx], chunk_start 70 | 71 | def __getitem__(self, index): 72 | stems = [] 73 | if self.seq_duration <= 0: 74 | folder_name = self.tracks[index] 75 | x = torchaudio.load( 76 | folder_name / f"{self.mix_name}.wav", 77 | )[0] 78 | for s in self.sources: 79 | source_name = folder_name / (s + ".wav") 80 | audio = torchaudio.load(source_name)[0] 81 | stems.append(audio) 82 | else: 83 | if self.random: 84 | track_idx = index // self.samples_per_track 85 | folder_name, chunk_start = self.tracks[ 86 | track_idx 87 | ], self._get_random_start(self.track_lengths[track_idx]) 88 | else: 89 | folder_name, chunk_start = self._get_track_from_chunk(index) 90 | for s in self.sources: 91 | if self.random_track_mix and self.random: 92 | track_idx = self._get_random_track_idx() 93 | folder_name, chunk_start = self.tracks[ 94 | track_idx 95 | ], self._get_random_start(self.track_lengths[track_idx]) 96 | source_name = folder_name / (s + ".wav") 97 | audio = torchaudio.load( 98 | source_name, 99 | num_frames=self.segment, 100 | frame_offset=chunk_start, 101 | )[0] 102 | stems.append(audio) 103 | if self.random_track_mix and self.random: 104 | x = sum(stems) 105 | else: 106 | x = torchaudio.load( 107 | folder_name / f"{self.mix_name}.wav", 108 | num_frames=self.segment, 109 | frame_offset=chunk_start, 110 | )[0] 111 | y = torch.stack(stems) 112 | if self.transform is not None: 113 | x, y = self.transform((x, y)) 114 | return x, y 115 | -------------------------------------------------------------------------------- /data/dataset/dnr.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from tqdm import tqdm 3 | from data.dataset import BaseDataset 4 | from pathlib import Path 5 | import os 6 | 7 | from aimless.utils import SDX_SOURCES as SOURCES 8 | 9 | __all__ = ["DnR"] 10 | 11 | 12 | class DnR(BaseDataset): 13 | def __init__(self, root: str, split: str, **kwargs): 14 | tracks, track_lengths = load_tracks(root, split) 15 | super().__init__( 16 | **kwargs, 17 | tracks=tracks, 18 | track_lengths=track_lengths, 19 | sources=SOURCES, 20 | mix_name="mix", 21 | ) 22 | 23 | 24 | def load_tracks(root: str, split: str): 25 | root = Path(os.path.expanduser(root)) 26 | if split == "train": 27 | split_root = root / "tr" 28 | elif split == "valid": 29 | split_root = root / "cv" 30 | elif split == "test": 31 | split_root = root / "tt" 32 | else: 33 | raise ValueError("Invalid split: {}".format(split)) 34 | tracks = sorted([x for x in split_root.iterdir() if x.is_dir()]) 35 | for x in tracks: 36 | assert torchaudio.info(str(x / "mix.wav")).sample_rate == DnR.sr 37 | 38 | track_lengths = [ 39 | torchaudio.info(str(x / "mix.wav")).num_frames for x in tqdm(tracks) 40 | ] 41 | return tracks, track_lengths 42 | -------------------------------------------------------------------------------- /data/dataset/fast_musdb.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from tqdm import tqdm 3 | from data.dataset import BaseDataset 4 | import musdb 5 | import os 6 | import yaml 7 | from pathlib import Path 8 | from typing import List 9 | 10 | from aimless.utils import MDX_SOURCES as SOURCES 11 | 12 | __all__ = ["FastMUSDB"] 13 | 14 | 15 | class FastMUSDB(BaseDataset): 16 | def __init__( 17 | self, 18 | root: str, 19 | subsets: List[str] = ["train", "test"], 20 | split: str = None, 21 | **kwargs 22 | ): 23 | tracks, track_lengths = load_tracks(root, subsets, split) 24 | super().__init__( 25 | **kwargs, 26 | tracks=tracks, 27 | track_lengths=track_lengths, 28 | sources=SOURCES, 29 | mix_name="mixture", 30 | ) 31 | 32 | 33 | def load_tracks(root, subsets=None, split=None): 34 | root = Path(os.path.expanduser(root)) 35 | setup_path = os.path.join(musdb.__path__[0], "configs", "mus.yaml") 36 | with open(setup_path, "r") as f: 37 | setup = yaml.safe_load(f) 38 | if subsets is not None: 39 | if isinstance(subsets, str): 40 | subsets = [subsets] 41 | else: 42 | subsets = ["train", "test"] 43 | 44 | if subsets != ["train"] and split is not None: 45 | raise RuntimeError("Subset has to set to `train` when split is used") 46 | 47 | print("Gathering files ...") 48 | tracks = [] 49 | track_lengths = [] 50 | for subset in subsets: 51 | subset_folder = root / subset 52 | for _, folders, _ in tqdm(os.walk(subset_folder)): 53 | # parse pcm tracks and sort by name 54 | for track_name in sorted(folders): 55 | if subset == "train": 56 | if split == "train" and track_name in setup["validation_tracks"]: 57 | continue 58 | elif ( 59 | split == "valid" 60 | and track_name not in setup["validation_tracks"] 61 | ): 62 | continue 63 | 64 | track_folder = subset_folder / track_name 65 | # add track to list of tracks 66 | tracks.append(track_folder) 67 | meta = torchaudio.info(os.path.join(track_folder, "mixture.wav")) 68 | assert meta.sample_rate == FastMUSDB.sr 69 | 70 | track_lengths.append(meta.num_frames) 71 | 72 | return tracks, track_lengths 73 | -------------------------------------------------------------------------------- /data/dataset/label_noise_bleed.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from tqdm import tqdm 3 | from data.dataset import BaseDataset 4 | import os 5 | from pathlib import Path 6 | import csv 7 | 8 | from aimless.utils import MDX_SOURCES as SOURCES 9 | 10 | __all__ = ["DnR"] 11 | 12 | 13 | # Run scripts/dataset_split_and_mix.py first! 14 | class LabelNoiseBleed(BaseDataset): 15 | def __init__(self, root: str, split: str, clean_csv_path: str = None, **kwargs): 16 | tracks, track_lengths = load_tracks(root, split, clean_csv_path) 17 | super().__init__( 18 | **kwargs, 19 | tracks=tracks, 20 | track_lengths=track_lengths, 21 | sources=SOURCES, 22 | mix_name="mixture", 23 | ) 24 | 25 | 26 | def load_tracks(root: str, split: str, clean_csv_path: str): 27 | root = Path(os.path.expanduser(root)) 28 | if split == "train": 29 | split_root = root / "train" 30 | elif split == "valid": 31 | split_root = root / "valid" 32 | elif split == "test": 33 | split_root = root / "test" 34 | else: 35 | raise ValueError("Invalid split: {}".format(split)) 36 | 37 | tracks = sorted([x for x in split_root.iterdir() if x.is_dir()]) 38 | if clean_csv_path is not None: 39 | with open(clean_csv_path, "r") as f: 40 | reader = csv.reader(f) 41 | clean_tracks = [x[0] for x in reader if x[1] == "Y"] 42 | tracks = [x for x in tracks if x.name in clean_tracks] 43 | 44 | for x in tracks: 45 | assert torchaudio.info(str(x / "mixture.wav")).sample_rate == LabelNoiseBleed.sr 46 | 47 | track_lengths = [ 48 | torchaudio.info(str(x / "mixture.wav")).num_frames for x in tqdm(tracks) 49 | ] 50 | return tracks, track_lengths 51 | -------------------------------------------------------------------------------- /data/dataset/speech.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | from tqdm import tqdm 3 | from torch.utils.data import Dataset 4 | from pathlib import Path 5 | import random 6 | import torch 7 | import soundfile as sf 8 | import numpy as np 9 | from resampy import resample 10 | from typing import Optional, Callable 11 | 12 | __all__ = ["SpeechNoise"] 13 | 14 | 15 | class SpeechNoise(Dataset): 16 | sr: int = 44100 17 | 18 | def __init__( 19 | self, 20 | speech_root: str, 21 | noise_root: str, 22 | seq_duration: float = 6.0, 23 | samples_per_track: int = 64, 24 | least_overlap_ratio: float = 0.5, 25 | snr_sampler: Optional[torch.distributions.Distribution] = None, 26 | mono: bool = True, 27 | transform: Optional[Callable] = None, 28 | ): 29 | super().__init__() 30 | speech_root = Path(speech_root) 31 | noise_root = Path(noise_root) 32 | speech_files = list(speech_root.glob("**/*.wav")) + list( 33 | speech_root.glob("**/*.flac") 34 | ) 35 | noise_files = list(Path(noise_root).glob("**/*.wav")) + list( 36 | Path(noise_root).glob("**/*.flac") 37 | ) 38 | 39 | speech_track_frames = [] 40 | speech_track_sr = [] 41 | for x in tqdm(speech_files): 42 | info = torchaudio.info(x) 43 | speech_track_frames.append(info.num_frames) 44 | speech_track_sr.append(info.sample_rate) 45 | 46 | noise_track_frames = [] 47 | noise_track_sr = [] 48 | for x in tqdm(noise_files): 49 | info = torchaudio.info(x) 50 | noise_track_frames.append(info.num_frames) 51 | noise_track_sr.append(info.sample_rate) 52 | 53 | self.speech_files = speech_files 54 | self.noise_files = noise_files 55 | self.speech_track_frames = speech_track_frames 56 | self.noise_track_frames = noise_track_frames 57 | self.speech_track_sr = speech_track_sr 58 | self.noise_track_sr = noise_track_sr 59 | 60 | self.seq_duration = seq_duration 61 | self.samples_per_track = samples_per_track 62 | self.segment = int(self.seq_duration * self.sr) 63 | self.least_overlap_ratio = least_overlap_ratio 64 | self.least_overlap_segment = int(least_overlap_ratio * self.segment) 65 | self.transform = transform 66 | self.snr_sampler = snr_sampler 67 | self.mono = mono 68 | 69 | self._size = len(self.noise_files) * self.samples_per_track 70 | 71 | def __len__(self): 72 | return self._size 73 | 74 | def _get_random_track_idx(self): 75 | return random.randrange(len(self.tracks)) 76 | 77 | def _get_random_start(self, length): 78 | return random.randrange(length - self.segment + 1) 79 | 80 | def __getitem__(self, index): 81 | track_idx = index // self.samples_per_track 82 | noise_sr = self.noise_track_sr[track_idx] 83 | noise_resample_ratio = self.sr / noise_sr 84 | noise_file = self.noise_files[track_idx] 85 | pos_start = int( 86 | self._get_random_start( 87 | int(self.noise_track_frames[track_idx] * noise_resample_ratio) 88 | ) 89 | / noise_resample_ratio 90 | ) 91 | frames = int(self.seq_duration * noise_sr) 92 | noise, _ = sf.read( 93 | noise_file, start=pos_start, frames=frames, fill_value=0, always_2d=True 94 | ) 95 | if noise_sr != self.sr: 96 | noise = resample(noise, noise_sr, self.sr, axis=0) 97 | if noise.shape[0] < self.segment: 98 | noise = np.pad( 99 | noise, ((0, self.segment - noise.shape[0]), (0, 0)), "constant" 100 | ) 101 | else: 102 | noise = noise[: self.segment] 103 | 104 | if self.mono: 105 | noise = noise.mean(axis=1, keepdims=True) 106 | else: 107 | noise = np.broadcast_to(noise, (noise.shape[0], 2)) 108 | 109 | # get a random speech file 110 | speech_idx = random.randint(0, len(self.speech_files) - 1) 111 | speech_file = self.speech_files[speech_idx] 112 | speech_sr = self.speech_track_sr[speech_idx] 113 | speech_resample_ratio = self.sr / speech_sr 114 | speech_resampled_length = int( 115 | self.speech_track_frames[speech_idx] * speech_resample_ratio 116 | ) 117 | 118 | if speech_resampled_length < self.least_overlap_segment: 119 | speech, _ = sf.read(speech_file, always_2d=True) 120 | if speech_sr != self.sr: 121 | speech = resample(speech, speech_sr, self.sr, axis=0) 122 | speech_resampled_length = speech.shape[0] 123 | 124 | if self.mono: 125 | speech = speech.mean(axis=1, keepdims=True) 126 | else: 127 | speech = np.broadcast_to(speech, (speech.shape[0], 2)) 128 | 129 | speech_energy = np.sum(speech**2) 130 | pos = random.randint(0, self.segment - speech_resampled_length) 131 | 132 | padded_speech = np.zeros_like(noise) 133 | padded_speech[pos : pos + speech_resampled_length] = speech 134 | speech = padded_speech 135 | else: 136 | pos = random.randint( 137 | self.least_overlap_segment - speech_resampled_length, 138 | self.segment - self.least_overlap_segment, 139 | ) 140 | if pos < 0: 141 | pos_start = int(-pos / speech_resample_ratio) 142 | frames = int( 143 | min(self.segment, (speech_resampled_length + pos)) 144 | / speech_resample_ratio 145 | ) 146 | else: 147 | pos_start = 0 148 | frames = int( 149 | min(speech_resampled_length, self.segment - pos) 150 | / speech_resample_ratio 151 | ) 152 | 153 | speech, _ = sf.read( 154 | speech_file, 155 | start=pos_start, 156 | frames=frames, 157 | fill_value=0, 158 | always_2d=True, 159 | ) 160 | if speech_sr != self.sr: 161 | speech = resample(speech, speech_sr, self.sr, axis=0) 162 | 163 | if self.mono: 164 | speech = speech.mean(axis=1, keepdims=True) 165 | else: 166 | speech = np.broadcast_to(speech, (speech.shape[0], 2)) 167 | 168 | speech_energy = np.sum(speech**2) 169 | 170 | padded_speech = np.zeros_like(noise) 171 | pos = max(0, pos) 172 | padded_speech[pos : pos + speech.shape[0]] = speech 173 | speech = padded_speech 174 | 175 | speech = torch.from_numpy(speech.T).float() 176 | noise = torch.from_numpy(noise.T).float() 177 | 178 | if self.snr_sampler is not None: 179 | snr = self.snr_sampler.sample() 180 | # scale noise to have the desired SNR 181 | noise_energy = noise.pow(2).sum() + 1e-8 182 | noise = noise * torch.sqrt(speech_energy / noise_energy) * 10 ** (-snr / 10) 183 | 184 | stems = torch.stack([speech, noise], dim=0) 185 | mix = speech + noise 186 | 187 | if self.transform is not None: 188 | mix, stems = self.transform((mix, stems)) 189 | 190 | return mix, stems 191 | -------------------------------------------------------------------------------- /data/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | from .musdb import MUSDB 2 | from .dnr import DnR 3 | from .bleed import Bleed 4 | from .label_noise import LabelNoise 5 | from .speech import SpeechNoise 6 | -------------------------------------------------------------------------------- /data/lightning/bleed.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | from typing import List 5 | 6 | 7 | from data.dataset import LabelNoiseBleed 8 | from data.augment import CPUBase 9 | 10 | 11 | class Bleed(pl.LightningDataModule): 12 | def __init__( 13 | self, 14 | root: str, 15 | seq_duration: float = 6.0, 16 | samples_per_track: int = 64, 17 | random: bool = False, 18 | random_track_mix: bool = False, 19 | transforms: List[CPUBase] = None, 20 | batch_size: int = 16, 21 | ): 22 | super().__init__() 23 | self.save_hyperparameters( 24 | "root", 25 | "seq_duration", 26 | "samples_per_track", 27 | "random", 28 | "random_track_mix", 29 | "batch_size", 30 | ) 31 | # manually save transforms since pedalboard is not pickleable 32 | if transforms is None: 33 | self.transforms = None 34 | else: 35 | self.transforms = Compose(transforms) 36 | 37 | def setup(self, stage=None): 38 | if stage == "fit": 39 | self.train_dataset = LabelNoiseBleed( 40 | root=self.hparams.root, 41 | split="train", 42 | seq_duration=self.hparams.seq_duration, 43 | samples_per_track=self.hparams.samples_per_track, 44 | random=self.hparams.random, 45 | random_track_mix=self.hparams.random_track_mix, 46 | transform=self.transforms, 47 | ) 48 | 49 | if stage == "validate" or stage == "fit": 50 | self.val_dataset = LabelNoiseBleed( 51 | root=self.hparams.root, 52 | split="valid", 53 | seq_duration=self.hparams.seq_duration, 54 | ) 55 | 56 | def train_dataloader(self): 57 | return DataLoader( 58 | self.train_dataset, 59 | batch_size=self.hparams.batch_size, 60 | num_workers=4, 61 | shuffle=True, 62 | drop_last=True, 63 | ) 64 | 65 | def val_dataloader(self): 66 | return DataLoader( 67 | self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4 68 | ) 69 | -------------------------------------------------------------------------------- /data/lightning/dnr.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from torch.utils.data import DataLoader, ConcatDataset 3 | import pytorch_lightning as pl 4 | from typing import List 5 | 6 | from data.dataset import DnR as DnRDataset 7 | from data.augment import CPUBase 8 | 9 | 10 | class DnR(pl.LightningDataModule): 11 | def __init__( 12 | self, 13 | root: str, 14 | seq_duration: float = 6.0, 15 | samples_per_track: int = 64, 16 | random: bool = False, 17 | include_val: bool = False, 18 | random_track_mix: bool = False, 19 | transforms: List[CPUBase] = None, 20 | batch_size: int = 16, 21 | ): 22 | super().__init__() 23 | self.save_hyperparameters( 24 | "root", 25 | "seq_duration", 26 | "samples_per_track", 27 | "random", 28 | "include_val", 29 | "random_track_mix", 30 | "batch_size", 31 | ) 32 | # manually save transforms since pedalboard is not pickleable 33 | if transforms is None: 34 | self.transforms = None 35 | else: 36 | self.transforms = Compose(transforms) 37 | 38 | def setup(self, stage=None): 39 | if stage == "fit": 40 | self.train_dataset = DnRDataset( 41 | root=self.hparams.root, 42 | split="train", 43 | seq_duration=self.hparams.seq_duration, 44 | samples_per_track=self.hparams.samples_per_track, 45 | random=self.hparams.random, 46 | random_track_mix=self.hparams.random_track_mix, 47 | transform=self.transforms, 48 | ) 49 | 50 | if self.hparams.include_val: 51 | self.train_dataset = ConcatDataset( 52 | [ 53 | self.train_dataset, 54 | DnRDataset( 55 | root=self.hparams.root, 56 | split="valid", 57 | seq_duration=self.hparams.seq_duration, 58 | samples_per_track=self.hparams.samples_per_track, 59 | random=self.hparams.random, 60 | random_track_mix=self.hparams.random_track_mix, 61 | transform=self.transforms, 62 | ), 63 | ] 64 | ) 65 | 66 | if stage == "validate" or stage == "fit": 67 | self.val_dataset = DnRDataset( 68 | root=self.hparams.root, 69 | split="valid", 70 | seq_duration=self.hparams.seq_duration, 71 | ) 72 | 73 | def train_dataloader(self): 74 | return DataLoader( 75 | self.train_dataset, 76 | batch_size=self.hparams.batch_size, 77 | num_workers=4, 78 | shuffle=True, 79 | drop_last=True, 80 | ) 81 | 82 | def val_dataloader(self): 83 | return DataLoader( 84 | self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4 85 | ) 86 | -------------------------------------------------------------------------------- /data/lightning/label_noise.csv: -------------------------------------------------------------------------------- 1 | ID,Clean,Notes 2 | 0a589d65-50a3-4999-8f16-b5b6199bceee,, 3 | 0d528a19-cb0f-4421-b250-444f9343e51c,, 4 | 0e0d57cd-8662-4091-86d4-ed3e35d04ef6,, 5 | 0f5fb60c-51d4-4618-871d-650c9e927b79,, 6 | 1a32e987-cae8-458f-9c26-d9a6abf5348d,, 7 | 1ade22ad-99bc-4954-b2b9-fb9e8fa06d41,Y, 8 | 1afe1b3b-3e2e-48d3-b859-f50e222cbaf4,, 9 | 1ee489e8-55ad-4d4a-8277-24051b85c02d,, 10 | 1f98fe4d-26c7-460f-9f68-33964bc4d8d3,Y, 11 | 1fc37390-1769-452d-9bea-19025be4c467,Y, 12 | 2a3fd99d-c86f-4275-9321-ad7146364503,Y, 13 | 2b4e8304-c92d-4347-b09e-cfb9b3e29bf2,, 14 | 2c020edb-5947-4fa7-afea-ebc592cea683,, 15 | 2dc237cd-1637-46f0-8f58-ca68dc6f6031,Y, 16 | 2e5d996d-43f3-4359-b7c5-afebe9997556,, 17 | 02ee37da-eea3-42b4-83bf-ab7f243afa13,, 18 | 3a047d1a-f56d-4bf4-9910-f4f77206e53d,, 19 | 3a6023c3-4cd8-46c0-a376-d353743648a9,, 20 | 3c3b5fdb-f15e-4ba4-884a-b083ce2426c6,, 21 | 3c557409-3a34-43c2-9159-5421bbad5ecb,, 22 | 3e41f238-7c48-4a42-ba70-5ee39824a844,, 23 | 3e656eec-84d4-4a45-b410-d3817d849f92,, 24 | 3e7985e5-408f-4cf8-92b9-b9f62f738dd3,Y, 25 | 3e389000-8fdc-4b63-b8b8-ab044273790d,, 26 | 3f5233cb-57fa-4772-b389-f295a6f416ae,, 27 | 4a896cde-57c6-4646-b610-1b0b654d0349,, 28 | 4b9f86f4-23e4-458b-839e-8a63b584bea3,, 29 | 4cbd6c36-87a2-4d50-86e3-52d39b98fad3,, 30 | 5a6df5c7-a58a-479e-bdfb-c5946c221933,Y,"20s of harmonica, some drum bleed in vox" 31 | 05b0ab77-2495-438d-8831-e3d81f96c16d,Y,"Bass sometimes doesn't sound like ""bass""" 32 | 05e7af85-9721-4b42-952a-ccd34feb6033,, 33 | 5f04798d-c7be-4b8a-90bd-1fcd9946e875,, 34 | 6b168ae6-9d8a-4dc2-9d27-898e6871bf8b,, 35 | 06bfc6e7-e5ac-4827-bf8e-ffaf1675872f,Y,Crash cymbal in other 36 | 6c70d5e0-5972-444a-86f8-a558dbb92d92,, 37 | 6cd44645-ed19-4ecc-a57c-58d400005b29,, 38 | 6ce087b4-e571-4472-9be2-04b5340311c6,, 39 | 6ceda40a-88bc-4e98-87c3-dd5c91725d41,, 40 | 6e50565e-8179-4913-af54-a2a7f0dcab2f,, 41 | 7a9f3169-7bdf-483c-9634-8e5b097d50df,, 42 | 7ba734f0-547e-4142-a9b4-01a0c25b9d1f,, 43 | 7bfa233c-24ed-4c7a-9096-11e3aa00c55d,, 44 | 7dd515b0-e218-425d-b8bf-a75056237d6a,, 45 | 7e74ce11-1603-440b-8794-b6e665b917e1,, 46 | 7efcb55d-c0db-472f-b765-1739ad7536aa,, 47 | 07fb2df2-91d6-458d-9230-9638b4edac08,, 48 | 8a6c9c1f-4865-404f-a805-1949de36a33c,, 49 | 8b83ba75-5e35-48c4-b42f-5419da2e6301,Y, 50 | 8ce11544-9a6f-4f1e-ac2f-fc10343f15c8,, 51 | 8e2d0c5c-6764-4d74-a740-391d7931ffd7,, 52 | 8f36f17f-c033-4c1b-a793-80dce43d507b,, 53 | 9ac2612b-e25f-4d27-8d43-b957e7e5a74b,Y, 54 | 9c8a5c66-f6d8-4425-8671-6b7aa6a2663b,, 55 | 9ce23a79-20eb-431a-80d2-eda3260ef503,, 56 | 9eb8bc50-cffb-4c19-be0e-d27423e3e102,, 57 | 9f581867-9f63-4ec4-8a43-e62c3c4230a6,, 58 | 9fc0cff7-bb02-496f-9fa1-db67c52b1b4b,, 59 | 13f233aa-a2e5-4683-8533-2f1e344b55b4,, 60 | 014f3712-293b-42af-9f29-0ed1785be792,, 61 | 16fb6c39-3834-4cfc-abf4-abb4a8d4646c,, 62 | 22d265ef-ee2b-4aba-8d60-c3430295cd6d,Y, 63 | 22ea41a3-1766-4a76-8071-380b27f1869a,, 64 | 24f8a652-0168-4702-8725-42f9924e6729,, 65 | 30cfc60a-5a57-4000-a05e-65006c8f6f74,, 66 | 35a19148-49bf-451d-9a0e-5ab8e914c367,, 67 | 36ee7fc6-604c-4a75-b4f0-0a9ceef3b9cb,Y, 68 | 43ca388d-3e62-4df2-b72d-abf407a7aa5a,, 69 | 045dcfd1-e960-4332-80cc-fdacc4a7c6a7,, 70 | 046ab651-a333-46e1-9d27-ab14ee036c42,, 71 | 46bc5393-7753-44ae-913b-bd5fa8f33e98,, 72 | 49ed7cef-8ffb-4833-bc96-6df7b8ff5b43,, 73 | 58efddfb-04b3-4951-858e-e7dbcfccfc21,, 74 | 58fa04aa-426f-4e16-a112-29eb6e2f2d3e,, 75 | 63b68795-0076-476b-a917-dec9e89bf91e,Y, 76 | 72c6f013-11ea-4bf7-93b2-c6ef2c117718,, 77 | 78ef22ce-472f-4f82-8656-16df73b9465f,Y, 78 | 87a5da23-f17b-44da-accf-c04832f81a14,, 79 | 88b545e5-4d06-4d55-a306-1bd3a2915ee5,, 80 | 89c515c9-5e93-4cb4-9806-20432d2d074d,, 81 | 89f2c781-5c67-4508-a2d6-236744b8c197,, 82 | 93dda0e8-dd76-49e4-b08f-54a82387cdf6,, 83 | 94fafb2a-9f4e-4a01-bee9-998008f95f41,, 84 | 97b07e0e-274e-4212-a66b-44210a48724d,Y,Strange sfx in Vox @ 0:30 85 | 125fc63d-9b69-4170-a46a-42c91bc28446,, 86 | 152d4f5d-4093-4fa4-a4a4-8a9b3502d89d,, 87 | 174a115f-3688-45dc-8c39-9d05f21758e1,, 88 | 0177be35-64de-469b-908b-2d9edb49c053,, 89 | 212bb137-fd01-465e-80f3-a890fb0ebcdd,, 90 | 260c431b-72d4-4ac6-bae5-ee49ce5c0fe4,, 91 | 312bec8d-1c61-43e0-924a-1fb87ddc3e41,Y, 92 | 322f4d9d-b0c9-4ab3-9e30-544a25331ffd,, 93 | 334ed0e5-1761-4ee9-9d39-a555eb9b64a0,, 94 | 0358fd1e-244a-4422-9a42-29b5d68f6e4b,Y,No Bass track 95 | 390b4fea-92be-4fe3-9576-86529f80b4ae,, 96 | 664cc931-4e00-441e-9a78-e7a292515cea,, 97 | 704f1de9-1d02-4c2b-af05-107a7700a51d,, 98 | 747d5c98-665b-4470-a696-7a6cf6968ef1,, 99 | 765d5131-afd9-4ad7-8786-8bef5705c1c2,Y, 100 | 825f697d-5fcd-4429-ba33-23163c726ca7,, 101 | 1921a83e-0373-4bf7-8dc9-6cc9401c9309,, 102 | 2973adc9-6bf8-4422-9f7a-6fd0038eb565,Y, 103 | 4380ad97-2620-4419-a011-ddfa29a87f54,, 104 | 4999a0bf-a753-4e0e-85b1-690259dabf96,, 105 | 6031b120-f6e2-4999-96ba-a1e31be68ea8,, 106 | 06114cc2-e34d-4f8f-82d3-cf6981572f2f,, 107 | 6681f493-c996-424a-9bdb-c671912ea9db,, 108 | 7180bffa-dd48-49b9-bc02-f6e3f7f165b0,, 109 | 8042b88a-6179-406b-9ec4-b45a4cdd4a71,, 110 | 8804c154-6294-481a-ad63-bc61162cae2f,, 111 | 28748b6e-6125-42f7-998d-2ad734e39b6c,, 112 | 49478a32-483f-48d2-a594-d272b44bf587,Y, 113 | 53808b95-cfe9-461d-a113-ffadf32817a1,, 114 | 95378cf3-e939-42e0-b486-ebf2ca951664,, 115 | 169628c5-266d-4e11-993b-440fa5fa2167,Y, 116 | 378742ba-5ba8-44c2-9cfd-8a609decca57,, 117 | 553048ce-7afd-4e0e-b4cb-4896620287a1,, 118 | 731893c6-67b9-42f0-aea6-d1f70c2c9870,, 119 | 737356b2-ce9c-448b-877b-e42b3ed94563,, 120 | 763641c7-488f-4959-a554-fdbce9582644,, 121 | 04204031-4f98-44ba-9c47-98c2f2e6b8fc,, 122 | 04798708-6915-4dbc-842e-d394d545d4eb,, 123 | 4857878a-e44b-4143-90e9-b65d0b704306,, 124 | 5640831d-7853-4d06-8166-988e2844b652,, 125 | 7524054e-dc67-47e0-8c26-ea1d4d70d2fb,Y, 126 | 8427760a-b82e-4136-8f12-dfd53cad9bc9,, 127 | 25789239-1075-43b9-bfc9-51dff4a29590,, 128 | a0b9a4e4-51f5-4c98-a090-0317fb891056,, 129 | a0eae9d2-d97f-401a-a495-1e1d1cb84a9c,Y, 130 | a1dcaeb2-f4e6-4818-b490-09f44d624afc,, 131 | a6bccb70-62b5-42aa-bfc9-3b0b886a2b2d,Y, 132 | a56d9450-3a26-485c-8ac3-24b6b54e2c1d,, 133 | a199697c-3cdc-47c7-a9c7-c1b07dd6c9dd,, 134 | aa600069-0a98-45e6-94fa-4000bfe46c25,, 135 | aaf0fc7b-d7f5-412c-b3b9-313a7c483666,, 136 | ad6c3742-e517-42eb-9dec-e74ed05387cc,, 137 | aefc1609-976b-423e-8516-f7d588d64ff7,, 138 | afca84b2-0277-4b1b-8696-5f14543f338c,Y, 139 | b8a79d39-346e-4258-a810-572b3b2c9ab1,, 140 | b8d6f3eb-f2d6-4342-af90-6d09f5257b6b,, 141 | b92cb1ca-baa9-4c74-b6dc-36389671ed76,, 142 | b207da3d-4baf-485a-98e1-657602479b3a,, 143 | b876b54b-6007-4d36-a6f4-efed8829d5fc,, 144 | bacbb01f-b877-4d62-8050-992f1d85543a,, 145 | baea951d-526a-49aa-8329-c8de676341fb,, 146 | bb45cf1a-4c58-4fe3-88ea-97ef27527507,, 147 | bbf40b5a-8ef9-4aec-a6c3-8b8706eb2ba0,, 148 | bc1f2967-f834-43bd-aadc-95afc897cfe7,, 149 | bc964128-da16-4e4c-af95-4d1211e78c70,, 150 | bd25c90e-d307-4cd9-adfb-46f5e323a81b,, 151 | bdcc429e-ed95-40d3-a1af-bad268d66b25,, 152 | bdd109ec-d5dd-4d91-92ad-66b679518026,, 153 | c6d73235-1dd5-4085-a3b3-50a3466c6168,, 154 | c8f42ad5-5a2f-4398-b9d9-207fd4fcc551,Y, 155 | c15ade79-43c0-4271-9f00-a121cefc92e5,Y, 156 | c70471f9-9c4a-41c9-b8f8-20ac38847a8e,Y, 157 | c228818e-eabe-434b-9d60-2fb84a6c5b2a,, 158 | c2330200-ad8e-4848-8c2b-b70612f4b80e,Y, 159 | c8752696-4ae5-4e47-b2ad-622496966fa9,, 160 | ca080447-fe99-4f6d-98c9-c69b68dacba3,, 161 | cc3e4991-6cce-40fe-a917-81a4fbb92ea6,, 162 | cc7f7675-d3c8-4a49-a2d7-a8959b694004,, 163 | cda46831-26d1-4dd8-ac50-6004d27d45dd,, 164 | cff5bcde-6a15-4c5b-b529-e1c528c46335,Y,Some slight vocal bleed in other 165 | d4df499c-e394-4753-b459-e167e6a58bad,, 166 | d4fe2408-c123-4739-93bb-22f558ae99d7,, 167 | d7d28204-a8ac-4c2b-bb3c-c941f4a00b85,Y, 168 | d8f0e410-5761-4d4a-9000-effe11089bbd,, 169 | d028d7c2-45c1-4846-b7df-4964238fd460,, 170 | d45bb3a6-eb80-44b3-b2ef-56cc9d5b4914,, 171 | d072debf-ea5b-4e8a-a447-cc1868cfc5fa,, 172 | d890ff35-300d-49f2-8054-49ab47262987,Y, 173 | d2401d3d-967c-46be-b9a0-3da571105158,, 174 | d624037a-1a76-4dd9-9e60-4ba380748a0b,, 175 | d4262245-3143-4c05-8423-6cbdc6253042,, 176 | dda2c057-6d73-43fa-a130-d7d6562c09ca,Y, 177 | dfb0e076-cb6b-4dcc-9934-c60070ff04d7,, 178 | e2ccbc17-44bf-431a-af2b-4cf2fbd19a72,, 179 | e2e4ce50-cd0d-4144-809d-4cf8c8e4912e,, 180 | e3ab3975-033b-40e2-b538-09396b3d4244,Y, 181 | e4de8632-6f69-4c63-8081-f4c2b77b40df,Y, 182 | e37cdb09-e648-4e9b-bc06-d178a964161c,Y, 183 | e62afdcd-0c96-4bee-80c7-1c17b897a6d7,, 184 | e78fa5de-cfdd-44fa-87c1-10a337b7011f,, 185 | e9336d31-c0df-4c91-be2b-7c4420c9cd34,Y, 186 | e1108928-9776-434c-bc57-c32dfdb7839c,, 187 | ea29ab4d-7f72-4331-b2a4-d3945c754211,, 188 | ea898682-08e7-4818-b516-8c0e10a4c20a,, 189 | ebea0f1d-8e23-469e-8eed-5269a9c684f0,, 190 | ed90a89a-bf22-444d-af3d-d9ac3896ebd2,, 191 | ee082817-dbda-4fbf-b5aa-8dce2320ae35,, 192 | ef1510e0-ba23-4b59-ba53-14181d73f213,, 193 | f0c565c5-fc73-4da1-b979-0fac0167f671,, 194 | f4b735de-14b1-4091-a9ba-c8b30c0740a7,Y, 195 | f9e58f4d-e361-4598-9c9a-d0a83529cc68,Y, 196 | f40ffd10-4e8b-41e6-bd8a-971929ca9138,, 197 | f76e2c13-9a9a-4cac-b6dd-45b5111aac6d,Y, 198 | fa46f72c-696d-45bc-bcc5-2b3305800565,, 199 | faad432d-6ad0-492d-96f1-321eeb9685b5,, 200 | fac94d9a-59da-4f83-9027-3eafe082ad16,Y, 201 | fcd1937f-2b21-4a78-889a-7b7e63e0ebdd,, 202 | fd6e4b4a-f33a-4f3c-aa6e-7c65fd5dc0bc,, 203 | fe3ae408-d35f-4c17-aa33-402238725a9d,Y, 204 | ff486935-7ce2-4e23-8908-0ff5fcc50856,Y, -------------------------------------------------------------------------------- /data/lightning/label_noise.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | from typing import List 5 | import os 6 | 7 | from data.dataset import LabelNoiseBleed 8 | from data.augment import CPUBase 9 | 10 | 11 | class LabelNoise(pl.LightningDataModule): 12 | def __init__( 13 | self, 14 | root: str, 15 | seq_duration: float = 6.0, 16 | samples_per_track: int = 64, 17 | random: bool = False, 18 | random_track_mix: bool = False, 19 | transforms: List[CPUBase] = None, 20 | batch_size: int = 16, 21 | ): 22 | super().__init__() 23 | self.save_hyperparameters( 24 | "root", 25 | "seq_duration", 26 | "samples_per_track", 27 | "random", 28 | "random_track_mix", 29 | "batch_size", 30 | ) 31 | # manually save transforms since pedalboard is not pickleable 32 | if transforms is None: 33 | self.transforms = None 34 | else: 35 | self.transforms = Compose(transforms) 36 | 37 | def setup(self, stage=None): 38 | label_noise_path = None 39 | current_dir = os.path.dirname(os.path.realpath(__file__)) 40 | if "label_noise.csv" in os.listdir(current_dir): 41 | label_noise_path = os.path.join(current_dir, "label_noise.csv") 42 | assert label_noise_path is not None 43 | if stage == "fit": 44 | self.train_dataset = LabelNoiseBleed( 45 | root=self.hparams.root, 46 | split="train", 47 | seq_duration=self.hparams.seq_duration, 48 | samples_per_track=self.hparams.samples_per_track, 49 | random=self.hparams.random, 50 | random_track_mix=self.hparams.random_track_mix, 51 | transform=self.transforms, 52 | clean_csv_path=label_noise_path, 53 | ) 54 | 55 | if stage == "validate" or stage == "fit": 56 | self.val_dataset = LabelNoiseBleed( 57 | root=self.hparams.root, 58 | split="train", 59 | seq_duration=self.hparams.seq_duration, 60 | random=self.hparams.random, 61 | random_track_mix=self.hparams.random_track_mix, 62 | transform=self.transforms, 63 | clean_csv_path=label_noise_path, 64 | ) 65 | 66 | def train_dataloader(self): 67 | return DataLoader( 68 | self.train_dataset, 69 | batch_size=self.hparams.batch_size, 70 | num_workers=8, 71 | shuffle=True, 72 | drop_last=True, 73 | ) 74 | 75 | def val_dataloader(self): 76 | return DataLoader( 77 | self.val_dataset, 78 | batch_size=self.hparams.batch_size, 79 | num_workers=8, 80 | shuffle=False, 81 | drop_last=True, 82 | ) 83 | -------------------------------------------------------------------------------- /data/lightning/musdb.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | from typing import List 5 | 6 | 7 | from data.dataset import FastMUSDB 8 | from data.augment import CPUBase 9 | 10 | 11 | class MUSDB(pl.LightningDataModule): 12 | def __init__( 13 | self, 14 | root: str, 15 | seq_duration: float = 6.0, 16 | samples_per_track: int = 64, 17 | random: bool = False, 18 | random_track_mix: bool = False, 19 | transforms: List[CPUBase] = None, 20 | batch_size: int = 16, 21 | ): 22 | super().__init__() 23 | self.save_hyperparameters( 24 | "root", 25 | "seq_duration", 26 | "samples_per_track", 27 | "random", 28 | "random_track_mix", 29 | "batch_size", 30 | ) 31 | # manually save transforms since pedalboard is not pickleable 32 | if transforms is None: 33 | self.transforms = None 34 | else: 35 | self.transforms = Compose(transforms) 36 | 37 | def setup(self, stage=None): 38 | if stage == "fit": 39 | self.train_dataset = FastMUSDB( 40 | root=self.hparams.root, 41 | subsets=["train"], 42 | seq_duration=self.hparams.seq_duration, 43 | samples_per_track=self.hparams.samples_per_track, 44 | random=self.hparams.random, 45 | random_track_mix=self.hparams.random_track_mix, 46 | transform=self.transforms, 47 | ) 48 | 49 | if stage == "validate" or stage == "fit": 50 | self.val_dataset = FastMUSDB( 51 | root=self.hparams.root, 52 | subsets=["test"], 53 | seq_duration=self.hparams.seq_duration, 54 | ) 55 | 56 | def train_dataloader(self): 57 | return DataLoader( 58 | self.train_dataset, 59 | batch_size=self.hparams.batch_size, 60 | num_workers=4, 61 | shuffle=True, 62 | drop_last=True, 63 | ) 64 | 65 | def val_dataloader(self): 66 | return DataLoader( 67 | self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4 68 | ) 69 | -------------------------------------------------------------------------------- /data/lightning/speech.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Compose 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | from typing import List, Optional 5 | import torch 6 | 7 | from data.dataset.speech import SpeechNoise as SpeechNoiseDataset 8 | from data.augment import CPUBase 9 | 10 | 11 | class SpeechNoise(pl.LightningDataModule): 12 | def __init__( 13 | self, 14 | speech_root: str, 15 | noise_root: str, 16 | seq_duration: float = 6.0, 17 | samples_per_track: int = 64, 18 | least_overlap_ratio: float = 0.5, 19 | snr_sampler: Optional[torch.distributions.Distribution] = None, 20 | mono: bool = True, 21 | transforms: List[CPUBase] = None, 22 | batch_size: int = 16, 23 | ): 24 | super().__init__() 25 | self.save_hyperparameters( 26 | "speech_root", 27 | "noise_root", 28 | "seq_duration", 29 | "samples_per_track", 30 | "least_overlap_ratio", 31 | "snr_sampler", 32 | "mono", 33 | "batch_size", 34 | ) 35 | # manually save transforms since pedalboard is not pickleable 36 | if transforms is None: 37 | self.transforms = None 38 | else: 39 | self.transforms = Compose(transforms) 40 | 41 | def setup(self, stage=None): 42 | if stage == "fit": 43 | self.train_dataset = SpeechNoiseDataset( 44 | speech_root=self.hparams.speech_root, 45 | noise_root=self.hparams.noise_root, 46 | seq_duration=self.hparams.seq_duration, 47 | samples_per_track=self.hparams.samples_per_track, 48 | least_overlap_ratio=self.hparams.least_overlap_ratio, 49 | snr_sampler=self.hparams.snr_sampler, 50 | mono=self.hparams.mono, 51 | transform=self.transforms, 52 | ) 53 | 54 | def train_dataloader(self): 55 | return DataLoader( 56 | self.train_dataset, 57 | batch_size=self.hparams.batch_size, 58 | num_workers=4, 59 | shuffle=True, 60 | drop_last=True, 61 | ) 62 | -------------------------------------------------------------------------------- /docs/aimless-logo-crop.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: demixing 2 | channels: 3 | - defaults 4 | - pytorch 5 | - nvidia 6 | 7 | dependencies: 8 | - pip 9 | - numpy 10 | - pytorch 11 | - torchaudio 12 | - torchvision 13 | - cudatoolkit=11.7 14 | - pip: 15 | - pytorch-lightning[extra]==1.9.4 16 | - torch-optimizer 17 | - musdb 18 | - git+https://github.com/yoyololicon/norbert 19 | - pedalboard 20 | - audio-data-pytorch 21 | - pyloudnorm 22 | - torch_fftconv 23 | - streamlit 24 | - librosa 25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.cli import LightningCLI 3 | from pytorch_lightning.strategies import DDPStrategy 4 | 5 | 6 | def cli_main(): 7 | torch.set_float32_matmul_precision("medium") 8 | 9 | cli = LightningCLI( 10 | trainer_defaults={ 11 | "accelerator": "gpu", 12 | "strategy": DDPStrategy(find_unused_parameters=False), 13 | "log_every_n_steps": 1, 14 | } 15 | ) 16 | 17 | 18 | if __name__ == "__main__": 19 | cli_main() 20 | -------------------------------------------------------------------------------- /scripts/audio_utils.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | import numpy as np 3 | import resampy 4 | from math import ceil 5 | import librosa 6 | 7 | from scipy.signal import stft 8 | import pyloudnorm as pyln 9 | import math 10 | import warnings 11 | 12 | # We silence this warning as we peak normalize the samples before bouncing them 13 | warnings.filterwarnings("ignore", message="Possible clipped samples in output.") 14 | 15 | 16 | def trim_relative_silence_from_audio(audio, sr, frame_duration=0.04, hop_duration=0.01): 17 | assert 0 < hop_duration <= frame_duration 18 | frame_length = int(frame_duration * sr) 19 | hop_length = int(hop_duration * sr) 20 | 21 | _, _, S = stft( 22 | audio, 23 | nfft=frame_length, 24 | noverlap=frame_length - hop_length, 25 | padded=True, 26 | nperseg=frame_length, 27 | boundary="constant", 28 | ) 29 | rms = librosa.feature.rms( 30 | S=S, frame_length=frame_length, hop_length=hop_length, pad_mode="constant" 31 | ).flatten() 32 | threshold = 0.01 * rms.max() 33 | active_flag = rms >= threshold 34 | active_idxs = active_flag.nonzero()[0] 35 | start_idx = max(int(max(active_idxs[0] - 1, 0) * hop_duration * sr), 0) 36 | end_idx = min( 37 | int(ceil(min(active_idxs[-1] + 1, rms.shape[0]) * hop_duration * sr)), 38 | audio.shape[0], 39 | ) 40 | 41 | return start_idx, end_idx 42 | 43 | 44 | def lufs_norm(data, sr, norm=-6): 45 | loudness = get_lufs(data, sr) 46 | 47 | assert not math.isinf(loudness) 48 | 49 | norm_data = pyln.normalize.loudness(data, loudness, norm) 50 | n, d = np.sum(np.array(norm_data)), np.sum(np.array(data)) 51 | gain = n / d if d else 0.0 52 | 53 | return norm_data, gain 54 | 55 | 56 | def get_lufs(data, sr): 57 | block_size = 0.4 if len(data) / sr >= 0.4 else len(data) / sr 58 | # measure the loudness first 59 | meter = pyln.Meter(rate=sr, block_size=block_size) # create BS.1770 meter 60 | loudness = meter.integrated_loudness(data) 61 | 62 | return loudness 63 | 64 | 65 | def peak_norm(data, mx): 66 | eps = 1e-10 67 | max_sample = np.max(np.abs(data)) 68 | scale_factor = mx / (max_sample + eps) 69 | 70 | return data * scale_factor 71 | 72 | 73 | def gain_to_db(g): 74 | return 20 * np.log10(g) 75 | 76 | 77 | def db_to_gain(db): 78 | return 10 ** (db / 20.0) 79 | 80 | 81 | def gain_from_combined_db_levels(dbs): 82 | return db_to_gain(sum(dbs)) 83 | 84 | 85 | def validate_audio(d): 86 | assert np.isnan(d).any() == False, "Nan value found in mixture" 87 | assert np.isneginf(d).any() == False, "Neg. Inf value found in mixture" 88 | assert np.isposinf(d).any() == False, "Pos. Inf value found in mixture" 89 | assert np.isinf(d).any() == False, "Inf value found in mixture" 90 | -------------------------------------------------------------------------------- /scripts/convert_to_dnr.py: -------------------------------------------------------------------------------- 1 | # Similar to dnr-utils (https://github.com/darius522/dnr-utils) 2 | 3 | from tqdm import tqdm 4 | import numpy as np 5 | from scipy import stats 6 | import audio_utils 7 | import soundfile as sf 8 | import os 9 | import pandas as pd 10 | import random 11 | import matplotlib.pyplot as plt 12 | from glob import glob 13 | import collections 14 | from itertools import repeat 15 | import argparse 16 | import librosa 17 | import math 18 | 19 | 20 | def apply_fadeout(audio, sr, duration=3.0): 21 | # convert to audio indices (samples) 22 | length = int(duration * sr) 23 | end = audio.shape[0] 24 | start = end - length 25 | 26 | # compute fade out curve 27 | # linear fade 28 | fade_curve = np.linspace(1.0, 0.0, length) 29 | 30 | # apply the curve 31 | audio[start:end] = audio[start:end] * fade_curve 32 | return audio 33 | 34 | 35 | def create_dir(dir, subdirs=None): 36 | os.makedirs(dir, exist_ok=True) 37 | 38 | if subdirs: 39 | for s in subdirs: 40 | os.makedirs(os.path.join(dir, s), exist_ok=True) 41 | 42 | return dir 43 | 44 | 45 | def make_unique_filename(files, dir=None): 46 | """ 47 | Given a list of files, generate a new - unique - filename 48 | """ 49 | while True: 50 | f = str(np.random.randint(low=100, high=100000)) 51 | if not f in files: 52 | break 53 | # Optionally create dir 54 | if dir: 55 | create_dir(os.path.join(dir, f)) 56 | return f, os.path.join(dir, f) 57 | 58 | 59 | def set_seed(seed): 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | 63 | 64 | def aug_gain(config, audio, low=0.0, high=1.0): 65 | gain = np.random.uniform(low=low, high=high) 66 | return audio * gain 67 | 68 | 69 | def aug_reverb(config, audio, ir): 70 | pass 71 | 72 | 73 | def gen_poisson(mu): 74 | """Adapted from https://github.com/rmalouf/learning/blob/master/zt.py""" 75 | r = np.random.uniform(low=stats.poisson.pmf(0, mu)) 76 | return int(stats.poisson.ppf(r, mu)) 77 | 78 | 79 | def gen_norm(mu, sig): 80 | return np.random.normal(loc=mu, scale=sig) 81 | 82 | 83 | def gen_skewnorm(skew, mu, sig): 84 | # negative a means skewed left while positive means skewed right; a=0 -> normal 85 | return stats.skewnorm(a=skew, loc=mu, scale=sig).rvs() 86 | 87 | 88 | def get_some_noise(shape): 89 | return np.random.randn(shape).astype(np.float32) 90 | 91 | 92 | class MixtureObj: 93 | def __init__( 94 | self, 95 | seq_dur=60.0, 96 | sr=44100, 97 | files=None, 98 | mix_lufs=None, 99 | partition="train", 100 | peak_norm_db=None, 101 | annot_path=None, 102 | class_dirs=None, 103 | ): 104 | self.seq_dur = seq_dur 105 | self.sr = sr 106 | self.input_dirs = class_dirs 107 | self.annot_path = annot_path 108 | self.mix = np.zeros(int(self.seq_dur * self.sr)) 109 | self.mix_lufs = mix_lufs 110 | self.partition = partition 111 | self.allocated_windows = { 112 | "music": {"times": [], "samples": [], "levels": [], "files": []}, 113 | "sfx": {"times": [], "samples": [], "levels": [], "files": []}, 114 | "speech": {"times": [], "samples": [], "levels": [], "files": []}, 115 | } 116 | 117 | # Set some submix-specific values 118 | self.submixes = {"music": [], "sfx": [], "speech": []} 119 | 120 | self.mix_max_peak = 0.0 121 | self.peak_norm = audio_utils.db_to_gain(peak_norm_db) 122 | 123 | self.files = files 124 | 125 | self.num_seg_music = gen_poisson(7.0) 126 | self.num_seg_sfx = gen_poisson(12.0) 127 | self.num_seg_speech = gen_poisson(8.0) 128 | 129 | def check_overlap_for_range(self, r1, c): 130 | ranges = self.allocated_windows[c]["times"] 131 | for r2 in ranges: 132 | if not (np.sum(r1) < r2[0] or r1[0] > np.sum(r2)): 133 | return True 134 | return False 135 | 136 | def _check_for_files_exhaustion(self): 137 | if len(self.files["speech"]) < self.num_seg_speech: 138 | return True 139 | return False 140 | 141 | def _load_wav(self, file): 142 | d, sr = librosa.load(file, sr=self.sr) 143 | return d 144 | 145 | def _register_event( 146 | self, mix_pos, length, cl, file, clip_start, clip_end, clip_gain 147 | ): 148 | self.allocated_windows[cl]["files"].append(file) 149 | self.allocated_windows[cl]["times"].append([mix_pos, length]) 150 | self.allocated_windows[cl]["samples"].append([clip_start, clip_end]) 151 | self.allocated_windows[cl]["levels"].append(clip_gain) 152 | 153 | def _update_event_levels(self, gain): 154 | for c in self.allocated_windows.keys(): 155 | self.allocated_windows[c]["levels"] = [ 156 | gain * x for x in self.allocated_windows[c]["levels"] 157 | ] 158 | 159 | def _get_time_range_for_submix(self, data, c, idx, num_events): 160 | """ 161 | Compute time window for sound in submix 162 | Args: 163 | data: array sound [Nspl, Nch] 164 | c: the sound class [music, sfx, speech] 165 | idx: the sound index 166 | """ 167 | file_len = len(data) / self.sr 168 | 169 | # We don't want two mix of the same classto overlap 170 | min_start = ( 171 | 0.0 172 | if not len(self.allocated_windows[c]["times"]) 173 | else np.ceil(np.sum(self.allocated_windows[c]["times"][-1])) 174 | ) 175 | 176 | # Check if we have enough time remaining (at least 30% length of current file or for speech the ENTIRE LENGTH) 177 | # if c == "speech": 178 | # import pdb 179 | 180 | # pdb.set_trace() 181 | if (self.seq_dur - min_start <= file_len * 0.3) or ( 182 | c == "speech" and (self.seq_dur - min_start <= file_len) 183 | ): 184 | return None, None 185 | 186 | # Pick start based on previous event pos. 187 | end_lim = self.seq_dur - file_len if c == "speech" else self.seq_dur - 2.0 188 | start = min(gen_skewnorm(skew=5, mu=int(min_start + 1), sig=2.0), end_lim) 189 | max_len = ( 190 | file_len if self.seq_dur - start >= file_len else (self.seq_dur - start) 191 | ) 192 | 193 | # If speech, take the whole thing 194 | if c == "speech": 195 | length = max_len 196 | else: 197 | length = max( 198 | min(gen_norm(mu=max_len / 2.0, sig=max_len / 10.0), max_len), 0.1 199 | ) 200 | 201 | return start, length 202 | 203 | def _create_from_annots(self, annot_path): 204 | """ 205 | Create the mixture strictly from a given annotation file 206 | """ 207 | columns = [ 208 | "file", 209 | "class", 210 | "mix_start_time", 211 | "mix_end_time", 212 | "clip_start_time", 213 | "clip_end_time", 214 | "clip_gain", 215 | "annotation", 216 | ] 217 | annots = pd.read_csv(annot_path, names=columns, skiprows=1) 218 | self.submixes = {k: np.zeros_like(self.mix) for k in self.submixes.keys()} 219 | 220 | for index, row in annots.iterrows(): 221 | f_class = row["class"] 222 | abs_path = os.path.join( 223 | self.input_dirs[f_class], "/".join(row["file"].split("/")[1:]) 224 | ) 225 | data, _ = sf.read(abs_path) 226 | 227 | mix_start = int(row["mix_start_time"] * self.sr) 228 | mix_end = int(row["mix_end_time"] * self.sr) 229 | clip_start = int(row["clip_start_time"] * self.sr) 230 | clip_end = int(row["clip_end_time"] * self.sr) 231 | 232 | clip_gain = float(row["clip_gain"]) 233 | 234 | length = min(len(data), clip_end - clip_start, mix_end - mix_start) 235 | 236 | clip = data[clip_start : clip_start + length] * clip_gain 237 | self.submixes[f_class][mix_start : mix_start + length] += clip 238 | 239 | # Finally collapse submixes 240 | self.mix = np.sum( 241 | np.stack([self.submixes[c] for c in self.submixes.keys()], -1), -1 242 | ) 243 | 244 | def _set_submix(self, c, num_seg): 245 | submix = np.zeros(int(self.seq_dur * self.sr)) 246 | class_lufs = np.random.uniform( 247 | self.mix_lufs[c] - self.mix_lufs["ranges"]["class"], 248 | self.mix_lufs[c] + self.mix_lufs["ranges"]["class"], 249 | ) 250 | 251 | for i in range(num_seg): 252 | # If eval and speech, mix without replacement until exhaustion 253 | if self.partition == "eval" and c == "speech": 254 | f = self.files[c].pop(random.randrange(len(self.files[c]))) 255 | else: 256 | f = np.random.choice(self.files[c]) 257 | 258 | d = self._load_wav(f) 259 | if c == "speech": 260 | # If speech is too long, trim to 70% and apply 3 second fade-out 261 | max_speech_len = math.floor(0.7 * self.seq_dur * self.sr - 1) 262 | if len(d) >= max_speech_len: 263 | d = apply_fadeout(d[:max_speech_len], self.sr, 3) 264 | s, l = self._get_time_range_for_submix(d, c, i, num_seg) 265 | clip_lufs = np.random.uniform( 266 | class_lufs - self.mix_lufs["ranges"]["clip"], 267 | class_lufs + self.mix_lufs["ranges"]["clip"], 268 | ) 269 | if s != None and l != None: 270 | try: 271 | s_clip, l_clip = int(s * self.sr), int(l * self.sr) 272 | r_spl = np.random.randint(0, len(d) - l_clip) if c == "music" else 0 273 | data = np.copy(d[r_spl : r_spl + l_clip]) 274 | data_norm, gain = audio_utils.lufs_norm( 275 | data=data, sr=self.sr, norm=clip_lufs 276 | ) 277 | submix[s_clip : s_clip + l_clip] = data_norm 278 | self._register_event( 279 | mix_pos=s, 280 | length=l, 281 | cl=c, 282 | file=f, 283 | clip_start=r_spl / self.sr, 284 | clip_end=(r_spl / self.sr) + l, 285 | clip_gain=gain, 286 | ) 287 | except Exception as e: 288 | print(e) 289 | print("could not register event. Skipping...") 290 | print(f, c, s, l, s_clip, l_clip) 291 | elif c == "speech": 292 | self.files[c].append(f) 293 | 294 | self.mix_max_peak = max(self.mix_max_peak, np.max(np.abs(submix))) 295 | self.submixes[c] = submix 296 | 297 | def _create_final_mix(self): 298 | # Add some noise to sfx submix 299 | rand_range = self.mix_lufs["ranges"]["class"] 300 | 301 | # Set the peak norm gain 302 | peak_norm_gain = ( 303 | 1.0 304 | if self.mix_max_peak <= self.peak_norm 305 | else self.peak_norm / self.mix_max_peak 306 | ) 307 | # Compute master gain 308 | master_lufs = np.random.uniform( 309 | self.mix_lufs["master"] - rand_range, self.mix_lufs["master"] + rand_range 310 | ) 311 | 312 | # After adding to master, peak normalize the submixes 313 | for c in self.submixes.keys(): 314 | # Peak norm 315 | self.submixes[c] *= peak_norm_gain 316 | # Add submix to main mix 317 | self.mix += self.submixes[c] 318 | 319 | self.mix, master_gain = audio_utils.lufs_norm( 320 | data=self.mix, sr=self.sr, norm=master_lufs 321 | ) 322 | 323 | # Main LUFS norm 324 | for c in self.submixes.keys(): 325 | self.submixes[c] *= master_gain 326 | 327 | # Optionally add some noise to the mix (i.e. avoid digital silences) 328 | if self.mix_lufs["noise"] != None: 329 | noise_lufs = np.random.uniform( 330 | self.mix_lufs["noise"] - rand_range, self.mix_lufs["noise"] + rand_range 331 | ) 332 | noise = get_some_noise(shape=int(self.seq_dur * self.sr)) 333 | noise, _ = audio_utils.lufs_norm(data=noise, sr=self.sr, norm=noise_lufs) 334 | self.mix += noise 335 | 336 | # After peak norm / master norm, we need to update the registered events' props 337 | self._update_event_levels(master_gain * peak_norm_gain) 338 | 339 | def __call__(self): 340 | # Check if we build from the annotation or not 341 | if self.annot_path: 342 | self._create_from_annots(annot_path=self.annot_path) 343 | return self.submixes, self.mix, None, None 344 | # If not, Create submixes from scratch 345 | else: 346 | if not self._check_for_files_exhaustion(): 347 | self._set_submix(c="music", num_seg=self.num_seg_music) 348 | self._set_submix(c="speech", num_seg=self.num_seg_speech) 349 | self._set_submix(c="sfx", num_seg=self.num_seg_sfx) 350 | self._create_final_mix() 351 | 352 | return self.submixes, self.mix, self.allocated_windows, self.files 353 | else: 354 | return None, None, None, self.files 355 | 356 | 357 | class DatasetCompiler: 358 | def __init__( 359 | self, 360 | speech_set: str = None, 361 | fx_set: str = None, 362 | music_set: str = None, 363 | num_output_files: int = 1000, 364 | output_dir: str = ".", 365 | sample_rate: int = 44100, 366 | seq_dur: float = 60.0, 367 | partition: str = "tr", 368 | mix_lufs: dict = None, 369 | peak_norm_db: float = -0.5, 370 | ): 371 | self.output_files = [] 372 | self.speech_set = speech_set 373 | self.fx_set = fx_set 374 | self.music_set = music_set 375 | self.partition = partition 376 | self.sr = sample_rate 377 | self.num_files = num_output_files 378 | self.output_dir = output_dir 379 | self.wavfiles = { 380 | "speech": self._get_filepaths(speech_set), 381 | "sfx": self._get_filepaths(fx_set), 382 | "music": self._get_filepaths(music_set), 383 | } 384 | self.mix_lufs = mix_lufs 385 | self.peak_norm_db = peak_norm_db 386 | self.seq_dur = seq_dur 387 | 388 | self.stats = {"overlap": {}, "times": {}} 389 | 390 | def _get_mix_stats(self, data, res): 391 | """ 392 | Compute amount of overlap between given classes 393 | Args: 394 | data: times 395 | res: resolution of the timesteps 396 | """ 397 | # Average sample length 398 | for c in data.keys(): 399 | for t in data[c]["times"]: 400 | self.stats["times"].setdefault(c, []).append(t[-1]) 401 | 402 | overlap = 0.0 403 | classes = list(data.keys()) 404 | timesteps = np.arange(0, self.seq_dur, res) 405 | labels = [] 406 | 407 | # Walk through steps 408 | for i, j in enumerate(timesteps): 409 | # Check if each class is present at this step 410 | detect = [] 411 | for c in classes: 412 | for t in data[c]["times"]: 413 | s, e = t[0], sum(t) 414 | if j >= s and j <= e: 415 | detect.append(c) 416 | break 417 | labels.append("-".join(sorted(detect))) 418 | 419 | counter = collections.Counter(labels) 420 | 421 | values = list(counter.values()) 422 | keys = list(counter.keys()) 423 | values = [x / sum(values) for x in values] 424 | for k, v in zip(keys, values): 425 | self.stats["overlap"].setdefault(k, []).append(v) 426 | 427 | return overlap 428 | 429 | def _plot_stats(self): 430 | plt.rcParams.update({"font.size": 25}) 431 | 432 | # Build the plot 433 | fig, ax = plt.subplots(2, 1, figsize=(25, 25)) 434 | 435 | for i, stats in enumerate(self.stats.keys()): 436 | data = self.stats[stats] 437 | classes = list(data.keys()) 438 | 439 | x_pos = np.arange(len(classes)) 440 | means = [np.mean(data[k]) for k in classes] 441 | error = [np.std(data[k]) for k in classes] 442 | try: 443 | classes[classes.index("")] = "silence" 444 | except: 445 | pass 446 | 447 | ax[i].bar( 448 | x_pos, 449 | means, 450 | yerr=error, 451 | align="center", 452 | alpha=0.5, 453 | ecolor="black", 454 | capsize=10, 455 | ) 456 | ax[i].set_xticks(x_pos) 457 | ax[i].set_xticklabels( 458 | classes, rotation=45, va="center", position=(0, -0.28) 459 | ) 460 | ax[i].yaxis.grid(True) 461 | 462 | ax[0].set_ylabel("Presence Amount") 463 | ax[0].set_title( 464 | self.partition 465 | + " - Classes Overlap (" 466 | + str(len(self.output_files)) 467 | + " Mixtures)" 468 | ) 469 | ax[1].set_ylabel("Length (s)") 470 | ax[1].set_title( 471 | self.partition 472 | + " - Average Sample Length per Classes (" 473 | + str(len(self.output_files)) 474 | + " Mixtures)" 475 | ) 476 | 477 | # Save the figure and show 478 | plt.tight_layout() 479 | plt.savefig(os.path.join(self.output_dir, self.partition + "_set_stats.png")) 480 | 481 | def _get_filepaths(self, dir, c=""): 482 | files = glob(dir + "/**/*.wav", recursive=True) 483 | files += glob(dir + "/**/*.flac", recursive=True) 484 | # If eval and speech, repeat list twice 485 | if self.partition == "eval" and c == "speech": 486 | return [x for item in files for x in repeat(item, 2)] 487 | return files 488 | 489 | def _get_annot_paths(self, dir): 490 | return glob(dir + "/**/annots.csv", recursive=True) 491 | 492 | def _write_wav(self, f_dir, sr, data, source): 493 | audio_utils.validate_audio(data) 494 | sf.write(os.path.join(f_dir, source + ".wav"), data, sr, subtype="FLOAT") 495 | 496 | def _write_ground_truth(self, f_dir, data): 497 | out_df = pd.DataFrame( 498 | columns=[ 499 | "file", 500 | "class", 501 | "mix start time", 502 | "mix end time", 503 | "clip start sample", 504 | "clip end sample", 505 | "clip gain", 506 | "annotation", 507 | ] 508 | ) 509 | 510 | sources = list(data.keys()) 511 | # Sort events in ascending order, time-wise 512 | all_data = [ 513 | [t, spl, f, s, l] 514 | for s in sources 515 | for t, spl, f, l in zip( 516 | data[s]["times"], 517 | data[s]["samples"], 518 | data[s]["files"], 519 | data[s]["levels"], 520 | ) 521 | ] 522 | sorted_data = sorted(all_data, key=lambda x: x[0][0]) 523 | 524 | # Retrieve annotations and write to pd frame 525 | for e in sorted_data: 526 | time, spl, file, source, level = e 527 | f_id = os.path.basename(file).split(".")[0] 528 | annot = f_id 529 | 530 | row = [file, source, time[0], sum(time), spl[0], spl[1], level, annot] 531 | out_df.loc[len(out_df.index)] = row 532 | 533 | out_df.to_csv(os.path.join(f_dir, "annots.csv")) 534 | 535 | def __call__(self): 536 | print(f"Building {self.partition} dataset...") 537 | for i in tqdm(range(self.num_files)): 538 | sources, mix, annots, files = MixtureObj( 539 | seq_dur=self.seq_dur, 540 | partition=self.partition, 541 | sr=self.sr, 542 | files=self.wavfiles, 543 | mix_lufs=self.mix_lufs, 544 | peak_norm_db=self.peak_norm_db, 545 | )() 546 | # For eval set, we drain the speech dataset until empty 547 | if sources: 548 | f_name, f_dir = make_unique_filename(self.output_files, self.output_dir) 549 | self._write_ground_truth(f_dir, annots) 550 | sources["mix"] = mix 551 | for source in list(sources.keys()): 552 | self._write_wav(f_dir, int(self.sr), sources[source], source) 553 | self.output_files.append(f_name) 554 | self._get_mix_stats(annots, 0.1) 555 | 556 | 557 | def main(): 558 | set_seed(42) 559 | parser = argparse.ArgumentParser( 560 | description="Convert existing music, speech, and sfx datasets into format used by DNR" 561 | ) 562 | parser.add_argument("music_dir", type=str) 563 | parser.add_argument("speech_dir", type=str) 564 | parser.add_argument("sfx_dir", type=str) 565 | parser.add_argument( 566 | "-n", 567 | "--num_output_files", 568 | type=int, 569 | default=1000, 570 | help="Total number of output files before splitting", 571 | ) 572 | parser.add_argument("-sr", type=int, help="Sample rate", default=44100) 573 | parser.add_argument("-o", "--output_dir", type=str, required=False) 574 | 575 | args = parser.parse_args() 576 | music_dir = args.music_dir 577 | speech_dir = args.speech_dir 578 | sfx_dir = args.sfx_dir 579 | num_output_files = args.num_output_files 580 | sample_rate = args.sr 581 | output_dir = args.output_dir 582 | 583 | # Create output directory if it doesn't exist 584 | if not output_dir: 585 | output_dir = os.path.join(os.getcwd(), "data") 586 | 587 | splits = {"train": 0.7, "val": 0.1, "test": 0.2} 588 | 589 | mix_lufs = { 590 | "music": -24, 591 | "sfx": -21, 592 | "speech": -17, 593 | "master": -27, 594 | "noise": None, 595 | "ranges": { 596 | "class": 2, 597 | "clip": 1, 598 | }, 599 | } 600 | 601 | for split_name in ["train", "val", "test"]: 602 | split_output_dir = create_dir(os.path.join(output_dir, split_name)) 603 | DatasetCompiler( 604 | speech_dir, 605 | sfx_dir, 606 | music_dir, 607 | num_output_files=round(num_output_files * splits[split_name]), 608 | output_dir=split_output_dir, 609 | sample_rate=sample_rate, 610 | partition=split_name, 611 | mix_lufs=mix_lufs, 612 | )() 613 | 614 | 615 | if __name__ == "__main__": 616 | main() 617 | -------------------------------------------------------------------------------- /scripts/dataset_split_and_mix.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | import torchaudio 4 | import sys 5 | 6 | 7 | # Run with python scripts/dataset_split_and_mix.py /path/to/dataset 8 | def main(): 9 | root = Path(sys.argv[1]) 10 | splits = [0.9, 0.1, 0] 11 | 12 | train = root / "train" 13 | # train = Path("./") / "train" 14 | valid = root / "valid" 15 | # valid = Path("./") / "valid" 16 | # test = root / "test" 17 | train.mkdir(exist_ok=True) 18 | valid.mkdir(exist_ok=True) 19 | # test.mkdir(exist_ok=True) 20 | total_tracks = len(list(root.iterdir())) - 3 # 3 for train, valid, test 21 | num_train = int(total_tracks * splits[0]) 22 | num_valid = int(total_tracks * splits[1]) 23 | num_test = int(total_tracks * splits[2]) 24 | 25 | # Add remainder as necessary 26 | remainder = total_tracks - (num_train + num_valid + num_test) 27 | for i in range(remainder): 28 | if i % 2 == 0: 29 | num_train += 1 30 | elif i % 2 == 1: 31 | num_valid += 1 32 | # else: 33 | # num_test += 1 34 | 35 | num = 0 36 | for d in tqdm(root.iterdir(), total=total_tracks): 37 | if d.is_dir() and d.name != "train" and d.name != "valid" and d.name != "test": 38 | bass, sr = torchaudio.load(str(d / "bass.wav")) 39 | drums, sr = torchaudio.load(str(d / "drums.wav")) 40 | other, sr = torchaudio.load(str(d / "other.wav")) 41 | vocals, sr = torchaudio.load(str(d / "vocals.wav")) 42 | mixture = bass + drums + other + vocals 43 | torchaudio.save(str(d / "mixture.wav"), mixture, sr) 44 | 45 | if num < num_train: 46 | d.rename(train / d.name) 47 | # elif num < num_train + num_valid: 48 | else: 49 | d.rename(valid / d.name) 50 | # else: 51 | # d.rename(test / d.name) 52 | num += 1 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /scripts/download_youtube.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torchaudio.utils import download_asset 4 | from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS 5 | from torchaudio.transforms import Fade 6 | from tqdm import tqdm 7 | from audio_data_pytorch import YoutubeDataset, Resample 8 | import os 9 | 10 | bundle = HDEMUCS_HIGH_MUSDB_PLUS 11 | sample_rate = bundle.sample_rate 12 | fade_overlap = 0.1 13 | root = "./youtube_data" 14 | print(f"Sample rate: {sample_rate}") 15 | 16 | 17 | def main(): 18 | url_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "urls.txt") 19 | with open(url_file_path) as f: 20 | urls = f.readlines() 21 | print("Loading dataset...") 22 | dataset = YoutubeDataset( 23 | root=root, 24 | urls=urls, 25 | crop_length=10, # Crop source in 10s chunks (optional but suggested) 26 | transforms=torch.nn.Sequential( 27 | Resample(source=48000, target=sample_rate), 28 | Fade( 29 | fade_in_len=0, 30 | fade_out_len=int(fade_overlap * sample_rate), 31 | fade_shape="linear", 32 | ), 33 | ), 34 | ) 35 | print("Loading model...") 36 | model = bundle.get_model() 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | model.to(device) 39 | 40 | print("Separating chunks...") 41 | for i, chunk in enumerate(tqdm(dataset)): 42 | original_url_idx = int(dataset.wavs[i].split("/")[-1].split(".")[0]) 43 | sources = separate(chunk, model=model, device=device) 44 | 45 | if not os.path.exists(f"{root}/separated/{original_url_idx}"): 46 | os.makedirs(f"{root}/separated/{original_url_idx}") 47 | for source in sources: 48 | torchaudio.save( 49 | f"{root}/separated/{original_url_idx}/{source}.wav", 50 | sources[source], # Transpose to get channels first for soundfile 51 | sample_rate=sample_rate, 52 | ) 53 | 54 | 55 | def separate(chunk: torch.Tensor, model: torch.nn.Module, device: torch.device): 56 | chunk.to(device) 57 | ref = chunk.mean(0) 58 | chunk = (chunk - ref.mean()) / ref.std() # normalization 59 | 60 | with torch.no_grad(): 61 | out = model.forward(chunk[None]) 62 | 63 | sources = out.squeeze(0) 64 | sources = sources * ref.std() + ref.mean() # denormalization 65 | 66 | sources_list = model.sources 67 | sources = list(sources) 68 | dict_sources = dict(zip(sources_list, sources)) 69 | return dict_sources 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /scripts/musdb_to_voiceless.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import soundfile as sf 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("path", type=str, help="Path to the MUSDB18 dataset") 8 | parser.add_argument("output", type=str, help="Path to the output folder") 9 | args = parser.parse_args() 10 | print("Converting MUSDB18 to voiceless") 11 | print("This will take a while ...") 12 | print("") 13 | 14 | path = pathlib.Path(args.path) 15 | output = pathlib.Path(args.output) 16 | output.mkdir(exist_ok=True) 17 | 18 | for track in path.glob("**/mixture.wav"): 19 | folder = track.parent 20 | print(f"Processing {folder}") 21 | 22 | bass, sr = sf.read(folder / "bass.wav") 23 | drums, sr = sf.read(folder / "drums.wav") 24 | other, sr = sf.read(folder / "other.wav") 25 | mix = bass + drums + other 26 | 27 | output_folder = output / folder.relative_to(path) 28 | output_folder.mkdir(exist_ok=True, parents=True) 29 | sf.write(output_folder / "mixture.wav", mix, sr) 30 | -------------------------------------------------------------------------------- /scripts/urls.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/watch?v=mlB-3842Cjs -------------------------------------------------------------------------------- /scripts/webapp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import librosa 4 | import torchaudio 5 | import numpy as np 6 | import streamlit as st 7 | import librosa.display 8 | import pyloudnorm as pyln 9 | import matplotlib.pyplot as plt 10 | 11 | from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS 12 | from torchaudio.transforms import Fade 13 | 14 | 15 | st.set_page_config(page_title="aimless-splitter") 16 | st.image("docs/aimless-logo-crop.png", use_column_width="always") 17 | 18 | bundle = HDEMUCS_HIGH_MUSDB_PLUS 19 | sample_rate = bundle.sample_rate 20 | fade_overlap = 0.1 21 | 22 | 23 | @st.experimental_singleton 24 | def load_hdemucs(): 25 | print("Loading pretrained HDEMUCS model...") 26 | model = bundle.get_model() 27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 28 | model.to(device) 29 | 30 | return model 31 | 32 | 33 | def plot_spectrogram(y, *, sample_rate, figsize=(12, 3)): 34 | # Convert to mono 35 | if y.ndim > 1: 36 | y = y[0] 37 | 38 | fig = plt.figure(figsize=figsize) 39 | D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max, top_db=120) 40 | img = librosa.display.specshow(D, y_axis="linear", x_axis="time", sr=sample_rate) 41 | fig.colorbar(img, format="%+2.f dB") 42 | st.pyplot(fig=fig, clear_figure=True) 43 | 44 | 45 | def separate_sources( 46 | model: torch.nn.Module, 47 | mix: torch.Tensor, 48 | segment: float = 10.0, 49 | overlap: float = 0.1, 50 | device=None, 51 | ): 52 | """ 53 | Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment. 54 | 55 | Args: 56 | segment (int): segment length in seconds 57 | device (torch.device, str, or None): if provided, device on which to 58 | execute the computation, otherwise `mix.device` is assumed. 59 | When `device` is different from `mix.device`, only local computations will 60 | be on `device`, while the entire tracks will be stored on `mix.device`. 61 | """ 62 | if device is None: 63 | device = mix.device 64 | else: 65 | device = torch.device(device) 66 | 67 | batch, channels, length = mix.shape 68 | 69 | chunk_len = int(sample_rate * segment * (1 + overlap)) 70 | start = 0 71 | end = chunk_len 72 | overlap_frames = overlap * sample_rate 73 | fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear") 74 | 75 | final = torch.zeros(batch, len(model.sources), channels, length, device=device) 76 | 77 | while start < length - overlap_frames: 78 | chunk = mix[:, :, start:end] 79 | with torch.no_grad(): 80 | out = model.forward(chunk) 81 | out = fade(out) 82 | final[:, :, :, start:end] += out 83 | if start == 0: 84 | fade.fade_in_len = int(overlap_frames) 85 | start += int(chunk_len - overlap_frames) 86 | else: 87 | start += chunk_len 88 | end += chunk_len 89 | if end >= length: 90 | fade.fade_out_len = 0 91 | return final 92 | 93 | 94 | def process_file(file, model: torch.nn.Module, device: torch.device): 95 | import tempfile 96 | import shutil 97 | from pathlib import Path 98 | 99 | # Cache file to disk so we can read it with ffmpeg 100 | with tempfile.NamedTemporaryFile("wb", suffix=Path(file.name).suffix) as f: 101 | shutil.copyfileobj(file, f) 102 | 103 | duration = librosa.get_duration(filename=f.name) 104 | num_frames = -1 105 | if duration > 60: 106 | st.write(f"File is {duration:.01f}s long. Loading the first 60s only.") 107 | sr = librosa.get_samplerate(f.name) 108 | num_frames = 60 * sr 109 | 110 | x, sr = torchaudio.load(f.name, num_frames=num_frames) 111 | 112 | # resample if needed 113 | if sr != sample_rate: 114 | x = torchaudio.functional.resample(sr, sample_rate)(x) 115 | 116 | st.subheader("Mix") 117 | x_numpy = x.numpy() 118 | plot_spectrogram(x_numpy, sample_rate=sample_rate) 119 | st.audio(x_numpy, sample_rate=sample_rate) 120 | 121 | waveform = x.to(device) 122 | 123 | # split into 10.0 sec chunks 124 | ref = waveform.mean(0) 125 | waveform = (waveform - ref.mean()) / ref.std() # normalization 126 | sources = separate_sources( 127 | model, 128 | waveform[None], 129 | device=device, 130 | )[0] 131 | sources = sources * ref.std() + ref.mean() 132 | 133 | sources_list = model.sources 134 | sources = list(sources) 135 | 136 | audios = dict(zip(sources_list, sources)) 137 | 138 | for source, audio in audios.items(): 139 | audio = audio.cpu().numpy() 140 | st.subheader(source.capitalize()) 141 | plot_spectrogram(audio, sample_rate=sample_rate) 142 | st.audio(audio, sample_rate=sample_rate) 143 | 144 | 145 | # load pretrained model 146 | hdemucs = load_hdemucs() 147 | 148 | # load audio 149 | uploaded_file = st.file_uploader("Choose a file to demix.") 150 | 151 | if uploaded_file is not None: 152 | # split with hdemucs 153 | hdemucs_sources = process_file(uploaded_file, hdemucs, "cuda:0") 154 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="aimless", 5 | version="0.0.1", 6 | author="Artificial Intelligence and Music League for Effective Source Separation", 7 | author_email="chin-yun.yu@qmul.ac.uk", 8 | packages=setuptools.find_packages(exclude=["tests", "tests.*", "data", "data.*"]), 9 | install_requires=["torch", "pytorch-lightning", "torch_fftconv"], 10 | classifiers=[ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: MIT License", 13 | "Operating System :: OS Independent", 14 | ], 15 | ) 16 | --------------------------------------------------------------------------------