├── .gitmodules
├── .gitignore
├── .env.tmp
├── requirements.txt
├── config.yaml
├── .pre-commit-config.yaml
├── LICENSE
├── main
├── dataset.py
├── utils.py
├── inference.py
└── module_vckt.py
├── exp
├── base_vctk_k_2.yaml
├── base_vctk_k_none.yaml
└── singing.yaml
├── train.py
├── README.md
├── medley_vox.py
├── mlc.py
├── eval.py
└── eval_nmf.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "datasets"]
2 | path = datasets
3 | url = https://github.com/yoyololicon/pytorch-wav-datasets.git
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .mypy_cache
3 | .env
4 | .DS_Store
5 | .DS_Store/
6 | .hydra
7 | venv/
8 | logs/
9 | data/
10 | train.log
11 | .idea/
12 | .vscode/
13 | est/
--------------------------------------------------------------------------------
/.env.tmp:
--------------------------------------------------------------------------------
1 | DIR_LOGS=/logs
2 | DIR_DATA=/data
3 |
4 | # Required if using wandb logger
5 | WANDB_PROJECT=wandbprojectname
6 | WANDB_ENTITY=wandbuser
7 | WANDB_API_KEY=wandbapikey
8 |
9 | # Required if using Common Voice dataset
10 | HUGGINGFACE_TOKEN=huggingfacetoken
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | pytorch-lightning==1.9.5
3 | python-dotenv
4 | hydra-core
5 | hydra-colorlog
6 | wandb
7 | auraloss
8 | yt-dlp
9 | datasets
10 | pyloudnorm
11 | einops
12 | omegaconf
13 | rich
14 | plotly
15 | librosa
16 | transformers
17 | eng-to-ipa
18 | ema-pytorch
19 | py7zr
20 |
21 | audio-diffusion-pytorch==0.0.92
22 | audio-encoders-pytorch
23 | audio-data-pytorch
24 | quantizer-pytorch
25 | difformer-pytorch
26 | a-transformers-pytorch
27 | asteroid
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - exp: null
4 | - override hydra/hydra_logging: colorlog
5 | - override hydra/job_logging: colorlog
6 |
7 | seed: 12345
8 | train: True
9 | ignore_warnings: True
10 | print_config: False # Prints tree with all configurations
11 | work_dir: ${hydra:runtime.cwd} # This is the root of the project
12 | logs_dir: ${work_dir}${oc.env:DIR_LOGS} # This is the root for all logs
13 | data_dir: ${work_dir}${oc.env:DIR_DATA} # This is the root for all data
14 | ckpt_dir: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
15 | # Hydra experiment configs log dir
16 | hydra:
17 | run:
18 | dir: ${logs_dir}/runs/${now:%Y-%m-%d-%H-%M-%S}
19 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v2.3.0
4 | hooks:
5 | - id: end-of-file-fixer
6 | - id: trailing-whitespace
7 |
8 | # Formats code correctly
9 | - repo: https://github.com/psf/black
10 | rev: 21.12b0
11 | hooks:
12 | - id: black
13 | args: [
14 | '--experimental-string-processing'
15 | ]
16 |
17 | # Checks types
18 | - repo: https://github.com/pre-commit/mirrors-mypy
19 | rev: 'v0.971'
20 | hooks:
21 | - id: mypy
22 | additional_dependencies: [data-science-types>=0.2, torch>=1.6]
23 |
24 | # Sorts imports
25 | - repo: https://github.com/pycqa/isort
26 | rev: 5.10.1
27 | hooks:
28 | - id: isort
29 | name: isort (python)
30 | args: ["--profile", "black"]
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/main/dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 | from torchaudio.datasets.vctk import VCTK_092
8 | from torchaudio.functional import resample
9 |
10 |
11 | def vctk_collate(batch,
12 | target_rate=22050,
13 | length=131072,
14 | mix_k = None):
15 | def pad_tensor(t):
16 | current_length = t.shape[1]
17 | padding_length = length - current_length
18 | return F.pad(t, (0, padding_length))
19 |
20 | original_sr = batch[0][1]
21 | waveforms = [pad_tensor(resample(waveform=data[0],
22 | orig_freq=original_sr,
23 | new_freq=target_rate)) for data in batch]
24 | if not mix_k:
25 | return torch.stack(waveforms)
26 | if len(waveforms) % mix_k != 0:
27 | raise ValueError("Batch size must be divisble by mix_k")
28 | random.shuffle(waveforms)
29 | partitioned_waveforms = [waveforms[i:i + mix_k] for i in range(0, len(waveforms), mix_k)]
30 | summed_list = [torch.sum(torch.stack(subset), dim=0) for subset in partitioned_waveforms]
31 | return torch.stack(summed_list)
32 |
33 |
34 | if __name__ == '__main__':
35 |
36 | dataset = VCTK_092(root="/home/emilian/PycharmProjects/multi-speaker-diff-sep/data/vctk", download=True)
37 | dataloader = DataLoader(dataset=dataset,
38 | batch_size=16,
39 | collate_fn=partial(vctk_collate, mix_k=2),
40 | pin_memory=True,
41 | num_workers=0)
42 | output = next(iter(dataloader))
--------------------------------------------------------------------------------
/exp/base_vctk_k_2.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # To pick dataset path execute with: ++datamodule.dataset.path=/your_wav_root/
4 |
5 | sampling_rate: 22050
6 | length: 131072
7 | channels: 1
8 | log_every_n_steps: 1000
9 |
10 | model:
11 | _target_: main.module_vckt.Model
12 | lr: 1e-4
13 | lr_beta1: 0.95
14 | lr_beta2: 0.999
15 | lr_eps: 1e-6
16 | lr_weight_decay: 1e-3
17 | # ema_beta: 0.9999
18 | # ema_power: 0.7
19 |
20 | model:
21 | _target_: audio_diffusion_pytorch.AudioDiffusionModel
22 | in_channels: ${channels}
23 | channels: 256
24 | patch_size: 16
25 | resnet_groups: 8
26 | kernel_multiplier_downsample: 2
27 | multipliers: [ 1, 2, 4, 4, 4, 4, 4 ]
28 | factors: [ 4, 4, 4, 2, 2, 2 ]
29 | num_blocks: [ 2, 2, 2, 2, 2, 2 ]
30 | attentions: [ 0, 0, 0, 1, 1, 1, 1 ]
31 | attention_heads: 8
32 | attention_features: 128
33 | attention_multiplier: 2
34 | use_nearest_upsample: False
35 | use_skip_scale: True
36 | diffusion_type: k
37 | diffusion_sigma_distribution:
38 | _target_: audio_diffusion_pytorch.LogNormalDistribution
39 | mean: -3.0
40 | std: 1.0
41 | diffusion_sigma_data: 0.2
42 |
43 |
44 | datamodule:
45 | _target_: main.module_vckt.Datamodule
46 | dataset:
47 | _target_: torchaudio.datasets.vctk.VCTK_092
48 | root: null
49 | download: True
50 |
51 | collate_fn:
52 | _target_: main.dataset.vctk_collate
53 | _partial_: True
54 | target_rate: ${sampling_rate}
55 | length: ${length}
56 | mix_k: 2
57 |
58 | val_split: 0.01
59 | batch_size: 32
60 | num_workers: 8
61 | pin_memory: True
62 |
63 | callbacks:
64 | rich_progress_bar:
65 | _target_: pytorch_lightning.callbacks.RichProgressBar
66 |
67 | model_checkpoint:
68 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
69 | monitor: "valid_loss" # name of the logged metric which determines when model is improving
70 | save_top_k: 1 # save k best models (determined by above metric)
71 | save_last: True # additionaly always save model from last epoch
72 | mode: "min" # can be "max" or "min"
73 | verbose: False
74 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
75 | filename: '{epoch:02d}-{valid_loss:.3f}'
76 |
77 | model_summary:
78 | _target_: pytorch_lightning.callbacks.RichModelSummary
79 | max_depth: 2
80 |
81 | audio_samples_logger:
82 | _target_: main.module_vckt.SampleLogger
83 | num_items: 4
84 | channels: ${channels}
85 | sampling_rate: ${sampling_rate}
86 | length: ${length}
87 | sampling_steps: [100]
88 | # use_ema_model: True
89 | diffusion_sampler:
90 | _target_: audio_diffusion_pytorch.ADPM2Sampler
91 | rho: 1.0
92 | diffusion_schedule:
93 | _target_: audio_diffusion_pytorch.KarrasSchedule
94 | sigma_min: 0.0001
95 | sigma_max: 5.0
96 | rho: 9.0
97 |
98 | loggers:
99 | wandb:
100 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
101 | project: ${oc.env:WANDB_PROJECT}
102 | entity: ${oc.env:WANDB_ENTITY}
103 | # offline: False # set True to store all logs only locally
104 | job_type: "train"
105 | group: ""
106 | save_dir: ${logs_dir}
107 |
108 | trainer:
109 | _target_: pytorch_lightning.Trainer
110 | gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`
111 | precision: 32 # Precision used for tensors, default `32`
112 | accelerator: null # `ddp` GPUs train individually and sync gradients, default `None`
113 | min_epochs: 0
114 | max_epochs: -1
115 | enable_model_summary: False
116 | log_every_n_steps: 1 # Logs metrics every N batches
117 | check_val_every_n_epoch: null
118 | val_check_interval: ${log_every_n_steps}
119 |
--------------------------------------------------------------------------------
/exp/base_vctk_k_none.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # To pick dataset path execute with: ++datamodule.dataset.path=/your_wav_root/
4 |
5 | sampling_rate: 22050
6 | length: 131072
7 | channels: 1
8 | log_every_n_steps: 1000
9 |
10 | model:
11 | _target_: main.module_vckt.Model
12 | lr: 1e-4
13 | lr_beta1: 0.95
14 | lr_beta2: 0.999
15 | lr_eps: 1e-6
16 | lr_weight_decay: 1e-3
17 | # ema_beta: 0.9999
18 | # ema_power: 0.7
19 |
20 | model:
21 | _target_: audio_diffusion_pytorch.AudioDiffusionModel
22 | in_channels: ${channels}
23 | channels: 256
24 | patch_size: 16
25 | resnet_groups: 8
26 | kernel_multiplier_downsample: 2
27 | multipliers: [ 1, 2, 4, 4, 4, 4, 4 ]
28 | factors: [ 4, 4, 4, 2, 2, 2 ]
29 | num_blocks: [ 2, 2, 2, 2, 2, 2 ]
30 | attentions: [ 0, 0, 0, 1, 1, 1, 1 ]
31 | attention_heads: 8
32 | attention_features: 128
33 | attention_multiplier: 2
34 | use_nearest_upsample: False
35 | use_skip_scale: True
36 | diffusion_type: k
37 | diffusion_sigma_distribution:
38 | _target_: audio_diffusion_pytorch.LogNormalDistribution
39 | mean: -3.0
40 | std: 1.0
41 | diffusion_sigma_data: 0.2
42 |
43 |
44 | datamodule:
45 | _target_: main.module_vckt.Datamodule
46 | dataset:
47 | _target_: torchaudio.datasets.vctk.VCTK_092
48 | root: null
49 | download: True
50 |
51 | collate_fn:
52 | _target_: main.dataset.vctk_collate
53 | _partial_: True
54 | target_rate: ${sampling_rate}
55 | length: ${length}
56 | mix_k: null
57 |
58 | val_split: 0.01
59 | batch_size: 32
60 | num_workers: 8
61 | pin_memory: True
62 |
63 | callbacks:
64 | rich_progress_bar:
65 | _target_: pytorch_lightning.callbacks.RichProgressBar
66 |
67 | model_checkpoint:
68 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
69 | monitor: "valid_loss" # name of the logged metric which determines when model is improving
70 | save_top_k: 1 # save k best models (determined by above metric)
71 | save_last: True # additionaly always save model from last epoch
72 | mode: "min" # can be "max" or "min"
73 | verbose: False
74 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
75 | filename: '{epoch:02d}-{valid_loss:.3f}'
76 |
77 | model_summary:
78 | _target_: pytorch_lightning.callbacks.RichModelSummary
79 | max_depth: 2
80 |
81 | audio_samples_logger:
82 | _target_: main.module_vckt.SampleLogger
83 | num_items: 4
84 | channels: ${channels}
85 | sampling_rate: ${sampling_rate}
86 | length: ${length}
87 | sampling_steps: [100]
88 | # use_ema_model: True
89 | diffusion_sampler:
90 | _target_: audio_diffusion_pytorch.ADPM2Sampler
91 | rho: 1.0
92 | diffusion_schedule:
93 | _target_: audio_diffusion_pytorch.KarrasSchedule
94 | sigma_min: 0.0001
95 | sigma_max: 5.0
96 | rho: 9.0
97 |
98 | loggers:
99 | wandb:
100 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
101 | project: ${oc.env:WANDB_PROJECT}
102 | entity: ${oc.env:WANDB_ENTITY}
103 | # offline: False # set True to store all logs only locally
104 | job_type: "train"
105 | group: ""
106 | save_dir: ${logs_dir}
107 |
108 | trainer:
109 | _target_: pytorch_lightning.Trainer
110 | gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`
111 | precision: 32 # Precision used for tensors, default `32`
112 | accelerator: null # `ddp` GPUs train individually and sync gradients, default `None`
113 | min_epochs: 0
114 | max_epochs: -1
115 | enable_model_summary: False
116 | log_every_n_steps: 1 # Logs metrics every N batches
117 | check_val_every_n_epoch: null
118 | val_check_interval: ${log_every_n_steps}
119 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import dotenv
4 | import hydra
5 | import pytorch_lightning as pl
6 | from main import utils
7 | from omegaconf import DictConfig, open_dict
8 |
9 | # Load environment variables from `.env`.
10 | dotenv.load_dotenv(override=True)
11 | log = utils.get_logger(__name__)
12 |
13 |
14 | @hydra.main(config_path=".", config_name="config.yaml", version_base=None)
15 | def main(config: DictConfig) -> None:
16 |
17 | # Logs config tree
18 | utils.extras(config)
19 |
20 | # Apply seed for reproducibility
21 | pl.seed_everything(config.seed)
22 |
23 | # Initialize datamodule
24 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>.")
25 | datamodule = hydra.utils.instantiate(config.datamodule, _convert_="partial")
26 |
27 | # Initialize model
28 | log.info(f"Instantiating model <{config.model._target_}>.")
29 | model = hydra.utils.instantiate(config.model, _convert_="partial")
30 |
31 | # Initialize all callbacks (e.g. checkpoints, early stopping)
32 | callbacks = []
33 |
34 | # If save is provided add callback that saves and stops, to be used with +ckpt
35 | if "save" in config:
36 | # Ignore loggers and other callbacks
37 | with open_dict(config):
38 | config.pop("loggers")
39 | config.pop("callbacks")
40 | config.trainer.num_sanity_val_steps = 0
41 | attribute, path = config.get("save"), config.get("ckpt_dir")
42 | filename = os.path.join(path, f"{attribute}.pt")
43 | callbacks += [utils.SavePytorchModelAndStopCallback(filename, attribute)]
44 |
45 | if "callbacks" in config:
46 | for _, cb_conf in config["callbacks"].items():
47 | if "_target_" in cb_conf:
48 | log.info(f"Instantiating callback <{cb_conf._target_}>.")
49 | callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
50 |
51 | # Initialize loggers (e.g. wandb)
52 | loggers = []
53 | if "loggers" in config:
54 | for _, lg_conf in config["loggers"].items():
55 | if "_target_" in lg_conf:
56 | log.info(f"Instantiating logger <{lg_conf._target_}>.")
57 | # Sometimes wandb throws error if slow connection...
58 | logger = utils.retry_if_error(
59 | lambda: hydra.utils.instantiate(lg_conf, _convert_="partial")
60 | )
61 | loggers.append(logger)
62 |
63 | # Initialize trainer
64 | log.info(f"Instantiating trainer <{config.trainer._target_}>.")
65 | trainer = hydra.utils.instantiate(
66 | config.trainer, callbacks=callbacks, logger=loggers, _convert_="partial"
67 | )
68 |
69 | # Send some parameters from config to all lightning loggers
70 | log.info("Logging hyperparameters!")
71 | utils.log_hyperparameters(
72 | config=config,
73 | model=model,
74 | datamodule=datamodule,
75 | trainer=trainer,
76 | callbacks=callbacks,
77 | logger=loggers,
78 | )
79 |
80 | # Train with checkpoint if present, otherwise from start
81 | if "ckpt" in config:
82 | ckpt = config.get("ckpt")
83 | log.info(f"Starting training from {ckpt}")
84 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt)
85 | else:
86 | log.info("Starting training.")
87 | trainer.fit(model=model, datamodule=datamodule)
88 |
89 | # Make sure everything closed properly
90 | log.info("Finalizing!")
91 | utils.finish(
92 | config=config,
93 | model=model,
94 | datamodule=datamodule,
95 | trainer=trainer,
96 | callbacks=callbacks,
97 | logger=loggers,
98 | )
99 |
100 | # Print path to best checkpoint
101 | if (
102 | not config.trainer.get("fast_dev_run")
103 | and config.get("train")
104 | and not config.get("save")
105 | ):
106 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}")
107 |
108 |
109 | if __name__ == "__main__":
110 | main()
111 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unsupervised Singing Voice Separation
2 |
3 | Source code of the paper [Zero-Shot Duet Singing Voices Separation with
4 | Diffusion Models](https://sdx-workshop.github.io/papers/Yu.pdf) at the SDX workshop 2023.
5 |
6 | ## Setup
7 |
8 | Install requirements
9 |
10 | ```bash
11 | pip install -r requirements.txt
12 | ```
13 |
14 | Add environment variables, rename `.env.tmp` to `.env` and replace with your own variables (example values are random)
15 | ```bash
16 | DIR_LOGS=/logs
17 | DIR_DATA=/data
18 |
19 | # Required if using wandb logger
20 | WANDB_PROJECT=audioproject
21 | WANDB_ENTITY=johndoe
22 | WANDB_API_KEY=a21dzbqlybbzccqla4txa21dzbqlybbzccqla4tx
23 | ```
24 |
25 | ## Training
26 |
27 | The config we used for the paper is [`exp/singing.yaml`](exp/singing.yaml), you can run it with
28 | ```bash
29 | python train.py exp=singing
30 | ```
31 | You'll need to download the relevant dataset and resample them to 24 kHz.
32 | Them, modified the `datamodule` section of the config to point to the right path.
33 |
34 | Resume run from a checkpoint
35 |
36 | ```bash
37 | python train.py exp=singing +ckpt=/logs/ckpts/2022-08-17-01-22-18/'last.ckpt'
38 | ```
39 |
40 | ## Evaluation
41 |
42 | First, download the [MedleyVox](https://github.com/jeonchangbin49/MedleyVox?tab=readme-ov-file) dataset.
43 | Then, run the following command to evaluate the model on the `duet` subset of the dataset.
44 |
45 | ```bash
46 | python eval.py logs/runs/XXXX/.hydra/config.yaml logs/ckpts/XXXX/last.ckpt /your/path/to/MedleyVox -T 100 --cond --hop-length 32768 --self-cond --retry 2
47 | ```
48 |
49 | Some important arguments:
50 |
51 | 1. `-T`: number of diffusion steps
52 | 2. `--cond`: use auto-regressive conditioning on the ground truth (teacher forcing). Without this flag, the model will generate the full length audio at once
53 | 3. `--self-cond`: perform auto-regressive conditioning on the generated audio if use together with `--cond`
54 | 4. `--hop-length`: the hop length of the moving window
55 | 5. `--window`: the size of the moving window. Default to the same length as training data
56 | 6. `--retry`: number of retries for each auto-regressive step. The algorithm with generate `retry + 1` candidates and pick the most similar one to the ground truth. Default to 0
57 |
58 | For other arguments, please check out the code.
59 |
60 | ### NMF baseline
61 |
62 | This baseline depends on `torchnmf`.
63 |
64 | ```bash
65 | python eval_nmf.py /your/path/to/MedleyVox/ --thresh 0.08 --division 10 --kernel-size 7
66 | ```
67 |
68 | ### Checkpoint/Logs
69 |
70 | Our pre-trained singing voice diffusion model can be downloaded [here](https://drive.google.com/drive/folders/1nAj0JDiG70ddr_7UnhszpIiVCh4SzqgW?usp=sharing).
71 | You can find the training logs and unconditional singing samples generated during training on [wandb](https://api.wandb.ai/links/aimless/fqtcyjke).
72 |
73 | ## FAQ
74 |
75 |
76 | How do I load the model once I'm done training?
77 |
78 | If you want to load the checkpoint to restore training with the trainer you can do `python train.py exp=my_experiment +ckpt=/logs/ckpts/2022-08-17-01-22-18/'last.ckpt'`.
79 |
80 | Otherwise if you want to instantiate a model from the checkpoint:
81 | ```py
82 | from main.mymodule import Model
83 | model = Model.load_from_checkpoint(
84 | checkpoint_path='my_checkpoint.ckpt',
85 | learning_rate=1e-4,
86 | beta1=0.9,
87 | beta2=0.99,
88 | in_channels=1,
89 | patch_size=16,
90 | all_other_paratemeters_here...
91 | )
92 | ```
93 | to get only the PyTorch `.pt` checkpoint you can save the internal model weights as `torch.save(model.model.state_dict(), 'torchckpt.pt')`.
94 |
95 |
96 |
97 |
98 |
99 | Why no checkpoint is created at the end of the epoch?
100 |
101 | If the epoch is shorter than `log_every_n_steps` it doesn't save the checkpoint at the end of the epoch, but after the provided number of steps. If you want to checkpoint more frequently you can add `every_n_train_steps` to the ModelCheckpoint e.g.:
102 | ```yaml
103 | model_checkpoint:
104 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
105 | monitor: "valid_loss" # name of the logged metric which determines when model is improving
106 | save_top_k: 1 # save k best models (determined by above metric)
107 | save_last: True # additionaly always save model from last epoch
108 | mode: "min" # can be "max" or "min"
109 | verbose: False
110 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
111 | filename: '{epoch:02d}-{valid_loss:.3f}'
112 | every_n_train_steps: 10
113 | ```
114 | Note that logging the checkpoint so frequently is not recommended in general, since it takes a bit of time to store the file.
115 |
116 |
117 |
--------------------------------------------------------------------------------
/exp/singing.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # To pick dataset path execute with: ++datamodule.dataset.path=/your_wav_root/
4 |
5 | sampling_rate: 24000
6 | length: 131072
7 | channels: 1
8 | log_every_n_steps: 3000
9 | overlap: 65536
10 |
11 | model:
12 | _target_: main.module_vckt.Model
13 | lr: 1e-4
14 | lr_beta1: 0.95
15 | lr_beta2: 0.999
16 | lr_eps: 1e-6
17 | lr_weight_decay: 1e-3
18 | # ema_beta: 0.9999
19 | # ema_power: 0.7
20 |
21 | model:
22 | _target_: audio_diffusion_pytorch.AudioDiffusionModel
23 | in_channels: ${channels}
24 | channels: 256
25 | patch_size: 12
26 | resnet_groups: 8
27 | kernel_multiplier_downsample: 2
28 | multipliers: [ 1, 2, 4, 4, 4, 4, 4 ]
29 | factors: [ 4, 4, 4, 2, 2, 2 ]
30 | num_blocks: [ 2, 2, 2, 2, 2, 2 ]
31 | attentions: [ 0, 0, 0, 1, 1, 1, 1 ]
32 | attention_heads: 8
33 | attention_features: 128
34 | attention_multiplier: 2
35 | use_nearest_upsample: False
36 | use_skip_scale: True
37 | diffusion_type: k
38 | diffusion_sigma_distribution:
39 | _target_: audio_diffusion_pytorch.LogNormalDistribution
40 | mean: -3.0
41 | std: 1.0
42 | diffusion_sigma_data: 0.2
43 |
44 |
45 | datamodule:
46 | _target_: main.module_vckt.Datamodule
47 | dataset:
48 | _target_: torch.utils.data.ConcatDataset
49 | datasets:
50 | - _target_: datasets.wav.WAVDataset
51 | data_dir: /import/c4dm-datasets-ext/m4singer-24k/
52 | segment: ${length}
53 | overlap: ${overlap}
54 | mono: True
55 | keepdim: True
56 | - _target_: datasets.wav.WAVDataset
57 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/OpenCPOP-24k/
58 | segment: ${length}
59 | overlap: ${overlap}
60 | mono: True
61 | keepdim: True
62 | - _target_: datasets.wav.WAVDataset
63 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/VocalSet-24k/
64 | segment: ${length}
65 | overlap: ${overlap}
66 | mono: True
67 | keepdim: True
68 | - _target_: datasets.wav.WAVDataset
69 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/OpenSinger-24k/
70 | segment: ${length}
71 | overlap: ${overlap}
72 | mono: True
73 | keepdim: True
74 | - _target_: datasets.wav.WAVDataset
75 | data_dir: /import/c4dm-datasets-ext/jvs_music_ver1/
76 | segment: ${length}
77 | overlap: ${overlap}
78 | mono: True
79 | keepdim: True
80 | - _target_: datasets.wav.WAVDataset
81 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/CSD-24k/
82 | segment: ${length}
83 | overlap: ${overlap}
84 | mono: True
85 | keepdim: True
86 | - _target_: datasets.wav.WAVDataset
87 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/NUS-24k/
88 | segment: ${length}
89 | overlap: ${overlap}
90 | mono: True
91 | keepdim: True
92 | - _target_: datasets.wav.WAVDataset
93 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/PJS-24k/
94 | segment: ${length}
95 | overlap: ${overlap}
96 | mono: True
97 | keepdim: True
98 |
99 |
100 | collate_fn: null
101 |
102 | val_split: 0.01
103 | batch_size: 32
104 | num_workers: 8
105 | pin_memory: True
106 |
107 | callbacks:
108 | rich_progress_bar:
109 | _target_: pytorch_lightning.callbacks.RichProgressBar
110 |
111 | model_checkpoint:
112 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
113 | monitor: "valid_loss" # name of the logged metric which determines when model is improving
114 | save_top_k: 1 # save k best models (determined by above metric)
115 | save_last: True # additionaly always save model from last epoch
116 | mode: "min" # can be "max" or "min"
117 | verbose: False
118 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
119 | filename: '{epoch:02d}-{valid_loss:.3f}'
120 |
121 | model_summary:
122 | _target_: pytorch_lightning.callbacks.RichModelSummary
123 | max_depth: 2
124 |
125 | audio_samples_logger:
126 | _target_: main.module_vckt.SampleLogger
127 | num_items: 4
128 | channels: ${channels}
129 | sampling_rate: ${sampling_rate}
130 | length: ${length}
131 | sampling_steps: [100]
132 | # use_ema_model: True
133 | diffusion_sampler:
134 | _target_: audio_diffusion_pytorch.ADPM2Sampler
135 | rho: 1.0
136 | diffusion_schedule:
137 | _target_: audio_diffusion_pytorch.KarrasSchedule
138 | sigma_min: 0.0001
139 | sigma_max: 5.0
140 | rho: 9.0
141 |
142 | loggers:
143 | wandb:
144 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
145 | project: ${oc.env:WANDB_PROJECT}
146 | entity: ${oc.env:WANDB_ENTITY}
147 | # offline: False # set True to store all logs only locally
148 | job_type: "train"
149 | group: ""
150 | save_dir: ${logs_dir}
151 |
152 | trainer:
153 | _target_: pytorch_lightning.Trainer
154 | gpus: -1 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`
155 | precision: 32 # Precision used for tensors, default `32`
156 | accelerator: null # `ddp` GPUs train individually and sync gradients, default `None`
157 | min_epochs: 0
158 | max_epochs: -1
159 | enable_model_summary: False
160 | log_every_n_steps: 1 # Logs metrics every N batches
161 | check_val_every_n_epoch: null
162 | val_check_interval: ${log_every_n_steps}
163 |
--------------------------------------------------------------------------------
/medley_vox.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset
8 | import librosa
9 |
10 | exclude_duet = {
11 | "CatMartino_IPromise",
12 | "TleilaxEnsemble_Late",
13 | "TleilaxEnsemble_MelancholyFlowers",
14 | }
15 |
16 |
17 | class MedleyVox(Dataset):
18 | """Dataset class for MedleyVox source separation tasks.
19 |
20 | Args:
21 | task (str): One of ``'unison'``, ``'duet'``, ``'main_vs_rest'`` or
22 | ``'total'`` :
23 | * ``'unison'`` for unison vocal separation.
24 | * ``'duet'`` for duet vocal separation.
25 | * ``'main_vs_rest'`` for main vs. rest vocal separation (main vs rest).
26 | * ``'n_singing'`` for N-singing separation. We will use all of the duet, unison, and main vs. rest data.
27 |
28 | sample_rate (int) : The sample rate of the sources and mixtures.
29 | n_src (int) : The number of sources in the mixture. Actually, this is fixed to 2 for our tasks. Need to be specified for N-singing training (future work).
30 | segment (int, optional) : The desired sources and mixtures length in s.
31 | """
32 |
33 | dataset_name = "MedleyVox"
34 |
35 | def __init__(
36 | self,
37 | root_dir,
38 | metadata_dir=None,
39 | task="duet",
40 | sample_rate=24000,
41 | n_src=2,
42 | segment=None,
43 | return_id=True,
44 | drop_duet=False,
45 | ):
46 | self.root_dir = root_dir # /path/to/data/test_medleyDB
47 | self.metadata_dir = metadata_dir # ./testset/testset_config
48 | self.task = task.lower()
49 | self.return_id = return_id
50 | # Get the csv corresponding to the task
51 | if self.task == "unison":
52 | self.total_segments_list = glob.glob(f"{self.root_dir}/unison/*/*")
53 | elif self.task == "duet":
54 | self.total_segments_list = glob.glob(f"{self.root_dir}/duet/*/*")
55 | if drop_duet:
56 | total_segments_list = list(
57 | filter(
58 | lambda x: x.split("/")[-2] not in exclude_duet,
59 | self.total_segments_list,
60 | )
61 | )
62 | print(
63 | f"Drop {len(self.total_segments_list) - len(total_segments_list)} duet songs."
64 | )
65 | self.total_segments_list = total_segments_list
66 |
67 | elif self.task == "main_vs_rest":
68 | self.total_segments_list = glob.glob(f"{self.root_dir}/rest/*/*")
69 | elif self.task == "n_singing":
70 | self.total_segments_list = (
71 | glob.glob(f"{self.root_dir}/unison/*/*")
72 | + glob.glob(f"{self.root_dir}/duet/*/*")
73 | + glob.glob(f"{self.root_dir}/rest/*/*")
74 | )
75 | self.segment = segment
76 | self.sample_rate = sample_rate
77 | self.n_src = n_src
78 |
79 | def __len__(self):
80 | return len(self.total_segments_list)
81 |
82 | def __getitem__(self, idx):
83 | song_name = self.total_segments_list[idx].split("/")[-2]
84 | segment_name = self.total_segments_list[idx].split("/")[-1]
85 | mixture_path = (
86 | f"{self.total_segments_list[idx]}/mix/{song_name} - {segment_name}.wav"
87 | )
88 | self.mixture_path = mixture_path
89 | sources_path_list = glob.glob(f"{self.total_segments_list[idx]}/gt/*.wav")
90 |
91 | if self.task == "main_vs_rest" or self.task == "n_singing":
92 | if os.path.exists(
93 | f"{self.metadata_dir}/V1_rest_vocals_only_config/{song_name}.json"
94 | ):
95 | metadata_json_path = (
96 | f"{self.metadata_dir}/V1_rest_vocals_only_config/{song_name}.json"
97 | )
98 | elif os.path.exists(
99 | f"{self.metadata_dir}/V2_vocals_only_config/{song_name}.json"
100 | ):
101 | metadata_json_path = (
102 | f"{self.metadata_dir}/V2_vocals_only_config/{song_name}.json"
103 | )
104 | else:
105 | print("main vs. rest metadata not found.")
106 | raise AttributeError
107 | with open(metadata_json_path, "r") as json_file:
108 | metadata_json = json.load(json_file)
109 |
110 | # Read sources
111 | sources_list = []
112 | ids = []
113 | if self.task == "main_vs_rest" or self.task == "n_singing":
114 | gt_main_name = metadata_json[segment_name]["main_vocal"]
115 | gt_source, sr = librosa.load(
116 | f"{self.total_segments_list[idx]}/gt/{gt_main_name} - {segment_name}.wav",
117 | sr=self.sample_rate,
118 | )
119 | gt_rest_list = metadata_json[segment_name]["other_vocals"]
120 | ids.append(f"{gt_main_name} - {segment_name}")
121 |
122 | rest_sources_list = []
123 | for other_vocal_name in gt_rest_list:
124 | s, sr = librosa.load(
125 | f"{self.total_segments_list[idx]}/gt/{other_vocal_name} - {segment_name}.wav",
126 | sr=self.sample_rate,
127 | )
128 | rest_sources_list.append(s)
129 | ids.append(f"{other_vocal_name} - {segment_name}")
130 | rest_sources_list = np.stack(rest_sources_list, axis=0)
131 | gt_rest = rest_sources_list.sum(0)
132 |
133 | sources_list.append(gt_source)
134 | sources_list.append(gt_rest)
135 | else: # self.task == 'unison' or self.task == 'duet'
136 | for i, source_path in enumerate(sources_path_list):
137 | s, sr = librosa.load(source_path, sr=self.sample_rate)
138 | sources_list.append(s)
139 | ids.append(os.path.basename(source_path).replace(".wav", ""))
140 | # Read the mixture
141 | mixture, sr = librosa.load(mixture_path, sr=self.sample_rate)
142 | # Convert to torch tensor
143 | mixture = torch.as_tensor(mixture, dtype=torch.float32)
144 | # Stack sources
145 | sources = np.vstack(sources_list)
146 | # Convert sources to tensor
147 | sources = torch.as_tensor(sources, dtype=torch.float32)
148 | if not self.return_id:
149 | return mixture, sources
150 | # 5400-34479-0005_4973-24515-0007.wav
151 | # id1, id2 = mixture_path.split("/")[-1].split(".")[0].split("_")
152 |
153 | return mixture, sources, ids
154 |
--------------------------------------------------------------------------------
/main/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import warnings
4 | from typing import Callable, List, Optional, Sequence
5 |
6 | import pkg_resources # type: ignore
7 | import pytorch_lightning as pl
8 | import rich.syntax
9 | import rich.tree
10 | import torch
11 | from omegaconf import DictConfig, OmegaConf
12 | from pytorch_lightning import Callback
13 | from pytorch_lightning.utilities import rank_zero_only
14 |
15 |
16 | def get_logger(name=__name__) -> logging.Logger:
17 | """Initializes multi-GPU-friendly python command line logger."""
18 |
19 | logger = logging.getLogger(name)
20 |
21 | # this ensures all logging levels get marked with the rank zero decorator
22 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup
23 | for level in (
24 | "debug",
25 | "info",
26 | "warning",
27 | "error",
28 | "exception",
29 | "fatal",
30 | "critical",
31 | ):
32 | setattr(logger, level, rank_zero_only(getattr(logger, level)))
33 |
34 | return logger
35 |
36 |
37 | log = get_logger(__name__)
38 |
39 |
40 | def extras(config: DictConfig) -> None:
41 | """Applies optional utilities, controlled by config flags.
42 | Utilities:
43 | - Ignoring python warnings
44 | - Rich config printing
45 | """
46 |
47 | # disable python warnings if
48 | if config.get("ignore_warnings"):
49 | log.info("Disabling python warnings! ")
50 | warnings.filterwarnings("ignore")
51 |
52 | # pretty print config tree using Rich library if
53 | if config.get("print_config"):
54 | log.info("Printing config tree with Rich! ")
55 | print_config(config, resolve=True)
56 |
57 |
58 | @rank_zero_only
59 | def print_config(
60 | config: DictConfig,
61 | print_order: Sequence[str] = (
62 | "datamodule",
63 | "model",
64 | "callbacks",
65 | "logger",
66 | "trainer",
67 | ),
68 | resolve: bool = True,
69 | ) -> None:
70 | """Prints content of DictConfig using Rich library and its tree structure.
71 | Args:
72 | config (DictConfig): Configuration composed by Hydra.
73 | print_order (Sequence[str], optional): Determines in what order config components are printed.
74 | resolve (bool, optional): Whether to resolve reference fields of DictConfig.
75 | """
76 |
77 | style = "dim"
78 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
79 |
80 | quee = []
81 |
82 | for field in print_order:
83 | quee.append(field) if field in config else log.info(
84 | f"Field '{field}' not found in config"
85 | )
86 |
87 | for field in config:
88 | if field not in quee:
89 | quee.append(field)
90 |
91 | for field in quee:
92 | branch = tree.add(field, style=style, guide_style=style)
93 |
94 | config_group = config[field]
95 | if isinstance(config_group, DictConfig):
96 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
97 | else:
98 | branch_content = str(config_group)
99 |
100 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
101 |
102 | rich.print(tree)
103 |
104 | with open("config_tree.log", "w") as file:
105 | rich.print(tree, file=file)
106 |
107 |
108 | @rank_zero_only
109 | def log_hyperparameters(
110 | config: DictConfig,
111 | model: pl.LightningModule,
112 | datamodule: pl.LightningDataModule,
113 | trainer: pl.Trainer,
114 | callbacks: List[pl.Callback],
115 | logger: List[pl.loggers.Logger],
116 | ) -> None:
117 | """Controls which config parts are saved by Lightning loggers.
118 | Additionaly saves:
119 | - number of model parameters
120 | """
121 |
122 | if not trainer.logger:
123 | return
124 |
125 | hparams = {}
126 |
127 | # choose which parts of hydra config will be saved to loggers
128 | hparams["model"] = config["model"]
129 |
130 | # save number of model parameters
131 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
132 | hparams["model/params/trainable"] = sum(
133 | p.numel() for p in model.parameters() if p.requires_grad
134 | )
135 | hparams["model/params/non_trainable"] = sum(
136 | p.numel() for p in model.parameters() if not p.requires_grad
137 | )
138 |
139 | hparams["datamodule"] = config["datamodule"]
140 | hparams["trainer"] = config["trainer"]
141 |
142 | if "seed" in config:
143 | hparams["seed"] = config["seed"]
144 | if "callbacks" in config:
145 | hparams["callbacks"] = config["callbacks"]
146 |
147 | hparams["pacakges"] = get_packages_list()
148 |
149 | # send hparams to all loggers
150 | trainer.logger.log_hyperparams(hparams)
151 |
152 |
153 | def finish(
154 | config: DictConfig,
155 | model: pl.LightningModule,
156 | datamodule: pl.LightningDataModule,
157 | trainer: pl.Trainer,
158 | callbacks: List[pl.Callback],
159 | logger: List[pl.loggers.Logger],
160 | ) -> None:
161 | """Makes sure everything closed properly."""
162 |
163 | # without this sweeps with wandb logger might crash!
164 | for lg in logger:
165 | if isinstance(lg, pl.loggers.wandb.WandbLogger):
166 | import wandb
167 |
168 | wandb.finish()
169 |
170 |
171 | def get_packages_list() -> List[str]:
172 | return [f"{p.project_name}=={p.version}" for p in pkg_resources.working_set]
173 |
174 |
175 | def retry_if_error(fn: Callable, num_attemps: int = 10):
176 | for attempt in range(num_attemps):
177 | try:
178 | return fn()
179 | except:
180 | print(f"Retrying, attempt {attempt+1}")
181 | pass
182 | return fn()
183 |
184 |
185 | class SavePytorchModelAndStopCallback(Callback):
186 | def __init__(self, path: str, attribute: Optional[str] = None):
187 | self.path = path
188 | self.attribute = attribute
189 |
190 | def on_train_start(self, trainer, pl_module):
191 | model, path = pl_module, self.path
192 | if self.attribute is not None:
193 | assert_message = "provided model attribute not found in pl_module"
194 | assert hasattr(pl_module, self.attribute), assert_message
195 | model = getattr(
196 | pl_module, self.attribute, hasattr(pl_module, self.attribute)
197 | )
198 | # Make dir if not existent
199 | os.makedirs(os.path.split(path)[0], exist_ok=True)
200 | # Save model
201 | torch.save(model, path)
202 | log.info(f"PyTorch model saved at: {path}")
203 | # Stop trainer
204 | trainer.should_stop = True
205 |
--------------------------------------------------------------------------------
/main/inference.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from functools import partial
3 | from pathlib import Path
4 | from typing import List, Optional, Callable, Mapping
5 | import pytorch_lightning as pl
6 |
7 | import torch
8 | import torchaudio
9 | import tqdm
10 | from math import sqrt, ceil
11 | import hydra
12 | from audio_data_pytorch.utils import fractional_random_split
13 |
14 | from audio_diffusion_pytorch.diffusion import Schedule
15 | from torch.utils.data import DataLoader
16 | from torchaudio.datasets import VCTK_092
17 |
18 | from main.dataset import vctk_collate
19 | from main.module_vckt import Model
20 |
21 |
22 | class Separator(torch.nn.Module, abc.ABC):
23 | def __init__(self):
24 | super().__init__()
25 |
26 | @abc.abstractmethod
27 | def separate(mixture, num_steps) -> Mapping[str, torch.Tensor]:
28 | ...
29 |
30 |
31 | class WeaklyMSDMSeparator(Separator):
32 | def __init__(self, stem_to_model: Mapping[str, Model], sigma_schedule, **kwargs):
33 | super().__init__()
34 | self.stem_to_model = stem_to_model
35 | self.separation_kwargs = kwargs
36 | self.sigma_schedule = sigma_schedule
37 |
38 | def separate(self, mixture: torch.Tensor, num_steps: int):
39 | stems = self.stem_to_model.keys()
40 | models = [self.stem_to_model[s] for s in stems]
41 | fns = [m.model.diffusion.diffusion.denoise_fn for m in models]
42 |
43 | # get device of models
44 | devices = {m.device for m in models}
45 | assert len(devices) == 1, devices
46 | (device,) = devices
47 |
48 | mixture = mixture.to(device)
49 | batch_size, _, length_samples = mixture.shape
50 |
51 | def denoise_fn(x, sigma):
52 | xs = [x[:, i:i + 1] for i in range(4)]
53 | xs = [fn(x, sigma=sigma) for fn, x in zip(fns, xs)]
54 | return torch.cat(xs, dim=1)
55 |
56 | y = separate_mixture(
57 | mixture=mixture,
58 | denoise_fn=denoise_fn,
59 | sigmas=self.sigma_schedule(num_steps, device),
60 | noises=torch.randn(batch_size, len(stems), length_samples).to(device),
61 | **self.separation_kwargs,
62 | )
63 | return {stem: y[:, i:i + 1, :] for i, stem in enumerate(stems)}
64 |
65 |
66 | # Algorithms ------------------------------------------------------------------
67 |
68 | def differential_with_dirac(x, sigma, denoise_fn, mixture, source_id=0):
69 | num_sources = x.shape[1]
70 | x[:, [source_id], :] = mixture - (x.sum(dim=1, keepdim=True) - x[:, [source_id], :])
71 | score = (x - denoise_fn(x, sigma=sigma)) / sigma
72 | scores = [score[:, si] for si in range(num_sources)]
73 | ds = [s - score[:, source_id] for s in scores]
74 | return torch.stack(ds, dim=1)
75 |
76 |
77 | def differential_with_gaussian(x, sigma, denoise_fn, mixture, gamma_fn=None):
78 | gamma = sigma if gamma_fn is None else gamma_fn(sigma)
79 | d = (x - denoise_fn(x, sigma=sigma)) / sigma
80 | d = d - sigma / (2* gamma ** 2) * (mixture - x.sum(dim=[1], keepdim=True))
81 | return d
82 |
83 |
84 | @torch.no_grad()
85 | def separate_mixture(
86 | mixture: torch.Tensor,
87 | denoise_fn: Callable,
88 | sigmas: torch.Tensor,
89 | noises: Optional[torch.Tensor],
90 | differential_fn: Callable = differential_with_dirac,
91 | s_churn: float = 40.0, # > 0 to add randomness
92 | num_resamples: int = 2,
93 | use_tqdm: bool = False,
94 | ):
95 | # Set initial noise
96 | x = sigmas[0] * noises # [batch_size, num-sources, sample-length]
97 |
98 | vis_wrapper = tqdm.tqdm if use_tqdm else lambda x: x
99 | for i in vis_wrapper(range(len(sigmas) - 1)):
100 | sigma, sigma_next = sigmas[i], sigmas[i + 1]
101 |
102 | for r in range(num_resamples):
103 | # Inject randomness
104 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1)
105 | sigma_hat = sigma * (gamma + 1)
106 | x = x + torch.randn_like(x) * (sigma_hat ** 2 - sigma ** 2) ** 0.5
107 |
108 | # Compute conditioned derivative
109 | d = differential_fn(mixture=mixture, x=x, sigma=sigma_hat, denoise_fn=denoise_fn)
110 |
111 | # Update integral
112 | x = x + d * (sigma_next - sigma_hat)
113 |
114 | # Renoise if not last resample step
115 | if r < num_resamples - 1:
116 | x = x + sqrt(sigma ** 2 - sigma_next ** 2) * torch.randn_like(x)
117 |
118 | return x.cpu().detach()
119 |
120 |
121 | # -----------------------------------------------------------------------------
122 | def save_separation(
123 | separated_tracks: Mapping[str, torch.Tensor],
124 | sample_rate: int,
125 | chunk_path: Path,
126 | ):
127 | for stem, separated_track in separated_tracks.items():
128 | torchaudio.save(chunk_path / f"{stem}.wav", separated_track.cpu(), sample_rate=sample_rate)
129 |
130 |
131 | if __name__ == '__main__':
132 |
133 | pl.seed_everything(12345)
134 |
135 | with hydra.initialize(config_path=".."):
136 | cfg = hydra.compose(config_name="exp/base_vctk_k_none.yaml")
137 |
138 | model = hydra.utils.instantiate(cfg['model']).cuda()
139 | vctk_checkpoint = torch.load('/home/emilian/PycharmProjects/multi-speaker-diff-sep/data/epoch=117-valid_loss=0.015.ckpt',
140 | map_location='cuda')
141 | model.load_state_dict(vctk_checkpoint['state_dict'])
142 | diffusion_schedule = hydra.utils.instantiate(cfg['callbacks']['audio_samples_logger']['diffusion_schedule']).cuda()
143 | separator = WeaklyMSDMSeparator(stem_to_model = {"voice_1": model.cuda(),
144 | "voice_2": model.cuda()},
145 | sigma_schedule=diffusion_schedule,
146 | use_tqdm=True)
147 |
148 | dataset = VCTK_092(root="/home/emilian/PycharmProjects/multi-speaker-diff-sep/data/vctk", download=True)
149 | split = [1.0 - 0.01, 0.01]
150 | _, data_val = fractional_random_split(dataset, split)
151 | dataloader = DataLoader(dataset=data_val,
152 | batch_size=2,
153 | num_workers=0,
154 | pin_memory=False,
155 | drop_last=True,
156 | shuffle=False,
157 | collate_fn=partial(vctk_collate, mix_k=None))
158 | data_iter = iter(dataloader)
159 | _ = next(data_iter)
160 | _ = next(data_iter)
161 | batch = next(data_iter)
162 |
163 | mix = batch.sum(dim=0, keepdim=True)
164 | torchaudio.save("./source_0.wav", batch[0].cpu(), sample_rate=22050)
165 | torchaudio.save("./source_1.wav", batch[1].cpu(), sample_rate=22050)
166 | separated_tracks = separator.separate(mixture=mix, num_steps=500)
167 | save_separation({k: v[0] for k, v in separated_tracks.items()},
168 | sample_rate=22050,
169 | chunk_path=Path('.'))
170 | torchaudio.save("./mix.wav", mix[0].cpu(), sample_rate=22050)
171 |
--------------------------------------------------------------------------------
/mlc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from typing import List, Callable
6 | from torchaudio.transforms import Spectrogram
7 | from functools import reduce, partial
8 |
9 |
10 | class MLC(nn.Module):
11 | def __init__(
12 | self,
13 | n_fft: int,
14 | sr: int,
15 | gammas: List[float],
16 | hop_size: int,
17 | Hipass_f: float = 50,
18 | Lowpass_t=0.24,
19 | **kwargs,
20 | ):
21 | super().__init__()
22 | self.n_fft = n_fft
23 | self.hop_size = hop_size
24 |
25 | self.stft = Spectrogram(
26 | n_fft=n_fft,
27 | hop_length=hop_size,
28 | power=2,
29 | normalized=True,
30 | onesided=False,
31 | **kwargs,
32 | )
33 | hpi = int(Hipass_f * n_fft / sr) + 1
34 | lpi = int(Lowpass_t * sr / 1000) + 1
35 | layers = [lambda ceps, spec: (ceps, spec ** gammas[0])]
36 |
37 | def gamm_trsfm(x, g, i):
38 | x = torch.fft.fft(x, norm="ortho").real
39 | x[..., :i] = x[..., -i:] = 0
40 | return F.relu(x) ** g
41 |
42 | self.num_spec = 1
43 | self.num_ceps = 0
44 |
45 | for d, gamma in enumerate(gammas[1:]):
46 | if d % 2:
47 | layers.append(lambda ceps, _: (ceps, gamm_trsfm(ceps, gamma, hpi)))
48 | self.num_spec += 1
49 | else:
50 | layers.append(lambda _, spec: (gamm_trsfm(spec, gamma, lpi), spec))
51 | self.num_ceps += 1
52 |
53 | self.compute = partial(reduce, lambda x, f: f(*x), layers)
54 |
55 | def forward(self, x):
56 | return self.compute((None, self.stft(x).transpose(-1, -2)))
57 |
58 |
59 | class Sparse_Pitch_Profile(nn.Module):
60 | def __init__(self, in_channels, sr, harms_range=24, division=1, norm=False):
61 | """
62 |
63 | Parameters
64 | ----------
65 | in_channels: int
66 | window size
67 | sr: int
68 | sample rate
69 | harms_range: int
70 | The extended area above (or below) the piano pitch range (in semitones)
71 | 25 : though somewhat larger, to ensure the coverage is large enough (if division=1, 24 is sufficient)
72 | division: int
73 | The division number for filterbank frequency resolution. The frequency resolution is 1 / division (semitone)
74 | norm: bool
75 | If set to True, normalize each filterbank so the weight of each filterbank sum to 1.
76 | """
77 | super().__init__()
78 | step = 1 / division
79 | # midi_num shape = (88 + harms_range) * division + 2
80 | # this implementation make sure if we group midi_num with a size of division
81 | # each group will center at the piano pitch number and the extra pitch range
82 | # E.g., division = 2, midi_num = [20.25, 20.75, 21.25, ....]
83 | # dividion = 3, midi_num = [20.33, 20.67, 21, 21.33, ...]
84 | midi_num = np.arange(
85 | 20.5 - step / 2 - harms_range, 108.5 + step + harms_range, step
86 | )
87 | self.midi_num = midi_num
88 |
89 | fd = 440 * np.power(2, (midi_num - 69) / 12)
90 | self.fd = fd
91 |
92 | self.effected_dim = in_channels // 2 + 1
93 | # // 2 : the spectrum/ cepstrum are symmetric
94 |
95 | x = np.arange(self.effected_dim)
96 | freq_f = x * sr / in_channels
97 | freq_t = sr / x[1:]
98 | # avoid explosion; x[0] is always 0 for cepstrum
99 |
100 | inter_value = np.array([0, 1, 0])
101 | idxs = np.digitize(freq_f, fd)
102 |
103 | cols, rows, values = [], [], []
104 | for i in range(harms_range * division, (88 + 2 * harms_range) * division):
105 | idx = np.where((idxs == i + 1) | (idxs == i + 2))[0]
106 | c = idx
107 | r = np.broadcast_to(i - harms_range * division, idx.shape)
108 | x = np.interp(freq_f[idx], fd[i : i + 3], inter_value).astype(np.float32)
109 | if norm and len(idx):
110 | # x /= (fd[i + 2] - fd[i]) / sr * in_channels
111 | x /= x.sum() # energy normalization
112 |
113 | if len(idx) == 0 and len(values) and len(values[-1]):
114 | # low resolution in the lower frequency (for spec)/ highter frequency (for ceps),
115 | # some filterbanks will not get any bin index, so we copy the indexes from the previous iteration
116 | c = cols[-1].copy()
117 | r = rows[-1].copy()
118 | r[:] = i - harms_range * division
119 | x = values[-1].copy()
120 |
121 | cols.append(c)
122 | rows.append(r)
123 | values.append(x)
124 |
125 | cols, rows, values = (
126 | np.concatenate(cols),
127 | np.concatenate(rows),
128 | np.concatenate(values),
129 | )
130 | self.filters_f_idx = (rows, cols)
131 | self.filters_f_values = nn.Parameter(torch.tensor(values), requires_grad=False)
132 |
133 | idxs = np.digitize(freq_t, fd)
134 | cols, rows, values = [], [], []
135 | for i in range((88 + harms_range) * division - 1, -1, -1):
136 | idx = np.where((idxs == i + 1) | (idxs == i + 2))[0]
137 | c = idx + 1
138 | r = np.broadcast_to(i, idx.shape)
139 | x = np.interp(freq_t[idx], fd[i : i + 3], inter_value).astype(np.float32)
140 | if norm and len(idx):
141 | # x /= (1 / fd[i] - 1 / fd[i + 2]) * sr
142 | x /= x.sum()
143 |
144 | if len(idx) == 0 and len(values) and len(values[-1]):
145 | c = cols[-1].copy()
146 | r = rows[-1].copy()
147 | r[:] = i
148 | x = values[-1].copy()
149 |
150 | cols.append(c)
151 | rows.append(r)
152 | values.append(x)
153 |
154 | cols, rows, values = (
155 | np.concatenate(cols),
156 | np.concatenate(rows),
157 | np.concatenate(values),
158 | )
159 | self.filters_t_idx = (rows, cols)
160 | self.filters_t_values = nn.Parameter(torch.tensor(values), requires_grad=False)
161 | self.filter_size = torch.Size(
162 | ((88 + harms_range) * division, self.effected_dim)
163 | )
164 |
165 | def forward(self, ceps, spec):
166 | ceps, spec = ceps[..., : self.effected_dim], spec[..., : self.effected_dim]
167 | batch_dim, steps, _ = ceps.size()
168 | filter_f = torch.sparse_coo_tensor(
169 | self.filters_f_idx, self.filters_f_values, self.filter_size
170 | )
171 | filter_t = torch.sparse_coo_tensor(
172 | self.filters_t_idx, self.filters_t_values, self.filter_size
173 | )
174 | ppt = filter_t @ ceps.transpose(0, 2).contiguous().view(self.effected_dim, -1)
175 | ppf = filter_f @ spec.transpose(0, 2).contiguous().view(self.effected_dim, -1)
176 | return ppt.view(-1, steps, batch_dim).transpose(0, 2), ppf.view(
177 | -1, steps, batch_dim
178 | ).transpose(0, 2)
179 |
--------------------------------------------------------------------------------
/main/module_vckt.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, List, Optional
2 |
3 | import librosa
4 | import plotly.graph_objs as go
5 | import pytorch_lightning as pl
6 | import torch
7 | import torchaudio
8 | import wandb
9 | from audio_data_pytorch.utils import fractional_random_split
10 | from audio_diffusion_pytorch import AudioDiffusionModel, Sampler, Schedule
11 | from einops import rearrange
12 |
13 | # from ema_pytorch import EMA
14 | from pytorch_lightning import Callback, Trainer
15 | from pytorch_lightning.loggers import WandbLogger
16 | from torch import Tensor, nn
17 | from torch.utils.data import DataLoader
18 |
19 | """ Model """
20 |
21 |
22 | class Model(pl.LightningModule):
23 | def __init__(
24 | self,
25 | lr: float,
26 | lr_beta1: float,
27 | lr_beta2: float,
28 | lr_eps: float,
29 | lr_weight_decay: float,
30 | # ema_beta: float,
31 | # ema_power: float,
32 | model: nn.Module,
33 | ):
34 | super().__init__()
35 | self.lr = lr
36 | self.lr_beta1 = lr_beta1
37 | self.lr_beta2 = lr_beta2
38 | self.lr_eps = lr_eps
39 | self.lr_weight_decay = lr_weight_decay
40 | self.model = model
41 | # self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power)
42 |
43 | @property
44 | def device(self):
45 | return next(self.model.parameters()).device
46 |
47 | def configure_optimizers(self):
48 | optimizer = torch.optim.AdamW(
49 | list(self.model.parameters()),
50 | lr=self.lr,
51 | betas=(self.lr_beta1, self.lr_beta2),
52 | eps=self.lr_eps,
53 | weight_decay=self.lr_weight_decay,
54 | )
55 | return optimizer
56 |
57 | def training_step(self, batch, batch_idx):
58 | waveforms = batch
59 | loss = self.model(waveforms)
60 | self.log("train_loss", loss)
61 | # Update EMA model and log decay
62 | # self.model_ema.update()
63 | # self.log("ema_decay", self.model_ema.get_current_decay())
64 | return loss
65 |
66 | def validation_step(self, batch, batch_idx):
67 | waveforms = batch
68 | loss = self.model(waveforms)
69 | self.log("valid_loss", loss)
70 | return loss
71 |
72 |
73 | """ Datamodule """
74 |
75 |
76 | class Datamodule(pl.LightningDataModule):
77 | def __init__(
78 | self,
79 | dataset,
80 | collate_fn,
81 | *,
82 | val_split: float,
83 | batch_size: int,
84 | num_workers: int,
85 | pin_memory: bool = False,
86 | **kwargs: int,
87 | ) -> None:
88 | super().__init__()
89 | self.dataset = dataset
90 | self.collate_fn = collate_fn
91 | self.val_split = val_split
92 | self.batch_size = batch_size
93 | self.num_workers = num_workers
94 | self.pin_memory = pin_memory
95 | self.data_train: Any = None
96 | self.data_val: Any = None
97 |
98 | def setup(self, stage: Any = None) -> None:
99 | split = [1.0 - self.val_split, self.val_split]
100 | self.data_train, self.data_val = fractional_random_split(self.dataset, split)
101 |
102 | def train_dataloader(self) -> DataLoader:
103 | return DataLoader(
104 | dataset=self.data_train,
105 | collate_fn=self.collate_fn,
106 | batch_size=self.batch_size,
107 | num_workers=self.num_workers,
108 | pin_memory=self.pin_memory,
109 | drop_last=True,
110 | shuffle=True,
111 | )
112 |
113 | def val_dataloader(self) -> DataLoader:
114 | return DataLoader(
115 | dataset=self.data_val,
116 | collate_fn=self.collate_fn,
117 | batch_size=self.batch_size,
118 | num_workers=self.num_workers,
119 | pin_memory=self.pin_memory,
120 | drop_last=True,
121 | shuffle=True,
122 | )
123 |
124 |
125 | """ Callbacks """
126 |
127 |
128 | def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]:
129 | """Safely get Weights&Biases logger from Trainer."""
130 |
131 | if isinstance(trainer.logger, WandbLogger):
132 | return trainer.logger
133 |
134 | for logger in trainer.loggers:
135 | if isinstance(logger, WandbLogger):
136 | return logger
137 |
138 | print("WandbLogger not found.")
139 | return None
140 |
141 |
142 | def log_wandb_audio_batch(
143 | logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = ""
144 | ):
145 | num_items = samples.shape[0]
146 | samples = rearrange(samples, "b c t -> b t c").detach().cpu().numpy()
147 | logger.log(
148 | {
149 | f"sample_{idx}_{id}": wandb.Audio(
150 | samples[idx],
151 | caption=caption,
152 | sample_rate=sampling_rate,
153 | )
154 | for idx in range(num_items)
155 | }
156 | )
157 |
158 |
159 | def log_wandb_audio_spectrogram(
160 | logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = ""
161 | ):
162 | num_items = samples.shape[0]
163 | samples = samples.detach().cpu()
164 | transform = torchaudio.transforms.MelSpectrogram(
165 | sample_rate=sampling_rate,
166 | n_fft=1024,
167 | hop_length=512,
168 | n_mels=80,
169 | center=True,
170 | norm="slaney",
171 | )
172 |
173 | def get_spectrogram_image(x):
174 | spectrogram = transform(x[0])
175 | image = librosa.power_to_db(spectrogram)
176 | trace = [go.Heatmap(z=image, colorscale="viridis")]
177 | layout = go.Layout(
178 | yaxis=dict(title="Mel Bin (Log Frequency)"),
179 | xaxis=dict(title="Frame"),
180 | title_text=caption,
181 | title_font_size=10,
182 | )
183 | fig = go.Figure(data=trace, layout=layout)
184 | return fig
185 |
186 | logger.log(
187 | {
188 | f"mel_spectrogram_{idx}_{id}": get_spectrogram_image(samples[idx])
189 | for idx in range(num_items)
190 | }
191 | )
192 |
193 |
194 | class SampleLogger(Callback):
195 | def __init__(
196 | self,
197 | num_items: int,
198 | channels: int,
199 | sampling_rate: int,
200 | length: int,
201 | sampling_steps: List[int],
202 | diffusion_schedule: Schedule,
203 | diffusion_sampler: Sampler,
204 | # use_ema_model: bool,
205 | ) -> None:
206 | self.num_items = num_items
207 | self.channels = channels
208 | self.sampling_rate = sampling_rate
209 | self.length = length
210 | self.sampling_steps = sampling_steps
211 | self.diffusion_schedule = diffusion_schedule
212 | self.diffusion_sampler = diffusion_sampler
213 | # self.use_aa_model = use_ema_model
214 | self.log_next = False
215 |
216 | def on_validation_epoch_start(self, trainer, pl_module):
217 | self.log_next = True
218 |
219 | def on_validation_batch_start(
220 | self, trainer, pl_module, batch, batch_idx, dataloader_idx
221 | ):
222 | if self.log_next:
223 | self.log_sample(trainer, pl_module, batch)
224 | self.log_next = False
225 |
226 | @torch.no_grad()
227 | def log_sample(self, trainer, pl_module, batch):
228 | is_train = pl_module.training
229 | if is_train:
230 | pl_module.eval()
231 |
232 | wandb_logger = get_wandb_logger(trainer).experiment
233 |
234 | diffusion_model = pl_module.model
235 | # if self.use_ema_model:
236 | # diffusion_model = pl_module.model_ema.ema_model
237 |
238 | # Get start diffusion noise
239 | noise = torch.randn(
240 | (self.num_items, self.channels, self.length), device=pl_module.device
241 | )
242 |
243 | for steps in self.sampling_steps:
244 | samples = diffusion_model.sample(
245 | noise=noise,
246 | sampler=self.diffusion_sampler,
247 | sigma_schedule=self.diffusion_schedule,
248 | num_steps=steps,
249 | )
250 | log_wandb_audio_batch(
251 | logger=wandb_logger,
252 | id="sample",
253 | samples=samples,
254 | sampling_rate=self.sampling_rate,
255 | caption=f"Sampled in {steps} steps",
256 | )
257 | log_wandb_audio_spectrogram(
258 | logger=wandb_logger,
259 | id="sample",
260 | samples=samples,
261 | sampling_rate=self.sampling_rate,
262 | caption=f"Sampled in {steps} steps",
263 | )
264 |
265 | if is_train:
266 | pl_module.train()
267 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import hydra
3 | import os
4 | import torch
5 | import pathlib
6 | from tqdm import tqdm
7 | from typing import Callable, Optional
8 | from math import sqrt
9 | from asteroid.metrics import get_metrics
10 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
11 | import soundfile as sf
12 |
13 | from medley_vox import MedleyVox
14 |
15 | COMPUTE_METRICS = ["si_sdr", "sdr"]
16 |
17 |
18 | @torch.no_grad()
19 | def separate_mixture(
20 | mixture: torch.Tensor,
21 | noises: torch.Tensor,
22 | denoise_fn: Callable,
23 | sigmas: torch.Tensor,
24 | cond: Optional[torch.Tensor] = None,
25 | cond_index: int = 0,
26 | s_churn: float = 40.0, # > 0 to add randomness
27 | num_resamples: int = 2,
28 | use_tqdm: bool = False,
29 | gaussian: bool = False,
30 | ):
31 | # Set initial noise
32 | x = sigmas[0] * noises # [batch_size, num-sources, sample-length]
33 |
34 | for i in tqdm(range(len(sigmas) - 1), disable=not use_tqdm):
35 | sigma, sigma_next = sigmas[i], sigmas[i + 1]
36 |
37 | for r in range(num_resamples):
38 | # Inject randomness
39 | gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
40 | sigma_hat = sigma * (gamma + 1)
41 | x = x + torch.randn_like(x) * (sigma_hat**2 - sigma**2) ** 0.5
42 |
43 | if cond is not None:
44 | noisey_cond = cond + torch.randn_like(cond) * sigma
45 | x[:, :cond_index] = noisey_cond
46 |
47 | # Compute conditioned derivative
48 | if not gaussian:
49 | x[:1] = mixture - x[1:].sum(dim=0, keepdim=True)
50 | score = (x - denoise_fn(x, sigma=sigma)) / sigma
51 | if gaussian:
52 | d = score + sigma / (2 * gamma**2) * (mixture - x.sum(dim=0))
53 | x += d * (sigma_next - sigma_hat)
54 | else:
55 | ds = score[1:] - score[:1]
56 |
57 | # Update integral
58 | x[1:] += ds * (sigma_next - sigma_hat)
59 |
60 | # Renoise if not last resample step
61 | if r < num_resamples - 1:
62 | x = x + torch.sqrt(sigma**2 - sigma_next**2) * torch.randn_like(x)
63 |
64 | return x
65 |
66 |
67 | if __name__ == "__main__":
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument("config", type=str, help="Path to config file")
70 | parser.add_argument("ckpt", type=str, help="Path to checkpoint file")
71 | parser.add_argument("medleyvox", type=str, help="Path to MedleyVox dataset")
72 | parser.add_argument("-T", default=100, type=int, help="Number of diffusion steps")
73 | parser.add_argument("-S", default=40.0, type=float, help="S churn")
74 | parser.add_argument("--out", type=str, help="Output directory")
75 | parser.add_argument("--cond", action="store_true", help="Use conditioning")
76 | parser.add_argument(
77 | "--self-cond", action="store_true", help="Use self conditioning"
78 | )
79 | parser.add_argument("--hop-length", type=int, help="Hop length")
80 | parser.add_argument("--window", type=int, help="Window size")
81 | parser.add_argument("--full-duet", action="store_true", help="Drop duet songs")
82 | parser.add_argument("--retry", type=int, default=0, help="Retry")
83 | parser.add_argument("--outer-retry", type=int, default=0, help="Outer retry")
84 |
85 | args = parser.parse_args()
86 |
87 | config_path, config_name = os.path.split(args.config)
88 |
89 | with hydra.initialize(config_path=config_path):
90 | cfg = hydra.compose(config_name=config_name)
91 |
92 | sr = cfg.sampling_rate
93 | length = cfg.length
94 |
95 | model = hydra.utils.instantiate(cfg.model)
96 | vctk_checkpoint = torch.load(
97 | args.ckpt,
98 | map_location="cpu",
99 | )
100 | model.load_state_dict(vctk_checkpoint["state_dict"])
101 | diffusion_schedule = hydra.utils.instantiate(
102 | cfg.callbacks.audio_samples_logger.diffusion_schedule
103 | )
104 |
105 | model = model.cuda()
106 | model.eval()
107 | diffusion_schedule = diffusion_schedule.cuda()
108 |
109 | inner_denoise_fn = model.model.diffusion.diffusion.denoise_fn
110 |
111 | def denoise_fn(x, sigma):
112 | x = x.unsqueeze(1)
113 | return inner_denoise_fn(x, sigma=sigma).squeeze(1)
114 |
115 | sigmas = diffusion_schedule(args.T, "cuda")
116 |
117 | dataset = MedleyVox(
118 | args.medleyvox,
119 | sample_rate=sr,
120 | drop_duet=not args.full_duet,
121 | )
122 |
123 | hop_length = length // 2 if args.hop_length is None else args.hop_length
124 | window_size = length if args.window is None else args.window
125 | loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
126 | s_churn = args.S
127 |
128 | if args.cond:
129 | print(window_size / sr, hop_length / sr)
130 |
131 | accumulate_metrics_mean = {}
132 |
133 | with tqdm(dataset) as pbar:
134 | for mix_num, (x_cpu, y_cpu, ids) in enumerate(pbar):
135 | x = x_cpu.cuda()
136 | y = y_cpu.cuda()
137 | n = y.shape[0]
138 |
139 | outer_trials = []
140 | for i in range(args.outer_retry + 1):
141 | if args.cond:
142 | cond = torch.zeros(n, window_size).cuda()
143 | sub_m = torch.zeros(window_size).cuda()
144 | result = []
145 | for sub_x, sub_y in zip(
146 | torch.split(x, hop_length), torch.split(y, hop_length, 1)
147 | ):
148 | noise = torch.randn(n, window_size).cuda()
149 | overlap_size = window_size - sub_x.numel()
150 | if overlap_size > 0:
151 | sub_m = torch.cat([sub_m[-overlap_size:], sub_x])
152 | cond = cond[:, -overlap_size:]
153 | else:
154 | sub_m = sub_x
155 | cond = None
156 |
157 | trials = []
158 | for i in range(args.retry + 1):
159 | pred = separate_mixture(
160 | sub_m,
161 | noise,
162 | denoise_fn,
163 | sigmas,
164 | s_churn=s_churn,
165 | cond=cond,
166 | cond_index=overlap_size,
167 | use_tqdm=False,
168 | gaussian=False,
169 | )
170 | sub_pred = pred[:, -sub_x.numel() :]
171 |
172 | if args.retry > 0:
173 | loss, align_pred = loss_func(
174 | sub_pred.unsqueeze(0),
175 | sub_y.unsqueeze(0),
176 | return_est=True,
177 | )
178 | trials.append((loss, align_pred.squeeze()))
179 | else:
180 | trials.append((0, sub_pred))
181 |
182 | _, sub_pred = min(trials, key=lambda x: x[0])
183 | cond = torch.cat(
184 | ([] if cond is None else [cond])
185 | + [sub_pred if args.self_cond else sub_y],
186 | dim=1,
187 | )
188 | result.append(sub_pred)
189 |
190 | result = torch.cat(result, dim=1)
191 | else:
192 | original_length = x.numel()
193 | padding = length - (original_length % length)
194 | if padding < length:
195 | x = torch.cat([x, x.new_zeros(padding)], dim=0)
196 |
197 | result = separate_mixture(
198 | x,
199 | torch.randn(n, x.numel()).cuda(),
200 | denoise_fn,
201 | sigmas,
202 | s_churn=s_churn,
203 | use_tqdm=False,
204 | )[:, :original_length]
205 |
206 | loss, reordered_sources = loss_func(
207 | result.unsqueeze(0), y.unsqueeze(0), return_est=True
208 | )
209 | outer_trials.append((loss, reordered_sources))
210 |
211 | _, reordered_sources = min(outer_trials, key=lambda x: x[0])
212 | est = reordered_sources.squeeze().cpu().numpy()
213 |
214 | utt_metrics = get_metrics(
215 | x_cpu.numpy(),
216 | y_cpu.numpy(),
217 | est,
218 | sample_rate=sr,
219 | metrics_list=COMPUTE_METRICS,
220 | )
221 |
222 | # calculate improvement
223 | for metric in COMPUTE_METRICS:
224 | v = utt_metrics.pop("input_" + metric)
225 | utt_metrics[metric + "i"] = utt_metrics[metric] - v
226 |
227 | for k, v in utt_metrics.items():
228 | if k not in accumulate_metrics_mean:
229 | accumulate_metrics_mean[k] = 0
230 |
231 | accumulate_metrics_mean[k] += (v - accumulate_metrics_mean[k]) / (
232 | mix_num + 1
233 | )
234 |
235 | pbar.set_postfix(accumulate_metrics_mean)
236 |
237 | if args.out is not None:
238 | out_dir = pathlib.Path(args.out) / f"medleyvox_{mix_num}"
239 | out_dir.mkdir(parents=True, exist_ok=True)
240 |
241 | sf.write(
242 | out_dir / "mixture.wav",
243 | x_cpu.numpy(),
244 | sr,
245 | "PCM_16",
246 | )
247 |
248 | for i, s in enumerate(est):
249 | out_path = out_dir / f"{ids[i]}.wav"
250 | sf.write(out_path, s, sr, "PCM_16")
251 |
252 | print(accumulate_metrics_mean)
253 |
--------------------------------------------------------------------------------
/eval_nmf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import hydra
3 | import os
4 | import torch
5 | import pathlib
6 | from tqdm import tqdm
7 | from typing import Callable, Optional
8 | from itertools import combinations, accumulate
9 | from functools import reduce
10 | import numpy as np
11 | from math import sqrt
12 | from asteroid.metrics import get_metrics
13 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr
14 | import soundfile as sf
15 | from torchnmf.nmf import NMF
16 | import torch.nn.functional as F
17 | from torchaudio.transforms import Spectrogram, InverseSpectrogram
18 |
19 | from mlc import Sparse_Pitch_Profile, MLC
20 | from medley_vox import MedleyVox
21 | from eval import COMPUTE_METRICS
22 |
23 |
24 | def get_harmonics(n_fft, freqs, window_fn=torch.hann_window):
25 | window = window_fn(n_fft)
26 | tmp = []
27 | for f in freqs:
28 | h = torch.arange(int(0.5 * sr / f)) + 1
29 | # h = h[:40]
30 | ch = (
31 | f
32 | * 27
33 | / 4
34 | * (
35 | torch.exp(-1j * torch.pi * h)
36 | + 2 * (1 + 2 * torch.exp(-1j * torch.pi * h)) / (1j * torch.pi * h)
37 | - 6 * (1 - torch.exp(-1j * torch.pi * h)) / (1j * torch.pi * h) ** 2
38 | )
39 | )
40 | ch /= torch.abs(ch).max()
41 | t = torch.arange(n_fft) / sr
42 | eu = ch @ torch.exp(2j * torch.pi * h[:, None] * f * t)
43 | eu /= torch.linalg.norm(eu)
44 | tmp.append(torch.abs(torch.fft.fft(eu * window.numpy())[: n_fft // 2 + 1]))
45 |
46 | noise = torch.ones(n_fft // 2 + 1)
47 | tmp.append(noise)
48 |
49 | W_f0 = torch.stack(tmp).T
50 | W_f0 /= W_f0.max(0).values
51 | return W_f0
52 |
53 |
54 | def get_streams(
55 | cfp: torch.Tensor,
56 | freqs: np.ndarray,
57 | n: int = 2,
58 | thresh_ratio: float = 0.1,
59 | ):
60 | top_values, top_indices = torch.topk(cfp, n * 3, dim=1)
61 | thresh = top_values.max() * thresh_ratio
62 | top_freqs = freqs[top_indices]
63 |
64 | # remove zero values
65 | top_zipped = [
66 | tuple(x[1] for x in filter(lambda x: x[0] > thresh, zip(*pair)))
67 | for pair in zip(top_values.tolist(), top_freqs.tolist())
68 | ]
69 |
70 | # init states
71 |
72 | def cases(states: tuple, possible_states: tuple):
73 | possible_states = tuple(sorted(possible_states))
74 | is_none = lambda x: x is None
75 | none_mapper = map(is_none, states)
76 | none_states = list(filter(is_none, states))
77 | valid_states = list(filter(lambda x: not is_none(x), states))
78 |
79 | if len(valid_states) == 0:
80 | return possible_states[:n] + (None,) * max(0, n - len(possible_states))
81 | elif len(possible_states) == 0:
82 | return (None,) * n
83 |
84 | def get_dist(curr, incoming):
85 | diff = abs(curr - incoming)
86 | return diff**2
87 |
88 | # first, possible_states are more than valid_states
89 | if len(possible_states) > len(valid_states):
90 | permutes = list(
91 | combinations(range(len(possible_states)), len(valid_states))
92 | )
93 | permuted_states = [
94 | [possible_states[i] for i in permute] for permute in permutes
95 | ]
96 | dists = map(
97 | lambda permute: reduce(
98 | lambda acc, pair: acc + get_dist(*pair),
99 | zip(valid_states, permute),
100 | 0,
101 | ),
102 | permuted_states,
103 | )
104 | min_permute_index = min(zip(dists, permutes), key=lambda x: x[0])[1]
105 | new_valid_states = tuple(possible_states[i] for i in min_permute_index)
106 | new_none_states = tuple(
107 | possible_states[i]
108 | for i in range(len(possible_states))
109 | if i not in min_permute_index
110 | )
111 | else:
112 | permutes = list(
113 | combinations(range(len(valid_states)), len(possible_states))
114 | )
115 | permuted_valid_states = [
116 | [valid_states[i] for i in permute] for permute in permutes
117 | ]
118 | dists = map(
119 | lambda permute: reduce(
120 | lambda acc, pair: acc + get_dist(*pair),
121 | zip(permute, possible_states),
122 | 0,
123 | ),
124 | permuted_valid_states,
125 | )
126 | min_permute_index = min(zip(dists, permutes), key=lambda x: x[0])[1]
127 | new_valid_states = ()
128 | for i in range(len(valid_states)):
129 | if i in min_permute_index:
130 | new_valid_states += (possible_states[min_permute_index.index(i)],)
131 | else:
132 | new_valid_states += (None,)
133 |
134 | new_none_states = ()
135 | new_none_states = (
136 | new_none_states[: len(none_states)]
137 | if len(new_none_states) > len(none_states)
138 | else new_none_states + (None,) * (len(none_states) - len(new_none_states))
139 | )
140 |
141 | # merge states
142 | new_states, *_ = reduce(
143 | lambda acc, is_none: (acc[0] + acc[1][:1], acc[1][1:], acc[2])
144 | if is_none
145 | else (acc[0] + acc[2][:1], acc[1], acc[2][1:]),
146 | none_mapper,
147 | ((), new_none_states, new_valid_states),
148 | )
149 | return new_states
150 |
151 | state_changes = list(accumulate(top_zipped, func=cases, initial=(None,) * n))[1:]
152 |
153 | return list(zip(*state_changes))
154 |
155 |
156 | if __name__ == "__main__":
157 | parser = argparse.ArgumentParser()
158 | parser.add_argument("medleyvox", type=str, help="Path to MedleyVox dataset")
159 | parser.add_argument("--n_fft", default=8192, type=int, help="N FFT")
160 | parser.add_argument("--hop_length", default=256, type=int, help="Hop size")
161 | parser.add_argument("--win_length", default=2048, type=int, help="Window size")
162 | parser.add_argument(
163 | "--window",
164 | choices=["hann", "blackman", "hamming"],
165 | help="Window type",
166 | default="hann",
167 | )
168 | parser.add_argument("--out", type=str, help="Output directory")
169 | parser.add_argument("--full-duet", action="store_true", help="Drop duet songs")
170 | parser.add_argument("--divisions", type=int, default=8, help="Divisions")
171 | parser.add_argument("--min-f0", type=float, default=50, help="Min F0")
172 | parser.add_argument("--max-f0", type=float, default=1000, help="Max F0")
173 | parser.add_argument(
174 | "--gammas", type=float, nargs="+", help="Gammas", default=[0.2, 0.6, 0.8]
175 | )
176 | parser.add_argument(
177 | "--thresh", type=float, default=0.0, help="Salience threshold ratio"
178 | )
179 | parser.add_argument("--beta", type=float, default=1.0, help="Beta")
180 | parser.add_argument("--kernel-size", type=int, default=3, help="Kernel size")
181 |
182 | args = parser.parse_args()
183 |
184 | window_fn = {
185 | "hann": torch.hann_window,
186 | "blackman": torch.blackman_window,
187 | "hamming": torch.hamming_window,
188 | }[args.window]
189 |
190 | sr = 24000
191 |
192 | dataset = MedleyVox(
193 | args.medleyvox,
194 | sample_rate=sr,
195 | drop_duet=not args.full_duet,
196 | )
197 |
198 | mlc = MLC(
199 | args.n_fft,
200 | sr,
201 | args.gammas,
202 | args.hop_length,
203 | win_length=args.win_length,
204 | window_fn=window_fn,
205 | )
206 |
207 | pitch_profiler = Sparse_Pitch_Profile(
208 | args.n_fft, sr, 0, division=args.divisions, norm=True
209 | )
210 | freqs = pitch_profiler.fd[1:-1]
211 |
212 | idx_low = (freqs >= args.min_f0).nonzero()[0][0]
213 | idx_high = (freqs <= args.max_f0).nonzero()[0][-1]
214 | selection_slice = slice(idx_low, idx_high + 1)
215 | freqs = freqs[selection_slice]
216 |
217 | W_f0 = get_harmonics(args.win_length, freqs, window_fn=window_fn)
218 |
219 | spec = Spectrogram(
220 | n_fft=args.win_length,
221 | hop_length=args.hop_length,
222 | power=None,
223 | window_fn=window_fn,
224 | )
225 | inv_spec = InverseSpectrogram(
226 | n_fft=args.win_length,
227 | hop_length=args.hop_length,
228 | window_fn=window_fn,
229 | )
230 |
231 | loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
232 |
233 | accumulate_metrics_mean = {}
234 |
235 | with tqdm(dataset) as pbar:
236 | for mix_num, (x, y, ids) in enumerate(pbar):
237 | ppt, ppf = pitch_profiler(*mlc(x.unsqueeze(0)))
238 |
239 | # smoothing
240 | kernel_size = args.kernel_size
241 | ppt = F.avg_pool1d(ppt, kernel_size, stride=1, padding=kernel_size // 2)
242 | ppf = F.avg_pool1d(ppf, kernel_size, stride=1, padding=kernel_size // 2)
243 |
244 | cfp = (ppt * ppf).squeeze(0)
245 |
246 | # peak picking
247 | cfp = torch.where((cfp > cfp.roll(1, 1)) & (cfp > cfp.roll(-1, 1)), cfp, 0)
248 | cfp = cfp[:, selection_slice]
249 | stream1, stream2 = get_streams(
250 | cfp, np.arange(len(freqs)), n=2, thresh_ratio=args.thresh
251 | )
252 |
253 | singer1_H = torch.zeros(cfp.shape[0], len(freqs) + 1)
254 | singer2_H = singer1_H.clone()
255 | singer1_H.scatter_(
256 | 1,
257 | torch.tensor(
258 | list(
259 | map(
260 | lambda x: [len(freqs)] * 3
261 | if x is None
262 | else [x - 1, x, x + 1],
263 | stream1,
264 | )
265 | )
266 | ),
267 | 1,
268 | )
269 | singer2_H.scatter_(
270 | 1,
271 | torch.tensor(
272 | list(
273 | map(
274 | lambda x: [len(freqs)] * 3
275 | if x is None
276 | else [x - 1, x, x + 1],
277 | stream2,
278 | )
279 | )
280 | ),
281 | 1,
282 | )
283 |
284 | nmf = NMF(
285 | W=torch.cat([W_f0, W_f0], dim=1),
286 | H=torch.cat([singer1_H, singer2_H], dim=1),
287 | trainable_W=False,
288 | )
289 |
290 | X = spec(x)
291 |
292 | nmf.fit(X.abs().T, beta=args.beta, alpha=1e-6)
293 |
294 | with torch.no_grad():
295 | H1, H2 = nmf.H.chunk(2, dim=1)
296 | W1, W2 = nmf.W.chunk(2, dim=1)
297 | recon_singer1 = H1 @ W1.T
298 | recon_singer2 = H2 @ W2.T
299 | recon = recon_singer1 + recon_singer2
300 | mask1 = recon_singer1 / recon
301 | mask2 = recon_singer2 / recon
302 |
303 | y1 = inv_spec(X * mask1.T)
304 | y2 = inv_spec(X * mask2.T)
305 |
306 | result = torch.stack([y1, y2], dim=0)
307 |
308 | if result.shape[1] < y.shape[1]:
309 | result = F.pad(
310 | result.unsqueeze(0),
311 | (0, y.shape[1] - result.shape[1]),
312 | ).squeeze(0)
313 |
314 | loss, reordered_sources = loss_func(
315 | result.unsqueeze(0), y.unsqueeze(0), return_est=True
316 | )
317 | est = reordered_sources.squeeze().cpu().numpy()
318 |
319 | utt_metrics = get_metrics(
320 | x.numpy(),
321 | y.numpy(),
322 | est,
323 | sample_rate=sr,
324 | metrics_list=COMPUTE_METRICS,
325 | )
326 |
327 | # calculate improvement
328 | for metric in COMPUTE_METRICS:
329 | v = utt_metrics.pop("input_" + metric)
330 | utt_metrics[metric + "i"] = utt_metrics[metric] - v
331 |
332 | for k, v in utt_metrics.items():
333 | if k not in accumulate_metrics_mean:
334 | accumulate_metrics_mean[k] = 0
335 |
336 | accumulate_metrics_mean[k] += (v - accumulate_metrics_mean[k]) / (
337 | mix_num + 1
338 | )
339 |
340 | pbar.set_postfix(accumulate_metrics_mean)
341 |
342 | if args.out is not None:
343 | out_dir = pathlib.Path(args.out) / f"medleyvox_{mix_num}"
344 | out_dir.mkdir(parents=True, exist_ok=True)
345 |
346 | sf.write(
347 | out_dir / "mixture.wav",
348 | x.numpy(),
349 | sr,
350 | "PCM_16",
351 | )
352 |
353 | for i, s in enumerate(est):
354 | out_path = out_dir / f"{ids[i]}.wav"
355 | sf.write(out_path, s, sr, "PCM_16")
356 |
357 | print(accumulate_metrics_mean)
358 |
--------------------------------------------------------------------------------