├── .gitmodules ├── .gitignore ├── .env.tmp ├── requirements.txt ├── config.yaml ├── .pre-commit-config.yaml ├── LICENSE ├── main ├── dataset.py ├── utils.py ├── inference.py └── module_vckt.py ├── exp ├── base_vctk_k_2.yaml ├── base_vctk_k_none.yaml └── singing.yaml ├── train.py ├── README.md ├── medley_vox.py ├── mlc.py ├── eval.py └── eval_nmf.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "datasets"] 2 | path = datasets 3 | url = https://github.com/yoyololicon/pytorch-wav-datasets.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | .env 4 | .DS_Store 5 | .DS_Store/ 6 | .hydra 7 | venv/ 8 | logs/ 9 | data/ 10 | train.log 11 | .idea/ 12 | .vscode/ 13 | est/ -------------------------------------------------------------------------------- /.env.tmp: -------------------------------------------------------------------------------- 1 | DIR_LOGS=/logs 2 | DIR_DATA=/data 3 | 4 | # Required if using wandb logger 5 | WANDB_PROJECT=wandbprojectname 6 | WANDB_ENTITY=wandbuser 7 | WANDB_API_KEY=wandbapikey 8 | 9 | # Required if using Common Voice dataset 10 | HUGGINGFACE_TOKEN=huggingfacetoken 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | pytorch-lightning==1.9.5 3 | python-dotenv 4 | hydra-core 5 | hydra-colorlog 6 | wandb 7 | auraloss 8 | yt-dlp 9 | datasets 10 | pyloudnorm 11 | einops 12 | omegaconf 13 | rich 14 | plotly 15 | librosa 16 | transformers 17 | eng-to-ipa 18 | ema-pytorch 19 | py7zr 20 | 21 | audio-diffusion-pytorch==0.0.92 22 | audio-encoders-pytorch 23 | audio-data-pytorch 24 | quantizer-pytorch 25 | difformer-pytorch 26 | a-transformers-pytorch 27 | asteroid -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - exp: null 4 | - override hydra/hydra_logging: colorlog 5 | - override hydra/job_logging: colorlog 6 | 7 | seed: 12345 8 | train: True 9 | ignore_warnings: True 10 | print_config: False # Prints tree with all configurations 11 | work_dir: ${hydra:runtime.cwd} # This is the root of the project 12 | logs_dir: ${work_dir}${oc.env:DIR_LOGS} # This is the root for all logs 13 | data_dir: ${work_dir}${oc.env:DIR_DATA} # This is the root for all data 14 | ckpt_dir: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S} 15 | # Hydra experiment configs log dir 16 | hydra: 17 | run: 18 | dir: ${logs_dir}/runs/${now:%Y-%m-%d-%H-%M-%S} 19 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | 8 | # Formats code correctly 9 | - repo: https://github.com/psf/black 10 | rev: 21.12b0 11 | hooks: 12 | - id: black 13 | args: [ 14 | '--experimental-string-processing' 15 | ] 16 | 17 | # Checks types 18 | - repo: https://github.com/pre-commit/mirrors-mypy 19 | rev: 'v0.971' 20 | hooks: 21 | - id: mypy 22 | additional_dependencies: [data-science-types>=0.2, torch>=1.6] 23 | 24 | # Sorts imports 25 | - repo: https://github.com/pycqa/isort 26 | rev: 5.10.1 27 | hooks: 28 | - id: isort 29 | name: isort (python) 30 | args: ["--profile", "black"] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /main/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | from torchaudio.datasets.vctk import VCTK_092 8 | from torchaudio.functional import resample 9 | 10 | 11 | def vctk_collate(batch, 12 | target_rate=22050, 13 | length=131072, 14 | mix_k = None): 15 | def pad_tensor(t): 16 | current_length = t.shape[1] 17 | padding_length = length - current_length 18 | return F.pad(t, (0, padding_length)) 19 | 20 | original_sr = batch[0][1] 21 | waveforms = [pad_tensor(resample(waveform=data[0], 22 | orig_freq=original_sr, 23 | new_freq=target_rate)) for data in batch] 24 | if not mix_k: 25 | return torch.stack(waveforms) 26 | if len(waveforms) % mix_k != 0: 27 | raise ValueError("Batch size must be divisble by mix_k") 28 | random.shuffle(waveforms) 29 | partitioned_waveforms = [waveforms[i:i + mix_k] for i in range(0, len(waveforms), mix_k)] 30 | summed_list = [torch.sum(torch.stack(subset), dim=0) for subset in partitioned_waveforms] 31 | return torch.stack(summed_list) 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | dataset = VCTK_092(root="/home/emilian/PycharmProjects/multi-speaker-diff-sep/data/vctk", download=True) 37 | dataloader = DataLoader(dataset=dataset, 38 | batch_size=16, 39 | collate_fn=partial(vctk_collate, mix_k=2), 40 | pin_memory=True, 41 | num_workers=0) 42 | output = next(iter(dataloader)) -------------------------------------------------------------------------------- /exp/base_vctk_k_2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # To pick dataset path execute with: ++datamodule.dataset.path=/your_wav_root/ 4 | 5 | sampling_rate: 22050 6 | length: 131072 7 | channels: 1 8 | log_every_n_steps: 1000 9 | 10 | model: 11 | _target_: main.module_vckt.Model 12 | lr: 1e-4 13 | lr_beta1: 0.95 14 | lr_beta2: 0.999 15 | lr_eps: 1e-6 16 | lr_weight_decay: 1e-3 17 | # ema_beta: 0.9999 18 | # ema_power: 0.7 19 | 20 | model: 21 | _target_: audio_diffusion_pytorch.AudioDiffusionModel 22 | in_channels: ${channels} 23 | channels: 256 24 | patch_size: 16 25 | resnet_groups: 8 26 | kernel_multiplier_downsample: 2 27 | multipliers: [ 1, 2, 4, 4, 4, 4, 4 ] 28 | factors: [ 4, 4, 4, 2, 2, 2 ] 29 | num_blocks: [ 2, 2, 2, 2, 2, 2 ] 30 | attentions: [ 0, 0, 0, 1, 1, 1, 1 ] 31 | attention_heads: 8 32 | attention_features: 128 33 | attention_multiplier: 2 34 | use_nearest_upsample: False 35 | use_skip_scale: True 36 | diffusion_type: k 37 | diffusion_sigma_distribution: 38 | _target_: audio_diffusion_pytorch.LogNormalDistribution 39 | mean: -3.0 40 | std: 1.0 41 | diffusion_sigma_data: 0.2 42 | 43 | 44 | datamodule: 45 | _target_: main.module_vckt.Datamodule 46 | dataset: 47 | _target_: torchaudio.datasets.vctk.VCTK_092 48 | root: null 49 | download: True 50 | 51 | collate_fn: 52 | _target_: main.dataset.vctk_collate 53 | _partial_: True 54 | target_rate: ${sampling_rate} 55 | length: ${length} 56 | mix_k: 2 57 | 58 | val_split: 0.01 59 | batch_size: 32 60 | num_workers: 8 61 | pin_memory: True 62 | 63 | callbacks: 64 | rich_progress_bar: 65 | _target_: pytorch_lightning.callbacks.RichProgressBar 66 | 67 | model_checkpoint: 68 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 69 | monitor: "valid_loss" # name of the logged metric which determines when model is improving 70 | save_top_k: 1 # save k best models (determined by above metric) 71 | save_last: True # additionaly always save model from last epoch 72 | mode: "min" # can be "max" or "min" 73 | verbose: False 74 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S} 75 | filename: '{epoch:02d}-{valid_loss:.3f}' 76 | 77 | model_summary: 78 | _target_: pytorch_lightning.callbacks.RichModelSummary 79 | max_depth: 2 80 | 81 | audio_samples_logger: 82 | _target_: main.module_vckt.SampleLogger 83 | num_items: 4 84 | channels: ${channels} 85 | sampling_rate: ${sampling_rate} 86 | length: ${length} 87 | sampling_steps: [100] 88 | # use_ema_model: True 89 | diffusion_sampler: 90 | _target_: audio_diffusion_pytorch.ADPM2Sampler 91 | rho: 1.0 92 | diffusion_schedule: 93 | _target_: audio_diffusion_pytorch.KarrasSchedule 94 | sigma_min: 0.0001 95 | sigma_max: 5.0 96 | rho: 9.0 97 | 98 | loggers: 99 | wandb: 100 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 101 | project: ${oc.env:WANDB_PROJECT} 102 | entity: ${oc.env:WANDB_ENTITY} 103 | # offline: False # set True to store all logs only locally 104 | job_type: "train" 105 | group: "" 106 | save_dir: ${logs_dir} 107 | 108 | trainer: 109 | _target_: pytorch_lightning.Trainer 110 | gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0` 111 | precision: 32 # Precision used for tensors, default `32` 112 | accelerator: null # `ddp` GPUs train individually and sync gradients, default `None` 113 | min_epochs: 0 114 | max_epochs: -1 115 | enable_model_summary: False 116 | log_every_n_steps: 1 # Logs metrics every N batches 117 | check_val_every_n_epoch: null 118 | val_check_interval: ${log_every_n_steps} 119 | -------------------------------------------------------------------------------- /exp/base_vctk_k_none.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # To pick dataset path execute with: ++datamodule.dataset.path=/your_wav_root/ 4 | 5 | sampling_rate: 22050 6 | length: 131072 7 | channels: 1 8 | log_every_n_steps: 1000 9 | 10 | model: 11 | _target_: main.module_vckt.Model 12 | lr: 1e-4 13 | lr_beta1: 0.95 14 | lr_beta2: 0.999 15 | lr_eps: 1e-6 16 | lr_weight_decay: 1e-3 17 | # ema_beta: 0.9999 18 | # ema_power: 0.7 19 | 20 | model: 21 | _target_: audio_diffusion_pytorch.AudioDiffusionModel 22 | in_channels: ${channels} 23 | channels: 256 24 | patch_size: 16 25 | resnet_groups: 8 26 | kernel_multiplier_downsample: 2 27 | multipliers: [ 1, 2, 4, 4, 4, 4, 4 ] 28 | factors: [ 4, 4, 4, 2, 2, 2 ] 29 | num_blocks: [ 2, 2, 2, 2, 2, 2 ] 30 | attentions: [ 0, 0, 0, 1, 1, 1, 1 ] 31 | attention_heads: 8 32 | attention_features: 128 33 | attention_multiplier: 2 34 | use_nearest_upsample: False 35 | use_skip_scale: True 36 | diffusion_type: k 37 | diffusion_sigma_distribution: 38 | _target_: audio_diffusion_pytorch.LogNormalDistribution 39 | mean: -3.0 40 | std: 1.0 41 | diffusion_sigma_data: 0.2 42 | 43 | 44 | datamodule: 45 | _target_: main.module_vckt.Datamodule 46 | dataset: 47 | _target_: torchaudio.datasets.vctk.VCTK_092 48 | root: null 49 | download: True 50 | 51 | collate_fn: 52 | _target_: main.dataset.vctk_collate 53 | _partial_: True 54 | target_rate: ${sampling_rate} 55 | length: ${length} 56 | mix_k: null 57 | 58 | val_split: 0.01 59 | batch_size: 32 60 | num_workers: 8 61 | pin_memory: True 62 | 63 | callbacks: 64 | rich_progress_bar: 65 | _target_: pytorch_lightning.callbacks.RichProgressBar 66 | 67 | model_checkpoint: 68 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 69 | monitor: "valid_loss" # name of the logged metric which determines when model is improving 70 | save_top_k: 1 # save k best models (determined by above metric) 71 | save_last: True # additionaly always save model from last epoch 72 | mode: "min" # can be "max" or "min" 73 | verbose: False 74 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S} 75 | filename: '{epoch:02d}-{valid_loss:.3f}' 76 | 77 | model_summary: 78 | _target_: pytorch_lightning.callbacks.RichModelSummary 79 | max_depth: 2 80 | 81 | audio_samples_logger: 82 | _target_: main.module_vckt.SampleLogger 83 | num_items: 4 84 | channels: ${channels} 85 | sampling_rate: ${sampling_rate} 86 | length: ${length} 87 | sampling_steps: [100] 88 | # use_ema_model: True 89 | diffusion_sampler: 90 | _target_: audio_diffusion_pytorch.ADPM2Sampler 91 | rho: 1.0 92 | diffusion_schedule: 93 | _target_: audio_diffusion_pytorch.KarrasSchedule 94 | sigma_min: 0.0001 95 | sigma_max: 5.0 96 | rho: 9.0 97 | 98 | loggers: 99 | wandb: 100 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 101 | project: ${oc.env:WANDB_PROJECT} 102 | entity: ${oc.env:WANDB_ENTITY} 103 | # offline: False # set True to store all logs only locally 104 | job_type: "train" 105 | group: "" 106 | save_dir: ${logs_dir} 107 | 108 | trainer: 109 | _target_: pytorch_lightning.Trainer 110 | gpus: 0 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0` 111 | precision: 32 # Precision used for tensors, default `32` 112 | accelerator: null # `ddp` GPUs train individually and sync gradients, default `None` 113 | min_epochs: 0 114 | max_epochs: -1 115 | enable_model_summary: False 116 | log_every_n_steps: 1 # Logs metrics every N batches 117 | check_val_every_n_epoch: null 118 | val_check_interval: ${log_every_n_steps} 119 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dotenv 4 | import hydra 5 | import pytorch_lightning as pl 6 | from main import utils 7 | from omegaconf import DictConfig, open_dict 8 | 9 | # Load environment variables from `.env`. 10 | dotenv.load_dotenv(override=True) 11 | log = utils.get_logger(__name__) 12 | 13 | 14 | @hydra.main(config_path=".", config_name="config.yaml", version_base=None) 15 | def main(config: DictConfig) -> None: 16 | 17 | # Logs config tree 18 | utils.extras(config) 19 | 20 | # Apply seed for reproducibility 21 | pl.seed_everything(config.seed) 22 | 23 | # Initialize datamodule 24 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>.") 25 | datamodule = hydra.utils.instantiate(config.datamodule, _convert_="partial") 26 | 27 | # Initialize model 28 | log.info(f"Instantiating model <{config.model._target_}>.") 29 | model = hydra.utils.instantiate(config.model, _convert_="partial") 30 | 31 | # Initialize all callbacks (e.g. checkpoints, early stopping) 32 | callbacks = [] 33 | 34 | # If save is provided add callback that saves and stops, to be used with +ckpt 35 | if "save" in config: 36 | # Ignore loggers and other callbacks 37 | with open_dict(config): 38 | config.pop("loggers") 39 | config.pop("callbacks") 40 | config.trainer.num_sanity_val_steps = 0 41 | attribute, path = config.get("save"), config.get("ckpt_dir") 42 | filename = os.path.join(path, f"{attribute}.pt") 43 | callbacks += [utils.SavePytorchModelAndStopCallback(filename, attribute)] 44 | 45 | if "callbacks" in config: 46 | for _, cb_conf in config["callbacks"].items(): 47 | if "_target_" in cb_conf: 48 | log.info(f"Instantiating callback <{cb_conf._target_}>.") 49 | callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial")) 50 | 51 | # Initialize loggers (e.g. wandb) 52 | loggers = [] 53 | if "loggers" in config: 54 | for _, lg_conf in config["loggers"].items(): 55 | if "_target_" in lg_conf: 56 | log.info(f"Instantiating logger <{lg_conf._target_}>.") 57 | # Sometimes wandb throws error if slow connection... 58 | logger = utils.retry_if_error( 59 | lambda: hydra.utils.instantiate(lg_conf, _convert_="partial") 60 | ) 61 | loggers.append(logger) 62 | 63 | # Initialize trainer 64 | log.info(f"Instantiating trainer <{config.trainer._target_}>.") 65 | trainer = hydra.utils.instantiate( 66 | config.trainer, callbacks=callbacks, logger=loggers, _convert_="partial" 67 | ) 68 | 69 | # Send some parameters from config to all lightning loggers 70 | log.info("Logging hyperparameters!") 71 | utils.log_hyperparameters( 72 | config=config, 73 | model=model, 74 | datamodule=datamodule, 75 | trainer=trainer, 76 | callbacks=callbacks, 77 | logger=loggers, 78 | ) 79 | 80 | # Train with checkpoint if present, otherwise from start 81 | if "ckpt" in config: 82 | ckpt = config.get("ckpt") 83 | log.info(f"Starting training from {ckpt}") 84 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt) 85 | else: 86 | log.info("Starting training.") 87 | trainer.fit(model=model, datamodule=datamodule) 88 | 89 | # Make sure everything closed properly 90 | log.info("Finalizing!") 91 | utils.finish( 92 | config=config, 93 | model=model, 94 | datamodule=datamodule, 95 | trainer=trainer, 96 | callbacks=callbacks, 97 | logger=loggers, 98 | ) 99 | 100 | # Print path to best checkpoint 101 | if ( 102 | not config.trainer.get("fast_dev_run") 103 | and config.get("train") 104 | and not config.get("save") 105 | ): 106 | log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}") 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Singing Voice Separation 2 | 3 | Source code of the paper [Zero-Shot Duet Singing Voices Separation with 4 | Diffusion Models](https://sdx-workshop.github.io/papers/Yu.pdf) at the SDX workshop 2023. 5 | 6 | ## Setup 7 | 8 | Install requirements 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | Add environment variables, rename `.env.tmp` to `.env` and replace with your own variables (example values are random) 15 | ```bash 16 | DIR_LOGS=/logs 17 | DIR_DATA=/data 18 | 19 | # Required if using wandb logger 20 | WANDB_PROJECT=audioproject 21 | WANDB_ENTITY=johndoe 22 | WANDB_API_KEY=a21dzbqlybbzccqla4txa21dzbqlybbzccqla4tx 23 | ``` 24 | 25 | ## Training 26 | 27 | The config we used for the paper is [`exp/singing.yaml`](exp/singing.yaml), you can run it with 28 | ```bash 29 | python train.py exp=singing 30 | ``` 31 | You'll need to download the relevant dataset and resample them to 24 kHz. 32 | Them, modified the `datamodule` section of the config to point to the right path. 33 | 34 | Resume run from a checkpoint 35 | 36 | ```bash 37 | python train.py exp=singing +ckpt=/logs/ckpts/2022-08-17-01-22-18/'last.ckpt' 38 | ``` 39 | 40 | ## Evaluation 41 | 42 | First, download the [MedleyVox](https://github.com/jeonchangbin49/MedleyVox?tab=readme-ov-file) dataset. 43 | Then, run the following command to evaluate the model on the `duet` subset of the dataset. 44 | 45 | ```bash 46 | python eval.py logs/runs/XXXX/.hydra/config.yaml logs/ckpts/XXXX/last.ckpt /your/path/to/MedleyVox -T 100 --cond --hop-length 32768 --self-cond --retry 2 47 | ``` 48 | 49 | Some important arguments: 50 | 51 | 1. `-T`: number of diffusion steps 52 | 2. `--cond`: use auto-regressive conditioning on the ground truth (teacher forcing). Without this flag, the model will generate the full length audio at once 53 | 3. `--self-cond`: perform auto-regressive conditioning on the generated audio if use together with `--cond` 54 | 4. `--hop-length`: the hop length of the moving window 55 | 5. `--window`: the size of the moving window. Default to the same length as training data 56 | 6. `--retry`: number of retries for each auto-regressive step. The algorithm with generate `retry + 1` candidates and pick the most similar one to the ground truth. Default to 0 57 | 58 | For other arguments, please check out the code. 59 | 60 | ### NMF baseline 61 | 62 | This baseline depends on `torchnmf`. 63 | 64 | ```bash 65 | python eval_nmf.py /your/path/to/MedleyVox/ --thresh 0.08 --division 10 --kernel-size 7 66 | ``` 67 | 68 | ### Checkpoint/Logs 69 | 70 | Our pre-trained singing voice diffusion model can be downloaded [here](https://drive.google.com/drive/folders/1nAj0JDiG70ddr_7UnhszpIiVCh4SzqgW?usp=sharing). 71 | You can find the training logs and unconditional singing samples generated during training on [wandb](https://api.wandb.ai/links/aimless/fqtcyjke). 72 | 73 | ## FAQ 74 | 75 |
76 | How do I load the model once I'm done training? 77 | 78 | If you want to load the checkpoint to restore training with the trainer you can do `python train.py exp=my_experiment +ckpt=/logs/ckpts/2022-08-17-01-22-18/'last.ckpt'`. 79 | 80 | Otherwise if you want to instantiate a model from the checkpoint: 81 | ```py 82 | from main.mymodule import Model 83 | model = Model.load_from_checkpoint( 84 | checkpoint_path='my_checkpoint.ckpt', 85 | learning_rate=1e-4, 86 | beta1=0.9, 87 | beta2=0.99, 88 | in_channels=1, 89 | patch_size=16, 90 | all_other_paratemeters_here... 91 | ) 92 | ``` 93 | to get only the PyTorch `.pt` checkpoint you can save the internal model weights as `torch.save(model.model.state_dict(), 'torchckpt.pt')`. 94 | 95 |
96 | 97 | 98 |
99 | Why no checkpoint is created at the end of the epoch? 100 | 101 | If the epoch is shorter than `log_every_n_steps` it doesn't save the checkpoint at the end of the epoch, but after the provided number of steps. If you want to checkpoint more frequently you can add `every_n_train_steps` to the ModelCheckpoint e.g.: 102 | ```yaml 103 | model_checkpoint: 104 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 105 | monitor: "valid_loss" # name of the logged metric which determines when model is improving 106 | save_top_k: 1 # save k best models (determined by above metric) 107 | save_last: True # additionaly always save model from last epoch 108 | mode: "min" # can be "max" or "min" 109 | verbose: False 110 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S} 111 | filename: '{epoch:02d}-{valid_loss:.3f}' 112 | every_n_train_steps: 10 113 | ``` 114 | Note that logging the checkpoint so frequently is not recommended in general, since it takes a bit of time to store the file. 115 | 116 |
117 | -------------------------------------------------------------------------------- /exp/singing.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # To pick dataset path execute with: ++datamodule.dataset.path=/your_wav_root/ 4 | 5 | sampling_rate: 24000 6 | length: 131072 7 | channels: 1 8 | log_every_n_steps: 3000 9 | overlap: 65536 10 | 11 | model: 12 | _target_: main.module_vckt.Model 13 | lr: 1e-4 14 | lr_beta1: 0.95 15 | lr_beta2: 0.999 16 | lr_eps: 1e-6 17 | lr_weight_decay: 1e-3 18 | # ema_beta: 0.9999 19 | # ema_power: 0.7 20 | 21 | model: 22 | _target_: audio_diffusion_pytorch.AudioDiffusionModel 23 | in_channels: ${channels} 24 | channels: 256 25 | patch_size: 12 26 | resnet_groups: 8 27 | kernel_multiplier_downsample: 2 28 | multipliers: [ 1, 2, 4, 4, 4, 4, 4 ] 29 | factors: [ 4, 4, 4, 2, 2, 2 ] 30 | num_blocks: [ 2, 2, 2, 2, 2, 2 ] 31 | attentions: [ 0, 0, 0, 1, 1, 1, 1 ] 32 | attention_heads: 8 33 | attention_features: 128 34 | attention_multiplier: 2 35 | use_nearest_upsample: False 36 | use_skip_scale: True 37 | diffusion_type: k 38 | diffusion_sigma_distribution: 39 | _target_: audio_diffusion_pytorch.LogNormalDistribution 40 | mean: -3.0 41 | std: 1.0 42 | diffusion_sigma_data: 0.2 43 | 44 | 45 | datamodule: 46 | _target_: main.module_vckt.Datamodule 47 | dataset: 48 | _target_: torch.utils.data.ConcatDataset 49 | datasets: 50 | - _target_: datasets.wav.WAVDataset 51 | data_dir: /import/c4dm-datasets-ext/m4singer-24k/ 52 | segment: ${length} 53 | overlap: ${overlap} 54 | mono: True 55 | keepdim: True 56 | - _target_: datasets.wav.WAVDataset 57 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/OpenCPOP-24k/ 58 | segment: ${length} 59 | overlap: ${overlap} 60 | mono: True 61 | keepdim: True 62 | - _target_: datasets.wav.WAVDataset 63 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/VocalSet-24k/ 64 | segment: ${length} 65 | overlap: ${overlap} 66 | mono: True 67 | keepdim: True 68 | - _target_: datasets.wav.WAVDataset 69 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/OpenSinger-24k/ 70 | segment: ${length} 71 | overlap: ${overlap} 72 | mono: True 73 | keepdim: True 74 | - _target_: datasets.wav.WAVDataset 75 | data_dir: /import/c4dm-datasets-ext/jvs_music_ver1/ 76 | segment: ${length} 77 | overlap: ${overlap} 78 | mono: True 79 | keepdim: True 80 | - _target_: datasets.wav.WAVDataset 81 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/CSD-24k/ 82 | segment: ${length} 83 | overlap: ${overlap} 84 | mono: True 85 | keepdim: True 86 | - _target_: datasets.wav.WAVDataset 87 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/NUS-24k/ 88 | segment: ${length} 89 | overlap: ${overlap} 90 | mono: True 91 | keepdim: True 92 | - _target_: datasets.wav.WAVDataset 93 | data_dir: /import/c4dm-datasets-ext/ycy_artefacts/PJS-24k/ 94 | segment: ${length} 95 | overlap: ${overlap} 96 | mono: True 97 | keepdim: True 98 | 99 | 100 | collate_fn: null 101 | 102 | val_split: 0.01 103 | batch_size: 32 104 | num_workers: 8 105 | pin_memory: True 106 | 107 | callbacks: 108 | rich_progress_bar: 109 | _target_: pytorch_lightning.callbacks.RichProgressBar 110 | 111 | model_checkpoint: 112 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 113 | monitor: "valid_loss" # name of the logged metric which determines when model is improving 114 | save_top_k: 1 # save k best models (determined by above metric) 115 | save_last: True # additionaly always save model from last epoch 116 | mode: "min" # can be "max" or "min" 117 | verbose: False 118 | dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S} 119 | filename: '{epoch:02d}-{valid_loss:.3f}' 120 | 121 | model_summary: 122 | _target_: pytorch_lightning.callbacks.RichModelSummary 123 | max_depth: 2 124 | 125 | audio_samples_logger: 126 | _target_: main.module_vckt.SampleLogger 127 | num_items: 4 128 | channels: ${channels} 129 | sampling_rate: ${sampling_rate} 130 | length: ${length} 131 | sampling_steps: [100] 132 | # use_ema_model: True 133 | diffusion_sampler: 134 | _target_: audio_diffusion_pytorch.ADPM2Sampler 135 | rho: 1.0 136 | diffusion_schedule: 137 | _target_: audio_diffusion_pytorch.KarrasSchedule 138 | sigma_min: 0.0001 139 | sigma_max: 5.0 140 | rho: 9.0 141 | 142 | loggers: 143 | wandb: 144 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 145 | project: ${oc.env:WANDB_PROJECT} 146 | entity: ${oc.env:WANDB_ENTITY} 147 | # offline: False # set True to store all logs only locally 148 | job_type: "train" 149 | group: "" 150 | save_dir: ${logs_dir} 151 | 152 | trainer: 153 | _target_: pytorch_lightning.Trainer 154 | gpus: -1 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0` 155 | precision: 32 # Precision used for tensors, default `32` 156 | accelerator: null # `ddp` GPUs train individually and sync gradients, default `None` 157 | min_epochs: 0 158 | max_epochs: -1 159 | enable_model_summary: False 160 | log_every_n_steps: 1 # Logs metrics every N batches 161 | check_val_every_n_epoch: null 162 | val_check_interval: ${log_every_n_steps} 163 | -------------------------------------------------------------------------------- /medley_vox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | import librosa 9 | 10 | exclude_duet = { 11 | "CatMartino_IPromise", 12 | "TleilaxEnsemble_Late", 13 | "TleilaxEnsemble_MelancholyFlowers", 14 | } 15 | 16 | 17 | class MedleyVox(Dataset): 18 | """Dataset class for MedleyVox source separation tasks. 19 | 20 | Args: 21 | task (str): One of ``'unison'``, ``'duet'``, ``'main_vs_rest'`` or 22 | ``'total'`` : 23 | * ``'unison'`` for unison vocal separation. 24 | * ``'duet'`` for duet vocal separation. 25 | * ``'main_vs_rest'`` for main vs. rest vocal separation (main vs rest). 26 | * ``'n_singing'`` for N-singing separation. We will use all of the duet, unison, and main vs. rest data. 27 | 28 | sample_rate (int) : The sample rate of the sources and mixtures. 29 | n_src (int) : The number of sources in the mixture. Actually, this is fixed to 2 for our tasks. Need to be specified for N-singing training (future work). 30 | segment (int, optional) : The desired sources and mixtures length in s. 31 | """ 32 | 33 | dataset_name = "MedleyVox" 34 | 35 | def __init__( 36 | self, 37 | root_dir, 38 | metadata_dir=None, 39 | task="duet", 40 | sample_rate=24000, 41 | n_src=2, 42 | segment=None, 43 | return_id=True, 44 | drop_duet=False, 45 | ): 46 | self.root_dir = root_dir # /path/to/data/test_medleyDB 47 | self.metadata_dir = metadata_dir # ./testset/testset_config 48 | self.task = task.lower() 49 | self.return_id = return_id 50 | # Get the csv corresponding to the task 51 | if self.task == "unison": 52 | self.total_segments_list = glob.glob(f"{self.root_dir}/unison/*/*") 53 | elif self.task == "duet": 54 | self.total_segments_list = glob.glob(f"{self.root_dir}/duet/*/*") 55 | if drop_duet: 56 | total_segments_list = list( 57 | filter( 58 | lambda x: x.split("/")[-2] not in exclude_duet, 59 | self.total_segments_list, 60 | ) 61 | ) 62 | print( 63 | f"Drop {len(self.total_segments_list) - len(total_segments_list)} duet songs." 64 | ) 65 | self.total_segments_list = total_segments_list 66 | 67 | elif self.task == "main_vs_rest": 68 | self.total_segments_list = glob.glob(f"{self.root_dir}/rest/*/*") 69 | elif self.task == "n_singing": 70 | self.total_segments_list = ( 71 | glob.glob(f"{self.root_dir}/unison/*/*") 72 | + glob.glob(f"{self.root_dir}/duet/*/*") 73 | + glob.glob(f"{self.root_dir}/rest/*/*") 74 | ) 75 | self.segment = segment 76 | self.sample_rate = sample_rate 77 | self.n_src = n_src 78 | 79 | def __len__(self): 80 | return len(self.total_segments_list) 81 | 82 | def __getitem__(self, idx): 83 | song_name = self.total_segments_list[idx].split("/")[-2] 84 | segment_name = self.total_segments_list[idx].split("/")[-1] 85 | mixture_path = ( 86 | f"{self.total_segments_list[idx]}/mix/{song_name} - {segment_name}.wav" 87 | ) 88 | self.mixture_path = mixture_path 89 | sources_path_list = glob.glob(f"{self.total_segments_list[idx]}/gt/*.wav") 90 | 91 | if self.task == "main_vs_rest" or self.task == "n_singing": 92 | if os.path.exists( 93 | f"{self.metadata_dir}/V1_rest_vocals_only_config/{song_name}.json" 94 | ): 95 | metadata_json_path = ( 96 | f"{self.metadata_dir}/V1_rest_vocals_only_config/{song_name}.json" 97 | ) 98 | elif os.path.exists( 99 | f"{self.metadata_dir}/V2_vocals_only_config/{song_name}.json" 100 | ): 101 | metadata_json_path = ( 102 | f"{self.metadata_dir}/V2_vocals_only_config/{song_name}.json" 103 | ) 104 | else: 105 | print("main vs. rest metadata not found.") 106 | raise AttributeError 107 | with open(metadata_json_path, "r") as json_file: 108 | metadata_json = json.load(json_file) 109 | 110 | # Read sources 111 | sources_list = [] 112 | ids = [] 113 | if self.task == "main_vs_rest" or self.task == "n_singing": 114 | gt_main_name = metadata_json[segment_name]["main_vocal"] 115 | gt_source, sr = librosa.load( 116 | f"{self.total_segments_list[idx]}/gt/{gt_main_name} - {segment_name}.wav", 117 | sr=self.sample_rate, 118 | ) 119 | gt_rest_list = metadata_json[segment_name]["other_vocals"] 120 | ids.append(f"{gt_main_name} - {segment_name}") 121 | 122 | rest_sources_list = [] 123 | for other_vocal_name in gt_rest_list: 124 | s, sr = librosa.load( 125 | f"{self.total_segments_list[idx]}/gt/{other_vocal_name} - {segment_name}.wav", 126 | sr=self.sample_rate, 127 | ) 128 | rest_sources_list.append(s) 129 | ids.append(f"{other_vocal_name} - {segment_name}") 130 | rest_sources_list = np.stack(rest_sources_list, axis=0) 131 | gt_rest = rest_sources_list.sum(0) 132 | 133 | sources_list.append(gt_source) 134 | sources_list.append(gt_rest) 135 | else: # self.task == 'unison' or self.task == 'duet' 136 | for i, source_path in enumerate(sources_path_list): 137 | s, sr = librosa.load(source_path, sr=self.sample_rate) 138 | sources_list.append(s) 139 | ids.append(os.path.basename(source_path).replace(".wav", "")) 140 | # Read the mixture 141 | mixture, sr = librosa.load(mixture_path, sr=self.sample_rate) 142 | # Convert to torch tensor 143 | mixture = torch.as_tensor(mixture, dtype=torch.float32) 144 | # Stack sources 145 | sources = np.vstack(sources_list) 146 | # Convert sources to tensor 147 | sources = torch.as_tensor(sources, dtype=torch.float32) 148 | if not self.return_id: 149 | return mixture, sources 150 | # 5400-34479-0005_4973-24515-0007.wav 151 | # id1, id2 = mixture_path.split("/")[-1].split(".")[0].split("_") 152 | 153 | return mixture, sources, ids 154 | -------------------------------------------------------------------------------- /main/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import warnings 4 | from typing import Callable, List, Optional, Sequence 5 | 6 | import pkg_resources # type: ignore 7 | import pytorch_lightning as pl 8 | import rich.syntax 9 | import rich.tree 10 | import torch 11 | from omegaconf import DictConfig, OmegaConf 12 | from pytorch_lightning import Callback 13 | from pytorch_lightning.utilities import rank_zero_only 14 | 15 | 16 | def get_logger(name=__name__) -> logging.Logger: 17 | """Initializes multi-GPU-friendly python command line logger.""" 18 | 19 | logger = logging.getLogger(name) 20 | 21 | # this ensures all logging levels get marked with the rank zero decorator 22 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 23 | for level in ( 24 | "debug", 25 | "info", 26 | "warning", 27 | "error", 28 | "exception", 29 | "fatal", 30 | "critical", 31 | ): 32 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 33 | 34 | return logger 35 | 36 | 37 | log = get_logger(__name__) 38 | 39 | 40 | def extras(config: DictConfig) -> None: 41 | """Applies optional utilities, controlled by config flags. 42 | Utilities: 43 | - Ignoring python warnings 44 | - Rich config printing 45 | """ 46 | 47 | # disable python warnings if 48 | if config.get("ignore_warnings"): 49 | log.info("Disabling python warnings! ") 50 | warnings.filterwarnings("ignore") 51 | 52 | # pretty print config tree using Rich library if 53 | if config.get("print_config"): 54 | log.info("Printing config tree with Rich! ") 55 | print_config(config, resolve=True) 56 | 57 | 58 | @rank_zero_only 59 | def print_config( 60 | config: DictConfig, 61 | print_order: Sequence[str] = ( 62 | "datamodule", 63 | "model", 64 | "callbacks", 65 | "logger", 66 | "trainer", 67 | ), 68 | resolve: bool = True, 69 | ) -> None: 70 | """Prints content of DictConfig using Rich library and its tree structure. 71 | Args: 72 | config (DictConfig): Configuration composed by Hydra. 73 | print_order (Sequence[str], optional): Determines in what order config components are printed. 74 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 75 | """ 76 | 77 | style = "dim" 78 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 79 | 80 | quee = [] 81 | 82 | for field in print_order: 83 | quee.append(field) if field in config else log.info( 84 | f"Field '{field}' not found in config" 85 | ) 86 | 87 | for field in config: 88 | if field not in quee: 89 | quee.append(field) 90 | 91 | for field in quee: 92 | branch = tree.add(field, style=style, guide_style=style) 93 | 94 | config_group = config[field] 95 | if isinstance(config_group, DictConfig): 96 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 97 | else: 98 | branch_content = str(config_group) 99 | 100 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 101 | 102 | rich.print(tree) 103 | 104 | with open("config_tree.log", "w") as file: 105 | rich.print(tree, file=file) 106 | 107 | 108 | @rank_zero_only 109 | def log_hyperparameters( 110 | config: DictConfig, 111 | model: pl.LightningModule, 112 | datamodule: pl.LightningDataModule, 113 | trainer: pl.Trainer, 114 | callbacks: List[pl.Callback], 115 | logger: List[pl.loggers.Logger], 116 | ) -> None: 117 | """Controls which config parts are saved by Lightning loggers. 118 | Additionaly saves: 119 | - number of model parameters 120 | """ 121 | 122 | if not trainer.logger: 123 | return 124 | 125 | hparams = {} 126 | 127 | # choose which parts of hydra config will be saved to loggers 128 | hparams["model"] = config["model"] 129 | 130 | # save number of model parameters 131 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 132 | hparams["model/params/trainable"] = sum( 133 | p.numel() for p in model.parameters() if p.requires_grad 134 | ) 135 | hparams["model/params/non_trainable"] = sum( 136 | p.numel() for p in model.parameters() if not p.requires_grad 137 | ) 138 | 139 | hparams["datamodule"] = config["datamodule"] 140 | hparams["trainer"] = config["trainer"] 141 | 142 | if "seed" in config: 143 | hparams["seed"] = config["seed"] 144 | if "callbacks" in config: 145 | hparams["callbacks"] = config["callbacks"] 146 | 147 | hparams["pacakges"] = get_packages_list() 148 | 149 | # send hparams to all loggers 150 | trainer.logger.log_hyperparams(hparams) 151 | 152 | 153 | def finish( 154 | config: DictConfig, 155 | model: pl.LightningModule, 156 | datamodule: pl.LightningDataModule, 157 | trainer: pl.Trainer, 158 | callbacks: List[pl.Callback], 159 | logger: List[pl.loggers.Logger], 160 | ) -> None: 161 | """Makes sure everything closed properly.""" 162 | 163 | # without this sweeps with wandb logger might crash! 164 | for lg in logger: 165 | if isinstance(lg, pl.loggers.wandb.WandbLogger): 166 | import wandb 167 | 168 | wandb.finish() 169 | 170 | 171 | def get_packages_list() -> List[str]: 172 | return [f"{p.project_name}=={p.version}" for p in pkg_resources.working_set] 173 | 174 | 175 | def retry_if_error(fn: Callable, num_attemps: int = 10): 176 | for attempt in range(num_attemps): 177 | try: 178 | return fn() 179 | except: 180 | print(f"Retrying, attempt {attempt+1}") 181 | pass 182 | return fn() 183 | 184 | 185 | class SavePytorchModelAndStopCallback(Callback): 186 | def __init__(self, path: str, attribute: Optional[str] = None): 187 | self.path = path 188 | self.attribute = attribute 189 | 190 | def on_train_start(self, trainer, pl_module): 191 | model, path = pl_module, self.path 192 | if self.attribute is not None: 193 | assert_message = "provided model attribute not found in pl_module" 194 | assert hasattr(pl_module, self.attribute), assert_message 195 | model = getattr( 196 | pl_module, self.attribute, hasattr(pl_module, self.attribute) 197 | ) 198 | # Make dir if not existent 199 | os.makedirs(os.path.split(path)[0], exist_ok=True) 200 | # Save model 201 | torch.save(model, path) 202 | log.info(f"PyTorch model saved at: {path}") 203 | # Stop trainer 204 | trainer.should_stop = True 205 | -------------------------------------------------------------------------------- /main/inference.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from functools import partial 3 | from pathlib import Path 4 | from typing import List, Optional, Callable, Mapping 5 | import pytorch_lightning as pl 6 | 7 | import torch 8 | import torchaudio 9 | import tqdm 10 | from math import sqrt, ceil 11 | import hydra 12 | from audio_data_pytorch.utils import fractional_random_split 13 | 14 | from audio_diffusion_pytorch.diffusion import Schedule 15 | from torch.utils.data import DataLoader 16 | from torchaudio.datasets import VCTK_092 17 | 18 | from main.dataset import vctk_collate 19 | from main.module_vckt import Model 20 | 21 | 22 | class Separator(torch.nn.Module, abc.ABC): 23 | def __init__(self): 24 | super().__init__() 25 | 26 | @abc.abstractmethod 27 | def separate(mixture, num_steps) -> Mapping[str, torch.Tensor]: 28 | ... 29 | 30 | 31 | class WeaklyMSDMSeparator(Separator): 32 | def __init__(self, stem_to_model: Mapping[str, Model], sigma_schedule, **kwargs): 33 | super().__init__() 34 | self.stem_to_model = stem_to_model 35 | self.separation_kwargs = kwargs 36 | self.sigma_schedule = sigma_schedule 37 | 38 | def separate(self, mixture: torch.Tensor, num_steps: int): 39 | stems = self.stem_to_model.keys() 40 | models = [self.stem_to_model[s] for s in stems] 41 | fns = [m.model.diffusion.diffusion.denoise_fn for m in models] 42 | 43 | # get device of models 44 | devices = {m.device for m in models} 45 | assert len(devices) == 1, devices 46 | (device,) = devices 47 | 48 | mixture = mixture.to(device) 49 | batch_size, _, length_samples = mixture.shape 50 | 51 | def denoise_fn(x, sigma): 52 | xs = [x[:, i:i + 1] for i in range(4)] 53 | xs = [fn(x, sigma=sigma) for fn, x in zip(fns, xs)] 54 | return torch.cat(xs, dim=1) 55 | 56 | y = separate_mixture( 57 | mixture=mixture, 58 | denoise_fn=denoise_fn, 59 | sigmas=self.sigma_schedule(num_steps, device), 60 | noises=torch.randn(batch_size, len(stems), length_samples).to(device), 61 | **self.separation_kwargs, 62 | ) 63 | return {stem: y[:, i:i + 1, :] for i, stem in enumerate(stems)} 64 | 65 | 66 | # Algorithms ------------------------------------------------------------------ 67 | 68 | def differential_with_dirac(x, sigma, denoise_fn, mixture, source_id=0): 69 | num_sources = x.shape[1] 70 | x[:, [source_id], :] = mixture - (x.sum(dim=1, keepdim=True) - x[:, [source_id], :]) 71 | score = (x - denoise_fn(x, sigma=sigma)) / sigma 72 | scores = [score[:, si] for si in range(num_sources)] 73 | ds = [s - score[:, source_id] for s in scores] 74 | return torch.stack(ds, dim=1) 75 | 76 | 77 | def differential_with_gaussian(x, sigma, denoise_fn, mixture, gamma_fn=None): 78 | gamma = sigma if gamma_fn is None else gamma_fn(sigma) 79 | d = (x - denoise_fn(x, sigma=sigma)) / sigma 80 | d = d - sigma / (2* gamma ** 2) * (mixture - x.sum(dim=[1], keepdim=True)) 81 | return d 82 | 83 | 84 | @torch.no_grad() 85 | def separate_mixture( 86 | mixture: torch.Tensor, 87 | denoise_fn: Callable, 88 | sigmas: torch.Tensor, 89 | noises: Optional[torch.Tensor], 90 | differential_fn: Callable = differential_with_dirac, 91 | s_churn: float = 40.0, # > 0 to add randomness 92 | num_resamples: int = 2, 93 | use_tqdm: bool = False, 94 | ): 95 | # Set initial noise 96 | x = sigmas[0] * noises # [batch_size, num-sources, sample-length] 97 | 98 | vis_wrapper = tqdm.tqdm if use_tqdm else lambda x: x 99 | for i in vis_wrapper(range(len(sigmas) - 1)): 100 | sigma, sigma_next = sigmas[i], sigmas[i + 1] 101 | 102 | for r in range(num_resamples): 103 | # Inject randomness 104 | gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) 105 | sigma_hat = sigma * (gamma + 1) 106 | x = x + torch.randn_like(x) * (sigma_hat ** 2 - sigma ** 2) ** 0.5 107 | 108 | # Compute conditioned derivative 109 | d = differential_fn(mixture=mixture, x=x, sigma=sigma_hat, denoise_fn=denoise_fn) 110 | 111 | # Update integral 112 | x = x + d * (sigma_next - sigma_hat) 113 | 114 | # Renoise if not last resample step 115 | if r < num_resamples - 1: 116 | x = x + sqrt(sigma ** 2 - sigma_next ** 2) * torch.randn_like(x) 117 | 118 | return x.cpu().detach() 119 | 120 | 121 | # ----------------------------------------------------------------------------- 122 | def save_separation( 123 | separated_tracks: Mapping[str, torch.Tensor], 124 | sample_rate: int, 125 | chunk_path: Path, 126 | ): 127 | for stem, separated_track in separated_tracks.items(): 128 | torchaudio.save(chunk_path / f"{stem}.wav", separated_track.cpu(), sample_rate=sample_rate) 129 | 130 | 131 | if __name__ == '__main__': 132 | 133 | pl.seed_everything(12345) 134 | 135 | with hydra.initialize(config_path=".."): 136 | cfg = hydra.compose(config_name="exp/base_vctk_k_none.yaml") 137 | 138 | model = hydra.utils.instantiate(cfg['model']).cuda() 139 | vctk_checkpoint = torch.load('/home/emilian/PycharmProjects/multi-speaker-diff-sep/data/epoch=117-valid_loss=0.015.ckpt', 140 | map_location='cuda') 141 | model.load_state_dict(vctk_checkpoint['state_dict']) 142 | diffusion_schedule = hydra.utils.instantiate(cfg['callbacks']['audio_samples_logger']['diffusion_schedule']).cuda() 143 | separator = WeaklyMSDMSeparator(stem_to_model = {"voice_1": model.cuda(), 144 | "voice_2": model.cuda()}, 145 | sigma_schedule=diffusion_schedule, 146 | use_tqdm=True) 147 | 148 | dataset = VCTK_092(root="/home/emilian/PycharmProjects/multi-speaker-diff-sep/data/vctk", download=True) 149 | split = [1.0 - 0.01, 0.01] 150 | _, data_val = fractional_random_split(dataset, split) 151 | dataloader = DataLoader(dataset=data_val, 152 | batch_size=2, 153 | num_workers=0, 154 | pin_memory=False, 155 | drop_last=True, 156 | shuffle=False, 157 | collate_fn=partial(vctk_collate, mix_k=None)) 158 | data_iter = iter(dataloader) 159 | _ = next(data_iter) 160 | _ = next(data_iter) 161 | batch = next(data_iter) 162 | 163 | mix = batch.sum(dim=0, keepdim=True) 164 | torchaudio.save("./source_0.wav", batch[0].cpu(), sample_rate=22050) 165 | torchaudio.save("./source_1.wav", batch[1].cpu(), sample_rate=22050) 166 | separated_tracks = separator.separate(mixture=mix, num_steps=500) 167 | save_separation({k: v[0] for k, v in separated_tracks.items()}, 168 | sample_rate=22050, 169 | chunk_path=Path('.')) 170 | torchaudio.save("./mix.wav", mix[0].cpu(), sample_rate=22050) 171 | -------------------------------------------------------------------------------- /mlc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from typing import List, Callable 6 | from torchaudio.transforms import Spectrogram 7 | from functools import reduce, partial 8 | 9 | 10 | class MLC(nn.Module): 11 | def __init__( 12 | self, 13 | n_fft: int, 14 | sr: int, 15 | gammas: List[float], 16 | hop_size: int, 17 | Hipass_f: float = 50, 18 | Lowpass_t=0.24, 19 | **kwargs, 20 | ): 21 | super().__init__() 22 | self.n_fft = n_fft 23 | self.hop_size = hop_size 24 | 25 | self.stft = Spectrogram( 26 | n_fft=n_fft, 27 | hop_length=hop_size, 28 | power=2, 29 | normalized=True, 30 | onesided=False, 31 | **kwargs, 32 | ) 33 | hpi = int(Hipass_f * n_fft / sr) + 1 34 | lpi = int(Lowpass_t * sr / 1000) + 1 35 | layers = [lambda ceps, spec: (ceps, spec ** gammas[0])] 36 | 37 | def gamm_trsfm(x, g, i): 38 | x = torch.fft.fft(x, norm="ortho").real 39 | x[..., :i] = x[..., -i:] = 0 40 | return F.relu(x) ** g 41 | 42 | self.num_spec = 1 43 | self.num_ceps = 0 44 | 45 | for d, gamma in enumerate(gammas[1:]): 46 | if d % 2: 47 | layers.append(lambda ceps, _: (ceps, gamm_trsfm(ceps, gamma, hpi))) 48 | self.num_spec += 1 49 | else: 50 | layers.append(lambda _, spec: (gamm_trsfm(spec, gamma, lpi), spec)) 51 | self.num_ceps += 1 52 | 53 | self.compute = partial(reduce, lambda x, f: f(*x), layers) 54 | 55 | def forward(self, x): 56 | return self.compute((None, self.stft(x).transpose(-1, -2))) 57 | 58 | 59 | class Sparse_Pitch_Profile(nn.Module): 60 | def __init__(self, in_channels, sr, harms_range=24, division=1, norm=False): 61 | """ 62 | 63 | Parameters 64 | ---------- 65 | in_channels: int 66 | window size 67 | sr: int 68 | sample rate 69 | harms_range: int 70 | The extended area above (or below) the piano pitch range (in semitones) 71 | 25 : though somewhat larger, to ensure the coverage is large enough (if division=1, 24 is sufficient) 72 | division: int 73 | The division number for filterbank frequency resolution. The frequency resolution is 1 / division (semitone) 74 | norm: bool 75 | If set to True, normalize each filterbank so the weight of each filterbank sum to 1. 76 | """ 77 | super().__init__() 78 | step = 1 / division 79 | # midi_num shape = (88 + harms_range) * division + 2 80 | # this implementation make sure if we group midi_num with a size of division 81 | # each group will center at the piano pitch number and the extra pitch range 82 | # E.g., division = 2, midi_num = [20.25, 20.75, 21.25, ....] 83 | # dividion = 3, midi_num = [20.33, 20.67, 21, 21.33, ...] 84 | midi_num = np.arange( 85 | 20.5 - step / 2 - harms_range, 108.5 + step + harms_range, step 86 | ) 87 | self.midi_num = midi_num 88 | 89 | fd = 440 * np.power(2, (midi_num - 69) / 12) 90 | self.fd = fd 91 | 92 | self.effected_dim = in_channels // 2 + 1 93 | # // 2 : the spectrum/ cepstrum are symmetric 94 | 95 | x = np.arange(self.effected_dim) 96 | freq_f = x * sr / in_channels 97 | freq_t = sr / x[1:] 98 | # avoid explosion; x[0] is always 0 for cepstrum 99 | 100 | inter_value = np.array([0, 1, 0]) 101 | idxs = np.digitize(freq_f, fd) 102 | 103 | cols, rows, values = [], [], [] 104 | for i in range(harms_range * division, (88 + 2 * harms_range) * division): 105 | idx = np.where((idxs == i + 1) | (idxs == i + 2))[0] 106 | c = idx 107 | r = np.broadcast_to(i - harms_range * division, idx.shape) 108 | x = np.interp(freq_f[idx], fd[i : i + 3], inter_value).astype(np.float32) 109 | if norm and len(idx): 110 | # x /= (fd[i + 2] - fd[i]) / sr * in_channels 111 | x /= x.sum() # energy normalization 112 | 113 | if len(idx) == 0 and len(values) and len(values[-1]): 114 | # low resolution in the lower frequency (for spec)/ highter frequency (for ceps), 115 | # some filterbanks will not get any bin index, so we copy the indexes from the previous iteration 116 | c = cols[-1].copy() 117 | r = rows[-1].copy() 118 | r[:] = i - harms_range * division 119 | x = values[-1].copy() 120 | 121 | cols.append(c) 122 | rows.append(r) 123 | values.append(x) 124 | 125 | cols, rows, values = ( 126 | np.concatenate(cols), 127 | np.concatenate(rows), 128 | np.concatenate(values), 129 | ) 130 | self.filters_f_idx = (rows, cols) 131 | self.filters_f_values = nn.Parameter(torch.tensor(values), requires_grad=False) 132 | 133 | idxs = np.digitize(freq_t, fd) 134 | cols, rows, values = [], [], [] 135 | for i in range((88 + harms_range) * division - 1, -1, -1): 136 | idx = np.where((idxs == i + 1) | (idxs == i + 2))[0] 137 | c = idx + 1 138 | r = np.broadcast_to(i, idx.shape) 139 | x = np.interp(freq_t[idx], fd[i : i + 3], inter_value).astype(np.float32) 140 | if norm and len(idx): 141 | # x /= (1 / fd[i] - 1 / fd[i + 2]) * sr 142 | x /= x.sum() 143 | 144 | if len(idx) == 0 and len(values) and len(values[-1]): 145 | c = cols[-1].copy() 146 | r = rows[-1].copy() 147 | r[:] = i 148 | x = values[-1].copy() 149 | 150 | cols.append(c) 151 | rows.append(r) 152 | values.append(x) 153 | 154 | cols, rows, values = ( 155 | np.concatenate(cols), 156 | np.concatenate(rows), 157 | np.concatenate(values), 158 | ) 159 | self.filters_t_idx = (rows, cols) 160 | self.filters_t_values = nn.Parameter(torch.tensor(values), requires_grad=False) 161 | self.filter_size = torch.Size( 162 | ((88 + harms_range) * division, self.effected_dim) 163 | ) 164 | 165 | def forward(self, ceps, spec): 166 | ceps, spec = ceps[..., : self.effected_dim], spec[..., : self.effected_dim] 167 | batch_dim, steps, _ = ceps.size() 168 | filter_f = torch.sparse_coo_tensor( 169 | self.filters_f_idx, self.filters_f_values, self.filter_size 170 | ) 171 | filter_t = torch.sparse_coo_tensor( 172 | self.filters_t_idx, self.filters_t_values, self.filter_size 173 | ) 174 | ppt = filter_t @ ceps.transpose(0, 2).contiguous().view(self.effected_dim, -1) 175 | ppf = filter_f @ spec.transpose(0, 2).contiguous().view(self.effected_dim, -1) 176 | return ppt.view(-1, steps, batch_dim).transpose(0, 2), ppf.view( 177 | -1, steps, batch_dim 178 | ).transpose(0, 2) 179 | -------------------------------------------------------------------------------- /main/module_vckt.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, List, Optional 2 | 3 | import librosa 4 | import plotly.graph_objs as go 5 | import pytorch_lightning as pl 6 | import torch 7 | import torchaudio 8 | import wandb 9 | from audio_data_pytorch.utils import fractional_random_split 10 | from audio_diffusion_pytorch import AudioDiffusionModel, Sampler, Schedule 11 | from einops import rearrange 12 | 13 | # from ema_pytorch import EMA 14 | from pytorch_lightning import Callback, Trainer 15 | from pytorch_lightning.loggers import WandbLogger 16 | from torch import Tensor, nn 17 | from torch.utils.data import DataLoader 18 | 19 | """ Model """ 20 | 21 | 22 | class Model(pl.LightningModule): 23 | def __init__( 24 | self, 25 | lr: float, 26 | lr_beta1: float, 27 | lr_beta2: float, 28 | lr_eps: float, 29 | lr_weight_decay: float, 30 | # ema_beta: float, 31 | # ema_power: float, 32 | model: nn.Module, 33 | ): 34 | super().__init__() 35 | self.lr = lr 36 | self.lr_beta1 = lr_beta1 37 | self.lr_beta2 = lr_beta2 38 | self.lr_eps = lr_eps 39 | self.lr_weight_decay = lr_weight_decay 40 | self.model = model 41 | # self.model_ema = EMA(self.model, beta=ema_beta, power=ema_power) 42 | 43 | @property 44 | def device(self): 45 | return next(self.model.parameters()).device 46 | 47 | def configure_optimizers(self): 48 | optimizer = torch.optim.AdamW( 49 | list(self.model.parameters()), 50 | lr=self.lr, 51 | betas=(self.lr_beta1, self.lr_beta2), 52 | eps=self.lr_eps, 53 | weight_decay=self.lr_weight_decay, 54 | ) 55 | return optimizer 56 | 57 | def training_step(self, batch, batch_idx): 58 | waveforms = batch 59 | loss = self.model(waveforms) 60 | self.log("train_loss", loss) 61 | # Update EMA model and log decay 62 | # self.model_ema.update() 63 | # self.log("ema_decay", self.model_ema.get_current_decay()) 64 | return loss 65 | 66 | def validation_step(self, batch, batch_idx): 67 | waveforms = batch 68 | loss = self.model(waveforms) 69 | self.log("valid_loss", loss) 70 | return loss 71 | 72 | 73 | """ Datamodule """ 74 | 75 | 76 | class Datamodule(pl.LightningDataModule): 77 | def __init__( 78 | self, 79 | dataset, 80 | collate_fn, 81 | *, 82 | val_split: float, 83 | batch_size: int, 84 | num_workers: int, 85 | pin_memory: bool = False, 86 | **kwargs: int, 87 | ) -> None: 88 | super().__init__() 89 | self.dataset = dataset 90 | self.collate_fn = collate_fn 91 | self.val_split = val_split 92 | self.batch_size = batch_size 93 | self.num_workers = num_workers 94 | self.pin_memory = pin_memory 95 | self.data_train: Any = None 96 | self.data_val: Any = None 97 | 98 | def setup(self, stage: Any = None) -> None: 99 | split = [1.0 - self.val_split, self.val_split] 100 | self.data_train, self.data_val = fractional_random_split(self.dataset, split) 101 | 102 | def train_dataloader(self) -> DataLoader: 103 | return DataLoader( 104 | dataset=self.data_train, 105 | collate_fn=self.collate_fn, 106 | batch_size=self.batch_size, 107 | num_workers=self.num_workers, 108 | pin_memory=self.pin_memory, 109 | drop_last=True, 110 | shuffle=True, 111 | ) 112 | 113 | def val_dataloader(self) -> DataLoader: 114 | return DataLoader( 115 | dataset=self.data_val, 116 | collate_fn=self.collate_fn, 117 | batch_size=self.batch_size, 118 | num_workers=self.num_workers, 119 | pin_memory=self.pin_memory, 120 | drop_last=True, 121 | shuffle=True, 122 | ) 123 | 124 | 125 | """ Callbacks """ 126 | 127 | 128 | def get_wandb_logger(trainer: Trainer) -> Optional[WandbLogger]: 129 | """Safely get Weights&Biases logger from Trainer.""" 130 | 131 | if isinstance(trainer.logger, WandbLogger): 132 | return trainer.logger 133 | 134 | for logger in trainer.loggers: 135 | if isinstance(logger, WandbLogger): 136 | return logger 137 | 138 | print("WandbLogger not found.") 139 | return None 140 | 141 | 142 | def log_wandb_audio_batch( 143 | logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = "" 144 | ): 145 | num_items = samples.shape[0] 146 | samples = rearrange(samples, "b c t -> b t c").detach().cpu().numpy() 147 | logger.log( 148 | { 149 | f"sample_{idx}_{id}": wandb.Audio( 150 | samples[idx], 151 | caption=caption, 152 | sample_rate=sampling_rate, 153 | ) 154 | for idx in range(num_items) 155 | } 156 | ) 157 | 158 | 159 | def log_wandb_audio_spectrogram( 160 | logger: WandbLogger, id: str, samples: Tensor, sampling_rate: int, caption: str = "" 161 | ): 162 | num_items = samples.shape[0] 163 | samples = samples.detach().cpu() 164 | transform = torchaudio.transforms.MelSpectrogram( 165 | sample_rate=sampling_rate, 166 | n_fft=1024, 167 | hop_length=512, 168 | n_mels=80, 169 | center=True, 170 | norm="slaney", 171 | ) 172 | 173 | def get_spectrogram_image(x): 174 | spectrogram = transform(x[0]) 175 | image = librosa.power_to_db(spectrogram) 176 | trace = [go.Heatmap(z=image, colorscale="viridis")] 177 | layout = go.Layout( 178 | yaxis=dict(title="Mel Bin (Log Frequency)"), 179 | xaxis=dict(title="Frame"), 180 | title_text=caption, 181 | title_font_size=10, 182 | ) 183 | fig = go.Figure(data=trace, layout=layout) 184 | return fig 185 | 186 | logger.log( 187 | { 188 | f"mel_spectrogram_{idx}_{id}": get_spectrogram_image(samples[idx]) 189 | for idx in range(num_items) 190 | } 191 | ) 192 | 193 | 194 | class SampleLogger(Callback): 195 | def __init__( 196 | self, 197 | num_items: int, 198 | channels: int, 199 | sampling_rate: int, 200 | length: int, 201 | sampling_steps: List[int], 202 | diffusion_schedule: Schedule, 203 | diffusion_sampler: Sampler, 204 | # use_ema_model: bool, 205 | ) -> None: 206 | self.num_items = num_items 207 | self.channels = channels 208 | self.sampling_rate = sampling_rate 209 | self.length = length 210 | self.sampling_steps = sampling_steps 211 | self.diffusion_schedule = diffusion_schedule 212 | self.diffusion_sampler = diffusion_sampler 213 | # self.use_aa_model = use_ema_model 214 | self.log_next = False 215 | 216 | def on_validation_epoch_start(self, trainer, pl_module): 217 | self.log_next = True 218 | 219 | def on_validation_batch_start( 220 | self, trainer, pl_module, batch, batch_idx, dataloader_idx 221 | ): 222 | if self.log_next: 223 | self.log_sample(trainer, pl_module, batch) 224 | self.log_next = False 225 | 226 | @torch.no_grad() 227 | def log_sample(self, trainer, pl_module, batch): 228 | is_train = pl_module.training 229 | if is_train: 230 | pl_module.eval() 231 | 232 | wandb_logger = get_wandb_logger(trainer).experiment 233 | 234 | diffusion_model = pl_module.model 235 | # if self.use_ema_model: 236 | # diffusion_model = pl_module.model_ema.ema_model 237 | 238 | # Get start diffusion noise 239 | noise = torch.randn( 240 | (self.num_items, self.channels, self.length), device=pl_module.device 241 | ) 242 | 243 | for steps in self.sampling_steps: 244 | samples = diffusion_model.sample( 245 | noise=noise, 246 | sampler=self.diffusion_sampler, 247 | sigma_schedule=self.diffusion_schedule, 248 | num_steps=steps, 249 | ) 250 | log_wandb_audio_batch( 251 | logger=wandb_logger, 252 | id="sample", 253 | samples=samples, 254 | sampling_rate=self.sampling_rate, 255 | caption=f"Sampled in {steps} steps", 256 | ) 257 | log_wandb_audio_spectrogram( 258 | logger=wandb_logger, 259 | id="sample", 260 | samples=samples, 261 | sampling_rate=self.sampling_rate, 262 | caption=f"Sampled in {steps} steps", 263 | ) 264 | 265 | if is_train: 266 | pl_module.train() 267 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hydra 3 | import os 4 | import torch 5 | import pathlib 6 | from tqdm import tqdm 7 | from typing import Callable, Optional 8 | from math import sqrt 9 | from asteroid.metrics import get_metrics 10 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr 11 | import soundfile as sf 12 | 13 | from medley_vox import MedleyVox 14 | 15 | COMPUTE_METRICS = ["si_sdr", "sdr"] 16 | 17 | 18 | @torch.no_grad() 19 | def separate_mixture( 20 | mixture: torch.Tensor, 21 | noises: torch.Tensor, 22 | denoise_fn: Callable, 23 | sigmas: torch.Tensor, 24 | cond: Optional[torch.Tensor] = None, 25 | cond_index: int = 0, 26 | s_churn: float = 40.0, # > 0 to add randomness 27 | num_resamples: int = 2, 28 | use_tqdm: bool = False, 29 | gaussian: bool = False, 30 | ): 31 | # Set initial noise 32 | x = sigmas[0] * noises # [batch_size, num-sources, sample-length] 33 | 34 | for i in tqdm(range(len(sigmas) - 1), disable=not use_tqdm): 35 | sigma, sigma_next = sigmas[i], sigmas[i + 1] 36 | 37 | for r in range(num_resamples): 38 | # Inject randomness 39 | gamma = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) 40 | sigma_hat = sigma * (gamma + 1) 41 | x = x + torch.randn_like(x) * (sigma_hat**2 - sigma**2) ** 0.5 42 | 43 | if cond is not None: 44 | noisey_cond = cond + torch.randn_like(cond) * sigma 45 | x[:, :cond_index] = noisey_cond 46 | 47 | # Compute conditioned derivative 48 | if not gaussian: 49 | x[:1] = mixture - x[1:].sum(dim=0, keepdim=True) 50 | score = (x - denoise_fn(x, sigma=sigma)) / sigma 51 | if gaussian: 52 | d = score + sigma / (2 * gamma**2) * (mixture - x.sum(dim=0)) 53 | x += d * (sigma_next - sigma_hat) 54 | else: 55 | ds = score[1:] - score[:1] 56 | 57 | # Update integral 58 | x[1:] += ds * (sigma_next - sigma_hat) 59 | 60 | # Renoise if not last resample step 61 | if r < num_resamples - 1: 62 | x = x + torch.sqrt(sigma**2 - sigma_next**2) * torch.randn_like(x) 63 | 64 | return x 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("config", type=str, help="Path to config file") 70 | parser.add_argument("ckpt", type=str, help="Path to checkpoint file") 71 | parser.add_argument("medleyvox", type=str, help="Path to MedleyVox dataset") 72 | parser.add_argument("-T", default=100, type=int, help="Number of diffusion steps") 73 | parser.add_argument("-S", default=40.0, type=float, help="S churn") 74 | parser.add_argument("--out", type=str, help="Output directory") 75 | parser.add_argument("--cond", action="store_true", help="Use conditioning") 76 | parser.add_argument( 77 | "--self-cond", action="store_true", help="Use self conditioning" 78 | ) 79 | parser.add_argument("--hop-length", type=int, help="Hop length") 80 | parser.add_argument("--window", type=int, help="Window size") 81 | parser.add_argument("--full-duet", action="store_true", help="Drop duet songs") 82 | parser.add_argument("--retry", type=int, default=0, help="Retry") 83 | parser.add_argument("--outer-retry", type=int, default=0, help="Outer retry") 84 | 85 | args = parser.parse_args() 86 | 87 | config_path, config_name = os.path.split(args.config) 88 | 89 | with hydra.initialize(config_path=config_path): 90 | cfg = hydra.compose(config_name=config_name) 91 | 92 | sr = cfg.sampling_rate 93 | length = cfg.length 94 | 95 | model = hydra.utils.instantiate(cfg.model) 96 | vctk_checkpoint = torch.load( 97 | args.ckpt, 98 | map_location="cpu", 99 | ) 100 | model.load_state_dict(vctk_checkpoint["state_dict"]) 101 | diffusion_schedule = hydra.utils.instantiate( 102 | cfg.callbacks.audio_samples_logger.diffusion_schedule 103 | ) 104 | 105 | model = model.cuda() 106 | model.eval() 107 | diffusion_schedule = diffusion_schedule.cuda() 108 | 109 | inner_denoise_fn = model.model.diffusion.diffusion.denoise_fn 110 | 111 | def denoise_fn(x, sigma): 112 | x = x.unsqueeze(1) 113 | return inner_denoise_fn(x, sigma=sigma).squeeze(1) 114 | 115 | sigmas = diffusion_schedule(args.T, "cuda") 116 | 117 | dataset = MedleyVox( 118 | args.medleyvox, 119 | sample_rate=sr, 120 | drop_duet=not args.full_duet, 121 | ) 122 | 123 | hop_length = length // 2 if args.hop_length is None else args.hop_length 124 | window_size = length if args.window is None else args.window 125 | loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") 126 | s_churn = args.S 127 | 128 | if args.cond: 129 | print(window_size / sr, hop_length / sr) 130 | 131 | accumulate_metrics_mean = {} 132 | 133 | with tqdm(dataset) as pbar: 134 | for mix_num, (x_cpu, y_cpu, ids) in enumerate(pbar): 135 | x = x_cpu.cuda() 136 | y = y_cpu.cuda() 137 | n = y.shape[0] 138 | 139 | outer_trials = [] 140 | for i in range(args.outer_retry + 1): 141 | if args.cond: 142 | cond = torch.zeros(n, window_size).cuda() 143 | sub_m = torch.zeros(window_size).cuda() 144 | result = [] 145 | for sub_x, sub_y in zip( 146 | torch.split(x, hop_length), torch.split(y, hop_length, 1) 147 | ): 148 | noise = torch.randn(n, window_size).cuda() 149 | overlap_size = window_size - sub_x.numel() 150 | if overlap_size > 0: 151 | sub_m = torch.cat([sub_m[-overlap_size:], sub_x]) 152 | cond = cond[:, -overlap_size:] 153 | else: 154 | sub_m = sub_x 155 | cond = None 156 | 157 | trials = [] 158 | for i in range(args.retry + 1): 159 | pred = separate_mixture( 160 | sub_m, 161 | noise, 162 | denoise_fn, 163 | sigmas, 164 | s_churn=s_churn, 165 | cond=cond, 166 | cond_index=overlap_size, 167 | use_tqdm=False, 168 | gaussian=False, 169 | ) 170 | sub_pred = pred[:, -sub_x.numel() :] 171 | 172 | if args.retry > 0: 173 | loss, align_pred = loss_func( 174 | sub_pred.unsqueeze(0), 175 | sub_y.unsqueeze(0), 176 | return_est=True, 177 | ) 178 | trials.append((loss, align_pred.squeeze())) 179 | else: 180 | trials.append((0, sub_pred)) 181 | 182 | _, sub_pred = min(trials, key=lambda x: x[0]) 183 | cond = torch.cat( 184 | ([] if cond is None else [cond]) 185 | + [sub_pred if args.self_cond else sub_y], 186 | dim=1, 187 | ) 188 | result.append(sub_pred) 189 | 190 | result = torch.cat(result, dim=1) 191 | else: 192 | original_length = x.numel() 193 | padding = length - (original_length % length) 194 | if padding < length: 195 | x = torch.cat([x, x.new_zeros(padding)], dim=0) 196 | 197 | result = separate_mixture( 198 | x, 199 | torch.randn(n, x.numel()).cuda(), 200 | denoise_fn, 201 | sigmas, 202 | s_churn=s_churn, 203 | use_tqdm=False, 204 | )[:, :original_length] 205 | 206 | loss, reordered_sources = loss_func( 207 | result.unsqueeze(0), y.unsqueeze(0), return_est=True 208 | ) 209 | outer_trials.append((loss, reordered_sources)) 210 | 211 | _, reordered_sources = min(outer_trials, key=lambda x: x[0]) 212 | est = reordered_sources.squeeze().cpu().numpy() 213 | 214 | utt_metrics = get_metrics( 215 | x_cpu.numpy(), 216 | y_cpu.numpy(), 217 | est, 218 | sample_rate=sr, 219 | metrics_list=COMPUTE_METRICS, 220 | ) 221 | 222 | # calculate improvement 223 | for metric in COMPUTE_METRICS: 224 | v = utt_metrics.pop("input_" + metric) 225 | utt_metrics[metric + "i"] = utt_metrics[metric] - v 226 | 227 | for k, v in utt_metrics.items(): 228 | if k not in accumulate_metrics_mean: 229 | accumulate_metrics_mean[k] = 0 230 | 231 | accumulate_metrics_mean[k] += (v - accumulate_metrics_mean[k]) / ( 232 | mix_num + 1 233 | ) 234 | 235 | pbar.set_postfix(accumulate_metrics_mean) 236 | 237 | if args.out is not None: 238 | out_dir = pathlib.Path(args.out) / f"medleyvox_{mix_num}" 239 | out_dir.mkdir(parents=True, exist_ok=True) 240 | 241 | sf.write( 242 | out_dir / "mixture.wav", 243 | x_cpu.numpy(), 244 | sr, 245 | "PCM_16", 246 | ) 247 | 248 | for i, s in enumerate(est): 249 | out_path = out_dir / f"{ids[i]}.wav" 250 | sf.write(out_path, s, sr, "PCM_16") 251 | 252 | print(accumulate_metrics_mean) 253 | -------------------------------------------------------------------------------- /eval_nmf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hydra 3 | import os 4 | import torch 5 | import pathlib 6 | from tqdm import tqdm 7 | from typing import Callable, Optional 8 | from itertools import combinations, accumulate 9 | from functools import reduce 10 | import numpy as np 11 | from math import sqrt 12 | from asteroid.metrics import get_metrics 13 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr 14 | import soundfile as sf 15 | from torchnmf.nmf import NMF 16 | import torch.nn.functional as F 17 | from torchaudio.transforms import Spectrogram, InverseSpectrogram 18 | 19 | from mlc import Sparse_Pitch_Profile, MLC 20 | from medley_vox import MedleyVox 21 | from eval import COMPUTE_METRICS 22 | 23 | 24 | def get_harmonics(n_fft, freqs, window_fn=torch.hann_window): 25 | window = window_fn(n_fft) 26 | tmp = [] 27 | for f in freqs: 28 | h = torch.arange(int(0.5 * sr / f)) + 1 29 | # h = h[:40] 30 | ch = ( 31 | f 32 | * 27 33 | / 4 34 | * ( 35 | torch.exp(-1j * torch.pi * h) 36 | + 2 * (1 + 2 * torch.exp(-1j * torch.pi * h)) / (1j * torch.pi * h) 37 | - 6 * (1 - torch.exp(-1j * torch.pi * h)) / (1j * torch.pi * h) ** 2 38 | ) 39 | ) 40 | ch /= torch.abs(ch).max() 41 | t = torch.arange(n_fft) / sr 42 | eu = ch @ torch.exp(2j * torch.pi * h[:, None] * f * t) 43 | eu /= torch.linalg.norm(eu) 44 | tmp.append(torch.abs(torch.fft.fft(eu * window.numpy())[: n_fft // 2 + 1])) 45 | 46 | noise = torch.ones(n_fft // 2 + 1) 47 | tmp.append(noise) 48 | 49 | W_f0 = torch.stack(tmp).T 50 | W_f0 /= W_f0.max(0).values 51 | return W_f0 52 | 53 | 54 | def get_streams( 55 | cfp: torch.Tensor, 56 | freqs: np.ndarray, 57 | n: int = 2, 58 | thresh_ratio: float = 0.1, 59 | ): 60 | top_values, top_indices = torch.topk(cfp, n * 3, dim=1) 61 | thresh = top_values.max() * thresh_ratio 62 | top_freqs = freqs[top_indices] 63 | 64 | # remove zero values 65 | top_zipped = [ 66 | tuple(x[1] for x in filter(lambda x: x[0] > thresh, zip(*pair))) 67 | for pair in zip(top_values.tolist(), top_freqs.tolist()) 68 | ] 69 | 70 | # init states 71 | 72 | def cases(states: tuple, possible_states: tuple): 73 | possible_states = tuple(sorted(possible_states)) 74 | is_none = lambda x: x is None 75 | none_mapper = map(is_none, states) 76 | none_states = list(filter(is_none, states)) 77 | valid_states = list(filter(lambda x: not is_none(x), states)) 78 | 79 | if len(valid_states) == 0: 80 | return possible_states[:n] + (None,) * max(0, n - len(possible_states)) 81 | elif len(possible_states) == 0: 82 | return (None,) * n 83 | 84 | def get_dist(curr, incoming): 85 | diff = abs(curr - incoming) 86 | return diff**2 87 | 88 | # first, possible_states are more than valid_states 89 | if len(possible_states) > len(valid_states): 90 | permutes = list( 91 | combinations(range(len(possible_states)), len(valid_states)) 92 | ) 93 | permuted_states = [ 94 | [possible_states[i] for i in permute] for permute in permutes 95 | ] 96 | dists = map( 97 | lambda permute: reduce( 98 | lambda acc, pair: acc + get_dist(*pair), 99 | zip(valid_states, permute), 100 | 0, 101 | ), 102 | permuted_states, 103 | ) 104 | min_permute_index = min(zip(dists, permutes), key=lambda x: x[0])[1] 105 | new_valid_states = tuple(possible_states[i] for i in min_permute_index) 106 | new_none_states = tuple( 107 | possible_states[i] 108 | for i in range(len(possible_states)) 109 | if i not in min_permute_index 110 | ) 111 | else: 112 | permutes = list( 113 | combinations(range(len(valid_states)), len(possible_states)) 114 | ) 115 | permuted_valid_states = [ 116 | [valid_states[i] for i in permute] for permute in permutes 117 | ] 118 | dists = map( 119 | lambda permute: reduce( 120 | lambda acc, pair: acc + get_dist(*pair), 121 | zip(permute, possible_states), 122 | 0, 123 | ), 124 | permuted_valid_states, 125 | ) 126 | min_permute_index = min(zip(dists, permutes), key=lambda x: x[0])[1] 127 | new_valid_states = () 128 | for i in range(len(valid_states)): 129 | if i in min_permute_index: 130 | new_valid_states += (possible_states[min_permute_index.index(i)],) 131 | else: 132 | new_valid_states += (None,) 133 | 134 | new_none_states = () 135 | new_none_states = ( 136 | new_none_states[: len(none_states)] 137 | if len(new_none_states) > len(none_states) 138 | else new_none_states + (None,) * (len(none_states) - len(new_none_states)) 139 | ) 140 | 141 | # merge states 142 | new_states, *_ = reduce( 143 | lambda acc, is_none: (acc[0] + acc[1][:1], acc[1][1:], acc[2]) 144 | if is_none 145 | else (acc[0] + acc[2][:1], acc[1], acc[2][1:]), 146 | none_mapper, 147 | ((), new_none_states, new_valid_states), 148 | ) 149 | return new_states 150 | 151 | state_changes = list(accumulate(top_zipped, func=cases, initial=(None,) * n))[1:] 152 | 153 | return list(zip(*state_changes)) 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("medleyvox", type=str, help="Path to MedleyVox dataset") 159 | parser.add_argument("--n_fft", default=8192, type=int, help="N FFT") 160 | parser.add_argument("--hop_length", default=256, type=int, help="Hop size") 161 | parser.add_argument("--win_length", default=2048, type=int, help="Window size") 162 | parser.add_argument( 163 | "--window", 164 | choices=["hann", "blackman", "hamming"], 165 | help="Window type", 166 | default="hann", 167 | ) 168 | parser.add_argument("--out", type=str, help="Output directory") 169 | parser.add_argument("--full-duet", action="store_true", help="Drop duet songs") 170 | parser.add_argument("--divisions", type=int, default=8, help="Divisions") 171 | parser.add_argument("--min-f0", type=float, default=50, help="Min F0") 172 | parser.add_argument("--max-f0", type=float, default=1000, help="Max F0") 173 | parser.add_argument( 174 | "--gammas", type=float, nargs="+", help="Gammas", default=[0.2, 0.6, 0.8] 175 | ) 176 | parser.add_argument( 177 | "--thresh", type=float, default=0.0, help="Salience threshold ratio" 178 | ) 179 | parser.add_argument("--beta", type=float, default=1.0, help="Beta") 180 | parser.add_argument("--kernel-size", type=int, default=3, help="Kernel size") 181 | 182 | args = parser.parse_args() 183 | 184 | window_fn = { 185 | "hann": torch.hann_window, 186 | "blackman": torch.blackman_window, 187 | "hamming": torch.hamming_window, 188 | }[args.window] 189 | 190 | sr = 24000 191 | 192 | dataset = MedleyVox( 193 | args.medleyvox, 194 | sample_rate=sr, 195 | drop_duet=not args.full_duet, 196 | ) 197 | 198 | mlc = MLC( 199 | args.n_fft, 200 | sr, 201 | args.gammas, 202 | args.hop_length, 203 | win_length=args.win_length, 204 | window_fn=window_fn, 205 | ) 206 | 207 | pitch_profiler = Sparse_Pitch_Profile( 208 | args.n_fft, sr, 0, division=args.divisions, norm=True 209 | ) 210 | freqs = pitch_profiler.fd[1:-1] 211 | 212 | idx_low = (freqs >= args.min_f0).nonzero()[0][0] 213 | idx_high = (freqs <= args.max_f0).nonzero()[0][-1] 214 | selection_slice = slice(idx_low, idx_high + 1) 215 | freqs = freqs[selection_slice] 216 | 217 | W_f0 = get_harmonics(args.win_length, freqs, window_fn=window_fn) 218 | 219 | spec = Spectrogram( 220 | n_fft=args.win_length, 221 | hop_length=args.hop_length, 222 | power=None, 223 | window_fn=window_fn, 224 | ) 225 | inv_spec = InverseSpectrogram( 226 | n_fft=args.win_length, 227 | hop_length=args.hop_length, 228 | window_fn=window_fn, 229 | ) 230 | 231 | loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") 232 | 233 | accumulate_metrics_mean = {} 234 | 235 | with tqdm(dataset) as pbar: 236 | for mix_num, (x, y, ids) in enumerate(pbar): 237 | ppt, ppf = pitch_profiler(*mlc(x.unsqueeze(0))) 238 | 239 | # smoothing 240 | kernel_size = args.kernel_size 241 | ppt = F.avg_pool1d(ppt, kernel_size, stride=1, padding=kernel_size // 2) 242 | ppf = F.avg_pool1d(ppf, kernel_size, stride=1, padding=kernel_size // 2) 243 | 244 | cfp = (ppt * ppf).squeeze(0) 245 | 246 | # peak picking 247 | cfp = torch.where((cfp > cfp.roll(1, 1)) & (cfp > cfp.roll(-1, 1)), cfp, 0) 248 | cfp = cfp[:, selection_slice] 249 | stream1, stream2 = get_streams( 250 | cfp, np.arange(len(freqs)), n=2, thresh_ratio=args.thresh 251 | ) 252 | 253 | singer1_H = torch.zeros(cfp.shape[0], len(freqs) + 1) 254 | singer2_H = singer1_H.clone() 255 | singer1_H.scatter_( 256 | 1, 257 | torch.tensor( 258 | list( 259 | map( 260 | lambda x: [len(freqs)] * 3 261 | if x is None 262 | else [x - 1, x, x + 1], 263 | stream1, 264 | ) 265 | ) 266 | ), 267 | 1, 268 | ) 269 | singer2_H.scatter_( 270 | 1, 271 | torch.tensor( 272 | list( 273 | map( 274 | lambda x: [len(freqs)] * 3 275 | if x is None 276 | else [x - 1, x, x + 1], 277 | stream2, 278 | ) 279 | ) 280 | ), 281 | 1, 282 | ) 283 | 284 | nmf = NMF( 285 | W=torch.cat([W_f0, W_f0], dim=1), 286 | H=torch.cat([singer1_H, singer2_H], dim=1), 287 | trainable_W=False, 288 | ) 289 | 290 | X = spec(x) 291 | 292 | nmf.fit(X.abs().T, beta=args.beta, alpha=1e-6) 293 | 294 | with torch.no_grad(): 295 | H1, H2 = nmf.H.chunk(2, dim=1) 296 | W1, W2 = nmf.W.chunk(2, dim=1) 297 | recon_singer1 = H1 @ W1.T 298 | recon_singer2 = H2 @ W2.T 299 | recon = recon_singer1 + recon_singer2 300 | mask1 = recon_singer1 / recon 301 | mask2 = recon_singer2 / recon 302 | 303 | y1 = inv_spec(X * mask1.T) 304 | y2 = inv_spec(X * mask2.T) 305 | 306 | result = torch.stack([y1, y2], dim=0) 307 | 308 | if result.shape[1] < y.shape[1]: 309 | result = F.pad( 310 | result.unsqueeze(0), 311 | (0, y.shape[1] - result.shape[1]), 312 | ).squeeze(0) 313 | 314 | loss, reordered_sources = loss_func( 315 | result.unsqueeze(0), y.unsqueeze(0), return_est=True 316 | ) 317 | est = reordered_sources.squeeze().cpu().numpy() 318 | 319 | utt_metrics = get_metrics( 320 | x.numpy(), 321 | y.numpy(), 322 | est, 323 | sample_rate=sr, 324 | metrics_list=COMPUTE_METRICS, 325 | ) 326 | 327 | # calculate improvement 328 | for metric in COMPUTE_METRICS: 329 | v = utt_metrics.pop("input_" + metric) 330 | utt_metrics[metric + "i"] = utt_metrics[metric] - v 331 | 332 | for k, v in utt_metrics.items(): 333 | if k not in accumulate_metrics_mean: 334 | accumulate_metrics_mean[k] = 0 335 | 336 | accumulate_metrics_mean[k] += (v - accumulate_metrics_mean[k]) / ( 337 | mix_num + 1 338 | ) 339 | 340 | pbar.set_postfix(accumulate_metrics_mean) 341 | 342 | if args.out is not None: 343 | out_dir = pathlib.Path(args.out) / f"medleyvox_{mix_num}" 344 | out_dir.mkdir(parents=True, exist_ok=True) 345 | 346 | sf.write( 347 | out_dir / "mixture.wav", 348 | x.numpy(), 349 | sr, 350 | "PCM_16", 351 | ) 352 | 353 | for i, s in enumerate(est): 354 | out_path = out_dir / f"{ids[i]}.wav" 355 | sf.write(out_path, s, sr, "PCM_16") 356 | 357 | print(accumulate_metrics_mean) 358 | --------------------------------------------------------------------------------