├── .github
└── workflows
│ └── black.yml
├── .gitignore
├── README.md
├── aimless
├── __init__.py
├── augment.py
├── lightning
│ ├── __init__.py
│ ├── freq_mask.py
│ └── waveform.py
├── loss
│ ├── __init__.py
│ ├── freq.py
│ └── time.py
├── models
│ ├── __init__.py
│ ├── band_split_rnn.py
│ ├── demucs_split.py
│ ├── rel_tdemucs.py
│ └── xumx.py
└── utils
│ ├── __init__.py
│ ├── mwf.py
│ └── utils.py
├── cfg
├── cdx_a
│ ├── bandsplit_rnn.yaml
│ └── hdemucs.yaml
├── demucs.yaml
├── hdemucs.yaml
├── mdx_a
│ └── hdemucs.yaml
├── speech_enhance.yaml
└── xumx.yaml
├── data
├── __init__.py
├── augment.py
├── dataset
│ ├── __init__.py
│ ├── base.py
│ ├── dnr.py
│ ├── fast_musdb.py
│ ├── label_noise_bleed.py
│ └── speech.py
└── lightning
│ ├── __init__.py
│ ├── bleed.py
│ ├── dnr.py
│ ├── label_noise.csv
│ ├── label_noise.py
│ ├── musdb.py
│ └── speech.py
├── docs
└── aimless-logo-crop.svg
├── environment.yml
├── main.py
├── scripts
├── audio_utils.py
├── convert_to_dnr.py
├── dataset_split_and_mix.py
├── download_youtube.py
├── musdb_to_voiceless.py
├── urls.txt
└── webapp.py
└── setup.py
/.github/workflows/black.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | lint:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - uses: psf/black@stable
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### CUDA template
2 | *.i
3 | *.ii
4 | *.gpu
5 | *.ptx
6 | *.cubin
7 | *.fatbin
8 |
9 | ### VirtualEnv template
10 | # Virtualenv
11 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
12 | .Python
13 | [Bb]in
14 | [Ii]nclude
15 | [Ll]ib
16 | [Ll]ib64
17 | [Ll]ocal
18 | pyvenv.cfg
19 | .venv
20 | pip-selfcheck.json
21 |
22 | ### Linux template
23 | *~
24 |
25 | # temporary files which can be created if a process still has a handle open of a deleted file
26 | .fuse_hidden*
27 |
28 | # KDE directory preferences
29 | .directory
30 |
31 | # Linux trash folder which might appear on any partition or disk
32 | .Trash-*
33 |
34 | # .nfs files are created when an open file is removed but is still being accessed
35 | .nfs*
36 |
37 | ### JupyterNotebooks template
38 | # gitignore template for Jupyter Notebooks
39 | # website: http://jupyter.org/
40 |
41 | .ipynb_checkpoints
42 | */.ipynb_checkpoints/*
43 |
44 | # IPython
45 | profile_default/
46 | ipython_config.py
47 |
48 | # Remove previous ipynb_checkpoints
49 | # git rm -r .ipynb_checkpoints/
50 |
51 | ### macOS template
52 | # General
53 | .AppleDouble
54 | .DS_Store
55 | .LSOverride
56 |
57 | # Icon must end with two \r
58 | Icon
59 |
60 | # Thumbnails
61 | ._*
62 |
63 | # Files that might appear in the root of a volume
64 | .DocumentRevisions-V100
65 | .fseventsd
66 | .Spotlight-V100
67 | .TemporaryItems
68 | .Trashes
69 | .VolumeIcon.icns
70 | .com.apple.timemachine.donotpresent
71 |
72 | # Directories potentially created on remote AFP share
73 | .AppleDB
74 | .AppleDesktop
75 | Network Trash Folder
76 | Temporary Items
77 | .apdisk
78 |
79 | ### Python template
80 | # Byte-compiled / optimized / DLL files
81 | __pycache__/
82 | *.py[cod]
83 | *$py.class
84 |
85 | # C extensions
86 | *.so
87 |
88 | # Distribution / packaging
89 | develop-eggs/
90 | downloads/
91 | eggs/
92 | .eggs/
93 | lib/
94 | lib64/
95 | parts/
96 | sdist/
97 | var/
98 | wheels/
99 | share/python-wheels/
100 | *.egg-info/
101 | .installed.cfg
102 | *.egg
103 | MANIFEST
104 |
105 | # PyInstaller
106 | # Usually these files are written by a python script from a template
107 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
108 | *.manifest
109 | *.spec
110 |
111 | # Installer logs
112 | pip-log.txt
113 | pip-delete-this-directory.txt
114 |
115 | # Unit test / coverage reports
116 | htmlcov/
117 | .tox/
118 | .nox/
119 | .coverage
120 | .coverage.*
121 | .cache
122 | nosetests.xml
123 | coverage.xml
124 | *.cover
125 | *.py,cover
126 | .hypothesis/
127 | .pytest_cache/
128 | cover/
129 |
130 | # Translations
131 | *.mo
132 | *.pot
133 |
134 | # Django stuff:
135 | *.log
136 | local_settings.py
137 | db.sqlite3
138 | db.sqlite3-journal
139 |
140 | # Flask stuff:
141 | instance/
142 | .webassets-cache
143 |
144 | # Scrapy stuff:
145 | .scrapy
146 |
147 | # Sphinx documentation
148 | docs/_build/
149 |
150 | # PyBuilder
151 | .pybuilder/
152 | target/
153 |
154 | # Jupyter Notebook
155 |
156 | # IPython
157 |
158 | # pyenv
159 | # For a library or package, you might want to ignore these files since the code is
160 | # intended to run in multiple environments; otherwise, check them in:
161 | # .python-version
162 |
163 | # pipenv
164 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
165 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
166 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
167 | # install all needed dependencies.
168 | #Pipfile.lock
169 |
170 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
171 | __pypackages__/
172 |
173 | # Celery stuff
174 | celerybeat-schedule
175 | celerybeat.pid
176 |
177 | # SageMath parsed files
178 | *.sage.py
179 |
180 | # Environments
181 | .env
182 | env/
183 | venv/
184 | ENV/
185 | env.bak/
186 | venv.bak/
187 |
188 | # Spyder project settings
189 | .spyderproject
190 | .spyproject
191 |
192 | # Rope project settings
193 | .ropeproject
194 |
195 | # mkdocs documentation
196 | /site
197 |
198 | # mypy
199 | .mypy_cache/
200 | .dmypy.json
201 | dmypy.json
202 |
203 | # Pyre type checker
204 | .pyre/
205 |
206 | # pytype static type analyzer
207 | .pytype/
208 |
209 | # Cython debug symbols
210 | cython_debug/
211 |
212 | # Custom
213 | lightning_logs/*
214 | *.pyc
215 | build/
216 | dist/
217 | .idea/
218 | out/
219 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 |
7 | AIMLESS (Artificial Intelligence and Music League for Effective Source Separation) is a special interest group in audio source separation at C4DM, consisting of PhD students from the AIM CDT program.
8 | This repository is adapted from [Danna-Sep](https://github.com/yoyololicon/music-demixing-challenge-ismir-2021-entry) and
9 | contains our training code for the [SDX23 Sound Demixing Challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023).
10 |
11 |
12 | ## Quick Start
13 |
14 | The conda environment we used for training is described in `environment.yml`.
15 | The below commands should be runnable if you're using QMUL EECS servers.
16 | If you want to run it on your local machine, change the `root` param in the config to where you downloaded the MUSDB18-HQ dataset.
17 |
18 | ### Frequency-masking-based model
19 |
20 |
21 | ```commandline
22 | python main.py fit --config cfg/xumx.yaml
23 | ```
24 |
25 | ### Raw-waveform-based model
26 |
27 |
28 | ```commandline
29 | python main.py fit --config cfg/demucs.yaml
30 | ```
31 |
32 | ## Install the repository as a package
33 |
34 | This step is required if you want to test our submission repositories (see section [Reproduce the winning submission](#reproduce-the-winning-submission)) locally.
35 | ```sh
36 | pip install git+https://github.com/aim-qmul/sdx23-aimless
37 | ```
38 |
39 | ## Reproduce the winning submission
40 |
41 | ### CDX Leaderboard A, submission ID 220319
42 |
43 | This section describes how to reproduce the [best perform model](https://gitlab.aicrowd.com/yoyololicon/cdx-submissions/-/issues/90) we used on CDX leaderboard A.
44 | The submission consists of one HDemucs predicting all the targets and one BandSplitRNN predicitng the music from the mixture.
45 |
46 | To train the HDemucs:
47 | ```commandline
48 | python main.py fit --config cfg/cdx_a/hdemucs.yaml --data.init_args.root /DNR_DATASET_ROOT/dnr_v2/
49 | ```
50 | Remember to change `/DNR_DATASET_ROOT/dnr_v2/` to your download location of [Divide and Remaster (DnR) dataset](https://zenodo.org/record/6949108).
51 |
52 | To train the BandSplitRNN:
53 | ```commandline
54 | python main.py fit --config cfg/cdx_a/bandsplit_rnn.yaml --data.init_args.root /DNR_DATASET_ROOT/dnr_v2/
55 | ```
56 |
57 | We trained the models with no more than 4 GPUs, depending on the resources we had at the time.
58 |
59 | After training, please go to our [submission repository](https://gitlab.aicrowd.com/yoyololicon/cdx-submissions/).
60 | Then, copy the last checkpoint of HDemucs (usually located at `lightning_logs/version_**/checkpoints/last.ckpt`) to `my_submission/lightning_logs/hdemucs-64-sdr/checkpoints/last.ckpt` in the submission repository.
61 | Similarly, copy the last checkpoint of BandSplitRNN to `my_submission/lightning_logs/bandsplitRNN-music/checkpoints/last.ckpt`.
62 | After these steps, you have reproduced our submission!
63 |
64 | The inference procedure in our submission repository is a bit complex.
65 | Briefly speaking, the HDemucs predicts the targets independently for each channels of the stereo mixture, plus, the average (the mid) and the difference (the side) of the two channels.
66 | The stereo separated sources are made from a linear combination of these mono predictions.
67 | The separated music from the BandSplitRNN is enhanced by Wiener Filtering, and the final music predictions is the average from the two models.
68 |
69 | ### MDX Leaderboard A (Label Noise), submission ID 220426
70 |
71 | This section describes how to reproduce the [best perform model](https://gitlab.aicrowd.com/yoyololicon/mdx23-submissions/-/issues/76) we used on MDX leaderboard A.
72 |
73 | Firstly, we manually inspected the [label noise dataset](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23/dataset_files)(thanks @mhrice for the hard work!) and labeled the clean songs (no label noise).
74 | The labels are recorded in `data/lightning/label_noise.csv`.
75 | Then, a HDemucs was trained only on the clean labels with the following settings:
76 |
77 | * negative SDR as the loss function
78 | * Training occurs on random chunks and random stem combinations of the clean songs
79 | * Training batches are augmented and processed using different random effects
80 | * Due to all this randomization, validation is done also on the training dataset (no separate validation set)
81 |
82 | To reproduce the training:
83 | ```commandline
84 | python main.py fit --config cfg/mdx_a/hdemucs.yaml --data.init_args.root /DATASET_ROOT/
85 | ```
86 | Remember to place the label noise data under `/DATASET_ROOT/train/`.
87 |
88 | Other details:
89 | * Model is trained for ~800 epochs (approx. 2 weeks on 4 RTX A50000)
90 | * During the last ~200 epochs, the learning rate is reduced to 0.001, gradient accumulation is increased to 64, and the effect randomization chance is increased by a factor of 1.666 (e.g. 30% to 50% etc.)
91 |
92 | After training, please go to our [submission repository](https://gitlab.aicrowd.com/yoyololicon/mdx23-submissions/).
93 | Then, copy the checkpoint to `my_submission/acc64_4devices_lr0001_e1213_last.ckpt` in the submission repository.
94 | After these steps, you have reproduced our submission!
95 |
96 |
97 | ## Structure
98 |
99 | * `aimless`: package root, which can be imported for submission.
100 | * `loss`: loss functions.
101 | * `freq.*`: loss functions for frequency-domain models .
102 | * `time.*`: loss functions for time-domain models.
103 | * `augment`: data augmentations that are better on GPU.
104 | * `lightning`: all lightning modules.
105 | * `waveform.WaveformSeparator`: trainer for time-domain models.
106 | * `freq_mask.MaskPredictor`: trainer for frequency-domain models.
107 | * `models`: your custom models.
108 | * `cfg`: all config files.
109 | * `data`:
110 | * `dataset`: custom pytorch datasets.
111 | * `lightning`: all lightning data modules.
112 | * `augment`: data augmentations that are better on CPU.
113 |
114 | ## Streamlit
115 |
116 | Split song in the browser with pretrained Hybrid Demucs.
117 |
118 | ``` streamlit run scripts/webapp.py ```
119 |
120 | Then open [http://localhost:8501/](http://localhost:8501/) in your browser.
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/aimless/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/__init__.py
--------------------------------------------------------------------------------
/aimless/augment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import torchaudio
5 | from torchaudio.transforms import TimeStretch, Spectrogram, InverseSpectrogram, Resample
6 | from torchaudio import functional as aF
7 | from torch_fftconv import fft_conv1d
8 | from pathlib import Path
9 |
10 | __all__ = ["SpeedPerturb", "RandomPitch", "RandomConvolutions", "CudaBase"]
11 |
12 |
13 | class CudaBase(nn.Module):
14 | def __init__(self, rand_size, p=0.2):
15 | super().__init__()
16 | self.p = p
17 | self.rand_size = rand_size
18 |
19 | def _transform(self, stems, index):
20 | """
21 | Args:
22 | stems (torch.Tensor): (B, Num_channels, L)
23 | index (int): index of random transform
24 | Return:
25 | perturbed_stems (torch.Tensor): (B, Num_channels, L')
26 | """
27 | raise NotImplementedError
28 |
29 | def forward(self, stems: torch.Tensor):
30 | """
31 | Args:
32 | stems (torch.Tensor): (B, Num_sources, Num_channels, L)
33 | Return:
34 | perturbed_stems (torch.Tensor): (B, Num_sources, Num_channels, L')
35 | """
36 | shape = stems.shape
37 | orig_len = shape[-1]
38 | stems = stems.view(-1, *shape[-2:])
39 | select_mask = torch.rand(stems.shape[0], device=stems.device) < self.p
40 | if not torch.any(select_mask):
41 | return stems.view(*shape)
42 |
43 | select_idx = torch.where(select_mask)[0]
44 | perturbed_stems = torch.zeros_like(stems)
45 | perturbed_stems[~select_mask] = stems[~select_mask]
46 | selected_stems = stems[select_mask]
47 | rand_idx = torch.randint(
48 | self.rand_size, (selected_stems.shape[0],), device=stems.device
49 | )
50 |
51 | for i in range(self.rand_size):
52 | mask = rand_idx == i
53 | if not torch.any(mask):
54 | continue
55 | masked_stems = selected_stems[mask]
56 | perturbed_audio = self._transform(masked_stems, i).to(perturbed_stems.dtype)
57 |
58 | diff = perturbed_audio.shape[-1] - orig_len
59 |
60 | put_idx = select_idx[mask]
61 | if diff >= 0:
62 | perturbed_stems[put_idx] = perturbed_audio[..., :orig_len]
63 | else:
64 | perturbed_stems[put_idx, :, : orig_len + diff] = perturbed_audio
65 |
66 | perturbed_stems = perturbed_stems.view(*shape)
67 | return perturbed_stems
68 |
69 |
70 | class SpeedPerturb(CudaBase):
71 | def __init__(self, orig_freq=44100, speeds=[90, 100, 110], **kwargs):
72 | super().__init__(len(speeds), **kwargs)
73 | self.orig_freq = orig_freq
74 | self.resamplers = nn.ModuleList()
75 | self.speeds = speeds
76 | for s in self.speeds:
77 | new_freq = self.orig_freq * s // 100
78 | self.resamplers.append(Resample(self.orig_freq, new_freq))
79 |
80 | def _transform(self, stems, index):
81 | y = self.resamplers[index](stems.view(-1, stems.shape[-1])).view(
82 | *stems.shape[:-1], -1
83 | )
84 | return y
85 |
86 |
87 | class RandomPitch(CudaBase):
88 | def __init__(
89 | self, semitones=[-2, -1, 0, 1, 2], n_fft=2048, hop_length=512, **kwargs
90 | ):
91 | super().__init__(len(semitones), **kwargs)
92 | self.resamplers = nn.ModuleList()
93 |
94 | semitones = torch.tensor(semitones, dtype=torch.float32)
95 | rates = 2 ** (-semitones / 12)
96 | rrates = rates.reciprocal()
97 | rrates = (rrates * 100).long()
98 | rrates[rrates % 2 == 1] += 1
99 | rates = 100 / rrates
100 |
101 | self.register_buffer("rates", rates)
102 | self.spec = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
103 | self.inv_spec = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)
104 | self.stretcher = TimeStretch(hop_length, n_freq=n_fft // 2 + 1)
105 |
106 | for rr in rrates.tolist():
107 | self.resamplers.append(Resample(rr, 100))
108 |
109 | def _transform(self, stems, index):
110 | spec = self.spec(stems)
111 | stretched_spec = self.stretcher(spec, self.rates[index])
112 | stretched_stems = self.inv_spec(stretched_spec)
113 | shifted_stems = self.resamplers[index](
114 | stretched_stems.view(-1, stretched_stems.shape[-1])
115 | ).view(*stretched_stems.shape[:-1], -1)
116 | return shifted_stems
117 |
118 |
119 | class RandomConvolutions(CudaBase):
120 | def __init__(self, target_sr: int, ir_folder: str, **kwargs):
121 | ir_folder = Path(ir_folder)
122 | ir_files = list(ir_folder.glob("**/*.wav"))
123 | impulses = []
124 | for ir_file in ir_files:
125 | ir, sr = torchaudio.load(ir_file)
126 | if ir.shape[0] > 2:
127 | continue
128 | if sr != target_sr:
129 | ir = aF.resample(ir, sr, target_sr)
130 | if ir.shape[0] == 1:
131 | ir = ir.repeat(2, 1)
132 | impulses.append(ir)
133 |
134 | super().__init__(len(impulses), **kwargs)
135 | for i, impulse in enumerate(impulses):
136 | self.register_buffer(f"impulse_{i}", impulse)
137 |
138 | def _transform(self, stems, index):
139 | ir = self.get_buffer(f"impulse_{index}").unsqueeze(1)
140 | ir_flipped = ir.flip(-1)
141 | padded_stems = F.pad(stems, (ir.shape[-1] - 1, 0))
142 | # TODO: dynamically use F.conv1d if impulse is short
143 | convolved_stems = fft_conv1d(padded_stems, ir_flipped, groups=2)
144 | return convolved_stems
145 |
--------------------------------------------------------------------------------
/aimless/lightning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/lightning/__init__.py
--------------------------------------------------------------------------------
/aimless/lightning/freq_mask.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | from torch import nn
4 | from typing import List, Dict
5 | from torchaudio.transforms import Spectrogram, InverseSpectrogram
6 |
7 | from ..loss.time import SDR
8 | from ..loss.freq import FLoss
9 | from ..augment import CudaBase
10 |
11 | from ..utils import MWF, MDX_SOURCES, SDX_SOURCES, SE_SOURCES
12 |
13 |
14 | class MaskPredictor(pl.LightningModule):
15 | def __init__(
16 | self,
17 | model: nn.Module,
18 | criterion: FLoss,
19 | transforms: List[CudaBase] = None,
20 | target_track: str = None,
21 | targets: Dict[str, None] = {},
22 | n_fft: int = 4096,
23 | hop_length: int = 1024,
24 | **mwf_kwargs,
25 | ):
26 | super().__init__()
27 |
28 | self.model = model
29 | self.criterion = criterion
30 | self.sdr = SDR()
31 | self.mwf = MWF(**mwf_kwargs)
32 | self.spec = Spectrogram(n_fft=n_fft, hop_length=hop_length, power=None)
33 | self.inv_spec = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)
34 |
35 | if transforms is None:
36 | transforms = []
37 |
38 | self.transforms = nn.Sequential(*transforms)
39 | if target_track == "sdx":
40 | self.sources = SDX_SOURCES
41 | elif target_track == "mdx":
42 | self.sources = MDX_SOURCES
43 | elif target_track == "se":
44 | self.sources = SE_SOURCES
45 | else:
46 | raise ValueError(f"Invalid target track: {target_track}")
47 | self.register_buffer(
48 | "targets_idx",
49 | torch.tensor(sorted([self.sources.index(target) for target in targets])),
50 | )
51 |
52 | def forward(self, x):
53 | X = self.spec(x)
54 | X_mag = X.abs()
55 | pred_mask = self.model(X_mag)
56 | Y = self.mwf(pred_mask, X)
57 | pred = self.inv_spec(Y)
58 | return pred
59 |
60 | def training_step(self, batch, batch_idx):
61 | x, y = batch
62 | if len(self.transforms) > 0:
63 | y = self.transforms(y)
64 | x = y.sum(1)
65 | y = y[:, self.targets_idx].squeeze(1)
66 |
67 | X = self.spec(x)
68 | Y = self.spec(y)
69 | X_mag = X.abs()
70 | pred_mask = self.model(X_mag)
71 | loss, values = self.criterion(pred_mask, Y, X, y, x)
72 |
73 | values["loss"] = loss
74 | self.log_dict(values, prog_bar=False, sync_dist=True)
75 | return loss
76 |
77 | def validation_step(self, batch, batch_idx):
78 | x, y = batch
79 | y = y[:, self.targets_idx].squeeze(1)
80 |
81 | X = self.spec(x)
82 | Y = self.spec(y)
83 | X_mag = X.abs()
84 | pred_mask = self.model(X_mag)
85 | loss, values = self.criterion(pred_mask, Y, X, y, x)
86 |
87 | pred = self.inv_spec(self.mwf(pred_mask, X))
88 |
89 | batch = pred.shape[0]
90 | sdrs = (
91 | self.sdr(pred.view(-1, *pred.shape[-2:]), y.view(-1, *y.shape[-2:]))
92 | .view(batch, -1)
93 | .mean(0)
94 | )
95 |
96 | for i, t in enumerate(self.targets_idx):
97 | values[f"{self.sources[t]}_sdr"] = sdrs[i].item()
98 | values["avg_sdr"] = sdrs.mean().item()
99 | return loss, values
100 |
101 | def validation_epoch_end(self, outputs) -> None:
102 | avg_loss = sum(x[0] for x in outputs) / len(outputs)
103 | avg_values = {}
104 | for k in outputs[0][1].keys():
105 | avg_values[k] = sum(x[1][k] for x in outputs) / len(outputs)
106 |
107 | self.log("val_loss", avg_loss, prog_bar=True, sync_dist=True)
108 | self.log_dict(avg_values, prog_bar=False, sync_dist=True)
109 |
--------------------------------------------------------------------------------
/aimless/lightning/waveform.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | from torch import nn
4 | from typing import List, Dict
5 |
6 | from ..loss.time import TLoss, SDR
7 | from ..augment import CudaBase
8 |
9 | from ..utils import MDX_SOURCES, SDX_SOURCES, SE_SOURCES
10 |
11 |
12 | class WaveformSeparator(pl.LightningModule):
13 | def __init__(
14 | self,
15 | model: nn.Module,
16 | criterion: TLoss,
17 | transforms: List[CudaBase] = None,
18 | use_sdx_targets: bool = False, # will be deprecated, please use target_track
19 | target_track: str = None,
20 | targets: Dict[str, None] = {},
21 | ):
22 | super().__init__()
23 |
24 | self.model = model
25 | self.criterion = criterion
26 | self.sdr = SDR()
27 |
28 | if transforms is None:
29 | transforms = []
30 |
31 | if target_track is not None:
32 | if target_track == "sdx":
33 | self.sources = SDX_SOURCES
34 | elif target_track == "mdx":
35 | self.sources = MDX_SOURCES
36 | elif target_track == "se":
37 | self.sources = SE_SOURCES
38 | else:
39 | raise ValueError(f"Invalid target track: {target_track}")
40 | else:
41 | self.sources = SDX_SOURCES if use_sdx_targets else MDX_SOURCES
42 |
43 | self.transforms = nn.Sequential(*transforms)
44 | self.register_buffer(
45 | "targets_idx",
46 | torch.tensor(sorted([self.sources.index(target) for target in targets])),
47 | )
48 |
49 | def forward(self, x):
50 | return self.model(x)
51 |
52 | def training_step(self, batch, batch_idx):
53 | x, y = batch
54 | if len(self.transforms) > 0:
55 | y = self.transforms(y)
56 | x = y.sum(1)
57 | y = y[:, self.targets_idx].squeeze(1)
58 |
59 | pred = self.model(x)
60 | if pred.ndim == 4:
61 | pred = pred.squeeze(1)
62 | loss, values = self.criterion(pred, y, x)
63 |
64 | values["loss"] = loss
65 | self.log_dict(values, prog_bar=False, sync_dist=True)
66 | return loss
67 |
68 | def validation_step(self, batch, batch_idx):
69 | x, y = batch
70 | y = y[:, self.targets_idx].squeeze(1)
71 |
72 | pred = self.model(x)
73 | loss, values = self.criterion(pred, y, x)
74 |
75 | batch = pred.shape[0]
76 | sdrs = (
77 | self.sdr(pred.view(-1, *pred.shape[-2:]), y.view(-1, *y.shape[-2:]))
78 | .view(batch, -1)
79 | .mean(0)
80 | )
81 |
82 | for i, t in enumerate(self.targets_idx):
83 | values[f"{self.sources[t]}_sdr"] = sdrs[i].item()
84 | values["avg_sdr"] = sdrs.mean().item()
85 | return loss, values
86 |
87 | def validation_epoch_end(self, outputs) -> None:
88 | avg_loss = sum(x[0] for x in outputs) / len(outputs)
89 | avg_values = {}
90 | for k in outputs[0][1].keys():
91 | avg_values[k] = sum(x[1][k] for x in outputs) / len(outputs)
92 |
93 | self.log("val_loss", avg_loss, prog_bar=True, sync_dist=True)
94 | self.log_dict(avg_values, prog_bar=False, sync_dist=True)
95 |
--------------------------------------------------------------------------------
/aimless/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/loss/__init__.py
--------------------------------------------------------------------------------
/aimless/loss/freq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from itertools import combinations, chain
4 | from ..utils import MWF
5 | from torchaudio.transforms import InverseSpectrogram
6 |
7 |
8 | class FLoss(torch.nn.Module):
9 | r"""Base class for frequency domain loss modules.
10 | You can't use this module directly.
11 | Your loss should also subclass this class.
12 | Args:
13 | msk_hat (Tensor): mask tensor with size (batch, *, channels, bins, frames), `*` is an optional multi targets dimension.
14 | The tensor value should always be non-negative
15 | gt_spec (Tensor): target spectrogram complex tensor with size (batch, *, channels, bins, frames), `*` is an optional multi targets dimension.
16 | mix_spec (Tensor): mixture spectrogram complex tensor with size (batch, channels, bins, frames)
17 | gt (Tensor): target time signal tensor with size (batch, *, channels, samples), `*` is an optional multi targets dimension.
18 | mix (Tensor): mixture time signal tensor with size (batch, channels, samples)
19 | Returns:
20 | tuple: a length-2 tuple with the first element is the final loss tensor,
21 | and the second is a dict containing any intermediate loss value you want to monitor
22 | """
23 |
24 | def forward(self, *args, **kwargs):
25 | return self._core_loss(*args, **kwargs)
26 |
27 | def _core_loss(self, msk_hat, gt_spec, mix_spec, gt, mix):
28 | raise NotImplementedError
29 |
30 |
31 | class CLoss(FLoss):
32 | def __init__(
33 | self, mcoeff=10, n_fft=4096, hop_length=1024, complex_mse=True, **mwf_kwargs
34 | ):
35 | super().__init__()
36 | self.mcoeff = mcoeff
37 | self.inv_spec = InverseSpectrogram(n_fft=n_fft, hop_length=hop_length)
38 | self.complex_mse = complex_mse
39 | if len(mwf_kwargs):
40 | self.mwf = MWF(**mwf_kwargs)
41 |
42 | def _core_loss(self, msk_hat, gt_spec, mix_spec, gt, mix):
43 | if hasattr(self, "mwf"):
44 | Y = self.mwf(msk_hat, mix_spec)
45 | else:
46 | Y = msk_hat * mix_spec.unsqueeze(1)
47 | pred = self.inv_spec(Y)
48 | if self.complex_mse:
49 | loss_f = complex_mse_loss(Y, gt_spec)
50 | else:
51 | loss_f = real_mse_loss(msk_hat, gt_spec.abs(), mix_spec.abs())
52 | loss_t = sdr_loss(pred, gt, mix)
53 | loss = loss_f + self.mcoeff * loss_t
54 | return loss, {"loss_f": loss_f.item(), "loss_t": loss_t.item()}
55 |
56 |
57 | class MDLoss(FLoss):
58 | def __init__(self, mcoeff=10, n_fft=4096, hop_length=1024):
59 | super().__init__()
60 | self.mcoeff = mcoeff
61 | self.inv_spec = InverseSpectrogram(n_fft, hop_length=hop_length)
62 |
63 | def _core_loss(self, msk_hat, gt_spec, mix_spec, gt, mix):
64 | pred_spec = msk_hat * mix_spec
65 | diff = pred_spec - gt_spec
66 |
67 | real = diff.real.reshape(-1)
68 | imag = diff.imag.reshape(-1)
69 | mse = real @ real + imag @ imag
70 | loss_f = mse / real.numel()
71 |
72 | pred = self.inv_spec(pred_spec)
73 | batch_size, n_channels, length = pred.shape
74 |
75 | # Fix Length
76 | mix = mix[..., :length].reshape(-1, length)
77 | gt = gt[..., :length].reshape(-1, length)
78 | pred = pred.view(-1, length)
79 |
80 | loss_t = _sdr_loss_core(pred, gt, mix) + 1.0
81 | loss = loss_f + self.mcoeff * loss_t
82 | return loss, {"loss_f": loss_f.item(), "loss_t": loss_t.item()}
83 |
84 |
85 | def bce_loss(msk_hat, gt_spec):
86 | assert msk_hat.shape == gt_spec.shape
87 | loss = []
88 | gt_spec_power = gt_spec.abs()
89 | gt_spec_power *= gt_spec_power
90 | divider = gt_spec_power.sum(1) + 1e-10
91 | for c in chain(
92 | combinations(range(4), 1), combinations(range(4), 2), combinations(range(4), 3)
93 | ):
94 | m = sum([msk_hat[:, i] for i in c])
95 | gt = sum([gt_spec_power[:, i] for i in c]) / divider
96 | loss.append(F.binary_cross_entropy(m, gt))
97 |
98 | # All 14 Combination Losses (4C1 + 4C2 + 4C3)
99 | loss_mse = sum(loss) / len(loss)
100 | return loss_mse
101 |
102 |
103 | def complex_mse_loss(y_hat: torch.Tensor, gt_spec: torch.Tensor):
104 | assert y_hat.shape == gt_spec.shape
105 | assert gt_spec.is_complex() and y_hat.is_complex()
106 |
107 | loss = []
108 | for c in chain(
109 | combinations(range(4), 1), combinations(range(4), 2), combinations(range(4), 3)
110 | ):
111 | m = sum([y_hat[:, i] for i in c])
112 | gt = sum([gt_spec[:, i] for i in c])
113 | diff = m - gt
114 | real = diff.real.reshape(-1)
115 | imag = diff.imag.reshape(-1)
116 | mse = real @ real + imag @ imag
117 | loss.append(mse / real.numel())
118 |
119 | # All 14 Combination Losses (4C1 + 4C2 + 4C3)
120 | loss_mse = sum(loss) / len(loss)
121 | return loss_mse
122 |
123 |
124 | def real_mse_loss(msk_hat: torch.Tensor, gt_spec: torch.Tensor, mix_spec: torch.Tensor):
125 | assert msk_hat.shape == gt_spec.shape
126 | assert (
127 | msk_hat.is_floating_point()
128 | and gt_spec.is_floating_point()
129 | and mix_spec.is_floating_point()
130 | )
131 | assert not gt_spec.is_complex() and not mix_spec.is_complex()
132 |
133 | loss = []
134 | for c in chain(
135 | combinations(range(4), 1), combinations(range(4), 2), combinations(range(4), 3)
136 | ):
137 | m = sum([msk_hat[:, i] for i in c])
138 | gt = sum([gt_spec[:, i] for i in c])
139 | loss.append(F.mse_loss(m * mix_spec, gt))
140 |
141 | # All 14 Combination Losses (4C1 + 4C2 + 4C3)
142 | loss_mse = sum(loss) / len(loss)
143 | return loss_mse
144 |
145 |
146 | def sdr_loss(pred, gt_time, mix):
147 | # SDR-Combination Loss
148 |
149 | batch_size, _, n_channels, length = pred.shape
150 | pred, gt_time = (
151 | pred.transpose(0, 1).contiguous(),
152 | gt_time.transpose(0, 1).contiguous(),
153 | )
154 |
155 | # Fix Length
156 | mix = mix[..., :length].reshape(-1, length)
157 | gt_time = gt_time[..., :length].reshape(_, -1, length)
158 | pred = pred.view(_, -1, length)
159 |
160 | extend_pred = [pred.view(-1, length)]
161 | extend_gt = [gt_time.view(-1, length)]
162 |
163 | for c in chain(combinations(range(4), 2), combinations(range(4), 3)):
164 | extend_pred.append(sum([pred[i] for i in c]))
165 | extend_gt.append(sum([gt_time[i] for i in c]))
166 |
167 | extend_pred = torch.cat(extend_pred, 0)
168 | extend_gt = torch.cat(extend_gt, 0)
169 | extend_mix = mix.repeat(14, 1)
170 |
171 | loss_sdr = _sdr_loss_core(extend_pred, extend_gt, extend_mix)
172 |
173 | return 1.0 + loss_sdr
174 |
175 |
176 | def _sdr_loss_core(x_hat, x, y):
177 | assert x.shape == y.shape == x_hat.shape # (Batch, Len)
178 |
179 | ns = y - x
180 | ns_hat = y - x_hat
181 |
182 | ns_norm = ns[:, None, :] @ ns[:, :, None]
183 | ns_hat_norm = ns_hat[:, None, :] @ ns_hat[:, :, None]
184 |
185 | x_norm = x[:, None, :] @ x[:, :, None]
186 | x_hat_norm = x_hat[:, None, :] @ x_hat[:, :, None]
187 | x_cross = x[:, None, :] @ x_hat[:, :, None]
188 |
189 | x_norm, x_hat_norm, ns_norm, ns_hat_norm = (
190 | x_norm.relu(),
191 | x_hat_norm.relu(),
192 | ns_norm.relu(),
193 | ns_hat_norm.relu(),
194 | )
195 |
196 | alpha = x_norm / (ns_norm + x_norm + 1e-10)
197 |
198 | # Target
199 | sdr_cln = x_cross / (x_norm.sqrt() * x_hat_norm.sqrt() + 1e-10)
200 |
201 | # Noise
202 | num_noise = ns[:, None, :] @ ns_hat[:, :, None]
203 | denom_noise = ns_norm.sqrt() * ns_hat_norm.sqrt()
204 | sdr_noise = num_noise / (denom_noise + 1e-10)
205 |
206 | return torch.mean(-alpha * sdr_cln - (1 - alpha) * sdr_noise)
207 |
--------------------------------------------------------------------------------
/aimless/loss/time.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from itertools import combinations, chain
4 |
5 |
6 | class TLoss(torch.nn.Module):
7 | r"""Base class for time domain loss modules.
8 | You can't use this module directly.
9 | Your loss should also subclass this class.
10 | Args:
11 | pred (Tensor): predict time signal tensor with size (batch, *, channels, samples), `*` is an optional multi targets dimension.
12 | gt (Tensor): target time signal tensor with size (batch, *, channels, samples), `*` is an optional multi targets dimension.
13 | mix (Tensor): mixture time signal tensor with size (batch, channels, samples)
14 | Returns:
15 | tuple: a length-2 tuple with the first element is the final loss tensor,
16 | and the second is a dict containing any intermediate loss value you want to monitor
17 | """
18 |
19 | def forward(self, *args, **kwargs):
20 | return self._core_loss(*args, **kwargs)
21 |
22 | def _core_loss(self, pred, gt, mix):
23 | raise NotImplementedError
24 |
25 |
26 | class SDR(torch.nn.Module):
27 | def __init__(self) -> None:
28 | super().__init__()
29 | self.expr = "bi,bi->b"
30 |
31 | def _batch_dot(self, x, y):
32 | return torch.einsum(self.expr, x, y)
33 |
34 | def forward(self, estimates, references):
35 | if estimates.dtype != references.dtype:
36 | estimates = estimates.to(references.dtype)
37 | length = min(references.shape[-1], estimates.shape[-1])
38 | references = references[..., :length].reshape(references.shape[0], -1)
39 | estimates = estimates[..., :length].reshape(estimates.shape[0], -1)
40 |
41 | delta = 1e-7 # avoid numerical errors
42 | num = self._batch_dot(references, references)
43 | den = (
44 | num
45 | + self._batch_dot(estimates, estimates)
46 | - 2 * self._batch_dot(estimates, references)
47 | )
48 | den = den.relu().add(delta).log10()
49 | num = num.add(delta).log10()
50 | return 10 * (num - den)
51 |
52 |
53 | class NegativeSDR(TLoss):
54 | def __init__(self) -> None:
55 | super().__init__()
56 | self.sdr = SDR()
57 |
58 | def _core_loss(self, pred, gt, mix):
59 | return -self.sdr(pred, gt).mean(), {}
60 |
61 |
62 | class CL1Loss(TLoss):
63 | def _core_loss(self, pred, gt, mix):
64 | gt = gt[..., : pred.shape[-1]]
65 | loss = []
66 | for c in chain(
67 | combinations(range(4), 1),
68 | combinations(range(4), 2),
69 | combinations(range(4), 3),
70 | ):
71 | x = sum([pred[:, i] for i in c])
72 | y = sum([gt[:, i] for i in c])
73 | loss.append(F.l1_loss(x, y))
74 |
75 | # All 14 Combination Losses (4C1 + 4C2 + 4C3)
76 | loss_l1 = sum(loss) / len(loss)
77 | return loss_l1, {}
78 |
79 |
80 | class L1Loss(TLoss):
81 | def _core_loss(self, pred, gt, mix):
82 | return F.l1_loss(pred, gt[..., : pred.shape[-1]]), {}
83 |
--------------------------------------------------------------------------------
/aimless/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/models/__init__.py
--------------------------------------------------------------------------------
/aimless/models/band_split_rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import List, Tuple, Dict, Optional
5 |
6 | v7_freqs = [
7 | 100,
8 | 200,
9 | 300,
10 | 400,
11 | 500,
12 | 600,
13 | 700,
14 | 800,
15 | 900,
16 | 1000,
17 | 1250,
18 | 1500,
19 | 1750,
20 | 2000,
21 | 2250,
22 | 2500,
23 | 2750,
24 | 3000,
25 | 3250,
26 | 3500,
27 | 3750,
28 | 4000,
29 | 4500,
30 | 5000,
31 | 5500,
32 | 6000,
33 | 6500,
34 | 7000,
35 | 7500,
36 | 8000,
37 | 9000,
38 | 10000,
39 | 11000,
40 | 12000,
41 | 13000,
42 | 14000,
43 | 15000,
44 | 16000,
45 | 18000,
46 | 20000,
47 | 22000,
48 | ]
49 |
50 |
51 | class BandSplitRNN(nn.Module):
52 | def __init__(
53 | self,
54 | n_fft: int,
55 | split_freqs: List[int] = v7_freqs,
56 | hidden_size: int = 128,
57 | num_layers: int = 12,
58 | norm_groups: int = 4,
59 | ) -> None:
60 | super().__init__()
61 |
62 | # get split freq bins index from split freqs
63 | index = [0] + [int(n_fft * f / 44100) for f in split_freqs] + [n_fft // 2 + 1]
64 | chunk_size = [index[i + 1] - index[i] for i in range(len(index) - 1)]
65 | self.split_sections = tuple(chunk_size)
66 |
67 | # stage 1: band split modules
68 | self.norm1_list = nn.ModuleList(
69 | [nn.LayerNorm(chunk_size[i]) for i in range(len(chunk_size))]
70 | )
71 | self.fc1_list = nn.ModuleList(
72 | [nn.Linear(chunk_size[i], hidden_size) for i in range(len(chunk_size))]
73 | )
74 |
75 | # stage 2: RNN modules
76 | self.band_lstms = nn.ModuleList()
77 | self.time_lstms = nn.ModuleList()
78 | self.band_group_norms = nn.ModuleList()
79 | self.time_group_norms = nn.ModuleList()
80 | for i in range(num_layers):
81 | self.band_group_norms.append(nn.GroupNorm(norm_groups, hidden_size))
82 | self.band_lstms.append(
83 | nn.LSTM(
84 | input_size=hidden_size,
85 | hidden_size=hidden_size,
86 | batch_first=True,
87 | num_layers=1,
88 | bidirectional=True,
89 | proj_size=hidden_size // 2,
90 | )
91 | )
92 |
93 | self.time_lstms.append(
94 | nn.LSTM(
95 | input_size=hidden_size,
96 | hidden_size=hidden_size,
97 | batch_first=True,
98 | num_layers=1,
99 | bidirectional=True,
100 | proj_size=hidden_size // 2,
101 | )
102 | )
103 | self.time_group_norms.append(nn.GroupNorm(norm_groups, hidden_size))
104 |
105 | # stage 3: band merge modules and mask prediction modules
106 | self.norm2_list = nn.ModuleList(
107 | [nn.LayerNorm(hidden_size) for i in range(len(chunk_size))]
108 | )
109 | self.mlps = nn.ModuleList(
110 | [
111 | nn.Sequential(
112 | nn.Linear(hidden_size, hidden_size * 4),
113 | nn.Tanh(),
114 | nn.Linear(hidden_size * 4, chunk_size[i]),
115 | )
116 | for i in range(len(chunk_size))
117 | ]
118 | )
119 |
120 | def forward(self, mag: torch.Tensor):
121 | log_mag = torch.log(mag + 1e-8)
122 | batch, channels, freq_bins, time_bins = log_mag.shape
123 | # merge channels to batch
124 | log_mag = log_mag.view(-1, freq_bins, time_bins).transpose(1, 2)
125 |
126 | # stage 1: band split modules
127 | tmp = []
128 | for fc, norm, sub_band in zip(
129 | self.fc1_list, self.norm1_list, log_mag.split(self.split_sections, dim=2)
130 | ):
131 | tmp.append(fc(norm(sub_band)))
132 |
133 | x = torch.stack(tmp, dim=1)
134 |
135 | # stage 2: RNN modules
136 | for band_lstm, time_lstm, band_norm, time_norm in zip(
137 | self.band_lstms,
138 | self.time_lstms,
139 | self.band_group_norms,
140 | self.time_group_norms,
141 | ):
142 | x_reshape = x.reshape(-1, x.shape[-2], x.shape[-1])
143 | x_reshape = time_norm(x_reshape.transpose(1, 2)).transpose(1, 2)
144 | x = time_lstm(x_reshape)[0].view(*x.shape) + x
145 |
146 | x_reshape = x.transpose(1, 2).reshape(-1, x.shape[-3], x.shape[-1])
147 | x_reshape = band_norm(x_reshape.transpose(1, 2)).transpose(1, 2)
148 | x = (
149 | band_lstm(x_reshape)[0]
150 | .view(x.shape[0], x.shape[2], x.shape[1], x.shape[3])
151 | .transpose(1, 2)
152 | + x
153 | )
154 |
155 | # stage 3: band merge modules and mask prediction modules
156 | tmp = []
157 | for i, (mlp, norm) in enumerate(zip(self.mlps, self.norm2_list)):
158 | tmp.append(mlp(norm(x[:, i])))
159 |
160 | mask = (
161 | torch.cat(tmp, dim=-1)
162 | .transpose(1, 2)
163 | .reshape(batch, channels, freq_bins, time_bins)
164 | .sigmoid()
165 | )
166 | return mask
167 |
168 |
169 | class BandSplitRNNMulti(nn.Module):
170 | def __init__(
171 | self,
172 | n_fft: int,
173 | n_sources: int,
174 | split_freqs: List[int] = v7_freqs,
175 | hidden_size: int = 128,
176 | num_layers: int = 12,
177 | norm_groups: int = 4,
178 | ) -> None:
179 | super().__init__()
180 |
181 | # get split freq bins index from split freqs
182 | index = [0] + [int(n_fft * f / 44100) for f in split_freqs] + [n_fft // 2 + 1]
183 | chunk_size = [index[i + 1] - index[i] for i in range(len(index) - 1)]
184 | self.split_sections = tuple(chunk_size)
185 |
186 | self.n_sources = n_sources
187 |
188 | # stage 1: band split modules
189 | self.norm1_list = nn.ModuleList(
190 | [nn.LayerNorm(chunk_size[i]) for i in range(len(chunk_size))]
191 | )
192 | self.fc1_list = nn.ModuleList(
193 | [nn.Linear(chunk_size[i], hidden_size) for i in range(len(chunk_size))]
194 | )
195 |
196 | # stage 2: RNN modules
197 | self.band_lstms = nn.ModuleList()
198 | self.time_lstms = nn.ModuleList()
199 | self.band_group_norms = nn.ModuleList()
200 | self.time_group_norms = nn.ModuleList()
201 | for i in range(num_layers):
202 | self.band_group_norms.append(nn.GroupNorm(norm_groups, hidden_size))
203 | self.band_lstms.append(
204 | nn.LSTM(
205 | input_size=hidden_size,
206 | hidden_size=hidden_size,
207 | batch_first=True,
208 | num_layers=1,
209 | bidirectional=True,
210 | proj_size=hidden_size // 2,
211 | )
212 | )
213 |
214 | self.time_lstms.append(
215 | nn.LSTM(
216 | input_size=hidden_size,
217 | hidden_size=hidden_size,
218 | batch_first=True,
219 | num_layers=1,
220 | bidirectional=True,
221 | proj_size=hidden_size // 2,
222 | )
223 | )
224 | self.time_group_norms.append(nn.GroupNorm(norm_groups, hidden_size))
225 |
226 | # stage 3: band merge modules and mask prediction modules
227 | self.norm2_list = nn.ModuleList(
228 | [nn.LayerNorm(hidden_size) for i in range(len(chunk_size))]
229 | )
230 | self.mlps = nn.ModuleList(
231 | [
232 | nn.Sequential(
233 | nn.Linear(hidden_size, hidden_size * 4),
234 | nn.Tanh(),
235 | nn.Linear(hidden_size * 4, chunk_size[i] * n_sources),
236 | )
237 | for i in range(len(chunk_size))
238 | ]
239 | )
240 |
241 | def forward(self, mag: torch.Tensor):
242 | log_mag = torch.log(mag + 1e-8)
243 | batch, channels, freq_bins, time_bins = log_mag.shape
244 | # merge channels to batch
245 | log_mag = log_mag.view(-1, freq_bins, time_bins).transpose(1, 2)
246 |
247 | # stage 1: band split modules
248 | tmp = []
249 | for fc, norm, sub_band in zip(
250 | self.fc1_list, self.norm1_list, log_mag.split(self.split_sections, dim=2)
251 | ):
252 | tmp.append(fc(norm(sub_band)))
253 |
254 | x = torch.stack(tmp, dim=1)
255 |
256 | # stage 2: RNN modules
257 | for band_lstm, time_lstm, band_norm, time_norm in zip(
258 | self.band_lstms,
259 | self.time_lstms,
260 | self.band_group_norms,
261 | self.time_group_norms,
262 | ):
263 | x_reshape = x.reshape(-1, x.shape[-2], x.shape[-1])
264 | x_reshape = time_norm(x_reshape.transpose(1, 2)).transpose(1, 2)
265 | x = time_lstm(x_reshape)[0].view(*x.shape) + x
266 |
267 | x_reshape = x.transpose(1, 2).reshape(-1, x.shape[-3], x.shape[-1])
268 | x_reshape = band_norm(x_reshape.transpose(1, 2)).transpose(1, 2)
269 | x = (
270 | band_lstm(x_reshape)[0]
271 | .view(x.shape[0], x.shape[2], x.shape[1], x.shape[3])
272 | .transpose(1, 2)
273 | + x
274 | )
275 |
276 | # stage 3: band merge modules and mask prediction modules
277 | tmp = []
278 | for i, (mlp, norm) in enumerate(zip(self.mlps, self.norm2_list)):
279 | tmp.append(
280 | mlp(norm(x[:, i])).view(x.shape[0], x.shape[2], self.n_sources, -1)
281 | )
282 |
283 | mask = (
284 | torch.cat(tmp, dim=-1)
285 | .reshape(batch, channels, time_bins, self.n_sources, freq_bins)
286 | .permute(0, 3, 1, 4, 2)
287 | .softmax(dim=1)
288 | )
289 | return mask
290 |
--------------------------------------------------------------------------------
/aimless/models/demucs_split.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchaudio.transforms import Resample
4 |
5 |
6 | @torch.jit.script
7 | def glu(a, b):
8 | return a * b.sigmoid()
9 |
10 |
11 | @torch.jit.script
12 | def standardize(x, mu, std):
13 | return (x - mu) / std
14 |
15 |
16 | @torch.jit.script
17 | def destandardize(x, mu, std):
18 | return x * std + mu
19 |
20 |
21 | def rescale_conv(reference):
22 | @torch.no_grad()
23 | def closure(m: nn.Module):
24 | if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
25 | std = m.weight.std()
26 | scale = (std / reference) ** 0.5
27 | m.weight.div_(scale)
28 | if m.bias is not None:
29 | m.bias.div_(scale)
30 |
31 | return closure
32 |
33 |
34 | class DemucsSplit(nn.Module):
35 | def __init__(
36 | self,
37 | channels=64,
38 | depth=6,
39 | rescale=0.1,
40 | resample=True,
41 | kernel_size=8,
42 | stride=4,
43 | lstm_layers=2,
44 | ):
45 | super().__init__()
46 | self.kernel_size = kernel_size
47 | self.stride = stride
48 | self.depth = depth
49 | self.channels = channels
50 |
51 | if resample:
52 | self.up_sample = Resample(1, 2)
53 | self.down_sample = Resample(2, 1)
54 |
55 | self.encoder = nn.ModuleList()
56 | self.convs_1x1 = nn.ModuleList()
57 | self.dec_pre_convs = nn.ModuleList()
58 | self.decoder = nn.ModuleList()
59 |
60 | in_channels = 2
61 | for index in range(depth):
62 | self.encoder.append(
63 | nn.Sequential(
64 | nn.Conv1d(in_channels, channels, kernel_size, stride),
65 | nn.ReLU(inplace=True),
66 | nn.Conv1d(channels, channels * 2, 1),
67 | nn.GLU(dim=1),
68 | )
69 | )
70 |
71 | decode = []
72 | if index > 0:
73 | out_channels = in_channels * 2
74 | else:
75 | out_channels = 8
76 |
77 | self.convs_1x1.insert(
78 | 0, nn.Conv1d(channels, channels * 4, 3, padding=1, bias=False)
79 | )
80 | self.dec_pre_convs.insert(
81 | 0, nn.Conv1d(channels * 2, channels * 4, 3, padding=1, groups=4)
82 | )
83 | decode = [
84 | nn.ConvTranspose1d(
85 | channels * 2, out_channels, kernel_size, stride, groups=4
86 | )
87 | ]
88 | if index > 0:
89 | decode.append(nn.ReLU(inplace=True))
90 | self.decoder.insert(0, nn.Sequential(*decode))
91 | in_channels = channels
92 | channels *= 2
93 |
94 | channels = in_channels
95 |
96 | self.lstm = nn.LSTM(
97 | input_size=channels,
98 | hidden_size=channels,
99 | num_layers=lstm_layers,
100 | dropout=0,
101 | bidirectional=True,
102 | )
103 | self.lstm_linear = nn.Linear(channels * 2, channels * 2)
104 |
105 | self.apply(rescale_conv(reference=rescale))
106 |
107 | def forward(self, x):
108 | length = x.size(2)
109 |
110 | mono = x.mean(1, keepdim=True)
111 | mu = mono.mean(dim=-1, keepdim=True)
112 | std = mono.std(dim=-1, keepdim=True).add_(1e-5)
113 | x = standardize(x, mu, std)
114 |
115 | if hasattr(self, "up_sample"):
116 | x = self.up_sample(x)
117 |
118 | saved = []
119 | for encode in self.encoder:
120 | x = encode(x)
121 | saved.append(x)
122 |
123 | x = x.permute(2, 0, 1)
124 | x = self.lstm(x)[0]
125 | x = self.lstm_linear(x).permute(1, 2, 0)
126 |
127 | for decode, pre_dec, conv1x1 in zip(
128 | self.decoder, self.dec_pre_convs, self.convs_1x1
129 | ):
130 | skip = saved.pop()
131 |
132 | x = pre_dec(x) + conv1x1(skip[..., : x.shape[2]])
133 | a, b = x.view(x.shape[0], 4, -1, x.shape[2]).chunk(2, 2)
134 | x = glu(a, b)
135 | x = decode(x.view(x.shape[0], -1, x.shape[3]))
136 |
137 | if hasattr(self, "down_sample"):
138 | x = self.down_sample(x)
139 |
140 | x = destandardize(x, mu, std)
141 | x = x.view(-1, 4, 2, x.size(-1))
142 | return x
143 |
--------------------------------------------------------------------------------
/aimless/models/rel_tdemucs.py:
--------------------------------------------------------------------------------
1 | from math import inf
2 | from typing import Optional
3 | import torch
4 | from torch import nn, Tensor
5 | import torch.nn.functional as F
6 | from torchaudio.transforms import Resample
7 |
8 | from .demucs_split import standardize, destandardize, rescale_conv
9 |
10 |
11 | class PositionalEmbedding(nn.Module):
12 | def __init__(self, demb):
13 | super(PositionalEmbedding, self).__init__()
14 |
15 | self.demb = demb
16 |
17 | inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
18 | self.register_buffer("inv_freq", inv_freq)
19 |
20 | def forward(self, pos_seq):
21 | sinusoid_inp = pos_seq[:, None] * self.inv_freq
22 | pos_emb = torch.view_as_real(torch.exp(1j * sinusoid_inp)).view(
23 | pos_seq.size(0), -1
24 | )
25 | return pos_emb
26 |
27 |
28 | class RelMultiheadAttention(nn.MultiheadAttention):
29 | def __init__(self, *args, **kwargs):
30 | super().__init__(
31 | *args,
32 | bias=False,
33 | add_bias_kv=False,
34 | add_zero_attn=False,
35 | kdim=None,
36 | vdim=None,
37 | batch_first=True,
38 | **kwargs
39 | )
40 | self.register_parameter(
41 | "u", nn.Parameter(torch.zeros(self.num_heads, self.head_dim))
42 | )
43 | self.register_parameter(
44 | "v", nn.Parameter(torch.zeros(self.num_heads, self.head_dim))
45 | )
46 | self.pos_emb = PositionalEmbedding(self.embed_dim)
47 | self.pos_emb_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
48 |
49 | def forward(self, x, mask=None):
50 | # x: [B, T, C]
51 | B, T, _ = x.size()
52 | seq = torch.arange(-T + 1, T, device=x.device)
53 | pos_emb = self.pos_emb_proj(self.pos_emb(seq))
54 | pos_emb = pos_emb.view(-1, self.num_heads, self.head_dim)
55 |
56 | h = x @ self.in_proj_weight.t()
57 | h = h.view(B, T, self.num_heads, self.head_dim * 3)
58 | w_head_q, w_head_k, w_head_v = h.chunk(3, dim=-1)
59 |
60 | rw_head_q = w_head_q + self.u
61 | AC = rw_head_q.transpose(1, 2) @ w_head_k.permute(0, 2, 3, 1)
62 |
63 | rr_head_q = w_head_q + self.v
64 | BD = rr_head_q.transpose(1, 2) @ pos_emb.permute(1, 2, 0) # [B, H, T, 2T-1]
65 | BD = F.pad(BD, (1, 1)).view(B, self.num_heads, 2 * T + 1, T)[
66 | :, :, 1::2, :
67 | ] # [B, H, T, T]
68 |
69 | attn_score = (AC + BD) / self.head_dim**0.5
70 |
71 | if mask is not None:
72 | attn_score = attn_score.masked_fill(mask, -inf)
73 |
74 | with torch.cuda.amp.autocast(enabled=False):
75 | attn_prob = F.softmax(attn_score.float(), dim=-1)
76 | attn_prob = F.dropout(attn_prob, self.dropout, self.training)
77 |
78 | attn_vec = attn_prob @ w_head_v.transpose(1, 2)
79 | return self.out_proj(attn_vec.permute(0, 2, 1, 3).reshape(B, T, -1))
80 |
81 |
82 | class RelEncoderLayer(nn.TransformerEncoderLayer):
83 | def __init__(self, d_model: int, nhead: int, *args, dropout: float = 0.1, **kwargs):
84 | super().__init__(
85 | d_model, nhead, *args, batch_first=True, dropout=dropout, **kwargs
86 | )
87 | self.self_attn = RelMultiheadAttention(d_model, nhead, dropout=dropout)
88 |
89 | def _sa_block(self, x: Tensor, mask: Tensor = None) -> Tensor:
90 | return self.dropout1(self.self_attn(x, mask))
91 |
92 | def forward(
93 | self,
94 | src: Tensor,
95 | src_mask: Optional[Tensor] = None,
96 | src_key_padding_mask: Optional[Tensor] = None,
97 | ) -> Tensor:
98 | x = src
99 | if self.norm_first:
100 | x = x + self._sa_block(self.norm1(x), src_mask)
101 | x = x + self._ff_block(self.norm2(x))
102 | else:
103 | x = self.norm1(x + self._sa_block(x, src_mask))
104 | x = self.norm2(x + self._ff_block(x))
105 | return x
106 |
107 |
108 | class RelTDemucs(nn.Module):
109 | def __init__(
110 | self,
111 | in_channels=2,
112 | num_sources=4,
113 | channels=64,
114 | depth=6,
115 | context_size=None,
116 | rescale=0.1,
117 | resample=True,
118 | kernel_size=8,
119 | stride=4,
120 | attention_layers=8,
121 | **kwargs
122 | ):
123 | super().__init__()
124 | self.kernel_size = kernel_size
125 | self.stride = stride
126 | self.depth = depth
127 | self.channels = channels
128 | self.context_size = context_size
129 |
130 | if resample:
131 | self.up_sample = Resample(1, 2)
132 | self.down_sample = Resample(2, 1)
133 |
134 | self.encoder = nn.ModuleList()
135 | self.decoder = nn.ModuleList()
136 |
137 | current_channels = in_channels
138 | for index in range(depth):
139 | self.encoder.append(
140 | nn.Sequential(
141 | nn.Conv1d(current_channels, channels, kernel_size, stride),
142 | nn.ReLU(inplace=True),
143 | nn.Conv1d(channels, channels * 2, 1),
144 | nn.GLU(dim=1),
145 | )
146 | )
147 |
148 | out_channels = current_channels if index > 0 else num_sources * in_channels
149 |
150 | decode = [
151 | nn.Conv1d(channels, channels * 2, 3, padding=1, bias=False),
152 | nn.GLU(dim=1),
153 | nn.ConvTranspose1d(
154 | channels,
155 | out_channels,
156 | kernel_size,
157 | stride,
158 | ),
159 | ]
160 | if index > 0:
161 | decode.append(nn.ReLU(inplace=True))
162 | self.decoder.insert(0, nn.Sequential(*decode))
163 | current_channels = channels
164 | channels *= 2
165 |
166 | channels = current_channels
167 |
168 | encoder_layer = RelEncoderLayer(
169 | d_model=channels, dim_feedforward=channels * 4, **kwargs
170 | )
171 | self.transformer = nn.TransformerEncoder(encoder_layer, attention_layers)
172 | self.apply(rescale_conv(reference=rescale))
173 |
174 | def forward(self, x, context_size: int = None):
175 | batch, ch, _ = x.shape
176 | mono = x.mean(1, keepdim=True)
177 | mu = mono.mean(dim=-1, keepdim=True)
178 | std = mono.std(dim=-1, keepdim=True).add_(1e-5)
179 | x = standardize(x, mu, std)
180 |
181 | if hasattr(self, "up_sample"):
182 | x = self.up_sample(x)
183 |
184 | saved = []
185 | for encode in self.encoder:
186 | x = encode(x)
187 | saved.append(x)
188 |
189 | x = x.transpose(1, 2)
190 | mask = None
191 | if context_size is None and self.context_size is not None:
192 | context_size = self.context_size
193 |
194 | if context_size and context_size < x.size(1):
195 | mask = x.new_ones(x.size(1), x.size(1), dtype=torch.bool)
196 | mask = torch.triu(mask, diagonal=context_size)
197 | mask = mask | mask.T
198 |
199 | x = self.transformer(x, mask).transpose(1, 2)
200 |
201 | for decode in self.decoder:
202 | skip = saved.pop()
203 | x = decode(x + skip[..., : x.shape[-1]])
204 |
205 | if hasattr(self, "down_sample"):
206 | x = self.down_sample(x)
207 |
208 | x = destandardize(x, mu, std)
209 | x = x.view(batch, -1, ch, x.shape[-1])
210 | return x
211 |
--------------------------------------------------------------------------------
/aimless/models/xumx.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class X_UMX(nn.Module):
7 | __constants__ = ["max_bins"]
8 |
9 | max_bins: int
10 |
11 | def __init__(
12 | self, n_fft=4096, hidden_channels=512, max_bins=None, nb_channels=2, nb_layers=3
13 | ):
14 | super().__init__()
15 |
16 | self.nb_output_bins = n_fft // 2 + 1
17 | if max_bins:
18 | self.max_bins = max_bins
19 | else:
20 | self.max_bins = self.nb_output_bins
21 | self.hidden_channels = hidden_channels
22 | self.n_fft = n_fft
23 | self.nb_channels = nb_channels
24 | self.nb_layers = nb_layers
25 |
26 | self.input_means = nn.Parameter(torch.zeros(4 * self.max_bins))
27 | self.input_scale = nn.Parameter(torch.ones(4 * self.max_bins))
28 |
29 | self.output_means = nn.Parameter(torch.zeros(4 * self.nb_output_bins))
30 | self.output_scale = nn.Parameter(torch.ones(4 * self.nb_output_bins))
31 |
32 | self.affine1 = nn.Sequential(
33 | nn.Conv1d(
34 | nb_channels * self.max_bins * 4,
35 | hidden_channels * 4,
36 | 1,
37 | bias=False,
38 | groups=4,
39 | ),
40 | nn.BatchNorm1d(hidden_channels * 4),
41 | nn.Tanh(),
42 | )
43 | self.bass_lstm = nn.LSTM(
44 | input_size=self.hidden_channels,
45 | hidden_size=self.hidden_channels // 2,
46 | num_layers=nb_layers,
47 | dropout=0.4,
48 | bidirectional=True,
49 | )
50 | self.drums_lstm = nn.LSTM(
51 | input_size=self.hidden_channels,
52 | hidden_size=self.hidden_channels // 2,
53 | num_layers=nb_layers,
54 | dropout=0.4,
55 | bidirectional=True,
56 | )
57 | self.vocals_lstm = nn.LSTM(
58 | input_size=self.hidden_channels,
59 | hidden_size=self.hidden_channels // 2,
60 | num_layers=nb_layers,
61 | dropout=0.4,
62 | bidirectional=True,
63 | )
64 | self.other_lstm = nn.LSTM(
65 | input_size=self.hidden_channels,
66 | hidden_size=self.hidden_channels // 2,
67 | num_layers=nb_layers,
68 | dropout=0.4,
69 | bidirectional=True,
70 | )
71 |
72 | self.affine2 = nn.Sequential(
73 | nn.Conv1d(hidden_channels * 2, hidden_channels * 4, 1, bias=False),
74 | nn.BatchNorm1d(hidden_channels * 4),
75 | nn.ReLU(inplace=True),
76 | nn.Conv1d(
77 | hidden_channels * 4,
78 | nb_channels * self.nb_output_bins * 4,
79 | 1,
80 | bias=False,
81 | groups=4,
82 | ),
83 | nn.BatchNorm1d(nb_channels * self.nb_output_bins * 4),
84 | )
85 |
86 | def forward(self, spec: torch.Tensor):
87 | batch, channels, bins, frames = spec.shape
88 | spec = spec[..., : self.max_bins, :]
89 |
90 | x = (
91 | spec.unsqueeze(1) + self.input_means.view(4, 1, -1, 1)
92 | ) * self.input_scale.view(4, 1, -1, 1)
93 |
94 | x = x.reshape(batch, -1, frames)
95 | cross_1 = self.affine1(x).view(batch, 4, -1, frames).mean(1)
96 |
97 | cross_1 = cross_1.permute(2, 0, 1)
98 | bass, *_ = self.bass_lstm(cross_1)
99 | drums, *_ = self.drums_lstm(cross_1)
100 | others, *_ = self.other_lstm(cross_1)
101 | vocals, *_ = self.vocals_lstm(cross_1)
102 |
103 | avg = (bass + drums + vocals + others) * 0.25
104 | cross_2 = torch.cat([cross_1, avg], 2).permute(1, 2, 0)
105 |
106 | mask = self.affine2(cross_2).view(
107 | batch, 4, channels, bins, frames
108 | ) * self.output_scale.view(4, 1, -1, 1) + self.output_means.view(4, 1, -1, 1)
109 | return mask.relu()
110 |
--------------------------------------------------------------------------------
/aimless/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .mwf import MWF
2 |
3 |
4 | MDX_SOURCES = ["drums", "bass", "other", "vocals"]
5 | SDX_SOURCES = ["music", "sfx", "speech"]
6 | SE_SOURCES = ["speech", "noise"]
7 |
--------------------------------------------------------------------------------
/aimless/utils/mwf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import norbert
3 |
4 |
5 | class MWF(torch.nn.Module):
6 | def __init__(
7 | self, residual_model=False, softmask=False, alpha=1.0, n_iter=1
8 | ) -> None:
9 | super().__init__()
10 | self.residual_model = residual_model
11 | self.n_iter = n_iter
12 | self.softmask = softmask
13 | self.alpha = alpha
14 |
15 | def forward(self, msk_hat, mix_spec):
16 | assert msk_hat.ndim > mix_spec.ndim
17 | if not self.softmask and not self.n_iter and not self.residual_model:
18 | return msk_hat * mix_spec.unsqueeze(1)
19 |
20 | V = msk_hat * mix_spec.abs().unsqueeze(1)
21 | if self.softmask and self.alpha != 1:
22 | V = V.pow(self.alpha)
23 |
24 | X = mix_spec.transpose(1, 3).contiguous()
25 | V = V.permute(0, 4, 3, 2, 1).contiguous()
26 |
27 | if self.residual_model or V.shape[4] == 1:
28 | V = norbert.residual_model(V, X, self.alpha if self.softmask else 1)
29 |
30 | Y = norbert.wiener(
31 | V, X.to(torch.complex128), self.n_iter, use_softmask=self.softmask
32 | ).to(X.dtype)
33 |
34 | Y = Y.permute(0, 4, 3, 2, 1).contiguous()
35 | return Y
36 |
--------------------------------------------------------------------------------
/aimless/utils/utils.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/aimless/utils/utils.py
--------------------------------------------------------------------------------
/cfg/cdx_a/bandsplit_rnn.yaml:
--------------------------------------------------------------------------------
1 | # pytorch_lightning==1.8.5.post0
2 | seed_everything: 2434
3 | trainer:
4 | logger: true
5 | enable_checkpointing: true
6 | callbacks:
7 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
8 | init_args:
9 | dirpath: null
10 | filename: null
11 | monitor: null
12 | verbose: false
13 | save_last: true
14 | save_top_k: 1
15 | save_weights_only: false
16 | mode: min
17 | auto_insert_metric_name: true
18 | every_n_train_steps: 2000
19 | train_time_interval: null
20 | every_n_epochs: null
21 | save_on_train_epoch_end: null
22 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
23 | init_args:
24 | dirpath: null
25 | filename: null
26 | monitor: null
27 | verbose: false
28 | save_last: null
29 | save_top_k: -1
30 | save_weights_only: true
31 | mode: min
32 | auto_insert_metric_name: true
33 | every_n_train_steps: 1000
34 | train_time_interval: null
35 | every_n_epochs: null
36 | save_on_train_epoch_end: null
37 | - class_path: pytorch_lightning.callbacks.ModelSummary
38 | init_args:
39 | max_depth: 2
40 | default_root_dir: null
41 | gradient_clip_val: null
42 | gradient_clip_algorithm: null
43 | num_nodes: 1
44 | num_processes: null
45 | devices: null
46 | gpus: null
47 | auto_select_gpus: false
48 | tpu_cores: null
49 | ipus: null
50 | enable_progress_bar: true
51 | overfit_batches: 0.0
52 | track_grad_norm: -1
53 | check_val_every_n_epoch: 1
54 | fast_dev_run: false
55 | accumulate_grad_batches: null
56 | max_epochs: null
57 | min_epochs: null
58 | max_steps: 99000
59 | min_steps: null
60 | max_time: null
61 | limit_train_batches: null
62 | limit_val_batches: 0
63 | limit_test_batches: null
64 | limit_predict_batches: null
65 | val_check_interval: null
66 | log_every_n_steps: 1
67 | accelerator: gpu
68 | strategy: ddp
69 | sync_batchnorm: false
70 | precision: 32
71 | enable_model_summary: true
72 | num_sanity_val_steps: 0
73 | resume_from_checkpoint: null
74 | profiler: null
75 | benchmark: null
76 | deterministic: null
77 | reload_dataloaders_every_n_epochs: 0
78 | auto_lr_find: false
79 | replace_sampler_ddp: true
80 | detect_anomaly: false
81 | auto_scale_batch_size: false
82 | plugins: null
83 | amp_backend: native
84 | amp_level: null
85 | move_metrics_to_cpu: false
86 | multiple_trainloader_mode: max_size_cycle
87 | inference_mode: true
88 | ckpt_path: null
89 | model:
90 | class_path: aimless.lightning.freq_mask.MaskPredictor
91 | init_args:
92 | model:
93 | class_path: aimless.models.band_split_rnn.BandSplitRNN
94 | init_args:
95 | n_fft: 4096
96 | split_freqs:
97 | - 100
98 | - 200
99 | - 300
100 | - 400
101 | - 500
102 | - 600
103 | - 700
104 | - 800
105 | - 900
106 | - 1000
107 | - 1250
108 | - 1500
109 | - 1750
110 | - 2000
111 | - 2250
112 | - 2500
113 | - 2750
114 | - 3000
115 | - 3250
116 | - 3500
117 | - 3750
118 | - 4000
119 | - 4500
120 | - 5000
121 | - 5500
122 | - 6000
123 | - 6500
124 | - 7000
125 | - 7500
126 | - 8000
127 | - 9000
128 | - 10000
129 | - 11000
130 | - 12000
131 | - 13000
132 | - 14000
133 | - 15000
134 | - 16000
135 | - 18000
136 | - 20000
137 | - 22000
138 | hidden_size: 128
139 | num_layers: 12
140 | norm_groups: 4
141 | criterion:
142 | class_path: aimless.loss.freq.MDLoss
143 | init_args:
144 | mcoeff: 10
145 | n_fft: 4096
146 | hop_length: 1024
147 | transforms:
148 | - class_path: aimless.augment.SpeedPerturb
149 | init_args:
150 | orig_freq: 44100
151 | speeds:
152 | - 90
153 | - 100
154 | - 110
155 | p: 0.2
156 | - class_path: aimless.augment.RandomPitch
157 | init_args:
158 | semitones:
159 | - -1
160 | - 1
161 | - 0
162 | - 1
163 | - 2
164 | n_fft: 2048
165 | hop_length: 512
166 | p: 0.2
167 | target_track: sdx
168 | targets:
169 | music: null
170 | n_fft: 4096
171 | hop_length: 1024
172 | residual_model: true
173 | softmask: false
174 | alpha: 1.0
175 | n_iter: 1
176 | data:
177 | class_path: data.lightning.DnR
178 | init_args:
179 | root: /import/c4dm-datasets-ext/sdx-2023/dnr_v2/dnr_v2/
180 | seq_duration: 3.0
181 | samples_per_track: 144
182 | random: true
183 | include_val: true
184 | random_track_mix: true
185 | transforms:
186 | - class_path: data.augment.RandomGain
187 | init_args:
188 | low: 0.25
189 | high: 1.25
190 | p: 1.0
191 | - class_path: data.augment.RandomFlipPhase
192 | init_args:
193 | p: 0.5
194 | - class_path: data.augment.RandomSwapLR
195 | init_args:
196 | p: 0.5
197 | batch_size: 16
198 | optimizer:
199 | class_path: torch.optim.Adam
200 | init_args:
201 | lr: 0.0003
202 | betas:
203 | - 0.9
204 | - 0.999
205 | eps: 1.0e-08
206 | weight_decay: 0
207 | amsgrad: false
208 | foreach: null
209 | maximize: false
210 | capturable: false
211 | differentiable: false
212 | fused: false
213 |
--------------------------------------------------------------------------------
/cfg/cdx_a/hdemucs.yaml:
--------------------------------------------------------------------------------
1 | # pytorch_lightning==1.8.5.post0
2 | seed_everything: 2434
3 | trainer:
4 | logger: true
5 | enable_checkpointing: true
6 | callbacks:
7 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
8 | init_args:
9 | dirpath: null
10 | filename: null
11 | monitor: null
12 | verbose: false
13 | save_last: true
14 | save_top_k: 1
15 | save_weights_only: false
16 | mode: min
17 | auto_insert_metric_name: true
18 | every_n_train_steps: 2000
19 | train_time_interval: null
20 | every_n_epochs: null
21 | save_on_train_epoch_end: null
22 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
23 | init_args:
24 | dirpath: null
25 | filename: null
26 | monitor: null
27 | verbose: false
28 | save_last: null
29 | save_top_k: -1
30 | save_weights_only: true
31 | mode: min
32 | auto_insert_metric_name: true
33 | every_n_train_steps: 2000
34 | train_time_interval: null
35 | every_n_epochs: null
36 | save_on_train_epoch_end: null
37 | - class_path: pytorch_lightning.callbacks.ModelSummary
38 | init_args:
39 | max_depth: 2
40 | default_root_dir: null
41 | gradient_clip_val: null
42 | gradient_clip_algorithm: null
43 | num_nodes: 1
44 | num_processes: null
45 | devices: null
46 | gpus: null
47 | auto_select_gpus: false
48 | tpu_cores: null
49 | ipus: null
50 | enable_progress_bar: true
51 | overfit_batches: 0.0
52 | track_grad_norm: -1
53 | check_val_every_n_epoch: 1
54 | fast_dev_run: false
55 | accumulate_grad_batches: 4
56 | max_epochs: null
57 | min_epochs: null
58 | max_steps: 30000
59 | min_steps: null
60 | max_time: null
61 | limit_train_batches: null
62 | limit_val_batches: 0
63 | limit_test_batches: null
64 | limit_predict_batches: null
65 | val_check_interval: null
66 | log_every_n_steps: 1
67 | accelerator: gpu
68 | strategy: ddp
69 | sync_batchnorm: false
70 | precision: 32
71 | enable_model_summary: true
72 | num_sanity_val_steps: 0
73 | resume_from_checkpoint: null
74 | profiler: null
75 | benchmark: null
76 | deterministic: null
77 | reload_dataloaders_every_n_epochs: 0
78 | auto_lr_find: false
79 | replace_sampler_ddp: true
80 | detect_anomaly: false
81 | auto_scale_batch_size: false
82 | plugins: null
83 | amp_backend: native
84 | amp_level: null
85 | move_metrics_to_cpu: false
86 | multiple_trainloader_mode: max_size_cycle
87 | inference_mode: true
88 | ckpt_path: null
89 | model:
90 | class_path: aimless.lightning.waveform.WaveformSeparator
91 | init_args:
92 | model:
93 | class_path: torchaudio.models.HDemucs
94 | init_args:
95 | sources:
96 | - music
97 | - sfx
98 | - speech
99 | audio_channels: 1
100 | channels: 64
101 | growth: 2
102 | nfft: 4096
103 | depth: 6
104 | freq_emb: 0.2
105 | emb_scale: 10
106 | emb_smooth: true
107 | kernel_size: 8
108 | time_stride: 2
109 | stride: 4
110 | context: 1
111 | context_enc: 0
112 | norm_starts: 4
113 | norm_groups: 4
114 | dconv_depth: 2
115 | dconv_comp: 4
116 | dconv_attn: 4
117 | dconv_lstm: 4
118 | dconv_init: 0.0001
119 | criterion:
120 | class_path: aimless.loss.time.NegativeSDR
121 | transforms:
122 | - class_path: aimless.augment.SpeedPerturb
123 | init_args:
124 | orig_freq: 44100
125 | speeds:
126 | - 50
127 | - 60
128 | - 70
129 | - 80
130 | - 90
131 | - 100
132 | - 110
133 | - 120
134 | - 130
135 | - 140
136 | - 150
137 | p: 0.3
138 | - class_path: aimless.augment.RandomPitch
139 | init_args:
140 | semitones:
141 | - -7
142 | - -6
143 | - -5
144 | - -4
145 | - -3
146 | - -2
147 | - -1
148 | - 1
149 | - 0
150 | - 1
151 | - 2
152 | - 3
153 | - 4
154 | - 5
155 | - 6
156 | - 7
157 | n_fft: 2048
158 | hop_length: 512
159 | p: 0.3
160 | use_sdx_targets: false
161 | target_track: sdx
162 | targets:
163 | music: null
164 | sfx: null
165 | speech: null
166 | data:
167 | class_path: data.lightning.DnR
168 | init_args:
169 | root: /import/c4dm-datasets-ext/sdx-2023/dnr_v2/dnr_v2/
170 | seq_duration: 6.0
171 | samples_per_track: 144
172 | random: true
173 | include_val: true
174 | random_track_mix: true
175 | transforms:
176 | - class_path: data.augment.RandomGain
177 | init_args:
178 | low: 0.25
179 | high: 1.25
180 | p: 1.0
181 | - class_path: data.augment.RandomFlipPhase
182 | init_args:
183 | p: 0.5
184 | - class_path: data.augment.RandomSwapLR
185 | init_args:
186 | p: 0.5
187 | batch_size: 3
188 | optimizer:
189 | class_path: torch.optim.Adam
190 | init_args:
191 | lr: 0.0003
192 | betas:
193 | - 0.9
194 | - 0.999
195 | eps: 1.0e-08
196 | weight_decay: 0
197 | amsgrad: false
198 | foreach: null
199 | maximize: false
200 | capturable: false
201 | differentiable: false
202 | fused: false
203 |
--------------------------------------------------------------------------------
/cfg/demucs.yaml:
--------------------------------------------------------------------------------
1 | # pytorch_lightning==1.8.5.post0
2 | seed_everything: true
3 | trainer:
4 | callbacks:
5 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
6 | init_args:
7 | save_last: true
8 | every_n_train_steps: 2000
9 | filename: "{epoch}-{step}"
10 | - class_path: pytorch_lightning.callbacks.ModelSummary
11 | init_args:
12 | max_depth: 2
13 | logger: true
14 | enable_checkpointing: true
15 | callbacks: null
16 | default_root_dir: null
17 | gradient_clip_val: null
18 | gradient_clip_algorithm: null
19 | num_nodes: 1
20 | num_processes: null
21 | devices: null
22 | gpus: null
23 | auto_select_gpus: false
24 | tpu_cores: null
25 | ipus: null
26 | enable_progress_bar: true
27 | overfit_batches: 0.0
28 | track_grad_norm: -1
29 | check_val_every_n_epoch: 1
30 | fast_dev_run: false
31 | accumulate_grad_batches: null
32 | max_epochs: null
33 | min_epochs: null
34 | max_steps: -1
35 | min_steps: null
36 | max_time: null
37 | limit_train_batches: null
38 | limit_val_batches: null
39 | limit_test_batches: null
40 | limit_predict_batches: null
41 | val_check_interval: null
42 | log_every_n_steps: 1
43 | accelerator: gpu
44 | sync_batchnorm: true
45 | precision: 16
46 | enable_model_summary: true
47 | num_sanity_val_steps: 2
48 | resume_from_checkpoint: null
49 | profiler: null
50 | benchmark: null
51 | deterministic: null
52 | reload_dataloaders_every_n_epochs: 0
53 | auto_lr_find: false
54 | replace_sampler_ddp: true
55 | detect_anomaly: false
56 | auto_scale_batch_size: false
57 | plugins: null
58 | amp_backend: native
59 | amp_level: null
60 | move_metrics_to_cpu: false
61 | multiple_trainloader_mode: max_size_cycle
62 | inference_mode: true
63 | ckpt_path: null
64 | model:
65 | class_path: aimless.lightning.waveform.WaveformSeparator
66 | init_args:
67 | model:
68 | class_path: aimless.models.demucs_split.DemucsSplit
69 | init_args:
70 | channels: 48
71 | criterion:
72 | class_path: aimless.loss.time.L1Loss
73 | targets: {vocals, drums, bass, other}
74 | data:
75 | class_path: data.lightning.MUSDB
76 | init_args:
77 | root: /import/c4dm-datasets-ext/musdb18hq/
78 | seq_duration: 10.0
79 | samples_per_track: 150
80 | random: true
81 | random_track_mix: true
82 | transforms:
83 | - class_path: data.augment.RandomGain
84 | - class_path: data.augment.RandomFlipPhase
85 | - class_path: data.augment.RandomSwapLR
86 | - class_path: data.augment.LimitAug
87 | init_args:
88 | sample_rate: 44100
89 | batch_size: 4
90 | optimizer:
91 | class_path: torch.optim.Adam
92 | init_args:
93 | lr: 0.0003
--------------------------------------------------------------------------------
/cfg/hdemucs.yaml:
--------------------------------------------------------------------------------
1 | # pytorch_lightning==1.8.5.post0
2 | seed_everything: true
3 | trainer:
4 | callbacks:
5 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
6 | init_args:
7 | save_last: true
8 | every_n_train_steps: 2000
9 | filename: "{epoch}-{step}"
10 | - class_path: pytorch_lightning.callbacks.ModelSummary
11 | init_args:
12 | max_depth: 2
13 | logger: true
14 | enable_checkpointing: true
15 | # callbacks: null
16 | default_root_dir: null
17 | gradient_clip_val: null
18 | gradient_clip_algorithm: null
19 | num_nodes: 1
20 | num_processes: null
21 | devices: 1
22 | gpus: null
23 | auto_select_gpus: false
24 | tpu_cores: null
25 | ipus: null
26 | enable_progress_bar: true
27 | overfit_batches: 0.0
28 | track_grad_norm: -1
29 | check_val_every_n_epoch: 1
30 | fast_dev_run: false
31 | accumulate_grad_batches: null
32 | max_epochs: null
33 | min_epochs: null
34 | max_steps: -1
35 | min_steps: null
36 | max_time: null
37 | limit_train_batches: null
38 | limit_val_batches: null
39 | limit_test_batches: null
40 | limit_predict_batches: null
41 | val_check_interval: null
42 | log_every_n_steps: 1
43 | # accelerator: cpu
44 | accelerator: gpu
45 | strategy: ddp
46 | sync_batchnorm: true
47 | precision: 32
48 | enable_model_summary: true
49 | num_sanity_val_steps: 2
50 | resume_from_checkpoint: null
51 | profiler: null
52 | benchmark: null
53 | deterministic: null
54 | reload_dataloaders_every_n_epochs: 0
55 | auto_lr_find: false
56 | replace_sampler_ddp: true
57 | detect_anomaly: false
58 | auto_scale_batch_size: false
59 | plugins: null
60 | amp_backend: native
61 | amp_level: null
62 | move_metrics_to_cpu: false
63 | multiple_trainloader_mode: max_size_cycle
64 | inference_mode: true
65 | ckpt_path: null
66 | model:
67 | class_path: aimless.lightning.waveform.WaveformSeparator
68 | init_args:
69 | model:
70 | class_path: torchaudio.models.HDemucs
71 | init_args:
72 | sources:
73 | - vocals
74 | - drums
75 | - bass
76 | - other
77 | channels: 48
78 | criterion:
79 | class_path: aimless.loss.time.L1Loss
80 | transforms:
81 | - class_path: aimless.augment.SpeedPerturb
82 | init_args:
83 | orig_freq: 44100
84 | speeds:
85 | - 90
86 | - 100
87 | - 110
88 | p: 0.2
89 | - class_path: aimless.augment.RandomPitch
90 | init_args:
91 | semitones:
92 | - -1
93 | - 1
94 | - 0
95 | - 1
96 | - 2
97 | p: 0.2
98 | targets: {vocals, drums, bass, other}
99 | data:
100 | class_path: data.lightning.LabelNoise
101 | init_args:
102 | # root: /import/c4dm-datasets-ext/musdb18hq/
103 | # root: /Volumes/samsung_t5/moisesdb23_labelnoise_v1.0_split
104 | root: /import/c4dm-datasets-ext/cm007/moisesdb23_labelnoise_v1.0_split
105 | seq_duration: 10.0
106 | # samples_per_track: 1
107 | samples_per_track: 100
108 | transforms:
109 | - class_path: data.augment.RandomParametricEQ
110 | init_args:
111 | sample_rate: 44100
112 | p: 0.7
113 | - class_path: data.augment.RandomPedalboardDistortion
114 | init_args:
115 | sample_rate: 44100
116 | p: 0.01
117 | - class_path: data.augment.RandomPedalboardDelay
118 | init_args:
119 | sample_rate: 44100
120 | p: 0.02
121 | - class_path: data.augment.RandomPedalboardChorus
122 | init_args:
123 | sample_rate: 44100
124 | p: 0.01
125 | - class_path: data.augment.RandomPedalboardPhaser
126 | init_args:
127 | sample_rate: 44100
128 | p: 0.01
129 | - class_path: data.augment.RandomPedalboardCompressor
130 | init_args:
131 | sample_rate: 44100
132 | p: 0.5
133 | - class_path: data.augment.RandomPedalboardReverb
134 | init_args:
135 | sample_rate: 44100
136 | p: 0.2
137 | - class_path: data.augment.RandomStereoWidener
138 | init_args:
139 | sample_rate: 44100
140 | p: 0.3
141 | - class_path: data.augment.RandomPedalboardLimiter
142 | init_args:
143 | sample_rate: 44100
144 | p: 0.1
145 | - class_path: data.augment.RandomVolumeAutomation
146 | init_args:
147 | sample_rate: 44100
148 | p: 0.1
149 | - class_path: data.augment.LoudnessNormalize
150 | init_args:
151 | sample_rate: 44100
152 | target_lufs_db: -32.0
153 | p: 1.0
154 | random: true
155 | random_track_mix: false
156 | # batch_size: 1
157 | batch_size: 5
158 | optimizer:
159 | class_path: torch.optim.Adam
160 | init_args:
161 | lr: 0.0003
162 |
--------------------------------------------------------------------------------
/cfg/mdx_a/hdemucs.yaml:
--------------------------------------------------------------------------------
1 | # pytorch_lightning==1.8.5.post0
2 | seed_everything: 45
3 | trainer:
4 | logger: true
5 | enable_checkpointing: true
6 | callbacks:
7 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
8 | init_args:
9 | dirpath: null
10 | filename: null
11 | monitor: null
12 | verbose: false
13 | save_last: true
14 | save_top_k: 1
15 | save_weights_only: false
16 | mode: min
17 | auto_insert_metric_name: true
18 | every_n_train_steps: 2000
19 | train_time_interval: null
20 | every_n_epochs: null
21 | save_on_train_epoch_end: null
22 | - class_path: pytorch_lightning.callbacks.ModelSummary
23 | init_args:
24 | max_depth: 2
25 | default_root_dir: /import/c4dm-datasets-ext/cm007/mdx_checkpoints/
26 | gradient_clip_val: null
27 | gradient_clip_algorithm: null
28 | num_nodes: 1
29 | num_processes: null
30 | # devices: 1
31 | devices: 4
32 | gpus: null
33 | auto_select_gpus: false
34 | tpu_cores: null
35 | ipus: null
36 | enable_progress_bar: true
37 | overfit_batches: 0.0
38 | track_grad_norm: -1
39 | check_val_every_n_epoch: 10
40 | fast_dev_run: false
41 | accumulate_grad_batches: 4
42 | # accumulate_grad_batches: 32
43 | # accumulate_grad_batches: 64
44 | max_epochs: -1
45 | min_epochs: null
46 | max_steps: -1
47 | min_steps: null
48 | max_time: null
49 | limit_train_batches: null
50 | limit_val_batches: null
51 | limit_test_batches: null
52 | limit_predict_batches: null
53 | val_check_interval: null
54 | log_every_n_steps: 1
55 | accelerator: gpu
56 | strategy: ddp
57 | sync_batchnorm: false
58 | precision: 32
59 | enable_model_summary: true
60 | num_sanity_val_steps: 2
61 | resume_from_checkpoint: null
62 | profiler: null
63 | benchmark: null
64 | deterministic: null
65 | reload_dataloaders_every_n_epochs: 0
66 | auto_lr_find: false
67 | replace_sampler_ddp: true
68 | detect_anomaly: false
69 | auto_scale_batch_size: false
70 | plugins: null
71 | amp_backend: native
72 | amp_level: null
73 | move_metrics_to_cpu: false
74 | multiple_trainloader_mode: max_size_cycle
75 | inference_mode: true
76 | ckpt_path: null
77 | #ckpt_path: /import/c4dm-datasets-ext/cm007/mdx_checkpoints/epoch=750-step=212000.ckpt
78 | model:
79 | class_path: aimless.lightning.waveform.WaveformSeparator
80 | init_args:
81 | model:
82 | class_path: torchaudio.models.HDemucs
83 | init_args:
84 | sources:
85 | - vocals
86 | - drums
87 | - bass
88 | - other
89 | audio_channels: 2
90 | channels: 64
91 | growth: 2
92 | nfft: 4096
93 | depth: 6
94 | freq_emb: 0.2
95 | emb_scale: 10
96 | emb_smooth: true
97 | kernel_size: 8
98 | time_stride: 2
99 | stride: 4
100 | context: 1
101 | context_enc: 0
102 | norm_starts: 4
103 | norm_groups: 4
104 | dconv_depth: 2
105 | dconv_comp: 4
106 | dconv_attn: 4
107 | dconv_lstm: 4
108 | dconv_init: 0.0001
109 | criterion:
110 | class_path: aimless.loss.time.NegativeSDR
111 | transforms:
112 | - class_path: aimless.augment.SpeedPerturb
113 | init_args:
114 | orig_freq: 44100
115 | speeds:
116 | - 90
117 | - 100
118 | - 110
119 | p: 0.3
120 | - class_path: aimless.augment.RandomPitch
121 | init_args:
122 | semitones:
123 | - -1
124 | - 1
125 | - 0
126 | - 1
127 | - 2
128 | p: 0.3
129 | targets: {vocals, drums, bass, other}
130 | data:
131 | class_path: data.lightning.LabelNoise
132 | init_args:
133 | # root: /import/c4dm-datasets-ext/musdb18hq/
134 | # root: /Volumes/samsung_t5/moisesdb23_labelnoise_v1.0_split
135 | root: /import/c4dm-datasets-ext/cm007/moisesdb23_labelnoise_v1.0_split
136 | seq_duration: 10.0
137 | # samples_per_track: 1
138 | samples_per_track: 144
139 | transforms:
140 | - class_path: data.augment.RandomParametricEQ
141 | init_args:
142 | sample_rate: 44100
143 | p: 0.3
144 | - class_path: data.augment.RandomPedalboardDistortion
145 | init_args:
146 | sample_rate: 44100
147 | p: 0.03
148 | - class_path: data.augment.RandomPedalboardDelay
149 | init_args:
150 | sample_rate: 44100
151 | p: 0.03
152 | - class_path: data.augment.RandomPedalboardChorus
153 | init_args:
154 | sample_rate: 44100
155 | p: 0.03
156 | - class_path: data.augment.RandomPedalboardPhaser
157 | init_args:
158 | sample_rate: 44100
159 | p: 0.03
160 | - class_path: data.augment.RandomPedalboardCompressor
161 | init_args:
162 | sample_rate: 44100
163 | p: 0.3
164 | - class_path: data.augment.RandomPedalboardReverb
165 | init_args:
166 | sample_rate: 44100
167 | p: 0.3
168 | - class_path: data.augment.RandomStereoWidener
169 | init_args:
170 | sample_rate: 44100
171 | p: 0.3
172 | - class_path: data.augment.RandomPedalboardLimiter
173 | init_args:
174 | sample_rate: 44100
175 | p: 0.3
176 | - class_path: data.augment.RandomVolumeAutomation
177 | init_args:
178 | sample_rate: 44100
179 | p: 0.3
180 | - class_path: data.augment.LoudnessNormalize
181 | init_args:
182 | sample_rate: 44100
183 | target_lufs_db: -32.0
184 | p: 0.3
185 | - class_path: data.augment.RandomGain
186 | init_args:
187 | low: 0.25
188 | high: 1.25
189 | p: 0.3
190 | - class_path: data.augment.RandomFlipPhase
191 | init_args:
192 | p: 0.3
193 | - class_path: data.augment.RandomSwapLR
194 | init_args:
195 | p: 0.3
196 | random: true
197 | # include_val: true
198 | random_track_mix: true
199 | # batch_size: 1
200 | batch_size: 3
201 | optimizer:
202 | class_path: torch.optim.Adam
203 | init_args:
204 | lr: 0.0003
205 | # lr: 0.0001
206 | betas:
207 | - 0.9
208 | - 0.999
209 | eps: 1.0e-08
210 | weight_decay: 0
211 | amsgrad: false
212 | foreach: null
213 | maximize: false
214 | capturable: false
215 | differentiable: false
216 | fused: false
217 |
--------------------------------------------------------------------------------
/cfg/speech_enhance.yaml:
--------------------------------------------------------------------------------
1 | # pytorch_lightning==1.8.5.post0
2 | seed_everything: true
3 | trainer:
4 | detect_anomaly: true
5 | callbacks:
6 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
7 | init_args:
8 | save_last: true
9 | every_n_train_steps: 2000
10 | filename: "{epoch}-{step}"
11 | - class_path: pytorch_lightning.callbacks.ModelSummary
12 | init_args:
13 | max_depth: 2
14 | log_every_n_steps: 1
15 | accelerator: gpu
16 | strategy: ddp
17 | sync_batchnorm: true
18 | precision: 32
19 | ckpt_path: null
20 | model:
21 | class_path: aimless.lightning.waveform.WaveformSeparator
22 | init_args:
23 | model:
24 | class_path: torchaudio.models.HDemucs
25 | init_args:
26 | sources:
27 | - speech
28 | channels: 48
29 | audio_channels: 1
30 | criterion:
31 | class_path: aimless.loss.time.L1Loss
32 | transforms:
33 | - class_path: aimless.augment.SpeedPerturb
34 | init_args:
35 | orig_freq: 44100
36 | speeds:
37 | - 90
38 | - 100
39 | - 110
40 | p: 0.2
41 | - class_path: aimless.augment.RandomPitch
42 | init_args:
43 | semitones:
44 | - -1
45 | - 1
46 | - 0
47 | - 1
48 | - 2
49 | p: 0.2
50 | target_track: se
51 | targets: {speech}
52 | data:
53 | class_path: data.lightning.SpeechNoise
54 | init_args:
55 | speech_root: /import/c4dm-datasets/VCTK-Corpus-0.92/wav48_silence_trimmed/
56 | noise_root: /import/c4dm-datasets-ext/musdb18hq-voiceless-mix/
57 | seq_duration: 6.0
58 | samples_per_track: 64
59 | least_overlap_ratio: 0.5
60 | mono: true
61 | snr_sampler:
62 | class_path: torch.distributions.Uniform
63 | init_args:
64 | low: 0
65 | high: 20
66 | batch_size: 8
67 | optimizer:
68 | class_path: torch.optim.Adam
69 | init_args:
70 | lr: 0.0003
--------------------------------------------------------------------------------
/cfg/xumx.yaml:
--------------------------------------------------------------------------------
1 | # pytorch_lightning==1.8.5.post0
2 | seed_everything: true
3 | trainer:
4 | callbacks:
5 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
6 | init_args:
7 | save_last: true
8 | every_n_train_steps: 2000
9 | filename: "{epoch}-{step}"
10 | - class_path: pytorch_lightning.callbacks.ModelSummary
11 | init_args:
12 | max_depth: 2
13 | logger: true
14 | enable_checkpointing: true
15 | callbacks: null
16 | default_root_dir: null
17 | gradient_clip_val: null
18 | gradient_clip_algorithm: null
19 | num_nodes: 1
20 | num_processes: null
21 | devices: null
22 | gpus: null
23 | auto_select_gpus: false
24 | tpu_cores: null
25 | ipus: null
26 | enable_progress_bar: true
27 | overfit_batches: 0.0
28 | track_grad_norm: -1
29 | check_val_every_n_epoch: 1
30 | fast_dev_run: false
31 | accumulate_grad_batches: null
32 | max_epochs: null
33 | min_epochs: null
34 | max_steps: -1
35 | min_steps: null
36 | max_time: null
37 | limit_train_batches: null
38 | limit_val_batches: null
39 | limit_test_batches: null
40 | limit_predict_batches: null
41 | val_check_interval: null
42 | log_every_n_steps: 1
43 | accelerator: gpu
44 | sync_batchnorm: true
45 | precision: 32
46 | enable_model_summary: true
47 | num_sanity_val_steps: 2
48 | resume_from_checkpoint: null
49 | profiler: null
50 | benchmark: null
51 | deterministic: null
52 | reload_dataloaders_every_n_epochs: 0
53 | auto_lr_find: false
54 | replace_sampler_ddp: true
55 | detect_anomaly: false
56 | auto_scale_batch_size: false
57 | plugins: null
58 | amp_backend: native
59 | amp_level: null
60 | move_metrics_to_cpu: false
61 | multiple_trainloader_mode: max_size_cycle
62 | inference_mode: true
63 | ckpt_path: null
64 | model:
65 | class_path: aimless.lightning.freq_mask.MaskPredictor
66 | init_args:
67 | model:
68 | class_path: aimless.models.xumx.X_UMX
69 | init_args:
70 | n_fft: 4096
71 | hidden_channels: 512
72 | max_bins: 1487
73 | nb_channels: 2
74 | nb_layers: 3
75 | criterion:
76 | class_path: aimless.loss.freq.CLoss
77 | init_args:
78 | mcoeff: 10
79 | n_fft: 4096
80 | hop_length: 1024
81 | n_iter: 1
82 | transforms:
83 | - class_path: aimless.augment.SpeedPerturb
84 | init_args:
85 | orig_freq: 44100
86 | speeds:
87 | - 90
88 | - 100
89 | - 110
90 | p: 0.2
91 | - class_path: aimless.augment.RandomPitch
92 | init_args:
93 | semitones:
94 | - -1
95 | - 1
96 | - 0
97 | - 1
98 | - 2
99 | p: 0.2
100 | targets: {vocals, drums, bass, other}
101 | n_fft: 4096
102 | hop_length: 1024
103 | residual_model: false
104 | softmask: false
105 | alpha: 1.0
106 | n_iter: 1
107 | data:
108 | class_path: data.lightning.MUSDB
109 | init_args:
110 | root: /import/c4dm-datasets-ext/musdb18hq/
111 | seq_duration: 6.0
112 | samples_per_track: 64
113 | random: true
114 | random_track_mix: true
115 | transforms:
116 | - class_path: data.augment.RandomGain
117 | - class_path: data.augment.RandomFlipPhase
118 | - class_path: data.augment.RandomSwapLR
119 | - class_path: data.augment.LimitAug
120 | init_args:
121 | sample_rate: 44100
122 | batch_size: 4
123 | optimizer:
124 | class_path: torch.optim.Adam
125 | init_args:
126 | lr: 0.0001
127 | weight_decay: 0.00001
128 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aim-qmul/sdx23-aimless/878101248aeed466fe3a05d0645eab048b0fc393/data/__init__.py
--------------------------------------------------------------------------------
/data/augment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import torchaudio
4 | import numpy as np
5 | import scipy.stats
6 | import scipy.signal
7 | import pyloudnorm as pyln
8 | from typing import List, Tuple
9 |
10 | from pedalboard import (
11 | Pedalboard,
12 | Gain,
13 | Chorus,
14 | Reverb,
15 | Compressor,
16 | Phaser,
17 | Delay,
18 | Distortion,
19 | Limiter,
20 | )
21 |
22 |
23 | __all__ = [
24 | "RandomSwapLR",
25 | "RandomGain",
26 | "RandomFlipPhase",
27 | "LimitAug",
28 | "CPUBase",
29 | "RandomParametricEQ",
30 | "RandomStereoWidener",
31 | "RandomVolumeAutomation",
32 | "RandomPedalboardCompressor",
33 | "RandomPedalboardDelay",
34 | "RandomPedalboardChorus",
35 | "RandomPedalboardDistortion",
36 | "RandomPedalboardPhaser",
37 | "RandomPedalboardReverb",
38 | "RandomPedalboardLimiter",
39 | "RandomSoxReverb",
40 | "LoudnessNormalize",
41 | "RandomPan",
42 | "Mono2Stereo",
43 | ]
44 |
45 |
46 | class CPUBase(object):
47 | def __init__(self, p: float = 1.0):
48 | assert 0 <= p <= 1, "invalid probability value"
49 | self.p = p
50 |
51 | def __call__(
52 | self, x: Tuple[torch.Tensor, torch.Tensor]
53 | ) -> Tuple[torch.Tensor, torch.Tensor]:
54 | """
55 | Args:
56 | x (Tuple[torch.Tensor, torch.Tensor]): (mixture, stems) where
57 | mixture : (Num_channels, L)
58 | stems: (Num_sources, Num_channels, L)
59 | Return:
60 | Tuple[torch.Tensor, torch.Tensor]: (mixture, stems) where
61 | mixture: (Num_channels, L)
62 | stems: (Num_sources, Num_channels, L)
63 | """
64 | mixture, stems = x
65 | stems_fx = []
66 | for stem_idx in np.arange(stems.shape[0]):
67 | stem = stems[stem_idx, ...]
68 | if np.random.rand() < self.p:
69 | stems_fx.append(self._transform(stem))
70 | else:
71 | stems_fx.append(stem)
72 | stems = torch.stack(stems_fx, dim=0)
73 | mixture = stems.sum(0)
74 | return (mixture, stems)
75 |
76 | def _transform(self, stems: torch.Tensor) -> torch.Tensor:
77 | raise NotImplementedError
78 |
79 |
80 | class RandomSwapLR(CPUBase):
81 | def __init__(self, p=0.5) -> None:
82 | super().__init__(p=p)
83 |
84 | def _transform(self, stems: torch.Tensor):
85 | return torch.flip(stems, [0])
86 |
87 |
88 | class RandomGain(CPUBase):
89 | def __init__(self, low=0.25, high=1.25, **kwargs) -> None:
90 | super().__init__(**kwargs)
91 | self.low = low
92 | self.high = high
93 |
94 | def _transform(self, stems):
95 | gain = np.random.rand() * (self.high - self.low) + self.low
96 | return stems * gain
97 |
98 |
99 | class RandomFlipPhase(RandomSwapLR):
100 | def __init__(self, p=0.5) -> None:
101 | super().__init__(p=p)
102 |
103 | def _transform(self, stems: torch.Tensor):
104 | return -stems
105 |
106 |
107 | def db2linear(x):
108 | return 10 ** (x / 20)
109 |
110 |
111 | class LimitAug(CPUBase):
112 | def __init__(
113 | self,
114 | target_lufs_mean=-10.887,
115 | target_lufs_std=1.191,
116 | target_loudnorm_lufs=-14.0,
117 | max_release_ms=200.0,
118 | min_release_ms=30.0,
119 | sample_rate=44100,
120 | **kwargs,
121 | ) -> None:
122 | """
123 | Args:
124 | target_lufs_mean (float): mean of target LUFS. default: -10.887 (corresponding to the statistics of musdb-L)
125 | target_lufs_std (float): std of target LUFS. default: 1.191 (corresponding to the statistics of musdb-L)
126 | target_loudnorm_lufs (float): target LUFS after loudnorm. default: -14.0
127 | max_release_ms (float): max release time of limiter. default: 200.0
128 | min_release_ms (float): min release time of limiter. default: 30.0
129 | sample_rate (int): sample rate of audio. default: 44100
130 | """
131 | super().__init__(**kwargs)
132 | self.target_lufs_sampler = torch.distributions.Normal(
133 | target_lufs_mean, target_lufs_std
134 | )
135 | self.target_loudnorm_lufs = target_loudnorm_lufs
136 | self.sample_rate = sample_rate
137 | self.board = Pedalboard([Gain(0), Limiter(threshold_db=0.0, release_ms=100.0)])
138 | self.limiter_release_sampler = torch.distributions.Uniform(
139 | min_release_ms, max_release_ms
140 | )
141 | self.meter = pyln.Meter(sample_rate)
142 |
143 | def __call__(
144 | self, x: Tuple[torch.Tensor, torch.Tensor]
145 | ) -> Tuple[torch.Tensor, torch.Tensor]:
146 | if np.random.rand() > self.p:
147 | return x
148 | mixture, stems = x
149 | mixture_np = mixture.numpy().T
150 | loudness = self.meter.integrated_loudness(mixture_np)
151 | target_lufs = self.target_lufs_sampler.sample().item()
152 | self.board[1].release_ms = self.limiter_release_sampler.sample().item()
153 |
154 | if np.isinf(loudness):
155 | aug_gain = 0.0
156 | else:
157 | aug_gain = target_lufs - loudness
158 | self.board[0].gain_db = aug_gain
159 |
160 | new_mixture_np = self.board(mixture_np, self.sample_rate)
161 | after_loudness = self.meter.integrated_loudness(new_mixture_np)
162 |
163 | if not np.isinf(after_loudness):
164 | target_gain = self.target_loudnorm_lufs - after_loudness
165 | new_mixture_np *= db2linear(target_gain)
166 |
167 | new_mixture = torch.tensor(new_mixture_np.T, dtype=mixture.dtype)
168 | # apply element-wise gain to stems
169 | stems *= new_mixture.abs() / mixture.abs().add(1e-8)
170 | return (new_mixture, stems)
171 |
172 |
173 | def loguniform(low=0, high=1):
174 | return scipy.stats.loguniform.rvs(low, high)
175 |
176 |
177 | def rand(low=0, high=1):
178 | return (torch.rand(1).numpy()[0] * (high - low)) + low
179 |
180 |
181 | def randint(low=0, high=1):
182 | return torch.randint(low, high + 1, (1,)).numpy()[0]
183 |
184 |
185 | def biqaud(
186 | gain_db: float,
187 | cutoff_freq: float,
188 | q_factor: float,
189 | sample_rate: float,
190 | filter_type: str,
191 | ):
192 | """Use design parameters to generate coeffieicnets for a specific filter type.
193 | Args:
194 | gain_db (float): Shelving filter gain in dB.
195 | cutoff_freq (float): Cutoff frequency in Hz.
196 | q_factor (float): Q factor.
197 | sample_rate (float): Sample rate in Hz.
198 | filter_type (str): Filter type.
199 | One of "low_shelf", "high_shelf", or "peaking"
200 | Returns:
201 | b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2]
202 | a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2]
203 | """
204 |
205 | A = 10 ** (gain_db / 40.0)
206 | w0 = 2.0 * np.pi * (cutoff_freq / sample_rate)
207 | alpha = np.sin(w0) / (2.0 * q_factor)
208 |
209 | cos_w0 = np.cos(w0)
210 | sqrt_A = np.sqrt(A)
211 |
212 | if filter_type == "high_shelf":
213 | b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
214 | b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
215 | b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
216 | a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
217 | a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
218 | a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
219 | elif filter_type == "low_shelf":
220 | b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
221 | b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
222 | b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
223 | a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
224 | a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
225 | a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
226 | elif filter_type == "peaking":
227 | b0 = 1 + alpha * A
228 | b1 = -2 * cos_w0
229 | b2 = 1 - alpha * A
230 | a0 = 1 + alpha / A
231 | a1 = -2 * cos_w0
232 | a2 = 1 - alpha / A
233 | else:
234 | pass
235 | # raise ValueError(f"Invalid filter_type: {filter_type}.")
236 |
237 | b = np.array([b0, b1, b2]) / a0
238 | a = np.array([a0, a1, a2]) / a0
239 |
240 | return b, a
241 |
242 |
243 | def parametric_eq(
244 | x: np.ndarray,
245 | sample_rate: float,
246 | low_shelf_gain_db: float = 0.0,
247 | low_shelf_cutoff_freq: float = 80.0,
248 | low_shelf_q_factor: float = 0.707,
249 | band_gains_db: List[float] = [0.0],
250 | band_cutoff_freqs: List[float] = [300.0],
251 | band_q_factors: List[float] = [0.707],
252 | high_shelf_gain_db: float = 0.0,
253 | high_shelf_cutoff_freq: float = 1000.0,
254 | high_shelf_q_factor: float = 0.707,
255 | dtype=np.float32,
256 | ):
257 | """Multiband parametric EQ.
258 |
259 | Low-shelf -> Band 1 -> ... -> Band N -> High-shelf
260 |
261 | Args:
262 |
263 | """
264 | assert (
265 | len(band_gains_db) == len(band_cutoff_freqs) == len(band_q_factors)
266 | ) # must define for all bands
267 |
268 | # -------- apply low-shelf filter --------
269 | b, a = biqaud(
270 | low_shelf_gain_db,
271 | low_shelf_cutoff_freq,
272 | low_shelf_q_factor,
273 | sample_rate,
274 | "low_shelf",
275 | )
276 | x = scipy.signal.lfilter(b, a, x)
277 |
278 | # -------- apply peaking filters --------
279 | for gain_db, cutoff_freq, q_factor in zip(
280 | band_gains_db, band_cutoff_freqs, band_q_factors
281 | ):
282 | b, a = biqaud(
283 | gain_db,
284 | cutoff_freq,
285 | q_factor,
286 | sample_rate,
287 | "peaking",
288 | )
289 | x = scipy.signal.lfilter(b, a, x)
290 |
291 | # -------- apply high-shelf filter --------
292 | b, a = biqaud(
293 | high_shelf_gain_db,
294 | high_shelf_cutoff_freq,
295 | high_shelf_q_factor,
296 | sample_rate,
297 | "high_shelf",
298 | )
299 | sos5 = np.concatenate((b, a))
300 | x = scipy.signal.lfilter(b, a, x)
301 |
302 | return x.astype(dtype)
303 |
304 |
305 | class RandomParametricEQ(CPUBase):
306 | def __init__(
307 | self,
308 | sample_rate: float = 44100.0,
309 | num_bands: int = 3,
310 | min_gain_db: float = -6.0,
311 | max_gain_db: float = +6.0,
312 | min_cutoff_freq: float = 1000.0,
313 | max_cutoff_freq: float = 10000.0,
314 | min_q_factor: float = 0.1,
315 | max_q_factor: float = 4.0,
316 | **kwargs,
317 | ):
318 | super().__init__(**kwargs)
319 | self.sample_rate = sample_rate
320 | self.num_bands = num_bands
321 | self.min_gain_db = min_gain_db
322 | self.max_gain_db = max_gain_db
323 | self.min_cutoff_freq = min_cutoff_freq
324 | self.max_cutoff_freq = max_cutoff_freq
325 | self.min_q_factor = min_q_factor
326 | self.max_q_factor = max_q_factor
327 |
328 | def _transform(self, x: torch.Tensor):
329 | """
330 | Args:
331 | x: (torch.Tensor): Array of audio samples with shape (chs, seq_leq).
332 | The filter will be applied the final dimension, and by default the same
333 | filter will be applied to all channels.
334 | """
335 | low_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
336 | low_shelf_cutoff_freq = loguniform(20.0, 200.0)
337 | low_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
338 |
339 | high_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
340 | high_shelf_cutoff_freq = loguniform(8000.0, 16000.0)
341 | high_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
342 |
343 | band_gain_dbs = []
344 | band_cutoff_freqs = []
345 | band_q_factors = []
346 | for _ in range(self.num_bands):
347 | band_gain_dbs.append(rand(self.min_gain_db, self.max_gain_db))
348 | band_cutoff_freqs.append(
349 | loguniform(self.min_cutoff_freq, self.max_cutoff_freq)
350 | )
351 | band_q_factors.append(rand(self.min_q_factor, self.max_q_factor))
352 |
353 | y = parametric_eq(
354 | x.numpy(),
355 | self.sample_rate,
356 | low_shelf_gain_db=low_shelf_gain_db,
357 | low_shelf_cutoff_freq=low_shelf_cutoff_freq,
358 | low_shelf_q_factor=low_shelf_q_factor,
359 | band_gains_db=band_gain_dbs,
360 | band_cutoff_freqs=band_cutoff_freqs,
361 | band_q_factors=band_q_factors,
362 | high_shelf_gain_db=high_shelf_gain_db,
363 | high_shelf_cutoff_freq=high_shelf_cutoff_freq,
364 | high_shelf_q_factor=high_shelf_q_factor,
365 | )
366 |
367 | return torch.from_numpy(y)
368 |
369 |
370 | def stereo_widener(x: torch.Tensor, width: torch.Tensor):
371 | sqrt2 = np.sqrt(2)
372 |
373 | left = x[0, ...]
374 | right = x[1, ...]
375 |
376 | mid = (left + right) / sqrt2
377 | side = (left - right) / sqrt2
378 |
379 | # amplify mid and side signal seperately:
380 | mid *= 2 * (1 - width)
381 | side *= 2 * width
382 |
383 | left = (mid + side) / sqrt2
384 | right = (mid - side) / sqrt2
385 |
386 | x = torch.stack((left, right), dim=0)
387 |
388 | return x
389 |
390 |
391 | class RandomStereoWidener(CPUBase):
392 | def __init__(
393 | self,
394 | sample_rate: float = 44100.0,
395 | min_width: float = 0.0,
396 | max_width: float = 1.0,
397 | **kwargs,
398 | ) -> None:
399 | super().__init__(**kwargs)
400 | self.sample_rate = sample_rate
401 | self.min_width = min_width
402 | self.max_width = max_width
403 |
404 | def _transform(self, x: torch.Tensor):
405 | width = rand(self.min_width, self.max_width)
406 | return stereo_widener(x, width)
407 |
408 |
409 | class RandomVolumeAutomation(CPUBase):
410 | def __init__(
411 | self,
412 | sample_rate: float = 44100.0,
413 | min_segment_seconds: float = 3.0,
414 | min_gain_db: float = -6.0,
415 | max_gain_db: float = 6.0,
416 | **kwargs,
417 | ) -> None:
418 | super().__init__(**kwargs)
419 | self.sample_rate = sample_rate
420 | self.min_segment_seconds = min_segment_seconds
421 | self.min_gain_db = min_gain_db
422 | self.max_gain_db = max_gain_db
423 |
424 | def _transform(self, x: torch.Tensor):
425 | gain_db = torch.zeros(x.shape[-1]).type_as(x)
426 |
427 | seconds = x.shape[-1] / self.sample_rate
428 | max_num_segments = max(1, int(seconds // self.min_segment_seconds))
429 |
430 | num_segments = randint(1, max_num_segments)
431 | segment_lengths = (
432 | x.shape[-1]
433 | * np.random.dirichlet(
434 | [rand(0, 10) for _ in range(num_segments)], 1
435 | ) # TODO(cm): this can crash training
436 | ).astype("int")[0]
437 |
438 | segment_lengths = np.maximum(segment_lengths, 1)
439 |
440 | # check the sum is equal to the length of the signal
441 | diff = segment_lengths.sum() - x.shape[-1]
442 | if diff < 0:
443 | segment_lengths[-1] -= diff
444 | elif diff > 0:
445 | for idx in range(num_segments):
446 | if segment_lengths[idx] > diff + 1:
447 | segment_lengths[idx] -= diff
448 | break
449 |
450 | samples_filled = 0
451 | start_gain_db = 0
452 | for idx in range(num_segments):
453 | segment_samples = segment_lengths[idx]
454 | if idx != 0:
455 | start_gain_db = end_gain_db
456 |
457 | # sample random end gain
458 | end_gain_db = rand(self.min_gain_db, self.max_gain_db)
459 | fade = torch.linspace(start_gain_db, end_gain_db, steps=segment_samples)
460 | gain_db[samples_filled : samples_filled + segment_samples] = fade
461 | samples_filled = samples_filled + segment_samples
462 |
463 | # print(gain_db)
464 | x *= 10 ** (gain_db / 20.0)
465 | return x
466 |
467 |
468 | class RandomPedalboardCompressor(CPUBase):
469 | def __init__(
470 | self,
471 | sample_rate: float = 44100.0,
472 | min_threshold_db: float = -42.0,
473 | max_threshold_db: float = -6.0,
474 | min_ratio: float = 1.5,
475 | max_ratio: float = 4.0,
476 | min_attack_ms: float = 1.0,
477 | max_attack_ms: float = 50.0,
478 | min_release_ms: float = 10.0,
479 | max_release_ms: float = 250.0,
480 | **kwargs,
481 | ) -> None:
482 | super().__init__(**kwargs)
483 | self.sample_rate = sample_rate
484 | self.min_threshold_db = min_threshold_db
485 | self.max_threshold_db = max_threshold_db
486 | self.min_ratio = min_ratio
487 | self.max_ratio = max_ratio
488 | self.min_attack_ms = min_attack_ms
489 | self.max_attack_ms = max_attack_ms
490 | self.min_release_ms = min_release_ms
491 | self.max_release_ms = max_release_ms
492 |
493 | def _transform(self, x: torch.Tensor):
494 | board = Pedalboard()
495 | threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
496 | ratio = rand(self.min_ratio, self.max_ratio)
497 | attack_ms = rand(self.min_attack_ms, self.max_attack_ms)
498 | release_ms = rand(self.min_release_ms, self.max_release_ms)
499 |
500 | board.append(
501 | Compressor(
502 | threshold_db=threshold_db,
503 | ratio=ratio,
504 | attack_ms=attack_ms,
505 | release_ms=release_ms,
506 | )
507 | )
508 |
509 | # process audio using the pedalboard
510 | return torch.from_numpy(board(x.numpy(), self.sample_rate))
511 |
512 |
513 | class RandomPedalboardDelay(CPUBase):
514 | def __init__(
515 | self,
516 | sample_rate: float = 44100.0,
517 | min_delay_seconds: float = 0.1,
518 | max_delay_sconds: float = 1.0,
519 | min_feedback: float = 0.05,
520 | max_feedback: float = 0.6,
521 | min_mix: float = 0.0,
522 | max_mix: float = 0.7,
523 | **kwargs,
524 | ) -> None:
525 | super().__init__(**kwargs)
526 | self.sample_rate = sample_rate
527 | self.min_delay_seconds = min_delay_seconds
528 | self.max_delay_seconds = max_delay_sconds
529 | self.min_feedback = min_feedback
530 | self.max_feedback = max_feedback
531 | self.min_mix = min_mix
532 | self.max_mix = max_mix
533 |
534 | def _transform(self, x: torch.Tensor):
535 | board = Pedalboard()
536 | delay_seconds = loguniform(self.min_delay_seconds, self.max_delay_seconds)
537 | feedback = rand(self.min_feedback, self.max_feedback)
538 | mix = rand(self.min_mix, self.max_mix)
539 | board.append(Delay(delay_seconds=delay_seconds, feedback=feedback, mix=mix))
540 | return torch.from_numpy(board(x.numpy(), self.sample_rate))
541 |
542 |
543 | class RandomPedalboardChorus(CPUBase):
544 | def __init__(
545 | self,
546 | sample_rate: float = 44100.0,
547 | min_rate_hz: float = 0.25,
548 | max_rate_hz: float = 4.0,
549 | min_depth: float = 0.0,
550 | max_depth: float = 0.6,
551 | min_centre_delay_ms: float = 5.0,
552 | max_centre_delay_ms: float = 10.0,
553 | min_feedback: float = 0.1,
554 | max_feedback: float = 0.6,
555 | min_mix: float = 0.1,
556 | max_mix: float = 0.7,
557 | **kwargs,
558 | ) -> None:
559 | super().__init__(**kwargs)
560 | self.sample_rate = sample_rate
561 | self.min_rate_hz = min_rate_hz
562 | self.max_rate_hz = max_rate_hz
563 | self.min_depth = min_depth
564 | self.max_depth = max_depth
565 | self.min_centre_delay_ms = min_centre_delay_ms
566 | self.max_centre_delay_ms = max_centre_delay_ms
567 | self.min_feedback = min_feedback
568 | self.max_feedback = max_feedback
569 | self.min_mix = min_mix
570 | self.max_mix = max_mix
571 |
572 | def _transform(self, x: torch.Tensor):
573 | board = Pedalboard()
574 | rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
575 | depth = rand(self.min_depth, self.max_depth)
576 | centre_delay_ms = rand(self.min_centre_delay_ms, self.max_centre_delay_ms)
577 | feedback = rand(self.min_feedback, self.max_feedback)
578 | mix = rand(self.min_mix, self.max_mix)
579 | board.append(
580 | Chorus(
581 | rate_hz=rate_hz,
582 | depth=depth,
583 | centre_delay_ms=centre_delay_ms,
584 | feedback=feedback,
585 | mix=mix,
586 | )
587 | )
588 | # process audio using the pedalboard
589 | return torch.from_numpy(board(x.numpy(), self.sample_rate))
590 |
591 |
592 | class RandomPedalboardPhaser(CPUBase):
593 | def __init__(
594 | self,
595 | sample_rate: float = 44100.0,
596 | min_rate_hz: float = 0.25,
597 | max_rate_hz: float = 5.0,
598 | min_depth: float = 0.1,
599 | max_depth: float = 0.6,
600 | min_centre_frequency_hz: float = 200.0,
601 | max_centre_frequency_hz: float = 600.0,
602 | min_feedback: float = 0.1,
603 | max_feedback: float = 0.6,
604 | min_mix: float = 0.1,
605 | max_mix: float = 0.7,
606 | **kwargs,
607 | ) -> None:
608 | super().__init__(**kwargs)
609 | self.sample_rate = sample_rate
610 | self.min_rate_hz = min_rate_hz
611 | self.max_rate_hz = max_rate_hz
612 | self.min_depth = min_depth
613 | self.max_depth = max_depth
614 | self.min_centre_frequency_hz = min_centre_frequency_hz
615 | self.max_centre_frequency_hz = max_centre_frequency_hz
616 | self.min_feedback = min_feedback
617 | self.max_feedback = max_feedback
618 | self.min_mix = min_mix
619 | self.max_mix = max_mix
620 |
621 | def _transform(self, x: torch.Tensor):
622 | board = Pedalboard()
623 | rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
624 | depth = rand(self.min_depth, self.max_depth)
625 | centre_frequency_hz = rand(
626 | self.min_centre_frequency_hz, self.min_centre_frequency_hz
627 | )
628 | feedback = rand(self.min_feedback, self.max_feedback)
629 | mix = rand(self.min_mix, self.max_mix)
630 | board.append(
631 | Phaser(
632 | rate_hz=rate_hz,
633 | depth=depth,
634 | centre_frequency_hz=centre_frequency_hz,
635 | feedback=feedback,
636 | mix=mix,
637 | )
638 | )
639 | # process audio using the pedalboard
640 | return torch.from_numpy(board(x.numpy(), self.sample_rate))
641 |
642 |
643 | class RandomPedalboardLimiter(CPUBase):
644 | def __init__(
645 | self,
646 | sample_rate: float = 44100.0,
647 | min_threshold_db: float = -32.0,
648 | max_threshold_db: float = -6.0,
649 | min_release_ms: float = 10.0,
650 | max_release_ms: float = 300.0,
651 | **kwargs,
652 | ) -> None:
653 | super().__init__(**kwargs)
654 | self.sample_rate = sample_rate
655 | self.min_threshold_db = min_threshold_db
656 | self.max_threshold_db = max_threshold_db
657 | self.min_release_ms = min_release_ms
658 | self.max_release_ms = max_release_ms
659 |
660 | def _transform(self, x: torch.Tensor):
661 | board = Pedalboard()
662 | threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
663 | release_ms = rand(self.min_release_ms, self.max_release_ms)
664 | board.append(
665 | Limiter(
666 | threshold_db=threshold_db,
667 | release_ms=release_ms,
668 | )
669 | )
670 | return torch.from_numpy(board(x.numpy(), self.sample_rate))
671 |
672 |
673 | class RandomPedalboardDistortion(CPUBase):
674 | def __init__(
675 | self,
676 | sample_rate: float = 44100.0,
677 | min_drive_db: float = -20.0,
678 | max_drive_db: float = 12.0,
679 | **kwargs,
680 | ):
681 | super().__init__(**kwargs)
682 | self.sample_rate = sample_rate
683 | self.min_drive_db = min_drive_db
684 | self.max_drive_db = max_drive_db
685 |
686 | def _transform(self, x: torch.Tensor):
687 | board = Pedalboard()
688 | drive_db = rand(self.min_drive_db, self.max_drive_db)
689 | board.append(Distortion(drive_db=drive_db))
690 | return torch.from_numpy(board(x.numpy(), self.sample_rate))
691 |
692 |
693 | class RandomSoxReverb(CPUBase):
694 | def __init__(
695 | self,
696 | sample_rate: float = 44100.0,
697 | min_reverberance: float = 10.0,
698 | max_reverberance: float = 100.0,
699 | min_high_freq_damping: float = 0.0,
700 | max_high_freq_damping: float = 100.0,
701 | min_wet_dry: float = 0.0,
702 | max_wet_dry: float = 1.0,
703 | min_room_scale: float = 5.0,
704 | max_room_scale: float = 100.0,
705 | min_stereo_depth: float = 20.0,
706 | max_stereo_depth: float = 100.0,
707 | min_pre_delay: float = 0.0,
708 | max_pre_delay: float = 100.0,
709 | **kwargs,
710 | ) -> None:
711 | super().__init__(**kwargs)
712 | self.sample_rate = sample_rate
713 | self.min_reverberance = min_reverberance
714 | self.max_reverberance = max_reverberance
715 | self.min_high_freq_damping = min_high_freq_damping
716 | self.max_high_freq_damping = max_high_freq_damping
717 | self.min_wet_dry = min_wet_dry
718 | self.max_wet_dry = max_wet_dry
719 | self.min_room_scale = min_room_scale
720 | self.max_room_scale = max_room_scale
721 | self.min_stereo_depth = min_stereo_depth
722 | self.max_stereo_depth = max_stereo_depth
723 | self.min_pre_delay = min_pre_delay
724 | self.max_pre_delay = max_pre_delay
725 |
726 | def _transform(self, x: torch.Tensor):
727 | reverberance = rand(self.min_reverberance, self.max_reverberance)
728 | high_freq_damping = rand(self.min_high_freq_damping, self.max_high_freq_damping)
729 | room_scale = rand(self.min_room_scale, self.max_room_scale)
730 | stereo_depth = rand(self.min_stereo_depth, self.max_stereo_depth)
731 | wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
732 | pre_delay = rand(self.min_pre_delay, self.max_pre_delay)
733 |
734 | effects = [
735 | [
736 | "reverb",
737 | f"{reverberance}",
738 | f"{high_freq_damping}",
739 | f"{room_scale}",
740 | f"{stereo_depth}",
741 | f"{pre_delay}",
742 | "--wet-only",
743 | ]
744 | ]
745 | y, _ = torchaudio.sox_effects.apply_effects_tensor(
746 | x, self.sample_rate, effects, channels_first=True
747 | )
748 |
749 | # manual wet/dry mix
750 | return (x * (1 - wet_dry)) + (y * wet_dry)
751 |
752 |
753 | class RandomPedalboardReverb(CPUBase):
754 | def __init__(
755 | self,
756 | sample_rate: float = 44100.0,
757 | min_room_size: float = 0.0,
758 | max_room_size: float = 1.0,
759 | min_damping: float = 0.0,
760 | max_damping: float = 1.0,
761 | min_wet_dry: float = 0.0,
762 | max_wet_dry: float = 0.7,
763 | min_width: float = 0.0,
764 | max_width: float = 1.0,
765 | **kwargs,
766 | ) -> None:
767 | super().__init__(**kwargs)
768 | self.sample_rate = sample_rate
769 | self.min_room_size = min_room_size
770 | self.max_room_size = max_room_size
771 | self.min_damping = min_damping
772 | self.max_damping = max_damping
773 | self.min_wet_dry = min_wet_dry
774 | self.max_wet_dry = max_wet_dry
775 | self.min_width = min_width
776 | self.max_width = max_width
777 |
778 | def _transform(self, x: torch.Tensor):
779 | board = Pedalboard()
780 | room_size = rand(self.min_room_size, self.max_room_size)
781 | damping = rand(self.min_damping, self.max_damping)
782 | wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
783 | width = rand(self.min_width, self.max_width)
784 |
785 | board.append(
786 | Reverb(
787 | room_size=room_size,
788 | damping=damping,
789 | wet_level=wet_dry,
790 | dry_level=(1 - wet_dry),
791 | width=width,
792 | )
793 | )
794 |
795 | return torch.from_numpy(board(x.numpy(), self.sample_rate))
796 |
797 |
798 | class LoudnessNormalize(CPUBase):
799 | def __init__(
800 | self,
801 | sample_rate: float = 44100.0,
802 | target_lufs_db: float = -32.0,
803 | **kwargs,
804 | ) -> None:
805 | super().__init__(**kwargs)
806 | self.sample_rate = sample_rate
807 | self.meter = pyln.Meter(sample_rate)
808 | self.target_lufs_db = target_lufs_db
809 |
810 | def _transform(self, x: torch.Tensor):
811 | x_lufs_db = self.meter.integrated_loudness(x.permute(1, 0).numpy())
812 | delta_lufs_db = torch.tensor([self.target_lufs_db - x_lufs_db]).float()
813 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0)
814 | return gain_lin * x
815 |
816 |
817 | class Mono2Stereo(CPUBase):
818 | def __init__(self) -> None:
819 | super().__init__(p=1.0)
820 |
821 | def _transform(self, x: torch.Tensor):
822 | assert x.ndim == 2 and x.shape[0] == 1, "x must be mono"
823 | return torch.cat([x, x], dim=0)
824 |
825 |
826 | class RandomPan(CPUBase):
827 | def __init__(
828 | self,
829 | min_pan: float = -1.0,
830 | max_pan: float = 1.0,
831 | **kwargs,
832 | ) -> None:
833 | super().__init__(**kwargs)
834 | self.min_pan = min_pan
835 | self.max_pan = max_pan
836 |
837 | def _transform(self, x: torch.Tensor):
838 | """Constant power panning"""
839 | assert x.ndim == 2 and x.shape[0] == 2, "x must be stereo"
840 | theta = rand(self.min_pan, self.max_pan) * np.pi / 4
841 | x = x * 0.707 # normalize to prevent clipping
842 | left_x, right_x = x[0], x[1]
843 | cos_theta = np.cos(theta)
844 | sin_theta = np.sin(theta)
845 | left_x = left_x * (cos_theta - sin_theta)
846 | right_x = right_x * (cos_theta + sin_theta)
847 | return torch.stack([left_x, right_x], dim=0)
848 |
--------------------------------------------------------------------------------
/data/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseDataset
2 | from .fast_musdb import FastMUSDB
3 | from .dnr import DnR
4 | from .label_noise_bleed import LabelNoiseBleed
5 |
--------------------------------------------------------------------------------
/data/dataset/base.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import torch
3 | import random
4 | import numpy as np
5 | import torchaudio
6 | from typing import Optional, Callable, List, Tuple
7 | from pathlib import Path
8 |
9 | __all__ = ["BaseDataset"]
10 |
11 |
12 | class BaseDataset(Dataset):
13 | sr: int = 44100
14 |
15 | def __init__(
16 | self,
17 | tracks: List[Path],
18 | track_lengths: List[int],
19 | sources: List[str],
20 | mix_name: str = "mix",
21 | seq_duration: float = 6.0,
22 | samples_per_track: int = 64,
23 | random: bool = False,
24 | random_track_mix: bool = False,
25 | transform: Optional[Callable] = None,
26 | ):
27 | super().__init__()
28 | self.tracks = tracks
29 | self.track_lengths = track_lengths
30 | self.sources = sources
31 | self.mix_name = mix_name
32 | self.seq_duration = seq_duration
33 | self.samples_per_track = samples_per_track
34 | self.segment = int(self.seq_duration * self.sr)
35 | self.random = random
36 | self.random_track_mix = random_track_mix
37 | self.transform = transform
38 |
39 | if self.seq_duration <= 0:
40 | self._size = len(self.tracks)
41 | elif self.random:
42 | self._size = len(self.tracks) * self.samples_per_track
43 | else:
44 | chunks = [l // self.segment for l in self.track_lengths]
45 | cum_chunks = np.cumsum(chunks)
46 | self.cum_chunks = cum_chunks
47 | self._size = cum_chunks[-1]
48 |
49 | def load_tracks(self) -> Tuple[List[Path], List[int]]:
50 | # Implement in child class
51 | # Return list of tracks and list of track lengths
52 | raise NotImplementedError
53 |
54 | def __len__(self):
55 | return self._size
56 |
57 | def _get_random_track_idx(self):
58 | return random.randrange(len(self.tracks))
59 |
60 | def _get_random_start(self, length):
61 | return random.randrange(length - self.segment + 1)
62 |
63 | def _get_track_from_chunk(self, index):
64 | track_idx = np.digitize(index, self.cum_chunks)
65 | if track_idx > 0:
66 | chunk_start = (index - self.cum_chunks[track_idx - 1]) * self.segment
67 | else:
68 | chunk_start = index * self.segment
69 | return self.tracks[track_idx], chunk_start
70 |
71 | def __getitem__(self, index):
72 | stems = []
73 | if self.seq_duration <= 0:
74 | folder_name = self.tracks[index]
75 | x = torchaudio.load(
76 | folder_name / f"{self.mix_name}.wav",
77 | )[0]
78 | for s in self.sources:
79 | source_name = folder_name / (s + ".wav")
80 | audio = torchaudio.load(source_name)[0]
81 | stems.append(audio)
82 | else:
83 | if self.random:
84 | track_idx = index // self.samples_per_track
85 | folder_name, chunk_start = self.tracks[
86 | track_idx
87 | ], self._get_random_start(self.track_lengths[track_idx])
88 | else:
89 | folder_name, chunk_start = self._get_track_from_chunk(index)
90 | for s in self.sources:
91 | if self.random_track_mix and self.random:
92 | track_idx = self._get_random_track_idx()
93 | folder_name, chunk_start = self.tracks[
94 | track_idx
95 | ], self._get_random_start(self.track_lengths[track_idx])
96 | source_name = folder_name / (s + ".wav")
97 | audio = torchaudio.load(
98 | source_name,
99 | num_frames=self.segment,
100 | frame_offset=chunk_start,
101 | )[0]
102 | stems.append(audio)
103 | if self.random_track_mix and self.random:
104 | x = sum(stems)
105 | else:
106 | x = torchaudio.load(
107 | folder_name / f"{self.mix_name}.wav",
108 | num_frames=self.segment,
109 | frame_offset=chunk_start,
110 | )[0]
111 | y = torch.stack(stems)
112 | if self.transform is not None:
113 | x, y = self.transform((x, y))
114 | return x, y
115 |
--------------------------------------------------------------------------------
/data/dataset/dnr.py:
--------------------------------------------------------------------------------
1 | import torchaudio
2 | from tqdm import tqdm
3 | from data.dataset import BaseDataset
4 | from pathlib import Path
5 | import os
6 |
7 | from aimless.utils import SDX_SOURCES as SOURCES
8 |
9 | __all__ = ["DnR"]
10 |
11 |
12 | class DnR(BaseDataset):
13 | def __init__(self, root: str, split: str, **kwargs):
14 | tracks, track_lengths = load_tracks(root, split)
15 | super().__init__(
16 | **kwargs,
17 | tracks=tracks,
18 | track_lengths=track_lengths,
19 | sources=SOURCES,
20 | mix_name="mix",
21 | )
22 |
23 |
24 | def load_tracks(root: str, split: str):
25 | root = Path(os.path.expanduser(root))
26 | if split == "train":
27 | split_root = root / "tr"
28 | elif split == "valid":
29 | split_root = root / "cv"
30 | elif split == "test":
31 | split_root = root / "tt"
32 | else:
33 | raise ValueError("Invalid split: {}".format(split))
34 | tracks = sorted([x for x in split_root.iterdir() if x.is_dir()])
35 | for x in tracks:
36 | assert torchaudio.info(str(x / "mix.wav")).sample_rate == DnR.sr
37 |
38 | track_lengths = [
39 | torchaudio.info(str(x / "mix.wav")).num_frames for x in tqdm(tracks)
40 | ]
41 | return tracks, track_lengths
42 |
--------------------------------------------------------------------------------
/data/dataset/fast_musdb.py:
--------------------------------------------------------------------------------
1 | import torchaudio
2 | from tqdm import tqdm
3 | from data.dataset import BaseDataset
4 | import musdb
5 | import os
6 | import yaml
7 | from pathlib import Path
8 | from typing import List
9 |
10 | from aimless.utils import MDX_SOURCES as SOURCES
11 |
12 | __all__ = ["FastMUSDB"]
13 |
14 |
15 | class FastMUSDB(BaseDataset):
16 | def __init__(
17 | self,
18 | root: str,
19 | subsets: List[str] = ["train", "test"],
20 | split: str = None,
21 | **kwargs
22 | ):
23 | tracks, track_lengths = load_tracks(root, subsets, split)
24 | super().__init__(
25 | **kwargs,
26 | tracks=tracks,
27 | track_lengths=track_lengths,
28 | sources=SOURCES,
29 | mix_name="mixture",
30 | )
31 |
32 |
33 | def load_tracks(root, subsets=None, split=None):
34 | root = Path(os.path.expanduser(root))
35 | setup_path = os.path.join(musdb.__path__[0], "configs", "mus.yaml")
36 | with open(setup_path, "r") as f:
37 | setup = yaml.safe_load(f)
38 | if subsets is not None:
39 | if isinstance(subsets, str):
40 | subsets = [subsets]
41 | else:
42 | subsets = ["train", "test"]
43 |
44 | if subsets != ["train"] and split is not None:
45 | raise RuntimeError("Subset has to set to `train` when split is used")
46 |
47 | print("Gathering files ...")
48 | tracks = []
49 | track_lengths = []
50 | for subset in subsets:
51 | subset_folder = root / subset
52 | for _, folders, _ in tqdm(os.walk(subset_folder)):
53 | # parse pcm tracks and sort by name
54 | for track_name in sorted(folders):
55 | if subset == "train":
56 | if split == "train" and track_name in setup["validation_tracks"]:
57 | continue
58 | elif (
59 | split == "valid"
60 | and track_name not in setup["validation_tracks"]
61 | ):
62 | continue
63 |
64 | track_folder = subset_folder / track_name
65 | # add track to list of tracks
66 | tracks.append(track_folder)
67 | meta = torchaudio.info(os.path.join(track_folder, "mixture.wav"))
68 | assert meta.sample_rate == FastMUSDB.sr
69 |
70 | track_lengths.append(meta.num_frames)
71 |
72 | return tracks, track_lengths
73 |
--------------------------------------------------------------------------------
/data/dataset/label_noise_bleed.py:
--------------------------------------------------------------------------------
1 | import torchaudio
2 | from tqdm import tqdm
3 | from data.dataset import BaseDataset
4 | import os
5 | from pathlib import Path
6 | import csv
7 |
8 | from aimless.utils import MDX_SOURCES as SOURCES
9 |
10 | __all__ = ["DnR"]
11 |
12 |
13 | # Run scripts/dataset_split_and_mix.py first!
14 | class LabelNoiseBleed(BaseDataset):
15 | def __init__(self, root: str, split: str, clean_csv_path: str = None, **kwargs):
16 | tracks, track_lengths = load_tracks(root, split, clean_csv_path)
17 | super().__init__(
18 | **kwargs,
19 | tracks=tracks,
20 | track_lengths=track_lengths,
21 | sources=SOURCES,
22 | mix_name="mixture",
23 | )
24 |
25 |
26 | def load_tracks(root: str, split: str, clean_csv_path: str):
27 | root = Path(os.path.expanduser(root))
28 | if split == "train":
29 | split_root = root / "train"
30 | elif split == "valid":
31 | split_root = root / "valid"
32 | elif split == "test":
33 | split_root = root / "test"
34 | else:
35 | raise ValueError("Invalid split: {}".format(split))
36 |
37 | tracks = sorted([x for x in split_root.iterdir() if x.is_dir()])
38 | if clean_csv_path is not None:
39 | with open(clean_csv_path, "r") as f:
40 | reader = csv.reader(f)
41 | clean_tracks = [x[0] for x in reader if x[1] == "Y"]
42 | tracks = [x for x in tracks if x.name in clean_tracks]
43 |
44 | for x in tracks:
45 | assert torchaudio.info(str(x / "mixture.wav")).sample_rate == LabelNoiseBleed.sr
46 |
47 | track_lengths = [
48 | torchaudio.info(str(x / "mixture.wav")).num_frames for x in tqdm(tracks)
49 | ]
50 | return tracks, track_lengths
51 |
--------------------------------------------------------------------------------
/data/dataset/speech.py:
--------------------------------------------------------------------------------
1 | import torchaudio
2 | from tqdm import tqdm
3 | from torch.utils.data import Dataset
4 | from pathlib import Path
5 | import random
6 | import torch
7 | import soundfile as sf
8 | import numpy as np
9 | from resampy import resample
10 | from typing import Optional, Callable
11 |
12 | __all__ = ["SpeechNoise"]
13 |
14 |
15 | class SpeechNoise(Dataset):
16 | sr: int = 44100
17 |
18 | def __init__(
19 | self,
20 | speech_root: str,
21 | noise_root: str,
22 | seq_duration: float = 6.0,
23 | samples_per_track: int = 64,
24 | least_overlap_ratio: float = 0.5,
25 | snr_sampler: Optional[torch.distributions.Distribution] = None,
26 | mono: bool = True,
27 | transform: Optional[Callable] = None,
28 | ):
29 | super().__init__()
30 | speech_root = Path(speech_root)
31 | noise_root = Path(noise_root)
32 | speech_files = list(speech_root.glob("**/*.wav")) + list(
33 | speech_root.glob("**/*.flac")
34 | )
35 | noise_files = list(Path(noise_root).glob("**/*.wav")) + list(
36 | Path(noise_root).glob("**/*.flac")
37 | )
38 |
39 | speech_track_frames = []
40 | speech_track_sr = []
41 | for x in tqdm(speech_files):
42 | info = torchaudio.info(x)
43 | speech_track_frames.append(info.num_frames)
44 | speech_track_sr.append(info.sample_rate)
45 |
46 | noise_track_frames = []
47 | noise_track_sr = []
48 | for x in tqdm(noise_files):
49 | info = torchaudio.info(x)
50 | noise_track_frames.append(info.num_frames)
51 | noise_track_sr.append(info.sample_rate)
52 |
53 | self.speech_files = speech_files
54 | self.noise_files = noise_files
55 | self.speech_track_frames = speech_track_frames
56 | self.noise_track_frames = noise_track_frames
57 | self.speech_track_sr = speech_track_sr
58 | self.noise_track_sr = noise_track_sr
59 |
60 | self.seq_duration = seq_duration
61 | self.samples_per_track = samples_per_track
62 | self.segment = int(self.seq_duration * self.sr)
63 | self.least_overlap_ratio = least_overlap_ratio
64 | self.least_overlap_segment = int(least_overlap_ratio * self.segment)
65 | self.transform = transform
66 | self.snr_sampler = snr_sampler
67 | self.mono = mono
68 |
69 | self._size = len(self.noise_files) * self.samples_per_track
70 |
71 | def __len__(self):
72 | return self._size
73 |
74 | def _get_random_track_idx(self):
75 | return random.randrange(len(self.tracks))
76 |
77 | def _get_random_start(self, length):
78 | return random.randrange(length - self.segment + 1)
79 |
80 | def __getitem__(self, index):
81 | track_idx = index // self.samples_per_track
82 | noise_sr = self.noise_track_sr[track_idx]
83 | noise_resample_ratio = self.sr / noise_sr
84 | noise_file = self.noise_files[track_idx]
85 | pos_start = int(
86 | self._get_random_start(
87 | int(self.noise_track_frames[track_idx] * noise_resample_ratio)
88 | )
89 | / noise_resample_ratio
90 | )
91 | frames = int(self.seq_duration * noise_sr)
92 | noise, _ = sf.read(
93 | noise_file, start=pos_start, frames=frames, fill_value=0, always_2d=True
94 | )
95 | if noise_sr != self.sr:
96 | noise = resample(noise, noise_sr, self.sr, axis=0)
97 | if noise.shape[0] < self.segment:
98 | noise = np.pad(
99 | noise, ((0, self.segment - noise.shape[0]), (0, 0)), "constant"
100 | )
101 | else:
102 | noise = noise[: self.segment]
103 |
104 | if self.mono:
105 | noise = noise.mean(axis=1, keepdims=True)
106 | else:
107 | noise = np.broadcast_to(noise, (noise.shape[0], 2))
108 |
109 | # get a random speech file
110 | speech_idx = random.randint(0, len(self.speech_files) - 1)
111 | speech_file = self.speech_files[speech_idx]
112 | speech_sr = self.speech_track_sr[speech_idx]
113 | speech_resample_ratio = self.sr / speech_sr
114 | speech_resampled_length = int(
115 | self.speech_track_frames[speech_idx] * speech_resample_ratio
116 | )
117 |
118 | if speech_resampled_length < self.least_overlap_segment:
119 | speech, _ = sf.read(speech_file, always_2d=True)
120 | if speech_sr != self.sr:
121 | speech = resample(speech, speech_sr, self.sr, axis=0)
122 | speech_resampled_length = speech.shape[0]
123 |
124 | if self.mono:
125 | speech = speech.mean(axis=1, keepdims=True)
126 | else:
127 | speech = np.broadcast_to(speech, (speech.shape[0], 2))
128 |
129 | speech_energy = np.sum(speech**2)
130 | pos = random.randint(0, self.segment - speech_resampled_length)
131 |
132 | padded_speech = np.zeros_like(noise)
133 | padded_speech[pos : pos + speech_resampled_length] = speech
134 | speech = padded_speech
135 | else:
136 | pos = random.randint(
137 | self.least_overlap_segment - speech_resampled_length,
138 | self.segment - self.least_overlap_segment,
139 | )
140 | if pos < 0:
141 | pos_start = int(-pos / speech_resample_ratio)
142 | frames = int(
143 | min(self.segment, (speech_resampled_length + pos))
144 | / speech_resample_ratio
145 | )
146 | else:
147 | pos_start = 0
148 | frames = int(
149 | min(speech_resampled_length, self.segment - pos)
150 | / speech_resample_ratio
151 | )
152 |
153 | speech, _ = sf.read(
154 | speech_file,
155 | start=pos_start,
156 | frames=frames,
157 | fill_value=0,
158 | always_2d=True,
159 | )
160 | if speech_sr != self.sr:
161 | speech = resample(speech, speech_sr, self.sr, axis=0)
162 |
163 | if self.mono:
164 | speech = speech.mean(axis=1, keepdims=True)
165 | else:
166 | speech = np.broadcast_to(speech, (speech.shape[0], 2))
167 |
168 | speech_energy = np.sum(speech**2)
169 |
170 | padded_speech = np.zeros_like(noise)
171 | pos = max(0, pos)
172 | padded_speech[pos : pos + speech.shape[0]] = speech
173 | speech = padded_speech
174 |
175 | speech = torch.from_numpy(speech.T).float()
176 | noise = torch.from_numpy(noise.T).float()
177 |
178 | if self.snr_sampler is not None:
179 | snr = self.snr_sampler.sample()
180 | # scale noise to have the desired SNR
181 | noise_energy = noise.pow(2).sum() + 1e-8
182 | noise = noise * torch.sqrt(speech_energy / noise_energy) * 10 ** (-snr / 10)
183 |
184 | stems = torch.stack([speech, noise], dim=0)
185 | mix = speech + noise
186 |
187 | if self.transform is not None:
188 | mix, stems = self.transform((mix, stems))
189 |
190 | return mix, stems
191 |
--------------------------------------------------------------------------------
/data/lightning/__init__.py:
--------------------------------------------------------------------------------
1 | from .musdb import MUSDB
2 | from .dnr import DnR
3 | from .bleed import Bleed
4 | from .label_noise import LabelNoise
5 | from .speech import SpeechNoise
6 |
--------------------------------------------------------------------------------
/data/lightning/bleed.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Compose
2 | from torch.utils.data import DataLoader
3 | import pytorch_lightning as pl
4 | from typing import List
5 |
6 |
7 | from data.dataset import LabelNoiseBleed
8 | from data.augment import CPUBase
9 |
10 |
11 | class Bleed(pl.LightningDataModule):
12 | def __init__(
13 | self,
14 | root: str,
15 | seq_duration: float = 6.0,
16 | samples_per_track: int = 64,
17 | random: bool = False,
18 | random_track_mix: bool = False,
19 | transforms: List[CPUBase] = None,
20 | batch_size: int = 16,
21 | ):
22 | super().__init__()
23 | self.save_hyperparameters(
24 | "root",
25 | "seq_duration",
26 | "samples_per_track",
27 | "random",
28 | "random_track_mix",
29 | "batch_size",
30 | )
31 | # manually save transforms since pedalboard is not pickleable
32 | if transforms is None:
33 | self.transforms = None
34 | else:
35 | self.transforms = Compose(transforms)
36 |
37 | def setup(self, stage=None):
38 | if stage == "fit":
39 | self.train_dataset = LabelNoiseBleed(
40 | root=self.hparams.root,
41 | split="train",
42 | seq_duration=self.hparams.seq_duration,
43 | samples_per_track=self.hparams.samples_per_track,
44 | random=self.hparams.random,
45 | random_track_mix=self.hparams.random_track_mix,
46 | transform=self.transforms,
47 | )
48 |
49 | if stage == "validate" or stage == "fit":
50 | self.val_dataset = LabelNoiseBleed(
51 | root=self.hparams.root,
52 | split="valid",
53 | seq_duration=self.hparams.seq_duration,
54 | )
55 |
56 | def train_dataloader(self):
57 | return DataLoader(
58 | self.train_dataset,
59 | batch_size=self.hparams.batch_size,
60 | num_workers=4,
61 | shuffle=True,
62 | drop_last=True,
63 | )
64 |
65 | def val_dataloader(self):
66 | return DataLoader(
67 | self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4
68 | )
69 |
--------------------------------------------------------------------------------
/data/lightning/dnr.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Compose
2 | from torch.utils.data import DataLoader, ConcatDataset
3 | import pytorch_lightning as pl
4 | from typing import List
5 |
6 | from data.dataset import DnR as DnRDataset
7 | from data.augment import CPUBase
8 |
9 |
10 | class DnR(pl.LightningDataModule):
11 | def __init__(
12 | self,
13 | root: str,
14 | seq_duration: float = 6.0,
15 | samples_per_track: int = 64,
16 | random: bool = False,
17 | include_val: bool = False,
18 | random_track_mix: bool = False,
19 | transforms: List[CPUBase] = None,
20 | batch_size: int = 16,
21 | ):
22 | super().__init__()
23 | self.save_hyperparameters(
24 | "root",
25 | "seq_duration",
26 | "samples_per_track",
27 | "random",
28 | "include_val",
29 | "random_track_mix",
30 | "batch_size",
31 | )
32 | # manually save transforms since pedalboard is not pickleable
33 | if transforms is None:
34 | self.transforms = None
35 | else:
36 | self.transforms = Compose(transforms)
37 |
38 | def setup(self, stage=None):
39 | if stage == "fit":
40 | self.train_dataset = DnRDataset(
41 | root=self.hparams.root,
42 | split="train",
43 | seq_duration=self.hparams.seq_duration,
44 | samples_per_track=self.hparams.samples_per_track,
45 | random=self.hparams.random,
46 | random_track_mix=self.hparams.random_track_mix,
47 | transform=self.transforms,
48 | )
49 |
50 | if self.hparams.include_val:
51 | self.train_dataset = ConcatDataset(
52 | [
53 | self.train_dataset,
54 | DnRDataset(
55 | root=self.hparams.root,
56 | split="valid",
57 | seq_duration=self.hparams.seq_duration,
58 | samples_per_track=self.hparams.samples_per_track,
59 | random=self.hparams.random,
60 | random_track_mix=self.hparams.random_track_mix,
61 | transform=self.transforms,
62 | ),
63 | ]
64 | )
65 |
66 | if stage == "validate" or stage == "fit":
67 | self.val_dataset = DnRDataset(
68 | root=self.hparams.root,
69 | split="valid",
70 | seq_duration=self.hparams.seq_duration,
71 | )
72 |
73 | def train_dataloader(self):
74 | return DataLoader(
75 | self.train_dataset,
76 | batch_size=self.hparams.batch_size,
77 | num_workers=4,
78 | shuffle=True,
79 | drop_last=True,
80 | )
81 |
82 | def val_dataloader(self):
83 | return DataLoader(
84 | self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4
85 | )
86 |
--------------------------------------------------------------------------------
/data/lightning/label_noise.csv:
--------------------------------------------------------------------------------
1 | ID,Clean,Notes
2 | 0a589d65-50a3-4999-8f16-b5b6199bceee,,
3 | 0d528a19-cb0f-4421-b250-444f9343e51c,,
4 | 0e0d57cd-8662-4091-86d4-ed3e35d04ef6,,
5 | 0f5fb60c-51d4-4618-871d-650c9e927b79,,
6 | 1a32e987-cae8-458f-9c26-d9a6abf5348d,,
7 | 1ade22ad-99bc-4954-b2b9-fb9e8fa06d41,Y,
8 | 1afe1b3b-3e2e-48d3-b859-f50e222cbaf4,,
9 | 1ee489e8-55ad-4d4a-8277-24051b85c02d,,
10 | 1f98fe4d-26c7-460f-9f68-33964bc4d8d3,Y,
11 | 1fc37390-1769-452d-9bea-19025be4c467,Y,
12 | 2a3fd99d-c86f-4275-9321-ad7146364503,Y,
13 | 2b4e8304-c92d-4347-b09e-cfb9b3e29bf2,,
14 | 2c020edb-5947-4fa7-afea-ebc592cea683,,
15 | 2dc237cd-1637-46f0-8f58-ca68dc6f6031,Y,
16 | 2e5d996d-43f3-4359-b7c5-afebe9997556,,
17 | 02ee37da-eea3-42b4-83bf-ab7f243afa13,,
18 | 3a047d1a-f56d-4bf4-9910-f4f77206e53d,,
19 | 3a6023c3-4cd8-46c0-a376-d353743648a9,,
20 | 3c3b5fdb-f15e-4ba4-884a-b083ce2426c6,,
21 | 3c557409-3a34-43c2-9159-5421bbad5ecb,,
22 | 3e41f238-7c48-4a42-ba70-5ee39824a844,,
23 | 3e656eec-84d4-4a45-b410-d3817d849f92,,
24 | 3e7985e5-408f-4cf8-92b9-b9f62f738dd3,Y,
25 | 3e389000-8fdc-4b63-b8b8-ab044273790d,,
26 | 3f5233cb-57fa-4772-b389-f295a6f416ae,,
27 | 4a896cde-57c6-4646-b610-1b0b654d0349,,
28 | 4b9f86f4-23e4-458b-839e-8a63b584bea3,,
29 | 4cbd6c36-87a2-4d50-86e3-52d39b98fad3,,
30 | 5a6df5c7-a58a-479e-bdfb-c5946c221933,Y,"20s of harmonica, some drum bleed in vox"
31 | 05b0ab77-2495-438d-8831-e3d81f96c16d,Y,"Bass sometimes doesn't sound like ""bass"""
32 | 05e7af85-9721-4b42-952a-ccd34feb6033,,
33 | 5f04798d-c7be-4b8a-90bd-1fcd9946e875,,
34 | 6b168ae6-9d8a-4dc2-9d27-898e6871bf8b,,
35 | 06bfc6e7-e5ac-4827-bf8e-ffaf1675872f,Y,Crash cymbal in other
36 | 6c70d5e0-5972-444a-86f8-a558dbb92d92,,
37 | 6cd44645-ed19-4ecc-a57c-58d400005b29,,
38 | 6ce087b4-e571-4472-9be2-04b5340311c6,,
39 | 6ceda40a-88bc-4e98-87c3-dd5c91725d41,,
40 | 6e50565e-8179-4913-af54-a2a7f0dcab2f,,
41 | 7a9f3169-7bdf-483c-9634-8e5b097d50df,,
42 | 7ba734f0-547e-4142-a9b4-01a0c25b9d1f,,
43 | 7bfa233c-24ed-4c7a-9096-11e3aa00c55d,,
44 | 7dd515b0-e218-425d-b8bf-a75056237d6a,,
45 | 7e74ce11-1603-440b-8794-b6e665b917e1,,
46 | 7efcb55d-c0db-472f-b765-1739ad7536aa,,
47 | 07fb2df2-91d6-458d-9230-9638b4edac08,,
48 | 8a6c9c1f-4865-404f-a805-1949de36a33c,,
49 | 8b83ba75-5e35-48c4-b42f-5419da2e6301,Y,
50 | 8ce11544-9a6f-4f1e-ac2f-fc10343f15c8,,
51 | 8e2d0c5c-6764-4d74-a740-391d7931ffd7,,
52 | 8f36f17f-c033-4c1b-a793-80dce43d507b,,
53 | 9ac2612b-e25f-4d27-8d43-b957e7e5a74b,Y,
54 | 9c8a5c66-f6d8-4425-8671-6b7aa6a2663b,,
55 | 9ce23a79-20eb-431a-80d2-eda3260ef503,,
56 | 9eb8bc50-cffb-4c19-be0e-d27423e3e102,,
57 | 9f581867-9f63-4ec4-8a43-e62c3c4230a6,,
58 | 9fc0cff7-bb02-496f-9fa1-db67c52b1b4b,,
59 | 13f233aa-a2e5-4683-8533-2f1e344b55b4,,
60 | 014f3712-293b-42af-9f29-0ed1785be792,,
61 | 16fb6c39-3834-4cfc-abf4-abb4a8d4646c,,
62 | 22d265ef-ee2b-4aba-8d60-c3430295cd6d,Y,
63 | 22ea41a3-1766-4a76-8071-380b27f1869a,,
64 | 24f8a652-0168-4702-8725-42f9924e6729,,
65 | 30cfc60a-5a57-4000-a05e-65006c8f6f74,,
66 | 35a19148-49bf-451d-9a0e-5ab8e914c367,,
67 | 36ee7fc6-604c-4a75-b4f0-0a9ceef3b9cb,Y,
68 | 43ca388d-3e62-4df2-b72d-abf407a7aa5a,,
69 | 045dcfd1-e960-4332-80cc-fdacc4a7c6a7,,
70 | 046ab651-a333-46e1-9d27-ab14ee036c42,,
71 | 46bc5393-7753-44ae-913b-bd5fa8f33e98,,
72 | 49ed7cef-8ffb-4833-bc96-6df7b8ff5b43,,
73 | 58efddfb-04b3-4951-858e-e7dbcfccfc21,,
74 | 58fa04aa-426f-4e16-a112-29eb6e2f2d3e,,
75 | 63b68795-0076-476b-a917-dec9e89bf91e,Y,
76 | 72c6f013-11ea-4bf7-93b2-c6ef2c117718,,
77 | 78ef22ce-472f-4f82-8656-16df73b9465f,Y,
78 | 87a5da23-f17b-44da-accf-c04832f81a14,,
79 | 88b545e5-4d06-4d55-a306-1bd3a2915ee5,,
80 | 89c515c9-5e93-4cb4-9806-20432d2d074d,,
81 | 89f2c781-5c67-4508-a2d6-236744b8c197,,
82 | 93dda0e8-dd76-49e4-b08f-54a82387cdf6,,
83 | 94fafb2a-9f4e-4a01-bee9-998008f95f41,,
84 | 97b07e0e-274e-4212-a66b-44210a48724d,Y,Strange sfx in Vox @ 0:30
85 | 125fc63d-9b69-4170-a46a-42c91bc28446,,
86 | 152d4f5d-4093-4fa4-a4a4-8a9b3502d89d,,
87 | 174a115f-3688-45dc-8c39-9d05f21758e1,,
88 | 0177be35-64de-469b-908b-2d9edb49c053,,
89 | 212bb137-fd01-465e-80f3-a890fb0ebcdd,,
90 | 260c431b-72d4-4ac6-bae5-ee49ce5c0fe4,,
91 | 312bec8d-1c61-43e0-924a-1fb87ddc3e41,Y,
92 | 322f4d9d-b0c9-4ab3-9e30-544a25331ffd,,
93 | 334ed0e5-1761-4ee9-9d39-a555eb9b64a0,,
94 | 0358fd1e-244a-4422-9a42-29b5d68f6e4b,Y,No Bass track
95 | 390b4fea-92be-4fe3-9576-86529f80b4ae,,
96 | 664cc931-4e00-441e-9a78-e7a292515cea,,
97 | 704f1de9-1d02-4c2b-af05-107a7700a51d,,
98 | 747d5c98-665b-4470-a696-7a6cf6968ef1,,
99 | 765d5131-afd9-4ad7-8786-8bef5705c1c2,Y,
100 | 825f697d-5fcd-4429-ba33-23163c726ca7,,
101 | 1921a83e-0373-4bf7-8dc9-6cc9401c9309,,
102 | 2973adc9-6bf8-4422-9f7a-6fd0038eb565,Y,
103 | 4380ad97-2620-4419-a011-ddfa29a87f54,,
104 | 4999a0bf-a753-4e0e-85b1-690259dabf96,,
105 | 6031b120-f6e2-4999-96ba-a1e31be68ea8,,
106 | 06114cc2-e34d-4f8f-82d3-cf6981572f2f,,
107 | 6681f493-c996-424a-9bdb-c671912ea9db,,
108 | 7180bffa-dd48-49b9-bc02-f6e3f7f165b0,,
109 | 8042b88a-6179-406b-9ec4-b45a4cdd4a71,,
110 | 8804c154-6294-481a-ad63-bc61162cae2f,,
111 | 28748b6e-6125-42f7-998d-2ad734e39b6c,,
112 | 49478a32-483f-48d2-a594-d272b44bf587,Y,
113 | 53808b95-cfe9-461d-a113-ffadf32817a1,,
114 | 95378cf3-e939-42e0-b486-ebf2ca951664,,
115 | 169628c5-266d-4e11-993b-440fa5fa2167,Y,
116 | 378742ba-5ba8-44c2-9cfd-8a609decca57,,
117 | 553048ce-7afd-4e0e-b4cb-4896620287a1,,
118 | 731893c6-67b9-42f0-aea6-d1f70c2c9870,,
119 | 737356b2-ce9c-448b-877b-e42b3ed94563,,
120 | 763641c7-488f-4959-a554-fdbce9582644,,
121 | 04204031-4f98-44ba-9c47-98c2f2e6b8fc,,
122 | 04798708-6915-4dbc-842e-d394d545d4eb,,
123 | 4857878a-e44b-4143-90e9-b65d0b704306,,
124 | 5640831d-7853-4d06-8166-988e2844b652,,
125 | 7524054e-dc67-47e0-8c26-ea1d4d70d2fb,Y,
126 | 8427760a-b82e-4136-8f12-dfd53cad9bc9,,
127 | 25789239-1075-43b9-bfc9-51dff4a29590,,
128 | a0b9a4e4-51f5-4c98-a090-0317fb891056,,
129 | a0eae9d2-d97f-401a-a495-1e1d1cb84a9c,Y,
130 | a1dcaeb2-f4e6-4818-b490-09f44d624afc,,
131 | a6bccb70-62b5-42aa-bfc9-3b0b886a2b2d,Y,
132 | a56d9450-3a26-485c-8ac3-24b6b54e2c1d,,
133 | a199697c-3cdc-47c7-a9c7-c1b07dd6c9dd,,
134 | aa600069-0a98-45e6-94fa-4000bfe46c25,,
135 | aaf0fc7b-d7f5-412c-b3b9-313a7c483666,,
136 | ad6c3742-e517-42eb-9dec-e74ed05387cc,,
137 | aefc1609-976b-423e-8516-f7d588d64ff7,,
138 | afca84b2-0277-4b1b-8696-5f14543f338c,Y,
139 | b8a79d39-346e-4258-a810-572b3b2c9ab1,,
140 | b8d6f3eb-f2d6-4342-af90-6d09f5257b6b,,
141 | b92cb1ca-baa9-4c74-b6dc-36389671ed76,,
142 | b207da3d-4baf-485a-98e1-657602479b3a,,
143 | b876b54b-6007-4d36-a6f4-efed8829d5fc,,
144 | bacbb01f-b877-4d62-8050-992f1d85543a,,
145 | baea951d-526a-49aa-8329-c8de676341fb,,
146 | bb45cf1a-4c58-4fe3-88ea-97ef27527507,,
147 | bbf40b5a-8ef9-4aec-a6c3-8b8706eb2ba0,,
148 | bc1f2967-f834-43bd-aadc-95afc897cfe7,,
149 | bc964128-da16-4e4c-af95-4d1211e78c70,,
150 | bd25c90e-d307-4cd9-adfb-46f5e323a81b,,
151 | bdcc429e-ed95-40d3-a1af-bad268d66b25,,
152 | bdd109ec-d5dd-4d91-92ad-66b679518026,,
153 | c6d73235-1dd5-4085-a3b3-50a3466c6168,,
154 | c8f42ad5-5a2f-4398-b9d9-207fd4fcc551,Y,
155 | c15ade79-43c0-4271-9f00-a121cefc92e5,Y,
156 | c70471f9-9c4a-41c9-b8f8-20ac38847a8e,Y,
157 | c228818e-eabe-434b-9d60-2fb84a6c5b2a,,
158 | c2330200-ad8e-4848-8c2b-b70612f4b80e,Y,
159 | c8752696-4ae5-4e47-b2ad-622496966fa9,,
160 | ca080447-fe99-4f6d-98c9-c69b68dacba3,,
161 | cc3e4991-6cce-40fe-a917-81a4fbb92ea6,,
162 | cc7f7675-d3c8-4a49-a2d7-a8959b694004,,
163 | cda46831-26d1-4dd8-ac50-6004d27d45dd,,
164 | cff5bcde-6a15-4c5b-b529-e1c528c46335,Y,Some slight vocal bleed in other
165 | d4df499c-e394-4753-b459-e167e6a58bad,,
166 | d4fe2408-c123-4739-93bb-22f558ae99d7,,
167 | d7d28204-a8ac-4c2b-bb3c-c941f4a00b85,Y,
168 | d8f0e410-5761-4d4a-9000-effe11089bbd,,
169 | d028d7c2-45c1-4846-b7df-4964238fd460,,
170 | d45bb3a6-eb80-44b3-b2ef-56cc9d5b4914,,
171 | d072debf-ea5b-4e8a-a447-cc1868cfc5fa,,
172 | d890ff35-300d-49f2-8054-49ab47262987,Y,
173 | d2401d3d-967c-46be-b9a0-3da571105158,,
174 | d624037a-1a76-4dd9-9e60-4ba380748a0b,,
175 | d4262245-3143-4c05-8423-6cbdc6253042,,
176 | dda2c057-6d73-43fa-a130-d7d6562c09ca,Y,
177 | dfb0e076-cb6b-4dcc-9934-c60070ff04d7,,
178 | e2ccbc17-44bf-431a-af2b-4cf2fbd19a72,,
179 | e2e4ce50-cd0d-4144-809d-4cf8c8e4912e,,
180 | e3ab3975-033b-40e2-b538-09396b3d4244,Y,
181 | e4de8632-6f69-4c63-8081-f4c2b77b40df,Y,
182 | e37cdb09-e648-4e9b-bc06-d178a964161c,Y,
183 | e62afdcd-0c96-4bee-80c7-1c17b897a6d7,,
184 | e78fa5de-cfdd-44fa-87c1-10a337b7011f,,
185 | e9336d31-c0df-4c91-be2b-7c4420c9cd34,Y,
186 | e1108928-9776-434c-bc57-c32dfdb7839c,,
187 | ea29ab4d-7f72-4331-b2a4-d3945c754211,,
188 | ea898682-08e7-4818-b516-8c0e10a4c20a,,
189 | ebea0f1d-8e23-469e-8eed-5269a9c684f0,,
190 | ed90a89a-bf22-444d-af3d-d9ac3896ebd2,,
191 | ee082817-dbda-4fbf-b5aa-8dce2320ae35,,
192 | ef1510e0-ba23-4b59-ba53-14181d73f213,,
193 | f0c565c5-fc73-4da1-b979-0fac0167f671,,
194 | f4b735de-14b1-4091-a9ba-c8b30c0740a7,Y,
195 | f9e58f4d-e361-4598-9c9a-d0a83529cc68,Y,
196 | f40ffd10-4e8b-41e6-bd8a-971929ca9138,,
197 | f76e2c13-9a9a-4cac-b6dd-45b5111aac6d,Y,
198 | fa46f72c-696d-45bc-bcc5-2b3305800565,,
199 | faad432d-6ad0-492d-96f1-321eeb9685b5,,
200 | fac94d9a-59da-4f83-9027-3eafe082ad16,Y,
201 | fcd1937f-2b21-4a78-889a-7b7e63e0ebdd,,
202 | fd6e4b4a-f33a-4f3c-aa6e-7c65fd5dc0bc,,
203 | fe3ae408-d35f-4c17-aa33-402238725a9d,Y,
204 | ff486935-7ce2-4e23-8908-0ff5fcc50856,Y,
--------------------------------------------------------------------------------
/data/lightning/label_noise.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Compose
2 | from torch.utils.data import DataLoader
3 | import pytorch_lightning as pl
4 | from typing import List
5 | import os
6 |
7 | from data.dataset import LabelNoiseBleed
8 | from data.augment import CPUBase
9 |
10 |
11 | class LabelNoise(pl.LightningDataModule):
12 | def __init__(
13 | self,
14 | root: str,
15 | seq_duration: float = 6.0,
16 | samples_per_track: int = 64,
17 | random: bool = False,
18 | random_track_mix: bool = False,
19 | transforms: List[CPUBase] = None,
20 | batch_size: int = 16,
21 | ):
22 | super().__init__()
23 | self.save_hyperparameters(
24 | "root",
25 | "seq_duration",
26 | "samples_per_track",
27 | "random",
28 | "random_track_mix",
29 | "batch_size",
30 | )
31 | # manually save transforms since pedalboard is not pickleable
32 | if transforms is None:
33 | self.transforms = None
34 | else:
35 | self.transforms = Compose(transforms)
36 |
37 | def setup(self, stage=None):
38 | label_noise_path = None
39 | current_dir = os.path.dirname(os.path.realpath(__file__))
40 | if "label_noise.csv" in os.listdir(current_dir):
41 | label_noise_path = os.path.join(current_dir, "label_noise.csv")
42 | assert label_noise_path is not None
43 | if stage == "fit":
44 | self.train_dataset = LabelNoiseBleed(
45 | root=self.hparams.root,
46 | split="train",
47 | seq_duration=self.hparams.seq_duration,
48 | samples_per_track=self.hparams.samples_per_track,
49 | random=self.hparams.random,
50 | random_track_mix=self.hparams.random_track_mix,
51 | transform=self.transforms,
52 | clean_csv_path=label_noise_path,
53 | )
54 |
55 | if stage == "validate" or stage == "fit":
56 | self.val_dataset = LabelNoiseBleed(
57 | root=self.hparams.root,
58 | split="train",
59 | seq_duration=self.hparams.seq_duration,
60 | random=self.hparams.random,
61 | random_track_mix=self.hparams.random_track_mix,
62 | transform=self.transforms,
63 | clean_csv_path=label_noise_path,
64 | )
65 |
66 | def train_dataloader(self):
67 | return DataLoader(
68 | self.train_dataset,
69 | batch_size=self.hparams.batch_size,
70 | num_workers=8,
71 | shuffle=True,
72 | drop_last=True,
73 | )
74 |
75 | def val_dataloader(self):
76 | return DataLoader(
77 | self.val_dataset,
78 | batch_size=self.hparams.batch_size,
79 | num_workers=8,
80 | shuffle=False,
81 | drop_last=True,
82 | )
83 |
--------------------------------------------------------------------------------
/data/lightning/musdb.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Compose
2 | from torch.utils.data import DataLoader
3 | import pytorch_lightning as pl
4 | from typing import List
5 |
6 |
7 | from data.dataset import FastMUSDB
8 | from data.augment import CPUBase
9 |
10 |
11 | class MUSDB(pl.LightningDataModule):
12 | def __init__(
13 | self,
14 | root: str,
15 | seq_duration: float = 6.0,
16 | samples_per_track: int = 64,
17 | random: bool = False,
18 | random_track_mix: bool = False,
19 | transforms: List[CPUBase] = None,
20 | batch_size: int = 16,
21 | ):
22 | super().__init__()
23 | self.save_hyperparameters(
24 | "root",
25 | "seq_duration",
26 | "samples_per_track",
27 | "random",
28 | "random_track_mix",
29 | "batch_size",
30 | )
31 | # manually save transforms since pedalboard is not pickleable
32 | if transforms is None:
33 | self.transforms = None
34 | else:
35 | self.transforms = Compose(transforms)
36 |
37 | def setup(self, stage=None):
38 | if stage == "fit":
39 | self.train_dataset = FastMUSDB(
40 | root=self.hparams.root,
41 | subsets=["train"],
42 | seq_duration=self.hparams.seq_duration,
43 | samples_per_track=self.hparams.samples_per_track,
44 | random=self.hparams.random,
45 | random_track_mix=self.hparams.random_track_mix,
46 | transform=self.transforms,
47 | )
48 |
49 | if stage == "validate" or stage == "fit":
50 | self.val_dataset = FastMUSDB(
51 | root=self.hparams.root,
52 | subsets=["test"],
53 | seq_duration=self.hparams.seq_duration,
54 | )
55 |
56 | def train_dataloader(self):
57 | return DataLoader(
58 | self.train_dataset,
59 | batch_size=self.hparams.batch_size,
60 | num_workers=4,
61 | shuffle=True,
62 | drop_last=True,
63 | )
64 |
65 | def val_dataloader(self):
66 | return DataLoader(
67 | self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4
68 | )
69 |
--------------------------------------------------------------------------------
/data/lightning/speech.py:
--------------------------------------------------------------------------------
1 | from torchvision.transforms import Compose
2 | from torch.utils.data import DataLoader
3 | import pytorch_lightning as pl
4 | from typing import List, Optional
5 | import torch
6 |
7 | from data.dataset.speech import SpeechNoise as SpeechNoiseDataset
8 | from data.augment import CPUBase
9 |
10 |
11 | class SpeechNoise(pl.LightningDataModule):
12 | def __init__(
13 | self,
14 | speech_root: str,
15 | noise_root: str,
16 | seq_duration: float = 6.0,
17 | samples_per_track: int = 64,
18 | least_overlap_ratio: float = 0.5,
19 | snr_sampler: Optional[torch.distributions.Distribution] = None,
20 | mono: bool = True,
21 | transforms: List[CPUBase] = None,
22 | batch_size: int = 16,
23 | ):
24 | super().__init__()
25 | self.save_hyperparameters(
26 | "speech_root",
27 | "noise_root",
28 | "seq_duration",
29 | "samples_per_track",
30 | "least_overlap_ratio",
31 | "snr_sampler",
32 | "mono",
33 | "batch_size",
34 | )
35 | # manually save transforms since pedalboard is not pickleable
36 | if transforms is None:
37 | self.transforms = None
38 | else:
39 | self.transforms = Compose(transforms)
40 |
41 | def setup(self, stage=None):
42 | if stage == "fit":
43 | self.train_dataset = SpeechNoiseDataset(
44 | speech_root=self.hparams.speech_root,
45 | noise_root=self.hparams.noise_root,
46 | seq_duration=self.hparams.seq_duration,
47 | samples_per_track=self.hparams.samples_per_track,
48 | least_overlap_ratio=self.hparams.least_overlap_ratio,
49 | snr_sampler=self.hparams.snr_sampler,
50 | mono=self.hparams.mono,
51 | transform=self.transforms,
52 | )
53 |
54 | def train_dataloader(self):
55 | return DataLoader(
56 | self.train_dataset,
57 | batch_size=self.hparams.batch_size,
58 | num_workers=4,
59 | shuffle=True,
60 | drop_last=True,
61 | )
62 |
--------------------------------------------------------------------------------
/docs/aimless-logo-crop.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: demixing
2 | channels:
3 | - defaults
4 | - pytorch
5 | - nvidia
6 |
7 | dependencies:
8 | - pip
9 | - numpy
10 | - pytorch
11 | - torchaudio
12 | - torchvision
13 | - cudatoolkit=11.7
14 | - pip:
15 | - pytorch-lightning[extra]==1.9.4
16 | - torch-optimizer
17 | - musdb
18 | - git+https://github.com/yoyololicon/norbert
19 | - pedalboard
20 | - audio-data-pytorch
21 | - pyloudnorm
22 | - torch_fftconv
23 | - streamlit
24 | - librosa
25 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_lightning.cli import LightningCLI
3 | from pytorch_lightning.strategies import DDPStrategy
4 |
5 |
6 | def cli_main():
7 | torch.set_float32_matmul_precision("medium")
8 |
9 | cli = LightningCLI(
10 | trainer_defaults={
11 | "accelerator": "gpu",
12 | "strategy": DDPStrategy(find_unused_parameters=False),
13 | "log_every_n_steps": 1,
14 | }
15 | )
16 |
17 |
18 | if __name__ == "__main__":
19 | cli_main()
20 |
--------------------------------------------------------------------------------
/scripts/audio_utils.py:
--------------------------------------------------------------------------------
1 | import soundfile as sf
2 | import numpy as np
3 | import resampy
4 | from math import ceil
5 | import librosa
6 |
7 | from scipy.signal import stft
8 | import pyloudnorm as pyln
9 | import math
10 | import warnings
11 |
12 | # We silence this warning as we peak normalize the samples before bouncing them
13 | warnings.filterwarnings("ignore", message="Possible clipped samples in output.")
14 |
15 |
16 | def trim_relative_silence_from_audio(audio, sr, frame_duration=0.04, hop_duration=0.01):
17 | assert 0 < hop_duration <= frame_duration
18 | frame_length = int(frame_duration * sr)
19 | hop_length = int(hop_duration * sr)
20 |
21 | _, _, S = stft(
22 | audio,
23 | nfft=frame_length,
24 | noverlap=frame_length - hop_length,
25 | padded=True,
26 | nperseg=frame_length,
27 | boundary="constant",
28 | )
29 | rms = librosa.feature.rms(
30 | S=S, frame_length=frame_length, hop_length=hop_length, pad_mode="constant"
31 | ).flatten()
32 | threshold = 0.01 * rms.max()
33 | active_flag = rms >= threshold
34 | active_idxs = active_flag.nonzero()[0]
35 | start_idx = max(int(max(active_idxs[0] - 1, 0) * hop_duration * sr), 0)
36 | end_idx = min(
37 | int(ceil(min(active_idxs[-1] + 1, rms.shape[0]) * hop_duration * sr)),
38 | audio.shape[0],
39 | )
40 |
41 | return start_idx, end_idx
42 |
43 |
44 | def lufs_norm(data, sr, norm=-6):
45 | loudness = get_lufs(data, sr)
46 |
47 | assert not math.isinf(loudness)
48 |
49 | norm_data = pyln.normalize.loudness(data, loudness, norm)
50 | n, d = np.sum(np.array(norm_data)), np.sum(np.array(data))
51 | gain = n / d if d else 0.0
52 |
53 | return norm_data, gain
54 |
55 |
56 | def get_lufs(data, sr):
57 | block_size = 0.4 if len(data) / sr >= 0.4 else len(data) / sr
58 | # measure the loudness first
59 | meter = pyln.Meter(rate=sr, block_size=block_size) # create BS.1770 meter
60 | loudness = meter.integrated_loudness(data)
61 |
62 | return loudness
63 |
64 |
65 | def peak_norm(data, mx):
66 | eps = 1e-10
67 | max_sample = np.max(np.abs(data))
68 | scale_factor = mx / (max_sample + eps)
69 |
70 | return data * scale_factor
71 |
72 |
73 | def gain_to_db(g):
74 | return 20 * np.log10(g)
75 |
76 |
77 | def db_to_gain(db):
78 | return 10 ** (db / 20.0)
79 |
80 |
81 | def gain_from_combined_db_levels(dbs):
82 | return db_to_gain(sum(dbs))
83 |
84 |
85 | def validate_audio(d):
86 | assert np.isnan(d).any() == False, "Nan value found in mixture"
87 | assert np.isneginf(d).any() == False, "Neg. Inf value found in mixture"
88 | assert np.isposinf(d).any() == False, "Pos. Inf value found in mixture"
89 | assert np.isinf(d).any() == False, "Inf value found in mixture"
90 |
--------------------------------------------------------------------------------
/scripts/convert_to_dnr.py:
--------------------------------------------------------------------------------
1 | # Similar to dnr-utils (https://github.com/darius522/dnr-utils)
2 |
3 | from tqdm import tqdm
4 | import numpy as np
5 | from scipy import stats
6 | import audio_utils
7 | import soundfile as sf
8 | import os
9 | import pandas as pd
10 | import random
11 | import matplotlib.pyplot as plt
12 | from glob import glob
13 | import collections
14 | from itertools import repeat
15 | import argparse
16 | import librosa
17 | import math
18 |
19 |
20 | def apply_fadeout(audio, sr, duration=3.0):
21 | # convert to audio indices (samples)
22 | length = int(duration * sr)
23 | end = audio.shape[0]
24 | start = end - length
25 |
26 | # compute fade out curve
27 | # linear fade
28 | fade_curve = np.linspace(1.0, 0.0, length)
29 |
30 | # apply the curve
31 | audio[start:end] = audio[start:end] * fade_curve
32 | return audio
33 |
34 |
35 | def create_dir(dir, subdirs=None):
36 | os.makedirs(dir, exist_ok=True)
37 |
38 | if subdirs:
39 | for s in subdirs:
40 | os.makedirs(os.path.join(dir, s), exist_ok=True)
41 |
42 | return dir
43 |
44 |
45 | def make_unique_filename(files, dir=None):
46 | """
47 | Given a list of files, generate a new - unique - filename
48 | """
49 | while True:
50 | f = str(np.random.randint(low=100, high=100000))
51 | if not f in files:
52 | break
53 | # Optionally create dir
54 | if dir:
55 | create_dir(os.path.join(dir, f))
56 | return f, os.path.join(dir, f)
57 |
58 |
59 | def set_seed(seed):
60 | random.seed(seed)
61 | np.random.seed(seed)
62 |
63 |
64 | def aug_gain(config, audio, low=0.0, high=1.0):
65 | gain = np.random.uniform(low=low, high=high)
66 | return audio * gain
67 |
68 |
69 | def aug_reverb(config, audio, ir):
70 | pass
71 |
72 |
73 | def gen_poisson(mu):
74 | """Adapted from https://github.com/rmalouf/learning/blob/master/zt.py"""
75 | r = np.random.uniform(low=stats.poisson.pmf(0, mu))
76 | return int(stats.poisson.ppf(r, mu))
77 |
78 |
79 | def gen_norm(mu, sig):
80 | return np.random.normal(loc=mu, scale=sig)
81 |
82 |
83 | def gen_skewnorm(skew, mu, sig):
84 | # negative a means skewed left while positive means skewed right; a=0 -> normal
85 | return stats.skewnorm(a=skew, loc=mu, scale=sig).rvs()
86 |
87 |
88 | def get_some_noise(shape):
89 | return np.random.randn(shape).astype(np.float32)
90 |
91 |
92 | class MixtureObj:
93 | def __init__(
94 | self,
95 | seq_dur=60.0,
96 | sr=44100,
97 | files=None,
98 | mix_lufs=None,
99 | partition="train",
100 | peak_norm_db=None,
101 | annot_path=None,
102 | class_dirs=None,
103 | ):
104 | self.seq_dur = seq_dur
105 | self.sr = sr
106 | self.input_dirs = class_dirs
107 | self.annot_path = annot_path
108 | self.mix = np.zeros(int(self.seq_dur * self.sr))
109 | self.mix_lufs = mix_lufs
110 | self.partition = partition
111 | self.allocated_windows = {
112 | "music": {"times": [], "samples": [], "levels": [], "files": []},
113 | "sfx": {"times": [], "samples": [], "levels": [], "files": []},
114 | "speech": {"times": [], "samples": [], "levels": [], "files": []},
115 | }
116 |
117 | # Set some submix-specific values
118 | self.submixes = {"music": [], "sfx": [], "speech": []}
119 |
120 | self.mix_max_peak = 0.0
121 | self.peak_norm = audio_utils.db_to_gain(peak_norm_db)
122 |
123 | self.files = files
124 |
125 | self.num_seg_music = gen_poisson(7.0)
126 | self.num_seg_sfx = gen_poisson(12.0)
127 | self.num_seg_speech = gen_poisson(8.0)
128 |
129 | def check_overlap_for_range(self, r1, c):
130 | ranges = self.allocated_windows[c]["times"]
131 | for r2 in ranges:
132 | if not (np.sum(r1) < r2[0] or r1[0] > np.sum(r2)):
133 | return True
134 | return False
135 |
136 | def _check_for_files_exhaustion(self):
137 | if len(self.files["speech"]) < self.num_seg_speech:
138 | return True
139 | return False
140 |
141 | def _load_wav(self, file):
142 | d, sr = librosa.load(file, sr=self.sr)
143 | return d
144 |
145 | def _register_event(
146 | self, mix_pos, length, cl, file, clip_start, clip_end, clip_gain
147 | ):
148 | self.allocated_windows[cl]["files"].append(file)
149 | self.allocated_windows[cl]["times"].append([mix_pos, length])
150 | self.allocated_windows[cl]["samples"].append([clip_start, clip_end])
151 | self.allocated_windows[cl]["levels"].append(clip_gain)
152 |
153 | def _update_event_levels(self, gain):
154 | for c in self.allocated_windows.keys():
155 | self.allocated_windows[c]["levels"] = [
156 | gain * x for x in self.allocated_windows[c]["levels"]
157 | ]
158 |
159 | def _get_time_range_for_submix(self, data, c, idx, num_events):
160 | """
161 | Compute time window for sound in submix
162 | Args:
163 | data: array sound [Nspl, Nch]
164 | c: the sound class [music, sfx, speech]
165 | idx: the sound index
166 | """
167 | file_len = len(data) / self.sr
168 |
169 | # We don't want two mix of the same classto overlap
170 | min_start = (
171 | 0.0
172 | if not len(self.allocated_windows[c]["times"])
173 | else np.ceil(np.sum(self.allocated_windows[c]["times"][-1]))
174 | )
175 |
176 | # Check if we have enough time remaining (at least 30% length of current file or for speech the ENTIRE LENGTH)
177 | # if c == "speech":
178 | # import pdb
179 |
180 | # pdb.set_trace()
181 | if (self.seq_dur - min_start <= file_len * 0.3) or (
182 | c == "speech" and (self.seq_dur - min_start <= file_len)
183 | ):
184 | return None, None
185 |
186 | # Pick start based on previous event pos.
187 | end_lim = self.seq_dur - file_len if c == "speech" else self.seq_dur - 2.0
188 | start = min(gen_skewnorm(skew=5, mu=int(min_start + 1), sig=2.0), end_lim)
189 | max_len = (
190 | file_len if self.seq_dur - start >= file_len else (self.seq_dur - start)
191 | )
192 |
193 | # If speech, take the whole thing
194 | if c == "speech":
195 | length = max_len
196 | else:
197 | length = max(
198 | min(gen_norm(mu=max_len / 2.0, sig=max_len / 10.0), max_len), 0.1
199 | )
200 |
201 | return start, length
202 |
203 | def _create_from_annots(self, annot_path):
204 | """
205 | Create the mixture strictly from a given annotation file
206 | """
207 | columns = [
208 | "file",
209 | "class",
210 | "mix_start_time",
211 | "mix_end_time",
212 | "clip_start_time",
213 | "clip_end_time",
214 | "clip_gain",
215 | "annotation",
216 | ]
217 | annots = pd.read_csv(annot_path, names=columns, skiprows=1)
218 | self.submixes = {k: np.zeros_like(self.mix) for k in self.submixes.keys()}
219 |
220 | for index, row in annots.iterrows():
221 | f_class = row["class"]
222 | abs_path = os.path.join(
223 | self.input_dirs[f_class], "/".join(row["file"].split("/")[1:])
224 | )
225 | data, _ = sf.read(abs_path)
226 |
227 | mix_start = int(row["mix_start_time"] * self.sr)
228 | mix_end = int(row["mix_end_time"] * self.sr)
229 | clip_start = int(row["clip_start_time"] * self.sr)
230 | clip_end = int(row["clip_end_time"] * self.sr)
231 |
232 | clip_gain = float(row["clip_gain"])
233 |
234 | length = min(len(data), clip_end - clip_start, mix_end - mix_start)
235 |
236 | clip = data[clip_start : clip_start + length] * clip_gain
237 | self.submixes[f_class][mix_start : mix_start + length] += clip
238 |
239 | # Finally collapse submixes
240 | self.mix = np.sum(
241 | np.stack([self.submixes[c] for c in self.submixes.keys()], -1), -1
242 | )
243 |
244 | def _set_submix(self, c, num_seg):
245 | submix = np.zeros(int(self.seq_dur * self.sr))
246 | class_lufs = np.random.uniform(
247 | self.mix_lufs[c] - self.mix_lufs["ranges"]["class"],
248 | self.mix_lufs[c] + self.mix_lufs["ranges"]["class"],
249 | )
250 |
251 | for i in range(num_seg):
252 | # If eval and speech, mix without replacement until exhaustion
253 | if self.partition == "eval" and c == "speech":
254 | f = self.files[c].pop(random.randrange(len(self.files[c])))
255 | else:
256 | f = np.random.choice(self.files[c])
257 |
258 | d = self._load_wav(f)
259 | if c == "speech":
260 | # If speech is too long, trim to 70% and apply 3 second fade-out
261 | max_speech_len = math.floor(0.7 * self.seq_dur * self.sr - 1)
262 | if len(d) >= max_speech_len:
263 | d = apply_fadeout(d[:max_speech_len], self.sr, 3)
264 | s, l = self._get_time_range_for_submix(d, c, i, num_seg)
265 | clip_lufs = np.random.uniform(
266 | class_lufs - self.mix_lufs["ranges"]["clip"],
267 | class_lufs + self.mix_lufs["ranges"]["clip"],
268 | )
269 | if s != None and l != None:
270 | try:
271 | s_clip, l_clip = int(s * self.sr), int(l * self.sr)
272 | r_spl = np.random.randint(0, len(d) - l_clip) if c == "music" else 0
273 | data = np.copy(d[r_spl : r_spl + l_clip])
274 | data_norm, gain = audio_utils.lufs_norm(
275 | data=data, sr=self.sr, norm=clip_lufs
276 | )
277 | submix[s_clip : s_clip + l_clip] = data_norm
278 | self._register_event(
279 | mix_pos=s,
280 | length=l,
281 | cl=c,
282 | file=f,
283 | clip_start=r_spl / self.sr,
284 | clip_end=(r_spl / self.sr) + l,
285 | clip_gain=gain,
286 | )
287 | except Exception as e:
288 | print(e)
289 | print("could not register event. Skipping...")
290 | print(f, c, s, l, s_clip, l_clip)
291 | elif c == "speech":
292 | self.files[c].append(f)
293 |
294 | self.mix_max_peak = max(self.mix_max_peak, np.max(np.abs(submix)))
295 | self.submixes[c] = submix
296 |
297 | def _create_final_mix(self):
298 | # Add some noise to sfx submix
299 | rand_range = self.mix_lufs["ranges"]["class"]
300 |
301 | # Set the peak norm gain
302 | peak_norm_gain = (
303 | 1.0
304 | if self.mix_max_peak <= self.peak_norm
305 | else self.peak_norm / self.mix_max_peak
306 | )
307 | # Compute master gain
308 | master_lufs = np.random.uniform(
309 | self.mix_lufs["master"] - rand_range, self.mix_lufs["master"] + rand_range
310 | )
311 |
312 | # After adding to master, peak normalize the submixes
313 | for c in self.submixes.keys():
314 | # Peak norm
315 | self.submixes[c] *= peak_norm_gain
316 | # Add submix to main mix
317 | self.mix += self.submixes[c]
318 |
319 | self.mix, master_gain = audio_utils.lufs_norm(
320 | data=self.mix, sr=self.sr, norm=master_lufs
321 | )
322 |
323 | # Main LUFS norm
324 | for c in self.submixes.keys():
325 | self.submixes[c] *= master_gain
326 |
327 | # Optionally add some noise to the mix (i.e. avoid digital silences)
328 | if self.mix_lufs["noise"] != None:
329 | noise_lufs = np.random.uniform(
330 | self.mix_lufs["noise"] - rand_range, self.mix_lufs["noise"] + rand_range
331 | )
332 | noise = get_some_noise(shape=int(self.seq_dur * self.sr))
333 | noise, _ = audio_utils.lufs_norm(data=noise, sr=self.sr, norm=noise_lufs)
334 | self.mix += noise
335 |
336 | # After peak norm / master norm, we need to update the registered events' props
337 | self._update_event_levels(master_gain * peak_norm_gain)
338 |
339 | def __call__(self):
340 | # Check if we build from the annotation or not
341 | if self.annot_path:
342 | self._create_from_annots(annot_path=self.annot_path)
343 | return self.submixes, self.mix, None, None
344 | # If not, Create submixes from scratch
345 | else:
346 | if not self._check_for_files_exhaustion():
347 | self._set_submix(c="music", num_seg=self.num_seg_music)
348 | self._set_submix(c="speech", num_seg=self.num_seg_speech)
349 | self._set_submix(c="sfx", num_seg=self.num_seg_sfx)
350 | self._create_final_mix()
351 |
352 | return self.submixes, self.mix, self.allocated_windows, self.files
353 | else:
354 | return None, None, None, self.files
355 |
356 |
357 | class DatasetCompiler:
358 | def __init__(
359 | self,
360 | speech_set: str = None,
361 | fx_set: str = None,
362 | music_set: str = None,
363 | num_output_files: int = 1000,
364 | output_dir: str = ".",
365 | sample_rate: int = 44100,
366 | seq_dur: float = 60.0,
367 | partition: str = "tr",
368 | mix_lufs: dict = None,
369 | peak_norm_db: float = -0.5,
370 | ):
371 | self.output_files = []
372 | self.speech_set = speech_set
373 | self.fx_set = fx_set
374 | self.music_set = music_set
375 | self.partition = partition
376 | self.sr = sample_rate
377 | self.num_files = num_output_files
378 | self.output_dir = output_dir
379 | self.wavfiles = {
380 | "speech": self._get_filepaths(speech_set),
381 | "sfx": self._get_filepaths(fx_set),
382 | "music": self._get_filepaths(music_set),
383 | }
384 | self.mix_lufs = mix_lufs
385 | self.peak_norm_db = peak_norm_db
386 | self.seq_dur = seq_dur
387 |
388 | self.stats = {"overlap": {}, "times": {}}
389 |
390 | def _get_mix_stats(self, data, res):
391 | """
392 | Compute amount of overlap between given classes
393 | Args:
394 | data: times
395 | res: resolution of the timesteps
396 | """
397 | # Average sample length
398 | for c in data.keys():
399 | for t in data[c]["times"]:
400 | self.stats["times"].setdefault(c, []).append(t[-1])
401 |
402 | overlap = 0.0
403 | classes = list(data.keys())
404 | timesteps = np.arange(0, self.seq_dur, res)
405 | labels = []
406 |
407 | # Walk through steps
408 | for i, j in enumerate(timesteps):
409 | # Check if each class is present at this step
410 | detect = []
411 | for c in classes:
412 | for t in data[c]["times"]:
413 | s, e = t[0], sum(t)
414 | if j >= s and j <= e:
415 | detect.append(c)
416 | break
417 | labels.append("-".join(sorted(detect)))
418 |
419 | counter = collections.Counter(labels)
420 |
421 | values = list(counter.values())
422 | keys = list(counter.keys())
423 | values = [x / sum(values) for x in values]
424 | for k, v in zip(keys, values):
425 | self.stats["overlap"].setdefault(k, []).append(v)
426 |
427 | return overlap
428 |
429 | def _plot_stats(self):
430 | plt.rcParams.update({"font.size": 25})
431 |
432 | # Build the plot
433 | fig, ax = plt.subplots(2, 1, figsize=(25, 25))
434 |
435 | for i, stats in enumerate(self.stats.keys()):
436 | data = self.stats[stats]
437 | classes = list(data.keys())
438 |
439 | x_pos = np.arange(len(classes))
440 | means = [np.mean(data[k]) for k in classes]
441 | error = [np.std(data[k]) for k in classes]
442 | try:
443 | classes[classes.index("")] = "silence"
444 | except:
445 | pass
446 |
447 | ax[i].bar(
448 | x_pos,
449 | means,
450 | yerr=error,
451 | align="center",
452 | alpha=0.5,
453 | ecolor="black",
454 | capsize=10,
455 | )
456 | ax[i].set_xticks(x_pos)
457 | ax[i].set_xticklabels(
458 | classes, rotation=45, va="center", position=(0, -0.28)
459 | )
460 | ax[i].yaxis.grid(True)
461 |
462 | ax[0].set_ylabel("Presence Amount")
463 | ax[0].set_title(
464 | self.partition
465 | + " - Classes Overlap ("
466 | + str(len(self.output_files))
467 | + " Mixtures)"
468 | )
469 | ax[1].set_ylabel("Length (s)")
470 | ax[1].set_title(
471 | self.partition
472 | + " - Average Sample Length per Classes ("
473 | + str(len(self.output_files))
474 | + " Mixtures)"
475 | )
476 |
477 | # Save the figure and show
478 | plt.tight_layout()
479 | plt.savefig(os.path.join(self.output_dir, self.partition + "_set_stats.png"))
480 |
481 | def _get_filepaths(self, dir, c=""):
482 | files = glob(dir + "/**/*.wav", recursive=True)
483 | files += glob(dir + "/**/*.flac", recursive=True)
484 | # If eval and speech, repeat list twice
485 | if self.partition == "eval" and c == "speech":
486 | return [x for item in files for x in repeat(item, 2)]
487 | return files
488 |
489 | def _get_annot_paths(self, dir):
490 | return glob(dir + "/**/annots.csv", recursive=True)
491 |
492 | def _write_wav(self, f_dir, sr, data, source):
493 | audio_utils.validate_audio(data)
494 | sf.write(os.path.join(f_dir, source + ".wav"), data, sr, subtype="FLOAT")
495 |
496 | def _write_ground_truth(self, f_dir, data):
497 | out_df = pd.DataFrame(
498 | columns=[
499 | "file",
500 | "class",
501 | "mix start time",
502 | "mix end time",
503 | "clip start sample",
504 | "clip end sample",
505 | "clip gain",
506 | "annotation",
507 | ]
508 | )
509 |
510 | sources = list(data.keys())
511 | # Sort events in ascending order, time-wise
512 | all_data = [
513 | [t, spl, f, s, l]
514 | for s in sources
515 | for t, spl, f, l in zip(
516 | data[s]["times"],
517 | data[s]["samples"],
518 | data[s]["files"],
519 | data[s]["levels"],
520 | )
521 | ]
522 | sorted_data = sorted(all_data, key=lambda x: x[0][0])
523 |
524 | # Retrieve annotations and write to pd frame
525 | for e in sorted_data:
526 | time, spl, file, source, level = e
527 | f_id = os.path.basename(file).split(".")[0]
528 | annot = f_id
529 |
530 | row = [file, source, time[0], sum(time), spl[0], spl[1], level, annot]
531 | out_df.loc[len(out_df.index)] = row
532 |
533 | out_df.to_csv(os.path.join(f_dir, "annots.csv"))
534 |
535 | def __call__(self):
536 | print(f"Building {self.partition} dataset...")
537 | for i in tqdm(range(self.num_files)):
538 | sources, mix, annots, files = MixtureObj(
539 | seq_dur=self.seq_dur,
540 | partition=self.partition,
541 | sr=self.sr,
542 | files=self.wavfiles,
543 | mix_lufs=self.mix_lufs,
544 | peak_norm_db=self.peak_norm_db,
545 | )()
546 | # For eval set, we drain the speech dataset until empty
547 | if sources:
548 | f_name, f_dir = make_unique_filename(self.output_files, self.output_dir)
549 | self._write_ground_truth(f_dir, annots)
550 | sources["mix"] = mix
551 | for source in list(sources.keys()):
552 | self._write_wav(f_dir, int(self.sr), sources[source], source)
553 | self.output_files.append(f_name)
554 | self._get_mix_stats(annots, 0.1)
555 |
556 |
557 | def main():
558 | set_seed(42)
559 | parser = argparse.ArgumentParser(
560 | description="Convert existing music, speech, and sfx datasets into format used by DNR"
561 | )
562 | parser.add_argument("music_dir", type=str)
563 | parser.add_argument("speech_dir", type=str)
564 | parser.add_argument("sfx_dir", type=str)
565 | parser.add_argument(
566 | "-n",
567 | "--num_output_files",
568 | type=int,
569 | default=1000,
570 | help="Total number of output files before splitting",
571 | )
572 | parser.add_argument("-sr", type=int, help="Sample rate", default=44100)
573 | parser.add_argument("-o", "--output_dir", type=str, required=False)
574 |
575 | args = parser.parse_args()
576 | music_dir = args.music_dir
577 | speech_dir = args.speech_dir
578 | sfx_dir = args.sfx_dir
579 | num_output_files = args.num_output_files
580 | sample_rate = args.sr
581 | output_dir = args.output_dir
582 |
583 | # Create output directory if it doesn't exist
584 | if not output_dir:
585 | output_dir = os.path.join(os.getcwd(), "data")
586 |
587 | splits = {"train": 0.7, "val": 0.1, "test": 0.2}
588 |
589 | mix_lufs = {
590 | "music": -24,
591 | "sfx": -21,
592 | "speech": -17,
593 | "master": -27,
594 | "noise": None,
595 | "ranges": {
596 | "class": 2,
597 | "clip": 1,
598 | },
599 | }
600 |
601 | for split_name in ["train", "val", "test"]:
602 | split_output_dir = create_dir(os.path.join(output_dir, split_name))
603 | DatasetCompiler(
604 | speech_dir,
605 | sfx_dir,
606 | music_dir,
607 | num_output_files=round(num_output_files * splits[split_name]),
608 | output_dir=split_output_dir,
609 | sample_rate=sample_rate,
610 | partition=split_name,
611 | mix_lufs=mix_lufs,
612 | )()
613 |
614 |
615 | if __name__ == "__main__":
616 | main()
617 |
--------------------------------------------------------------------------------
/scripts/dataset_split_and_mix.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from tqdm import tqdm
3 | import torchaudio
4 | import sys
5 |
6 |
7 | # Run with python scripts/dataset_split_and_mix.py /path/to/dataset
8 | def main():
9 | root = Path(sys.argv[1])
10 | splits = [0.9, 0.1, 0]
11 |
12 | train = root / "train"
13 | # train = Path("./") / "train"
14 | valid = root / "valid"
15 | # valid = Path("./") / "valid"
16 | # test = root / "test"
17 | train.mkdir(exist_ok=True)
18 | valid.mkdir(exist_ok=True)
19 | # test.mkdir(exist_ok=True)
20 | total_tracks = len(list(root.iterdir())) - 3 # 3 for train, valid, test
21 | num_train = int(total_tracks * splits[0])
22 | num_valid = int(total_tracks * splits[1])
23 | num_test = int(total_tracks * splits[2])
24 |
25 | # Add remainder as necessary
26 | remainder = total_tracks - (num_train + num_valid + num_test)
27 | for i in range(remainder):
28 | if i % 2 == 0:
29 | num_train += 1
30 | elif i % 2 == 1:
31 | num_valid += 1
32 | # else:
33 | # num_test += 1
34 |
35 | num = 0
36 | for d in tqdm(root.iterdir(), total=total_tracks):
37 | if d.is_dir() and d.name != "train" and d.name != "valid" and d.name != "test":
38 | bass, sr = torchaudio.load(str(d / "bass.wav"))
39 | drums, sr = torchaudio.load(str(d / "drums.wav"))
40 | other, sr = torchaudio.load(str(d / "other.wav"))
41 | vocals, sr = torchaudio.load(str(d / "vocals.wav"))
42 | mixture = bass + drums + other + vocals
43 | torchaudio.save(str(d / "mixture.wav"), mixture, sr)
44 |
45 | if num < num_train:
46 | d.rename(train / d.name)
47 | # elif num < num_train + num_valid:
48 | else:
49 | d.rename(valid / d.name)
50 | # else:
51 | # d.rename(test / d.name)
52 | num += 1
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/scripts/download_youtube.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from torchaudio.utils import download_asset
4 | from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
5 | from torchaudio.transforms import Fade
6 | from tqdm import tqdm
7 | from audio_data_pytorch import YoutubeDataset, Resample
8 | import os
9 |
10 | bundle = HDEMUCS_HIGH_MUSDB_PLUS
11 | sample_rate = bundle.sample_rate
12 | fade_overlap = 0.1
13 | root = "./youtube_data"
14 | print(f"Sample rate: {sample_rate}")
15 |
16 |
17 | def main():
18 | url_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "urls.txt")
19 | with open(url_file_path) as f:
20 | urls = f.readlines()
21 | print("Loading dataset...")
22 | dataset = YoutubeDataset(
23 | root=root,
24 | urls=urls,
25 | crop_length=10, # Crop source in 10s chunks (optional but suggested)
26 | transforms=torch.nn.Sequential(
27 | Resample(source=48000, target=sample_rate),
28 | Fade(
29 | fade_in_len=0,
30 | fade_out_len=int(fade_overlap * sample_rate),
31 | fade_shape="linear",
32 | ),
33 | ),
34 | )
35 | print("Loading model...")
36 | model = bundle.get_model()
37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
38 | model.to(device)
39 |
40 | print("Separating chunks...")
41 | for i, chunk in enumerate(tqdm(dataset)):
42 | original_url_idx = int(dataset.wavs[i].split("/")[-1].split(".")[0])
43 | sources = separate(chunk, model=model, device=device)
44 |
45 | if not os.path.exists(f"{root}/separated/{original_url_idx}"):
46 | os.makedirs(f"{root}/separated/{original_url_idx}")
47 | for source in sources:
48 | torchaudio.save(
49 | f"{root}/separated/{original_url_idx}/{source}.wav",
50 | sources[source], # Transpose to get channels first for soundfile
51 | sample_rate=sample_rate,
52 | )
53 |
54 |
55 | def separate(chunk: torch.Tensor, model: torch.nn.Module, device: torch.device):
56 | chunk.to(device)
57 | ref = chunk.mean(0)
58 | chunk = (chunk - ref.mean()) / ref.std() # normalization
59 |
60 | with torch.no_grad():
61 | out = model.forward(chunk[None])
62 |
63 | sources = out.squeeze(0)
64 | sources = sources * ref.std() + ref.mean() # denormalization
65 |
66 | sources_list = model.sources
67 | sources = list(sources)
68 | dict_sources = dict(zip(sources_list, sources))
69 | return dict_sources
70 |
71 |
72 | if __name__ == "__main__":
73 | main()
74 |
--------------------------------------------------------------------------------
/scripts/musdb_to_voiceless.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pathlib
3 | import soundfile as sf
4 |
5 | if __name__ == "__main__":
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("path", type=str, help="Path to the MUSDB18 dataset")
8 | parser.add_argument("output", type=str, help="Path to the output folder")
9 | args = parser.parse_args()
10 | print("Converting MUSDB18 to voiceless")
11 | print("This will take a while ...")
12 | print("")
13 |
14 | path = pathlib.Path(args.path)
15 | output = pathlib.Path(args.output)
16 | output.mkdir(exist_ok=True)
17 |
18 | for track in path.glob("**/mixture.wav"):
19 | folder = track.parent
20 | print(f"Processing {folder}")
21 |
22 | bass, sr = sf.read(folder / "bass.wav")
23 | drums, sr = sf.read(folder / "drums.wav")
24 | other, sr = sf.read(folder / "other.wav")
25 | mix = bass + drums + other
26 |
27 | output_folder = output / folder.relative_to(path)
28 | output_folder.mkdir(exist_ok=True, parents=True)
29 | sf.write(output_folder / "mixture.wav", mix, sr)
30 |
--------------------------------------------------------------------------------
/scripts/urls.txt:
--------------------------------------------------------------------------------
1 | https://www.youtube.com/watch?v=mlB-3842Cjs
--------------------------------------------------------------------------------
/scripts/webapp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import librosa
4 | import torchaudio
5 | import numpy as np
6 | import streamlit as st
7 | import librosa.display
8 | import pyloudnorm as pyln
9 | import matplotlib.pyplot as plt
10 |
11 | from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
12 | from torchaudio.transforms import Fade
13 |
14 |
15 | st.set_page_config(page_title="aimless-splitter")
16 | st.image("docs/aimless-logo-crop.png", use_column_width="always")
17 |
18 | bundle = HDEMUCS_HIGH_MUSDB_PLUS
19 | sample_rate = bundle.sample_rate
20 | fade_overlap = 0.1
21 |
22 |
23 | @st.experimental_singleton
24 | def load_hdemucs():
25 | print("Loading pretrained HDEMUCS model...")
26 | model = bundle.get_model()
27 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28 | model.to(device)
29 |
30 | return model
31 |
32 |
33 | def plot_spectrogram(y, *, sample_rate, figsize=(12, 3)):
34 | # Convert to mono
35 | if y.ndim > 1:
36 | y = y[0]
37 |
38 | fig = plt.figure(figsize=figsize)
39 | D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max, top_db=120)
40 | img = librosa.display.specshow(D, y_axis="linear", x_axis="time", sr=sample_rate)
41 | fig.colorbar(img, format="%+2.f dB")
42 | st.pyplot(fig=fig, clear_figure=True)
43 |
44 |
45 | def separate_sources(
46 | model: torch.nn.Module,
47 | mix: torch.Tensor,
48 | segment: float = 10.0,
49 | overlap: float = 0.1,
50 | device=None,
51 | ):
52 | """
53 | Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.
54 |
55 | Args:
56 | segment (int): segment length in seconds
57 | device (torch.device, str, or None): if provided, device on which to
58 | execute the computation, otherwise `mix.device` is assumed.
59 | When `device` is different from `mix.device`, only local computations will
60 | be on `device`, while the entire tracks will be stored on `mix.device`.
61 | """
62 | if device is None:
63 | device = mix.device
64 | else:
65 | device = torch.device(device)
66 |
67 | batch, channels, length = mix.shape
68 |
69 | chunk_len = int(sample_rate * segment * (1 + overlap))
70 | start = 0
71 | end = chunk_len
72 | overlap_frames = overlap * sample_rate
73 | fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
74 |
75 | final = torch.zeros(batch, len(model.sources), channels, length, device=device)
76 |
77 | while start < length - overlap_frames:
78 | chunk = mix[:, :, start:end]
79 | with torch.no_grad():
80 | out = model.forward(chunk)
81 | out = fade(out)
82 | final[:, :, :, start:end] += out
83 | if start == 0:
84 | fade.fade_in_len = int(overlap_frames)
85 | start += int(chunk_len - overlap_frames)
86 | else:
87 | start += chunk_len
88 | end += chunk_len
89 | if end >= length:
90 | fade.fade_out_len = 0
91 | return final
92 |
93 |
94 | def process_file(file, model: torch.nn.Module, device: torch.device):
95 | import tempfile
96 | import shutil
97 | from pathlib import Path
98 |
99 | # Cache file to disk so we can read it with ffmpeg
100 | with tempfile.NamedTemporaryFile("wb", suffix=Path(file.name).suffix) as f:
101 | shutil.copyfileobj(file, f)
102 |
103 | duration = librosa.get_duration(filename=f.name)
104 | num_frames = -1
105 | if duration > 60:
106 | st.write(f"File is {duration:.01f}s long. Loading the first 60s only.")
107 | sr = librosa.get_samplerate(f.name)
108 | num_frames = 60 * sr
109 |
110 | x, sr = torchaudio.load(f.name, num_frames=num_frames)
111 |
112 | # resample if needed
113 | if sr != sample_rate:
114 | x = torchaudio.functional.resample(sr, sample_rate)(x)
115 |
116 | st.subheader("Mix")
117 | x_numpy = x.numpy()
118 | plot_spectrogram(x_numpy, sample_rate=sample_rate)
119 | st.audio(x_numpy, sample_rate=sample_rate)
120 |
121 | waveform = x.to(device)
122 |
123 | # split into 10.0 sec chunks
124 | ref = waveform.mean(0)
125 | waveform = (waveform - ref.mean()) / ref.std() # normalization
126 | sources = separate_sources(
127 | model,
128 | waveform[None],
129 | device=device,
130 | )[0]
131 | sources = sources * ref.std() + ref.mean()
132 |
133 | sources_list = model.sources
134 | sources = list(sources)
135 |
136 | audios = dict(zip(sources_list, sources))
137 |
138 | for source, audio in audios.items():
139 | audio = audio.cpu().numpy()
140 | st.subheader(source.capitalize())
141 | plot_spectrogram(audio, sample_rate=sample_rate)
142 | st.audio(audio, sample_rate=sample_rate)
143 |
144 |
145 | # load pretrained model
146 | hdemucs = load_hdemucs()
147 |
148 | # load audio
149 | uploaded_file = st.file_uploader("Choose a file to demix.")
150 |
151 | if uploaded_file is not None:
152 | # split with hdemucs
153 | hdemucs_sources = process_file(uploaded_file, hdemucs, "cuda:0")
154 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="aimless",
5 | version="0.0.1",
6 | author="Artificial Intelligence and Music League for Effective Source Separation",
7 | author_email="chin-yun.yu@qmul.ac.uk",
8 | packages=setuptools.find_packages(exclude=["tests", "tests.*", "data", "data.*"]),
9 | install_requires=["torch", "pytorch-lightning", "torch_fftconv"],
10 | classifiers=[
11 | "Programming Language :: Python :: 3",
12 | "License :: OSI Approved :: MIT License",
13 | "Operating System :: OS Independent",
14 | ],
15 | )
16 |
--------------------------------------------------------------------------------