├── .gitignore ├── LICENSE ├── README.md ├── assets └── pipeline.png ├── config ├── acc │ ├── fp16_gpus1.yaml │ └── fp16_gpus8.yaml ├── dataset │ ├── dnr.yaml │ └── mix_6k.yaml ├── debug.yaml ├── default.yaml ├── discriminator │ └── default.yaml ├── model │ ├── sdcodec_16k.yaml │ ├── sdcodec_16k_pretrainDAC.yaml │ ├── sdcodec_16k_shard4.yaml │ └── sdcodec_16k_shard8.yaml ├── run_config │ ├── slurm_1.yaml │ └── slurm_debug.yaml └── training │ ├── bs2_iter5k.yaml │ └── bs8_iter400k.yaml ├── environment.yml ├── eval_dnr.py ├── install_visqol.md ├── main.py ├── prepare ├── mani_dnr.py ├── mani_dns_clean.py ├── mani_dns_noise.py ├── mani_jamendo.py ├── mani_musan.py └── mani_wham.py └── src ├── datasets ├── __init__.py └── audio_dataset.py ├── metrics ├── __init__.py ├── adv.py ├── sdr.py ├── spectrum.py └── visqol.py ├── models ├── __init__.py ├── discriminator.py └── sdcodec.py ├── modules ├── __init__.py ├── base_dac.py ├── layers.py └── quantize.py ├── optim ├── __init__.py ├── cosine_lr_scheduler.py ├── exponential_lr_scheduler.py ├── inverse_sqrt_lr_scheduler.py ├── linear_warmup_lr_scheduler.py ├── polynomial_decay_lr_scheduler.py └── reduce_plateau_lr_scheduler.py ├── trainer.py └── utils ├── __init__.py ├── audio_process.py ├── torch_utils.py ├── utils.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## VsCode 2 | .vscode/ 3 | 4 | ## Python output 5 | **/*.pyc 6 | 7 | ## outputs 8 | manifest/ 9 | pretrained/ 10 | output/ 11 | local_scripts/ 12 | stdout/ 13 | stderr/ 14 | *.sh 15 | try.py 16 | plot*.py 17 | *.out 18 | *.wav 19 | *.ipynb -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Xiaoyu Bie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SD-Codec 2 | 3 | **Learning Source Disentanglement in Neural Audio Codec, ICASSP 2025** 4 | 5 | [Xiaoyu Bie](https://xiaoyubie1994.github.io/), [Xubo Liu](https://liuxubo717.github.io/), [Gaël Richard](https://www.telecom-paris.fr/gael-richard?l=en) 6 | 7 | **[[arXiv](https://arxiv.org/abs/2409.11228)]**, **[[Project](https://xiaoyubie1994.github.io/sdcodec/)]** 8 | 9 |

10 | 11 | The pretrained models can be downloaded via [Google Drive](https://drive.google.com/drive/folders/1-OjiNmtFdTUGwQF17FDMjzZgoBJJkHpG?usp=drive_link) 12 | 13 | ## Enviroment Setup 14 | All our models are trained on 8 A-100 80G GPUs 15 | 16 | ``` 17 | conda env create -f environment.yml 18 | conda activate gen_audio 19 | ``` 20 | 21 | The code was tested on Python 3.11.7 and PyTorch 2.1.2 22 | 23 | To install [VisQol](https://github.com/google/visqol), the simples way is to use [Bazelisk](https://github.com/bazelbuild/bazelisk?tab=readme-ov-file), especially if you want to install it on the cluster where you do not have the `sudo` right. 24 | Here is an [example](https://github.com/XiaoyuBIE1994/SDCodec/blob/main/install_visqol.md) 25 | 26 | ## Dataset Preparation 27 | 28 | We use the following dataset: 29 | - [Divide and Remaster (DnR)](https://zenodo.org/records/6949108) 30 | - [DNS-Challenge 5](https://github.com/microsoft/DNS-Challenge) 31 | - [MTG-Jamendo](https://mtg.github.io/mtg-jamendo-dataset/) 32 | - [MUSAN](https://www.openslr.org/17/) 33 | - [WHAM!](http://wham.whisper.ai/) 34 | 35 | ```bash 36 | mkdir manifest 37 | 38 | # DnR 39 | pyhthon prepare/mani_dnr.py --data-dir PATH_TO_DnR 40 | 41 | # DNS Challenge 5 42 | pyhthon prepare/mani_dns_clean.py --data-dir PATH_TO_DNS_CLEAN # or by partition 43 | pyhthon prepare/mani_dns_noise.py --data-dir PATH_TO_DNS_NOISE 44 | 45 | # Jamedo 46 | pyhthon prepare/mani_jamendo.py --data-dir PATH_TO_JAMENDO # or by partition 47 | 48 | # MUSAN 49 | pyhthon prepare/mani_musan.py --data-dir PATH_TO_MUSAN 50 | 51 | # WHAM 52 | pyhthon prepare/mani_wham.py --data-dir PATH_TO_WHAM 53 | ``` 54 | 55 | ## Training 56 | ``` 57 | # debug on single GPU 58 | accelerate launch --config_file config/acc/fp16_gpus1.yaml main.py --config-name debug +run_config=slurm_debug 59 | 60 | # training on 8 GPUs 61 | accelerate launch --config_file config/acc/fp16_gpus8.yaml main.py --config-name default +run_config=slurm_1 62 | ``` 63 | 64 | ## Evaluation 65 | By default, we use the last checkpoint for the evaluation 66 | ``` 67 | model_dir=PATH_TO_MODEL 68 | nohup python eval_dnr.py --ret-dir ${model_dir} --csv-path ./manifest/val.csv --length 5 > ${model_dir}/val.log 2>&1 & 69 | nohup python eval_dnr.py --ret-dir ${model_dir} --csv-path ./manifest/test.csv --length 10 > ${model_dir}/test.log 2>&1 & 70 | ``` 71 | 72 | 73 | ## Citation 74 | If you find this project usefule in your resarch, please consider cite: 75 | ```BibTeX 76 | @inproceedings{bie2025sdcodec, 77 | author={Bie, Xiaoyu and Liu, Xubo and Richard, Ga{\"e}l}, 78 | title={Learning Source Disentanglement in Neural Audio Codec}, 79 | booktitle={IEEE International Conference on Acoustic, Speech and Signal Procssing (ICASSP)}, 80 | year={2025}, 81 | } 82 | ``` 83 | 84 | 85 | ## Acknowledgments 86 | Some of the code in this project is inspired or modifed from the following projects: 87 | - [AudioCraft](https://github.com/facebookresearch/audiocraft) 88 | - [DAC](https://github.com/descriptinc/descript-audio-codec) 89 | -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XiaoyuBIE1994/SDCodec/62026eb7c1fc1079ef55f1c170f1066f383b349e/assets/pipeline.png -------------------------------------------------------------------------------- /config/acc/fp16_gpus1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: fp16 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false -------------------------------------------------------------------------------- /config/acc/fp16_gpus8.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: fp16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /config/dataset/dnr.yaml: -------------------------------------------------------------------------------- 1 | # Dataset config 2 | trainset_cfg: 3 | speech: 4 | - manifest/speech_dnr.csv 5 | music: 6 | - manifest/music_dnr.csv 7 | sfx: 8 | - manifest/sfx_dnr.csv 9 | n_examples: 10000000 10 | chunk_size: 2.0 11 | trim_silence: False 12 | 13 | 14 | valset_cfg: 15 | tsv_filepath: manifest/val.csv 16 | chunk_size: 5.0 17 | 18 | testset_cfg: 19 | tsv_filepath: manifest/test.csv 20 | chunk_size: 10.0 21 | 22 | -------------------------------------------------------------------------------- /config/dataset/mix_6k.yaml: -------------------------------------------------------------------------------- 1 | # Dataset config 2 | trainset_cfg: 3 | speech: 4 | - manifest/speech_dns5_emotional_speech.csv 5 | - manifest/speech_dns5_french_speech.csv 6 | - manifest/speech_dns5_german_speech.csv 7 | - manifest/speech_dns5_italian_speech.csv 8 | - manifest/speech_dns5_read_speech.csv 9 | - manifest/speech_dns5_russian_speech.csv 10 | - manifest/speech_dns5_spanish_speech.csv 11 | - manifest/speech_dns5_vctk_wav48_silence_trimmed.csv 12 | - manifest/speech_dns5_VocalSet_48kHz_mono.csv 13 | - manifest/speech_musan.csv 14 | music: 15 | - manifest/music_jamendo.csv 16 | - manifest/music_musan.csv 17 | sfx: 18 | - manifest/sfx_dns5.csv 19 | - manifest/sfx_musan.csv 20 | - manifest/sfx_wham.csv 21 | n_examples: 10000000 22 | chunk_size: 2.0 23 | trim_silence: False 24 | 25 | 26 | valset_cfg: 27 | tsv_filepath: manifest/val_dnr.csv 28 | chunk_size: 5.0 29 | 30 | testset_cfg: 31 | tsv_filepath: manifest/test_dnr.csv 32 | chunk_size: 10.0 33 | -------------------------------------------------------------------------------- /config/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: sdcodec_16k 3 | - discriminator: default 4 | - dataset: dnr 5 | - training: bs2_iter5k 6 | - _self_ 7 | 8 | sampling_rate: 16000 9 | resume: True 10 | resume_dir: null 11 | backup_code: True -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: sdcodec_16k 3 | - discriminator: default 4 | - dataset: mix_6k 5 | - training: bs8_iter400k 6 | - _self_ 7 | 8 | sampling_rate: 16000 9 | resume: True 10 | resume_dir: null 11 | backup_code: True 12 | -------------------------------------------------------------------------------- /config/discriminator/default.yaml: -------------------------------------------------------------------------------- 1 | # Basic 2 | name: Discriminator 3 | rates: [] 4 | periods: [2, 3, 5, 7, 11] 5 | fft_sizes: [2048, 1024, 512] 6 | bands: 7 | - [0.0, 0.1] 8 | - [0.1, 0.25] 9 | - [0.25, 0.5] 10 | - [0.5, 0.75] 11 | - [0.75, 1.0] 12 | -------------------------------------------------------------------------------- /config/model/sdcodec_16k.yaml: -------------------------------------------------------------------------------- 1 | # Basic 2 | name: SDCodec 3 | latent_dim: 1024 4 | tracks: ['speech', 'music', 'sfx'] 5 | enc_params: 6 | name: DACEncoder 7 | d_model: 64 8 | strides: [2, 4, 5, 8] 9 | dec_params: 10 | name: DACDecoder 11 | d_model: 1536 12 | strides: [8, 5, 4, 2] 13 | quant_params: 14 | name: MultiSourceRVQ 15 | n_codebooks: [12, 12, 12] 16 | codebook_size: [1024, 1024, 1024] 17 | codebook_dim: [8, 8, 8] 18 | quantizer_dropout: 0.0 19 | code_jit_prob: [0.0, 0.0, 0.0] 20 | code_jit_size: [3, 5, 3] 21 | shared_codebooks: 0 22 | pretrain: {} 23 | 24 | -------------------------------------------------------------------------------- /config/model/sdcodec_16k_pretrainDAC.yaml: -------------------------------------------------------------------------------- 1 | # Basic 2 | name: SDCodec 3 | latent_dim: 1024 4 | tracks: ['speech', 'music', 'sfx'] 5 | enc_params: 6 | name: DACEncoder 7 | d_model: 64 8 | strides: [2, 4, 5, 8] 9 | dec_params: 10 | name: DACDecoder 11 | d_model: 1536 12 | strides: [8, 5, 4, 2] 13 | quant_params: 14 | name: MultiSourceRVQ 15 | n_codebooks: [12, 12, 12] 16 | codebook_size: [1024, 1024, 1024] 17 | codebook_dim: [8, 8, 8] 18 | quantizer_dropout: 0.0 19 | code_jit_prob: [0.0, 0.0, 0.0] 20 | code_jit_size: [3, 5, 3] 21 | shared_codebooks: 1 22 | pretrain: 23 | load_pretrained: './pretrained/weights_16khz_8kbps_0.0.5_new.pth' 24 | ignore_args: ['n_codebooks', 'codebook_size', 'codebook_dim'] 25 | ignore_modules: [] 26 | freeze_modules: [] 27 | 28 | -------------------------------------------------------------------------------- /config/model/sdcodec_16k_shard4.yaml: -------------------------------------------------------------------------------- 1 | # Basic 2 | name: SDCodec 3 | latent_dim: 1024 4 | tracks: ['speech', 'music', 'sfx'] 5 | enc_params: 6 | name: DACEncoder 7 | d_model: 64 8 | strides: [2, 4, 5, 8] 9 | dec_params: 10 | name: DACDecoder 11 | d_model: 1536 12 | strides: [8, 5, 4, 2] 13 | quant_params: 14 | name: MultiSourceRVQ 15 | n_codebooks: [12, 12, 12] 16 | codebook_size: [1024, 1024, 1024] 17 | codebook_dim: [8, 8, 8] 18 | quantizer_dropout: 0.0 19 | code_jit_prob: [0.0, 0.0, 0.0] 20 | code_jit_size: [3, 5, 3] 21 | shared_codebooks: 4 22 | pretrain: {} 23 | 24 | -------------------------------------------------------------------------------- /config/model/sdcodec_16k_shard8.yaml: -------------------------------------------------------------------------------- 1 | # Basic 2 | name: SDCodec 3 | latent_dim: 1024 4 | tracks: ['speech', 'music', 'sfx'] 5 | enc_params: 6 | name: DACEncoder 7 | d_model: 64 8 | strides: [2, 4, 5, 8] 9 | dec_params: 10 | name: DACDecoder 11 | d_model: 1536 12 | strides: [8, 5, 4, 2] 13 | quant_params: 14 | name: MultiSourceRVQ 15 | n_codebooks: [12, 12, 12] 16 | codebook_size: [1024, 1024, 1024] 17 | codebook_dim: [8, 8, 8] 18 | quantizer_dropout: 0.0 19 | code_jit_prob: [0.0, 0.0, 0.0] 20 | code_jit_size: [3, 5, 3] 21 | shared_codebooks: 8 22 | pretrain: {} 23 | 24 | -------------------------------------------------------------------------------- /config/run_config/slurm_1.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | job: 4 | config: 5 | override_dirname: 6 | kv_sep: '_' 7 | item_sep: '/' 8 | exclude_keys: 9 | - config_name 10 | - run_config 11 | run: 12 | # dir: output/debug # job's output dir 13 | dir: output/${hydra.job.config_name}/${hydra.job.override_dirname} # job's output dir 14 | # dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} 15 | job_logging: 16 | formatters: 17 | simple: 18 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 19 | datefmt: '%Y-%m-%d %H:%M:%S' -------------------------------------------------------------------------------- /config/run_config/slurm_debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | job: 4 | config: 5 | override_dirname: 6 | kv_sep: ':' 7 | item_sep: '/' 8 | exclude_keys: 9 | - run_config 10 | - distributed_training.distributed_port 11 | - distributed_training.distributed_world_size 12 | - model.pretrained_model_path 13 | - model.target_network_path 14 | - next_script 15 | - task.cache_in_scratch 16 | - task.data 17 | - checkpoint.save_interval_updates 18 | - checkpoint.keep_interval_updates 19 | - checkpoint.save_on_overflow 20 | - common.log_interval 21 | - common.user_dir 22 | run: 23 | dir: output/debug # job's output dir 24 | # dir: output/${hydra.job.config_name}/${hydra.job.override_dirname} # job's output dir 25 | # dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname} 26 | job_logging: 27 | formatters: 28 | simple: 29 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 30 | datefmt: '%Y-%m-%d %H:%M:%S' -------------------------------------------------------------------------------- /config/training/bs2_iter5k.yaml: -------------------------------------------------------------------------------- 1 | 2 | total_steps: &total 5000 3 | warmup_steps: &warm 500 4 | print_steps: 50 5 | eval_steps: 500 6 | vis_steps: 500 7 | test_steps: 20000 8 | 9 | # total_steps: &total 200000 10 | # warmup_steps: &warm 10000 11 | # print_steps: 500 12 | # eval_steps: 5000 13 | # vis_steps: 10000 14 | # test_steps: 2000000 15 | 16 | early_stop: 50 # count for each eval 17 | grad_clip: 10.0 # no clip if < 0 18 | save_iters: [1000, 1500] 19 | vis_idx: [100, 200, 300] 20 | 21 | seed: 42 22 | 23 | # data transformation and augmentation 24 | transform: 25 | lufs_norm_db: 26 | speech: -17 27 | music: -24 28 | sfx: -21 29 | mix: -27 30 | var: 2 31 | peak_norm_db: -0.5 32 | random_num_sources: [0.2, 0.2, 0.6] 33 | random_swap_prob: 0.5 34 | 35 | # Optimizer 36 | optimizer: 37 | name: AdamW 38 | lr: 1e-4 39 | betas: [0.8, 0.99] 40 | # weight_decay: 1e-5 41 | 42 | # Scheduler 43 | scheduler: 44 | name: ExponentialLRScheduler 45 | total_steps: *total 46 | warmup_steps: *warm 47 | lr_min_ratio: 0.0 48 | gamma: 0.999 49 | # gamma: 0.999996 50 | 51 | # Loss 52 | loss: 53 | MultiScaleSTFTLoss: 54 | window_lengths: [2048, 512] 55 | MelSpectrogramLoss: 56 | n_mels: [5, 10, 20, 40, 80, 160, 320] 57 | window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 58 | mel_fmin: [0, 0, 0, 0, 0, 0, 0] 59 | mel_fmax: [null, null, null, null, null, null, null] 60 | pow: 1.0 61 | clamp_eps: 1.0e-5 62 | mag_weight: 0.0 63 | lambdas: 64 | mel/loss: 15.0 65 | adv/feat_loss: 2.0 66 | adv/gen_loss: 1.0 67 | vq/commitment_loss: 0.25 68 | vq/codebook_loss: 1.0 69 | 70 | # Dataloader config 71 | dataloader: 72 | num_workers: 8 73 | train_bs: 2 74 | eval_bs: 32 75 | -------------------------------------------------------------------------------- /config/training/bs8_iter400k.yaml: -------------------------------------------------------------------------------- 1 | total_steps: &total 400000 2 | warmup_steps: &warm 10000 3 | print_steps: 500 4 | eval_steps: 5000 5 | vis_steps: 10000 6 | test_steps: 30000 7 | 8 | 9 | early_stop: 50 # count for each eval 10 | grad_clip: 10.0 # no clip if < 0 11 | save_iters: [50000, 100000, 150000, 200000, 250000, 300000, 350000] 12 | vis_idx: [100, 200, 300] 13 | 14 | seed: 42 15 | 16 | # data transformation and augmentation 17 | transform: 18 | lufs_norm_db: 19 | speech: -17 20 | music: -24 21 | sfx: -21 22 | mix: -27 23 | var: 2 24 | peak_norm_db: -0.5 25 | random_num_sources: [0.6, 0.2, 0.2] 26 | random_swap_prob: 0.5 27 | 28 | # Optimizer 29 | optimizer: 30 | name: AdamW 31 | lr: 1e-4 32 | betas: [0.8, 0.99] 33 | # weight_decay: 1e-5 34 | 35 | # Scheduler 36 | scheduler: 37 | name: ExponentialLRScheduler 38 | total_steps: *total 39 | warmup_steps: *warm 40 | lr_min_ratio: 0.0 41 | gamma: 0.999996 42 | 43 | # Loss 44 | loss: 45 | MultiScaleSTFTLoss: 46 | window_lengths: [2048, 512] 47 | MelSpectrogramLoss: 48 | n_mels: [5, 10, 20, 40, 80, 160, 320] 49 | window_lengths: [32, 64, 128, 256, 512, 1024, 2048] 50 | mel_fmin: [0, 0, 0, 0, 0, 0, 0] 51 | mel_fmax: [null, null, null, null, null, null, null] 52 | pow: 1.0 53 | clamp_eps: 1.0e-5 54 | mag_weight: 0.0 55 | lambdas: 56 | mel/loss: 15.0 57 | adv/feat_loss: 2.0 58 | adv/gen_loss: 1.0 59 | vq/commitment_loss: 0.25 60 | vq/codebook_loss: 1.0 61 | 62 | # Dataloader config 63 | dataloader: 64 | num_workers: 8 65 | train_bs: 8 66 | eval_bs: 32 67 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gen_audio 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - binutils_impl_linux-64=2.38=h2a08ee3_1 11 | - blas=1.0=mkl 12 | - brotli-python=1.0.9=py311h6a678d5_7 13 | - bzip2=1.0.8=h7b6447c_0 14 | - ca-certificates=2024.3.11=h06a4308_0 15 | - certifi=2024.2.2=py311h06a4308_0 16 | - cffi=1.16.0=py311h5eee18b_0 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - cryptography=41.0.7=py311hdda0065_0 19 | - cuda-cudart=11.8.89=0 20 | - cuda-cupti=11.8.87=0 21 | - cuda-libraries=11.8.0=0 22 | - cuda-nvrtc=11.8.89=0 23 | - cuda-nvtx=11.8.86=0 24 | - cuda-runtime=11.8.0=0 25 | - ffmpeg=4.3=hf484d3e_0 26 | - filelock=3.13.1=py311h06a4308_0 27 | - freetype=2.12.1=h4a9f257_0 28 | - gcc=12.1.0=h9ea6d83_10 29 | - gcc_impl_linux-64=12.1.0=hea43390_17 30 | - giflib=5.2.1=h5eee18b_3 31 | - gmp=6.2.1=h295c915_3 32 | - gmpy2=2.1.2=py311hc9b5ff0_0 33 | - gnutls=3.6.15=he1e5248_0 34 | - gxx_impl_linux-64=12.1.0=hea43390_17 35 | - idna=3.4=py311h06a4308_0 36 | - intel-openmp=2023.1.0=hdb19cb5_46306 37 | - jinja2=3.1.2=py311h06a4308_0 38 | - jpeg=9e=h5eee18b_1 39 | - kernel-headers_linux-64=2.6.32=he073ed8_17 40 | - lame=3.100=h7b6447c_0 41 | - lcms2=2.12=h3be6417_0 42 | - ld_impl_linux-64=2.38=h1181459_1 43 | - lerc=3.0=h295c915_0 44 | - libcublas=11.11.3.6=0 45 | - libcufft=10.9.0.58=0 46 | - libcufile=1.8.1.2=0 47 | - libcurand=10.3.4.107=0 48 | - libcusolver=11.4.1.48=0 49 | - libcusparse=11.7.5.86=0 50 | - libdeflate=1.17=h5eee18b_1 51 | - libffi=3.4.4=h6a678d5_0 52 | - libgcc-devel_linux-64=12.1.0=h1ec3361_17 53 | - libgcc-ng=13.2.0=h807b86a_5 54 | - libgomp=13.2.0=h807b86a_5 55 | - libiconv=1.16=h7f8727e_2 56 | - libidn2=2.3.4=h5eee18b_0 57 | - libjpeg-turbo=2.0.0=h9bf148f_0 58 | - libnpp=11.8.0.86=0 59 | - libnvjpeg=11.9.0.86=0 60 | - libpng=1.6.39=h5eee18b_0 61 | - libsanitizer=12.1.0=ha89aaad_17 62 | - libstdcxx-devel_linux-64=12.1.0=h1ec3361_17 63 | - libstdcxx-ng=13.2.0=h7e041cc_5 64 | - libtasn1=4.19.0=h5eee18b_0 65 | - libtiff=4.5.1=h6a678d5_0 66 | - libunistring=0.9.10=h27cfd23_0 67 | - libuuid=1.41.5=h5eee18b_0 68 | - libwebp=1.3.2=h11a3e52_0 69 | - libwebp-base=1.3.2=h5eee18b_0 70 | - llvm-openmp=14.0.6=h9e868ea_0 71 | - lz4-c=1.9.4=h6a678d5_0 72 | - markupsafe=2.1.3=py311h5eee18b_0 73 | - mkl=2023.1.0=h213fc3f_46344 74 | - mkl-service=2.4.0=py311h5eee18b_1 75 | - mkl_fft=1.3.8=py311h5eee18b_0 76 | - mkl_random=1.2.4=py311hdb19cb5_0 77 | - mpc=1.1.0=h10f8cd9_1 78 | - mpfr=4.0.2=hb69a4c5_1 79 | - mpmath=1.3.0=py311h06a4308_0 80 | - ncurses=6.4=h6a678d5_0 81 | - nettle=3.7.3=hbbd107a_1 82 | - networkx=3.1=py311h06a4308_0 83 | - openh264=2.1.1=h4ff587b_0 84 | - openjpeg=2.4.0=h3ad879b_0 85 | - openssl=3.2.1=hd590300_1 86 | - pillow=10.0.1=py311ha6cbd5a_0 87 | - pycparser=2.21=pyhd3eb1b0_0 88 | - pyopenssl=23.2.0=py311h06a4308_0 89 | - pysocks=1.7.1=py311h06a4308_0 90 | - python=3.11.7=h955ad1f_0 91 | - pytorch=2.1.2=py3.11_cuda11.8_cudnn8.7.0_0 92 | - pytorch-cuda=11.8=h7e8668a_5 93 | - pytorch-mutex=1.0=cuda 94 | - pyyaml=6.0.1=py311h5eee18b_0 95 | - readline=8.2=h5eee18b_0 96 | - requests=2.31.0=py311h06a4308_0 97 | - sqlite=3.41.2=h5eee18b_0 98 | - sympy=1.12=py311h06a4308_0 99 | - sysroot_linux-64=2.12=he073ed8_17 100 | - tbb=2021.8.0=hdb19cb5_0 101 | - tk=8.6.12=h1ccaba5_0 102 | - torchaudio=2.1.2=py311_cu118 103 | - torchtriton=2.1.0=py311 104 | - torchvision=0.16.2=py311_cu118 105 | - typing_extensions=4.7.1=py311h06a4308_0 106 | - urllib3=1.26.18=py311h06a4308_0 107 | - wheel=0.41.2=py311h06a4308_0 108 | - xz=5.4.5=h5eee18b_0 109 | - yaml=0.2.5=h7b6447c_0 110 | - zlib=1.2.13=h5eee18b_0 111 | - zstd=1.5.5=hc292b87_0 112 | - pip: 113 | - absl-py==2.0.0 114 | - accelerate==0.28.0 115 | - aiohttp==3.9.3 116 | - aiosignal==1.3.1 117 | - antlr4-python3-runtime==4.9.3 118 | - appdirs==1.4.4 119 | - argbind==0.3.7 120 | - argparse==1.4.0 121 | - asttokens==2.4.1 122 | - attrs==23.2.0 123 | - audioread==3.0.1 124 | - braceexpand==0.1.7 125 | - cachetools==5.3.2 126 | - click==8.1.7 127 | - contourpy==1.2.0 128 | - cycler==0.12.1 129 | - datasets==2.18.0 130 | - decorator==5.1.1 131 | - descript-audio-codec==1.0.0 132 | - descript-audiotools==0.7.2 133 | - diffusers==0.27.2 134 | - dill==0.3.8 135 | - docker-pycreds==0.4.0 136 | - docstring-parser==0.16 137 | - einops==0.7.0 138 | - executing==2.0.1 139 | - ffmpy==0.3.2 140 | - fire==0.6.0 141 | - flatten-dict==0.4.2 142 | - fonttools==4.47.0 143 | - frozenlist==1.4.1 144 | - fsspec==2023.12.2 145 | - ftfy==6.2.0 146 | - future==0.18.3 147 | - gitdb==4.0.11 148 | - gitpython==3.1.43 149 | - google-auth==2.26.1 150 | - google-auth-oauthlib==1.2.0 151 | - grpcio==1.60.0 152 | - h5py==3.11.0 153 | - huggingface-hub==0.20.2 154 | - hydra-core==1.3.2 155 | - importlib-metadata==7.0.2 156 | - importlib-resources==6.4.0 157 | - ipython==8.23.0 158 | - jedi==0.19.1 159 | - joblib==1.3.2 160 | - julius==0.2.7 161 | - kiwisolver==1.4.5 162 | - laion-clap==1.1.4 163 | - lazy-loader==0.3 164 | - librosa==0.10.1 165 | - llvmlite==0.41.1 166 | - markdown==3.5.1 167 | - markdown-it-py==3.0.0 168 | - markdown2==2.4.13 169 | - matplotlib==3.8.3 170 | - matplotlib-inline==0.1.6 171 | - mdurl==0.1.2 172 | - msgpack==1.0.7 173 | - multidict==6.0.5 174 | - multiprocess==0.70.16 175 | - numba==0.58.1 176 | - numpy==1.23.5 177 | - oauthlib==3.2.2 178 | - omegaconf==2.3.0 179 | - packaging==23.2 180 | - pandas==2.1.4 181 | - parso==0.8.3 182 | - pesq==0.0.4 183 | - pexpect==4.9.0 184 | - pip==24.0 185 | - platformdirs==4.1.0 186 | - pooch==1.8.0 187 | - progressbar==2.5 188 | - prompt-toolkit==3.0.43 189 | - protobuf==3.19.6 190 | - psutil==5.9.7 191 | - ptyprocess==0.7.0 192 | - pure-eval==0.2.2 193 | - pyarrow==15.0.1 194 | - pyarrow-hotfix==0.6 195 | - pyasn1==0.5.1 196 | - pyasn1-modules==0.3.0 197 | - pygments==2.17.2 198 | - pyloudnorm==0.1.1 199 | - pyparsing==3.1.1 200 | - pysndfx==0.3.6 201 | - pystoi==0.4.1 202 | - python-dateutil==2.8.2 203 | - pytz==2023.3.post1 204 | - randomname==0.2.1 205 | - regex==2023.12.25 206 | - requests-oauthlib==1.3.1 207 | - resampy==0.4.3 208 | - rich==13.7.1 209 | - rsa==4.9 210 | - safetensors==0.4.1 211 | - scikit-learn==1.3.2 212 | - scipy==1.12.0 213 | - sentry-sdk==2.0.1 214 | - setproctitle==1.3.3 215 | - setuptools==69.0.3 216 | - six==1.16.0 217 | - smmap==5.0.1 218 | - soundfile==0.12.1 219 | - soxr==0.3.7 220 | - stack-data==0.6.3 221 | - tensorboard==2.16.2 222 | - tensorboard-data-server==0.7.2 223 | - termcolor==2.4.0 224 | - threadpoolctl==3.2.0 225 | - tokenizers==0.13.3 226 | - torch-stoi==0.2.1 227 | - torchlibrosa==0.1.0 228 | - tqdm==4.66.1 229 | - traitlets==5.14.2 230 | - transformers==4.30.0 231 | - tzdata==2023.4 232 | - visqol==3.3.3 233 | - wandb==0.16.6 234 | - wcwidth==0.2.13 235 | - webdataset==0.2.86 236 | - werkzeug==3.0.1 237 | - wget==3.2 238 | - xxhash==3.4.1 239 | - yarl==1.9.4 240 | - zipp==3.18.1 241 | prefix: /home/xbie/anaconda3/envs/gen_audio 242 | -------------------------------------------------------------------------------- /eval_dnr.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import json 4 | import argparse 5 | import importlib 6 | import math 7 | import julius 8 | import pandas as pd 9 | from pathlib import Path 10 | from omegaconf import OmegaConf 11 | from tqdm import tqdm 12 | import numpy as np 13 | from collections import namedtuple 14 | 15 | from src import utils 16 | from src.metrics import ( 17 | VisqolMetric, 18 | SingleSrcNegSDR, 19 | MultiScaleSTFTLoss, 20 | MelSpectrogramLoss, 21 | ) 22 | 23 | import torch 24 | import torchaudio 25 | from accelerate import Accelerator 26 | accelerator = Accelerator() 27 | 28 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset', 29 | add_help=True, 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | 32 | parser.add_argument('--ret-dir', type=str, default='output/debug', help='Training result directory') 33 | parser.add_argument('--csv-path', type=str, default='./manifest/test.csv', help='csv file to test') 34 | parser.add_argument('--data-sr', type=int, default=[44100], nargs='+', help='list of sampling rate in test files') 35 | parser.add_argument('--length', type=int, default=10, help='audio length') 36 | parser.add_argument('--visqol-mode', type=str, default='speech', choices=['speech', 'audio'], help='visqol mode') 37 | parser.add_argument('--threshold', type=float, default=0.4, help='threshold of silence part to drop audio') 38 | parser.add_argument('--fast', action='store_true', help='fast eval, disable visqol computation') 39 | 40 | # parse 41 | args = parser.parse_args() 42 | ret_dir = Path(args.ret_dir) 43 | csv_path = Path(args.csv_path) 44 | length = args.length 45 | visqol_mode = args.visqol_mode 46 | threshold = args.threshold 47 | use_visqol = not args.fast 48 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 49 | 50 | # read config 51 | cfg_filepath = ret_dir / '.hydra' / 'config.yaml' 52 | cfg = OmegaConf.load(cfg_filepath) 53 | sample_rate = cfg.sampling_rate 54 | chunk_len = sample_rate * length 55 | 56 | # init julius resample 57 | resample_pool = dict() 58 | for sr in args.data_sr: 59 | old_sr = sr 60 | new_sr = sample_rate 61 | gcd = math.gcd(old_sr, new_sr) 62 | old_sr = old_sr // gcd 63 | new_sr = new_sr // gcd 64 | resample_pool[sr] = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr) 65 | 66 | # import lib 67 | model_name = cfg.model.pop('name') 68 | module_path = str(ret_dir / 'backup_src' / 'models').replace('/', '.') 69 | try: 70 | load_model = importlib.import_module(module_path) 71 | net_class = getattr(load_model, f'{model_name}') 72 | print('Load model from ckpt') 73 | except: 74 | from src import models 75 | net_class = getattr(models, f'{model_name}') 76 | print('Load model from source code') 77 | 78 | # load model and weigth 79 | model_cfg = cfg.model 80 | model = net_class(sample_rate=sample_rate, **model_cfg) 81 | total_params = sum(p.numel() for p in model.parameters()) / 1e6 82 | print(f'Total params: {total_params:.2f} Mb') 83 | print('Model sampling rate: {} Hz'.format(model.sample_rate)) 84 | 85 | ckpt_finalpath = ret_dir / 'ckpt_final' / 'ckpt_model_final.pth' 86 | state_dict = torch.load(ckpt_finalpath, map_location=torch.device('cpu')) 87 | model.load_state_dict(state_dict) 88 | model = model.to(device) 89 | model.eval() 90 | print(f'ckpt path: {ckpt_finalpath}') 91 | print(f'Model weights load successfully...') 92 | 93 | # prepare metrics 94 | loss_cfg = cfg.training.loss 95 | metric_stft = MultiScaleSTFTLoss(**loss_cfg.MultiScaleSTFTLoss) 96 | metric_mel = MelSpectrogramLoss(**loss_cfg.MelSpectrogramLoss) 97 | metric_sisdr = SingleSrcNegSDR(sdr_type='sisdr') 98 | metric_visqol = VisqolMetric(mode=visqol_mode) 99 | 100 | # prepare data transform 101 | transform_cfg = cfg.training.transform 102 | volume_norm = utils.VolumeNorm(sample_rate=sample_rate) 103 | def _data_transform(batch, transform_cfg, valid_tracks=['speech'], norm_var=0): 104 | peak_norm = utils.db_to_gain(transform_cfg.peak_norm_db) 105 | mix_max_peak = torch.zeros_like(batch['speech'])[...,:1] # (bs, C, 1) 106 | 107 | # volume norm for each track 108 | for track in valid_tracks: 109 | batch[track] = volume_norm(signal=batch[track], 110 | target_loudness=transform_cfg.lufs_norm_db[track], 111 | var=norm_var) 112 | # peak value 113 | peak = batch[track].abs().max(dim=-1, keepdims=True)[0] 114 | mix_max_peak = torch.maximum(peak, mix_max_peak) 115 | 116 | # peak norm 117 | peak_gain = torch.ones_like(mix_max_peak) # (bs, C, 1) 118 | peak_gain[mix_max_peak > peak_norm] = peak_norm / mix_max_peak[mix_max_peak > peak_norm] 119 | 120 | # build mix 121 | batch['mix'] = torch.zeros_like(batch['speech']) 122 | for track in valid_tracks: 123 | batch[track] *= peak_gain 124 | batch['mix'] += batch[track] 125 | 126 | # mix volum norm 127 | batch['mix'], mix_gain = volume_norm(signal=batch['mix'], 128 | target_loudness=transform_cfg.lufs_norm_db['mix'], 129 | var=norm_var, 130 | return_gain=True) 131 | 132 | # norm each track 133 | for track in valid_tracks: 134 | batch[track] *= mix_gain[:, None, None] 135 | 136 | batch['valid_tracks'] = valid_tracks 137 | batch['random_swap'] = False 138 | 139 | return batch 140 | 141 | 142 | # define mask separation 143 | sep_norm = utils.WavSepMagNorm() 144 | 145 | # define STFT params 146 | STFTParams = namedtuple( 147 | "STFTParams", 148 | ["window_length", "hop_length", "window_type", "padding_type"], 149 | ) 150 | stft_params = STFTParams( 151 | window_length=1024, 152 | hop_length=256, 153 | window_type="hann", 154 | padding_type="reflect", 155 | ) 156 | 157 | # run eval 158 | tracks = model.tracks 159 | print('Model tracks: {}'.format(tracks)) 160 | test_tracks = ['mix'] + [f'{t}_rec' for t in tracks] + [f'{t}_sep' for t in tracks] + [f'{t}_sep_mask' for t in tracks] 161 | test_results = {t: {} for t in test_tracks} 162 | metadata = pd.read_csv(csv_path) 163 | 164 | for i in tqdm(range(len(metadata)), desc='Eval'): 165 | # for i in tqdm(range(20), desc='Eval'): 166 | wav_info = metadata.iloc[i] 167 | audio_id = wav_info['id'] 168 | start = wav_info['start'] 169 | end = wav_info['end'] 170 | batch = {} 171 | # read data 172 | for t in tracks: 173 | x, sr = torchaudio.load(wav_info[t]) 174 | x = x.mean(dim=0)[..., start: end] 175 | if sr != sample_rate: 176 | x = resample_pool[sr](x) 177 | batch[t] = x 178 | audio_len = x.shape[-1] 179 | 180 | # clip audio 181 | for j, k in enumerate(range(0, audio_len-chunk_len+1, chunk_len)): 182 | clip_id = f'{audio_id}_{j}' 183 | eval_batch = {} 184 | mask = {} 185 | for t in tracks: 186 | audio_clip = batch[t][k:k+chunk_len] 187 | 188 | # silent audio detection 189 | audio_energy = torch.stft(audio_clip, n_fft=stft_params.window_length, hop_length=stft_params.hop_length, 190 | win_length=stft_params.window_length, 191 | window=torch.hann_window(stft_params.window_length, device='cpu'), 192 | pad_mode=stft_params.padding_type, center=True, onesided=True, return_complex=True).abs().sum(dim=0) 193 | count = sum(1 for item in audio_energy if item > 1e-6) 194 | silence_detect = count < threshold * len(audio_energy) 195 | mask[f'{t}_rec'] = silence_detect 196 | mask[f'{t}_sep'] = silence_detect 197 | mask[f'{t}_sep_mask'] = silence_detect 198 | 199 | eval_batch[t] = audio_clip.reshape(1,1,-1).to(device) 200 | 201 | mask['mix'] = all(mask.values()) 202 | 203 | # data transform 204 | # eval_batch = _data_transform(eval_batch, transform_cfg=transform_cfg, valid_tracks=tracks, norm_var=0) 205 | eval_batch['mix'] = eval_batch['speech']+eval_batch['music']+eval_batch['sfx'] 206 | eval_batch['valid_tracks'] = tracks 207 | eval_batch['random_swap'] = False 208 | 209 | # mixture forward 210 | with torch.no_grad(): 211 | output_audio = model.evaluate(input_audio=eval_batch['mix'], 212 | output_tracks=['mix']+tracks) 213 | # eval_batch = model(eval_batch) 214 | # output_audio = eval_batch['recon'][:,:,0] 215 | 216 | # Eval mix reconstruction 217 | est = output_audio[:, 0].unsqueeze(1) 218 | ref = eval_batch['mix'] 219 | test_results['mix'][clip_id] = {} 220 | if mask['mix']: 221 | test_results['mix'][clip_id]['stft'] = None 222 | test_results['mix'][clip_id]['mel'] = None 223 | test_results['mix'][clip_id]['sisdr'] = None 224 | if use_visqol: 225 | test_results['mix'][clip_id]['visqol'] = None 226 | else: 227 | test_results['mix'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item() 228 | test_results['mix'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item() 229 | test_results['mix'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item() 230 | if use_visqol: 231 | test_results['mix'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate) 232 | 233 | # Eval separation using synthesizer (decoder) 234 | for p, t in enumerate(tracks): 235 | est = output_audio[:,p+1].unsqueeze(1) 236 | ref = eval_batch[t] 237 | test_results[f'{t}_sep'][clip_id] = {} 238 | if mask[f'{t}_sep']: 239 | test_results[f'{t}_sep'][clip_id]['stft'] = None 240 | test_results[f'{t}_sep'][clip_id]['mel'] = None 241 | test_results[f'{t}_sep'][clip_id]['sisdr'] = None 242 | if use_visqol: 243 | test_results[f'{t}_sep'][clip_id]['visqol'] = None 244 | else: 245 | test_results[f'{t}_sep'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item() 246 | test_results[f'{t}_sep'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item() 247 | test_results[f'{t}_sep'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item() 248 | if use_visqol: 249 | test_results[f'{t}_sep'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate) 250 | 251 | # Eval separation using mask 252 | mix = eval_batch['mix'].unsqueeze(2) 253 | signal_sep = output_audio[:,1:].unsqueeze(2) 254 | all_sep_mask_norm = sep_norm(mix, signal_sep) 255 | for p, t in enumerate(tracks): 256 | est = all_sep_mask_norm[:,p] 257 | ref = eval_batch[t] 258 | ref = ref[...,:est.shape[-1]] # stft + istft. shorter 259 | # breakpoint() 260 | test_results[f'{t}_sep_mask'][clip_id] = {} 261 | if mask[f'{t}_sep_mask']: 262 | test_results[f'{t}_sep_mask'][clip_id]['stft'] = None 263 | test_results[f'{t}_sep_mask'][clip_id]['mel'] = None 264 | test_results[f'{t}_sep_mask'][clip_id]['sisdr'] = None 265 | if use_visqol: 266 | test_results[f'{t}_sep'][clip_id]['visqol'] = None 267 | else: 268 | test_results[f'{t}_sep_mask'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item() 269 | test_results[f'{t}_sep_mask'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item() 270 | test_results[f'{t}_sep_mask'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item() 271 | if use_visqol: 272 | test_results[f'{t}_sep_mask'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate) 273 | 274 | # Evaluate reconstruction of single track 275 | for p, t in enumerate(tracks): 276 | # single track forward 277 | with torch.no_grad(): 278 | output_audio = model.evaluate(input_audio=eval_batch[t], 279 | output_tracks=[t]) 280 | est = output_audio 281 | ref = eval_batch[t] 282 | test_results[f'{t}_rec'][clip_id] = {} 283 | if mask[f'{t}_rec']: 284 | test_results[f'{t}_rec'][clip_id]['stft'] = None 285 | test_results[f'{t}_rec'][clip_id]['mel'] = None 286 | test_results[f'{t}_rec'][clip_id]['sisdr'] = None 287 | if use_visqol: 288 | test_results[f'{t}_rec'][clip_id]['visqol'] = None 289 | else: 290 | test_results[f'{t}_rec'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item() 291 | test_results[f'{t}_rec'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item() 292 | test_results[f'{t}_rec'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item() 293 | if use_visqol: 294 | test_results[f'{t}_rec'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate) 295 | 296 | 297 | test_results['summary'] = {} 298 | for track in test_tracks: 299 | test_results['summary'][track] = {} 300 | list_stft = [] 301 | list_mel = [] 302 | list_sisdr = [] 303 | if use_visqol: 304 | list_visqol = [] 305 | 306 | for metrics in test_results[track].values(): 307 | list_stft.append(metrics['stft']) 308 | list_mel.append(metrics['mel']) 309 | list_sisdr.append(metrics['sisdr']) 310 | if use_visqol: 311 | list_visqol.append(metrics['visqol']) 312 | 313 | np_stft = np.array([x for x in list_stft if x is not None]) 314 | np_mel = np.array([x for x in list_mel if x is not None]) 315 | np_sisdr = np.array([x for x in list_sisdr if x is not None]) 316 | if use_visqol: 317 | np_visqol = np.array([x for x in list_visqol if x is not None]) 318 | 319 | stft_m, stft_std = np.mean(np_stft), np.std(np_stft) 320 | mel_m, mel_std = np.mean(np_mel), np.std(np_mel) 321 | sisdr_m, sisdr_std = np.mean(np_sisdr), np.std(np_sisdr) 322 | if use_visqol: 323 | visqol_m, visqol_std = np.mean(np_visqol), np.std(np_visqol) 324 | 325 | print('='*80) 326 | print(f'{track}') 327 | print('Valid datapoint: {}/{}'.format(len(np_stft), len(list_stft))) 328 | print('Distance STFT: {:.2f} +/- {:.2f}'.format(stft_m, stft_std)) 329 | print('Distance Mel: {:.2f} +/- {:.2f}'.format(mel_m, mel_std)) 330 | print('SI-SDR: {:.2f} +/- {:.2f}'.format(sisdr_m, sisdr_std)) 331 | if use_visqol: 332 | print('VisQOL: {:.2f} +/- {:.2f}'.format(visqol_m, visqol_std)) 333 | 334 | test_results['summary'][track]['tot_seq'] = len(list_stft) 335 | test_results['summary'][track]['valid_seq'] = len(np_stft) 336 | test_results['summary'][track]['stft'] = {'mean': stft_m, 'std': stft_std} 337 | test_results['summary'][track]['mel'] = {'mean': mel_m, 'std': mel_std} 338 | test_results['summary'][track]['sisdr'] = {'mean': sisdr_m, 'std': sisdr_std} 339 | if use_visqol: 340 | test_results['summary'][track]['visqol'] = {'mean': visqol_m, 'std': visqol_std} 341 | 342 | 343 | # save to json 344 | json_filename = ret_dir / '{}_{}s.json'.format(csv_path.stem, length) 345 | with open(json_filename, 'w') as f: 346 | json.dump(test_results, f, indent=1) -------------------------------------------------------------------------------- /install_visqol.md: -------------------------------------------------------------------------------- 1 | ## 1. Install Bazel 2 | 3 | Here is an example, you may need to change the download link to the latest version, and adapt to your conda env filepath 4 | 5 | ```bash 6 | # Linux for example 7 | 8 | # Check the architecture of your computer/cluster 9 | dpkg --print-architecture 10 | 11 | # Download the right version of Bazelisk binary realease, e.g. linux-amd64 12 | # check the latest version: https://github.com/bazelbuild/bazelisk/releases 13 | wget https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-linux-amd64 14 | 15 | # Rename and move to your conda env, e.g. torch2 16 | chmod +x bazelisk-linux-amd64 17 | mv bazelisk-linux-amd64 /home/ids/xbie/anaconda3/envs/torch2/bin/bazel 18 | 19 | # Now you can run bazel by activating the conda env 20 | ``` 21 | 22 | ## 2. Install VisQOL 23 | ```bash 24 | # Clone the VisQOL repo 25 | git clone https://github.com/google/visqol.git 26 | 27 | # Install bin, need GCC 12 28 | cd visqol 29 | bazel build :visqol -c opt --jobs=10 30 | 31 | # Install Python API 32 | pip install . 33 | ``` 34 | 35 | ## 3. Potential Problem 36 | Here are some problems I have encountered during installation and how I solved them 37 | ```bash 38 | # It might encounter gcc/g++ compilation problem, even if you have successfully build visqol 39 | # e.g. libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/xbie/anaconda3/envs/torch2/lib/python3.11/site-packages/visqol/visqol_lib_py.so) 40 | # in that case, masure your gcc/g++ version is >= 12 41 | conda install conda-forge::gxx_impl_linux-64 42 | conda install -c conda-forge gcc=12.1.0 43 | 44 | # In case your Bazel is installed in the system path 45 | # then you should update gcc/g++ in system-wise 46 | sudo apt install --reinstall gcc-12 47 | sudo apt install --reinstall g++-12 48 | sudo ln -s -f /usr/bin/gcc-12 /usr/bin/gcc 49 | sudo ln -s -f /usr/bin/g++-12 /usr/bin/g++ 50 | 51 | # In case fail to import `visqol_lib_py` with unkown path 52 | # install it manually, make sure you are in the Visqol root path 53 | # if on the cluster, use `--jobs=10` to limit the parallel jobs 54 | # then we need to re-run the installation to link the .so library 55 | bazel build -c opt //python:visqol_lib_py.so --jobs=10 56 | pip install . 57 | 58 | # If still have problems, 59 | # remove all Bazel cache files and re-run the installation 60 | rm -rf ~/.cache/bazel 61 | rm -rf ~/.cache/bazelisk 62 | `` 63 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Copyright (c) 2024 by Telecom-Paris 5 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 6 | License agreement in LICENSE.txt 7 | """ 8 | 9 | import os 10 | import sys 11 | import shutil 12 | import logging 13 | import hydra 14 | from omegaconf import DictConfig, OmegaConf 15 | 16 | import torch 17 | 18 | @hydra.main(version_base=None, config_path="config", config_name="default") 19 | def main(cfg: DictConfig) -> None: 20 | 21 | from src.trainer import Trainer 22 | trainer = Trainer(cfg) 23 | trainer.run() 24 | 25 | 26 | if __name__ == "__main__": 27 | main() -------------------------------------------------------------------------------- /prepare/mani_dnr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | import os 7 | import argparse 8 | import librosa 9 | import torchaudio 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from collections import namedtuple 13 | import torch 14 | 15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset', 16 | add_help=True, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/dnr_v2', help='Audio Dataset Path') 20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest') 21 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format') 22 | 23 | args = parser.parse_args() 24 | data_dir = Path(args.data_dir) 25 | out_dir = Path(args.out_dir) 26 | ext = args.ext 27 | 28 | train_dir = data_dir / 'tr' 29 | val_dir = data_dir / 'cv' 30 | test_dir = data_dir / 'tt' 31 | tracks = ['speech', 'music', 'sfx'] 32 | 33 | # define STFT params 34 | STFTParams = namedtuple( 35 | "STFTParams", 36 | ["window_length", "hop_length", "window_type", "padding_type"], 37 | ) 38 | stft_params = STFTParams( 39 | window_length=1024, 40 | hop_length=256, 41 | window_type="hann", 42 | padding_type="reflect", 43 | ) 44 | 45 | ## train 46 | for track in tracks: 47 | with open(out_dir / f'{track}_dnr.csv', 'w') as f: 48 | f.write('id,filepath,sr,length,start,end\n') 49 | for subdir in tqdm(sorted(train_dir.iterdir()), desc=f'train_{track}'): 50 | if subdir.is_dir(): 51 | audio_id = subdir.name 52 | audio_filepath = subdir / f'{track}.{ext}' 53 | x, sr = torchaudio.load(audio_filepath) 54 | length = x.shape[-1] 55 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 56 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 57 | f.write(line) 58 | 59 | 60 | ## val with no silence 61 | length = 5 62 | threshold = 0.5 63 | with open(out_dir / 'val_dnr.csv', 'w') as f: 64 | f.write('id,mix,speech,music,sfx,sr,length,start,end\n') 65 | for subdir in tqdm(sorted(val_dir.iterdir()), desc='val'): 66 | if subdir.is_dir(): 67 | audio_id = subdir.name 68 | mix_filepath = subdir / f'mix.{ext}' 69 | speech_filepath = subdir / f'speech.{ext}' 70 | music_filepath = subdir / f'music.{ext}' 71 | sfx_filepath = subdir / f'sfx.{ext}' 72 | 73 | x_mix, sr = torchaudio.load(mix_filepath) 74 | x_speech, _ = torchaudio.load(speech_filepath) 75 | x_music, _ = torchaudio.load(music_filepath) 76 | x_sfx, _ = torchaudio.load(sfx_filepath) 77 | 78 | chunk_len = sr * length 79 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x_mix.numpy(), top_db=30) 80 | for j, k in enumerate(range(trim30dBs, trim30dBe-chunk_len, chunk_len)): 81 | chunk_mix = x_mix[..., k:k+chunk_len] 82 | chunk_speech = x_speech[..., k:k+chunk_len] 83 | chunk_music = x_music[..., k:k+chunk_len] 84 | chunk_sfx = x_sfx[..., k:k+chunk_len] 85 | 86 | # detect silence length 87 | is_active = True 88 | for audio_clip in [chunk_mix, chunk_speech, chunk_music, chunk_sfx]: 89 | audio_clip = audio_clip.reshape(-1) 90 | audio_energy = torch.stft(audio_clip, n_fft=stft_params.window_length, 91 | hop_length=stft_params.hop_length, win_length=stft_params.window_length, 92 | window=torch.hann_window(stft_params.window_length, device='cpu'), 93 | pad_mode=stft_params.padding_type, center=True, onesided=True, return_complex=True 94 | ).abs().sum(dim=0) 95 | count = sum(1 for item in audio_energy if item > 1e-6) 96 | if count < threshold * len(audio_energy): 97 | is_active = False 98 | 99 | # save if no silence detect 100 | if is_active: 101 | clip_id = f'{audio_id}_{j}' 102 | start = k 103 | end = k + chunk_len 104 | line = '{},{},{},{},{},{},{},{},{}\n'.format(clip_id, mix_filepath, speech_filepath, music_filepath, sfx_filepath, \ 105 | sr, chunk_len, start, end) 106 | f.write(line) 107 | 108 | ## test with no silence 109 | length = 10 110 | threshold = 0.5 111 | with open(out_dir / 'test_dnr.csv', 'w') as f: 112 | f.write('id,mix,speech,music,sfx,sr,length,start,end\n') 113 | for subdir in tqdm(sorted(val_dir.iterdir()), desc='test'): 114 | if subdir.is_dir(): 115 | audio_id = subdir.name 116 | mix_filepath = subdir / f'mix.{ext}' 117 | speech_filepath = subdir / f'speech.{ext}' 118 | music_filepath = subdir / f'music.{ext}' 119 | sfx_filepath = subdir / f'sfx.{ext}' 120 | 121 | x_mix, sr = torchaudio.load(mix_filepath) 122 | x_speech, _ = torchaudio.load(speech_filepath) 123 | x_music, _ = torchaudio.load(music_filepath) 124 | x_sfx, _ = torchaudio.load(sfx_filepath) 125 | 126 | chunk_len = sr * length 127 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x_mix.numpy(), top_db=30) 128 | for j, k in enumerate(range(trim30dBs, trim30dBe-chunk_len, chunk_len)): 129 | chunk_mix = x_mix[..., k:k+chunk_len] 130 | chunk_speech = x_speech[..., k:k+chunk_len] 131 | chunk_music = x_music[..., k:k+chunk_len] 132 | chunk_sfx = x_sfx[..., k:k+chunk_len] 133 | 134 | # detect silence length 135 | is_active = True 136 | for audio_clip in [chunk_mix, chunk_speech, chunk_music, chunk_sfx]: 137 | audio_clip = audio_clip.reshape(-1) 138 | audio_energy = torch.stft(audio_clip, n_fft=stft_params.window_length, 139 | hop_length=stft_params.hop_length, win_length=stft_params.window_length, 140 | window=torch.hann_window(stft_params.window_length, device='cpu'), 141 | pad_mode=stft_params.padding_type, center=True, onesided=True, return_complex=True 142 | ).abs().sum(dim=0) 143 | count = sum(1 for item in audio_energy if item > 1e-6) 144 | if count < threshold * len(audio_energy): 145 | is_active = False 146 | 147 | # save if no silence detect 148 | if is_active: 149 | clip_id = f'{audio_id}_{j}' 150 | start = k 151 | end = k + chunk_len 152 | line = '{},{},{},{},{},{},{},{},{}\n'.format(clip_id, mix_filepath, speech_filepath, music_filepath, sfx_filepath, \ 153 | sr, chunk_len, start, end) 154 | f.write(line) -------------------------------------------------------------------------------- /prepare/mani_dns_clean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | import os 7 | import argparse 8 | import librosa 9 | import torchaudio 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from collections import namedtuple 13 | import torch 14 | 15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5 clean speech', 16 | add_help=True, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/DNS-Challenge/', help='Audio Dataset Path') 20 | parser.add_argument('--partition', type=str, default='all', help='Audio partition to create manifest') 21 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest') 22 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short') 23 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format') 24 | 25 | args = parser.parse_args() 26 | data_dir = Path(args.data_dir) 27 | out_dir = Path(args.out_dir) 28 | partition = args.partition 29 | threshold = args.threshold 30 | ext = args.ext 31 | 32 | if partition == 'all': 33 | audio_sources = ['emotional_speech', 34 | 'read_speech', 35 | 'vctk_wav48_silence_trimmed', 36 | 'VocalSet_48kHz_mono', 37 | 'french_speech', 38 | 'german_speech', 39 | 'italian_speech', 40 | 'russian_speech', 41 | 'spanish_speech'] 42 | else: 43 | audio_sources = [partition] 44 | 45 | for audio_source in audio_sources: 46 | audio_dir = data_dir / 'datasets_fullband/clean_fullband' / audio_source 47 | audio_manifest = out_dir / f'speech_dns5_{audio_source}.csv' 48 | audio_len = 0 49 | with open(audio_manifest, 'w') as f: 50 | f.write('id,filepath,sr,length,start,end\n') 51 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'clean/{audio_source}'): 52 | audio_id = audio_filepath.stem 53 | x, sr = torchaudio.load(audio_filepath) 54 | length = x.shape[-1] 55 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 56 | utt_len = (trim30dBe - trim30dBs) / sr 57 | if utt_len >= threshold: 58 | audio_len += length / sr 59 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 60 | f.write(line) 61 | print('Source: {}. audio len: {:.2f}h'.format(audio_source, audio_len/3600)) 62 | 63 | -------------------------------------------------------------------------------- /prepare/mani_dns_noise.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | import os 7 | import argparse 8 | import librosa 9 | import torchaudio 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from collections import namedtuple 13 | import torch 14 | 15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5 noise data', 16 | add_help=True, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/DNS-Challenge/', help='Audio Dataset Path') 20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest') 21 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short') 22 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format') 23 | 24 | args = parser.parse_args() 25 | data_dir = Path(args.data_dir) 26 | out_dir = Path(args.out_dir) 27 | threshold = args.threshold 28 | ext = args.ext 29 | 30 | # noise 31 | audio_dir = data_dir / 'datasets_fullband/noise_fullband' 32 | audio_manifest = out_dir / f'sfx_dns5.csv' 33 | audio_len = 0 34 | with open(audio_manifest, 'w') as f: 35 | f.write('id,filepath,sr,length,start,end\n') 36 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'noise'): 37 | audio_id = audio_filepath.stem 38 | x, sr = torchaudio.load(audio_filepath) 39 | length = x.shape[-1] 40 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 41 | utt_len = (trim30dBe - trim30dBs) / sr 42 | if utt_len >= threshold: 43 | audio_len += length / sr 44 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 45 | f.write(line) 46 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600)) -------------------------------------------------------------------------------- /prepare/mani_jamendo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | import os 7 | import argparse 8 | import librosa 9 | import torchaudio 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from collections import namedtuple 13 | import torch 14 | 15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, MTG-Jamendo', 16 | add_help=True, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/mtg-jamendo/', help='Audio Dataset Path') 20 | parser.add_argument('--partition', type=int, default=-1, help='Audio partition to create manifest') 21 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest') 22 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short') 23 | parser.add_argument('--ext', type=str, default='mp3', choices=['wav', 'mp3', 'flac'], help='Audio format') 24 | 25 | args = parser.parse_args() 26 | data_dir = Path(args.data_dir) 27 | out_dir = Path(args.out_dir) 28 | partition = args.partition 29 | threshold = args.threshold 30 | ext = args.ext 31 | 32 | if partition == -1: 33 | audio_sources = list(range(100)) 34 | audio_manifest = out_dir / f'music_jamendo.csv' 35 | else: 36 | audio_sources = list(range(10*partition, 10*partition+10)) 37 | audio_manifest = out_dir / f'music_jamendo_{partition}.csv' 38 | 39 | with open(audio_manifest, 'w') as f: 40 | f.write('id,filepath,sr,length,start,end\n') 41 | audio_len = 0 42 | for audio_source in audio_sources: 43 | audio_dir = data_dir / f"{audio_source:0>2}" 44 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f"jamendo_{audio_source:0>2}"): 45 | audio_id = audio_filepath.stem 46 | x, sr = torchaudio.load(audio_filepath) 47 | length = x.shape[-1] 48 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 49 | utt_len = (trim30dBe - trim30dBs) / sr 50 | if utt_len >= threshold: 51 | audio_len += length / sr 52 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 53 | f.write(line) 54 | print('Source: music. audio len: {:.2f}h'.format(audio_len/3600)) 55 | 56 | -------------------------------------------------------------------------------- /prepare/mani_musan.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | import os 7 | import argparse 8 | import librosa 9 | import torchaudio 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from collections import namedtuple 13 | import torch 14 | 15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5', 16 | add_help=True, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/Libri-light/', help='Audio Dataset Path') 20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest') 21 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short') 22 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format') 23 | args = parser.parse_args() 24 | data_dir = Path(args.data_dir) 25 | out_dir = Path(args.out_dir) 26 | threshold = args.threshold 27 | ext = args.ext 28 | 29 | # speech 30 | audio_dir = data_dir / 'speech' 31 | audio_manifest = out_dir / f'speech_musan.csv' 32 | audio_len = 0 33 | with open(audio_manifest, 'w') as f: 34 | f.write('id,filepath,sr,length,start,end\n') 35 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'speech'): 36 | audio_id = audio_filepath.stem 37 | x, sr = torchaudio.load(audio_filepath) 38 | length = x.shape[-1] 39 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 40 | utt_len = (trim30dBe - trim30dBs) / sr 41 | if utt_len >= threshold: 42 | audio_len += length / sr 43 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 44 | f.write(line) 45 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600)) 46 | 47 | # music 48 | audio_dir = data_dir / 'music' 49 | audio_manifest = out_dir / f'music_musan.csv' 50 | audio_len = 0 51 | with open(audio_manifest, 'w') as f: 52 | f.write('id,filepath,sr,length,start,end\n') 53 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'music'): 54 | audio_id = audio_filepath.stem 55 | x, sr = torchaudio.load(audio_filepath) 56 | length = x.shape[-1] 57 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 58 | utt_len = (trim30dBe - trim30dBs) / sr 59 | if utt_len >= threshold: 60 | audio_len += length / sr 61 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 62 | f.write(line) 63 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600)) 64 | 65 | # noise 66 | audio_dir = data_dir / 'noise' 67 | audio_manifest = out_dir / f'sfx_musan.csv' 68 | audio_len = 0 69 | with open(audio_manifest, 'w') as f: 70 | f.write('id,filepath,sr,length,start,end\n') 71 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'noise'): 72 | audio_id = audio_filepath.stem 73 | x, sr = torchaudio.load(audio_filepath) 74 | length = x.shape[-1] 75 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 76 | utt_len = (trim30dBe - trim30dBs) / sr 77 | if utt_len >= threshold: 78 | audio_len += length / sr 79 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 80 | f.write(line) 81 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600)) -------------------------------------------------------------------------------- /prepare/mani_wham.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | import os 7 | import argparse 8 | import librosa 9 | import torchaudio 10 | from tqdm import tqdm 11 | from pathlib import Path 12 | from collections import namedtuple 13 | import torch 14 | 15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5', 16 | add_help=True, 17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | 19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/high_res_wham/', help='Audio Dataset Path') 20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest') 21 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short') 22 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format') 23 | 24 | args = parser.parse_args() 25 | data_dir = Path(args.data_dir) 26 | out_dir = Path(args.out_dir) 27 | threshold = args.threshold 28 | ext = args.ext 29 | 30 | # noise 31 | audio_dir = data_dir / 'audio' 32 | audio_manifest = out_dir / f'sfx_wham.csv' 33 | audio_len = 0 34 | with open(audio_manifest, 'w') as f: 35 | f.write('id,filepath,sr,length,start,end\n') 36 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'noise'): 37 | audio_id = audio_filepath.stem 38 | x, sr = torchaudio.load(audio_filepath) 39 | length = x.shape[-1] 40 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30) 41 | utt_len = (trim30dBe - trim30dBs) / sr 42 | if utt_len >= threshold: 43 | audio_len += length / sr 44 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe) 45 | f.write(line) 46 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600)) -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .audio_dnr import DatasetDnR 2 | from .audio_dataset import DatasetAudioTrain, DatasetAudioVal -------------------------------------------------------------------------------- /src/datasets/audio_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | 7 | import math 8 | import numpy as np 9 | import pandas as pd 10 | import julius 11 | from pathlib import Path 12 | from omegaconf import DictConfig 13 | from typing import List 14 | 15 | import torch 16 | import torchaudio 17 | from torch.utils.data import Dataset 18 | from accelerate.logging import get_logger 19 | 20 | logger = get_logger(__name__) 21 | 22 | 23 | 24 | class DatasetAudioTrain(Dataset): 25 | def __init__(self, 26 | sample_rate: int, 27 | speech: List[str], 28 | music: List[str], 29 | sfx: List[str], 30 | n_examples: int = 10000000, 31 | chunk_size: float = 2.0, 32 | trim_silence: bool = False, 33 | **kwargs 34 | ) -> None: 35 | super().__init__() 36 | 37 | # init 38 | self.EPS = 1e-8 39 | self.sample_rate = sample_rate # target sampling rate 40 | self.length = n_examples # pseudo dataset length 41 | self.chunk_size = chunk_size # negative for entire sentence 42 | self.trim_silence = trim_silence 43 | 44 | # manifest 45 | self.csv_files = {} 46 | self.csv_files['speech'] = [Path(filepath) for filepath in speech] 47 | self.csv_files['music'] = [Path(filepath) for filepath in music] 48 | self.csv_files['sfx'] = [Path(filepath) for filepath in sfx] 49 | 50 | # check valid samples 51 | self.resample_pool = dict() 52 | self.metadata_dict = dict() 53 | self.lens_dict = dict() 54 | for track, files in self.csv_files.items(): 55 | logger.info(f"Track: {track}") 56 | orig_utt, orig_len, drop_utt, drop_len = 0, 0, 0, 0 57 | metadata_list = [] 58 | for tsv_filepath in files: 59 | if not tsv_filepath.is_file(): 60 | logger.error('No tsv file found in: {}'.format(tsv_filepath)) 61 | continue 62 | else: 63 | logger.info(f'Manifest filepath: {tsv_filepath}') 64 | metadata = pd.read_csv(tsv_filepath) 65 | if self.trim_silence: 66 | wav_lens = (metadata['end'] - metadata['start']) / metadata['sr'] 67 | else: 68 | wav_lens = metadata['length'] / metadata['sr'] 69 | # remove wav files that too short 70 | orig_utt += len(metadata) 71 | drop_rows = [] 72 | for row_idx in range(len(wav_lens)): 73 | orig_len += wav_lens[row_idx] 74 | if wav_lens[row_idx] < self.chunk_size: 75 | drop_rows.append(row_idx) 76 | drop_utt += 1 77 | drop_len += wav_lens[row_idx] 78 | else: 79 | # prepare julius resample 80 | sr = int(metadata.at[row_idx, 'sr']) 81 | if sr not in self.resample_pool.keys(): 82 | old_sr = sr 83 | new_sr = self.sample_rate 84 | gcd = math.gcd(old_sr, new_sr) 85 | old_sr = old_sr // gcd 86 | new_sr = new_sr // gcd 87 | self.resample_pool[sr] = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr) 88 | 89 | metadata = metadata.drop(drop_rows) 90 | metadata_list.append(metadata) 91 | 92 | self.metadata_dict[track] = pd.concat(metadata_list, axis=0) 93 | self.lens_dict[track] = len(self.metadata_dict[track]) 94 | 95 | logger.info("Drop {}/{} utterances ({:.2f}/{:.2f}h), shorter than {:.2f}s".format( 96 | drop_utt, orig_utt, drop_len / 3600, orig_len / 3600, self.chunk_size 97 | )) 98 | logger.info('Used data: {} utterances, ({:.2f} h)'.format( 99 | self.lens_dict[track], (orig_len-drop_len) / 3600 100 | )) 101 | 102 | logger.info('Resample pool: {}'.format(list(self.resample_pool.keys()))) 103 | 104 | 105 | def __len__(self): 106 | return self.length # can be any number 107 | 108 | 109 | def __getitem__(self, idx:int): 110 | 111 | batch = {} 112 | for track in self.csv_files.keys(): 113 | idx = np.random.randint(self.lens_dict[track]) 114 | wav_info = self.metadata_dict[track].iloc[idx] 115 | chunk_len = int(wav_info['sr'] * self.chunk_size) 116 | 117 | # slice wav files 118 | if self.trim_silence: 119 | start = np.random.randint(int(wav_info['start']), int(wav_info['end']) - chunk_len + 1) 120 | else: 121 | start = np.random.randint(0, int(wav_info['length']) - chunk_len + 1) 122 | 123 | # load file 124 | x, sr = torchaudio.load(wav_info['filepath'], 125 | frame_offset=start, 126 | num_frames=chunk_len) 127 | 128 | # single channel 129 | x = x.mean(dim=0, keepdim=True) 130 | 131 | # resample 132 | if sr != self.sample_rate: 133 | x = self.resample_pool[sr](x) 134 | 135 | batch[track] = x 136 | 137 | return batch 138 | 139 | 140 | 141 | class DatasetAudioVal(Dataset): 142 | def __init__(self, 143 | sample_rate: int, 144 | tsv_filepath: str, 145 | chunk_size: float = 5.0, 146 | **kwargs 147 | ) -> None: 148 | super().__init__() 149 | 150 | # init 151 | self.EPS = 1e-8 152 | self.sample_rate = sample_rate # target sampling rate 153 | self.tsv_filepath = Path(tsv_filepath) 154 | self.chunk_size = chunk_size # negative for entire sentence 155 | self.resample_pool = dict() 156 | 157 | # read manifest tsv file 158 | if self.tsv_filepath.is_file(): 159 | metadata = pd.read_csv(self.tsv_filepath) 160 | logger.info(f'Manifest filepath: {self.tsv_filepath}') 161 | else: 162 | logger.error('No tsv file found in: {}'.format(self.tsv_filepath)) 163 | 164 | # audio lengths 165 | wav_lens = (metadata['end'] - metadata['start']) / metadata['sr'] 166 | 167 | # remove wav files that too short 168 | orig_utt = len(metadata) 169 | orig_len, drop_utt, drop_len = 0, 0, 0 170 | drop_rows = [] 171 | for row_idx in range(len(wav_lens)): 172 | orig_len += wav_lens[row_idx] 173 | if wav_lens[row_idx] < self.chunk_size: 174 | drop_rows.append(row_idx) 175 | drop_utt += 1 176 | drop_len += wav_lens[row_idx] 177 | else: 178 | # prepare julius resample 179 | sr = int(metadata.at[row_idx, 'sr']) 180 | if sr not in self.resample_pool.keys(): 181 | old_sr = sr 182 | new_sr = self.sample_rate 183 | gcd = math.gcd(old_sr, new_sr) 184 | old_sr = old_sr // gcd 185 | new_sr = new_sr // gcd 186 | self.resample_pool[sr] = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr) 187 | 188 | logger.info("Drop {}/{} utts ({:.2f}/{:.2f}h), shorter than {:.2f}s".format( 189 | drop_utt, orig_utt, drop_len / 3600, orig_len / 3600, self.chunk_size 190 | )) 191 | logger.info('Actual data size: {} utterance, ({:.2f} h)'.format( 192 | orig_utt-drop_utt, (orig_len-drop_len) / 3600 193 | )) 194 | logger.info('Resample pool: {}'.format(list(self.resample_pool.keys()))) 195 | 196 | self.metadata = metadata.drop(drop_rows) 197 | 198 | 199 | def __len__(self): 200 | return len(self.metadata) 201 | 202 | def __getitem__(self, idx:int): 203 | 204 | wav_info = self.metadata.iloc[idx] 205 | chunk_len = int(wav_info['sr'] * self.chunk_size) 206 | start = wav_info['start'] 207 | 208 | # Load wav files and resample if needed 209 | batch = {} 210 | for track in ['mix', 'speech', 'music', 'sfx']: 211 | x, sr = torchaudio.load(wav_info[track], 212 | frame_offset=start, 213 | num_frames=chunk_len) 214 | x = x.mean(dim=0, keepdim=True) 215 | # resample 216 | if sr != self.sample_rate: 217 | x = self.resample_pool[sr](x) 218 | batch[track] = x 219 | 220 | return batch 221 | 222 | 223 | 224 | if __name__ == '__main__': 225 | import time 226 | from tqdm import tqdm 227 | from torch.utils.data import DataLoader 228 | from accelerate import Accelerator 229 | from omegaconf import OmegaConf 230 | accelerator = Accelerator() 231 | 232 | # train 233 | cfg = OmegaConf.create() 234 | cfg.sample_rate = 16000 235 | cfg.speech = [ 236 | './manifest/speech_dnr.csv' 237 | ] 238 | cfg.music = [ 239 | './manifest/music_dnr.csv' 240 | ] 241 | cfg.sfx = [ 242 | './manifest/sfx_dnr.csv' 243 | ] 244 | cfg.general = [ 245 | './manifest/mix_dnr.csv' 246 | ] 247 | cfg.chunk_size = 2.0 248 | cfg.trim_silence = False 249 | train_dataset = DatasetAudioTrain(**cfg) 250 | 251 | print('Train data: {}'.format(len(train_dataset))) 252 | 253 | idx = np.random.randint(train_dataset.__len__()) 254 | data_ = train_dataset.__getitem__(idx) 255 | for k, v in data_.items(): 256 | print('audio idx: {} audio: {}, length: {}'.format(idx, k, v.shape)) 257 | 258 | # val 259 | cfg = OmegaConf.create() 260 | cfg.sample_rate = 16000 261 | cfg.tsv_filepath = './manifest/val.csv' 262 | cfg.chunk_size = 5.0 263 | val_dataset = DatasetAudioVal(**cfg) 264 | 265 | print('Validation data: {}'.format(len(val_dataset))) 266 | idx = np.random.randint(val_dataset.__len__()) 267 | data_ = val_dataset.__getitem__(idx) 268 | for k, v in data_.items(): 269 | print('audio idx: {} audio: {}, length: {}'.format(idx, k, v.shape)) 270 | 271 | # test 272 | cfg = OmegaConf.create() 273 | cfg.sample_rate = 16000 274 | cfg.tsv_filepath = './manifest/test.csv' 275 | cfg.chunk_size = 10.0 276 | test_dataset = DatasetAudioVal(**cfg) 277 | 278 | print('Test data: {}'.format(len(test_dataset))) 279 | idx = np.random.randint(test_dataset.__len__()) 280 | data_ = test_dataset.__getitem__(idx) 281 | for k, v in data_.items(): 282 | print('audio idx: {} audio: {}, length: {}'.format(idx, k, v.shape)) 283 | 284 | # dataloader 285 | train_dataset.length = 10000 286 | train_loader = DataLoader(dataset=train_dataset, 287 | batch_size=32, num_workers=8, 288 | shuffle=True, drop_last=True) 289 | 290 | total_seq = 0 291 | start_time = time.time() 292 | for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)): 293 | mix_audio = batch['speech'] 294 | total_seq += mix_audio.shape[0] 295 | # print(mix_audio.shape) 296 | # breakpoint() 297 | 298 | elapsed_time = time.time() - start_time 299 | tpf = elapsed_time / total_seq 300 | tpb = elapsed_time / (i+1) 301 | 302 | print(f"Read pure data time cost {tpf:.3f}s per file") 303 | print(f"Read pure data time cost {tpb:.3f}s per batch") 304 | 305 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn import L1Loss, MSELoss 3 | from .sdr import ( 4 | SingleSrcNegSDR, 5 | ) 6 | from .spectrum import ( 7 | MultiScaleSTFTLoss, 8 | MelSpectrogramLoss, 9 | ) 10 | 11 | from .adv import ( 12 | GANLoss, 13 | ) 14 | 15 | from .visqol import VisqolMetric -------------------------------------------------------------------------------- /src/metrics/adv.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | class GANLoss(nn.Module): 7 | """ 8 | Computes a discriminator loss, given a discriminator on 9 | generated waveforms/spectrograms compared to ground truth 10 | waveforms/spectrograms. Computes the loss for both the 11 | discriminator and the generator in separate functions. 12 | 13 | Implementation modified from: h 14 | """ 15 | 16 | def __init__(self, discriminator): 17 | super().__init__() 18 | self.discriminator = discriminator 19 | 20 | def forward(self, fake, real): 21 | d_fake = self.discriminator(fake) 22 | d_real = self.discriminator(real) 23 | return d_fake, d_real 24 | 25 | def discriminator_loss(self, fake, real): 26 | d_fake, d_real = self.forward(fake.clone().detach(), real) 27 | 28 | loss_d = 0 29 | for x_fake, x_real in zip(d_fake, d_real): 30 | loss_d += torch.mean(x_fake[-1] ** 2) 31 | loss_d += torch.mean((1 - x_real[-1]) ** 2) 32 | return loss_d 33 | 34 | def generator_loss(self, fake, real): 35 | d_fake, d_real = self.forward(fake, real) 36 | 37 | loss_g = 0 38 | for x_fake in d_fake: 39 | loss_g += torch.mean((1 - x_fake[-1]) ** 2) 40 | 41 | loss_feature = 0 42 | 43 | for i in range(len(d_fake)): 44 | for j in range(len(d_fake[i]) - 1): 45 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) 46 | return loss_g, loss_feature -------------------------------------------------------------------------------- /src/metrics/sdr.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.nn.modules.loss import _Loss 8 | 9 | class SingleSrcNegSDR(_Loss): 10 | r"""Base class for single-source negative SI-SDR, SD-SDR and SNR. 11 | 12 | Args: 13 | sdr_type (str): choose between ``snr`` for plain SNR, ``sisdr`` for 14 | SI-SDR and ``sdsdr`` for SD-SDR [1]. 15 | zero_mean (bool, optional): by default it zero mean the ref and 16 | estimate before computing the loss. 17 | take_log (bool, optional): by default the log10 of sdr is returned. 18 | reduction (string, optional): Specifies the reduction to apply to 19 | the output: 20 | ``'none'``: no reduction will be applied, 21 | ``'sum'``: the sum of the output 22 | ``'mean'``: the sum of the output will be divided by the number of 23 | elements in the output. 24 | 25 | Shape: 26 | - ests : :math:`(batch, time)`. 27 | - refs: :math:`(batch, time)`. 28 | 29 | Returns: 30 | :class:`torch.Tensor`: with shape :math:`(batch)` if ``reduction='none'`` else 31 | [] scalar if ``reduction='mean'``. 32 | 33 | Examples 34 | >>> import torch 35 | >>> from asteroid.losses import PITLossWrapper 36 | >>> refs = torch.randn(10, 2, 32000) 37 | >>> ests = torch.randn(10, 2, 32000) 38 | >>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"), 39 | >>> pit_from='pw_pt') 40 | >>> loss = loss_func(ests, refs) 41 | 42 | References 43 | [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE 44 | International Conference on Acoustics, Speech and Signal 45 | Processing (ICASSP) 2019. 46 | 47 | Implementation modified from Astroid project: https://github.com/asteroid-team/asteroid 48 | """ 49 | def __init__(self, 50 | sdr_type : str, 51 | zero_mean: bool = True, 52 | take_log: bool = True, 53 | reduction: str ="mean", 54 | EPS: float = 1e-8): 55 | super().__init__(reduction=reduction) 56 | 57 | assert sdr_type in ["snr", "sisdr", "sdsdr"] 58 | self.sdr_type = sdr_type 59 | self.zero_mean = zero_mean 60 | self.take_log = take_log 61 | self.EPS = 1e-8 62 | 63 | def forward(self, est, ref): 64 | 65 | assert est.shape == ref.shape, ( 66 | 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 67 | ) 68 | assert len(est.shape) == len(ref.shape) == 3, ( 69 | 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 70 | ) 71 | 72 | assert est.shape[1] == 1, ( 73 | 'expected mono channel, get {} channels'.format(est.shape[1]) 74 | ) 75 | 76 | B, C, T = est.shape 77 | est = est.reshape(-1, T) 78 | ref = ref.reshape(-1, T) 79 | 80 | # Step 1. Zero-mean norm 81 | if self.zero_mean: 82 | mean_source = torch.mean(ref, dim=1, keepdim=True) 83 | mean_estimate = torch.mean(est, dim=1, keepdim=True) 84 | ref = ref - mean_source 85 | est = est - mean_estimate 86 | # Step 2. Pair-wise SI-SDR. 87 | if self.sdr_type in ["sisdr", "sdsdr"]: 88 | # [batch, 1] 89 | dot = torch.sum(est * ref, dim=1, keepdim=True) 90 | # [batch, 1] 91 | s_ref_energy = torch.sum(ref**2, dim=1, keepdim=True) + self.EPS 92 | # [batch, time] 93 | scaled_ref = dot * ref / s_ref_energy 94 | else: 95 | # [batch, time] 96 | scaled_ref = ref 97 | if self.sdr_type in ["sdsdr", "snr"]: 98 | e_noise = est - ref 99 | else: 100 | e_noise = est - scaled_ref 101 | 102 | self.cache = torch.cat((ref, est, scaled_ref, e_noise), dim=0).detach().cpu() 103 | 104 | # [batch] 105 | losses = torch.sum(scaled_ref**2, dim=1) / (torch.sum(e_noise**2, dim=1) + self.EPS) 106 | 107 | if self.take_log: 108 | losses = 10 * torch.log10(losses + self.EPS) 109 | 110 | if self.reduction == "mean": 111 | losses = losses.mean() 112 | elif self.reduction == "sum": 113 | losses = losses.sum() 114 | else: 115 | losses = losses 116 | 117 | return -losses 118 | 119 | 120 | SingleSISDRLoss = SingleSrcNegSDR("sisdr") 121 | SingleSDSDRLoss = SingleSrcNegSDR("sdsdr") 122 | SingleSNRLoss = SingleSrcNegSDR("snr") 123 | -------------------------------------------------------------------------------- /src/metrics/spectrum.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List, Callable 3 | from collections import namedtuple 4 | import numpy as np 5 | from librosa.filters import mel 6 | 7 | import torch 8 | from torch import nn 9 | from torch import Tensor 10 | 11 | STFTParams = namedtuple( 12 | "STFTParams", 13 | ["window_length", "hop_length", "window_type", "padding_type"], 14 | ) 15 | 16 | 17 | def get_window(window_type: str, window_length: int, device: str): 18 | if window_type == "average": 19 | window = torch.ones(window_length) / window_length 20 | elif window_type == "sqrt_hann": 21 | window = torch.hann_window(window_length).sqrt() 22 | else: 23 | win_fn = getattr(torch, f'{window_type}_window') 24 | window = win_fn(window_length) 25 | 26 | window = window.to(device) 27 | return window 28 | 29 | 30 | class MultiScaleSTFTLoss(nn.Module): 31 | """Computes the multi-scale STFT loss from [1]. 32 | 33 | Parameters 34 | ---------- 35 | window_lengths : List[int], optional 36 | Length of each window of each STFT, by default [2048, 512] 37 | loss_fn : typing.Callable, optional 38 | How to compare each loss, by default nn.L1Loss() 39 | clamp_eps : float, optional 40 | Clamp on the log magnitude, below, by default 1e-5 41 | mag_weight : float, optional 42 | Weight of raw magnitude portion of loss, by default 1.0 43 | log_weight : float, optional 44 | Weight of log magnitude portion of loss, by default 1.0 45 | pow : float, optional 46 | Power to raise magnitude to before taking log, by default 2.0 47 | window_type : str, optional 48 | Type of window to use, by default ``hann``. 49 | match_stride : bool, optional 50 | Whether to match the stride of convolutional layers, by default False 51 | 52 | References 53 | ---------- 54 | 55 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. 56 | "DDSP: Differentiable Digital Signal Processing." 57 | International Conference on Learning Representations. 2019. 58 | 59 | Implementation modified from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 60 | """ 61 | 62 | def __init__( 63 | self, 64 | window_lengths: List[int] = [2048, 512], 65 | loss_fn: Callable = nn.L1Loss(), 66 | clamp_eps: float = 1e-5, 67 | mag_weight: float = 1.0, 68 | log_weight: float = 1.0, 69 | pow: float = 2.0, 70 | window_type: str = "hann", 71 | padding_type: str = "reflect", 72 | ): 73 | super().__init__() 74 | self.stft_params = [ 75 | STFTParams( 76 | window_length=w, 77 | hop_length=w // 4, 78 | window_type=window_type, 79 | padding_type=padding_type 80 | ) 81 | for w in window_lengths 82 | ] 83 | self.loss_fn = loss_fn 84 | self.clamp_eps = clamp_eps 85 | self.mag_weight = mag_weight 86 | self.log_weight = log_weight 87 | self.pow = pow 88 | 89 | 90 | def forward(self, est: Tensor, ref : Tensor): 91 | """Computes multi-scale STFT between an estimate and a reference 92 | signal. 93 | 94 | Parameters 95 | ---------- 96 | est : torch.Tensor [B, C, T] 97 | Estimate signal 98 | ref : torch.Tensor [B, C, T] 99 | Reference signal 100 | 101 | Returns 102 | ------- 103 | torch.Tensor 104 | Multi-scale STFT loss. 105 | """ 106 | device = est.device 107 | if ref.device != est.device: 108 | est.to(device) 109 | 110 | assert est.shape == ref.shape, ( 111 | 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 112 | ) 113 | assert len(est.shape) == len(ref.shape) == 3, ( 114 | 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 115 | ) 116 | 117 | B, C, T = est.shape 118 | est = est.reshape(-1, T) 119 | ref = ref.reshape(-1, T) 120 | 121 | loss = 0.0 122 | for s in self.stft_params: 123 | est_spec = torch.stft(est, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length, 124 | window=get_window(s.window_type, s.window_length, device), 125 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True) 126 | ref_spec = torch.stft(ref, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length, 127 | window=get_window(s.window_type, s.window_length, device), 128 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True) 129 | 130 | loss += self.log_weight * self.loss_fn( 131 | est_spec.abs().clamp(self.clamp_eps).pow(self.pow).log10(), 132 | ref_spec.abs().clamp(self.clamp_eps).pow(self.pow).log10(), 133 | ) 134 | loss += self.mag_weight * self.loss_fn(est_spec.abs(), ref_spec.abs()) 135 | return loss 136 | 137 | 138 | class MelSpectrogramLoss(nn.Module): 139 | """Compute distance between mel spectrograms. Can be used 140 | in a multi-scale way. 141 | 142 | Parameters 143 | ---------- 144 | n_mels : List[int] 145 | Number of mels per STFT, by default [150, 80], 146 | window_lengths : List[int], optional 147 | Length of each window of each STFT, by default [2048, 512] 148 | loss_fn : typing.Callable, optional 149 | How to compare each loss, by default nn.L1Loss() 150 | clamp_eps : float, optional 151 | Clamp on the log magnitude, below, by default 1e-5 152 | mag_weight : float, optional 153 | Weight of raw magnitude portion of loss, by default 1.0 154 | log_weight : float, optional 155 | Weight of log magnitude portion of loss, by default 1.0 156 | pow : float, optional 157 | Power to raise magnitude to before taking log, by default 2.0 158 | window_type : str, optional 159 | Type of window to use, by default ``hann``. 160 | match_stride : bool, optional 161 | Whether to match the stride of convolutional layers, by default False 162 | 163 | Implementation modified from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py 164 | """ 165 | 166 | def __init__( 167 | self, 168 | sr: int = 44100, 169 | n_mels: List[int] = [150, 80], 170 | mel_fmin: List[float] = [0.0, 0.0], 171 | mel_fmax: List[float] = [None, None], 172 | window_lengths: List[int] = [2048, 512], 173 | loss_fn: Callable = nn.L1Loss(), 174 | clamp_eps: float = 1e-5, 175 | mag_weight: float = 1.0, 176 | log_weight: float = 1.0, 177 | pow: float = 2.0, 178 | window_type: str = "hann", 179 | padding_type: str = "reflect", 180 | match_stride: bool = False, 181 | ): 182 | assert len(n_mels) == len(window_lengths) == len(mel_fmin) == len(mel_fmax), \ 183 | f'lengths are different, n_mels: {n_mels}, window_lengths: {window_lengths}, mel_fmin: {mel_fmin}, mel_fmax: {mel_fmax}' 184 | super().__init__() 185 | self.stft_params = [ 186 | STFTParams( 187 | window_length=w, 188 | hop_length=w // 4, 189 | window_type=window_type, 190 | padding_type=padding_type, 191 | ) 192 | for w in window_lengths 193 | ] 194 | self.sr = sr 195 | self.n_mels = n_mels 196 | self.loss_fn = loss_fn 197 | self.clamp_eps = clamp_eps 198 | self.log_weight = log_weight 199 | self.mag_weight = mag_weight 200 | self.mel_fmin = mel_fmin 201 | self.mel_fmax = mel_fmax 202 | self.pow = pow 203 | 204 | def forward(self, est: Tensor, ref : Tensor): 205 | """Computes multi-scale mel loss between an estimate and 206 | a reference signal. 207 | 208 | Parameters 209 | ---------- 210 | est : torch.Tensor [B, C, T] 211 | Estimate signal 212 | ref : torch.Tensor [B, C, T] 213 | Reference signal 214 | 215 | Returns 216 | ------- 217 | torch.Tensor 218 | Multi-scale Mel loss. 219 | """ 220 | device = est.device 221 | if ref.device != est.device: 222 | est.to(device) 223 | 224 | assert est.shape == ref.shape, ( 225 | 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 226 | ) 227 | assert len(est.shape) == len(ref.shape) == 3, ( 228 | 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 229 | ) 230 | 231 | B, C, T = est.shape 232 | est = est.reshape(-1, T) 233 | ref = ref.reshape(-1, T) 234 | 235 | loss = 0.0 236 | for n_mels, fmin, fmax, s in zip( 237 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params 238 | ): 239 | est_spec = torch.stft(est, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length, 240 | window=get_window(s.window_type, s.window_length, device), 241 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True) 242 | ref_spec = torch.stft(ref, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length, 243 | window=get_window(s.window_type, s.window_length, device), 244 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True) 245 | 246 | # convert to mel 247 | est_mag = est_spec.abs() 248 | ref_mag = ref_spec.abs() 249 | 250 | mel_basis = mel(sr=self.sr, n_fft=s.window_length, n_mels=n_mels, fmin=fmin, fmax=fmax) 251 | mel_basis = torch.from_numpy(mel_basis).to(device) 252 | 253 | est_mel = (est_mag.transpose(-1, -2) @ mel_basis.T).transpose(-1, -2) 254 | ref_mel = (ref_mag.transpose(-1, -2) @ mel_basis.T).transpose(-1, -2) 255 | 256 | 257 | loss += self.log_weight * self.loss_fn( 258 | est_mel.clamp(self.clamp_eps).pow(self.pow).log10(), 259 | ref_mel.clamp(self.clamp_eps).pow(self.pow).log10(), 260 | ) 261 | loss += self.mag_weight * self.loss_fn(est_mel, ref_mel) 262 | 263 | return loss -------------------------------------------------------------------------------- /src/metrics/visqol.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import soxr 4 | import torch 5 | import numpy as np 6 | from visqol import visqol_lib_py 7 | from visqol.pb2 import visqol_config_pb2 8 | from visqol.pb2 import similarity_result_pb2 9 | 10 | class VisqolMetric(): 11 | 12 | def __init__(self, 13 | mode='audio', 14 | reduction='mean'): 15 | 16 | self.reduction = reduction 17 | config = visqol_config_pb2.VisqolConfig() 18 | if mode == "audio": 19 | self.fs = 48000 20 | config.audio.sample_rate = self.fs 21 | config.options.use_speech_scoring = False 22 | svr_model_path = "libsvm_nu_svr_model.txt" 23 | elif mode == "speech": 24 | self.fs = 16000 25 | config.audio.sample_rate = self.fs 26 | config.options.use_speech_scoring = True 27 | svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" 28 | else: 29 | raise ValueError(f"Unrecognized mode: {mode}") 30 | 31 | config.options.svr_model_path = os.path.join( 32 | os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path) 33 | 34 | self.api = visqol_lib_py.VisqolApi() 35 | self.api.Create(config) 36 | 37 | def __call__(self, est, ref, sr=44100): 38 | assert est.shape == ref.shape, 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 39 | assert len(est.shape) == len(ref.shape) == 3, 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape) 40 | B, C, T = est.shape 41 | est = est.reshape(B*C, -1).detach().cpu().numpy().astype(np.float64) 42 | ref = ref.reshape(B*C, -1).detach().cpu().numpy().astype(np.float64) 43 | 44 | if sr != self.fs: 45 | est_list = [] 46 | ref_list = [] 47 | for i in range(est.shape[0]): 48 | est_list.append(soxr.resample(est[i], sr, self.fs)) 49 | ref_list.append(soxr.resample(ref[i], sr, self.fs)) 50 | est = np.array(est_list) 51 | ref = np.array(ref_list) 52 | 53 | ret = [] 54 | for i in range(est.shape[0]): 55 | ret.append(self.api.Measure(ref[i], est[i]).moslqo) 56 | 57 | if self.reduction == "mean": 58 | ret = np.mean(ret) 59 | elif self.reduction == "sum": 60 | ret = np.sum(ret) 61 | else: 62 | ret = ret 63 | return ret 64 | 65 | 66 | if __name__ == '__main__': 67 | eval_visqol = VisqolMetric(mode='audio') 68 | 69 | import soundfile as sf 70 | audio_name = '001_fragment_10.5' 71 | ref_audio_path = f'/home/xbie/Data/starnet/starnet_frag_16kHz/{audio_name}.wav' 72 | target_audio_path = ref_audio_path 73 | reference, fs_ref = sf.read(ref_audio_path) 74 | degraded, fs_target = sf.read(target_audio_path) 75 | print(f'Audio file ref: {fs_ref} Hz, target: {fs_target} Hz') 76 | 77 | est = torch.from_numpy(degraded).reshape(1, 1, -1) 78 | ref = torch.from_numpy(reference).reshape(1, 1, -1) 79 | metric_visqol = VisqolMetric(mode='audio') 80 | visqol = metric_visqol(est=est, ref=ref, sr=fs_ref) 81 | print(f'Visqol: {visqol:.2f}') 82 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sdcodec import SDCodec 2 | from .discriminator import Discriminator -------------------------------------------------------------------------------- /src/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | import julius # a little bit worse than soxr, but compatible with torch and cuda 7 | from einops import rearrange 8 | from collections import namedtuple 9 | 10 | from ..modules import WNConv1d, WNConv2d 11 | 12 | def get_window(window_type: str, window_length: int, device: str): 13 | if window_type == "average": 14 | window = torch.ones(window_length) / window_length 15 | elif window_type == "sqrt_hann": 16 | window = torch.hann_window(window_length).sqrt() 17 | else: 18 | win_fn = getattr(torch, f'{window_type}_window') 19 | window = win_fn(window_length) 20 | 21 | window = window.to(device) 22 | return window 23 | 24 | 25 | class MSD(nn.Module): 26 | def __init__(self, rate: int = 1, sample_rate: int = 44100): 27 | super().__init__() 28 | self.convs = nn.ModuleList( 29 | [ 30 | WNConv1d(1, 16, 15, 1, padding=7, act=True), 31 | WNConv1d(16, 64, 41, 4, groups=4, padding=20, act=True), 32 | WNConv1d(64, 256, 41, 4, groups=16, padding=20, act=True), 33 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20, act=True), 34 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20, act=True), 35 | WNConv1d(1024, 1024, 5, 1, padding=2, act=True), 36 | ] 37 | ) 38 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) 39 | 40 | # julius resample 41 | # https://adefossez.github.io/julius/julius/resample.html 42 | old_sr = sample_rate 43 | new_sr = sample_rate // rate 44 | gcd = math.gcd(old_sr, new_sr) 45 | old_sr = old_sr // gcd 46 | new_sr = new_sr // gcd 47 | self.resample = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr) 48 | 49 | def forward(self, x): 50 | x = self.resample(x) 51 | fmap = [] 52 | 53 | for l in self.convs: 54 | x = l(x) 55 | fmap.append(x) 56 | x = self.conv_post(x) 57 | fmap.append(x) 58 | 59 | return fmap 60 | 61 | 62 | class MPD(nn.Module): 63 | def __init__(self, period): 64 | super().__init__() 65 | self.period = period 66 | self.convs = nn.ModuleList( 67 | [ 68 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0), act=True), 69 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0), act=True), 70 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0), act=True), 71 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0), act=True), 72 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0), act=True), 73 | ] 74 | ) 75 | self.conv_post = WNConv2d( 76 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False 77 | ) 78 | 79 | def pad_to_period(self, x): 80 | t = x.shape[-1] 81 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect") 82 | return x 83 | 84 | def forward(self, x): 85 | fmap = [] 86 | 87 | x = self.pad_to_period(x) 88 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period) 89 | 90 | for layer in self.convs: 91 | x = layer(x) 92 | fmap.append(x) 93 | 94 | x = self.conv_post(x) 95 | fmap.append(x) 96 | 97 | return fmap 98 | 99 | 100 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] 101 | STFTParams = namedtuple( 102 | "STFTParams", 103 | ["window_length", "hop_length"], 104 | ) 105 | 106 | STFTParams = namedtuple( 107 | "STFTParams", 108 | ["window_length", "hop_length", "window_type", "padding_type", "match_stride"], 109 | ) 110 | 111 | class MRD(nn.Module): 112 | def __init__( 113 | self, 114 | window_length: int, 115 | hop_factor: float = 0.25, 116 | sample_rate: int = 44100, 117 | bands: list = BANDS, 118 | ): 119 | """Complex multi-band spectrogram discriminator. 120 | Parameters 121 | ---------- 122 | window_length : int 123 | Window length of STFT. 124 | hop_factor : float, optional 125 | Hop factor of the STFT, defaults to ``0.25 * window_length``. 126 | sample_rate : int, optional 127 | Sampling rate of audio in Hz, by default 44100 128 | bands : list, optional 129 | Bands to run discriminator over. 130 | """ 131 | super().__init__() 132 | 133 | self.window_length = window_length 134 | self.hop_factor = hop_factor 135 | self.sample_rate = sample_rate 136 | self.stft_params = STFTParams( 137 | window_length=window_length, 138 | hop_length=int(window_length * hop_factor), 139 | window_type='hann', 140 | padding_type='reflect', 141 | match_stride=True, 142 | ) 143 | 144 | n_fft = window_length // 2 + 1 145 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 146 | self.bands = bands 147 | 148 | ch = 32 149 | convs = lambda: nn.ModuleList( 150 | [ 151 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4), act=True), 152 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4), act=True), 153 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4), act=True), 154 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4), act=True), 155 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1), act=True), 156 | ] 157 | ) 158 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 159 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) 160 | 161 | def spectrogram(self, x): 162 | B, C, T = x.shape 163 | x = torch.stft(x.reshape(-1, T), 164 | n_fft=self.stft_params.window_length, 165 | hop_length=self.stft_params.hop_length, 166 | win_length=self.stft_params.window_length, 167 | window=torch.hann_window(self.stft_params.window_length).to(x.device), 168 | pad_mode='reflect', center=True, onesided=True, return_complex=True) 169 | x = torch.view_as_real(x) 170 | x = rearrange(x, "n f t c -> n c t f") 171 | # Split into bands 172 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 173 | return x_bands 174 | 175 | def forward(self, x): 176 | x_bands = self.spectrogram(x) 177 | fmap = [] 178 | 179 | x = [] 180 | for band, stack in zip(x_bands, self.band_convs): 181 | for layer in stack: 182 | band = layer(band) 183 | fmap.append(band) 184 | x.append(band) 185 | 186 | x = torch.cat(x, dim=-1) 187 | x = self.conv_post(x) 188 | fmap.append(x) 189 | 190 | return fmap 191 | 192 | 193 | class Discriminator(nn.Module): 194 | def __init__( 195 | self, 196 | sample_rate: int, 197 | rates: list = [], 198 | periods: list = [2, 3, 5, 7, 11], 199 | fft_sizes: list = [2048, 1024, 512], 200 | bands: list = BANDS, 201 | ): 202 | """Discriminator that combines multiple discriminators. 203 | 204 | Parameters 205 | ---------- 206 | sample_rate : int, needed 207 | Sampling rate of audio in Hz 208 | rates : list, optional 209 | MSD, sampling rates (in Hz), by default [] 210 | If empty, MSD is not used. 211 | periods : list, optional 212 | MPD, periods of samples, by default [2, 3, 5, 7, 11] 213 | fft_sizes : list, optional 214 | MRD, window sizes of the FFT, by default [2048, 1024, 512] 215 | bands : list, optional 216 | MRD, bands, by default `BANDS` 217 | """ 218 | super().__init__() 219 | discs = [] 220 | discs += [MSD(r, sample_rate=sample_rate) for r in rates] 221 | discs += [MPD(p) for p in periods] 222 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] 223 | self.discriminators = nn.ModuleList(discs) 224 | 225 | def preprocess(self, y): 226 | # Remove DC offset 227 | y = y - y.mean(dim=-1, keepdims=True) 228 | # Peak normalize the volume of input audio 229 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 230 | return y 231 | 232 | def forward(self, x): 233 | x = self.preprocess(x) 234 | fmaps = [d(x) for d in self.discriminators] 235 | return fmaps 236 | 237 | 238 | if __name__ == "__main__": 239 | disc = Discriminator(sample_rate=44100, 240 | rates=[], 241 | periods=[2, 3, 5, 7, 11], 242 | fft_sizes=[2048, 1024, 512], 243 | bands=[[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]]).to('cuda') 244 | x = torch.zeros(1, 1, 44100).to('cuda') 245 | 246 | total_params = sum(p.numel() for p in disc.parameters()) / 1e6 247 | print(f'Total params: {total_params:.2f} Mb') 248 | 249 | results = disc(x) 250 | print(len(results)) 251 | 252 | for i, result in enumerate(results): 253 | print(f"disc{i}") 254 | for i, r in enumerate(result): 255 | print(r.shape, r.mean(), r.min(), r.max()) 256 | print() 257 | -------------------------------------------------------------------------------- /src/models/sdcodec.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResDQ DAC-based Residual VQ-VAE model. 3 | """ 4 | import math 5 | import numpy as np 6 | from typing import List, Union, Tuple 7 | from typing import Optional 8 | import random 9 | from omegaconf import DictConfig 10 | 11 | import torch 12 | from torch import nn 13 | import torch.nn.functional as F 14 | from accelerate.logging import get_logger 15 | 16 | from ..modules import CodecMixin 17 | from .. import modules 18 | 19 | logger = get_logger(__name__) 20 | 21 | def init_weights(m): 22 | if isinstance(m, nn.Conv1d): 23 | nn.init.trunc_normal_(m.weight, std=0.02) # default kaiming_uniform_ 24 | nn.init.constant_(m.bias, 0) # default uniform_ 25 | 26 | 27 | class SDCodec(nn.Module, CodecMixin): 28 | """Source-aware Disentangled Neural Audio Codec. 29 | Args: 30 | 31 | """ 32 | def __init__( 33 | self, 34 | sample_rate: int, 35 | latent_dim: int = None, 36 | tracks: List[str] = ['speech', 'music', 'sfx'], 37 | enc_params: DictConfig = {'name': 'DACEncoder'}, 38 | dec_params: DictConfig = {'name': 'DACDecoder'}, 39 | quant_params: DictConfig = {'name': 'DACDecoder'}, 40 | pretrain: dict = {}, 41 | ): 42 | super().__init__() 43 | 44 | self.sample_rate = sample_rate 45 | self.tracks = tracks 46 | self.enc_params = enc_params 47 | self.dec_params = dec_params 48 | self.quant_params = quant_params 49 | self.pretrain = pretrain 50 | 51 | 52 | if latent_dim is None: 53 | latent_dim = self.enc_params.d_model * (2 ** len(self.enc_params.strides)) 54 | self.latent_dim = latent_dim 55 | self.hop_length = np.prod(self.enc_params.strides) 56 | 57 | 58 | # Define encoder and decoder 59 | enc_net = getattr(modules, self.enc_params.pop('name')) 60 | dec_net = getattr(modules, self.dec_params.pop('name')) 61 | self.encoder = enc_net(**self.enc_params) 62 | self.decoder = dec_net(**self.dec_params) 63 | 64 | # Define quantizer 65 | quant_net = getattr(modules, self.quant_params.pop('name')) 66 | self.quantizer = quant_net(tracks=self.tracks, **self.quant_params) 67 | 68 | # Init 69 | self.apply(init_weights) 70 | self.delay = self.get_delay() 71 | 72 | # Load pretrained 73 | load_pretrained = self.pretrain.get('load_pretrained', False) 74 | if load_pretrained: 75 | pretrained_dict = torch.load(load_pretrained) 76 | hyparam = pretrained_dict['metadata']['kwargs'] 77 | ignore_modules = self.pretrain.get('ignore_modules', []) 78 | freeze_modues = self.pretrain.get('freeze_modules', []) 79 | 80 | is_match = self._check_hyparam(hyparam) 81 | if is_match: 82 | self._load_pretrained(pretrained_dict['state_dict'], ignore_modules) 83 | self._freeze(freeze_modues) 84 | logger.info('Pretrain models load success from {}'.format(load_pretrained)) 85 | logger.info('-> modules ignored: {}'.format(ignore_modules)) 86 | logger.info('-> modules freezed: {}'.format(freeze_modues)) 87 | else: 88 | logger.info(f'Pretrain param do not match model, load pretrained failed...') 89 | logger.info('Pretrain params:') 90 | logger.info(hyparam) 91 | logger.info('Model params:') 92 | logger.info('Encoder: {}'.format(self.enc_params)) 93 | logger.info('Decoder: {}'.format(self.dec_params)) 94 | 95 | @property 96 | def device(self): 97 | """Gets the device the model is on by looking at the device of 98 | the first parameter. May not be valid if model is split across 99 | multiple devices. 100 | """ 101 | return list(self.parameters())[0].device 102 | 103 | 104 | def _check_hyparam(self, hyparam): 105 | return (hyparam['encoder_dim'] == self.enc_params['d_model']) \ 106 | and (hyparam['encoder_rates'] == self.enc_params['strides']) \ 107 | and (hyparam['decoder_dim'] == self.dec_params['d_model']) \ 108 | and (hyparam['sample_rate'] == self.sample_rate) 109 | 110 | 111 | def _load_pretrained(self, pretrained_state, ignored_modules=[]): 112 | own_state = self.state_dict() 113 | pretrained_state = {k: v for k, v in pretrained_state.items() if k.split('.')[0] not in ignored_modules} 114 | for k in own_state.keys(): 115 | if k in pretrained_state: 116 | own_state[k] = pretrained_state[k] 117 | elif 'quantizer' not in ignored_modules: 118 | own_key_list = k.split('.') 119 | if own_key_list[0] == 'quantizer': 120 | if own_key_list[1] == 'jitter_dict': 121 | continue 122 | elif own_key_list[1] == 'shared_rvq': 123 | own_key_list.pop(1) # shared_rvq 124 | else: 125 | own_key_list.pop(1) # rvq_dict 126 | own_key_list.pop(1) # (speech, music, sfx) 127 | pretrained_key = '.'.join(own_key_list) 128 | if pretrained_key not in pretrained_state: 129 | print(k) 130 | print(pretrained_key) 131 | breakpoint() 132 | own_state[k] = pretrained_state[pretrained_key] 133 | 134 | 135 | def _freeze(self, freeze_modues=[]): 136 | for module in freeze_modues: 137 | child = getattr(self, module) 138 | for param in child.parameters(): 139 | param.requires_grad = False 140 | 141 | 142 | def preprocess(self, audio_data, sample_rate) -> torch.Tensor: 143 | if sample_rate is None: 144 | sample_rate = self.sample_rate 145 | assert sample_rate == self.sample_rate 146 | 147 | length = audio_data.shape[-1] 148 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length 149 | audio_data = F.pad(audio_data, (0, right_pad)) 150 | 151 | return audio_data 152 | 153 | 154 | def encode(self, audio_data: torch.Tensor) -> torch.Tensor: 155 | """Encode given audio data and return quantized latent codes 156 | 157 | Parameters 158 | ---------- 159 | audio_data : Tensor[B x 1 x T] 160 | Audio data to encode 161 | 162 | Returns 163 | ------- 164 | "feats" : Tensor[B x D x T] 165 | Continuous features before quantization 166 | """ 167 | return self.encoder(audio_data) 168 | 169 | 170 | def decode(self, z: torch.Tensor) -> torch.Tensor: 171 | """Decode given latent codes and return audio data 172 | 173 | Parameters 174 | ---------- 175 | z : Tensor[B x D x T] 176 | Quantized continuous representation of input 177 | 178 | Returns 179 | ------- 180 | "audio" : Tensor[B x 1 x length] 181 | Decoded audio data. 182 | """ 183 | return self.decoder(z) 184 | 185 | 186 | def quantize( 187 | self, 188 | feats: torch.Tensor, 189 | track: str = 'speech', 190 | n_quantizers: int = None, 191 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 192 | """Encode given audio data and return quantized latent codes 193 | 194 | Parameters 195 | ---------- 196 | feats : Tensor[B x D x T] 197 | Continuous features before quantization 198 | track: str 199 | Specify which quantizer to be used 200 | n_quantizers : int, optional 201 | Number of quantizers to use, by default None 202 | If None, all quantizers are used. 203 | 204 | Returns 205 | ------- 206 | "z" : Tensor[B x D x T] 207 | Quantized continuous representation of input 208 | "codes" : Tensor[B x N x T] 209 | Codebook indices for each codebook 210 | (quantized discrete representation of input) 211 | "latents" : Tensor[B x N*D' x T] 212 | Projected latents (continuous representation of input before quantization) 213 | "vq/commitment_loss" : Tensor[1] 214 | Commitment loss to train encoder to predict vectors closer to codebook 215 | entries 216 | "vq/codebook_loss" : Tensor[1] 217 | Codebook loss to update the codebook 218 | """ 219 | assert track in self.tracks, 'f{track} not included in quantizer' 220 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer( 221 | track, feats, n_quantizers 222 | ) 223 | return z, codes, latents, commitment_loss, codebook_loss 224 | 225 | 226 | def forward( 227 | self, 228 | batch: Optional[dict], 229 | sample_rate: int = None, 230 | n_quantizers: int = None, 231 | ): 232 | """Model forward pass 233 | 234 | Parameters 235 | ---------- 236 | batch : dict of Tensor[B x 1 x T] 237 | Batch input with audio data 238 | sample_rate : int, optional 239 | Sample rate of audio data in Hz, by default None 240 | If None, defaults to `self.sample_rate` 241 | n_quantizers : int, optional 242 | Number of quantizers to use, by default None. 243 | If None, all quantizers are used. 244 | Returns 245 | ------- 246 | dict 247 | A dictionary with the following keys: 248 | "track/z" : Tensor[B x D x T] 249 | Quantized continuous representation of input 250 | "track/codes" : Tensor[B x N x T] 251 | Codebook indices for each codebook 252 | (quantized discrete representation of input) 253 | "track/latents" : Tensor[B x N*D x T] 254 | Projected latents (continuous representation of input before quantization) 255 | "vq/commitment_loss" : Tensor[1] 256 | Commitment loss to train encoder to predict vectors closer to codebook 257 | entries 258 | "vq/codebook_loss" : Tensor[1] 259 | Codebook loss to update the codebook 260 | "length" : int 261 | Number of samples in input audio 262 | "ref" : Tensor[B x (K+1) x 1 x length] 263 | Decoded audio data. 264 | "recon" : Tensor[B x (K+1) x 1 x length] 265 | Decoded audio data. 266 | """ 267 | 268 | # mix by masking 269 | audio_data = batch['mix'] 270 | bs, _, length = audio_data.shape 271 | valid_tracks = batch['valid_tracks'] 272 | mask = [1 if t in valid_tracks else 0 for t in self.tracks] 273 | mask = torch.tensor(mask, device=audio_data.device) 274 | 275 | # preprocess, zero-padding to proper length 276 | audio_data = self.preprocess(audio_data, sample_rate) 277 | 278 | # encoder 279 | feats = self.encode(audio_data) 280 | 281 | # quantize 282 | dict_z = {} 283 | dict_commit = {} 284 | dict_cb = {} 285 | # for i, track in enumerate(self.tracks): 286 | for i, track in enumerate(self.tracks): 287 | # quantize 288 | z, codes, latents, commitment_loss, codebook_loss = self.quantize( 289 | feats, track, n_quantizers 290 | ) 291 | # ppl 292 | probs = F.one_hot(codes.detach(), num_classes=self.quant_params.codebook_size[i]).float() # (B, N, T, num_class) 293 | avg_probs = probs.mean(dim=(0,2)) # (N, num_class) 294 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)) # (N,) 295 | # save retsults 296 | batch[f'{track}/z'] = z 297 | batch[f'{track}/codes'] = codes 298 | batch[f'{track}/latents'] = latents 299 | batch[f'{track}/ppl'] = perplexity 300 | dict_z[track] = z 301 | dict_commit[track] = commitment_loss 302 | dict_cb[track] = codebook_loss 303 | 304 | # commit and codebook loss 305 | commit_loss = (torch.stack([dict_commit[t] for t in self.tracks]) * mask).sum() / len(valid_tracks) 306 | cb_loss = (torch.stack([dict_cb[t] for t in self.tracks]) * mask).sum() / len(valid_tracks) 307 | 308 | # re-mix and decode 309 | if batch['random_swap']: 310 | z_mix = torch.stack([dict_z[t][batch['shuffle_list'][t]] for t in self.tracks]) 311 | else: 312 | z_mix = torch.stack([dict_z[t] for t in self.tracks]) 313 | z_mix = (z_mix * mask[:, None, None, None]).sum(dim=0) 314 | x_remix = self.decode(z_mix) # (B, 1, T) 315 | 316 | # collect data 317 | sep_out = [self.decode(dict_z[t]) for t in valid_tracks] 318 | audio_recon = torch.stack([x_remix] + sep_out, dim=1) # (B, K, 1, T) 319 | 320 | batch['recon'] = audio_recon[..., :length] 321 | batch['length'] = length 322 | batch['vq/commitment_loss'] = commit_loss 323 | batch['vq/codebook_loss'] = cb_loss 324 | 325 | return batch 326 | 327 | 328 | def evaluate( 329 | self, 330 | input_audio: torch.Tensor, 331 | sample_rate: int = None, 332 | n_quantizers: int = None, 333 | output_tracks: list[str] = ['mix'], 334 | ) -> torch.Tensor: 335 | """Model evaluation 336 | Parameters 337 | ---------- 338 | input_audio : Tensor[B x 1 x T] 339 | Audio data to encode 340 | sample_rate : int, optional 341 | Sample rate of audio data in Hz, by default None 342 | If None, defaults to `self.sample_rate` 343 | n_quantizers : int, optional 344 | Number of quantizers to use, by default None. 345 | If None, all quantizers are used. 346 | output_tracks : List[str] 347 | List of track to return 348 | 349 | Returns 350 | ------- 351 | output_audio : Tensor[B x K x T] 352 | Output audio with K tracks 353 | """ 354 | assert all((t in self.tracks) or (t=='mix') for t in output_tracks); \ 355 | "output tracks {} not included in model tracks {}".format(output_tracks, self.tracks) 356 | 357 | bs, _, length = input_audio.shape 358 | audio_data = self.preprocess(input_audio, sample_rate) # (B, 1, T) 359 | 360 | # encoder 361 | feats = self.encode(audio_data) 362 | 363 | # quantization 364 | latent_dict = {} 365 | for track in self.tracks: 366 | z, codes, latents, commitment_loss, codebook_loss = self.quantize( 367 | feats, track, n_quantizers 368 | ) 369 | latent_dict[track] = z 370 | 371 | # decoder 372 | list_out = [] 373 | for track in output_tracks: 374 | if track == 'mix': 375 | z_mix = torch.stack(list(latent_dict.values()), dim=0).sum(dim=0) 376 | x_out = self.decode(z_mix) 377 | list_out.append(x_out) 378 | else: 379 | x_out = self.decode(latent_dict[track]) 380 | list_out.append(x_out) 381 | 382 | output_audio = torch.cat(list_out, dim=1)[...,:length] 383 | 384 | return output_audio 385 | 386 | 387 | 388 | if __name__ == "__main__": 389 | from accelerate import Accelerator 390 | from omegaconf import OmegaConf 391 | import numpy as np 392 | from functools import partial 393 | 394 | accelerator = Accelerator() 395 | logger = get_logger(__name__) 396 | 397 | cfg = OmegaConf.load('/home/xbie/Code/ResDQ/config/model/sdcodec_16k.yaml') 398 | sample_rate = 16000 399 | tracks = ['speech', 'music', 'sfx'] 400 | mask = [True, False, False] 401 | show_params = False 402 | 403 | model_name = cfg.pop('name') 404 | model = SDCodec(sample_rate=sample_rate, **cfg).to("cuda") 405 | 406 | if show_params: 407 | for n, m in model.named_modules(): 408 | o = m.extra_repr() 409 | p = sum(p.numel() for p in m.parameters()) 410 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params." 411 | setattr(m, "extra_repr", partial(fn, o=o, p=p)) 412 | print(model) 413 | else: 414 | print("Total # of params: {:.2f} M".format(sum(p.numel() for p in model.parameters())/1e6)) 415 | 416 | print('Show model device') 417 | print(model.encoder.block[0].weight.device) 418 | print(model.quantizer_dict['speech'].quantizers[0].in_proj.weight.device) 419 | 420 | 421 | # test on random input 422 | bs = 2 423 | length = 1 424 | batch = {} 425 | length = int(sample_rate * length) 426 | for t in tracks: 427 | batch[t] = torch.randn(bs, 1, length).to(model.device) 428 | batch[t].requires_grad_(True) 429 | batch[t].retain_grad() 430 | 431 | # Make a forward pass 432 | batch = model(batch, 433 | track_mask=mask, 434 | random_swap=True) 435 | print("Input shape:", batch['ref'].shape) 436 | print("Output shape:", batch['recon'].shape) 437 | 438 | # Create gradient variable 439 | B, N,_, T = batch['recon'].shape 440 | model_out = batch['recon'].reshape(-1, 1, T) 441 | grad = torch.zeros_like(model_out) 442 | grad[:, :, grad.shape[-1] // 2] = 1 443 | 444 | # Make a backward pass 445 | model_out.backward(grad) 446 | 447 | # Check non-zero values 448 | gradmap = batch['speech'].grad.squeeze(0) 449 | gradmap = (gradmap != 0).sum(0) # sum across features 450 | rf = (gradmap != 0).sum() 451 | 452 | print(f"Receptive field: {rf.item()}") 453 | 454 | 455 | ## test on real data 456 | import pandas as pd 457 | import torchaudio 458 | 459 | idx = 0 460 | tsv_filepath = '/home/xbie/Data/dnr_v2/val_mini_2s.tsv' 461 | 462 | metadata = pd.read_csv(tsv_filepath) 463 | wav_info = metadata.iloc[idx] 464 | 465 | x_speech, sr = torchaudio.load(wav_info['speech']) 466 | x_music, sr = torchaudio.load(wav_info['music']) 467 | x_sfx, sr = torchaudio.load(wav_info['sfx']) 468 | x_mix, sr = torchaudio.load(wav_info['mix']) 469 | 470 | x_speech = x_speech[..., wav_info['trim30dBs']: wav_info['trim30dBe']].reshape(1, 1, -1).to(model.device) 471 | x_music = x_music[..., wav_info['trim30dBs']: wav_info['trim30dBe']].reshape(1, 1, -1).to(model.device) 472 | x_sfx = x_sfx[..., wav_info['trim30dBs']: wav_info['trim30dBe']].reshape(1, 1, -1).to(model.device) 473 | 474 | batch = {'speech': x_speech, 475 | 'music': x_music, 476 | 'sfx': x_sfx} 477 | 478 | batch = model(batch, 479 | track_mask=mask, 480 | random_swap=True) 481 | 482 | print('Reference audio shape: {}'.format(batch['ref'].shape)) 483 | print('Recon audio shape: {}'.format(batch['recon'].shape)) 484 | 485 | # batch['ref'] = batch['ref'].detach().cpu() 486 | # batch['recon'] = batch['recon'].detach().cpu() 487 | 488 | # # for i, c in enumerate(['speech', 'music', 'sfx', 'mix']): 489 | # # path = '{}_{}_refRaw.wav'.format(wav_info['id'], c) 490 | # # torchaudio.save(path, batch['ref'][0, i], 44100) 491 | 492 | # # for i, c in enumerate(['speech', 'music', 'sfx', 'mix']): 493 | # # path = '{}_{}_reconRaw.wav'.format(wav_info['id'], c) 494 | # # torchaudio.save(path, batch['recon'][0, i], 44100) 495 | # # breakpoint() 496 | # breakpoint() -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | """Modules used for building the models.""" 4 | 5 | 6 | from .layers import ( 7 | WNConv1d, 8 | WNConv2d, 9 | WNConvTranspose1d, 10 | Snake1d, 11 | SLSTM, 12 | Jitter, 13 | ) 14 | 15 | from .base_dac import ( 16 | DACEncoder, 17 | DACDecoder, 18 | DACEncoderTrans, 19 | DACDecoderTrans, 20 | CodecMixin, 21 | ) 22 | 23 | 24 | from .quantize import ( 25 | VectorQuantize, 26 | ResidualVectorQuantize, 27 | MultiSourceRVQ, 28 | ) -------------------------------------------------------------------------------- /src/modules/base_dac.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic NN module from DAC 3 | """ 4 | import math 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from .layers import Snake1d, WNConv1d, WNConvTranspose1d, TransformerSentenceEncoderLayer 10 | 11 | 12 | class ResidualUnit(nn.Module): 13 | def __init__(self, dim: int = 16, dilation: int = 1): 14 | super().__init__() 15 | k = 7 # kernal size for the first conv 16 | pad = ((k - 1) * dilation) // 2 # 2*p - d*(k-1) 17 | self.block = nn.Sequential( 18 | Snake1d(dim), 19 | WNConv1d(dim, dim, kernel_size=k, dilation=dilation, padding=pad), 20 | Snake1d(dim), 21 | WNConv1d(dim, dim, kernel_size=1), 22 | ) 23 | 24 | def forward(self, x): 25 | y = self.block(x) 26 | pad = (x.shape[-1] - y.shape[-1]) // 2 # identical in-out channel 27 | if pad > 0: 28 | x = x[..., pad:-pad] 29 | return x + y 30 | 31 | 32 | class EncoderBlock(nn.Module): 33 | def __init__(self, dim: int = 16, stride: int = 1): 34 | super().__init__() 35 | self.block = nn.Sequential( 36 | ResidualUnit(dim // 2, dilation=1), 37 | ResidualUnit(dim // 2, dilation=3), 38 | ResidualUnit(dim // 2, dilation=9), 39 | Snake1d(dim // 2), 40 | WNConv1d( 41 | dim // 2, 42 | dim, 43 | kernel_size=2 * stride, 44 | stride=stride, 45 | padding=math.ceil(stride / 2), 46 | ), 47 | ) 48 | 49 | def forward(self, x): 50 | return self.block(x) 51 | 52 | 53 | class DecoderBlock(nn.Module): 54 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): 55 | super().__init__() 56 | self.block = nn.Sequential( 57 | Snake1d(input_dim), 58 | WNConvTranspose1d( 59 | input_dim, 60 | output_dim, 61 | kernel_size=2 * stride, 62 | stride=stride, 63 | padding=math.floor(stride / 2), # ceil() -> floor(), for 16kHz 64 | ), 65 | ResidualUnit(output_dim, dilation=1), 66 | ResidualUnit(output_dim, dilation=3), 67 | ResidualUnit(output_dim, dilation=9), 68 | ) 69 | 70 | def forward(self, x): 71 | return self.block(x) 72 | 73 | 74 | class DACEncoder(nn.Module): 75 | def __init__( 76 | self, 77 | d_model: int = 64, 78 | strides: list = [2, 4, 8, 8], 79 | d_latent: int = 1024, 80 | ): 81 | super().__init__() 82 | # Create first convolution 83 | layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 84 | 85 | # Create EncoderBlockd_models that double channels as they downsample by `stride` 86 | for stride in strides: 87 | d_model *= 2 88 | layers += [EncoderBlock(d_model, stride=stride)] 89 | 90 | # Create last convolution 91 | layers += [ 92 | Snake1d(d_model), 93 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1), 94 | ] 95 | 96 | # Wrap black into nn.Sequential 97 | self.block = nn.Sequential(*layers) 98 | self.enc_dim = d_model 99 | 100 | def forward(self, x): 101 | return self.block(x) 102 | 103 | 104 | class DACEncoderTrans(nn.Module): 105 | def __init__( 106 | self, 107 | d_model: int = 64, 108 | strides: list = [2, 4, 8, 8], 109 | att_d_model: int = 512, 110 | att_nhead: int = 8, 111 | att_ff: int = 2048, 112 | att_norm_first: bool = False, 113 | att_layers: int = 1, 114 | d_latent: int = 1024, 115 | ): 116 | super().__init__() 117 | # Create first convolution 118 | layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)] 119 | 120 | # Create EncoderBlockd_models that double channels as they downsample by `stride` 121 | for stride in strides: 122 | d_model *= 2 123 | layers += [EncoderBlock(d_model, stride=stride)] 124 | 125 | # Create convolution 126 | layers += [ 127 | Snake1d(d_model), 128 | WNConv1d(d_model, att_d_model, kernel_size=3, padding=1), 129 | ] 130 | 131 | # Create attention layers 132 | layers += [ 133 | TransformerSentenceEncoderLayer(d_model=att_d_model, nhead=att_nhead, 134 | dim_feedforward=att_ff, norm_first=att_norm_first, 135 | num_layers=att_layers) 136 | ] 137 | 138 | # Create last convolution 139 | layers += [ 140 | Snake1d(att_d_model), 141 | WNConv1d(att_d_model, d_latent, kernel_size=3, padding=1), 142 | ] 143 | 144 | # Wrap black into nn.Sequential 145 | self.block = nn.Sequential(*layers) 146 | self.enc_dim = d_model 147 | 148 | def forward(self, x): 149 | return self.block(x) 150 | 151 | 152 | class DACDecoder(nn.Module): 153 | def __init__( 154 | self, 155 | d_model: int = 1536, 156 | strides: list = [8, 8, 4, 2], 157 | d_latent: int = 1024, 158 | d_out: int = 1, 159 | ): 160 | super().__init__() 161 | 162 | # Add first conv layer 163 | layers = [WNConv1d(d_latent, d_model, kernel_size=7, padding=3)] 164 | 165 | # Add upsampling + MRF blocks (from HiFi GAN) 166 | for stride in strides: 167 | layers += [DecoderBlock(d_model, d_model//2, stride)] 168 | d_model = d_model // 2 169 | 170 | # Add final conv layer 171 | layers += [ 172 | Snake1d(d_model), 173 | WNConv1d(d_model, d_out, kernel_size=7, padding=3), 174 | nn.Tanh(), 175 | ] 176 | 177 | self.model = nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | return self.model(x) 181 | 182 | 183 | class DACDecoderTrans(nn.Module): 184 | def __init__( 185 | self, 186 | d_model: int = 1536, 187 | strides: list = [8, 8, 4, 2], 188 | d_latent: int = 1024, 189 | att_d_model: int = 512, 190 | att_nhead: int = 8, 191 | att_ff: int = 2048, 192 | att_norm_first: bool = False, 193 | att_layers: int = 1, 194 | d_out: int = 1, 195 | ): 196 | super().__init__() 197 | 198 | # Add first conv layer 199 | layers = [WNConv1d(d_latent, att_d_model, kernel_size=7, padding=3)] 200 | 201 | # Add attention layer 202 | layers += [ 203 | TransformerSentenceEncoderLayer(d_model=att_d_model, nhead=att_nhead, 204 | dim_feedforward=att_ff, norm_first=att_norm_first, 205 | num_layers=att_layers) 206 | ] 207 | 208 | # Add conv layer 209 | layers += [ 210 | Snake1d(att_d_model), 211 | WNConv1d(att_d_model, d_model, kernel_size=7, padding=3) 212 | ] 213 | 214 | # Add upsampling + MRF blocks (from HiFi GAN) 215 | for stride in strides: 216 | layers += [DecoderBlock(d_model, d_model//2, stride)] 217 | d_model = d_model // 2 218 | 219 | # Add final conv layer 220 | layers += [ 221 | Snake1d(d_model), 222 | WNConv1d(d_model, d_out, kernel_size=7, padding=3), 223 | nn.Tanh(), 224 | ] 225 | 226 | self.model = nn.Sequential(*layers) 227 | 228 | def forward(self, x): 229 | return self.model(x) 230 | 231 | 232 | 233 | class CodecMixin: 234 | """Truncated version of DAC CodecMixin 235 | """ 236 | def get_delay(self): 237 | # Any number works here, delay is invariant to input length 238 | l_out = self.get_output_length(0) 239 | L = l_out 240 | 241 | layers = [] 242 | for layer in self.modules(): 243 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 244 | layers.append(layer) 245 | 246 | for layer in reversed(layers): 247 | d = layer.dilation[0] 248 | k = layer.kernel_size[0] 249 | s = layer.stride[0] 250 | 251 | if isinstance(layer, nn.ConvTranspose1d): 252 | L = ((L - d * (k - 1) - 1) / s) + 1 253 | elif isinstance(layer, nn.Conv1d): 254 | L = (L - 1) * s + d * (k - 1) + 1 255 | 256 | L = math.ceil(L) 257 | 258 | l_in = L 259 | 260 | return (l_in - l_out) // 2 261 | 262 | def get_output_length(self, input_length): 263 | L = input_length 264 | # Calculate output length 265 | for layer in self.modules(): 266 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): 267 | d = layer.dilation[0] 268 | k = layer.kernel_size[0] 269 | s = layer.stride[0] 270 | 271 | if isinstance(layer, nn.Conv1d): 272 | L = ((L - d * (k - 1) - 1) / s) + 1 273 | elif isinstance(layer, nn.ConvTranspose1d): 274 | L = (L - 1) * s + d * (k - 1) + 1 275 | 276 | L = math.floor(L) 277 | return L -------------------------------------------------------------------------------- /src/modules/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import TransformerEncoderLayer 4 | from torch.nn.utils.parametrizations import weight_norm 5 | from torch.distributions import Categorical 6 | 7 | def WNConv1d(*args, **kwargs): 8 | act = kwargs.pop("act", False) 9 | conv = weight_norm(nn.Conv1d(*args, **kwargs)) 10 | if act: 11 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 12 | else: 13 | return conv 14 | 15 | 16 | def WNConvTranspose1d(*args, **kwargs): 17 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 18 | 19 | 20 | 21 | def WNConv2d(*args, **kwargs): 22 | act = kwargs.pop("act", False) 23 | conv = weight_norm(nn.Conv2d(*args, **kwargs)) 24 | if act: 25 | return nn.Sequential(conv, nn.LeakyReLU(0.1)) 26 | else: 27 | return conv 28 | 29 | # Scripting this brings model speed up 1.4x 30 | @torch.jit.script 31 | def snake(x, alpha): 32 | shape = x.shape 33 | x = x.reshape(shape[0], shape[1], -1) 34 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 35 | x = x.reshape(shape) 36 | return x 37 | 38 | class Snake1d(nn.Module): 39 | def __init__(self, channels): 40 | super().__init__() 41 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 42 | 43 | def forward(self, x): 44 | return snake(x, self.alpha) 45 | 46 | 47 | class TransformerSentenceEncoderLayer(nn.Module): 48 | """ 49 | Stacks of TransformerEncoderlayer used for conv layer output 50 | """ 51 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 52 | norm_first=False, bias=True, num_layers=1): 53 | super().__init__() 54 | layers = [TransformerEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, dropout=0.1, 55 | batch_first=False, norm_first=norm_first, bias=bias) 56 | for _ in range(num_layers)] 57 | self.blocks = nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | assert len(x.shape) == 3 61 | x = x.permute(2, 0, 1) # (bs, C, L) -> (L, bs, C) 62 | x = self.blocks(x) 63 | x = x.permute(1, 2, 0) # (L, bs, C) -> (bs, C, L) 64 | return x 65 | 66 | 67 | class SLSTM(nn.Module): 68 | """ 69 | LSTM without worrying about the hidden state, nor the layout of the data. 70 | Expects input as convolutional layout. 71 | Modified from the EnCodec: https://github.com/facebookresearch/encodec/blob/main/encodec/modules/lstm.py 72 | """ 73 | def __init__(self, dimension: int, num_layers: int = 2, bidirectional: bool = True, skip: bool = True): 74 | super().__init__() 75 | self.skip = skip 76 | self.bidirectional = bidirectional 77 | self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=self.bidirectional) 78 | if bidirectional: 79 | self.linear_ = nn.Linear(dimension*2, dimension) 80 | 81 | def forward(self, x): 82 | x = x.permute(2, 0, 1) 83 | y, _ = self.lstm(x) 84 | if self.bidirectional: 85 | y = self.linear_(y) 86 | if self.skip: 87 | y = y + x 88 | y = y.permute(1, 2, 0) 89 | return y 90 | 91 | 92 | 93 | class Jitter(nn.Module): 94 | """ 95 | Shuffule the input by randomly swapping with neighborhood 96 | Modified from the SQ-VAE speech: https://github.com/sony/sqvae/blob/main/speech/model.py#L76 97 | Args: 98 | p: probability to shuffle code 99 | size: kernel size to shuffle code 100 | """ 101 | def __init__(self, p, size=3): 102 | super().__init__() 103 | self.p = p 104 | self.size = size 105 | prob = torch.ones(size) * p / (size - 1) 106 | prob[size//2] = 1 - p 107 | self.register_buffer("prob", prob) 108 | 109 | def forward(self, x): 110 | if not self.training or self.p == 0.0: 111 | return x 112 | else: 113 | batch_size, dim, T = x.size() 114 | 115 | dist = Categorical(probs=self.prob) 116 | index = dist.sample(torch.Size([batch_size, T])) - len(self.prob)//2 117 | index += torch.arange(T, device=x.device) 118 | index.clamp_(0, T-1) 119 | x = torch.gather(x, -1, index.unsqueeze(1).expand(-1, dim, -1)) 120 | 121 | return x 122 | 123 | 124 | if __name__ == "__main__": 125 | 126 | jit = Jitter(p=0.8, size=5) 127 | print(jit.prob) 128 | 129 | x = torch.arange(20).reshape(1,1,20) 130 | x = torch.cat((x, x+100), dim=0) 131 | print(x) 132 | print(jit(x)) -------------------------------------------------------------------------------- /src/modules/quantize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from DAC: https://github.com/descriptinc/descript-audio-codec 3 | Correct some arguement description 4 | """ 5 | from typing import Union 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from einops import rearrange 12 | 13 | from .layers import WNConv1d, Jitter 14 | 15 | 16 | class VectorQuantize(nn.Module): 17 | """ 18 | Implementation of VQ similar to Karpathy's repo: 19 | https://github.com/karpathy/deep-vector-quantization 20 | Additionally uses following tricks from Improved VQGAN 21 | (https://arxiv.org/pdf/2110.04627.pdf): 22 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space 23 | for improved codebook usage 24 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which 25 | improves training stability 26 | """ 27 | 28 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): 29 | super().__init__() 30 | self.codebook_size = codebook_size 31 | self.codebook_dim = codebook_dim 32 | 33 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) 34 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) 35 | self.codebook = nn.Embedding(codebook_size, codebook_dim) 36 | 37 | def forward(self, z): 38 | """Quantized the input tensor using a fixed codebook and returns 39 | the corresponding codebook vectors 40 | 41 | Parameters 42 | ---------- 43 | z : Tensor[B x D x T] 44 | 45 | Returns 46 | ------- 47 | "z_q" : Tensor[B x D x T] 48 | Quantized continuous representation of input 49 | "vq/commitment_loss" : Tensor[1] 50 | Commitment loss to train encoder to predict vectors closer to codebook 51 | entries 52 | "vq/codebook_loss" : Tensor[1] 53 | Codebook loss to update the codebook 54 | "codes" : Tensor[B x T] 55 | Codebook indices (quantized discrete representation of input) 56 | "latents" : Tensor[B x D' x T] 57 | Projected latents (continuous representation of input before quantization) 58 | """ 59 | 60 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space 61 | z_e = self.in_proj(z) # z_e : (B x D x T) 62 | z_q, indices = self.decode_latents(z_e) 63 | 64 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) 65 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) 66 | 67 | z_q = ( 68 | z_e + (z_q - z_e).detach() 69 | ) # noop in forward pass, straight-through gradient estimator in backward pass 70 | 71 | z_q = self.out_proj(z_q) 72 | 73 | return z_q, commitment_loss, codebook_loss, indices, z_e 74 | 75 | def embed_code(self, embed_id): 76 | return F.embedding(embed_id, self.codebook.weight) 77 | 78 | def decode_code(self, embed_id): 79 | return self.embed_code(embed_id).transpose(1, 2) 80 | 81 | def decode_latents(self, latents): 82 | encodings = rearrange(latents, "b d t -> (b t) d") 83 | codebook = self.codebook.weight # codebook: (N x D) 84 | 85 | # L2 normalize encodings and codebook (ViT-VQGAN) 86 | encodings = F.normalize(encodings) 87 | codebook = F.normalize(codebook) 88 | 89 | # Compute euclidean distance with codebook 90 | dist = ( 91 | encodings.pow(2).sum(1, keepdim=True) 92 | - 2 * encodings @ codebook.t() 93 | + codebook.pow(2).sum(1, keepdim=True).t() 94 | ) 95 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 96 | z_q = self.decode_code(indices) 97 | return z_q, indices 98 | 99 | 100 | class ResidualVectorQuantize(nn.Module): 101 | """ 102 | Introduced in SoundStream: An end2end neural audio codec 103 | https://arxiv.org/abs/2107.03312 104 | """ 105 | 106 | def __init__( 107 | self, 108 | input_dim: int = 1024, 109 | n_codebooks: int = 9, 110 | codebook_size: int = 1024, 111 | codebook_dim: Union[int, list] = 8, 112 | quantizer_dropout: float = 0.0, 113 | ): 114 | super().__init__() 115 | if isinstance(codebook_dim, int): 116 | codebook_dim = [codebook_dim for _ in range(n_codebooks)] 117 | 118 | self.n_codebooks = n_codebooks 119 | self.codebook_dim = codebook_dim 120 | self.codebook_size = codebook_size 121 | 122 | self.quantizers = nn.ModuleList( 123 | [ 124 | VectorQuantize(input_dim, codebook_size, codebook_dim[i]) 125 | for i in range(n_codebooks) 126 | ] 127 | ) 128 | self.quantizer_dropout = quantizer_dropout 129 | 130 | def forward(self, z, n_quantizers: int = None): 131 | """Quantized the input tensor using a fixed set of `n` codebooks and returns 132 | the corresponding codebook vectors 133 | Parameters 134 | ---------- 135 | z : Tensor[B x D x T] 136 | n_quantizers : int, optional 137 | No. of quantizers to use 138 | (n_quantizers < self.n_codebooks ex: for quantizer dropout) 139 | Note: if `self.quantizer_dropout` is True, this argument is ignored 140 | when in training mode, and a random number of quantizers is used. 141 | 142 | Returns 143 | ------- 144 | "z_q" : Tensor[B x D x T] 145 | Quantized continuous representation of input 146 | "codes" : Tensor[B x N x T] 147 | Codebook indices for each codebook 148 | "latents" : Tensor[B x N*D' x T] 149 | Concatenated projected latents (continuous representation of input before quantization) 150 | "vq/commitment_loss" : Tensor[1] 151 | Commitment loss to train encoder to predict vectors closer to codebook entries 152 | "vq/codebook_loss" : Tensor[1] 153 | Codebook loss to update the codebook 154 | """ 155 | z_q = 0 156 | residual = z 157 | commitment_loss = 0 158 | codebook_loss = 0 159 | 160 | codebook_indices = [] 161 | latents = [] 162 | 163 | if n_quantizers is None: 164 | n_quantizers = self.n_codebooks 165 | if self.training: 166 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 167 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) 168 | n_dropout = int(z.shape[0] * self.quantizer_dropout) 169 | n_quantizers[:n_dropout] = dropout[:n_dropout] 170 | n_quantizers = n_quantizers.to(z.device) 171 | 172 | for i, quantizer in enumerate(self.quantizers): 173 | if self.training is False and i >= n_quantizers: 174 | break 175 | 176 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( 177 | residual 178 | ) 179 | 180 | # Create mask to apply quantizer dropout 181 | mask = ( 182 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers 183 | ) 184 | z_q = z_q + z_q_i * mask[:, None, None] 185 | residual = residual - z_q_i 186 | 187 | # Sum losses 188 | commitment_loss += (commitment_loss_i * mask).mean() 189 | codebook_loss += (codebook_loss_i * mask).mean() 190 | 191 | codebook_indices.append(indices_i) 192 | latents.append(z_e_i) 193 | 194 | codes = torch.stack(codebook_indices, dim=1) 195 | latents = torch.cat(latents, dim=1) 196 | 197 | return z_q, codes, latents, commitment_loss, codebook_loss 198 | 199 | def from_codes(self, codes: torch.Tensor): 200 | """Given the quantized codes, reconstruct the continuous representation 201 | Parameters 202 | ---------- 203 | codes : Tensor[B x N x T] 204 | Codebook indices for each codebook 205 | 206 | Returns 207 | ------- 208 | "z_q" : Tensor[B x D x T] 209 | Quantized continuous representation of input 210 | "z_p" : Tensor[B x N*D' x T] 211 | Concatenated quantized codes before up-projection 212 | """ 213 | z_q = 0.0 214 | z_p = [] 215 | n_codebooks = codes.shape[1] 216 | for i in range(n_codebooks): 217 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) 218 | z_p.append(z_p_i) 219 | 220 | z_q_i = self.quantizers[i].out_proj(z_p_i) 221 | z_q = z_q + z_q_i 222 | return z_q, torch.cat(z_p, dim=1), codes 223 | 224 | def from_latents(self, latents: torch.Tensor): 225 | """Given the unquantized latents, reconstruct the 226 | continuous representation after quantization. 227 | 228 | Parameters 229 | ---------- 230 | latents : Tensor[B x N*D x T] 231 | Concatenated projected latents (continuous representation of input before quantization) 232 | 233 | Returns 234 | ------- 235 | "z_q" : Tensor[B x D x T] 236 | Quantized representation of full-projected space 237 | "z_p" : Tensor[B x N*D' x T] 238 | Concatenated quantized codes before up-projection 239 | "codes" : Tensor[B x N x T] 240 | Codebook indices for each codebook 241 | """ 242 | z_q = 0 243 | z_p = [] 244 | codes = [] 245 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) 246 | 247 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 248 | 0 249 | ] 250 | for i in range(n_codebooks): 251 | j, k = dims[i], dims[i + 1] 252 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) 253 | z_p.append(z_p_i) 254 | codes.append(codes_i) 255 | 256 | z_q_i = self.quantizers[i].out_proj(z_p_i) 257 | z_q = z_q + z_q_i 258 | 259 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) 260 | 261 | 262 | class MultiSourceRVQ(nn.Module): 263 | """ 264 | Parallele RVQ for multiple sources 265 | """ 266 | def __init__( 267 | self, 268 | tracks: list[str] = ['speech', 'music', 'sfx'], 269 | input_dim: int = 1024, 270 | n_codebooks: list[int] = [12, 12, 12], 271 | codebook_size: list[int] = [1024, 1024, 1024], 272 | codebook_dim: list[int] = [8, 8, 8], 273 | quantizer_dropout: float = 0.0, 274 | code_jit_prob: list[float] = [0.0, 0.0, 0.0], 275 | code_jit_size: list[int] = [3, 5, 3], 276 | shared_codebooks: int = 8, 277 | ): 278 | super().__init__() 279 | self.tracks = tracks 280 | self.input_dim = input_dim 281 | self.n_codebooks = n_codebooks 282 | self.codebook_size = codebook_size 283 | self.codebook_dim = codebook_dim 284 | self.quantizer_dropout = quantizer_dropout 285 | self.code_jit_prob = code_jit_prob 286 | self.code_jit_size = code_jit_size 287 | self.shared_codebooks = shared_codebooks 288 | 289 | self.rvq_dict = nn.ModuleDict() 290 | self.jitter_dict = nn.ModuleDict() 291 | 292 | for i, t in enumerate(self.tracks): 293 | self.rvq_dict[t] = ResidualVectorQuantize( 294 | input_dim=self.input_dim, 295 | n_codebooks=self.n_codebooks[i]-self.shared_codebooks, 296 | codebook_size=self.codebook_size[i], 297 | codebook_dim=self.codebook_dim[i], 298 | quantizer_dropout=self.quantizer_dropout, 299 | ) 300 | self.jitter_dict[t] = Jitter( 301 | p=self.code_jit_prob[i], 302 | size=self.code_jit_size[i], 303 | ) 304 | 305 | if shared_codebooks > 0: 306 | self.shared_rvq = ResidualVectorQuantize( 307 | input_dim=self.input_dim, 308 | n_codebooks=self.shared_codebooks, 309 | codebook_size=self.codebook_size[0], 310 | codebook_dim=self.codebook_dim[0], 311 | quantizer_dropout=self.quantizer_dropout, 312 | ) 313 | 314 | def forward(self, track_name, feats, n_quantizers: int = None): 315 | assert track_name in self.tracks, '{} not in model tracks: {}'.format(track_name, self.tracks) 316 | 317 | # number of quantizer to be used 318 | if n_quantizers is None: 319 | n_quantizers = self.n_codebooks[self.tracks.index(track_name)] 320 | 321 | # quantize 322 | z, codes, latents, commitment_loss, codebook_loss = self.rvq_dict[track_name]( 323 | feats, n_quantizers 324 | ) 325 | 326 | # shared codebook 327 | if self.shared_codebooks > 0 and n_quantizers > self.rvq_dict[track_name].n_codebooks: 328 | z_shard, codes_shard, latents_shard, commitment_loss_shard, codebook_loss_shard = self.shared_rvq( 329 | feats-z, n_quantizers-self.rvq_dict[track_name].n_codebooks 330 | ) 331 | z = z + z_shard 332 | codes = torch.cat((codes, codes_shard), dim=1) # (B, N, T) 333 | latents = torch.cat((latents, latents_shard), dim=1) # (B, N*D', T) 334 | commitment_loss += commitment_loss_shard 335 | codebook_loss += codebook_loss_shard 336 | 337 | # jitter, ignored if prob <= 0 338 | z = self.jitter_dict[track_name](z) # (B, D, T) 339 | 340 | return z, codes, latents, commitment_loss, codebook_loss 341 | 342 | 343 | if __name__ == "__main__": 344 | rvq = MultiSourceRVQ(quantizer_dropout=0.0) 345 | feats = torch.randn(16, 1024, 20) 346 | z_q, codes, latents, commitment_loss, codebook_loss = rvq('speech', feats) 347 | print(z_q.shape) 348 | print(codes.shape) 349 | print(latents.shape) 350 | print(commitment_loss) 351 | print(codebook_loss) 352 | 353 | # z_q, z_p, codes = rvq.from_latents(latents) 354 | # print(z_q[0,0]) 355 | # print(codes[0,0]) -------------------------------------------------------------------------------- /src/optim/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scheduler modified from AudioCraft project https://github.com/facebookresearch/audiocraft/tree/main 3 | """ 4 | 5 | # flake8: noqa 6 | from .cosine_lr_scheduler import CosineLRScheduler 7 | from .exponential_lr_scheduler import ExponentialLRScheduler 8 | from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler 9 | from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler 10 | from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler 11 | from .reduce_plateau_lr_scheduler import ReducePlateauLRScheduler -------------------------------------------------------------------------------- /src/optim/cosine_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | class CosineLRScheduler(_LRScheduler): 14 | """Cosine LR scheduler. 15 | 16 | Args: 17 | optimizer (Optimizer): Torch optimizer. 18 | warmup_steps (int): Number of warmup steps. 19 | total_steps (int): Total number of steps. 20 | lr_min_ratio (float): Minimum learning rate. 21 | cycle_length (float): Cycle length. 22 | """ 23 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, 24 | lr_min_ratio: float = 0.0, cycle_length: float = 1.0): 25 | self.warmup_steps = warmup_steps 26 | assert self.warmup_steps >= 0 27 | self.total_steps = total_steps 28 | assert self.total_steps >= warmup_steps 29 | self.lr_min_ratio = lr_min_ratio 30 | self.cycle_length = cycle_length 31 | super().__init__(optimizer) 32 | 33 | def _get_sched_lr(self, lr: float, step: int): 34 | if step < self.warmup_steps: 35 | lr_ratio = step / self.warmup_steps 36 | lr = lr_ratio * lr 37 | elif step <= self.total_steps: 38 | s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) 39 | lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ 40 | (1. + math.cos(math.pi * s / self.cycle_length)) 41 | lr = lr_ratio * lr 42 | else: 43 | lr_ratio = self.lr_min_ratio 44 | lr = lr_ratio * lr 45 | return lr 46 | 47 | def get_lr(self): 48 | return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] -------------------------------------------------------------------------------- /src/optim/exponential_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | class ExponentialLRScheduler(_LRScheduler): 14 | """Exponential LR scheduler. 15 | 16 | Args: 17 | optimizer (Optimizer): Torch optimizer. 18 | warmup_steps (int): Number of warmup steps. 19 | total_steps (int): Total number of steps. 20 | lr_min_ratio (float): Minimum learning rate. 21 | cycle_length (float): Cycle length. 22 | """ 23 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, 24 | lr_min_ratio: float = 0.0, gamma: float = 0.99999): 25 | self.warmup_steps = warmup_steps 26 | assert self.warmup_steps >= 0 27 | self.total_steps = total_steps 28 | assert self.total_steps >= warmup_steps 29 | self.lr_min_ratio = lr_min_ratio 30 | self.cumprod = 1 31 | self.gamma = gamma 32 | super().__init__(optimizer) 33 | 34 | def _get_sched_lr(self, lr: float, step: int): 35 | if step < self.warmup_steps: 36 | lr_ratio = step / self.warmup_steps 37 | lr = lr_ratio * lr 38 | elif step <= self.total_steps: 39 | self.cumprod *= self.gamma 40 | lr_ratio = self.cumprod 41 | lr = lr_ratio * lr 42 | else: 43 | lr_ratio = self.lr_min_ratio 44 | lr = lr_ratio * lr 45 | return lr 46 | 47 | def get_lr(self): 48 | return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] -------------------------------------------------------------------------------- /src/optim/inverse_sqrt_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | class InverseSquareRootLRScheduler(_LRScheduler): 14 | """Inverse square root LR scheduler. 15 | 16 | Args: 17 | optimizer (Optimizer): Torch optimizer. 18 | warmup_steps (int): Number of warmup steps. 19 | warmup_init_lr (tp.Optional[float]): Initial learning rate 20 | during warmup phase. When not set, use the provided learning rate. 21 | """ 22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): 23 | self.warmup_steps = warmup_steps 24 | self.warmup_init_lr = warmup_init_lr 25 | super().__init__(optimizer) 26 | 27 | def _get_sched_lr(self, lr: float, step: int): 28 | if step < self.warmup_steps: 29 | warmup_init_lr = self.warmup_init_lr or 0 30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps 31 | lr = warmup_init_lr + step * lr_step 32 | else: 33 | decay_factor = lr * self.warmup_steps**0.5 34 | lr = decay_factor * step**-0.5 35 | return lr 36 | 37 | def get_lr(self): 38 | return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /src/optim/linear_warmup_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | class LinearWarmupLRScheduler(_LRScheduler): 14 | """Inverse square root LR scheduler. 15 | 16 | Args: 17 | optimizer (Optimizer): Torch optimizer. 18 | warmup_steps (int): Number of warmup steps. 19 | warmup_init_lr (tp.Optional[float]): Initial learning rate 20 | during warmup phase. When not set, use the provided learning rate. 21 | """ 22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): 23 | self.warmup_steps = warmup_steps 24 | self.warmup_init_lr = warmup_init_lr 25 | super().__init__(optimizer) 26 | 27 | def _get_sched_lr(self, lr: float, step: int): 28 | if step < self.warmup_steps: 29 | warmup_init_lr = self.warmup_init_lr or 0 30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps 31 | lr = warmup_init_lr + step * lr_step 32 | return lr 33 | 34 | def get_lr(self): 35 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /src/optim/polynomial_decay_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | 11 | class PolynomialDecayLRScheduler(_LRScheduler): 12 | """Polynomial decay LR scheduler. 13 | 14 | Args: 15 | optimizer (Optimizer): Torch optimizer. 16 | warmup_steps (int): Number of warmup steps. 17 | total_steps (int): Total number of steps. 18 | end_lr (float): Final learning rate to achieve over total number of steps. 19 | zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0. 20 | power (float): Decay exponent. 21 | """ 22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int, 23 | end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.): 24 | self.warmup_steps = warmup_steps 25 | self.total_steps = total_steps 26 | self.end_lr = end_lr 27 | self.zero_lr_warmup_steps = zero_lr_warmup_steps 28 | self.power = power 29 | super().__init__(optimizer) 30 | 31 | def _get_sched_lr(self, lr: float, step: int): 32 | if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps: 33 | lr = 0 34 | elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps: 35 | lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps) 36 | lr = lr_ratio * lr 37 | elif step >= self.total_steps: 38 | lr = self.end_lr 39 | else: 40 | total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps 41 | lr_range = lr - self.end_lr 42 | pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps) 43 | lr = lr_range * pct_remaining ** self.power + self.end_lr 44 | return lr 45 | 46 | def get_lr(self): 47 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /src/optim/reduce_plateau_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import math 4 | 5 | from torch.optim import Optimizer 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | 9 | class ReducePlateauLRScheduler(_LRScheduler): 10 | """Cosine LR scheduler. 11 | Add warmup steps for torch.optim.lr_scheduler.ReduceLROnPlateau 12 | 13 | Args: 14 | optimizer (Optimizer): Torch optimizer. 15 | warmup_steps (int): Number of warmup steps. 16 | total_steps (int): Total number of steps. 17 | lr_min_ratio (float): Minimum learning rate. 18 | cycle_length (float): Cycle length. 19 | """ 20 | 21 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, 22 | lr_min_ratio: float = 0.0, cycle_length: float = 1.0): 23 | # TODO 24 | pass -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2024 by Telecom-Paris 3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr) 4 | License agreement in LICENSE.txt 5 | """ 6 | import sys 7 | import shutil 8 | import numpy as np 9 | from pathlib import Path 10 | from omegaconf import DictConfig, OmegaConf 11 | from hydra.core.hydra_config import HydraConfig 12 | from collections import OrderedDict, defaultdict 13 | import random 14 | from einops import rearrange 15 | 16 | import torch 17 | from torch.utils.data import DataLoader 18 | import torch.nn.functional as F 19 | from accelerate import Accelerator 20 | from accelerate.logging import get_logger 21 | from accelerate.utils import set_seed 22 | from torch.distributions import Categorical 23 | 24 | from src import datasets, models, optim, utils 25 | from .metrics import ( 26 | VisqolMetric, 27 | SingleSrcNegSDR, 28 | MultiScaleSTFTLoss, 29 | MelSpectrogramLoss, 30 | GANLoss, 31 | ) 32 | 33 | 34 | logger = get_logger(__name__) # avoid douplicated print, params defined by Hydra 35 | accelerator = Accelerator(project_dir=HydraConfig.get().runtime.output_dir, 36 | step_scheduler_with_optimizer=False, 37 | log_with="tensorboard") 38 | 39 | class Trainer(object): 40 | def __init__(self, cfg: DictConfig): 41 | 42 | # Init 43 | self.cfg = cfg 44 | OmegaConf.set_struct(self.cfg, False) # enable config.pop() 45 | self.project_dir = Path(accelerator.project_dir) 46 | self.device = accelerator.device 47 | logger.info('Init Trainer') 48 | 49 | # Fix random 50 | seed = self.cfg.training.get('seed', False) 51 | if seed: 52 | set_seed(seed) 53 | 54 | # Backup code, only on main process 55 | if self.cfg.backup_code and accelerator.is_main_process: 56 | back_dir = self.project_dir / 'backup_src' 57 | logger.info(f'Backup code at: {back_dir}') 58 | cwd = HydraConfig.get().runtime.cwd 59 | src_dir = Path(cwd) / 'src' 60 | if back_dir.exists(): 61 | shutil.rmtree(back_dir) 62 | shutil.copytree(src=src_dir, dst=back_dir) 63 | 64 | # Checkpoint 65 | self.ckpt_best = self.project_dir / 'ckpt_best' 66 | self.ckptdir = self.project_dir / 'checkpoints' 67 | self.ckptdir.mkdir(exist_ok=True) 68 | self.ckpt_final = self.project_dir / 'ckpt_final' 69 | self.ckpt_final.mkdir(exist_ok=True) 70 | 71 | # Check resume 72 | if self.cfg.resume: 73 | resume_dir = self.cfg.get('resume_dir', None) 74 | if resume_dir is None: 75 | logger.info(f'No resume_dir provided, try to resume from best ckpt directory...') 76 | self.ckpt_resume = self.ckpt_best 77 | else: 78 | self.ckpt_resume = self.project_dir / resume_dir 79 | if not self.ckpt_best.is_dir(): 80 | self.cfg.resume = False 81 | logger.info(f'Resume FAILED, no ckpt dir at: {self.ckpt_best}') 82 | 83 | # Tensorboard tracker 84 | accelerator.init_trackers(project_name='tb') 85 | self.tracker = accelerator.get_tracker("tensorboard") 86 | logger.info('Tracker backend: tensorboard') 87 | 88 | # Prepare dataset 89 | self.sr = self.cfg.sampling_rate 90 | logger.info('=====> Training dataloader') 91 | self.train_loader = self._get_data(self.cfg.dataset.trainset_cfg, self.cfg.training.dataloader, 92 | is_train=True, sample_rate=self.sr) 93 | logger.info('=====> Validation dataloader') 94 | self.val_loader = self._get_data(self.cfg.dataset.valset_cfg, self.cfg.training.dataloader, 95 | is_train=False, sample_rate=self.sr) 96 | logger.info('=====> Test dataloader') 97 | self.test_loader = self._get_data(self.cfg.dataset.testset_cfg, self.cfg.training.dataloader, 98 | is_train=False, sample_rate=self.sr) 99 | 100 | # Prepare generator 101 | model_name = self.cfg.model.pop('name') 102 | optim_name = self.cfg.training.optimizer.pop('name') 103 | scheduler_name = self.cfg.training.scheduler.pop('name') 104 | self.model = self._get_model(model_name, self.cfg.model, self.sr) 105 | self.optimizer_g = self._get_optim(self.model.parameters(), optim_name, self.cfg.training.optimizer) 106 | self.scheduler_g = self._get_scheduler(self.optimizer_g, scheduler_name, self.cfg.training.scheduler) 107 | 108 | # Prepare discriminator 109 | dis_name = self.cfg.discriminator.pop('name') 110 | self.discriminator = self._get_model(dis_name, self.cfg.discriminator, sample_rate=self.sr) 111 | self.optimizer_d = self._get_optim(self.discriminator.parameters(), optim_name, self.cfg.training.optimizer) 112 | self.scheduler_d = self._get_scheduler(self.optimizer_d, scheduler_name, self.cfg.training.scheduler) 113 | 114 | # Prepare training recording 115 | self.tm = utils.TrainMonitor() 116 | 117 | # Accelerator preparation 118 | self._acc_prepare() 119 | 120 | # Define the loss function 121 | self._metric_prepare(self.cfg.training.loss) 122 | self.lambdas = self.cfg.training.loss.lambdas 123 | 124 | # Define the audio transform function 125 | self._transform_prepare(self.cfg.training.transform) 126 | 127 | # Resume 128 | if self.cfg.resume: 129 | logger.info(f'Resume training from: {self.ckpt_resume}') 130 | accelerator.load_state(self.ckpt_resume) 131 | self.tm.nb_step += 1 132 | logger.info(f'Training re-start from iter: {self.tm.nb_step}') 133 | else: 134 | logger.info(f'Experiment workdir: {self.project_dir}') 135 | logger.info(f'num_processes: {accelerator.num_processes}') 136 | logger.info(f'batch size per gpu for train: {self.cfg.training.dataloader.train_bs}') 137 | logger.info(f'batch size per gpu for validation: {self.cfg.training.dataloader.eval_bs}') 138 | logger.info(f'mixed_precision: {accelerator.mixed_precision}') 139 | # logger.info(OmegaConf.to_yaml(self.cfg)) 140 | logger.info('Trainer init finish') 141 | 142 | # Basic info 143 | self.tracks = self.cfg.model.tracks # ['speech', 'music', 'sfx'] 144 | self.eval_tracks = ['mix_rec'] + [f'{t}_rec' for t in self.tracks] + \ 145 | [f'{t}_sep' for t in self.tracks] + [f'{t}_sep_mask' for t in self.tracks] 146 | logger.info('Used audio tracks: {}'.format(self.tracks)) 147 | logger.info('Eval audio tracks: {}'.format(self.eval_tracks)) 148 | logger.info(f'Target sampling rate: {self.sr} Hz') 149 | logger.info(f'Random seed: {seed}') 150 | 151 | 152 | def _get_data(self, dataset_cfg, dataloader_cfg, is_train=True, sample_rate=44100): 153 | 154 | num_workers = dataloader_cfg.num_workers 155 | 156 | if is_train: 157 | batch_size = dataloader_cfg.train_bs 158 | shuffle = True 159 | drop_last = True 160 | data_class = getattr(datasets, f'DatasetAudioTrain') 161 | else: 162 | batch_size = dataloader_cfg.eval_bs 163 | shuffle = False 164 | drop_last = False 165 | data_class = getattr(datasets, f'DatasetAudioVal') 166 | 167 | dataset = data_class(sample_rate=sample_rate, **dataset_cfg) 168 | dataloader = DataLoader(dataset=dataset, 169 | batch_size=batch_size, num_workers=num_workers, 170 | shuffle=shuffle, drop_last=drop_last) 171 | return dataloader 172 | 173 | 174 | def _get_model(self, model_name, model_cfg, sample_rate=44100): 175 | logger.info(f"Model: {model_name}") 176 | net_class = getattr(models, f'{model_name}') 177 | model = net_class(sample_rate=sample_rate, **model_cfg) 178 | total_params = sum(p.numel() for p in model.parameters()) / 1e6 179 | logger.info(f'Total params: {total_params:.2f} Mb') 180 | total_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 181 | logger.info(f'Total trainable params: {total_train_params:.2f} Mb') 182 | return model 183 | 184 | 185 | def _get_optim(self, params, optim_name, optim_cfg): 186 | logger.info(f"Optimizer: {optim_name}") 187 | optim_class = getattr(torch.optim, optim_name) 188 | optimizer = optim_class(filter(lambda p: p.requires_grad, params), **optim_cfg) 189 | return optimizer 190 | 191 | 192 | def _get_scheduler(self, optimizer, scheduler_name, scheduler_cfg): 193 | logger.info(f"Scheduler: {scheduler_name}") 194 | sche_class = getattr(optim, scheduler_name) 195 | scheduler = sche_class(optimizer, **scheduler_cfg) 196 | return scheduler 197 | 198 | 199 | def _metric_prepare(self, loss_cfg): 200 | self.stft_loss = MultiScaleSTFTLoss(**loss_cfg.MultiScaleSTFTLoss) 201 | self.mel_loss = MelSpectrogramLoss(**loss_cfg.MelSpectrogramLoss) 202 | self.gan_loss = GANLoss(self.discriminator) 203 | self.eval_sisdr = SingleSrcNegSDR(sdr_type='sisdr') 204 | # self.eval_visqol = VisqolMetric(mode='speech', reduction=None) 205 | 206 | 207 | def _transform_prepare(self, transform_cfg): 208 | self.volume_norm = utils.VolumeNorm(sample_rate=self.sr) 209 | self.sep_norm = utils.WavSepMagNorm() 210 | self.peak_norm = utils.db_to_gain(transform_cfg.peak_norm_db) 211 | 212 | 213 | def _acc_prepare(self): 214 | self.model = accelerator.prepare(self.model) 215 | self.discriminator = accelerator.prepare(self.discriminator) 216 | self.optimizer_g = accelerator.prepare(self.optimizer_g) 217 | self.optimizer_d = accelerator.prepare(self.optimizer_d) 218 | self.scheduler_g = accelerator.prepare(self.scheduler_g) 219 | self.scheduler_d = accelerator.prepare(self.scheduler_d) 220 | self.train_loader = accelerator.prepare(self.train_loader) 221 | self.val_loader = accelerator.prepare(self.val_loader) 222 | self.test_loader = accelerator.prepare(self.test_loader) 223 | accelerator.register_for_checkpointing(self.tm) 224 | self.model_eval_func = self.model.module.evaluate if accelerator.use_distributed \ 225 | else self.model.evaluate 226 | logger.info('{} iterations per epoch'.format(len(self.train_loader))) 227 | 228 | 229 | def _data_transform(self, batch, transform_cfg, nb_step=-1, is_eval=False): 230 | 231 | # re-build data 232 | if is_eval: 233 | batch['valid_tracks'] = self.tracks 234 | norm_var = 0 235 | else: 236 | # random drop 0-2 tracks 237 | dist = Categorical(probs=torch.tensor(transform_cfg.random_num_sources)) 238 | num_sources = dist.sample() + 1 239 | batch['valid_tracks'] = random.sample(self.tracks, num_sources) 240 | norm_var = transform_cfg.lufs_norm_db['var'] 241 | 242 | batch['in_sources'] = len(batch['valid_tracks']) 243 | # build mix 244 | mix_max_peak = torch.zeros_like(batch['speech'])[...,:1] # (bs, C, 1) 245 | for track in batch['valid_tracks']: 246 | # volume norm 247 | batch[track] = self.volume_norm(signal=batch[track], 248 | target_loudness=transform_cfg.lufs_norm_db[track], 249 | var=norm_var) 250 | # peak value 251 | peak = batch[track].abs().max(dim=-1, keepdims=True)[0] 252 | mix_max_peak = torch.maximum(peak, mix_max_peak) 253 | 254 | # peak norm 255 | peak_gain = torch.ones_like(mix_max_peak) # (bs, C, 1) 256 | peak_gain[mix_max_peak > self.peak_norm] = self.peak_norm / mix_max_peak[mix_max_peak > self.peak_norm] 257 | batch['mix'] = torch.zeros_like(batch['speech']) 258 | for track in batch['valid_tracks']: 259 | batch[track] *= peak_gain 260 | batch['mix'] += batch[track] 261 | # mix volum norm 262 | batch['mix'], mix_gain = self.volume_norm(signal=batch['mix'], 263 | target_loudness=transform_cfg.lufs_norm_db['mix'], 264 | var=norm_var, 265 | return_gain=True) 266 | 267 | # norm each track 268 | for track in batch['valid_tracks']: 269 | batch[track] *= mix_gain[:, None, None] 270 | 271 | # random swap tracks 272 | batch['random_swap'] = (not is_eval) and (random.random() < transform_cfg.random_swap_prob) 273 | if batch['random_swap']: 274 | bs = batch['mix'].shape[0] 275 | mix_ref = torch.zeros_like(batch['mix']) 276 | batch['shuffle_list'] = {} 277 | for track in self.tracks: 278 | shuffle_list = list(range(bs)) 279 | random.shuffle(shuffle_list) 280 | batch['shuffle_list'][track] = shuffle_list 281 | if track in batch['valid_tracks']: 282 | mix_ref += batch[track][shuffle_list] 283 | else: 284 | mix_ref = batch['mix'].clone() 285 | 286 | batch['ref'] = torch.stack([mix_ref]+[batch[t].clone() for t in batch['valid_tracks']], dim=1) # (B, K, C, T) 287 | 288 | return batch 289 | 290 | 291 | def _print_logs(self, log_dict, title='Train', nb_step=0, use_tracker=True): 292 | msg = f"{title} iter {nb_step:d}" 293 | 294 | if title == 'Train': 295 | for k, v in log_dict.items(): 296 | k = '/'.join(k.split('/')[1:]) 297 | if k == 'lr': 298 | msg += f' {k}: {v:.8f}' 299 | else: 300 | msg += f' {k}: {v:.2f}' 301 | logger.info(msg) 302 | else: 303 | logger.info(msg) 304 | for c in self.eval_tracks: 305 | msg = f"--> {c}:" 306 | select_keys = filter(lambda k: k.split('/')[0] == f'eval_{c}' 307 | or k.split('/')[0] == f'test_{c}', log_dict.keys()) 308 | for k in select_keys: 309 | v = log_dict[k] 310 | k = '/'.join(k.split('/')[1:]) 311 | msg += ' {}: {:.2f}'.format(k, v) 312 | logger.info(msg) 313 | 314 | if use_tracker: 315 | self.tracker.log(log_dict, step=nb_step) # tracker automatically discard plt and audio 316 | 317 | 318 | def run(self): 319 | total_steps = self.cfg.training.total_steps 320 | print_steps = self.cfg.training.print_steps 321 | eval_steps = self.cfg.training.eval_steps 322 | vis_steps = self.cfg.training.vis_steps 323 | test_steps = self.cfg.training.test_steps 324 | early_stop = self.cfg.training.early_stop 325 | grad_clip = self.cfg.training.grad_clip 326 | save_iters = self.cfg.training.save_iters 327 | 328 | self.model.train() 329 | logger.info('Training...') 330 | while self.tm.nb_step <= total_steps: 331 | for batch in self.train_loader: 332 | 333 | # data transform and augmentation 334 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step) 335 | 336 | # train one step 337 | with accelerator.autocast(): 338 | train_log_dict = self.train_one_step(batch, grad_clip) 339 | 340 | # print log 341 | if self.tm.nb_step % print_steps == 0: 342 | self._print_logs(train_log_dict, title='Train', nb_step=self.tm.nb_step) 343 | 344 | # eval 345 | if self.tm.nb_step % eval_steps == 0: 346 | val_log_dict = self.run_eval() 347 | self._print_logs(val_log_dict, title='Validation', nb_step=self.tm.nb_step) 348 | 349 | # save best val 350 | if val_log_dict['val'] < self.tm.best_eval: 351 | self.tm.best_eval = val_log_dict['val'] 352 | self.tm.best_step = self.tm.nb_step 353 | logger.info("\t-->Validation improved!!! Save best!!!") 354 | accelerator.save_state(output_dir=self.ckpt_best, safe_serialization=False) # otherwise can't reload correctly 355 | # early stop 356 | else: 357 | self.tm.early_stop += 1 358 | if self.tm.early_stop >= early_stop: 359 | logger.info(f"\t--> Validation no imporved for {early_stop} times") 360 | logger.info(f"\t--> Training finished by early stop") 361 | logger.info(f"\t--> Best model saved at iter: {self.tm.best_step}") 362 | logger.info(f"\t--> Final test, load best ckpt") 363 | accelerator.load_state(self.ckpt_best) 364 | # save state and end 365 | unwrapped_model = accelerator.unwrap_model(self.model) 366 | torch.save(unwrapped_model.state_dict(), self.ckpt_final / 'ckpt_model_final.pth') 367 | unwrapped_discriminator = accelerator.unwrap_model(self.discriminator) 368 | torch.save(unwrapped_discriminator.state_dict(), self.ckpt_final / 'ckpt_discriminator_final.pth') 369 | logger.info(f"\t--> Final ckpt saved in {self.ckpt_final}") 370 | # final test 371 | test_log_dict = self.run_test() 372 | self._print_logs(test_log_dict, title='Final Test', nb_step=self.tm.best_step, use_tracker=False) 373 | accelerator.end_training() 374 | return 375 | 376 | # save model 377 | if self.tm.nb_step in save_iters: 378 | unwrapped_model = accelerator.unwrap_model(self.model) 379 | torch.save(unwrapped_model.state_dict(), self.ckptdir / 'ckpt_model_iter{}.pth'.format(self.tm.nb_step)) 380 | unwrapped_discriminator = accelerator.unwrap_model(self.discriminator) 381 | torch.save(unwrapped_discriminator.state_dict(), self.ckptdir / 'ckpt_discriminator_iter{}.pth'.format(self.tm.nb_step)) 382 | logger.info('\t--> Checkpoints saved for iteration: {}'.format(self.tm.nb_step)) 383 | 384 | # vis 385 | if self.tm.nb_step % vis_steps == 0: 386 | self.run_vis(nb_step=self.tm.nb_step) 387 | 388 | # test set 389 | if self.tm.nb_step % test_steps == 0: 390 | test_log_dict = self.run_test() 391 | self._print_logs(test_log_dict, title='Test', nb_step=self.tm.nb_step) 392 | 393 | self.tm.nb_step += 1 394 | 395 | # training end due to maximum train iters 396 | if self.tm.nb_step > total_steps: 397 | logger.info(f"\t--> Training finished by reaching max iterations") 398 | logger.info(f"\t--> Best model saved at iter: {self.tm.best_step}") 399 | logger.info(f"\t--> Final test, load best ckpt") 400 | accelerator.load_state(self.ckpt_best) 401 | # save state and end 402 | unwrapped_model = accelerator.unwrap_model(self.model) 403 | torch.save(unwrapped_model.state_dict(), self.ckpt_final / 'ckpt_model_final.pth') 404 | unwrapped_discriminator = accelerator.unwrap_model(self.discriminator) 405 | torch.save(unwrapped_discriminator.state_dict(), self.ckpt_final / 'ckpt_discriminator_final.pth') 406 | logger.info(f"\t--> Final ckpt saved in {self.ckpt_final}") 407 | # final test 408 | test_log_dict = self.run_test() 409 | self._print_logs(test_log_dict, title='Final Test', nb_step=self.tm.best_step, use_tracker=False) 410 | accelerator.end_training() 411 | return 412 | 413 | # breakpoint() 414 | 415 | 416 | def train_one_step(self, batch, grad_clip): 417 | # Forward, AMP automatically set by Accelerator 418 | batch = self.model(batch) 419 | 420 | # Reshape in/out 421 | audio_recon = rearrange(batch['recon'], 'b k c t -> (b k) c t') 422 | audio_ref = rearrange(batch['ref'], 'b k c t -> (b k) c t') 423 | 424 | # Train discriminator 425 | batch["adv/disc_loss"] = self.gan_loss.discriminator_loss(fake=audio_recon, real=audio_ref) 426 | self.optimizer_d.zero_grad() 427 | accelerator.backward(batch["adv/disc_loss"]) 428 | if accelerator.sync_gradients: 429 | grad_norm_d = accelerator.clip_grad_norm_(self.discriminator.parameters(), max_norm=grad_clip) 430 | self.optimizer_d.step() 431 | self.scheduler_d.step() 432 | 433 | # Train generator 434 | batch["stft/loss"] = self.stft_loss(est=audio_recon, ref=audio_ref) 435 | batch["mel/loss"] = self.mel_loss(est=audio_recon, ref=audio_ref) 436 | ( 437 | batch["adv/gen_loss"], 438 | batch["adv/feat_loss"], 439 | ) = self.gan_loss.generator_loss(fake=audio_recon, real=audio_ref) 440 | 441 | loss = sum([v * batch[k] for k, v in self.lambdas.items() if k in batch]) 442 | 443 | # debugging nan error 444 | if loss.isnan().any(): 445 | logger.error('Nan detect, debugging...') 446 | ckpt_debug = self.project_dir / 'ckpt_debug' 447 | ckpt_debug.mkdir(exist_ok=True) 448 | data_debug = ckpt_debug / f'batch_data_{accelerator.process_index}.pth' 449 | accelerator.save_state(output_dir=ckpt_debug, safe_serialization=False) 450 | torch.save(batch, data_debug) 451 | logger.info(f"\t--> Debug state saved in {ckpt_debug}") 452 | accelerator.wait_for_everyone() 453 | accelerator.end_training() 454 | sys.exist() 455 | return 456 | # breakpoint() 457 | 458 | 459 | # Generator gradient descent 460 | self.optimizer_g.zero_grad() 461 | accelerator.backward(loss) 462 | if accelerator.sync_gradients: 463 | grad_norm_g = accelerator.clip_grad_norm_(self.model.parameters(), max_norm=grad_clip) 464 | self.optimizer_g.step() 465 | self.scheduler_g.step() 466 | 467 | # Mean reduce across all GPUs 468 | log_dict = OrderedDict() 469 | log_dict['train/lr'] = self.scheduler_g.get_last_lr()[0] 470 | log_dict['train/loss_d'] = accelerator.reduce(batch["adv/disc_loss"], reduction="mean").item() 471 | log_dict['train/loss_g'] = accelerator.reduce(loss, reduction="mean").item() 472 | log_dict['train/grad_norm_d'] = accelerator.reduce(grad_norm_d, reduction="mean").item() 473 | log_dict['train/grad_norm_g'] = accelerator.reduce(grad_norm_g, reduction="mean").item() 474 | log_dict['train/stft'] = accelerator.reduce(batch["stft/loss"], reduction="mean").item() 475 | log_dict['train/mel'] = accelerator.reduce(batch['mel/loss'], reduction="mean").item() 476 | log_dict['train/gen'] = accelerator.reduce(batch['adv/gen_loss'], reduction="mean").item() 477 | log_dict['train/feat'] = accelerator.reduce(batch['adv/feat_loss'], reduction="mean").item() 478 | log_dict['train/commit_loss'] = accelerator.reduce(batch['vq/commitment_loss'], reduction="mean").item() 479 | log_dict['train/cb_loss'] = accelerator.reduce(batch['vq/codebook_loss'], reduction="mean").item() 480 | 481 | return log_dict 482 | 483 | 484 | @torch.no_grad() 485 | def run_eval(self): 486 | """Distributed evaluation 487 | for inputs, targets in validation_dataloader: 488 | predictions = model(inputs) 489 | # Gather all predictions and targets 490 | all_predictions, all_targets = accelerator.gather_for_metrics((predictions, targets)) 491 | # Example of use with a *Datasets.Metric* 492 | metric.add_batch(all_predictions, all_targets) 493 | """ 494 | self.model.eval() 495 | am_dis_stft= {k: utils.AverageMeter() for k in self.eval_tracks} 496 | am_dis_mel = {k: utils.AverageMeter() for k in self.eval_tracks} 497 | am_sisdr = {k: utils.AverageMeter() for k in self.eval_tracks} 498 | am_ppl_rvq = {k: [utils.AverageMeter() for _ in range(self.cfg.model.quant_params.n_codebooks[i])] \ 499 | for i, k in enumerate(self.tracks) 500 | } 501 | 502 | for batch in self.val_loader: 503 | 504 | # data transform and augmentation 505 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step, is_eval=True) 506 | 507 | # Forward 508 | batch = self.model(batch) 509 | 510 | # Distributed evaluation 511 | all_recon = accelerator.gather_for_metrics(batch['recon']) # valide for nested list/tuple/dict 512 | all_ref = accelerator.gather_for_metrics(batch['ref']) 513 | 514 | # Codebook perplexity 515 | for t, am_ppl_list in am_ppl_rvq.items(): 516 | ppl_rvq = accelerator.gather_for_metrics(batch[f'{t}/ppl'].unsqueeze(0)) 517 | ppl_rvq = ppl_rvq.mean(dim=0) 518 | for i, am_ppl in enumerate(am_ppl_list): 519 | am_ppl.update(ppl_rvq[i].item()) 520 | 521 | # Eval mix reconstruction 522 | est = all_recon[:,0] 523 | ref = all_ref[:,0] 524 | am_dis_stft['mix_rec'].update(self.stft_loss(est=est, ref=ref).item()) 525 | am_dis_mel['mix_rec'].update(self.mel_loss(est=est, ref=ref).item()) 526 | am_sisdr['mix_rec'].update(- self.eval_sisdr(est=est, ref=ref).item()) 527 | 528 | # Eval separation using synthesizer (decoder) 529 | for i, t in enumerate(self.tracks): 530 | est = all_recon[:,i+1] 531 | ref = all_ref[:,i+1] 532 | am_dis_stft[f'{t}_sep'].update(self.stft_loss(est=est, ref=ref).item()) 533 | am_dis_mel[f'{t}_sep'].update(self.mel_loss(est=est, ref=ref).item()) 534 | am_sisdr[f'{t}_sep'].update(- self.eval_sisdr(est=est, ref=ref).item()) 535 | 536 | # Eval separation using mask 537 | all_sep_mask_norm = self.sep_norm(mix=all_ref[:,0:1], signal_sep=all_recon[:,1:]) 538 | for i, t in enumerate(self.tracks): 539 | est = all_sep_mask_norm[:,i] 540 | ref = all_ref[:,i+1] 541 | ref = ref[...,:est.shape[-1]] # stft + istft. shorter 542 | am_dis_stft[f'{t}_sep_mask'].update(self.stft_loss(est=est, ref=ref).item()) 543 | am_dis_mel[f'{t}_sep_mask'].update(self.mel_loss(est=est, ref=ref).item()) 544 | am_sisdr[f'{t}_sep_mask'].update(- self.eval_sisdr(est=est, ref=ref).item()) 545 | 546 | # Evaluate reconstruction of single track 547 | for i, t in enumerate(self.tracks): 548 | out_audio = self.model_eval_func(batch[t], output_tracks=[t]) 549 | all_recon = accelerator.gather_for_metrics(out_audio) 550 | all_ref = accelerator.gather_for_metrics(batch[t]) 551 | am_dis_stft[f'{t}_rec'].update(self.stft_loss(est=all_recon, ref=all_ref).item()) 552 | am_dis_mel[f'{t}_rec'].update(self.mel_loss(est=all_recon, ref=all_ref).item()) 553 | am_sisdr[f'{t}_rec'].update(- self.eval_sisdr(est=all_recon, ref=all_ref).item()) 554 | 555 | log_dict = OrderedDict() 556 | for t in self.eval_tracks: 557 | log_dict[f'eval_{t}/stft'] = am_dis_stft[t].avg 558 | log_dict[f'eval_{t}/mel'] = am_dis_mel[t].avg 559 | log_dict[f'eval_{t}/sisdr'] = am_sisdr[t].avg 560 | for t in self.tracks: 561 | for i, am_ppl in enumerate(am_ppl_rvq[t]): 562 | log_dict[f'eval_ppl/{t}_q{i}'] = am_ppl.avg 563 | 564 | log_dict['val'] = - np.mean([log_dict[f'eval_{k}/sisdr'] for k in self.eval_tracks]) # key to update best model 565 | 566 | self.model.train() 567 | 568 | return log_dict 569 | 570 | 571 | @torch.no_grad() 572 | def run_test(self): 573 | self.model.eval() 574 | 575 | am_dis_stft= {k: utils.AverageMeter() for k in self.eval_tracks} 576 | am_dis_mel = {k: utils.AverageMeter() for k in self.eval_tracks} 577 | am_sisdr = {k: utils.AverageMeter() for k in self.eval_tracks} 578 | # am_visqol = {k: utils.AverageMeter() for k in self.eval_tracks} 579 | am_ppl_rvq = {k: [utils.AverageMeter() for _ in range(self.cfg.model.quant_params.n_codebooks[i])] \ 580 | for i, k in enumerate(self.tracks) 581 | } 582 | 583 | for batch in self.val_loader: 584 | 585 | # data transform and augmentation 586 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step, is_eval=True) 587 | 588 | # Forward 589 | batch = self.model(batch) 590 | 591 | # Distributed evaluation 592 | all_recon = accelerator.gather_for_metrics(batch['recon']) # valide for nested list/tuple/dict 593 | all_ref = accelerator.gather_for_metrics(batch['ref']) 594 | 595 | # Codebook perplexity 596 | for t, am_ppl_list in am_ppl_rvq.items(): 597 | ppl_rvq = accelerator.gather_for_metrics(batch[f'{t}/ppl'].unsqueeze(0)) 598 | ppl_rvq = ppl_rvq.mean(dim=0) 599 | for i, am_ppl in enumerate(am_ppl_list): 600 | am_ppl.update(ppl_rvq[i].item()) 601 | 602 | # Eval mix reconstruction 603 | est = all_recon[:,0] 604 | ref = all_ref[:,0] 605 | am_dis_stft['mix_rec'].update(self.stft_loss(est=est, ref=ref).item()) 606 | am_dis_mel['mix_rec'].update(self.mel_loss(est=est, ref=ref).item()) 607 | am_sisdr['mix_rec'].update(- self.eval_sisdr(est=est, ref=ref).item()) 608 | ## distributed compute visqol 609 | # list_visqol = self.eval_visqol(est=batch['recon'][:,0], ref=batch['ref'][:,0], sr=self.sr) 610 | # scores_visqol = accelerator.gather_for_metrics(list_visqol) 611 | # am_visqol['mix'].update(np.mean(scores_visqol)) 612 | 613 | # Eval separation using synthesizer (decoder) 614 | for i, t in enumerate(self.tracks): 615 | est = all_recon[:,i+1] 616 | ref = all_ref[:,i+1] 617 | am_dis_stft[f'{t}_sep'].update(self.stft_loss(est=est, ref=ref).item()) 618 | am_dis_mel[f'{t}_sep'].update(self.mel_loss(est=est, ref=ref).item()) 619 | am_sisdr[f'{t}_sep'].update(- self.eval_sisdr(est=est, ref=ref).item()) 620 | ## distributed compute visqol 621 | # list_visqol = self.eval_visqol(est=batch['recon'][:,i+1], ref=batch['ref'][:,i+1], sr=self.sr) 622 | # scores_visqol = accelerator.gather_for_metrics(list_visqol) 623 | # am_visqol[f'{t}_sep'].update(np.mean(scores_visqol)) 624 | 625 | # Eval separation using mask 626 | all_sep_mask_norm = self.sep_norm(mix=all_ref[:,0:1], signal_sep=all_recon[:,1:]) 627 | for i, t in enumerate(self.tracks): 628 | est = all_sep_mask_norm[:,i] 629 | ref = all_ref[:,i+1] 630 | ref = ref[...,:est.shape[-1]] # stft + istft. shorter 631 | am_dis_stft[f'{t}_sep_mask'].update(self.stft_loss(est=est, ref=ref).item()) 632 | am_dis_mel[f'{t}_sep_mask'].update(self.mel_loss(est=est, ref=ref).item()) 633 | am_sisdr[f'{t}_sep_mask'].update(- self.eval_sisdr(est=est, ref=ref).item()) 634 | 635 | # Evaluate reconstruction on each individual track 636 | for i, t in enumerate(self.tracks): 637 | out_audio = self.model_eval_func(batch[t], output_tracks=[t]) 638 | all_recon = accelerator.gather_for_metrics(out_audio) 639 | all_ref = accelerator.gather_for_metrics(batch[t]) 640 | am_dis_stft[f'{t}_rec'].update(self.stft_loss(est=all_recon, ref=all_ref).item()) 641 | am_dis_mel[f'{t}_rec'].update(self.mel_loss(est=all_recon, ref=all_ref).item()) 642 | am_sisdr[f'{t}_rec'].update(- self.eval_sisdr(est=all_recon, ref=all_ref).item()) 643 | ## distributed compute visqol 644 | # list_visqol = self.eval_visqol(est=out_audio, ref=batch[t], sr=self.sr) 645 | # scores_visqol = accelerator.gather_for_metrics(list_visqol) 646 | # am_visqol[f'{t}_rec'].update(np.mean(scores_visqol)) 647 | 648 | log_dict = OrderedDict() 649 | for t in self.eval_tracks: 650 | log_dict[f'test_{t}/stft'] = am_dis_stft[t].avg 651 | log_dict[f'test_{t}/mel'] = am_dis_mel[t].avg 652 | log_dict[f'test_{t}/sisdr'] = am_sisdr[t].avg 653 | # log_dict[f'test_{t}/visqol'] = am_visqol[t].avg 654 | for t in self.tracks: 655 | for i, am_ppl in enumerate(am_ppl_rvq[t]): 656 | log_dict[f'test_ppl/{t}_q{i}'] = am_ppl.avg 657 | 658 | self.model.train() 659 | 660 | return log_dict 661 | 662 | 663 | @torch.no_grad() 664 | @accelerator.on_main_process 665 | def run_vis(self, nb_step): 666 | self.model.eval() 667 | 668 | ret_dict = defaultdict(list) 669 | 670 | writer = self.tracker.writer 671 | 672 | vis_idx = self.cfg.training.get('vis_idx', []) 673 | for idx in vis_idx: 674 | # get data 675 | batch = self.val_loader.dataset.__getitem__(idx) 676 | for t in self.tracks: 677 | batch[t] = batch[t].unsqueeze(0).to(accelerator.device) # (1, 1, T) 678 | # data transform and augmentation 679 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step, is_eval=True) 680 | # single track recon 681 | for t in self.tracks: 682 | ret_dict[f'{t}_orig'].append(batch[t][0]) 683 | ret_dict[f'{t}_recon'].append(self.model_eval_func(batch[t], output_tracks=[t])[0]) 684 | 685 | # mix recon and separation 686 | sep_out = self.model_eval_func(batch['mix'], output_tracks= ['mix'] + self.tracks) 687 | ret_dict['mix_orig'].append(batch['mix'][0]) 688 | ret_dict['mix_recon'].append(sep_out[:, 0]) # (1, T) 689 | for i, t in enumerate(self.tracks): 690 | ret_dict[f'{t}_sep'].append(sep_out[:, i+1]) 691 | 692 | # separation using FFT-mask 693 | mix = batch['mix'].unsqueeze(2) 694 | signal_sep = sep_out[:,1:].unsqueeze(2) 695 | all_sep_mask_norm = self.sep_norm(mix, signal_sep) 696 | for p, t in enumerate(self.tracks): 697 | est = all_sep_mask_norm[0,p] 698 | right_pad = mix.shape[-1] - est.shape[-1] 699 | est = F.pad(est, (0, right_pad)) 700 | ret_dict[f'{t}_sep_mask'].append(est) 701 | 702 | 703 | # mix 704 | audio_mix_orig = torch.cat(ret_dict['mix_orig'], dim=-1).detach().cpu() 705 | audio_mix_recon = torch.cat(ret_dict['mix_recon'], dim=-1).detach().cpu() 706 | writer.add_audio('mix_orig', audio_mix_orig, global_step=nb_step, sample_rate=self.sr) 707 | writer.add_audio('mix_recon', audio_mix_recon, global_step=nb_step, sample_rate=self.sr) 708 | audio_signal = torch.cat((audio_mix_orig, audio_mix_recon), dim=0).numpy() 709 | fig = utils.vis_spec(audio_signal, fs=self.sr, fig_width=8*len(vis_idx), 710 | tight_layout=False, save_fig=None) 711 | writer.add_figure('mix', fig, global_step=nb_step) 712 | 713 | # track 714 | for t in self.tracks: 715 | audio_orig = torch.cat(ret_dict[f'{t}_orig'], dim=-1).detach().cpu() 716 | audio_recon = torch.cat(ret_dict[f'{t}_recon'], dim=-1).detach().cpu() 717 | audio_sep = torch.cat(ret_dict[f'{t}_sep'], dim=-1).detach().cpu() 718 | audio_sep_mask = torch.cat(ret_dict[f'{t}_sep_mask'], dim=-1).detach().cpu() 719 | writer.add_audio(f'{t}_orig', audio_orig, global_step=nb_step, sample_rate=self.sr) 720 | writer.add_audio(f'{t}_recon', audio_recon, global_step=nb_step, sample_rate=self.sr) 721 | writer.add_audio(f'{t}_sep', audio_sep, global_step=nb_step, sample_rate=self.sr) 722 | writer.add_audio(f'{t}_sep_mask', audio_sep_mask, global_step=nb_step, sample_rate=self.sr) 723 | audio_signal = torch.cat((audio_orig, audio_recon, audio_sep, audio_sep_mask), dim=0).numpy() 724 | fig = utils.vis_spec(audio_signal, fs=self.sr, fig_width=8*len(vis_idx), 725 | tight_layout=False, save_fig=None) 726 | writer.add_figure(f'{t}', fig, global_step=nb_step) 727 | 728 | self.model.train() 729 | 730 | return -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_process import ( 2 | normalize_mean_var_np, 3 | normalize_max_norm_np, 4 | normalize_mean_var, 5 | normalize_max_norm, 6 | db_to_gain, 7 | VolumeNorm, 8 | WavSepMagNorm, 9 | ) 10 | from .torch_utils import ( 11 | warmup_learning_rate, 12 | get_scheduler, 13 | Configure_AdamW, 14 | ) 15 | from .utils import ( 16 | warm_up, 17 | AverageMeter, 18 | TrainMonitor, 19 | ) 20 | 21 | from .vis import ( 22 | vis_spec, 23 | ) -------------------------------------------------------------------------------- /src/utils/audio_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn 4 | import torchaudio 5 | from collections import namedtuple 6 | from einops import rearrange 7 | 8 | def normalize_mean_var_np(wav, eps=1e-8, std=None): 9 | mean = np.mean(wav, axis=-1, keepdims=True) 10 | if std is None: 11 | std = np.std(wav, axis=-1, keepdims=True) 12 | return (wav - mean) / (std + eps) 13 | 14 | def normalize_max_norm_np(wav): 15 | return wav /np.max(np.abs(wav), axis=-1, keepdims=True) 16 | 17 | 18 | def normalize_mean_var(wav, eps=1e-8, std=None): 19 | mean = wav.mean(-1, keepdim=True) 20 | if std is None: 21 | std = wav.std(-1, keepdim=True) 22 | return (wav - mean) / (std + eps) 23 | 24 | def normalize_max_norm(wav): 25 | return wav / wav.abs().max(-1, keepdim=True) 26 | 27 | def db_to_gain(db): 28 | return np.power(10.0, db/20.0) 29 | 30 | 31 | 32 | class VolumeNorm: 33 | """ 34 | Volume normalization to a specific loudness [LUFS standard] 35 | """ 36 | def __init__(self, sample_rate=16000): 37 | self.lufs_meter = torchaudio.transforms.Loudness(sample_rate) 38 | 39 | def __call__(self, signal, target_loudness=-30, var=0, return_gain=False): 40 | """ 41 | signal: torch.Tensor [B, channels, L] 42 | """ 43 | bs = signal.shape[0] 44 | # LUFS diff 45 | lufs_ref = self.lufs_meter(signal) 46 | lufs_target = (target_loudness + (torch.rand(bs) * 2 - 1) * var).to(lufs_ref.device) 47 | # db to gain 48 | gain = torch.exp((lufs_target - lufs_ref) * np.log(10) / 20) 49 | gain[gain.isnan()] = 0 # zero gain for silent audio 50 | # norm 51 | signal *= gain[:, None, None] 52 | 53 | if return_gain: 54 | return signal, gain 55 | else: 56 | return signal 57 | 58 | 59 | STFTParams = namedtuple( 60 | "STFTParams", 61 | ["window_length", "hop_length", "window_type", "padding_type"], 62 | ) 63 | STFT_PARAMS = STFTParams( 64 | window_length=1024, 65 | hop_length=256, 66 | window_type="hann", 67 | padding_type="reflect", 68 | ) 69 | 70 | 71 | class WavSepMagNorm: 72 | """ 73 | Normalize the separation results using the magnitude 74 | X_i = (|X_i| / sum_k<|X_i|> * |X_mix|) * exp(j arg) 75 | """ 76 | def __init__(self): 77 | self.stft_params = STFT_PARAMS 78 | 79 | def __call__(self, mix, signal_sep): 80 | """ 81 | Parameters 82 | ---------- 83 | mix: torch.Tensor [B, 1, channels, L] 84 | Mixture signal 85 | signal: torch.Tensor [B, K, channels, L] 86 | Separation results without normalization 87 | Returns 88 | ------- 89 | ret: torch.Tensor [B, K, channels, L'] 90 | Separation results 91 | """ 92 | bs, K, channels, _ = signal_sep.shape 93 | mix = rearrange(mix, 'b k c l -> (b k c) l') 94 | signal_sep = rearrange(signal_sep, 'b k c l -> (b k c) l') 95 | 96 | mix_spec = torch.stft(mix, n_fft=self.stft_params.window_length, hop_length=self.stft_params.hop_length, 97 | win_length=self.stft_params.window_length, 98 | window=torch.hann_window(self.stft_params.window_length, device=mix.device), 99 | pad_mode=self.stft_params.padding_type, center=True, onesided=True, return_complex=True) 100 | signal_sep_spec = torch.stft(signal_sep, n_fft=self.stft_params.window_length, hop_length=self.stft_params.hop_length, 101 | win_length=self.stft_params.window_length, 102 | window=torch.hann_window(self.stft_params.window_length, device=signal_sep.device), 103 | pad_mode=self.stft_params.padding_type, center=True, onesided=True, return_complex=True) 104 | 105 | mix_spec = rearrange(mix_spec, '(b k c) n t -> b k c n t', k=1, c=channels) 106 | signal_sep_spec = rearrange(signal_sep_spec, '(b k c) n t -> b k c n t', k=K, c=channels) 107 | 108 | signal_sep_mag = signal_sep_spec.abs() # (B, K, C, N, T) 109 | ratio = signal_sep_mag / signal_sep_mag.sum(dim=1, keepdim=True) 110 | ret_spec = torch.polar(mix_spec.abs() * ratio, mix_spec.angle()) 111 | 112 | ret_spec = rearrange(ret_spec, 'b k c n t -> (b k c) n t') 113 | ret = torch.istft(ret_spec, n_fft=self.stft_params.window_length, hop_length=self.stft_params.hop_length, 114 | win_length=self.stft_params.window_length, 115 | window=torch.hann_window(self.stft_params.window_length, device=mix.device), 116 | center=True) 117 | ret = rearrange(ret, '(b k c) l -> b k c l', k=K, c=channels) 118 | 119 | return ret 120 | 121 | 122 | if __name__ == '__main__': 123 | 124 | import torch 125 | sep_norm = WavSepMagNorm() 126 | mix = torch.randn(32, 1, 1, 16000) 127 | signal_sep = torch.randn(32, 3, 1, 16000) 128 | ret = sep_norm(mix, signal_sep) 129 | print(ret.shape) 130 | print(((mix[...,:ret.shape[-1]] - ret.sum(dim=1, keepdim=True)) < 1e-6).all()) 131 | 132 | -------------------------------------------------------------------------------- /src/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import lr_scheduler 3 | 4 | def warmup_learning_rate(optimizer, nb_iter, warmup_iter, max_lr): 5 | """warmup learning rate""" 6 | lr = max_lr * nb_iter / warmup_iter 7 | for param_group in optimizer.param_groups: 8 | param_group['lr'] = lr 9 | return lr 10 | 11 | 12 | def get_scheduler(optimizer, args): 13 | if args.policy == 'linear': 14 | scheduler = lr_scheduler.LinearLR(optimizer, total_iters=args.total_iter) # factor 0.33-1 15 | elif args.policy == 'cosine': 16 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.total_iter) 17 | elif args.policy == 'step': 18 | scheduler = lr_scheduler.StepLR( 19 | optimizer, step_size=args.decay_step, gamma=0.1) 20 | elif args.policy == 'multistep': 21 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=0.05) 22 | elif args.policy == 'plateau': 23 | scheduler = lr_scheduler.ReduceLROnPlateau( 24 | optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 25 | else: 26 | return NotImplementedError('learning rate args.policy [%s] is not implemented', args.policy) 27 | return scheduler 28 | 29 | 30 | def Configure_AdamW(model, weight_decay, learning_rate): 31 | 32 | all_params = set(model.parameters()) 33 | decay = set() 34 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d) 35 | 36 | for m in model.modules(): 37 | if isinstance(m, (torch.nn.Linear, torch.nn.Conv1d)): 38 | decay.add(m.weight) 39 | no_decay = all_params - decay 40 | 41 | # create the pytorch optimizer object 42 | optim_groups = [ 43 | {"params": list(decay), "weight_decay": weight_decay}, 44 | {"params": list(no_decay), "weight_decay": 0.0}, 45 | ] 46 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate) 47 | return optimizer -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import json 5 | from typing import List, Union 6 | 7 | def warm_up(step, warmup_steps): 8 | if step < warmup_steps: 9 | warmup_ratio = step / warmup_steps 10 | else: 11 | warmup_ratio = 1 12 | return warmup_ratio 13 | 14 | 15 | class AverageMeter(object): 16 | """Computes and stores the average and current value""" 17 | 18 | def __init__(self): 19 | self.hist = [] 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.count = 0 25 | self.avg = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.count += n 30 | self.avg += (self.val - self.avg) * n / self.count 31 | 32 | def save_log(self): 33 | self.hist.append(self.avg) 34 | self.reset() 35 | 36 | 37 | class TrainMonitor(object): 38 | """Record training""" 39 | 40 | def __init__(self, nb_step=1, best_eval=math.inf, best_step=1, early_stop=0): 41 | self.nb_step = nb_step 42 | self.best_eval = best_eval 43 | self.best_step = best_step 44 | self.early_stop = early_stop 45 | 46 | 47 | def state_dict(self): 48 | sd = {'nb_step': self.nb_step, 49 | 'best_eval': self.best_eval, 50 | 'best_step': self.best_step, 51 | 'early_stop': self.early_stop, 52 | } 53 | return sd 54 | 55 | 56 | def load_state_dict(self, state_dict: dict): 57 | self.nb_step = state_dict['nb_step'] 58 | self.best_eval = state_dict['best_eval'] 59 | self.best_step = state_dict['best_step'] 60 | self.early_stop = state_dict['early_stop'] -------------------------------------------------------------------------------- /src/utils/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.pylab as pylab 4 | params = {'legend.fontsize': 'xx-large', 5 | 'axes.labelsize': 'xx-large', 6 | 'axes.titlesize':'xx-large', 7 | 'xtick.labelsize':'xx-large', 8 | 'ytick.labelsize':'xx-large'} 9 | pylab.rcParams.update(params) 10 | 11 | 12 | 13 | def vis_spec(signal, 14 | fs=44100, nfft=1024, hop=512, 15 | fig_width=10, fig_height_per_plot=6, cmap='inferno', vmin=-150, vmax=None, 16 | use_colorbar=True, tight_layout=False, save_fig='./vis.png'): 17 | 18 | # plt.clf() 19 | 20 | # pre-plot 21 | if len(signal.shape) == 1: 22 | signal = signal.reshape(1, -1) 23 | 24 | num_sig = signal.shape[0] 25 | fig_size = (fig_width, fig_height_per_plot*num_sig) 26 | fig, axes = plt.subplots(nrows=num_sig, ncols=1, sharex=True, figsize=fig_size) 27 | if num_sig == 1: 28 | axes = [axes] 29 | 30 | # iterative plot 31 | for i, ax in enumerate(axes): 32 | x = signal[i] + 1e-9 33 | Pxx, freqs, bins, im = ax.specgram(x, scale='dB', 34 | Fs=fs, NFFT=nfft, noverlap=hop, 35 | vmin=vmin, vmax=vmax, cmap=cmap) 36 | if i == num_sig//2: 37 | ax.set_ylabel('Frequency (Hz)') 38 | if i == num_sig-1: 39 | ax.set_xlabel('Time (s)') 40 | 41 | # post-plot 42 | if use_colorbar: 43 | plt.colorbar(im, ax=axes, format="%+2.f dB") 44 | 45 | if tight_layout: 46 | plt.tight_layout() 47 | 48 | if save_fig: 49 | plt.savefig(save_fig) 50 | plt.close(fig) 51 | return 52 | else: 53 | return fig 54 | 55 | 56 | if __name__ == '__main__': 57 | import torchaudio 58 | 59 | filepath = '/home/xbie/Data/dnr_v2/tr/1002/mix.wav' 60 | x, fs = torchaudio.load(filepath) 61 | x = x.numpy().reshape(-1) 62 | 63 | import librosa 64 | sr = 44100 65 | clip = sr * 3 66 | _, (trim30dBs,_) = librosa.effects.trim(x[:clip], top_db=30) 67 | print(trim30dBs) 68 | breakpoint() 69 | # plt.clf() 70 | # plt.plot(x) 71 | # plt.savefig('./tmp.png') 72 | # plt.close() 73 | s = 500000 74 | e = 500000 + 51200 75 | x_clip = x[:, s:e] 76 | vis_spec(x_clip) 77 | 78 | --------------------------------------------------------------------------------