├── .gitignore
├── LICENSE
├── README.md
├── assets
└── pipeline.png
├── config
├── acc
│ ├── fp16_gpus1.yaml
│ └── fp16_gpus8.yaml
├── dataset
│ ├── dnr.yaml
│ └── mix_6k.yaml
├── debug.yaml
├── default.yaml
├── discriminator
│ └── default.yaml
├── model
│ ├── sdcodec_16k.yaml
│ ├── sdcodec_16k_pretrainDAC.yaml
│ ├── sdcodec_16k_shard4.yaml
│ └── sdcodec_16k_shard8.yaml
├── run_config
│ ├── slurm_1.yaml
│ └── slurm_debug.yaml
└── training
│ ├── bs2_iter5k.yaml
│ └── bs8_iter400k.yaml
├── environment.yml
├── eval_dnr.py
├── install_visqol.md
├── main.py
├── prepare
├── mani_dnr.py
├── mani_dns_clean.py
├── mani_dns_noise.py
├── mani_jamendo.py
├── mani_musan.py
└── mani_wham.py
└── src
├── datasets
├── __init__.py
└── audio_dataset.py
├── metrics
├── __init__.py
├── adv.py
├── sdr.py
├── spectrum.py
└── visqol.py
├── models
├── __init__.py
├── discriminator.py
└── sdcodec.py
├── modules
├── __init__.py
├── base_dac.py
├── layers.py
└── quantize.py
├── optim
├── __init__.py
├── cosine_lr_scheduler.py
├── exponential_lr_scheduler.py
├── inverse_sqrt_lr_scheduler.py
├── linear_warmup_lr_scheduler.py
├── polynomial_decay_lr_scheduler.py
└── reduce_plateau_lr_scheduler.py
├── trainer.py
└── utils
├── __init__.py
├── audio_process.py
├── torch_utils.py
├── utils.py
└── vis.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ## VsCode
2 | .vscode/
3 |
4 | ## Python output
5 | **/*.pyc
6 |
7 | ## outputs
8 | manifest/
9 | pretrained/
10 | output/
11 | local_scripts/
12 | stdout/
13 | stderr/
14 | *.sh
15 | try.py
16 | plot*.py
17 | *.out
18 | *.wav
19 | *.ipynb
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Xiaoyu Bie
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SD-Codec
2 |
3 | **Learning Source Disentanglement in Neural Audio Codec, ICASSP 2025**
4 |
5 | [Xiaoyu Bie](https://xiaoyubie1994.github.io/), [Xubo Liu](https://liuxubo717.github.io/), [Gaël Richard](https://www.telecom-paris.fr/gael-richard?l=en)
6 |
7 | **[[arXiv](https://arxiv.org/abs/2409.11228)]**, **[[Project](https://xiaoyubie1994.github.io/sdcodec/)]**
8 |
9 |
10 |
11 | The pretrained models can be downloaded via [Google Drive](https://drive.google.com/drive/folders/1-OjiNmtFdTUGwQF17FDMjzZgoBJJkHpG?usp=drive_link)
12 |
13 | ## Enviroment Setup
14 | All our models are trained on 8 A-100 80G GPUs
15 |
16 | ```
17 | conda env create -f environment.yml
18 | conda activate gen_audio
19 | ```
20 |
21 | The code was tested on Python 3.11.7 and PyTorch 2.1.2
22 |
23 | To install [VisQol](https://github.com/google/visqol), the simples way is to use [Bazelisk](https://github.com/bazelbuild/bazelisk?tab=readme-ov-file), especially if you want to install it on the cluster where you do not have the `sudo` right.
24 | Here is an [example](https://github.com/XiaoyuBIE1994/SDCodec/blob/main/install_visqol.md)
25 |
26 | ## Dataset Preparation
27 |
28 | We use the following dataset:
29 | - [Divide and Remaster (DnR)](https://zenodo.org/records/6949108)
30 | - [DNS-Challenge 5](https://github.com/microsoft/DNS-Challenge)
31 | - [MTG-Jamendo](https://mtg.github.io/mtg-jamendo-dataset/)
32 | - [MUSAN](https://www.openslr.org/17/)
33 | - [WHAM!](http://wham.whisper.ai/)
34 |
35 | ```bash
36 | mkdir manifest
37 |
38 | # DnR
39 | pyhthon prepare/mani_dnr.py --data-dir PATH_TO_DnR
40 |
41 | # DNS Challenge 5
42 | pyhthon prepare/mani_dns_clean.py --data-dir PATH_TO_DNS_CLEAN # or by partition
43 | pyhthon prepare/mani_dns_noise.py --data-dir PATH_TO_DNS_NOISE
44 |
45 | # Jamedo
46 | pyhthon prepare/mani_jamendo.py --data-dir PATH_TO_JAMENDO # or by partition
47 |
48 | # MUSAN
49 | pyhthon prepare/mani_musan.py --data-dir PATH_TO_MUSAN
50 |
51 | # WHAM
52 | pyhthon prepare/mani_wham.py --data-dir PATH_TO_WHAM
53 | ```
54 |
55 | ## Training
56 | ```
57 | # debug on single GPU
58 | accelerate launch --config_file config/acc/fp16_gpus1.yaml main.py --config-name debug +run_config=slurm_debug
59 |
60 | # training on 8 GPUs
61 | accelerate launch --config_file config/acc/fp16_gpus8.yaml main.py --config-name default +run_config=slurm_1
62 | ```
63 |
64 | ## Evaluation
65 | By default, we use the last checkpoint for the evaluation
66 | ```
67 | model_dir=PATH_TO_MODEL
68 | nohup python eval_dnr.py --ret-dir ${model_dir} --csv-path ./manifest/val.csv --length 5 > ${model_dir}/val.log 2>&1 &
69 | nohup python eval_dnr.py --ret-dir ${model_dir} --csv-path ./manifest/test.csv --length 10 > ${model_dir}/test.log 2>&1 &
70 | ```
71 |
72 |
73 | ## Citation
74 | If you find this project usefule in your resarch, please consider cite:
75 | ```BibTeX
76 | @inproceedings{bie2025sdcodec,
77 | author={Bie, Xiaoyu and Liu, Xubo and Richard, Ga{\"e}l},
78 | title={Learning Source Disentanglement in Neural Audio Codec},
79 | booktitle={IEEE International Conference on Acoustic, Speech and Signal Procssing (ICASSP)},
80 | year={2025},
81 | }
82 | ```
83 |
84 |
85 | ## Acknowledgments
86 | Some of the code in this project is inspired or modifed from the following projects:
87 | - [AudioCraft](https://github.com/facebookresearch/audiocraft)
88 | - [DAC](https://github.com/descriptinc/descript-audio-codec)
89 |
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XiaoyuBIE1994/SDCodec/62026eb7c1fc1079ef55f1c170f1066f383b349e/assets/pipeline.png
--------------------------------------------------------------------------------
/config/acc/fp16_gpus1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: 'NO'
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: fp16
9 | num_machines: 1
10 | num_processes: 1
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
--------------------------------------------------------------------------------
/config/acc/fp16_gpus8.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: fp16
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/config/dataset/dnr.yaml:
--------------------------------------------------------------------------------
1 | # Dataset config
2 | trainset_cfg:
3 | speech:
4 | - manifest/speech_dnr.csv
5 | music:
6 | - manifest/music_dnr.csv
7 | sfx:
8 | - manifest/sfx_dnr.csv
9 | n_examples: 10000000
10 | chunk_size: 2.0
11 | trim_silence: False
12 |
13 |
14 | valset_cfg:
15 | tsv_filepath: manifest/val.csv
16 | chunk_size: 5.0
17 |
18 | testset_cfg:
19 | tsv_filepath: manifest/test.csv
20 | chunk_size: 10.0
21 |
22 |
--------------------------------------------------------------------------------
/config/dataset/mix_6k.yaml:
--------------------------------------------------------------------------------
1 | # Dataset config
2 | trainset_cfg:
3 | speech:
4 | - manifest/speech_dns5_emotional_speech.csv
5 | - manifest/speech_dns5_french_speech.csv
6 | - manifest/speech_dns5_german_speech.csv
7 | - manifest/speech_dns5_italian_speech.csv
8 | - manifest/speech_dns5_read_speech.csv
9 | - manifest/speech_dns5_russian_speech.csv
10 | - manifest/speech_dns5_spanish_speech.csv
11 | - manifest/speech_dns5_vctk_wav48_silence_trimmed.csv
12 | - manifest/speech_dns5_VocalSet_48kHz_mono.csv
13 | - manifest/speech_musan.csv
14 | music:
15 | - manifest/music_jamendo.csv
16 | - manifest/music_musan.csv
17 | sfx:
18 | - manifest/sfx_dns5.csv
19 | - manifest/sfx_musan.csv
20 | - manifest/sfx_wham.csv
21 | n_examples: 10000000
22 | chunk_size: 2.0
23 | trim_silence: False
24 |
25 |
26 | valset_cfg:
27 | tsv_filepath: manifest/val_dnr.csv
28 | chunk_size: 5.0
29 |
30 | testset_cfg:
31 | tsv_filepath: manifest/test_dnr.csv
32 | chunk_size: 10.0
33 |
--------------------------------------------------------------------------------
/config/debug.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: sdcodec_16k
3 | - discriminator: default
4 | - dataset: dnr
5 | - training: bs2_iter5k
6 | - _self_
7 |
8 | sampling_rate: 16000
9 | resume: True
10 | resume_dir: null
11 | backup_code: True
--------------------------------------------------------------------------------
/config/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - model: sdcodec_16k
3 | - discriminator: default
4 | - dataset: mix_6k
5 | - training: bs8_iter400k
6 | - _self_
7 |
8 | sampling_rate: 16000
9 | resume: True
10 | resume_dir: null
11 | backup_code: True
12 |
--------------------------------------------------------------------------------
/config/discriminator/default.yaml:
--------------------------------------------------------------------------------
1 | # Basic
2 | name: Discriminator
3 | rates: []
4 | periods: [2, 3, 5, 7, 11]
5 | fft_sizes: [2048, 1024, 512]
6 | bands:
7 | - [0.0, 0.1]
8 | - [0.1, 0.25]
9 | - [0.25, 0.5]
10 | - [0.5, 0.75]
11 | - [0.75, 1.0]
12 |
--------------------------------------------------------------------------------
/config/model/sdcodec_16k.yaml:
--------------------------------------------------------------------------------
1 | # Basic
2 | name: SDCodec
3 | latent_dim: 1024
4 | tracks: ['speech', 'music', 'sfx']
5 | enc_params:
6 | name: DACEncoder
7 | d_model: 64
8 | strides: [2, 4, 5, 8]
9 | dec_params:
10 | name: DACDecoder
11 | d_model: 1536
12 | strides: [8, 5, 4, 2]
13 | quant_params:
14 | name: MultiSourceRVQ
15 | n_codebooks: [12, 12, 12]
16 | codebook_size: [1024, 1024, 1024]
17 | codebook_dim: [8, 8, 8]
18 | quantizer_dropout: 0.0
19 | code_jit_prob: [0.0, 0.0, 0.0]
20 | code_jit_size: [3, 5, 3]
21 | shared_codebooks: 0
22 | pretrain: {}
23 |
24 |
--------------------------------------------------------------------------------
/config/model/sdcodec_16k_pretrainDAC.yaml:
--------------------------------------------------------------------------------
1 | # Basic
2 | name: SDCodec
3 | latent_dim: 1024
4 | tracks: ['speech', 'music', 'sfx']
5 | enc_params:
6 | name: DACEncoder
7 | d_model: 64
8 | strides: [2, 4, 5, 8]
9 | dec_params:
10 | name: DACDecoder
11 | d_model: 1536
12 | strides: [8, 5, 4, 2]
13 | quant_params:
14 | name: MultiSourceRVQ
15 | n_codebooks: [12, 12, 12]
16 | codebook_size: [1024, 1024, 1024]
17 | codebook_dim: [8, 8, 8]
18 | quantizer_dropout: 0.0
19 | code_jit_prob: [0.0, 0.0, 0.0]
20 | code_jit_size: [3, 5, 3]
21 | shared_codebooks: 1
22 | pretrain:
23 | load_pretrained: './pretrained/weights_16khz_8kbps_0.0.5_new.pth'
24 | ignore_args: ['n_codebooks', 'codebook_size', 'codebook_dim']
25 | ignore_modules: []
26 | freeze_modules: []
27 |
28 |
--------------------------------------------------------------------------------
/config/model/sdcodec_16k_shard4.yaml:
--------------------------------------------------------------------------------
1 | # Basic
2 | name: SDCodec
3 | latent_dim: 1024
4 | tracks: ['speech', 'music', 'sfx']
5 | enc_params:
6 | name: DACEncoder
7 | d_model: 64
8 | strides: [2, 4, 5, 8]
9 | dec_params:
10 | name: DACDecoder
11 | d_model: 1536
12 | strides: [8, 5, 4, 2]
13 | quant_params:
14 | name: MultiSourceRVQ
15 | n_codebooks: [12, 12, 12]
16 | codebook_size: [1024, 1024, 1024]
17 | codebook_dim: [8, 8, 8]
18 | quantizer_dropout: 0.0
19 | code_jit_prob: [0.0, 0.0, 0.0]
20 | code_jit_size: [3, 5, 3]
21 | shared_codebooks: 4
22 | pretrain: {}
23 |
24 |
--------------------------------------------------------------------------------
/config/model/sdcodec_16k_shard8.yaml:
--------------------------------------------------------------------------------
1 | # Basic
2 | name: SDCodec
3 | latent_dim: 1024
4 | tracks: ['speech', 'music', 'sfx']
5 | enc_params:
6 | name: DACEncoder
7 | d_model: 64
8 | strides: [2, 4, 5, 8]
9 | dec_params:
10 | name: DACDecoder
11 | d_model: 1536
12 | strides: [8, 5, 4, 2]
13 | quant_params:
14 | name: MultiSourceRVQ
15 | n_codebooks: [12, 12, 12]
16 | codebook_size: [1024, 1024, 1024]
17 | codebook_dim: [8, 8, 8]
18 | quantizer_dropout: 0.0
19 | code_jit_prob: [0.0, 0.0, 0.0]
20 | code_jit_size: [3, 5, 3]
21 | shared_codebooks: 8
22 | pretrain: {}
23 |
24 |
--------------------------------------------------------------------------------
/config/run_config/slurm_1.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | hydra:
3 | job:
4 | config:
5 | override_dirname:
6 | kv_sep: '_'
7 | item_sep: '/'
8 | exclude_keys:
9 | - config_name
10 | - run_config
11 | run:
12 | # dir: output/debug # job's output dir
13 | dir: output/${hydra.job.config_name}/${hydra.job.override_dirname} # job's output dir
14 | # dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
15 | job_logging:
16 | formatters:
17 | simple:
18 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
19 | datefmt: '%Y-%m-%d %H:%M:%S'
--------------------------------------------------------------------------------
/config/run_config/slurm_debug.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | hydra:
3 | job:
4 | config:
5 | override_dirname:
6 | kv_sep: ':'
7 | item_sep: '/'
8 | exclude_keys:
9 | - run_config
10 | - distributed_training.distributed_port
11 | - distributed_training.distributed_world_size
12 | - model.pretrained_model_path
13 | - model.target_network_path
14 | - next_script
15 | - task.cache_in_scratch
16 | - task.data
17 | - checkpoint.save_interval_updates
18 | - checkpoint.keep_interval_updates
19 | - checkpoint.save_on_overflow
20 | - common.log_interval
21 | - common.user_dir
22 | run:
23 | dir: output/debug # job's output dir
24 | # dir: output/${hydra.job.config_name}/${hydra.job.override_dirname} # job's output dir
25 | # dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}_${hydra.launcher.gpus_per_node}/${hydra.job.override_dirname}
26 | job_logging:
27 | formatters:
28 | simple:
29 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
30 | datefmt: '%Y-%m-%d %H:%M:%S'
--------------------------------------------------------------------------------
/config/training/bs2_iter5k.yaml:
--------------------------------------------------------------------------------
1 |
2 | total_steps: &total 5000
3 | warmup_steps: &warm 500
4 | print_steps: 50
5 | eval_steps: 500
6 | vis_steps: 500
7 | test_steps: 20000
8 |
9 | # total_steps: &total 200000
10 | # warmup_steps: &warm 10000
11 | # print_steps: 500
12 | # eval_steps: 5000
13 | # vis_steps: 10000
14 | # test_steps: 2000000
15 |
16 | early_stop: 50 # count for each eval
17 | grad_clip: 10.0 # no clip if < 0
18 | save_iters: [1000, 1500]
19 | vis_idx: [100, 200, 300]
20 |
21 | seed: 42
22 |
23 | # data transformation and augmentation
24 | transform:
25 | lufs_norm_db:
26 | speech: -17
27 | music: -24
28 | sfx: -21
29 | mix: -27
30 | var: 2
31 | peak_norm_db: -0.5
32 | random_num_sources: [0.2, 0.2, 0.6]
33 | random_swap_prob: 0.5
34 |
35 | # Optimizer
36 | optimizer:
37 | name: AdamW
38 | lr: 1e-4
39 | betas: [0.8, 0.99]
40 | # weight_decay: 1e-5
41 |
42 | # Scheduler
43 | scheduler:
44 | name: ExponentialLRScheduler
45 | total_steps: *total
46 | warmup_steps: *warm
47 | lr_min_ratio: 0.0
48 | gamma: 0.999
49 | # gamma: 0.999996
50 |
51 | # Loss
52 | loss:
53 | MultiScaleSTFTLoss:
54 | window_lengths: [2048, 512]
55 | MelSpectrogramLoss:
56 | n_mels: [5, 10, 20, 40, 80, 160, 320]
57 | window_lengths: [32, 64, 128, 256, 512, 1024, 2048]
58 | mel_fmin: [0, 0, 0, 0, 0, 0, 0]
59 | mel_fmax: [null, null, null, null, null, null, null]
60 | pow: 1.0
61 | clamp_eps: 1.0e-5
62 | mag_weight: 0.0
63 | lambdas:
64 | mel/loss: 15.0
65 | adv/feat_loss: 2.0
66 | adv/gen_loss: 1.0
67 | vq/commitment_loss: 0.25
68 | vq/codebook_loss: 1.0
69 |
70 | # Dataloader config
71 | dataloader:
72 | num_workers: 8
73 | train_bs: 2
74 | eval_bs: 32
75 |
--------------------------------------------------------------------------------
/config/training/bs8_iter400k.yaml:
--------------------------------------------------------------------------------
1 | total_steps: &total 400000
2 | warmup_steps: &warm 10000
3 | print_steps: 500
4 | eval_steps: 5000
5 | vis_steps: 10000
6 | test_steps: 30000
7 |
8 |
9 | early_stop: 50 # count for each eval
10 | grad_clip: 10.0 # no clip if < 0
11 | save_iters: [50000, 100000, 150000, 200000, 250000, 300000, 350000]
12 | vis_idx: [100, 200, 300]
13 |
14 | seed: 42
15 |
16 | # data transformation and augmentation
17 | transform:
18 | lufs_norm_db:
19 | speech: -17
20 | music: -24
21 | sfx: -21
22 | mix: -27
23 | var: 2
24 | peak_norm_db: -0.5
25 | random_num_sources: [0.6, 0.2, 0.2]
26 | random_swap_prob: 0.5
27 |
28 | # Optimizer
29 | optimizer:
30 | name: AdamW
31 | lr: 1e-4
32 | betas: [0.8, 0.99]
33 | # weight_decay: 1e-5
34 |
35 | # Scheduler
36 | scheduler:
37 | name: ExponentialLRScheduler
38 | total_steps: *total
39 | warmup_steps: *warm
40 | lr_min_ratio: 0.0
41 | gamma: 0.999996
42 |
43 | # Loss
44 | loss:
45 | MultiScaleSTFTLoss:
46 | window_lengths: [2048, 512]
47 | MelSpectrogramLoss:
48 | n_mels: [5, 10, 20, 40, 80, 160, 320]
49 | window_lengths: [32, 64, 128, 256, 512, 1024, 2048]
50 | mel_fmin: [0, 0, 0, 0, 0, 0, 0]
51 | mel_fmax: [null, null, null, null, null, null, null]
52 | pow: 1.0
53 | clamp_eps: 1.0e-5
54 | mag_weight: 0.0
55 | lambdas:
56 | mel/loss: 15.0
57 | adv/feat_loss: 2.0
58 | adv/gen_loss: 1.0
59 | vq/commitment_loss: 0.25
60 | vq/codebook_loss: 1.0
61 |
62 | # Dataloader config
63 | dataloader:
64 | num_workers: 8
65 | train_bs: 8
66 | eval_bs: 32
67 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gen_audio
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - binutils_impl_linux-64=2.38=h2a08ee3_1
11 | - blas=1.0=mkl
12 | - brotli-python=1.0.9=py311h6a678d5_7
13 | - bzip2=1.0.8=h7b6447c_0
14 | - ca-certificates=2024.3.11=h06a4308_0
15 | - certifi=2024.2.2=py311h06a4308_0
16 | - cffi=1.16.0=py311h5eee18b_0
17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
18 | - cryptography=41.0.7=py311hdda0065_0
19 | - cuda-cudart=11.8.89=0
20 | - cuda-cupti=11.8.87=0
21 | - cuda-libraries=11.8.0=0
22 | - cuda-nvrtc=11.8.89=0
23 | - cuda-nvtx=11.8.86=0
24 | - cuda-runtime=11.8.0=0
25 | - ffmpeg=4.3=hf484d3e_0
26 | - filelock=3.13.1=py311h06a4308_0
27 | - freetype=2.12.1=h4a9f257_0
28 | - gcc=12.1.0=h9ea6d83_10
29 | - gcc_impl_linux-64=12.1.0=hea43390_17
30 | - giflib=5.2.1=h5eee18b_3
31 | - gmp=6.2.1=h295c915_3
32 | - gmpy2=2.1.2=py311hc9b5ff0_0
33 | - gnutls=3.6.15=he1e5248_0
34 | - gxx_impl_linux-64=12.1.0=hea43390_17
35 | - idna=3.4=py311h06a4308_0
36 | - intel-openmp=2023.1.0=hdb19cb5_46306
37 | - jinja2=3.1.2=py311h06a4308_0
38 | - jpeg=9e=h5eee18b_1
39 | - kernel-headers_linux-64=2.6.32=he073ed8_17
40 | - lame=3.100=h7b6447c_0
41 | - lcms2=2.12=h3be6417_0
42 | - ld_impl_linux-64=2.38=h1181459_1
43 | - lerc=3.0=h295c915_0
44 | - libcublas=11.11.3.6=0
45 | - libcufft=10.9.0.58=0
46 | - libcufile=1.8.1.2=0
47 | - libcurand=10.3.4.107=0
48 | - libcusolver=11.4.1.48=0
49 | - libcusparse=11.7.5.86=0
50 | - libdeflate=1.17=h5eee18b_1
51 | - libffi=3.4.4=h6a678d5_0
52 | - libgcc-devel_linux-64=12.1.0=h1ec3361_17
53 | - libgcc-ng=13.2.0=h807b86a_5
54 | - libgomp=13.2.0=h807b86a_5
55 | - libiconv=1.16=h7f8727e_2
56 | - libidn2=2.3.4=h5eee18b_0
57 | - libjpeg-turbo=2.0.0=h9bf148f_0
58 | - libnpp=11.8.0.86=0
59 | - libnvjpeg=11.9.0.86=0
60 | - libpng=1.6.39=h5eee18b_0
61 | - libsanitizer=12.1.0=ha89aaad_17
62 | - libstdcxx-devel_linux-64=12.1.0=h1ec3361_17
63 | - libstdcxx-ng=13.2.0=h7e041cc_5
64 | - libtasn1=4.19.0=h5eee18b_0
65 | - libtiff=4.5.1=h6a678d5_0
66 | - libunistring=0.9.10=h27cfd23_0
67 | - libuuid=1.41.5=h5eee18b_0
68 | - libwebp=1.3.2=h11a3e52_0
69 | - libwebp-base=1.3.2=h5eee18b_0
70 | - llvm-openmp=14.0.6=h9e868ea_0
71 | - lz4-c=1.9.4=h6a678d5_0
72 | - markupsafe=2.1.3=py311h5eee18b_0
73 | - mkl=2023.1.0=h213fc3f_46344
74 | - mkl-service=2.4.0=py311h5eee18b_1
75 | - mkl_fft=1.3.8=py311h5eee18b_0
76 | - mkl_random=1.2.4=py311hdb19cb5_0
77 | - mpc=1.1.0=h10f8cd9_1
78 | - mpfr=4.0.2=hb69a4c5_1
79 | - mpmath=1.3.0=py311h06a4308_0
80 | - ncurses=6.4=h6a678d5_0
81 | - nettle=3.7.3=hbbd107a_1
82 | - networkx=3.1=py311h06a4308_0
83 | - openh264=2.1.1=h4ff587b_0
84 | - openjpeg=2.4.0=h3ad879b_0
85 | - openssl=3.2.1=hd590300_1
86 | - pillow=10.0.1=py311ha6cbd5a_0
87 | - pycparser=2.21=pyhd3eb1b0_0
88 | - pyopenssl=23.2.0=py311h06a4308_0
89 | - pysocks=1.7.1=py311h06a4308_0
90 | - python=3.11.7=h955ad1f_0
91 | - pytorch=2.1.2=py3.11_cuda11.8_cudnn8.7.0_0
92 | - pytorch-cuda=11.8=h7e8668a_5
93 | - pytorch-mutex=1.0=cuda
94 | - pyyaml=6.0.1=py311h5eee18b_0
95 | - readline=8.2=h5eee18b_0
96 | - requests=2.31.0=py311h06a4308_0
97 | - sqlite=3.41.2=h5eee18b_0
98 | - sympy=1.12=py311h06a4308_0
99 | - sysroot_linux-64=2.12=he073ed8_17
100 | - tbb=2021.8.0=hdb19cb5_0
101 | - tk=8.6.12=h1ccaba5_0
102 | - torchaudio=2.1.2=py311_cu118
103 | - torchtriton=2.1.0=py311
104 | - torchvision=0.16.2=py311_cu118
105 | - typing_extensions=4.7.1=py311h06a4308_0
106 | - urllib3=1.26.18=py311h06a4308_0
107 | - wheel=0.41.2=py311h06a4308_0
108 | - xz=5.4.5=h5eee18b_0
109 | - yaml=0.2.5=h7b6447c_0
110 | - zlib=1.2.13=h5eee18b_0
111 | - zstd=1.5.5=hc292b87_0
112 | - pip:
113 | - absl-py==2.0.0
114 | - accelerate==0.28.0
115 | - aiohttp==3.9.3
116 | - aiosignal==1.3.1
117 | - antlr4-python3-runtime==4.9.3
118 | - appdirs==1.4.4
119 | - argbind==0.3.7
120 | - argparse==1.4.0
121 | - asttokens==2.4.1
122 | - attrs==23.2.0
123 | - audioread==3.0.1
124 | - braceexpand==0.1.7
125 | - cachetools==5.3.2
126 | - click==8.1.7
127 | - contourpy==1.2.0
128 | - cycler==0.12.1
129 | - datasets==2.18.0
130 | - decorator==5.1.1
131 | - descript-audio-codec==1.0.0
132 | - descript-audiotools==0.7.2
133 | - diffusers==0.27.2
134 | - dill==0.3.8
135 | - docker-pycreds==0.4.0
136 | - docstring-parser==0.16
137 | - einops==0.7.0
138 | - executing==2.0.1
139 | - ffmpy==0.3.2
140 | - fire==0.6.0
141 | - flatten-dict==0.4.2
142 | - fonttools==4.47.0
143 | - frozenlist==1.4.1
144 | - fsspec==2023.12.2
145 | - ftfy==6.2.0
146 | - future==0.18.3
147 | - gitdb==4.0.11
148 | - gitpython==3.1.43
149 | - google-auth==2.26.1
150 | - google-auth-oauthlib==1.2.0
151 | - grpcio==1.60.0
152 | - h5py==3.11.0
153 | - huggingface-hub==0.20.2
154 | - hydra-core==1.3.2
155 | - importlib-metadata==7.0.2
156 | - importlib-resources==6.4.0
157 | - ipython==8.23.0
158 | - jedi==0.19.1
159 | - joblib==1.3.2
160 | - julius==0.2.7
161 | - kiwisolver==1.4.5
162 | - laion-clap==1.1.4
163 | - lazy-loader==0.3
164 | - librosa==0.10.1
165 | - llvmlite==0.41.1
166 | - markdown==3.5.1
167 | - markdown-it-py==3.0.0
168 | - markdown2==2.4.13
169 | - matplotlib==3.8.3
170 | - matplotlib-inline==0.1.6
171 | - mdurl==0.1.2
172 | - msgpack==1.0.7
173 | - multidict==6.0.5
174 | - multiprocess==0.70.16
175 | - numba==0.58.1
176 | - numpy==1.23.5
177 | - oauthlib==3.2.2
178 | - omegaconf==2.3.0
179 | - packaging==23.2
180 | - pandas==2.1.4
181 | - parso==0.8.3
182 | - pesq==0.0.4
183 | - pexpect==4.9.0
184 | - pip==24.0
185 | - platformdirs==4.1.0
186 | - pooch==1.8.0
187 | - progressbar==2.5
188 | - prompt-toolkit==3.0.43
189 | - protobuf==3.19.6
190 | - psutil==5.9.7
191 | - ptyprocess==0.7.0
192 | - pure-eval==0.2.2
193 | - pyarrow==15.0.1
194 | - pyarrow-hotfix==0.6
195 | - pyasn1==0.5.1
196 | - pyasn1-modules==0.3.0
197 | - pygments==2.17.2
198 | - pyloudnorm==0.1.1
199 | - pyparsing==3.1.1
200 | - pysndfx==0.3.6
201 | - pystoi==0.4.1
202 | - python-dateutil==2.8.2
203 | - pytz==2023.3.post1
204 | - randomname==0.2.1
205 | - regex==2023.12.25
206 | - requests-oauthlib==1.3.1
207 | - resampy==0.4.3
208 | - rich==13.7.1
209 | - rsa==4.9
210 | - safetensors==0.4.1
211 | - scikit-learn==1.3.2
212 | - scipy==1.12.0
213 | - sentry-sdk==2.0.1
214 | - setproctitle==1.3.3
215 | - setuptools==69.0.3
216 | - six==1.16.0
217 | - smmap==5.0.1
218 | - soundfile==0.12.1
219 | - soxr==0.3.7
220 | - stack-data==0.6.3
221 | - tensorboard==2.16.2
222 | - tensorboard-data-server==0.7.2
223 | - termcolor==2.4.0
224 | - threadpoolctl==3.2.0
225 | - tokenizers==0.13.3
226 | - torch-stoi==0.2.1
227 | - torchlibrosa==0.1.0
228 | - tqdm==4.66.1
229 | - traitlets==5.14.2
230 | - transformers==4.30.0
231 | - tzdata==2023.4
232 | - visqol==3.3.3
233 | - wandb==0.16.6
234 | - wcwidth==0.2.13
235 | - webdataset==0.2.86
236 | - werkzeug==3.0.1
237 | - wget==3.2
238 | - xxhash==3.4.1
239 | - yarl==1.9.4
240 | - zipp==3.18.1
241 | prefix: /home/xbie/anaconda3/envs/gen_audio
242 |
--------------------------------------------------------------------------------
/eval_dnr.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import json
4 | import argparse
5 | import importlib
6 | import math
7 | import julius
8 | import pandas as pd
9 | from pathlib import Path
10 | from omegaconf import OmegaConf
11 | from tqdm import tqdm
12 | import numpy as np
13 | from collections import namedtuple
14 |
15 | from src import utils
16 | from src.metrics import (
17 | VisqolMetric,
18 | SingleSrcNegSDR,
19 | MultiScaleSTFTLoss,
20 | MelSpectrogramLoss,
21 | )
22 |
23 | import torch
24 | import torchaudio
25 | from accelerate import Accelerator
26 | accelerator = Accelerator()
27 |
28 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset',
29 | add_help=True,
30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
31 |
32 | parser.add_argument('--ret-dir', type=str, default='output/debug', help='Training result directory')
33 | parser.add_argument('--csv-path', type=str, default='./manifest/test.csv', help='csv file to test')
34 | parser.add_argument('--data-sr', type=int, default=[44100], nargs='+', help='list of sampling rate in test files')
35 | parser.add_argument('--length', type=int, default=10, help='audio length')
36 | parser.add_argument('--visqol-mode', type=str, default='speech', choices=['speech', 'audio'], help='visqol mode')
37 | parser.add_argument('--threshold', type=float, default=0.4, help='threshold of silence part to drop audio')
38 | parser.add_argument('--fast', action='store_true', help='fast eval, disable visqol computation')
39 |
40 | # parse
41 | args = parser.parse_args()
42 | ret_dir = Path(args.ret_dir)
43 | csv_path = Path(args.csv_path)
44 | length = args.length
45 | visqol_mode = args.visqol_mode
46 | threshold = args.threshold
47 | use_visqol = not args.fast
48 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
49 |
50 | # read config
51 | cfg_filepath = ret_dir / '.hydra' / 'config.yaml'
52 | cfg = OmegaConf.load(cfg_filepath)
53 | sample_rate = cfg.sampling_rate
54 | chunk_len = sample_rate * length
55 |
56 | # init julius resample
57 | resample_pool = dict()
58 | for sr in args.data_sr:
59 | old_sr = sr
60 | new_sr = sample_rate
61 | gcd = math.gcd(old_sr, new_sr)
62 | old_sr = old_sr // gcd
63 | new_sr = new_sr // gcd
64 | resample_pool[sr] = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr)
65 |
66 | # import lib
67 | model_name = cfg.model.pop('name')
68 | module_path = str(ret_dir / 'backup_src' / 'models').replace('/', '.')
69 | try:
70 | load_model = importlib.import_module(module_path)
71 | net_class = getattr(load_model, f'{model_name}')
72 | print('Load model from ckpt')
73 | except:
74 | from src import models
75 | net_class = getattr(models, f'{model_name}')
76 | print('Load model from source code')
77 |
78 | # load model and weigth
79 | model_cfg = cfg.model
80 | model = net_class(sample_rate=sample_rate, **model_cfg)
81 | total_params = sum(p.numel() for p in model.parameters()) / 1e6
82 | print(f'Total params: {total_params:.2f} Mb')
83 | print('Model sampling rate: {} Hz'.format(model.sample_rate))
84 |
85 | ckpt_finalpath = ret_dir / 'ckpt_final' / 'ckpt_model_final.pth'
86 | state_dict = torch.load(ckpt_finalpath, map_location=torch.device('cpu'))
87 | model.load_state_dict(state_dict)
88 | model = model.to(device)
89 | model.eval()
90 | print(f'ckpt path: {ckpt_finalpath}')
91 | print(f'Model weights load successfully...')
92 |
93 | # prepare metrics
94 | loss_cfg = cfg.training.loss
95 | metric_stft = MultiScaleSTFTLoss(**loss_cfg.MultiScaleSTFTLoss)
96 | metric_mel = MelSpectrogramLoss(**loss_cfg.MelSpectrogramLoss)
97 | metric_sisdr = SingleSrcNegSDR(sdr_type='sisdr')
98 | metric_visqol = VisqolMetric(mode=visqol_mode)
99 |
100 | # prepare data transform
101 | transform_cfg = cfg.training.transform
102 | volume_norm = utils.VolumeNorm(sample_rate=sample_rate)
103 | def _data_transform(batch, transform_cfg, valid_tracks=['speech'], norm_var=0):
104 | peak_norm = utils.db_to_gain(transform_cfg.peak_norm_db)
105 | mix_max_peak = torch.zeros_like(batch['speech'])[...,:1] # (bs, C, 1)
106 |
107 | # volume norm for each track
108 | for track in valid_tracks:
109 | batch[track] = volume_norm(signal=batch[track],
110 | target_loudness=transform_cfg.lufs_norm_db[track],
111 | var=norm_var)
112 | # peak value
113 | peak = batch[track].abs().max(dim=-1, keepdims=True)[0]
114 | mix_max_peak = torch.maximum(peak, mix_max_peak)
115 |
116 | # peak norm
117 | peak_gain = torch.ones_like(mix_max_peak) # (bs, C, 1)
118 | peak_gain[mix_max_peak > peak_norm] = peak_norm / mix_max_peak[mix_max_peak > peak_norm]
119 |
120 | # build mix
121 | batch['mix'] = torch.zeros_like(batch['speech'])
122 | for track in valid_tracks:
123 | batch[track] *= peak_gain
124 | batch['mix'] += batch[track]
125 |
126 | # mix volum norm
127 | batch['mix'], mix_gain = volume_norm(signal=batch['mix'],
128 | target_loudness=transform_cfg.lufs_norm_db['mix'],
129 | var=norm_var,
130 | return_gain=True)
131 |
132 | # norm each track
133 | for track in valid_tracks:
134 | batch[track] *= mix_gain[:, None, None]
135 |
136 | batch['valid_tracks'] = valid_tracks
137 | batch['random_swap'] = False
138 |
139 | return batch
140 |
141 |
142 | # define mask separation
143 | sep_norm = utils.WavSepMagNorm()
144 |
145 | # define STFT params
146 | STFTParams = namedtuple(
147 | "STFTParams",
148 | ["window_length", "hop_length", "window_type", "padding_type"],
149 | )
150 | stft_params = STFTParams(
151 | window_length=1024,
152 | hop_length=256,
153 | window_type="hann",
154 | padding_type="reflect",
155 | )
156 |
157 | # run eval
158 | tracks = model.tracks
159 | print('Model tracks: {}'.format(tracks))
160 | test_tracks = ['mix'] + [f'{t}_rec' for t in tracks] + [f'{t}_sep' for t in tracks] + [f'{t}_sep_mask' for t in tracks]
161 | test_results = {t: {} for t in test_tracks}
162 | metadata = pd.read_csv(csv_path)
163 |
164 | for i in tqdm(range(len(metadata)), desc='Eval'):
165 | # for i in tqdm(range(20), desc='Eval'):
166 | wav_info = metadata.iloc[i]
167 | audio_id = wav_info['id']
168 | start = wav_info['start']
169 | end = wav_info['end']
170 | batch = {}
171 | # read data
172 | for t in tracks:
173 | x, sr = torchaudio.load(wav_info[t])
174 | x = x.mean(dim=0)[..., start: end]
175 | if sr != sample_rate:
176 | x = resample_pool[sr](x)
177 | batch[t] = x
178 | audio_len = x.shape[-1]
179 |
180 | # clip audio
181 | for j, k in enumerate(range(0, audio_len-chunk_len+1, chunk_len)):
182 | clip_id = f'{audio_id}_{j}'
183 | eval_batch = {}
184 | mask = {}
185 | for t in tracks:
186 | audio_clip = batch[t][k:k+chunk_len]
187 |
188 | # silent audio detection
189 | audio_energy = torch.stft(audio_clip, n_fft=stft_params.window_length, hop_length=stft_params.hop_length,
190 | win_length=stft_params.window_length,
191 | window=torch.hann_window(stft_params.window_length, device='cpu'),
192 | pad_mode=stft_params.padding_type, center=True, onesided=True, return_complex=True).abs().sum(dim=0)
193 | count = sum(1 for item in audio_energy if item > 1e-6)
194 | silence_detect = count < threshold * len(audio_energy)
195 | mask[f'{t}_rec'] = silence_detect
196 | mask[f'{t}_sep'] = silence_detect
197 | mask[f'{t}_sep_mask'] = silence_detect
198 |
199 | eval_batch[t] = audio_clip.reshape(1,1,-1).to(device)
200 |
201 | mask['mix'] = all(mask.values())
202 |
203 | # data transform
204 | # eval_batch = _data_transform(eval_batch, transform_cfg=transform_cfg, valid_tracks=tracks, norm_var=0)
205 | eval_batch['mix'] = eval_batch['speech']+eval_batch['music']+eval_batch['sfx']
206 | eval_batch['valid_tracks'] = tracks
207 | eval_batch['random_swap'] = False
208 |
209 | # mixture forward
210 | with torch.no_grad():
211 | output_audio = model.evaluate(input_audio=eval_batch['mix'],
212 | output_tracks=['mix']+tracks)
213 | # eval_batch = model(eval_batch)
214 | # output_audio = eval_batch['recon'][:,:,0]
215 |
216 | # Eval mix reconstruction
217 | est = output_audio[:, 0].unsqueeze(1)
218 | ref = eval_batch['mix']
219 | test_results['mix'][clip_id] = {}
220 | if mask['mix']:
221 | test_results['mix'][clip_id]['stft'] = None
222 | test_results['mix'][clip_id]['mel'] = None
223 | test_results['mix'][clip_id]['sisdr'] = None
224 | if use_visqol:
225 | test_results['mix'][clip_id]['visqol'] = None
226 | else:
227 | test_results['mix'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item()
228 | test_results['mix'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item()
229 | test_results['mix'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item()
230 | if use_visqol:
231 | test_results['mix'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate)
232 |
233 | # Eval separation using synthesizer (decoder)
234 | for p, t in enumerate(tracks):
235 | est = output_audio[:,p+1].unsqueeze(1)
236 | ref = eval_batch[t]
237 | test_results[f'{t}_sep'][clip_id] = {}
238 | if mask[f'{t}_sep']:
239 | test_results[f'{t}_sep'][clip_id]['stft'] = None
240 | test_results[f'{t}_sep'][clip_id]['mel'] = None
241 | test_results[f'{t}_sep'][clip_id]['sisdr'] = None
242 | if use_visqol:
243 | test_results[f'{t}_sep'][clip_id]['visqol'] = None
244 | else:
245 | test_results[f'{t}_sep'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item()
246 | test_results[f'{t}_sep'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item()
247 | test_results[f'{t}_sep'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item()
248 | if use_visqol:
249 | test_results[f'{t}_sep'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate)
250 |
251 | # Eval separation using mask
252 | mix = eval_batch['mix'].unsqueeze(2)
253 | signal_sep = output_audio[:,1:].unsqueeze(2)
254 | all_sep_mask_norm = sep_norm(mix, signal_sep)
255 | for p, t in enumerate(tracks):
256 | est = all_sep_mask_norm[:,p]
257 | ref = eval_batch[t]
258 | ref = ref[...,:est.shape[-1]] # stft + istft. shorter
259 | # breakpoint()
260 | test_results[f'{t}_sep_mask'][clip_id] = {}
261 | if mask[f'{t}_sep_mask']:
262 | test_results[f'{t}_sep_mask'][clip_id]['stft'] = None
263 | test_results[f'{t}_sep_mask'][clip_id]['mel'] = None
264 | test_results[f'{t}_sep_mask'][clip_id]['sisdr'] = None
265 | if use_visqol:
266 | test_results[f'{t}_sep'][clip_id]['visqol'] = None
267 | else:
268 | test_results[f'{t}_sep_mask'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item()
269 | test_results[f'{t}_sep_mask'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item()
270 | test_results[f'{t}_sep_mask'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item()
271 | if use_visqol:
272 | test_results[f'{t}_sep_mask'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate)
273 |
274 | # Evaluate reconstruction of single track
275 | for p, t in enumerate(tracks):
276 | # single track forward
277 | with torch.no_grad():
278 | output_audio = model.evaluate(input_audio=eval_batch[t],
279 | output_tracks=[t])
280 | est = output_audio
281 | ref = eval_batch[t]
282 | test_results[f'{t}_rec'][clip_id] = {}
283 | if mask[f'{t}_rec']:
284 | test_results[f'{t}_rec'][clip_id]['stft'] = None
285 | test_results[f'{t}_rec'][clip_id]['mel'] = None
286 | test_results[f'{t}_rec'][clip_id]['sisdr'] = None
287 | if use_visqol:
288 | test_results[f'{t}_rec'][clip_id]['visqol'] = None
289 | else:
290 | test_results[f'{t}_rec'][clip_id]['stft'] = metric_stft(est=est, ref=ref).item()
291 | test_results[f'{t}_rec'][clip_id]['mel'] = metric_mel(est=est, ref=ref).item()
292 | test_results[f'{t}_rec'][clip_id]['sisdr'] = - metric_sisdr(est=est, ref=ref).item()
293 | if use_visqol:
294 | test_results[f'{t}_rec'][clip_id]['visqol'] = metric_visqol(est=est, ref=ref, sr=sample_rate)
295 |
296 |
297 | test_results['summary'] = {}
298 | for track in test_tracks:
299 | test_results['summary'][track] = {}
300 | list_stft = []
301 | list_mel = []
302 | list_sisdr = []
303 | if use_visqol:
304 | list_visqol = []
305 |
306 | for metrics in test_results[track].values():
307 | list_stft.append(metrics['stft'])
308 | list_mel.append(metrics['mel'])
309 | list_sisdr.append(metrics['sisdr'])
310 | if use_visqol:
311 | list_visqol.append(metrics['visqol'])
312 |
313 | np_stft = np.array([x for x in list_stft if x is not None])
314 | np_mel = np.array([x for x in list_mel if x is not None])
315 | np_sisdr = np.array([x for x in list_sisdr if x is not None])
316 | if use_visqol:
317 | np_visqol = np.array([x for x in list_visqol if x is not None])
318 |
319 | stft_m, stft_std = np.mean(np_stft), np.std(np_stft)
320 | mel_m, mel_std = np.mean(np_mel), np.std(np_mel)
321 | sisdr_m, sisdr_std = np.mean(np_sisdr), np.std(np_sisdr)
322 | if use_visqol:
323 | visqol_m, visqol_std = np.mean(np_visqol), np.std(np_visqol)
324 |
325 | print('='*80)
326 | print(f'{track}')
327 | print('Valid datapoint: {}/{}'.format(len(np_stft), len(list_stft)))
328 | print('Distance STFT: {:.2f} +/- {:.2f}'.format(stft_m, stft_std))
329 | print('Distance Mel: {:.2f} +/- {:.2f}'.format(mel_m, mel_std))
330 | print('SI-SDR: {:.2f} +/- {:.2f}'.format(sisdr_m, sisdr_std))
331 | if use_visqol:
332 | print('VisQOL: {:.2f} +/- {:.2f}'.format(visqol_m, visqol_std))
333 |
334 | test_results['summary'][track]['tot_seq'] = len(list_stft)
335 | test_results['summary'][track]['valid_seq'] = len(np_stft)
336 | test_results['summary'][track]['stft'] = {'mean': stft_m, 'std': stft_std}
337 | test_results['summary'][track]['mel'] = {'mean': mel_m, 'std': mel_std}
338 | test_results['summary'][track]['sisdr'] = {'mean': sisdr_m, 'std': sisdr_std}
339 | if use_visqol:
340 | test_results['summary'][track]['visqol'] = {'mean': visqol_m, 'std': visqol_std}
341 |
342 |
343 | # save to json
344 | json_filename = ret_dir / '{}_{}s.json'.format(csv_path.stem, length)
345 | with open(json_filename, 'w') as f:
346 | json.dump(test_results, f, indent=1)
--------------------------------------------------------------------------------
/install_visqol.md:
--------------------------------------------------------------------------------
1 | ## 1. Install Bazel
2 |
3 | Here is an example, you may need to change the download link to the latest version, and adapt to your conda env filepath
4 |
5 | ```bash
6 | # Linux for example
7 |
8 | # Check the architecture of your computer/cluster
9 | dpkg --print-architecture
10 |
11 | # Download the right version of Bazelisk binary realease, e.g. linux-amd64
12 | # check the latest version: https://github.com/bazelbuild/bazelisk/releases
13 | wget https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-linux-amd64
14 |
15 | # Rename and move to your conda env, e.g. torch2
16 | chmod +x bazelisk-linux-amd64
17 | mv bazelisk-linux-amd64 /home/ids/xbie/anaconda3/envs/torch2/bin/bazel
18 |
19 | # Now you can run bazel by activating the conda env
20 | ```
21 |
22 | ## 2. Install VisQOL
23 | ```bash
24 | # Clone the VisQOL repo
25 | git clone https://github.com/google/visqol.git
26 |
27 | # Install bin, need GCC 12
28 | cd visqol
29 | bazel build :visqol -c opt --jobs=10
30 |
31 | # Install Python API
32 | pip install .
33 | ```
34 |
35 | ## 3. Potential Problem
36 | Here are some problems I have encountered during installation and how I solved them
37 | ```bash
38 | # It might encounter gcc/g++ compilation problem, even if you have successfully build visqol
39 | # e.g. libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/xbie/anaconda3/envs/torch2/lib/python3.11/site-packages/visqol/visqol_lib_py.so)
40 | # in that case, masure your gcc/g++ version is >= 12
41 | conda install conda-forge::gxx_impl_linux-64
42 | conda install -c conda-forge gcc=12.1.0
43 |
44 | # In case your Bazel is installed in the system path
45 | # then you should update gcc/g++ in system-wise
46 | sudo apt install --reinstall gcc-12
47 | sudo apt install --reinstall g++-12
48 | sudo ln -s -f /usr/bin/gcc-12 /usr/bin/gcc
49 | sudo ln -s -f /usr/bin/g++-12 /usr/bin/g++
50 |
51 | # In case fail to import `visqol_lib_py` with unkown path
52 | # install it manually, make sure you are in the Visqol root path
53 | # if on the cluster, use `--jobs=10` to limit the parallel jobs
54 | # then we need to re-run the installation to link the .so library
55 | bazel build -c opt //python:visqol_lib_py.so --jobs=10
56 | pip install .
57 |
58 | # If still have problems,
59 | # remove all Bazel cache files and re-run the installation
60 | rm -rf ~/.cache/bazel
61 | rm -rf ~/.cache/bazelisk
62 | ``
63 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Copyright (c) 2024 by Telecom-Paris
5 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
6 | License agreement in LICENSE.txt
7 | """
8 |
9 | import os
10 | import sys
11 | import shutil
12 | import logging
13 | import hydra
14 | from omegaconf import DictConfig, OmegaConf
15 |
16 | import torch
17 |
18 | @hydra.main(version_base=None, config_path="config", config_name="default")
19 | def main(cfg: DictConfig) -> None:
20 |
21 | from src.trainer import Trainer
22 | trainer = Trainer(cfg)
23 | trainer.run()
24 |
25 |
26 | if __name__ == "__main__":
27 | main()
--------------------------------------------------------------------------------
/prepare/mani_dnr.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 | import os
7 | import argparse
8 | import librosa
9 | import torchaudio
10 | from tqdm import tqdm
11 | from pathlib import Path
12 | from collections import namedtuple
13 | import torch
14 |
15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset',
16 | add_help=True,
17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 |
19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/dnr_v2', help='Audio Dataset Path')
20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest')
21 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format')
22 |
23 | args = parser.parse_args()
24 | data_dir = Path(args.data_dir)
25 | out_dir = Path(args.out_dir)
26 | ext = args.ext
27 |
28 | train_dir = data_dir / 'tr'
29 | val_dir = data_dir / 'cv'
30 | test_dir = data_dir / 'tt'
31 | tracks = ['speech', 'music', 'sfx']
32 |
33 | # define STFT params
34 | STFTParams = namedtuple(
35 | "STFTParams",
36 | ["window_length", "hop_length", "window_type", "padding_type"],
37 | )
38 | stft_params = STFTParams(
39 | window_length=1024,
40 | hop_length=256,
41 | window_type="hann",
42 | padding_type="reflect",
43 | )
44 |
45 | ## train
46 | for track in tracks:
47 | with open(out_dir / f'{track}_dnr.csv', 'w') as f:
48 | f.write('id,filepath,sr,length,start,end\n')
49 | for subdir in tqdm(sorted(train_dir.iterdir()), desc=f'train_{track}'):
50 | if subdir.is_dir():
51 | audio_id = subdir.name
52 | audio_filepath = subdir / f'{track}.{ext}'
53 | x, sr = torchaudio.load(audio_filepath)
54 | length = x.shape[-1]
55 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
56 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
57 | f.write(line)
58 |
59 |
60 | ## val with no silence
61 | length = 5
62 | threshold = 0.5
63 | with open(out_dir / 'val_dnr.csv', 'w') as f:
64 | f.write('id,mix,speech,music,sfx,sr,length,start,end\n')
65 | for subdir in tqdm(sorted(val_dir.iterdir()), desc='val'):
66 | if subdir.is_dir():
67 | audio_id = subdir.name
68 | mix_filepath = subdir / f'mix.{ext}'
69 | speech_filepath = subdir / f'speech.{ext}'
70 | music_filepath = subdir / f'music.{ext}'
71 | sfx_filepath = subdir / f'sfx.{ext}'
72 |
73 | x_mix, sr = torchaudio.load(mix_filepath)
74 | x_speech, _ = torchaudio.load(speech_filepath)
75 | x_music, _ = torchaudio.load(music_filepath)
76 | x_sfx, _ = torchaudio.load(sfx_filepath)
77 |
78 | chunk_len = sr * length
79 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x_mix.numpy(), top_db=30)
80 | for j, k in enumerate(range(trim30dBs, trim30dBe-chunk_len, chunk_len)):
81 | chunk_mix = x_mix[..., k:k+chunk_len]
82 | chunk_speech = x_speech[..., k:k+chunk_len]
83 | chunk_music = x_music[..., k:k+chunk_len]
84 | chunk_sfx = x_sfx[..., k:k+chunk_len]
85 |
86 | # detect silence length
87 | is_active = True
88 | for audio_clip in [chunk_mix, chunk_speech, chunk_music, chunk_sfx]:
89 | audio_clip = audio_clip.reshape(-1)
90 | audio_energy = torch.stft(audio_clip, n_fft=stft_params.window_length,
91 | hop_length=stft_params.hop_length, win_length=stft_params.window_length,
92 | window=torch.hann_window(stft_params.window_length, device='cpu'),
93 | pad_mode=stft_params.padding_type, center=True, onesided=True, return_complex=True
94 | ).abs().sum(dim=0)
95 | count = sum(1 for item in audio_energy if item > 1e-6)
96 | if count < threshold * len(audio_energy):
97 | is_active = False
98 |
99 | # save if no silence detect
100 | if is_active:
101 | clip_id = f'{audio_id}_{j}'
102 | start = k
103 | end = k + chunk_len
104 | line = '{},{},{},{},{},{},{},{},{}\n'.format(clip_id, mix_filepath, speech_filepath, music_filepath, sfx_filepath, \
105 | sr, chunk_len, start, end)
106 | f.write(line)
107 |
108 | ## test with no silence
109 | length = 10
110 | threshold = 0.5
111 | with open(out_dir / 'test_dnr.csv', 'w') as f:
112 | f.write('id,mix,speech,music,sfx,sr,length,start,end\n')
113 | for subdir in tqdm(sorted(val_dir.iterdir()), desc='test'):
114 | if subdir.is_dir():
115 | audio_id = subdir.name
116 | mix_filepath = subdir / f'mix.{ext}'
117 | speech_filepath = subdir / f'speech.{ext}'
118 | music_filepath = subdir / f'music.{ext}'
119 | sfx_filepath = subdir / f'sfx.{ext}'
120 |
121 | x_mix, sr = torchaudio.load(mix_filepath)
122 | x_speech, _ = torchaudio.load(speech_filepath)
123 | x_music, _ = torchaudio.load(music_filepath)
124 | x_sfx, _ = torchaudio.load(sfx_filepath)
125 |
126 | chunk_len = sr * length
127 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x_mix.numpy(), top_db=30)
128 | for j, k in enumerate(range(trim30dBs, trim30dBe-chunk_len, chunk_len)):
129 | chunk_mix = x_mix[..., k:k+chunk_len]
130 | chunk_speech = x_speech[..., k:k+chunk_len]
131 | chunk_music = x_music[..., k:k+chunk_len]
132 | chunk_sfx = x_sfx[..., k:k+chunk_len]
133 |
134 | # detect silence length
135 | is_active = True
136 | for audio_clip in [chunk_mix, chunk_speech, chunk_music, chunk_sfx]:
137 | audio_clip = audio_clip.reshape(-1)
138 | audio_energy = torch.stft(audio_clip, n_fft=stft_params.window_length,
139 | hop_length=stft_params.hop_length, win_length=stft_params.window_length,
140 | window=torch.hann_window(stft_params.window_length, device='cpu'),
141 | pad_mode=stft_params.padding_type, center=True, onesided=True, return_complex=True
142 | ).abs().sum(dim=0)
143 | count = sum(1 for item in audio_energy if item > 1e-6)
144 | if count < threshold * len(audio_energy):
145 | is_active = False
146 |
147 | # save if no silence detect
148 | if is_active:
149 | clip_id = f'{audio_id}_{j}'
150 | start = k
151 | end = k + chunk_len
152 | line = '{},{},{},{},{},{},{},{},{}\n'.format(clip_id, mix_filepath, speech_filepath, music_filepath, sfx_filepath, \
153 | sr, chunk_len, start, end)
154 | f.write(line)
--------------------------------------------------------------------------------
/prepare/mani_dns_clean.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 | import os
7 | import argparse
8 | import librosa
9 | import torchaudio
10 | from tqdm import tqdm
11 | from pathlib import Path
12 | from collections import namedtuple
13 | import torch
14 |
15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5 clean speech',
16 | add_help=True,
17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 |
19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/DNS-Challenge/', help='Audio Dataset Path')
20 | parser.add_argument('--partition', type=str, default='all', help='Audio partition to create manifest')
21 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest')
22 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short')
23 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format')
24 |
25 | args = parser.parse_args()
26 | data_dir = Path(args.data_dir)
27 | out_dir = Path(args.out_dir)
28 | partition = args.partition
29 | threshold = args.threshold
30 | ext = args.ext
31 |
32 | if partition == 'all':
33 | audio_sources = ['emotional_speech',
34 | 'read_speech',
35 | 'vctk_wav48_silence_trimmed',
36 | 'VocalSet_48kHz_mono',
37 | 'french_speech',
38 | 'german_speech',
39 | 'italian_speech',
40 | 'russian_speech',
41 | 'spanish_speech']
42 | else:
43 | audio_sources = [partition]
44 |
45 | for audio_source in audio_sources:
46 | audio_dir = data_dir / 'datasets_fullband/clean_fullband' / audio_source
47 | audio_manifest = out_dir / f'speech_dns5_{audio_source}.csv'
48 | audio_len = 0
49 | with open(audio_manifest, 'w') as f:
50 | f.write('id,filepath,sr,length,start,end\n')
51 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'clean/{audio_source}'):
52 | audio_id = audio_filepath.stem
53 | x, sr = torchaudio.load(audio_filepath)
54 | length = x.shape[-1]
55 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
56 | utt_len = (trim30dBe - trim30dBs) / sr
57 | if utt_len >= threshold:
58 | audio_len += length / sr
59 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
60 | f.write(line)
61 | print('Source: {}. audio len: {:.2f}h'.format(audio_source, audio_len/3600))
62 |
63 |
--------------------------------------------------------------------------------
/prepare/mani_dns_noise.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 | import os
7 | import argparse
8 | import librosa
9 | import torchaudio
10 | from tqdm import tqdm
11 | from pathlib import Path
12 | from collections import namedtuple
13 | import torch
14 |
15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5 noise data',
16 | add_help=True,
17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 |
19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/DNS-Challenge/', help='Audio Dataset Path')
20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest')
21 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short')
22 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format')
23 |
24 | args = parser.parse_args()
25 | data_dir = Path(args.data_dir)
26 | out_dir = Path(args.out_dir)
27 | threshold = args.threshold
28 | ext = args.ext
29 |
30 | # noise
31 | audio_dir = data_dir / 'datasets_fullband/noise_fullband'
32 | audio_manifest = out_dir / f'sfx_dns5.csv'
33 | audio_len = 0
34 | with open(audio_manifest, 'w') as f:
35 | f.write('id,filepath,sr,length,start,end\n')
36 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'noise'):
37 | audio_id = audio_filepath.stem
38 | x, sr = torchaudio.load(audio_filepath)
39 | length = x.shape[-1]
40 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
41 | utt_len = (trim30dBe - trim30dBs) / sr
42 | if utt_len >= threshold:
43 | audio_len += length / sr
44 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
45 | f.write(line)
46 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600))
--------------------------------------------------------------------------------
/prepare/mani_jamendo.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 | import os
7 | import argparse
8 | import librosa
9 | import torchaudio
10 | from tqdm import tqdm
11 | from pathlib import Path
12 | from collections import namedtuple
13 | import torch
14 |
15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, MTG-Jamendo',
16 | add_help=True,
17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 |
19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/mtg-jamendo/', help='Audio Dataset Path')
20 | parser.add_argument('--partition', type=int, default=-1, help='Audio partition to create manifest')
21 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest')
22 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short')
23 | parser.add_argument('--ext', type=str, default='mp3', choices=['wav', 'mp3', 'flac'], help='Audio format')
24 |
25 | args = parser.parse_args()
26 | data_dir = Path(args.data_dir)
27 | out_dir = Path(args.out_dir)
28 | partition = args.partition
29 | threshold = args.threshold
30 | ext = args.ext
31 |
32 | if partition == -1:
33 | audio_sources = list(range(100))
34 | audio_manifest = out_dir / f'music_jamendo.csv'
35 | else:
36 | audio_sources = list(range(10*partition, 10*partition+10))
37 | audio_manifest = out_dir / f'music_jamendo_{partition}.csv'
38 |
39 | with open(audio_manifest, 'w') as f:
40 | f.write('id,filepath,sr,length,start,end\n')
41 | audio_len = 0
42 | for audio_source in audio_sources:
43 | audio_dir = data_dir / f"{audio_source:0>2}"
44 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f"jamendo_{audio_source:0>2}"):
45 | audio_id = audio_filepath.stem
46 | x, sr = torchaudio.load(audio_filepath)
47 | length = x.shape[-1]
48 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
49 | utt_len = (trim30dBe - trim30dBs) / sr
50 | if utt_len >= threshold:
51 | audio_len += length / sr
52 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
53 | f.write(line)
54 | print('Source: music. audio len: {:.2f}h'.format(audio_len/3600))
55 |
56 |
--------------------------------------------------------------------------------
/prepare/mani_musan.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 | import os
7 | import argparse
8 | import librosa
9 | import torchaudio
10 | from tqdm import tqdm
11 | from pathlib import Path
12 | from collections import namedtuple
13 | import torch
14 |
15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5',
16 | add_help=True,
17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 |
19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/Libri-light/', help='Audio Dataset Path')
20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest')
21 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short')
22 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format')
23 | args = parser.parse_args()
24 | data_dir = Path(args.data_dir)
25 | out_dir = Path(args.out_dir)
26 | threshold = args.threshold
27 | ext = args.ext
28 |
29 | # speech
30 | audio_dir = data_dir / 'speech'
31 | audio_manifest = out_dir / f'speech_musan.csv'
32 | audio_len = 0
33 | with open(audio_manifest, 'w') as f:
34 | f.write('id,filepath,sr,length,start,end\n')
35 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'speech'):
36 | audio_id = audio_filepath.stem
37 | x, sr = torchaudio.load(audio_filepath)
38 | length = x.shape[-1]
39 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
40 | utt_len = (trim30dBe - trim30dBs) / sr
41 | if utt_len >= threshold:
42 | audio_len += length / sr
43 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
44 | f.write(line)
45 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600))
46 |
47 | # music
48 | audio_dir = data_dir / 'music'
49 | audio_manifest = out_dir / f'music_musan.csv'
50 | audio_len = 0
51 | with open(audio_manifest, 'w') as f:
52 | f.write('id,filepath,sr,length,start,end\n')
53 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'music'):
54 | audio_id = audio_filepath.stem
55 | x, sr = torchaudio.load(audio_filepath)
56 | length = x.shape[-1]
57 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
58 | utt_len = (trim30dBe - trim30dBs) / sr
59 | if utt_len >= threshold:
60 | audio_len += length / sr
61 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
62 | f.write(line)
63 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600))
64 |
65 | # noise
66 | audio_dir = data_dir / 'noise'
67 | audio_manifest = out_dir / f'sfx_musan.csv'
68 | audio_len = 0
69 | with open(audio_manifest, 'w') as f:
70 | f.write('id,filepath,sr,length,start,end\n')
71 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'noise'):
72 | audio_id = audio_filepath.stem
73 | x, sr = torchaudio.load(audio_filepath)
74 | length = x.shape[-1]
75 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
76 | utt_len = (trim30dBe - trim30dBs) / sr
77 | if utt_len >= threshold:
78 | audio_len += length / sr
79 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
80 | f.write(line)
81 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600))
--------------------------------------------------------------------------------
/prepare/mani_wham.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 | import os
7 | import argparse
8 | import librosa
9 | import torchaudio
10 | from tqdm import tqdm
11 | from pathlib import Path
12 | from collections import namedtuple
13 | import torch
14 |
15 | parser = argparse.ArgumentParser(description='Generate manifest for audio dataset, DNS-challenge-5',
16 | add_help=True,
17 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 |
19 | parser.add_argument('--data-dir', type=str, default='/home/xbie/Data/high_res_wham/', help='Audio Dataset Path')
20 | parser.add_argument('--out-dir', type=str, default='./manifest', help='Path to write manifest')
21 | parser.add_argument('--threshold', type=float, default=0.5, help='Remove audio files that are too short')
22 | parser.add_argument('--ext', type=str, default='wav', choices=['wav', 'mp3', 'flac'], help='Audio format')
23 |
24 | args = parser.parse_args()
25 | data_dir = Path(args.data_dir)
26 | out_dir = Path(args.out_dir)
27 | threshold = args.threshold
28 | ext = args.ext
29 |
30 | # noise
31 | audio_dir = data_dir / 'audio'
32 | audio_manifest = out_dir / f'sfx_wham.csv'
33 | audio_len = 0
34 | with open(audio_manifest, 'w') as f:
35 | f.write('id,filepath,sr,length,start,end\n')
36 | for audio_filepath in tqdm(list(audio_dir.glob(f'**/*.{ext}')), desc=f'noise'):
37 | audio_id = audio_filepath.stem
38 | x, sr = torchaudio.load(audio_filepath)
39 | length = x.shape[-1]
40 | _, (trim30dBs,trim30dBe) = librosa.effects.trim(x.numpy(), top_db=30)
41 | utt_len = (trim30dBe - trim30dBs) / sr
42 | if utt_len >= threshold:
43 | audio_len += length / sr
44 | line = '{},{},{},{},{},{}\n'.format(audio_id, audio_filepath, sr, length, trim30dBs, trim30dBe)
45 | f.write(line)
46 | print('Source: noise. audio len: {:.2f}h'.format(audio_len/3600))
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # from .audio_dnr import DatasetDnR
2 | from .audio_dataset import DatasetAudioTrain, DatasetAudioVal
--------------------------------------------------------------------------------
/src/datasets/audio_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 |
7 | import math
8 | import numpy as np
9 | import pandas as pd
10 | import julius
11 | from pathlib import Path
12 | from omegaconf import DictConfig
13 | from typing import List
14 |
15 | import torch
16 | import torchaudio
17 | from torch.utils.data import Dataset
18 | from accelerate.logging import get_logger
19 |
20 | logger = get_logger(__name__)
21 |
22 |
23 |
24 | class DatasetAudioTrain(Dataset):
25 | def __init__(self,
26 | sample_rate: int,
27 | speech: List[str],
28 | music: List[str],
29 | sfx: List[str],
30 | n_examples: int = 10000000,
31 | chunk_size: float = 2.0,
32 | trim_silence: bool = False,
33 | **kwargs
34 | ) -> None:
35 | super().__init__()
36 |
37 | # init
38 | self.EPS = 1e-8
39 | self.sample_rate = sample_rate # target sampling rate
40 | self.length = n_examples # pseudo dataset length
41 | self.chunk_size = chunk_size # negative for entire sentence
42 | self.trim_silence = trim_silence
43 |
44 | # manifest
45 | self.csv_files = {}
46 | self.csv_files['speech'] = [Path(filepath) for filepath in speech]
47 | self.csv_files['music'] = [Path(filepath) for filepath in music]
48 | self.csv_files['sfx'] = [Path(filepath) for filepath in sfx]
49 |
50 | # check valid samples
51 | self.resample_pool = dict()
52 | self.metadata_dict = dict()
53 | self.lens_dict = dict()
54 | for track, files in self.csv_files.items():
55 | logger.info(f"Track: {track}")
56 | orig_utt, orig_len, drop_utt, drop_len = 0, 0, 0, 0
57 | metadata_list = []
58 | for tsv_filepath in files:
59 | if not tsv_filepath.is_file():
60 | logger.error('No tsv file found in: {}'.format(tsv_filepath))
61 | continue
62 | else:
63 | logger.info(f'Manifest filepath: {tsv_filepath}')
64 | metadata = pd.read_csv(tsv_filepath)
65 | if self.trim_silence:
66 | wav_lens = (metadata['end'] - metadata['start']) / metadata['sr']
67 | else:
68 | wav_lens = metadata['length'] / metadata['sr']
69 | # remove wav files that too short
70 | orig_utt += len(metadata)
71 | drop_rows = []
72 | for row_idx in range(len(wav_lens)):
73 | orig_len += wav_lens[row_idx]
74 | if wav_lens[row_idx] < self.chunk_size:
75 | drop_rows.append(row_idx)
76 | drop_utt += 1
77 | drop_len += wav_lens[row_idx]
78 | else:
79 | # prepare julius resample
80 | sr = int(metadata.at[row_idx, 'sr'])
81 | if sr not in self.resample_pool.keys():
82 | old_sr = sr
83 | new_sr = self.sample_rate
84 | gcd = math.gcd(old_sr, new_sr)
85 | old_sr = old_sr // gcd
86 | new_sr = new_sr // gcd
87 | self.resample_pool[sr] = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr)
88 |
89 | metadata = metadata.drop(drop_rows)
90 | metadata_list.append(metadata)
91 |
92 | self.metadata_dict[track] = pd.concat(metadata_list, axis=0)
93 | self.lens_dict[track] = len(self.metadata_dict[track])
94 |
95 | logger.info("Drop {}/{} utterances ({:.2f}/{:.2f}h), shorter than {:.2f}s".format(
96 | drop_utt, orig_utt, drop_len / 3600, orig_len / 3600, self.chunk_size
97 | ))
98 | logger.info('Used data: {} utterances, ({:.2f} h)'.format(
99 | self.lens_dict[track], (orig_len-drop_len) / 3600
100 | ))
101 |
102 | logger.info('Resample pool: {}'.format(list(self.resample_pool.keys())))
103 |
104 |
105 | def __len__(self):
106 | return self.length # can be any number
107 |
108 |
109 | def __getitem__(self, idx:int):
110 |
111 | batch = {}
112 | for track in self.csv_files.keys():
113 | idx = np.random.randint(self.lens_dict[track])
114 | wav_info = self.metadata_dict[track].iloc[idx]
115 | chunk_len = int(wav_info['sr'] * self.chunk_size)
116 |
117 | # slice wav files
118 | if self.trim_silence:
119 | start = np.random.randint(int(wav_info['start']), int(wav_info['end']) - chunk_len + 1)
120 | else:
121 | start = np.random.randint(0, int(wav_info['length']) - chunk_len + 1)
122 |
123 | # load file
124 | x, sr = torchaudio.load(wav_info['filepath'],
125 | frame_offset=start,
126 | num_frames=chunk_len)
127 |
128 | # single channel
129 | x = x.mean(dim=0, keepdim=True)
130 |
131 | # resample
132 | if sr != self.sample_rate:
133 | x = self.resample_pool[sr](x)
134 |
135 | batch[track] = x
136 |
137 | return batch
138 |
139 |
140 |
141 | class DatasetAudioVal(Dataset):
142 | def __init__(self,
143 | sample_rate: int,
144 | tsv_filepath: str,
145 | chunk_size: float = 5.0,
146 | **kwargs
147 | ) -> None:
148 | super().__init__()
149 |
150 | # init
151 | self.EPS = 1e-8
152 | self.sample_rate = sample_rate # target sampling rate
153 | self.tsv_filepath = Path(tsv_filepath)
154 | self.chunk_size = chunk_size # negative for entire sentence
155 | self.resample_pool = dict()
156 |
157 | # read manifest tsv file
158 | if self.tsv_filepath.is_file():
159 | metadata = pd.read_csv(self.tsv_filepath)
160 | logger.info(f'Manifest filepath: {self.tsv_filepath}')
161 | else:
162 | logger.error('No tsv file found in: {}'.format(self.tsv_filepath))
163 |
164 | # audio lengths
165 | wav_lens = (metadata['end'] - metadata['start']) / metadata['sr']
166 |
167 | # remove wav files that too short
168 | orig_utt = len(metadata)
169 | orig_len, drop_utt, drop_len = 0, 0, 0
170 | drop_rows = []
171 | for row_idx in range(len(wav_lens)):
172 | orig_len += wav_lens[row_idx]
173 | if wav_lens[row_idx] < self.chunk_size:
174 | drop_rows.append(row_idx)
175 | drop_utt += 1
176 | drop_len += wav_lens[row_idx]
177 | else:
178 | # prepare julius resample
179 | sr = int(metadata.at[row_idx, 'sr'])
180 | if sr not in self.resample_pool.keys():
181 | old_sr = sr
182 | new_sr = self.sample_rate
183 | gcd = math.gcd(old_sr, new_sr)
184 | old_sr = old_sr // gcd
185 | new_sr = new_sr // gcd
186 | self.resample_pool[sr] = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr)
187 |
188 | logger.info("Drop {}/{} utts ({:.2f}/{:.2f}h), shorter than {:.2f}s".format(
189 | drop_utt, orig_utt, drop_len / 3600, orig_len / 3600, self.chunk_size
190 | ))
191 | logger.info('Actual data size: {} utterance, ({:.2f} h)'.format(
192 | orig_utt-drop_utt, (orig_len-drop_len) / 3600
193 | ))
194 | logger.info('Resample pool: {}'.format(list(self.resample_pool.keys())))
195 |
196 | self.metadata = metadata.drop(drop_rows)
197 |
198 |
199 | def __len__(self):
200 | return len(self.metadata)
201 |
202 | def __getitem__(self, idx:int):
203 |
204 | wav_info = self.metadata.iloc[idx]
205 | chunk_len = int(wav_info['sr'] * self.chunk_size)
206 | start = wav_info['start']
207 |
208 | # Load wav files and resample if needed
209 | batch = {}
210 | for track in ['mix', 'speech', 'music', 'sfx']:
211 | x, sr = torchaudio.load(wav_info[track],
212 | frame_offset=start,
213 | num_frames=chunk_len)
214 | x = x.mean(dim=0, keepdim=True)
215 | # resample
216 | if sr != self.sample_rate:
217 | x = self.resample_pool[sr](x)
218 | batch[track] = x
219 |
220 | return batch
221 |
222 |
223 |
224 | if __name__ == '__main__':
225 | import time
226 | from tqdm import tqdm
227 | from torch.utils.data import DataLoader
228 | from accelerate import Accelerator
229 | from omegaconf import OmegaConf
230 | accelerator = Accelerator()
231 |
232 | # train
233 | cfg = OmegaConf.create()
234 | cfg.sample_rate = 16000
235 | cfg.speech = [
236 | './manifest/speech_dnr.csv'
237 | ]
238 | cfg.music = [
239 | './manifest/music_dnr.csv'
240 | ]
241 | cfg.sfx = [
242 | './manifest/sfx_dnr.csv'
243 | ]
244 | cfg.general = [
245 | './manifest/mix_dnr.csv'
246 | ]
247 | cfg.chunk_size = 2.0
248 | cfg.trim_silence = False
249 | train_dataset = DatasetAudioTrain(**cfg)
250 |
251 | print('Train data: {}'.format(len(train_dataset)))
252 |
253 | idx = np.random.randint(train_dataset.__len__())
254 | data_ = train_dataset.__getitem__(idx)
255 | for k, v in data_.items():
256 | print('audio idx: {} audio: {}, length: {}'.format(idx, k, v.shape))
257 |
258 | # val
259 | cfg = OmegaConf.create()
260 | cfg.sample_rate = 16000
261 | cfg.tsv_filepath = './manifest/val.csv'
262 | cfg.chunk_size = 5.0
263 | val_dataset = DatasetAudioVal(**cfg)
264 |
265 | print('Validation data: {}'.format(len(val_dataset)))
266 | idx = np.random.randint(val_dataset.__len__())
267 | data_ = val_dataset.__getitem__(idx)
268 | for k, v in data_.items():
269 | print('audio idx: {} audio: {}, length: {}'.format(idx, k, v.shape))
270 |
271 | # test
272 | cfg = OmegaConf.create()
273 | cfg.sample_rate = 16000
274 | cfg.tsv_filepath = './manifest/test.csv'
275 | cfg.chunk_size = 10.0
276 | test_dataset = DatasetAudioVal(**cfg)
277 |
278 | print('Test data: {}'.format(len(test_dataset)))
279 | idx = np.random.randint(test_dataset.__len__())
280 | data_ = test_dataset.__getitem__(idx)
281 | for k, v in data_.items():
282 | print('audio idx: {} audio: {}, length: {}'.format(idx, k, v.shape))
283 |
284 | # dataloader
285 | train_dataset.length = 10000
286 | train_loader = DataLoader(dataset=train_dataset,
287 | batch_size=32, num_workers=8,
288 | shuffle=True, drop_last=True)
289 |
290 | total_seq = 0
291 | start_time = time.time()
292 | for i, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
293 | mix_audio = batch['speech']
294 | total_seq += mix_audio.shape[0]
295 | # print(mix_audio.shape)
296 | # breakpoint()
297 |
298 | elapsed_time = time.time() - start_time
299 | tpf = elapsed_time / total_seq
300 | tpb = elapsed_time / (i+1)
301 |
302 | print(f"Read pure data time cost {tpf:.3f}s per file")
303 | print(f"Read pure data time cost {tpb:.3f}s per batch")
304 |
305 |
--------------------------------------------------------------------------------
/src/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from torch.nn import L1Loss, MSELoss
3 | from .sdr import (
4 | SingleSrcNegSDR,
5 | )
6 | from .spectrum import (
7 | MultiScaleSTFTLoss,
8 | MelSpectrogramLoss,
9 | )
10 |
11 | from .adv import (
12 | GANLoss,
13 | )
14 |
15 | from .visqol import VisqolMetric
--------------------------------------------------------------------------------
/src/metrics/adv.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | class GANLoss(nn.Module):
7 | """
8 | Computes a discriminator loss, given a discriminator on
9 | generated waveforms/spectrograms compared to ground truth
10 | waveforms/spectrograms. Computes the loss for both the
11 | discriminator and the generator in separate functions.
12 |
13 | Implementation modified from: h
14 | """
15 |
16 | def __init__(self, discriminator):
17 | super().__init__()
18 | self.discriminator = discriminator
19 |
20 | def forward(self, fake, real):
21 | d_fake = self.discriminator(fake)
22 | d_real = self.discriminator(real)
23 | return d_fake, d_real
24 |
25 | def discriminator_loss(self, fake, real):
26 | d_fake, d_real = self.forward(fake.clone().detach(), real)
27 |
28 | loss_d = 0
29 | for x_fake, x_real in zip(d_fake, d_real):
30 | loss_d += torch.mean(x_fake[-1] ** 2)
31 | loss_d += torch.mean((1 - x_real[-1]) ** 2)
32 | return loss_d
33 |
34 | def generator_loss(self, fake, real):
35 | d_fake, d_real = self.forward(fake, real)
36 |
37 | loss_g = 0
38 | for x_fake in d_fake:
39 | loss_g += torch.mean((1 - x_fake[-1]) ** 2)
40 |
41 | loss_feature = 0
42 |
43 | for i in range(len(d_fake)):
44 | for j in range(len(d_fake[i]) - 1):
45 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
46 | return loss_g, loss_feature
--------------------------------------------------------------------------------
/src/metrics/sdr.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from torch.nn.modules.loss import _Loss
8 |
9 | class SingleSrcNegSDR(_Loss):
10 | r"""Base class for single-source negative SI-SDR, SD-SDR and SNR.
11 |
12 | Args:
13 | sdr_type (str): choose between ``snr`` for plain SNR, ``sisdr`` for
14 | SI-SDR and ``sdsdr`` for SD-SDR [1].
15 | zero_mean (bool, optional): by default it zero mean the ref and
16 | estimate before computing the loss.
17 | take_log (bool, optional): by default the log10 of sdr is returned.
18 | reduction (string, optional): Specifies the reduction to apply to
19 | the output:
20 | ``'none'``: no reduction will be applied,
21 | ``'sum'``: the sum of the output
22 | ``'mean'``: the sum of the output will be divided by the number of
23 | elements in the output.
24 |
25 | Shape:
26 | - ests : :math:`(batch, time)`.
27 | - refs: :math:`(batch, time)`.
28 |
29 | Returns:
30 | :class:`torch.Tensor`: with shape :math:`(batch)` if ``reduction='none'`` else
31 | [] scalar if ``reduction='mean'``.
32 |
33 | Examples
34 | >>> import torch
35 | >>> from asteroid.losses import PITLossWrapper
36 | >>> refs = torch.randn(10, 2, 32000)
37 | >>> ests = torch.randn(10, 2, 32000)
38 | >>> loss_func = PITLossWrapper(SingleSrcNegSDR("sisdr"),
39 | >>> pit_from='pw_pt')
40 | >>> loss = loss_func(ests, refs)
41 |
42 | References
43 | [1] Le Roux, Jonathan, et al. "SDR half-baked or well done." IEEE
44 | International Conference on Acoustics, Speech and Signal
45 | Processing (ICASSP) 2019.
46 |
47 | Implementation modified from Astroid project: https://github.com/asteroid-team/asteroid
48 | """
49 | def __init__(self,
50 | sdr_type : str,
51 | zero_mean: bool = True,
52 | take_log: bool = True,
53 | reduction: str ="mean",
54 | EPS: float = 1e-8):
55 | super().__init__(reduction=reduction)
56 |
57 | assert sdr_type in ["snr", "sisdr", "sdsdr"]
58 | self.sdr_type = sdr_type
59 | self.zero_mean = zero_mean
60 | self.take_log = take_log
61 | self.EPS = 1e-8
62 |
63 | def forward(self, est, ref):
64 |
65 | assert est.shape == ref.shape, (
66 | 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
67 | )
68 | assert len(est.shape) == len(ref.shape) == 3, (
69 | 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
70 | )
71 |
72 | assert est.shape[1] == 1, (
73 | 'expected mono channel, get {} channels'.format(est.shape[1])
74 | )
75 |
76 | B, C, T = est.shape
77 | est = est.reshape(-1, T)
78 | ref = ref.reshape(-1, T)
79 |
80 | # Step 1. Zero-mean norm
81 | if self.zero_mean:
82 | mean_source = torch.mean(ref, dim=1, keepdim=True)
83 | mean_estimate = torch.mean(est, dim=1, keepdim=True)
84 | ref = ref - mean_source
85 | est = est - mean_estimate
86 | # Step 2. Pair-wise SI-SDR.
87 | if self.sdr_type in ["sisdr", "sdsdr"]:
88 | # [batch, 1]
89 | dot = torch.sum(est * ref, dim=1, keepdim=True)
90 | # [batch, 1]
91 | s_ref_energy = torch.sum(ref**2, dim=1, keepdim=True) + self.EPS
92 | # [batch, time]
93 | scaled_ref = dot * ref / s_ref_energy
94 | else:
95 | # [batch, time]
96 | scaled_ref = ref
97 | if self.sdr_type in ["sdsdr", "snr"]:
98 | e_noise = est - ref
99 | else:
100 | e_noise = est - scaled_ref
101 |
102 | self.cache = torch.cat((ref, est, scaled_ref, e_noise), dim=0).detach().cpu()
103 |
104 | # [batch]
105 | losses = torch.sum(scaled_ref**2, dim=1) / (torch.sum(e_noise**2, dim=1) + self.EPS)
106 |
107 | if self.take_log:
108 | losses = 10 * torch.log10(losses + self.EPS)
109 |
110 | if self.reduction == "mean":
111 | losses = losses.mean()
112 | elif self.reduction == "sum":
113 | losses = losses.sum()
114 | else:
115 | losses = losses
116 |
117 | return -losses
118 |
119 |
120 | SingleSISDRLoss = SingleSrcNegSDR("sisdr")
121 | SingleSDSDRLoss = SingleSrcNegSDR("sdsdr")
122 | SingleSNRLoss = SingleSrcNegSDR("snr")
123 |
--------------------------------------------------------------------------------
/src/metrics/spectrum.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import List, Callable
3 | from collections import namedtuple
4 | import numpy as np
5 | from librosa.filters import mel
6 |
7 | import torch
8 | from torch import nn
9 | from torch import Tensor
10 |
11 | STFTParams = namedtuple(
12 | "STFTParams",
13 | ["window_length", "hop_length", "window_type", "padding_type"],
14 | )
15 |
16 |
17 | def get_window(window_type: str, window_length: int, device: str):
18 | if window_type == "average":
19 | window = torch.ones(window_length) / window_length
20 | elif window_type == "sqrt_hann":
21 | window = torch.hann_window(window_length).sqrt()
22 | else:
23 | win_fn = getattr(torch, f'{window_type}_window')
24 | window = win_fn(window_length)
25 |
26 | window = window.to(device)
27 | return window
28 |
29 |
30 | class MultiScaleSTFTLoss(nn.Module):
31 | """Computes the multi-scale STFT loss from [1].
32 |
33 | Parameters
34 | ----------
35 | window_lengths : List[int], optional
36 | Length of each window of each STFT, by default [2048, 512]
37 | loss_fn : typing.Callable, optional
38 | How to compare each loss, by default nn.L1Loss()
39 | clamp_eps : float, optional
40 | Clamp on the log magnitude, below, by default 1e-5
41 | mag_weight : float, optional
42 | Weight of raw magnitude portion of loss, by default 1.0
43 | log_weight : float, optional
44 | Weight of log magnitude portion of loss, by default 1.0
45 | pow : float, optional
46 | Power to raise magnitude to before taking log, by default 2.0
47 | window_type : str, optional
48 | Type of window to use, by default ``hann``.
49 | match_stride : bool, optional
50 | Whether to match the stride of convolutional layers, by default False
51 |
52 | References
53 | ----------
54 |
55 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
56 | "DDSP: Differentiable Digital Signal Processing."
57 | International Conference on Learning Representations. 2019.
58 |
59 | Implementation modified from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
60 | """
61 |
62 | def __init__(
63 | self,
64 | window_lengths: List[int] = [2048, 512],
65 | loss_fn: Callable = nn.L1Loss(),
66 | clamp_eps: float = 1e-5,
67 | mag_weight: float = 1.0,
68 | log_weight: float = 1.0,
69 | pow: float = 2.0,
70 | window_type: str = "hann",
71 | padding_type: str = "reflect",
72 | ):
73 | super().__init__()
74 | self.stft_params = [
75 | STFTParams(
76 | window_length=w,
77 | hop_length=w // 4,
78 | window_type=window_type,
79 | padding_type=padding_type
80 | )
81 | for w in window_lengths
82 | ]
83 | self.loss_fn = loss_fn
84 | self.clamp_eps = clamp_eps
85 | self.mag_weight = mag_weight
86 | self.log_weight = log_weight
87 | self.pow = pow
88 |
89 |
90 | def forward(self, est: Tensor, ref : Tensor):
91 | """Computes multi-scale STFT between an estimate and a reference
92 | signal.
93 |
94 | Parameters
95 | ----------
96 | est : torch.Tensor [B, C, T]
97 | Estimate signal
98 | ref : torch.Tensor [B, C, T]
99 | Reference signal
100 |
101 | Returns
102 | -------
103 | torch.Tensor
104 | Multi-scale STFT loss.
105 | """
106 | device = est.device
107 | if ref.device != est.device:
108 | est.to(device)
109 |
110 | assert est.shape == ref.shape, (
111 | 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
112 | )
113 | assert len(est.shape) == len(ref.shape) == 3, (
114 | 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
115 | )
116 |
117 | B, C, T = est.shape
118 | est = est.reshape(-1, T)
119 | ref = ref.reshape(-1, T)
120 |
121 | loss = 0.0
122 | for s in self.stft_params:
123 | est_spec = torch.stft(est, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length,
124 | window=get_window(s.window_type, s.window_length, device),
125 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True)
126 | ref_spec = torch.stft(ref, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length,
127 | window=get_window(s.window_type, s.window_length, device),
128 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True)
129 |
130 | loss += self.log_weight * self.loss_fn(
131 | est_spec.abs().clamp(self.clamp_eps).pow(self.pow).log10(),
132 | ref_spec.abs().clamp(self.clamp_eps).pow(self.pow).log10(),
133 | )
134 | loss += self.mag_weight * self.loss_fn(est_spec.abs(), ref_spec.abs())
135 | return loss
136 |
137 |
138 | class MelSpectrogramLoss(nn.Module):
139 | """Compute distance between mel spectrograms. Can be used
140 | in a multi-scale way.
141 |
142 | Parameters
143 | ----------
144 | n_mels : List[int]
145 | Number of mels per STFT, by default [150, 80],
146 | window_lengths : List[int], optional
147 | Length of each window of each STFT, by default [2048, 512]
148 | loss_fn : typing.Callable, optional
149 | How to compare each loss, by default nn.L1Loss()
150 | clamp_eps : float, optional
151 | Clamp on the log magnitude, below, by default 1e-5
152 | mag_weight : float, optional
153 | Weight of raw magnitude portion of loss, by default 1.0
154 | log_weight : float, optional
155 | Weight of log magnitude portion of loss, by default 1.0
156 | pow : float, optional
157 | Power to raise magnitude to before taking log, by default 2.0
158 | window_type : str, optional
159 | Type of window to use, by default ``hann``.
160 | match_stride : bool, optional
161 | Whether to match the stride of convolutional layers, by default False
162 |
163 | Implementation modified from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
164 | """
165 |
166 | def __init__(
167 | self,
168 | sr: int = 44100,
169 | n_mels: List[int] = [150, 80],
170 | mel_fmin: List[float] = [0.0, 0.0],
171 | mel_fmax: List[float] = [None, None],
172 | window_lengths: List[int] = [2048, 512],
173 | loss_fn: Callable = nn.L1Loss(),
174 | clamp_eps: float = 1e-5,
175 | mag_weight: float = 1.0,
176 | log_weight: float = 1.0,
177 | pow: float = 2.0,
178 | window_type: str = "hann",
179 | padding_type: str = "reflect",
180 | match_stride: bool = False,
181 | ):
182 | assert len(n_mels) == len(window_lengths) == len(mel_fmin) == len(mel_fmax), \
183 | f'lengths are different, n_mels: {n_mels}, window_lengths: {window_lengths}, mel_fmin: {mel_fmin}, mel_fmax: {mel_fmax}'
184 | super().__init__()
185 | self.stft_params = [
186 | STFTParams(
187 | window_length=w,
188 | hop_length=w // 4,
189 | window_type=window_type,
190 | padding_type=padding_type,
191 | )
192 | for w in window_lengths
193 | ]
194 | self.sr = sr
195 | self.n_mels = n_mels
196 | self.loss_fn = loss_fn
197 | self.clamp_eps = clamp_eps
198 | self.log_weight = log_weight
199 | self.mag_weight = mag_weight
200 | self.mel_fmin = mel_fmin
201 | self.mel_fmax = mel_fmax
202 | self.pow = pow
203 |
204 | def forward(self, est: Tensor, ref : Tensor):
205 | """Computes multi-scale mel loss between an estimate and
206 | a reference signal.
207 |
208 | Parameters
209 | ----------
210 | est : torch.Tensor [B, C, T]
211 | Estimate signal
212 | ref : torch.Tensor [B, C, T]
213 | Reference signal
214 |
215 | Returns
216 | -------
217 | torch.Tensor
218 | Multi-scale Mel loss.
219 | """
220 | device = est.device
221 | if ref.device != est.device:
222 | est.to(device)
223 |
224 | assert est.shape == ref.shape, (
225 | 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
226 | )
227 | assert len(est.shape) == len(ref.shape) == 3, (
228 | 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
229 | )
230 |
231 | B, C, T = est.shape
232 | est = est.reshape(-1, T)
233 | ref = ref.reshape(-1, T)
234 |
235 | loss = 0.0
236 | for n_mels, fmin, fmax, s in zip(
237 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
238 | ):
239 | est_spec = torch.stft(est, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length,
240 | window=get_window(s.window_type, s.window_length, device),
241 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True)
242 | ref_spec = torch.stft(ref, n_fft=s.window_length, hop_length=s.hop_length, win_length=s.window_length,
243 | window=get_window(s.window_type, s.window_length, device),
244 | pad_mode=s.padding_type, center=True, onesided=True, return_complex=True)
245 |
246 | # convert to mel
247 | est_mag = est_spec.abs()
248 | ref_mag = ref_spec.abs()
249 |
250 | mel_basis = mel(sr=self.sr, n_fft=s.window_length, n_mels=n_mels, fmin=fmin, fmax=fmax)
251 | mel_basis = torch.from_numpy(mel_basis).to(device)
252 |
253 | est_mel = (est_mag.transpose(-1, -2) @ mel_basis.T).transpose(-1, -2)
254 | ref_mel = (ref_mag.transpose(-1, -2) @ mel_basis.T).transpose(-1, -2)
255 |
256 |
257 | loss += self.log_weight * self.loss_fn(
258 | est_mel.clamp(self.clamp_eps).pow(self.pow).log10(),
259 | ref_mel.clamp(self.clamp_eps).pow(self.pow).log10(),
260 | )
261 | loss += self.mag_weight * self.loss_fn(est_mel, ref_mel)
262 |
263 | return loss
--------------------------------------------------------------------------------
/src/metrics/visqol.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import soxr
4 | import torch
5 | import numpy as np
6 | from visqol import visqol_lib_py
7 | from visqol.pb2 import visqol_config_pb2
8 | from visqol.pb2 import similarity_result_pb2
9 |
10 | class VisqolMetric():
11 |
12 | def __init__(self,
13 | mode='audio',
14 | reduction='mean'):
15 |
16 | self.reduction = reduction
17 | config = visqol_config_pb2.VisqolConfig()
18 | if mode == "audio":
19 | self.fs = 48000
20 | config.audio.sample_rate = self.fs
21 | config.options.use_speech_scoring = False
22 | svr_model_path = "libsvm_nu_svr_model.txt"
23 | elif mode == "speech":
24 | self.fs = 16000
25 | config.audio.sample_rate = self.fs
26 | config.options.use_speech_scoring = True
27 | svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
28 | else:
29 | raise ValueError(f"Unrecognized mode: {mode}")
30 |
31 | config.options.svr_model_path = os.path.join(
32 | os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path)
33 |
34 | self.api = visqol_lib_py.VisqolApi()
35 | self.api.Create(config)
36 |
37 | def __call__(self, est, ref, sr=44100):
38 | assert est.shape == ref.shape, 'expected same shape, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
39 | assert len(est.shape) == len(ref.shape) == 3, 'expected BxCxN, get est: {}, ref: {} instead'.format(est.shape, ref.shape)
40 | B, C, T = est.shape
41 | est = est.reshape(B*C, -1).detach().cpu().numpy().astype(np.float64)
42 | ref = ref.reshape(B*C, -1).detach().cpu().numpy().astype(np.float64)
43 |
44 | if sr != self.fs:
45 | est_list = []
46 | ref_list = []
47 | for i in range(est.shape[0]):
48 | est_list.append(soxr.resample(est[i], sr, self.fs))
49 | ref_list.append(soxr.resample(ref[i], sr, self.fs))
50 | est = np.array(est_list)
51 | ref = np.array(ref_list)
52 |
53 | ret = []
54 | for i in range(est.shape[0]):
55 | ret.append(self.api.Measure(ref[i], est[i]).moslqo)
56 |
57 | if self.reduction == "mean":
58 | ret = np.mean(ret)
59 | elif self.reduction == "sum":
60 | ret = np.sum(ret)
61 | else:
62 | ret = ret
63 | return ret
64 |
65 |
66 | if __name__ == '__main__':
67 | eval_visqol = VisqolMetric(mode='audio')
68 |
69 | import soundfile as sf
70 | audio_name = '001_fragment_10.5'
71 | ref_audio_path = f'/home/xbie/Data/starnet/starnet_frag_16kHz/{audio_name}.wav'
72 | target_audio_path = ref_audio_path
73 | reference, fs_ref = sf.read(ref_audio_path)
74 | degraded, fs_target = sf.read(target_audio_path)
75 | print(f'Audio file ref: {fs_ref} Hz, target: {fs_target} Hz')
76 |
77 | est = torch.from_numpy(degraded).reshape(1, 1, -1)
78 | ref = torch.from_numpy(reference).reshape(1, 1, -1)
79 | metric_visqol = VisqolMetric(mode='audio')
80 | visqol = metric_visqol(est=est, ref=ref, sr=fs_ref)
81 | print(f'Visqol: {visqol:.2f}')
82 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .sdcodec import SDCodec
2 | from .discriminator import Discriminator
--------------------------------------------------------------------------------
/src/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import math
6 | import julius # a little bit worse than soxr, but compatible with torch and cuda
7 | from einops import rearrange
8 | from collections import namedtuple
9 |
10 | from ..modules import WNConv1d, WNConv2d
11 |
12 | def get_window(window_type: str, window_length: int, device: str):
13 | if window_type == "average":
14 | window = torch.ones(window_length) / window_length
15 | elif window_type == "sqrt_hann":
16 | window = torch.hann_window(window_length).sqrt()
17 | else:
18 | win_fn = getattr(torch, f'{window_type}_window')
19 | window = win_fn(window_length)
20 |
21 | window = window.to(device)
22 | return window
23 |
24 |
25 | class MSD(nn.Module):
26 | def __init__(self, rate: int = 1, sample_rate: int = 44100):
27 | super().__init__()
28 | self.convs = nn.ModuleList(
29 | [
30 | WNConv1d(1, 16, 15, 1, padding=7, act=True),
31 | WNConv1d(16, 64, 41, 4, groups=4, padding=20, act=True),
32 | WNConv1d(64, 256, 41, 4, groups=16, padding=20, act=True),
33 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20, act=True),
34 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20, act=True),
35 | WNConv1d(1024, 1024, 5, 1, padding=2, act=True),
36 | ]
37 | )
38 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
39 |
40 | # julius resample
41 | # https://adefossez.github.io/julius/julius/resample.html
42 | old_sr = sample_rate
43 | new_sr = sample_rate // rate
44 | gcd = math.gcd(old_sr, new_sr)
45 | old_sr = old_sr // gcd
46 | new_sr = new_sr // gcd
47 | self.resample = julius.ResampleFrac(old_sr=old_sr, new_sr=new_sr)
48 |
49 | def forward(self, x):
50 | x = self.resample(x)
51 | fmap = []
52 |
53 | for l in self.convs:
54 | x = l(x)
55 | fmap.append(x)
56 | x = self.conv_post(x)
57 | fmap.append(x)
58 |
59 | return fmap
60 |
61 |
62 | class MPD(nn.Module):
63 | def __init__(self, period):
64 | super().__init__()
65 | self.period = period
66 | self.convs = nn.ModuleList(
67 | [
68 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0), act=True),
69 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0), act=True),
70 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0), act=True),
71 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0), act=True),
72 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0), act=True),
73 | ]
74 | )
75 | self.conv_post = WNConv2d(
76 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
77 | )
78 |
79 | def pad_to_period(self, x):
80 | t = x.shape[-1]
81 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
82 | return x
83 |
84 | def forward(self, x):
85 | fmap = []
86 |
87 | x = self.pad_to_period(x)
88 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
89 |
90 | for layer in self.convs:
91 | x = layer(x)
92 | fmap.append(x)
93 |
94 | x = self.conv_post(x)
95 | fmap.append(x)
96 |
97 | return fmap
98 |
99 |
100 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
101 | STFTParams = namedtuple(
102 | "STFTParams",
103 | ["window_length", "hop_length"],
104 | )
105 |
106 | STFTParams = namedtuple(
107 | "STFTParams",
108 | ["window_length", "hop_length", "window_type", "padding_type", "match_stride"],
109 | )
110 |
111 | class MRD(nn.Module):
112 | def __init__(
113 | self,
114 | window_length: int,
115 | hop_factor: float = 0.25,
116 | sample_rate: int = 44100,
117 | bands: list = BANDS,
118 | ):
119 | """Complex multi-band spectrogram discriminator.
120 | Parameters
121 | ----------
122 | window_length : int
123 | Window length of STFT.
124 | hop_factor : float, optional
125 | Hop factor of the STFT, defaults to ``0.25 * window_length``.
126 | sample_rate : int, optional
127 | Sampling rate of audio in Hz, by default 44100
128 | bands : list, optional
129 | Bands to run discriminator over.
130 | """
131 | super().__init__()
132 |
133 | self.window_length = window_length
134 | self.hop_factor = hop_factor
135 | self.sample_rate = sample_rate
136 | self.stft_params = STFTParams(
137 | window_length=window_length,
138 | hop_length=int(window_length * hop_factor),
139 | window_type='hann',
140 | padding_type='reflect',
141 | match_stride=True,
142 | )
143 |
144 | n_fft = window_length // 2 + 1
145 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
146 | self.bands = bands
147 |
148 | ch = 32
149 | convs = lambda: nn.ModuleList(
150 | [
151 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4), act=True),
152 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4), act=True),
153 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4), act=True),
154 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4), act=True),
155 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1), act=True),
156 | ]
157 | )
158 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
159 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
160 |
161 | def spectrogram(self, x):
162 | B, C, T = x.shape
163 | x = torch.stft(x.reshape(-1, T),
164 | n_fft=self.stft_params.window_length,
165 | hop_length=self.stft_params.hop_length,
166 | win_length=self.stft_params.window_length,
167 | window=torch.hann_window(self.stft_params.window_length).to(x.device),
168 | pad_mode='reflect', center=True, onesided=True, return_complex=True)
169 | x = torch.view_as_real(x)
170 | x = rearrange(x, "n f t c -> n c t f")
171 | # Split into bands
172 | x_bands = [x[..., b[0] : b[1]] for b in self.bands]
173 | return x_bands
174 |
175 | def forward(self, x):
176 | x_bands = self.spectrogram(x)
177 | fmap = []
178 |
179 | x = []
180 | for band, stack in zip(x_bands, self.band_convs):
181 | for layer in stack:
182 | band = layer(band)
183 | fmap.append(band)
184 | x.append(band)
185 |
186 | x = torch.cat(x, dim=-1)
187 | x = self.conv_post(x)
188 | fmap.append(x)
189 |
190 | return fmap
191 |
192 |
193 | class Discriminator(nn.Module):
194 | def __init__(
195 | self,
196 | sample_rate: int,
197 | rates: list = [],
198 | periods: list = [2, 3, 5, 7, 11],
199 | fft_sizes: list = [2048, 1024, 512],
200 | bands: list = BANDS,
201 | ):
202 | """Discriminator that combines multiple discriminators.
203 |
204 | Parameters
205 | ----------
206 | sample_rate : int, needed
207 | Sampling rate of audio in Hz
208 | rates : list, optional
209 | MSD, sampling rates (in Hz), by default []
210 | If empty, MSD is not used.
211 | periods : list, optional
212 | MPD, periods of samples, by default [2, 3, 5, 7, 11]
213 | fft_sizes : list, optional
214 | MRD, window sizes of the FFT, by default [2048, 1024, 512]
215 | bands : list, optional
216 | MRD, bands, by default `BANDS`
217 | """
218 | super().__init__()
219 | discs = []
220 | discs += [MSD(r, sample_rate=sample_rate) for r in rates]
221 | discs += [MPD(p) for p in periods]
222 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
223 | self.discriminators = nn.ModuleList(discs)
224 |
225 | def preprocess(self, y):
226 | # Remove DC offset
227 | y = y - y.mean(dim=-1, keepdims=True)
228 | # Peak normalize the volume of input audio
229 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
230 | return y
231 |
232 | def forward(self, x):
233 | x = self.preprocess(x)
234 | fmaps = [d(x) for d in self.discriminators]
235 | return fmaps
236 |
237 |
238 | if __name__ == "__main__":
239 | disc = Discriminator(sample_rate=44100,
240 | rates=[],
241 | periods=[2, 3, 5, 7, 11],
242 | fft_sizes=[2048, 1024, 512],
243 | bands=[[0.0, 0.1], [0.1, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]]).to('cuda')
244 | x = torch.zeros(1, 1, 44100).to('cuda')
245 |
246 | total_params = sum(p.numel() for p in disc.parameters()) / 1e6
247 | print(f'Total params: {total_params:.2f} Mb')
248 |
249 | results = disc(x)
250 | print(len(results))
251 |
252 | for i, result in enumerate(results):
253 | print(f"disc{i}")
254 | for i, r in enumerate(result):
255 | print(r.shape, r.mean(), r.min(), r.max())
256 | print()
257 |
--------------------------------------------------------------------------------
/src/models/sdcodec.py:
--------------------------------------------------------------------------------
1 | """
2 | ResDQ DAC-based Residual VQ-VAE model.
3 | """
4 | import math
5 | import numpy as np
6 | from typing import List, Union, Tuple
7 | from typing import Optional
8 | import random
9 | from omegaconf import DictConfig
10 |
11 | import torch
12 | from torch import nn
13 | import torch.nn.functional as F
14 | from accelerate.logging import get_logger
15 |
16 | from ..modules import CodecMixin
17 | from .. import modules
18 |
19 | logger = get_logger(__name__)
20 |
21 | def init_weights(m):
22 | if isinstance(m, nn.Conv1d):
23 | nn.init.trunc_normal_(m.weight, std=0.02) # default kaiming_uniform_
24 | nn.init.constant_(m.bias, 0) # default uniform_
25 |
26 |
27 | class SDCodec(nn.Module, CodecMixin):
28 | """Source-aware Disentangled Neural Audio Codec.
29 | Args:
30 |
31 | """
32 | def __init__(
33 | self,
34 | sample_rate: int,
35 | latent_dim: int = None,
36 | tracks: List[str] = ['speech', 'music', 'sfx'],
37 | enc_params: DictConfig = {'name': 'DACEncoder'},
38 | dec_params: DictConfig = {'name': 'DACDecoder'},
39 | quant_params: DictConfig = {'name': 'DACDecoder'},
40 | pretrain: dict = {},
41 | ):
42 | super().__init__()
43 |
44 | self.sample_rate = sample_rate
45 | self.tracks = tracks
46 | self.enc_params = enc_params
47 | self.dec_params = dec_params
48 | self.quant_params = quant_params
49 | self.pretrain = pretrain
50 |
51 |
52 | if latent_dim is None:
53 | latent_dim = self.enc_params.d_model * (2 ** len(self.enc_params.strides))
54 | self.latent_dim = latent_dim
55 | self.hop_length = np.prod(self.enc_params.strides)
56 |
57 |
58 | # Define encoder and decoder
59 | enc_net = getattr(modules, self.enc_params.pop('name'))
60 | dec_net = getattr(modules, self.dec_params.pop('name'))
61 | self.encoder = enc_net(**self.enc_params)
62 | self.decoder = dec_net(**self.dec_params)
63 |
64 | # Define quantizer
65 | quant_net = getattr(modules, self.quant_params.pop('name'))
66 | self.quantizer = quant_net(tracks=self.tracks, **self.quant_params)
67 |
68 | # Init
69 | self.apply(init_weights)
70 | self.delay = self.get_delay()
71 |
72 | # Load pretrained
73 | load_pretrained = self.pretrain.get('load_pretrained', False)
74 | if load_pretrained:
75 | pretrained_dict = torch.load(load_pretrained)
76 | hyparam = pretrained_dict['metadata']['kwargs']
77 | ignore_modules = self.pretrain.get('ignore_modules', [])
78 | freeze_modues = self.pretrain.get('freeze_modules', [])
79 |
80 | is_match = self._check_hyparam(hyparam)
81 | if is_match:
82 | self._load_pretrained(pretrained_dict['state_dict'], ignore_modules)
83 | self._freeze(freeze_modues)
84 | logger.info('Pretrain models load success from {}'.format(load_pretrained))
85 | logger.info('-> modules ignored: {}'.format(ignore_modules))
86 | logger.info('-> modules freezed: {}'.format(freeze_modues))
87 | else:
88 | logger.info(f'Pretrain param do not match model, load pretrained failed...')
89 | logger.info('Pretrain params:')
90 | logger.info(hyparam)
91 | logger.info('Model params:')
92 | logger.info('Encoder: {}'.format(self.enc_params))
93 | logger.info('Decoder: {}'.format(self.dec_params))
94 |
95 | @property
96 | def device(self):
97 | """Gets the device the model is on by looking at the device of
98 | the first parameter. May not be valid if model is split across
99 | multiple devices.
100 | """
101 | return list(self.parameters())[0].device
102 |
103 |
104 | def _check_hyparam(self, hyparam):
105 | return (hyparam['encoder_dim'] == self.enc_params['d_model']) \
106 | and (hyparam['encoder_rates'] == self.enc_params['strides']) \
107 | and (hyparam['decoder_dim'] == self.dec_params['d_model']) \
108 | and (hyparam['sample_rate'] == self.sample_rate)
109 |
110 |
111 | def _load_pretrained(self, pretrained_state, ignored_modules=[]):
112 | own_state = self.state_dict()
113 | pretrained_state = {k: v for k, v in pretrained_state.items() if k.split('.')[0] not in ignored_modules}
114 | for k in own_state.keys():
115 | if k in pretrained_state:
116 | own_state[k] = pretrained_state[k]
117 | elif 'quantizer' not in ignored_modules:
118 | own_key_list = k.split('.')
119 | if own_key_list[0] == 'quantizer':
120 | if own_key_list[1] == 'jitter_dict':
121 | continue
122 | elif own_key_list[1] == 'shared_rvq':
123 | own_key_list.pop(1) # shared_rvq
124 | else:
125 | own_key_list.pop(1) # rvq_dict
126 | own_key_list.pop(1) # (speech, music, sfx)
127 | pretrained_key = '.'.join(own_key_list)
128 | if pretrained_key not in pretrained_state:
129 | print(k)
130 | print(pretrained_key)
131 | breakpoint()
132 | own_state[k] = pretrained_state[pretrained_key]
133 |
134 |
135 | def _freeze(self, freeze_modues=[]):
136 | for module in freeze_modues:
137 | child = getattr(self, module)
138 | for param in child.parameters():
139 | param.requires_grad = False
140 |
141 |
142 | def preprocess(self, audio_data, sample_rate) -> torch.Tensor:
143 | if sample_rate is None:
144 | sample_rate = self.sample_rate
145 | assert sample_rate == self.sample_rate
146 |
147 | length = audio_data.shape[-1]
148 | right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
149 | audio_data = F.pad(audio_data, (0, right_pad))
150 |
151 | return audio_data
152 |
153 |
154 | def encode(self, audio_data: torch.Tensor) -> torch.Tensor:
155 | """Encode given audio data and return quantized latent codes
156 |
157 | Parameters
158 | ----------
159 | audio_data : Tensor[B x 1 x T]
160 | Audio data to encode
161 |
162 | Returns
163 | -------
164 | "feats" : Tensor[B x D x T]
165 | Continuous features before quantization
166 | """
167 | return self.encoder(audio_data)
168 |
169 |
170 | def decode(self, z: torch.Tensor) -> torch.Tensor:
171 | """Decode given latent codes and return audio data
172 |
173 | Parameters
174 | ----------
175 | z : Tensor[B x D x T]
176 | Quantized continuous representation of input
177 |
178 | Returns
179 | -------
180 | "audio" : Tensor[B x 1 x length]
181 | Decoded audio data.
182 | """
183 | return self.decoder(z)
184 |
185 |
186 | def quantize(
187 | self,
188 | feats: torch.Tensor,
189 | track: str = 'speech',
190 | n_quantizers: int = None,
191 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
192 | """Encode given audio data and return quantized latent codes
193 |
194 | Parameters
195 | ----------
196 | feats : Tensor[B x D x T]
197 | Continuous features before quantization
198 | track: str
199 | Specify which quantizer to be used
200 | n_quantizers : int, optional
201 | Number of quantizers to use, by default None
202 | If None, all quantizers are used.
203 |
204 | Returns
205 | -------
206 | "z" : Tensor[B x D x T]
207 | Quantized continuous representation of input
208 | "codes" : Tensor[B x N x T]
209 | Codebook indices for each codebook
210 | (quantized discrete representation of input)
211 | "latents" : Tensor[B x N*D' x T]
212 | Projected latents (continuous representation of input before quantization)
213 | "vq/commitment_loss" : Tensor[1]
214 | Commitment loss to train encoder to predict vectors closer to codebook
215 | entries
216 | "vq/codebook_loss" : Tensor[1]
217 | Codebook loss to update the codebook
218 | """
219 | assert track in self.tracks, 'f{track} not included in quantizer'
220 | z, codes, latents, commitment_loss, codebook_loss = self.quantizer(
221 | track, feats, n_quantizers
222 | )
223 | return z, codes, latents, commitment_loss, codebook_loss
224 |
225 |
226 | def forward(
227 | self,
228 | batch: Optional[dict],
229 | sample_rate: int = None,
230 | n_quantizers: int = None,
231 | ):
232 | """Model forward pass
233 |
234 | Parameters
235 | ----------
236 | batch : dict of Tensor[B x 1 x T]
237 | Batch input with audio data
238 | sample_rate : int, optional
239 | Sample rate of audio data in Hz, by default None
240 | If None, defaults to `self.sample_rate`
241 | n_quantizers : int, optional
242 | Number of quantizers to use, by default None.
243 | If None, all quantizers are used.
244 | Returns
245 | -------
246 | dict
247 | A dictionary with the following keys:
248 | "track/z" : Tensor[B x D x T]
249 | Quantized continuous representation of input
250 | "track/codes" : Tensor[B x N x T]
251 | Codebook indices for each codebook
252 | (quantized discrete representation of input)
253 | "track/latents" : Tensor[B x N*D x T]
254 | Projected latents (continuous representation of input before quantization)
255 | "vq/commitment_loss" : Tensor[1]
256 | Commitment loss to train encoder to predict vectors closer to codebook
257 | entries
258 | "vq/codebook_loss" : Tensor[1]
259 | Codebook loss to update the codebook
260 | "length" : int
261 | Number of samples in input audio
262 | "ref" : Tensor[B x (K+1) x 1 x length]
263 | Decoded audio data.
264 | "recon" : Tensor[B x (K+1) x 1 x length]
265 | Decoded audio data.
266 | """
267 |
268 | # mix by masking
269 | audio_data = batch['mix']
270 | bs, _, length = audio_data.shape
271 | valid_tracks = batch['valid_tracks']
272 | mask = [1 if t in valid_tracks else 0 for t in self.tracks]
273 | mask = torch.tensor(mask, device=audio_data.device)
274 |
275 | # preprocess, zero-padding to proper length
276 | audio_data = self.preprocess(audio_data, sample_rate)
277 |
278 | # encoder
279 | feats = self.encode(audio_data)
280 |
281 | # quantize
282 | dict_z = {}
283 | dict_commit = {}
284 | dict_cb = {}
285 | # for i, track in enumerate(self.tracks):
286 | for i, track in enumerate(self.tracks):
287 | # quantize
288 | z, codes, latents, commitment_loss, codebook_loss = self.quantize(
289 | feats, track, n_quantizers
290 | )
291 | # ppl
292 | probs = F.one_hot(codes.detach(), num_classes=self.quant_params.codebook_size[i]).float() # (B, N, T, num_class)
293 | avg_probs = probs.mean(dim=(0,2)) # (N, num_class)
294 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1)) # (N,)
295 | # save retsults
296 | batch[f'{track}/z'] = z
297 | batch[f'{track}/codes'] = codes
298 | batch[f'{track}/latents'] = latents
299 | batch[f'{track}/ppl'] = perplexity
300 | dict_z[track] = z
301 | dict_commit[track] = commitment_loss
302 | dict_cb[track] = codebook_loss
303 |
304 | # commit and codebook loss
305 | commit_loss = (torch.stack([dict_commit[t] for t in self.tracks]) * mask).sum() / len(valid_tracks)
306 | cb_loss = (torch.stack([dict_cb[t] for t in self.tracks]) * mask).sum() / len(valid_tracks)
307 |
308 | # re-mix and decode
309 | if batch['random_swap']:
310 | z_mix = torch.stack([dict_z[t][batch['shuffle_list'][t]] for t in self.tracks])
311 | else:
312 | z_mix = torch.stack([dict_z[t] for t in self.tracks])
313 | z_mix = (z_mix * mask[:, None, None, None]).sum(dim=0)
314 | x_remix = self.decode(z_mix) # (B, 1, T)
315 |
316 | # collect data
317 | sep_out = [self.decode(dict_z[t]) for t in valid_tracks]
318 | audio_recon = torch.stack([x_remix] + sep_out, dim=1) # (B, K, 1, T)
319 |
320 | batch['recon'] = audio_recon[..., :length]
321 | batch['length'] = length
322 | batch['vq/commitment_loss'] = commit_loss
323 | batch['vq/codebook_loss'] = cb_loss
324 |
325 | return batch
326 |
327 |
328 | def evaluate(
329 | self,
330 | input_audio: torch.Tensor,
331 | sample_rate: int = None,
332 | n_quantizers: int = None,
333 | output_tracks: list[str] = ['mix'],
334 | ) -> torch.Tensor:
335 | """Model evaluation
336 | Parameters
337 | ----------
338 | input_audio : Tensor[B x 1 x T]
339 | Audio data to encode
340 | sample_rate : int, optional
341 | Sample rate of audio data in Hz, by default None
342 | If None, defaults to `self.sample_rate`
343 | n_quantizers : int, optional
344 | Number of quantizers to use, by default None.
345 | If None, all quantizers are used.
346 | output_tracks : List[str]
347 | List of track to return
348 |
349 | Returns
350 | -------
351 | output_audio : Tensor[B x K x T]
352 | Output audio with K tracks
353 | """
354 | assert all((t in self.tracks) or (t=='mix') for t in output_tracks); \
355 | "output tracks {} not included in model tracks {}".format(output_tracks, self.tracks)
356 |
357 | bs, _, length = input_audio.shape
358 | audio_data = self.preprocess(input_audio, sample_rate) # (B, 1, T)
359 |
360 | # encoder
361 | feats = self.encode(audio_data)
362 |
363 | # quantization
364 | latent_dict = {}
365 | for track in self.tracks:
366 | z, codes, latents, commitment_loss, codebook_loss = self.quantize(
367 | feats, track, n_quantizers
368 | )
369 | latent_dict[track] = z
370 |
371 | # decoder
372 | list_out = []
373 | for track in output_tracks:
374 | if track == 'mix':
375 | z_mix = torch.stack(list(latent_dict.values()), dim=0).sum(dim=0)
376 | x_out = self.decode(z_mix)
377 | list_out.append(x_out)
378 | else:
379 | x_out = self.decode(latent_dict[track])
380 | list_out.append(x_out)
381 |
382 | output_audio = torch.cat(list_out, dim=1)[...,:length]
383 |
384 | return output_audio
385 |
386 |
387 |
388 | if __name__ == "__main__":
389 | from accelerate import Accelerator
390 | from omegaconf import OmegaConf
391 | import numpy as np
392 | from functools import partial
393 |
394 | accelerator = Accelerator()
395 | logger = get_logger(__name__)
396 |
397 | cfg = OmegaConf.load('/home/xbie/Code/ResDQ/config/model/sdcodec_16k.yaml')
398 | sample_rate = 16000
399 | tracks = ['speech', 'music', 'sfx']
400 | mask = [True, False, False]
401 | show_params = False
402 |
403 | model_name = cfg.pop('name')
404 | model = SDCodec(sample_rate=sample_rate, **cfg).to("cuda")
405 |
406 | if show_params:
407 | for n, m in model.named_modules():
408 | o = m.extra_repr()
409 | p = sum(p.numel() for p in m.parameters())
410 | fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
411 | setattr(m, "extra_repr", partial(fn, o=o, p=p))
412 | print(model)
413 | else:
414 | print("Total # of params: {:.2f} M".format(sum(p.numel() for p in model.parameters())/1e6))
415 |
416 | print('Show model device')
417 | print(model.encoder.block[0].weight.device)
418 | print(model.quantizer_dict['speech'].quantizers[0].in_proj.weight.device)
419 |
420 |
421 | # test on random input
422 | bs = 2
423 | length = 1
424 | batch = {}
425 | length = int(sample_rate * length)
426 | for t in tracks:
427 | batch[t] = torch.randn(bs, 1, length).to(model.device)
428 | batch[t].requires_grad_(True)
429 | batch[t].retain_grad()
430 |
431 | # Make a forward pass
432 | batch = model(batch,
433 | track_mask=mask,
434 | random_swap=True)
435 | print("Input shape:", batch['ref'].shape)
436 | print("Output shape:", batch['recon'].shape)
437 |
438 | # Create gradient variable
439 | B, N,_, T = batch['recon'].shape
440 | model_out = batch['recon'].reshape(-1, 1, T)
441 | grad = torch.zeros_like(model_out)
442 | grad[:, :, grad.shape[-1] // 2] = 1
443 |
444 | # Make a backward pass
445 | model_out.backward(grad)
446 |
447 | # Check non-zero values
448 | gradmap = batch['speech'].grad.squeeze(0)
449 | gradmap = (gradmap != 0).sum(0) # sum across features
450 | rf = (gradmap != 0).sum()
451 |
452 | print(f"Receptive field: {rf.item()}")
453 |
454 |
455 | ## test on real data
456 | import pandas as pd
457 | import torchaudio
458 |
459 | idx = 0
460 | tsv_filepath = '/home/xbie/Data/dnr_v2/val_mini_2s.tsv'
461 |
462 | metadata = pd.read_csv(tsv_filepath)
463 | wav_info = metadata.iloc[idx]
464 |
465 | x_speech, sr = torchaudio.load(wav_info['speech'])
466 | x_music, sr = torchaudio.load(wav_info['music'])
467 | x_sfx, sr = torchaudio.load(wav_info['sfx'])
468 | x_mix, sr = torchaudio.load(wav_info['mix'])
469 |
470 | x_speech = x_speech[..., wav_info['trim30dBs']: wav_info['trim30dBe']].reshape(1, 1, -1).to(model.device)
471 | x_music = x_music[..., wav_info['trim30dBs']: wav_info['trim30dBe']].reshape(1, 1, -1).to(model.device)
472 | x_sfx = x_sfx[..., wav_info['trim30dBs']: wav_info['trim30dBe']].reshape(1, 1, -1).to(model.device)
473 |
474 | batch = {'speech': x_speech,
475 | 'music': x_music,
476 | 'sfx': x_sfx}
477 |
478 | batch = model(batch,
479 | track_mask=mask,
480 | random_swap=True)
481 |
482 | print('Reference audio shape: {}'.format(batch['ref'].shape))
483 | print('Recon audio shape: {}'.format(batch['recon'].shape))
484 |
485 | # batch['ref'] = batch['ref'].detach().cpu()
486 | # batch['recon'] = batch['recon'].detach().cpu()
487 |
488 | # # for i, c in enumerate(['speech', 'music', 'sfx', 'mix']):
489 | # # path = '{}_{}_refRaw.wav'.format(wav_info['id'], c)
490 | # # torchaudio.save(path, batch['ref'][0, i], 44100)
491 |
492 | # # for i, c in enumerate(['speech', 'music', 'sfx', 'mix']):
493 | # # path = '{}_{}_reconRaw.wav'.format(wav_info['id'], c)
494 | # # torchaudio.save(path, batch['recon'][0, i], 44100)
495 | # # breakpoint()
496 | # breakpoint()
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | """Modules used for building the models."""
4 |
5 |
6 | from .layers import (
7 | WNConv1d,
8 | WNConv2d,
9 | WNConvTranspose1d,
10 | Snake1d,
11 | SLSTM,
12 | Jitter,
13 | )
14 |
15 | from .base_dac import (
16 | DACEncoder,
17 | DACDecoder,
18 | DACEncoderTrans,
19 | DACDecoderTrans,
20 | CodecMixin,
21 | )
22 |
23 |
24 | from .quantize import (
25 | VectorQuantize,
26 | ResidualVectorQuantize,
27 | MultiSourceRVQ,
28 | )
--------------------------------------------------------------------------------
/src/modules/base_dac.py:
--------------------------------------------------------------------------------
1 | """
2 | Basic NN module from DAC
3 | """
4 | import math
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from .layers import Snake1d, WNConv1d, WNConvTranspose1d, TransformerSentenceEncoderLayer
10 |
11 |
12 | class ResidualUnit(nn.Module):
13 | def __init__(self, dim: int = 16, dilation: int = 1):
14 | super().__init__()
15 | k = 7 # kernal size for the first conv
16 | pad = ((k - 1) * dilation) // 2 # 2*p - d*(k-1)
17 | self.block = nn.Sequential(
18 | Snake1d(dim),
19 | WNConv1d(dim, dim, kernel_size=k, dilation=dilation, padding=pad),
20 | Snake1d(dim),
21 | WNConv1d(dim, dim, kernel_size=1),
22 | )
23 |
24 | def forward(self, x):
25 | y = self.block(x)
26 | pad = (x.shape[-1] - y.shape[-1]) // 2 # identical in-out channel
27 | if pad > 0:
28 | x = x[..., pad:-pad]
29 | return x + y
30 |
31 |
32 | class EncoderBlock(nn.Module):
33 | def __init__(self, dim: int = 16, stride: int = 1):
34 | super().__init__()
35 | self.block = nn.Sequential(
36 | ResidualUnit(dim // 2, dilation=1),
37 | ResidualUnit(dim // 2, dilation=3),
38 | ResidualUnit(dim // 2, dilation=9),
39 | Snake1d(dim // 2),
40 | WNConv1d(
41 | dim // 2,
42 | dim,
43 | kernel_size=2 * stride,
44 | stride=stride,
45 | padding=math.ceil(stride / 2),
46 | ),
47 | )
48 |
49 | def forward(self, x):
50 | return self.block(x)
51 |
52 |
53 | class DecoderBlock(nn.Module):
54 | def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
55 | super().__init__()
56 | self.block = nn.Sequential(
57 | Snake1d(input_dim),
58 | WNConvTranspose1d(
59 | input_dim,
60 | output_dim,
61 | kernel_size=2 * stride,
62 | stride=stride,
63 | padding=math.floor(stride / 2), # ceil() -> floor(), for 16kHz
64 | ),
65 | ResidualUnit(output_dim, dilation=1),
66 | ResidualUnit(output_dim, dilation=3),
67 | ResidualUnit(output_dim, dilation=9),
68 | )
69 |
70 | def forward(self, x):
71 | return self.block(x)
72 |
73 |
74 | class DACEncoder(nn.Module):
75 | def __init__(
76 | self,
77 | d_model: int = 64,
78 | strides: list = [2, 4, 8, 8],
79 | d_latent: int = 1024,
80 | ):
81 | super().__init__()
82 | # Create first convolution
83 | layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
84 |
85 | # Create EncoderBlockd_models that double channels as they downsample by `stride`
86 | for stride in strides:
87 | d_model *= 2
88 | layers += [EncoderBlock(d_model, stride=stride)]
89 |
90 | # Create last convolution
91 | layers += [
92 | Snake1d(d_model),
93 | WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
94 | ]
95 |
96 | # Wrap black into nn.Sequential
97 | self.block = nn.Sequential(*layers)
98 | self.enc_dim = d_model
99 |
100 | def forward(self, x):
101 | return self.block(x)
102 |
103 |
104 | class DACEncoderTrans(nn.Module):
105 | def __init__(
106 | self,
107 | d_model: int = 64,
108 | strides: list = [2, 4, 8, 8],
109 | att_d_model: int = 512,
110 | att_nhead: int = 8,
111 | att_ff: int = 2048,
112 | att_norm_first: bool = False,
113 | att_layers: int = 1,
114 | d_latent: int = 1024,
115 | ):
116 | super().__init__()
117 | # Create first convolution
118 | layers = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
119 |
120 | # Create EncoderBlockd_models that double channels as they downsample by `stride`
121 | for stride in strides:
122 | d_model *= 2
123 | layers += [EncoderBlock(d_model, stride=stride)]
124 |
125 | # Create convolution
126 | layers += [
127 | Snake1d(d_model),
128 | WNConv1d(d_model, att_d_model, kernel_size=3, padding=1),
129 | ]
130 |
131 | # Create attention layers
132 | layers += [
133 | TransformerSentenceEncoderLayer(d_model=att_d_model, nhead=att_nhead,
134 | dim_feedforward=att_ff, norm_first=att_norm_first,
135 | num_layers=att_layers)
136 | ]
137 |
138 | # Create last convolution
139 | layers += [
140 | Snake1d(att_d_model),
141 | WNConv1d(att_d_model, d_latent, kernel_size=3, padding=1),
142 | ]
143 |
144 | # Wrap black into nn.Sequential
145 | self.block = nn.Sequential(*layers)
146 | self.enc_dim = d_model
147 |
148 | def forward(self, x):
149 | return self.block(x)
150 |
151 |
152 | class DACDecoder(nn.Module):
153 | def __init__(
154 | self,
155 | d_model: int = 1536,
156 | strides: list = [8, 8, 4, 2],
157 | d_latent: int = 1024,
158 | d_out: int = 1,
159 | ):
160 | super().__init__()
161 |
162 | # Add first conv layer
163 | layers = [WNConv1d(d_latent, d_model, kernel_size=7, padding=3)]
164 |
165 | # Add upsampling + MRF blocks (from HiFi GAN)
166 | for stride in strides:
167 | layers += [DecoderBlock(d_model, d_model//2, stride)]
168 | d_model = d_model // 2
169 |
170 | # Add final conv layer
171 | layers += [
172 | Snake1d(d_model),
173 | WNConv1d(d_model, d_out, kernel_size=7, padding=3),
174 | nn.Tanh(),
175 | ]
176 |
177 | self.model = nn.Sequential(*layers)
178 |
179 | def forward(self, x):
180 | return self.model(x)
181 |
182 |
183 | class DACDecoderTrans(nn.Module):
184 | def __init__(
185 | self,
186 | d_model: int = 1536,
187 | strides: list = [8, 8, 4, 2],
188 | d_latent: int = 1024,
189 | att_d_model: int = 512,
190 | att_nhead: int = 8,
191 | att_ff: int = 2048,
192 | att_norm_first: bool = False,
193 | att_layers: int = 1,
194 | d_out: int = 1,
195 | ):
196 | super().__init__()
197 |
198 | # Add first conv layer
199 | layers = [WNConv1d(d_latent, att_d_model, kernel_size=7, padding=3)]
200 |
201 | # Add attention layer
202 | layers += [
203 | TransformerSentenceEncoderLayer(d_model=att_d_model, nhead=att_nhead,
204 | dim_feedforward=att_ff, norm_first=att_norm_first,
205 | num_layers=att_layers)
206 | ]
207 |
208 | # Add conv layer
209 | layers += [
210 | Snake1d(att_d_model),
211 | WNConv1d(att_d_model, d_model, kernel_size=7, padding=3)
212 | ]
213 |
214 | # Add upsampling + MRF blocks (from HiFi GAN)
215 | for stride in strides:
216 | layers += [DecoderBlock(d_model, d_model//2, stride)]
217 | d_model = d_model // 2
218 |
219 | # Add final conv layer
220 | layers += [
221 | Snake1d(d_model),
222 | WNConv1d(d_model, d_out, kernel_size=7, padding=3),
223 | nn.Tanh(),
224 | ]
225 |
226 | self.model = nn.Sequential(*layers)
227 |
228 | def forward(self, x):
229 | return self.model(x)
230 |
231 |
232 |
233 | class CodecMixin:
234 | """Truncated version of DAC CodecMixin
235 | """
236 | def get_delay(self):
237 | # Any number works here, delay is invariant to input length
238 | l_out = self.get_output_length(0)
239 | L = l_out
240 |
241 | layers = []
242 | for layer in self.modules():
243 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
244 | layers.append(layer)
245 |
246 | for layer in reversed(layers):
247 | d = layer.dilation[0]
248 | k = layer.kernel_size[0]
249 | s = layer.stride[0]
250 |
251 | if isinstance(layer, nn.ConvTranspose1d):
252 | L = ((L - d * (k - 1) - 1) / s) + 1
253 | elif isinstance(layer, nn.Conv1d):
254 | L = (L - 1) * s + d * (k - 1) + 1
255 |
256 | L = math.ceil(L)
257 |
258 | l_in = L
259 |
260 | return (l_in - l_out) // 2
261 |
262 | def get_output_length(self, input_length):
263 | L = input_length
264 | # Calculate output length
265 | for layer in self.modules():
266 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
267 | d = layer.dilation[0]
268 | k = layer.kernel_size[0]
269 | s = layer.stride[0]
270 |
271 | if isinstance(layer, nn.Conv1d):
272 | L = ((L - d * (k - 1) - 1) / s) + 1
273 | elif isinstance(layer, nn.ConvTranspose1d):
274 | L = (L - 1) * s + d * (k - 1) + 1
275 |
276 | L = math.floor(L)
277 | return L
--------------------------------------------------------------------------------
/src/modules/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import TransformerEncoderLayer
4 | from torch.nn.utils.parametrizations import weight_norm
5 | from torch.distributions import Categorical
6 |
7 | def WNConv1d(*args, **kwargs):
8 | act = kwargs.pop("act", False)
9 | conv = weight_norm(nn.Conv1d(*args, **kwargs))
10 | if act:
11 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
12 | else:
13 | return conv
14 |
15 |
16 | def WNConvTranspose1d(*args, **kwargs):
17 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
18 |
19 |
20 |
21 | def WNConv2d(*args, **kwargs):
22 | act = kwargs.pop("act", False)
23 | conv = weight_norm(nn.Conv2d(*args, **kwargs))
24 | if act:
25 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
26 | else:
27 | return conv
28 |
29 | # Scripting this brings model speed up 1.4x
30 | @torch.jit.script
31 | def snake(x, alpha):
32 | shape = x.shape
33 | x = x.reshape(shape[0], shape[1], -1)
34 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
35 | x = x.reshape(shape)
36 | return x
37 |
38 | class Snake1d(nn.Module):
39 | def __init__(self, channels):
40 | super().__init__()
41 | self.alpha = nn.Parameter(torch.ones(1, channels, 1))
42 |
43 | def forward(self, x):
44 | return snake(x, self.alpha)
45 |
46 |
47 | class TransformerSentenceEncoderLayer(nn.Module):
48 | """
49 | Stacks of TransformerEncoderlayer used for conv layer output
50 | """
51 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
52 | norm_first=False, bias=True, num_layers=1):
53 | super().__init__()
54 | layers = [TransformerEncoderLayer(d_model, nhead, dim_feedforward=dim_feedforward, dropout=0.1,
55 | batch_first=False, norm_first=norm_first, bias=bias)
56 | for _ in range(num_layers)]
57 | self.blocks = nn.Sequential(*layers)
58 |
59 | def forward(self, x):
60 | assert len(x.shape) == 3
61 | x = x.permute(2, 0, 1) # (bs, C, L) -> (L, bs, C)
62 | x = self.blocks(x)
63 | x = x.permute(1, 2, 0) # (L, bs, C) -> (bs, C, L)
64 | return x
65 |
66 |
67 | class SLSTM(nn.Module):
68 | """
69 | LSTM without worrying about the hidden state, nor the layout of the data.
70 | Expects input as convolutional layout.
71 | Modified from the EnCodec: https://github.com/facebookresearch/encodec/blob/main/encodec/modules/lstm.py
72 | """
73 | def __init__(self, dimension: int, num_layers: int = 2, bidirectional: bool = True, skip: bool = True):
74 | super().__init__()
75 | self.skip = skip
76 | self.bidirectional = bidirectional
77 | self.lstm = nn.LSTM(dimension, dimension, num_layers, bidirectional=self.bidirectional)
78 | if bidirectional:
79 | self.linear_ = nn.Linear(dimension*2, dimension)
80 |
81 | def forward(self, x):
82 | x = x.permute(2, 0, 1)
83 | y, _ = self.lstm(x)
84 | if self.bidirectional:
85 | y = self.linear_(y)
86 | if self.skip:
87 | y = y + x
88 | y = y.permute(1, 2, 0)
89 | return y
90 |
91 |
92 |
93 | class Jitter(nn.Module):
94 | """
95 | Shuffule the input by randomly swapping with neighborhood
96 | Modified from the SQ-VAE speech: https://github.com/sony/sqvae/blob/main/speech/model.py#L76
97 | Args:
98 | p: probability to shuffle code
99 | size: kernel size to shuffle code
100 | """
101 | def __init__(self, p, size=3):
102 | super().__init__()
103 | self.p = p
104 | self.size = size
105 | prob = torch.ones(size) * p / (size - 1)
106 | prob[size//2] = 1 - p
107 | self.register_buffer("prob", prob)
108 |
109 | def forward(self, x):
110 | if not self.training or self.p == 0.0:
111 | return x
112 | else:
113 | batch_size, dim, T = x.size()
114 |
115 | dist = Categorical(probs=self.prob)
116 | index = dist.sample(torch.Size([batch_size, T])) - len(self.prob)//2
117 | index += torch.arange(T, device=x.device)
118 | index.clamp_(0, T-1)
119 | x = torch.gather(x, -1, index.unsqueeze(1).expand(-1, dim, -1))
120 |
121 | return x
122 |
123 |
124 | if __name__ == "__main__":
125 |
126 | jit = Jitter(p=0.8, size=5)
127 | print(jit.prob)
128 |
129 | x = torch.arange(20).reshape(1,1,20)
130 | x = torch.cat((x, x+100), dim=0)
131 | print(x)
132 | print(jit(x))
--------------------------------------------------------------------------------
/src/modules/quantize.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from DAC: https://github.com/descriptinc/descript-audio-codec
3 | Correct some arguement description
4 | """
5 | from typing import Union
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from einops import rearrange
12 |
13 | from .layers import WNConv1d, Jitter
14 |
15 |
16 | class VectorQuantize(nn.Module):
17 | """
18 | Implementation of VQ similar to Karpathy's repo:
19 | https://github.com/karpathy/deep-vector-quantization
20 | Additionally uses following tricks from Improved VQGAN
21 | (https://arxiv.org/pdf/2110.04627.pdf):
22 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
23 | for improved codebook usage
24 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
25 | improves training stability
26 | """
27 |
28 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
29 | super().__init__()
30 | self.codebook_size = codebook_size
31 | self.codebook_dim = codebook_dim
32 |
33 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
34 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
35 | self.codebook = nn.Embedding(codebook_size, codebook_dim)
36 |
37 | def forward(self, z):
38 | """Quantized the input tensor using a fixed codebook and returns
39 | the corresponding codebook vectors
40 |
41 | Parameters
42 | ----------
43 | z : Tensor[B x D x T]
44 |
45 | Returns
46 | -------
47 | "z_q" : Tensor[B x D x T]
48 | Quantized continuous representation of input
49 | "vq/commitment_loss" : Tensor[1]
50 | Commitment loss to train encoder to predict vectors closer to codebook
51 | entries
52 | "vq/codebook_loss" : Tensor[1]
53 | Codebook loss to update the codebook
54 | "codes" : Tensor[B x T]
55 | Codebook indices (quantized discrete representation of input)
56 | "latents" : Tensor[B x D' x T]
57 | Projected latents (continuous representation of input before quantization)
58 | """
59 |
60 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
61 | z_e = self.in_proj(z) # z_e : (B x D x T)
62 | z_q, indices = self.decode_latents(z_e)
63 |
64 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
65 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
66 |
67 | z_q = (
68 | z_e + (z_q - z_e).detach()
69 | ) # noop in forward pass, straight-through gradient estimator in backward pass
70 |
71 | z_q = self.out_proj(z_q)
72 |
73 | return z_q, commitment_loss, codebook_loss, indices, z_e
74 |
75 | def embed_code(self, embed_id):
76 | return F.embedding(embed_id, self.codebook.weight)
77 |
78 | def decode_code(self, embed_id):
79 | return self.embed_code(embed_id).transpose(1, 2)
80 |
81 | def decode_latents(self, latents):
82 | encodings = rearrange(latents, "b d t -> (b t) d")
83 | codebook = self.codebook.weight # codebook: (N x D)
84 |
85 | # L2 normalize encodings and codebook (ViT-VQGAN)
86 | encodings = F.normalize(encodings)
87 | codebook = F.normalize(codebook)
88 |
89 | # Compute euclidean distance with codebook
90 | dist = (
91 | encodings.pow(2).sum(1, keepdim=True)
92 | - 2 * encodings @ codebook.t()
93 | + codebook.pow(2).sum(1, keepdim=True).t()
94 | )
95 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
96 | z_q = self.decode_code(indices)
97 | return z_q, indices
98 |
99 |
100 | class ResidualVectorQuantize(nn.Module):
101 | """
102 | Introduced in SoundStream: An end2end neural audio codec
103 | https://arxiv.org/abs/2107.03312
104 | """
105 |
106 | def __init__(
107 | self,
108 | input_dim: int = 1024,
109 | n_codebooks: int = 9,
110 | codebook_size: int = 1024,
111 | codebook_dim: Union[int, list] = 8,
112 | quantizer_dropout: float = 0.0,
113 | ):
114 | super().__init__()
115 | if isinstance(codebook_dim, int):
116 | codebook_dim = [codebook_dim for _ in range(n_codebooks)]
117 |
118 | self.n_codebooks = n_codebooks
119 | self.codebook_dim = codebook_dim
120 | self.codebook_size = codebook_size
121 |
122 | self.quantizers = nn.ModuleList(
123 | [
124 | VectorQuantize(input_dim, codebook_size, codebook_dim[i])
125 | for i in range(n_codebooks)
126 | ]
127 | )
128 | self.quantizer_dropout = quantizer_dropout
129 |
130 | def forward(self, z, n_quantizers: int = None):
131 | """Quantized the input tensor using a fixed set of `n` codebooks and returns
132 | the corresponding codebook vectors
133 | Parameters
134 | ----------
135 | z : Tensor[B x D x T]
136 | n_quantizers : int, optional
137 | No. of quantizers to use
138 | (n_quantizers < self.n_codebooks ex: for quantizer dropout)
139 | Note: if `self.quantizer_dropout` is True, this argument is ignored
140 | when in training mode, and a random number of quantizers is used.
141 |
142 | Returns
143 | -------
144 | "z_q" : Tensor[B x D x T]
145 | Quantized continuous representation of input
146 | "codes" : Tensor[B x N x T]
147 | Codebook indices for each codebook
148 | "latents" : Tensor[B x N*D' x T]
149 | Concatenated projected latents (continuous representation of input before quantization)
150 | "vq/commitment_loss" : Tensor[1]
151 | Commitment loss to train encoder to predict vectors closer to codebook entries
152 | "vq/codebook_loss" : Tensor[1]
153 | Codebook loss to update the codebook
154 | """
155 | z_q = 0
156 | residual = z
157 | commitment_loss = 0
158 | codebook_loss = 0
159 |
160 | codebook_indices = []
161 | latents = []
162 |
163 | if n_quantizers is None:
164 | n_quantizers = self.n_codebooks
165 | if self.training:
166 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
167 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
168 | n_dropout = int(z.shape[0] * self.quantizer_dropout)
169 | n_quantizers[:n_dropout] = dropout[:n_dropout]
170 | n_quantizers = n_quantizers.to(z.device)
171 |
172 | for i, quantizer in enumerate(self.quantizers):
173 | if self.training is False and i >= n_quantizers:
174 | break
175 |
176 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
177 | residual
178 | )
179 |
180 | # Create mask to apply quantizer dropout
181 | mask = (
182 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
183 | )
184 | z_q = z_q + z_q_i * mask[:, None, None]
185 | residual = residual - z_q_i
186 |
187 | # Sum losses
188 | commitment_loss += (commitment_loss_i * mask).mean()
189 | codebook_loss += (codebook_loss_i * mask).mean()
190 |
191 | codebook_indices.append(indices_i)
192 | latents.append(z_e_i)
193 |
194 | codes = torch.stack(codebook_indices, dim=1)
195 | latents = torch.cat(latents, dim=1)
196 |
197 | return z_q, codes, latents, commitment_loss, codebook_loss
198 |
199 | def from_codes(self, codes: torch.Tensor):
200 | """Given the quantized codes, reconstruct the continuous representation
201 | Parameters
202 | ----------
203 | codes : Tensor[B x N x T]
204 | Codebook indices for each codebook
205 |
206 | Returns
207 | -------
208 | "z_q" : Tensor[B x D x T]
209 | Quantized continuous representation of input
210 | "z_p" : Tensor[B x N*D' x T]
211 | Concatenated quantized codes before up-projection
212 | """
213 | z_q = 0.0
214 | z_p = []
215 | n_codebooks = codes.shape[1]
216 | for i in range(n_codebooks):
217 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
218 | z_p.append(z_p_i)
219 |
220 | z_q_i = self.quantizers[i].out_proj(z_p_i)
221 | z_q = z_q + z_q_i
222 | return z_q, torch.cat(z_p, dim=1), codes
223 |
224 | def from_latents(self, latents: torch.Tensor):
225 | """Given the unquantized latents, reconstruct the
226 | continuous representation after quantization.
227 |
228 | Parameters
229 | ----------
230 | latents : Tensor[B x N*D x T]
231 | Concatenated projected latents (continuous representation of input before quantization)
232 |
233 | Returns
234 | -------
235 | "z_q" : Tensor[B x D x T]
236 | Quantized representation of full-projected space
237 | "z_p" : Tensor[B x N*D' x T]
238 | Concatenated quantized codes before up-projection
239 | "codes" : Tensor[B x N x T]
240 | Codebook indices for each codebook
241 | """
242 | z_q = 0
243 | z_p = []
244 | codes = []
245 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
246 |
247 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
248 | 0
249 | ]
250 | for i in range(n_codebooks):
251 | j, k = dims[i], dims[i + 1]
252 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
253 | z_p.append(z_p_i)
254 | codes.append(codes_i)
255 |
256 | z_q_i = self.quantizers[i].out_proj(z_p_i)
257 | z_q = z_q + z_q_i
258 |
259 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
260 |
261 |
262 | class MultiSourceRVQ(nn.Module):
263 | """
264 | Parallele RVQ for multiple sources
265 | """
266 | def __init__(
267 | self,
268 | tracks: list[str] = ['speech', 'music', 'sfx'],
269 | input_dim: int = 1024,
270 | n_codebooks: list[int] = [12, 12, 12],
271 | codebook_size: list[int] = [1024, 1024, 1024],
272 | codebook_dim: list[int] = [8, 8, 8],
273 | quantizer_dropout: float = 0.0,
274 | code_jit_prob: list[float] = [0.0, 0.0, 0.0],
275 | code_jit_size: list[int] = [3, 5, 3],
276 | shared_codebooks: int = 8,
277 | ):
278 | super().__init__()
279 | self.tracks = tracks
280 | self.input_dim = input_dim
281 | self.n_codebooks = n_codebooks
282 | self.codebook_size = codebook_size
283 | self.codebook_dim = codebook_dim
284 | self.quantizer_dropout = quantizer_dropout
285 | self.code_jit_prob = code_jit_prob
286 | self.code_jit_size = code_jit_size
287 | self.shared_codebooks = shared_codebooks
288 |
289 | self.rvq_dict = nn.ModuleDict()
290 | self.jitter_dict = nn.ModuleDict()
291 |
292 | for i, t in enumerate(self.tracks):
293 | self.rvq_dict[t] = ResidualVectorQuantize(
294 | input_dim=self.input_dim,
295 | n_codebooks=self.n_codebooks[i]-self.shared_codebooks,
296 | codebook_size=self.codebook_size[i],
297 | codebook_dim=self.codebook_dim[i],
298 | quantizer_dropout=self.quantizer_dropout,
299 | )
300 | self.jitter_dict[t] = Jitter(
301 | p=self.code_jit_prob[i],
302 | size=self.code_jit_size[i],
303 | )
304 |
305 | if shared_codebooks > 0:
306 | self.shared_rvq = ResidualVectorQuantize(
307 | input_dim=self.input_dim,
308 | n_codebooks=self.shared_codebooks,
309 | codebook_size=self.codebook_size[0],
310 | codebook_dim=self.codebook_dim[0],
311 | quantizer_dropout=self.quantizer_dropout,
312 | )
313 |
314 | def forward(self, track_name, feats, n_quantizers: int = None):
315 | assert track_name in self.tracks, '{} not in model tracks: {}'.format(track_name, self.tracks)
316 |
317 | # number of quantizer to be used
318 | if n_quantizers is None:
319 | n_quantizers = self.n_codebooks[self.tracks.index(track_name)]
320 |
321 | # quantize
322 | z, codes, latents, commitment_loss, codebook_loss = self.rvq_dict[track_name](
323 | feats, n_quantizers
324 | )
325 |
326 | # shared codebook
327 | if self.shared_codebooks > 0 and n_quantizers > self.rvq_dict[track_name].n_codebooks:
328 | z_shard, codes_shard, latents_shard, commitment_loss_shard, codebook_loss_shard = self.shared_rvq(
329 | feats-z, n_quantizers-self.rvq_dict[track_name].n_codebooks
330 | )
331 | z = z + z_shard
332 | codes = torch.cat((codes, codes_shard), dim=1) # (B, N, T)
333 | latents = torch.cat((latents, latents_shard), dim=1) # (B, N*D', T)
334 | commitment_loss += commitment_loss_shard
335 | codebook_loss += codebook_loss_shard
336 |
337 | # jitter, ignored if prob <= 0
338 | z = self.jitter_dict[track_name](z) # (B, D, T)
339 |
340 | return z, codes, latents, commitment_loss, codebook_loss
341 |
342 |
343 | if __name__ == "__main__":
344 | rvq = MultiSourceRVQ(quantizer_dropout=0.0)
345 | feats = torch.randn(16, 1024, 20)
346 | z_q, codes, latents, commitment_loss, codebook_loss = rvq('speech', feats)
347 | print(z_q.shape)
348 | print(codes.shape)
349 | print(latents.shape)
350 | print(commitment_loss)
351 | print(codebook_loss)
352 |
353 | # z_q, z_p, codes = rvq.from_latents(latents)
354 | # print(z_q[0,0])
355 | # print(codes[0,0])
--------------------------------------------------------------------------------
/src/optim/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Scheduler modified from AudioCraft project https://github.com/facebookresearch/audiocraft/tree/main
3 | """
4 |
5 | # flake8: noqa
6 | from .cosine_lr_scheduler import CosineLRScheduler
7 | from .exponential_lr_scheduler import ExponentialLRScheduler
8 | from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler
9 | from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler
10 | from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler
11 | from .reduce_plateau_lr_scheduler import ReducePlateauLRScheduler
--------------------------------------------------------------------------------
/src/optim/cosine_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | from torch.optim import Optimizer
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | class CosineLRScheduler(_LRScheduler):
14 | """Cosine LR scheduler.
15 |
16 | Args:
17 | optimizer (Optimizer): Torch optimizer.
18 | warmup_steps (int): Number of warmup steps.
19 | total_steps (int): Total number of steps.
20 | lr_min_ratio (float): Minimum learning rate.
21 | cycle_length (float): Cycle length.
22 | """
23 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
24 | lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
25 | self.warmup_steps = warmup_steps
26 | assert self.warmup_steps >= 0
27 | self.total_steps = total_steps
28 | assert self.total_steps >= warmup_steps
29 | self.lr_min_ratio = lr_min_ratio
30 | self.cycle_length = cycle_length
31 | super().__init__(optimizer)
32 |
33 | def _get_sched_lr(self, lr: float, step: int):
34 | if step < self.warmup_steps:
35 | lr_ratio = step / self.warmup_steps
36 | lr = lr_ratio * lr
37 | elif step <= self.total_steps:
38 | s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
39 | lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
40 | (1. + math.cos(math.pi * s / self.cycle_length))
41 | lr = lr_ratio * lr
42 | else:
43 | lr_ratio = self.lr_min_ratio
44 | lr = lr_ratio * lr
45 | return lr
46 |
47 | def get_lr(self):
48 | return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
--------------------------------------------------------------------------------
/src/optim/exponential_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | from torch.optim import Optimizer
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | class ExponentialLRScheduler(_LRScheduler):
14 | """Exponential LR scheduler.
15 |
16 | Args:
17 | optimizer (Optimizer): Torch optimizer.
18 | warmup_steps (int): Number of warmup steps.
19 | total_steps (int): Total number of steps.
20 | lr_min_ratio (float): Minimum learning rate.
21 | cycle_length (float): Cycle length.
22 | """
23 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
24 | lr_min_ratio: float = 0.0, gamma: float = 0.99999):
25 | self.warmup_steps = warmup_steps
26 | assert self.warmup_steps >= 0
27 | self.total_steps = total_steps
28 | assert self.total_steps >= warmup_steps
29 | self.lr_min_ratio = lr_min_ratio
30 | self.cumprod = 1
31 | self.gamma = gamma
32 | super().__init__(optimizer)
33 |
34 | def _get_sched_lr(self, lr: float, step: int):
35 | if step < self.warmup_steps:
36 | lr_ratio = step / self.warmup_steps
37 | lr = lr_ratio * lr
38 | elif step <= self.total_steps:
39 | self.cumprod *= self.gamma
40 | lr_ratio = self.cumprod
41 | lr = lr_ratio * lr
42 | else:
43 | lr_ratio = self.lr_min_ratio
44 | lr = lr_ratio * lr
45 | return lr
46 |
47 | def get_lr(self):
48 | return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
--------------------------------------------------------------------------------
/src/optim/inverse_sqrt_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing as tp
8 |
9 | from torch.optim import Optimizer
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | class InverseSquareRootLRScheduler(_LRScheduler):
14 | """Inverse square root LR scheduler.
15 |
16 | Args:
17 | optimizer (Optimizer): Torch optimizer.
18 | warmup_steps (int): Number of warmup steps.
19 | warmup_init_lr (tp.Optional[float]): Initial learning rate
20 | during warmup phase. When not set, use the provided learning rate.
21 | """
22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
23 | self.warmup_steps = warmup_steps
24 | self.warmup_init_lr = warmup_init_lr
25 | super().__init__(optimizer)
26 |
27 | def _get_sched_lr(self, lr: float, step: int):
28 | if step < self.warmup_steps:
29 | warmup_init_lr = self.warmup_init_lr or 0
30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps
31 | lr = warmup_init_lr + step * lr_step
32 | else:
33 | decay_factor = lr * self.warmup_steps**0.5
34 | lr = decay_factor * step**-0.5
35 | return lr
36 |
37 | def get_lr(self):
38 | return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/src/optim/linear_warmup_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing as tp
8 |
9 | from torch.optim import Optimizer
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | class LinearWarmupLRScheduler(_LRScheduler):
14 | """Inverse square root LR scheduler.
15 |
16 | Args:
17 | optimizer (Optimizer): Torch optimizer.
18 | warmup_steps (int): Number of warmup steps.
19 | warmup_init_lr (tp.Optional[float]): Initial learning rate
20 | during warmup phase. When not set, use the provided learning rate.
21 | """
22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
23 | self.warmup_steps = warmup_steps
24 | self.warmup_init_lr = warmup_init_lr
25 | super().__init__(optimizer)
26 |
27 | def _get_sched_lr(self, lr: float, step: int):
28 | if step < self.warmup_steps:
29 | warmup_init_lr = self.warmup_init_lr or 0
30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps
31 | lr = warmup_init_lr + step * lr_step
32 | return lr
33 |
34 | def get_lr(self):
35 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/src/optim/polynomial_decay_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torch.optim import Optimizer
8 | from torch.optim.lr_scheduler import _LRScheduler
9 |
10 |
11 | class PolynomialDecayLRScheduler(_LRScheduler):
12 | """Polynomial decay LR scheduler.
13 |
14 | Args:
15 | optimizer (Optimizer): Torch optimizer.
16 | warmup_steps (int): Number of warmup steps.
17 | total_steps (int): Total number of steps.
18 | end_lr (float): Final learning rate to achieve over total number of steps.
19 | zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
20 | power (float): Decay exponent.
21 | """
22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
23 | end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
24 | self.warmup_steps = warmup_steps
25 | self.total_steps = total_steps
26 | self.end_lr = end_lr
27 | self.zero_lr_warmup_steps = zero_lr_warmup_steps
28 | self.power = power
29 | super().__init__(optimizer)
30 |
31 | def _get_sched_lr(self, lr: float, step: int):
32 | if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
33 | lr = 0
34 | elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
35 | lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
36 | lr = lr_ratio * lr
37 | elif step >= self.total_steps:
38 | lr = self.end_lr
39 | else:
40 | total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
41 | lr_range = lr - self.end_lr
42 | pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
43 | lr = lr_range * pct_remaining ** self.power + self.end_lr
44 | return lr
45 |
46 | def get_lr(self):
47 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
--------------------------------------------------------------------------------
/src/optim/reduce_plateau_lr_scheduler.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import math
4 |
5 | from torch.optim import Optimizer
6 | from torch.optim.lr_scheduler import _LRScheduler
7 |
8 |
9 | class ReducePlateauLRScheduler(_LRScheduler):
10 | """Cosine LR scheduler.
11 | Add warmup steps for torch.optim.lr_scheduler.ReduceLROnPlateau
12 |
13 | Args:
14 | optimizer (Optimizer): Torch optimizer.
15 | warmup_steps (int): Number of warmup steps.
16 | total_steps (int): Total number of steps.
17 | lr_min_ratio (float): Minimum learning rate.
18 | cycle_length (float): Cycle length.
19 | """
20 |
21 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
22 | lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
23 | # TODO
24 | pass
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2024 by Telecom-Paris
3 | Authoried by Xiaoyu BIE (xiaoyu.bie@telecom-paris.fr)
4 | License agreement in LICENSE.txt
5 | """
6 | import sys
7 | import shutil
8 | import numpy as np
9 | from pathlib import Path
10 | from omegaconf import DictConfig, OmegaConf
11 | from hydra.core.hydra_config import HydraConfig
12 | from collections import OrderedDict, defaultdict
13 | import random
14 | from einops import rearrange
15 |
16 | import torch
17 | from torch.utils.data import DataLoader
18 | import torch.nn.functional as F
19 | from accelerate import Accelerator
20 | from accelerate.logging import get_logger
21 | from accelerate.utils import set_seed
22 | from torch.distributions import Categorical
23 |
24 | from src import datasets, models, optim, utils
25 | from .metrics import (
26 | VisqolMetric,
27 | SingleSrcNegSDR,
28 | MultiScaleSTFTLoss,
29 | MelSpectrogramLoss,
30 | GANLoss,
31 | )
32 |
33 |
34 | logger = get_logger(__name__) # avoid douplicated print, params defined by Hydra
35 | accelerator = Accelerator(project_dir=HydraConfig.get().runtime.output_dir,
36 | step_scheduler_with_optimizer=False,
37 | log_with="tensorboard")
38 |
39 | class Trainer(object):
40 | def __init__(self, cfg: DictConfig):
41 |
42 | # Init
43 | self.cfg = cfg
44 | OmegaConf.set_struct(self.cfg, False) # enable config.pop()
45 | self.project_dir = Path(accelerator.project_dir)
46 | self.device = accelerator.device
47 | logger.info('Init Trainer')
48 |
49 | # Fix random
50 | seed = self.cfg.training.get('seed', False)
51 | if seed:
52 | set_seed(seed)
53 |
54 | # Backup code, only on main process
55 | if self.cfg.backup_code and accelerator.is_main_process:
56 | back_dir = self.project_dir / 'backup_src'
57 | logger.info(f'Backup code at: {back_dir}')
58 | cwd = HydraConfig.get().runtime.cwd
59 | src_dir = Path(cwd) / 'src'
60 | if back_dir.exists():
61 | shutil.rmtree(back_dir)
62 | shutil.copytree(src=src_dir, dst=back_dir)
63 |
64 | # Checkpoint
65 | self.ckpt_best = self.project_dir / 'ckpt_best'
66 | self.ckptdir = self.project_dir / 'checkpoints'
67 | self.ckptdir.mkdir(exist_ok=True)
68 | self.ckpt_final = self.project_dir / 'ckpt_final'
69 | self.ckpt_final.mkdir(exist_ok=True)
70 |
71 | # Check resume
72 | if self.cfg.resume:
73 | resume_dir = self.cfg.get('resume_dir', None)
74 | if resume_dir is None:
75 | logger.info(f'No resume_dir provided, try to resume from best ckpt directory...')
76 | self.ckpt_resume = self.ckpt_best
77 | else:
78 | self.ckpt_resume = self.project_dir / resume_dir
79 | if not self.ckpt_best.is_dir():
80 | self.cfg.resume = False
81 | logger.info(f'Resume FAILED, no ckpt dir at: {self.ckpt_best}')
82 |
83 | # Tensorboard tracker
84 | accelerator.init_trackers(project_name='tb')
85 | self.tracker = accelerator.get_tracker("tensorboard")
86 | logger.info('Tracker backend: tensorboard')
87 |
88 | # Prepare dataset
89 | self.sr = self.cfg.sampling_rate
90 | logger.info('=====> Training dataloader')
91 | self.train_loader = self._get_data(self.cfg.dataset.trainset_cfg, self.cfg.training.dataloader,
92 | is_train=True, sample_rate=self.sr)
93 | logger.info('=====> Validation dataloader')
94 | self.val_loader = self._get_data(self.cfg.dataset.valset_cfg, self.cfg.training.dataloader,
95 | is_train=False, sample_rate=self.sr)
96 | logger.info('=====> Test dataloader')
97 | self.test_loader = self._get_data(self.cfg.dataset.testset_cfg, self.cfg.training.dataloader,
98 | is_train=False, sample_rate=self.sr)
99 |
100 | # Prepare generator
101 | model_name = self.cfg.model.pop('name')
102 | optim_name = self.cfg.training.optimizer.pop('name')
103 | scheduler_name = self.cfg.training.scheduler.pop('name')
104 | self.model = self._get_model(model_name, self.cfg.model, self.sr)
105 | self.optimizer_g = self._get_optim(self.model.parameters(), optim_name, self.cfg.training.optimizer)
106 | self.scheduler_g = self._get_scheduler(self.optimizer_g, scheduler_name, self.cfg.training.scheduler)
107 |
108 | # Prepare discriminator
109 | dis_name = self.cfg.discriminator.pop('name')
110 | self.discriminator = self._get_model(dis_name, self.cfg.discriminator, sample_rate=self.sr)
111 | self.optimizer_d = self._get_optim(self.discriminator.parameters(), optim_name, self.cfg.training.optimizer)
112 | self.scheduler_d = self._get_scheduler(self.optimizer_d, scheduler_name, self.cfg.training.scheduler)
113 |
114 | # Prepare training recording
115 | self.tm = utils.TrainMonitor()
116 |
117 | # Accelerator preparation
118 | self._acc_prepare()
119 |
120 | # Define the loss function
121 | self._metric_prepare(self.cfg.training.loss)
122 | self.lambdas = self.cfg.training.loss.lambdas
123 |
124 | # Define the audio transform function
125 | self._transform_prepare(self.cfg.training.transform)
126 |
127 | # Resume
128 | if self.cfg.resume:
129 | logger.info(f'Resume training from: {self.ckpt_resume}')
130 | accelerator.load_state(self.ckpt_resume)
131 | self.tm.nb_step += 1
132 | logger.info(f'Training re-start from iter: {self.tm.nb_step}')
133 | else:
134 | logger.info(f'Experiment workdir: {self.project_dir}')
135 | logger.info(f'num_processes: {accelerator.num_processes}')
136 | logger.info(f'batch size per gpu for train: {self.cfg.training.dataloader.train_bs}')
137 | logger.info(f'batch size per gpu for validation: {self.cfg.training.dataloader.eval_bs}')
138 | logger.info(f'mixed_precision: {accelerator.mixed_precision}')
139 | # logger.info(OmegaConf.to_yaml(self.cfg))
140 | logger.info('Trainer init finish')
141 |
142 | # Basic info
143 | self.tracks = self.cfg.model.tracks # ['speech', 'music', 'sfx']
144 | self.eval_tracks = ['mix_rec'] + [f'{t}_rec' for t in self.tracks] + \
145 | [f'{t}_sep' for t in self.tracks] + [f'{t}_sep_mask' for t in self.tracks]
146 | logger.info('Used audio tracks: {}'.format(self.tracks))
147 | logger.info('Eval audio tracks: {}'.format(self.eval_tracks))
148 | logger.info(f'Target sampling rate: {self.sr} Hz')
149 | logger.info(f'Random seed: {seed}')
150 |
151 |
152 | def _get_data(self, dataset_cfg, dataloader_cfg, is_train=True, sample_rate=44100):
153 |
154 | num_workers = dataloader_cfg.num_workers
155 |
156 | if is_train:
157 | batch_size = dataloader_cfg.train_bs
158 | shuffle = True
159 | drop_last = True
160 | data_class = getattr(datasets, f'DatasetAudioTrain')
161 | else:
162 | batch_size = dataloader_cfg.eval_bs
163 | shuffle = False
164 | drop_last = False
165 | data_class = getattr(datasets, f'DatasetAudioVal')
166 |
167 | dataset = data_class(sample_rate=sample_rate, **dataset_cfg)
168 | dataloader = DataLoader(dataset=dataset,
169 | batch_size=batch_size, num_workers=num_workers,
170 | shuffle=shuffle, drop_last=drop_last)
171 | return dataloader
172 |
173 |
174 | def _get_model(self, model_name, model_cfg, sample_rate=44100):
175 | logger.info(f"Model: {model_name}")
176 | net_class = getattr(models, f'{model_name}')
177 | model = net_class(sample_rate=sample_rate, **model_cfg)
178 | total_params = sum(p.numel() for p in model.parameters()) / 1e6
179 | logger.info(f'Total params: {total_params:.2f} Mb')
180 | total_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
181 | logger.info(f'Total trainable params: {total_train_params:.2f} Mb')
182 | return model
183 |
184 |
185 | def _get_optim(self, params, optim_name, optim_cfg):
186 | logger.info(f"Optimizer: {optim_name}")
187 | optim_class = getattr(torch.optim, optim_name)
188 | optimizer = optim_class(filter(lambda p: p.requires_grad, params), **optim_cfg)
189 | return optimizer
190 |
191 |
192 | def _get_scheduler(self, optimizer, scheduler_name, scheduler_cfg):
193 | logger.info(f"Scheduler: {scheduler_name}")
194 | sche_class = getattr(optim, scheduler_name)
195 | scheduler = sche_class(optimizer, **scheduler_cfg)
196 | return scheduler
197 |
198 |
199 | def _metric_prepare(self, loss_cfg):
200 | self.stft_loss = MultiScaleSTFTLoss(**loss_cfg.MultiScaleSTFTLoss)
201 | self.mel_loss = MelSpectrogramLoss(**loss_cfg.MelSpectrogramLoss)
202 | self.gan_loss = GANLoss(self.discriminator)
203 | self.eval_sisdr = SingleSrcNegSDR(sdr_type='sisdr')
204 | # self.eval_visqol = VisqolMetric(mode='speech', reduction=None)
205 |
206 |
207 | def _transform_prepare(self, transform_cfg):
208 | self.volume_norm = utils.VolumeNorm(sample_rate=self.sr)
209 | self.sep_norm = utils.WavSepMagNorm()
210 | self.peak_norm = utils.db_to_gain(transform_cfg.peak_norm_db)
211 |
212 |
213 | def _acc_prepare(self):
214 | self.model = accelerator.prepare(self.model)
215 | self.discriminator = accelerator.prepare(self.discriminator)
216 | self.optimizer_g = accelerator.prepare(self.optimizer_g)
217 | self.optimizer_d = accelerator.prepare(self.optimizer_d)
218 | self.scheduler_g = accelerator.prepare(self.scheduler_g)
219 | self.scheduler_d = accelerator.prepare(self.scheduler_d)
220 | self.train_loader = accelerator.prepare(self.train_loader)
221 | self.val_loader = accelerator.prepare(self.val_loader)
222 | self.test_loader = accelerator.prepare(self.test_loader)
223 | accelerator.register_for_checkpointing(self.tm)
224 | self.model_eval_func = self.model.module.evaluate if accelerator.use_distributed \
225 | else self.model.evaluate
226 | logger.info('{} iterations per epoch'.format(len(self.train_loader)))
227 |
228 |
229 | def _data_transform(self, batch, transform_cfg, nb_step=-1, is_eval=False):
230 |
231 | # re-build data
232 | if is_eval:
233 | batch['valid_tracks'] = self.tracks
234 | norm_var = 0
235 | else:
236 | # random drop 0-2 tracks
237 | dist = Categorical(probs=torch.tensor(transform_cfg.random_num_sources))
238 | num_sources = dist.sample() + 1
239 | batch['valid_tracks'] = random.sample(self.tracks, num_sources)
240 | norm_var = transform_cfg.lufs_norm_db['var']
241 |
242 | batch['in_sources'] = len(batch['valid_tracks'])
243 | # build mix
244 | mix_max_peak = torch.zeros_like(batch['speech'])[...,:1] # (bs, C, 1)
245 | for track in batch['valid_tracks']:
246 | # volume norm
247 | batch[track] = self.volume_norm(signal=batch[track],
248 | target_loudness=transform_cfg.lufs_norm_db[track],
249 | var=norm_var)
250 | # peak value
251 | peak = batch[track].abs().max(dim=-1, keepdims=True)[0]
252 | mix_max_peak = torch.maximum(peak, mix_max_peak)
253 |
254 | # peak norm
255 | peak_gain = torch.ones_like(mix_max_peak) # (bs, C, 1)
256 | peak_gain[mix_max_peak > self.peak_norm] = self.peak_norm / mix_max_peak[mix_max_peak > self.peak_norm]
257 | batch['mix'] = torch.zeros_like(batch['speech'])
258 | for track in batch['valid_tracks']:
259 | batch[track] *= peak_gain
260 | batch['mix'] += batch[track]
261 | # mix volum norm
262 | batch['mix'], mix_gain = self.volume_norm(signal=batch['mix'],
263 | target_loudness=transform_cfg.lufs_norm_db['mix'],
264 | var=norm_var,
265 | return_gain=True)
266 |
267 | # norm each track
268 | for track in batch['valid_tracks']:
269 | batch[track] *= mix_gain[:, None, None]
270 |
271 | # random swap tracks
272 | batch['random_swap'] = (not is_eval) and (random.random() < transform_cfg.random_swap_prob)
273 | if batch['random_swap']:
274 | bs = batch['mix'].shape[0]
275 | mix_ref = torch.zeros_like(batch['mix'])
276 | batch['shuffle_list'] = {}
277 | for track in self.tracks:
278 | shuffle_list = list(range(bs))
279 | random.shuffle(shuffle_list)
280 | batch['shuffle_list'][track] = shuffle_list
281 | if track in batch['valid_tracks']:
282 | mix_ref += batch[track][shuffle_list]
283 | else:
284 | mix_ref = batch['mix'].clone()
285 |
286 | batch['ref'] = torch.stack([mix_ref]+[batch[t].clone() for t in batch['valid_tracks']], dim=1) # (B, K, C, T)
287 |
288 | return batch
289 |
290 |
291 | def _print_logs(self, log_dict, title='Train', nb_step=0, use_tracker=True):
292 | msg = f"{title} iter {nb_step:d}"
293 |
294 | if title == 'Train':
295 | for k, v in log_dict.items():
296 | k = '/'.join(k.split('/')[1:])
297 | if k == 'lr':
298 | msg += f' {k}: {v:.8f}'
299 | else:
300 | msg += f' {k}: {v:.2f}'
301 | logger.info(msg)
302 | else:
303 | logger.info(msg)
304 | for c in self.eval_tracks:
305 | msg = f"--> {c}:"
306 | select_keys = filter(lambda k: k.split('/')[0] == f'eval_{c}'
307 | or k.split('/')[0] == f'test_{c}', log_dict.keys())
308 | for k in select_keys:
309 | v = log_dict[k]
310 | k = '/'.join(k.split('/')[1:])
311 | msg += ' {}: {:.2f}'.format(k, v)
312 | logger.info(msg)
313 |
314 | if use_tracker:
315 | self.tracker.log(log_dict, step=nb_step) # tracker automatically discard plt and audio
316 |
317 |
318 | def run(self):
319 | total_steps = self.cfg.training.total_steps
320 | print_steps = self.cfg.training.print_steps
321 | eval_steps = self.cfg.training.eval_steps
322 | vis_steps = self.cfg.training.vis_steps
323 | test_steps = self.cfg.training.test_steps
324 | early_stop = self.cfg.training.early_stop
325 | grad_clip = self.cfg.training.grad_clip
326 | save_iters = self.cfg.training.save_iters
327 |
328 | self.model.train()
329 | logger.info('Training...')
330 | while self.tm.nb_step <= total_steps:
331 | for batch in self.train_loader:
332 |
333 | # data transform and augmentation
334 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step)
335 |
336 | # train one step
337 | with accelerator.autocast():
338 | train_log_dict = self.train_one_step(batch, grad_clip)
339 |
340 | # print log
341 | if self.tm.nb_step % print_steps == 0:
342 | self._print_logs(train_log_dict, title='Train', nb_step=self.tm.nb_step)
343 |
344 | # eval
345 | if self.tm.nb_step % eval_steps == 0:
346 | val_log_dict = self.run_eval()
347 | self._print_logs(val_log_dict, title='Validation', nb_step=self.tm.nb_step)
348 |
349 | # save best val
350 | if val_log_dict['val'] < self.tm.best_eval:
351 | self.tm.best_eval = val_log_dict['val']
352 | self.tm.best_step = self.tm.nb_step
353 | logger.info("\t-->Validation improved!!! Save best!!!")
354 | accelerator.save_state(output_dir=self.ckpt_best, safe_serialization=False) # otherwise can't reload correctly
355 | # early stop
356 | else:
357 | self.tm.early_stop += 1
358 | if self.tm.early_stop >= early_stop:
359 | logger.info(f"\t--> Validation no imporved for {early_stop} times")
360 | logger.info(f"\t--> Training finished by early stop")
361 | logger.info(f"\t--> Best model saved at iter: {self.tm.best_step}")
362 | logger.info(f"\t--> Final test, load best ckpt")
363 | accelerator.load_state(self.ckpt_best)
364 | # save state and end
365 | unwrapped_model = accelerator.unwrap_model(self.model)
366 | torch.save(unwrapped_model.state_dict(), self.ckpt_final / 'ckpt_model_final.pth')
367 | unwrapped_discriminator = accelerator.unwrap_model(self.discriminator)
368 | torch.save(unwrapped_discriminator.state_dict(), self.ckpt_final / 'ckpt_discriminator_final.pth')
369 | logger.info(f"\t--> Final ckpt saved in {self.ckpt_final}")
370 | # final test
371 | test_log_dict = self.run_test()
372 | self._print_logs(test_log_dict, title='Final Test', nb_step=self.tm.best_step, use_tracker=False)
373 | accelerator.end_training()
374 | return
375 |
376 | # save model
377 | if self.tm.nb_step in save_iters:
378 | unwrapped_model = accelerator.unwrap_model(self.model)
379 | torch.save(unwrapped_model.state_dict(), self.ckptdir / 'ckpt_model_iter{}.pth'.format(self.tm.nb_step))
380 | unwrapped_discriminator = accelerator.unwrap_model(self.discriminator)
381 | torch.save(unwrapped_discriminator.state_dict(), self.ckptdir / 'ckpt_discriminator_iter{}.pth'.format(self.tm.nb_step))
382 | logger.info('\t--> Checkpoints saved for iteration: {}'.format(self.tm.nb_step))
383 |
384 | # vis
385 | if self.tm.nb_step % vis_steps == 0:
386 | self.run_vis(nb_step=self.tm.nb_step)
387 |
388 | # test set
389 | if self.tm.nb_step % test_steps == 0:
390 | test_log_dict = self.run_test()
391 | self._print_logs(test_log_dict, title='Test', nb_step=self.tm.nb_step)
392 |
393 | self.tm.nb_step += 1
394 |
395 | # training end due to maximum train iters
396 | if self.tm.nb_step > total_steps:
397 | logger.info(f"\t--> Training finished by reaching max iterations")
398 | logger.info(f"\t--> Best model saved at iter: {self.tm.best_step}")
399 | logger.info(f"\t--> Final test, load best ckpt")
400 | accelerator.load_state(self.ckpt_best)
401 | # save state and end
402 | unwrapped_model = accelerator.unwrap_model(self.model)
403 | torch.save(unwrapped_model.state_dict(), self.ckpt_final / 'ckpt_model_final.pth')
404 | unwrapped_discriminator = accelerator.unwrap_model(self.discriminator)
405 | torch.save(unwrapped_discriminator.state_dict(), self.ckpt_final / 'ckpt_discriminator_final.pth')
406 | logger.info(f"\t--> Final ckpt saved in {self.ckpt_final}")
407 | # final test
408 | test_log_dict = self.run_test()
409 | self._print_logs(test_log_dict, title='Final Test', nb_step=self.tm.best_step, use_tracker=False)
410 | accelerator.end_training()
411 | return
412 |
413 | # breakpoint()
414 |
415 |
416 | def train_one_step(self, batch, grad_clip):
417 | # Forward, AMP automatically set by Accelerator
418 | batch = self.model(batch)
419 |
420 | # Reshape in/out
421 | audio_recon = rearrange(batch['recon'], 'b k c t -> (b k) c t')
422 | audio_ref = rearrange(batch['ref'], 'b k c t -> (b k) c t')
423 |
424 | # Train discriminator
425 | batch["adv/disc_loss"] = self.gan_loss.discriminator_loss(fake=audio_recon, real=audio_ref)
426 | self.optimizer_d.zero_grad()
427 | accelerator.backward(batch["adv/disc_loss"])
428 | if accelerator.sync_gradients:
429 | grad_norm_d = accelerator.clip_grad_norm_(self.discriminator.parameters(), max_norm=grad_clip)
430 | self.optimizer_d.step()
431 | self.scheduler_d.step()
432 |
433 | # Train generator
434 | batch["stft/loss"] = self.stft_loss(est=audio_recon, ref=audio_ref)
435 | batch["mel/loss"] = self.mel_loss(est=audio_recon, ref=audio_ref)
436 | (
437 | batch["adv/gen_loss"],
438 | batch["adv/feat_loss"],
439 | ) = self.gan_loss.generator_loss(fake=audio_recon, real=audio_ref)
440 |
441 | loss = sum([v * batch[k] for k, v in self.lambdas.items() if k in batch])
442 |
443 | # debugging nan error
444 | if loss.isnan().any():
445 | logger.error('Nan detect, debugging...')
446 | ckpt_debug = self.project_dir / 'ckpt_debug'
447 | ckpt_debug.mkdir(exist_ok=True)
448 | data_debug = ckpt_debug / f'batch_data_{accelerator.process_index}.pth'
449 | accelerator.save_state(output_dir=ckpt_debug, safe_serialization=False)
450 | torch.save(batch, data_debug)
451 | logger.info(f"\t--> Debug state saved in {ckpt_debug}")
452 | accelerator.wait_for_everyone()
453 | accelerator.end_training()
454 | sys.exist()
455 | return
456 | # breakpoint()
457 |
458 |
459 | # Generator gradient descent
460 | self.optimizer_g.zero_grad()
461 | accelerator.backward(loss)
462 | if accelerator.sync_gradients:
463 | grad_norm_g = accelerator.clip_grad_norm_(self.model.parameters(), max_norm=grad_clip)
464 | self.optimizer_g.step()
465 | self.scheduler_g.step()
466 |
467 | # Mean reduce across all GPUs
468 | log_dict = OrderedDict()
469 | log_dict['train/lr'] = self.scheduler_g.get_last_lr()[0]
470 | log_dict['train/loss_d'] = accelerator.reduce(batch["adv/disc_loss"], reduction="mean").item()
471 | log_dict['train/loss_g'] = accelerator.reduce(loss, reduction="mean").item()
472 | log_dict['train/grad_norm_d'] = accelerator.reduce(grad_norm_d, reduction="mean").item()
473 | log_dict['train/grad_norm_g'] = accelerator.reduce(grad_norm_g, reduction="mean").item()
474 | log_dict['train/stft'] = accelerator.reduce(batch["stft/loss"], reduction="mean").item()
475 | log_dict['train/mel'] = accelerator.reduce(batch['mel/loss'], reduction="mean").item()
476 | log_dict['train/gen'] = accelerator.reduce(batch['adv/gen_loss'], reduction="mean").item()
477 | log_dict['train/feat'] = accelerator.reduce(batch['adv/feat_loss'], reduction="mean").item()
478 | log_dict['train/commit_loss'] = accelerator.reduce(batch['vq/commitment_loss'], reduction="mean").item()
479 | log_dict['train/cb_loss'] = accelerator.reduce(batch['vq/codebook_loss'], reduction="mean").item()
480 |
481 | return log_dict
482 |
483 |
484 | @torch.no_grad()
485 | def run_eval(self):
486 | """Distributed evaluation
487 | for inputs, targets in validation_dataloader:
488 | predictions = model(inputs)
489 | # Gather all predictions and targets
490 | all_predictions, all_targets = accelerator.gather_for_metrics((predictions, targets))
491 | # Example of use with a *Datasets.Metric*
492 | metric.add_batch(all_predictions, all_targets)
493 | """
494 | self.model.eval()
495 | am_dis_stft= {k: utils.AverageMeter() for k in self.eval_tracks}
496 | am_dis_mel = {k: utils.AverageMeter() for k in self.eval_tracks}
497 | am_sisdr = {k: utils.AverageMeter() for k in self.eval_tracks}
498 | am_ppl_rvq = {k: [utils.AverageMeter() for _ in range(self.cfg.model.quant_params.n_codebooks[i])] \
499 | for i, k in enumerate(self.tracks)
500 | }
501 |
502 | for batch in self.val_loader:
503 |
504 | # data transform and augmentation
505 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step, is_eval=True)
506 |
507 | # Forward
508 | batch = self.model(batch)
509 |
510 | # Distributed evaluation
511 | all_recon = accelerator.gather_for_metrics(batch['recon']) # valide for nested list/tuple/dict
512 | all_ref = accelerator.gather_for_metrics(batch['ref'])
513 |
514 | # Codebook perplexity
515 | for t, am_ppl_list in am_ppl_rvq.items():
516 | ppl_rvq = accelerator.gather_for_metrics(batch[f'{t}/ppl'].unsqueeze(0))
517 | ppl_rvq = ppl_rvq.mean(dim=0)
518 | for i, am_ppl in enumerate(am_ppl_list):
519 | am_ppl.update(ppl_rvq[i].item())
520 |
521 | # Eval mix reconstruction
522 | est = all_recon[:,0]
523 | ref = all_ref[:,0]
524 | am_dis_stft['mix_rec'].update(self.stft_loss(est=est, ref=ref).item())
525 | am_dis_mel['mix_rec'].update(self.mel_loss(est=est, ref=ref).item())
526 | am_sisdr['mix_rec'].update(- self.eval_sisdr(est=est, ref=ref).item())
527 |
528 | # Eval separation using synthesizer (decoder)
529 | for i, t in enumerate(self.tracks):
530 | est = all_recon[:,i+1]
531 | ref = all_ref[:,i+1]
532 | am_dis_stft[f'{t}_sep'].update(self.stft_loss(est=est, ref=ref).item())
533 | am_dis_mel[f'{t}_sep'].update(self.mel_loss(est=est, ref=ref).item())
534 | am_sisdr[f'{t}_sep'].update(- self.eval_sisdr(est=est, ref=ref).item())
535 |
536 | # Eval separation using mask
537 | all_sep_mask_norm = self.sep_norm(mix=all_ref[:,0:1], signal_sep=all_recon[:,1:])
538 | for i, t in enumerate(self.tracks):
539 | est = all_sep_mask_norm[:,i]
540 | ref = all_ref[:,i+1]
541 | ref = ref[...,:est.shape[-1]] # stft + istft. shorter
542 | am_dis_stft[f'{t}_sep_mask'].update(self.stft_loss(est=est, ref=ref).item())
543 | am_dis_mel[f'{t}_sep_mask'].update(self.mel_loss(est=est, ref=ref).item())
544 | am_sisdr[f'{t}_sep_mask'].update(- self.eval_sisdr(est=est, ref=ref).item())
545 |
546 | # Evaluate reconstruction of single track
547 | for i, t in enumerate(self.tracks):
548 | out_audio = self.model_eval_func(batch[t], output_tracks=[t])
549 | all_recon = accelerator.gather_for_metrics(out_audio)
550 | all_ref = accelerator.gather_for_metrics(batch[t])
551 | am_dis_stft[f'{t}_rec'].update(self.stft_loss(est=all_recon, ref=all_ref).item())
552 | am_dis_mel[f'{t}_rec'].update(self.mel_loss(est=all_recon, ref=all_ref).item())
553 | am_sisdr[f'{t}_rec'].update(- self.eval_sisdr(est=all_recon, ref=all_ref).item())
554 |
555 | log_dict = OrderedDict()
556 | for t in self.eval_tracks:
557 | log_dict[f'eval_{t}/stft'] = am_dis_stft[t].avg
558 | log_dict[f'eval_{t}/mel'] = am_dis_mel[t].avg
559 | log_dict[f'eval_{t}/sisdr'] = am_sisdr[t].avg
560 | for t in self.tracks:
561 | for i, am_ppl in enumerate(am_ppl_rvq[t]):
562 | log_dict[f'eval_ppl/{t}_q{i}'] = am_ppl.avg
563 |
564 | log_dict['val'] = - np.mean([log_dict[f'eval_{k}/sisdr'] for k in self.eval_tracks]) # key to update best model
565 |
566 | self.model.train()
567 |
568 | return log_dict
569 |
570 |
571 | @torch.no_grad()
572 | def run_test(self):
573 | self.model.eval()
574 |
575 | am_dis_stft= {k: utils.AverageMeter() for k in self.eval_tracks}
576 | am_dis_mel = {k: utils.AverageMeter() for k in self.eval_tracks}
577 | am_sisdr = {k: utils.AverageMeter() for k in self.eval_tracks}
578 | # am_visqol = {k: utils.AverageMeter() for k in self.eval_tracks}
579 | am_ppl_rvq = {k: [utils.AverageMeter() for _ in range(self.cfg.model.quant_params.n_codebooks[i])] \
580 | for i, k in enumerate(self.tracks)
581 | }
582 |
583 | for batch in self.val_loader:
584 |
585 | # data transform and augmentation
586 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step, is_eval=True)
587 |
588 | # Forward
589 | batch = self.model(batch)
590 |
591 | # Distributed evaluation
592 | all_recon = accelerator.gather_for_metrics(batch['recon']) # valide for nested list/tuple/dict
593 | all_ref = accelerator.gather_for_metrics(batch['ref'])
594 |
595 | # Codebook perplexity
596 | for t, am_ppl_list in am_ppl_rvq.items():
597 | ppl_rvq = accelerator.gather_for_metrics(batch[f'{t}/ppl'].unsqueeze(0))
598 | ppl_rvq = ppl_rvq.mean(dim=0)
599 | for i, am_ppl in enumerate(am_ppl_list):
600 | am_ppl.update(ppl_rvq[i].item())
601 |
602 | # Eval mix reconstruction
603 | est = all_recon[:,0]
604 | ref = all_ref[:,0]
605 | am_dis_stft['mix_rec'].update(self.stft_loss(est=est, ref=ref).item())
606 | am_dis_mel['mix_rec'].update(self.mel_loss(est=est, ref=ref).item())
607 | am_sisdr['mix_rec'].update(- self.eval_sisdr(est=est, ref=ref).item())
608 | ## distributed compute visqol
609 | # list_visqol = self.eval_visqol(est=batch['recon'][:,0], ref=batch['ref'][:,0], sr=self.sr)
610 | # scores_visqol = accelerator.gather_for_metrics(list_visqol)
611 | # am_visqol['mix'].update(np.mean(scores_visqol))
612 |
613 | # Eval separation using synthesizer (decoder)
614 | for i, t in enumerate(self.tracks):
615 | est = all_recon[:,i+1]
616 | ref = all_ref[:,i+1]
617 | am_dis_stft[f'{t}_sep'].update(self.stft_loss(est=est, ref=ref).item())
618 | am_dis_mel[f'{t}_sep'].update(self.mel_loss(est=est, ref=ref).item())
619 | am_sisdr[f'{t}_sep'].update(- self.eval_sisdr(est=est, ref=ref).item())
620 | ## distributed compute visqol
621 | # list_visqol = self.eval_visqol(est=batch['recon'][:,i+1], ref=batch['ref'][:,i+1], sr=self.sr)
622 | # scores_visqol = accelerator.gather_for_metrics(list_visqol)
623 | # am_visqol[f'{t}_sep'].update(np.mean(scores_visqol))
624 |
625 | # Eval separation using mask
626 | all_sep_mask_norm = self.sep_norm(mix=all_ref[:,0:1], signal_sep=all_recon[:,1:])
627 | for i, t in enumerate(self.tracks):
628 | est = all_sep_mask_norm[:,i]
629 | ref = all_ref[:,i+1]
630 | ref = ref[...,:est.shape[-1]] # stft + istft. shorter
631 | am_dis_stft[f'{t}_sep_mask'].update(self.stft_loss(est=est, ref=ref).item())
632 | am_dis_mel[f'{t}_sep_mask'].update(self.mel_loss(est=est, ref=ref).item())
633 | am_sisdr[f'{t}_sep_mask'].update(- self.eval_sisdr(est=est, ref=ref).item())
634 |
635 | # Evaluate reconstruction on each individual track
636 | for i, t in enumerate(self.tracks):
637 | out_audio = self.model_eval_func(batch[t], output_tracks=[t])
638 | all_recon = accelerator.gather_for_metrics(out_audio)
639 | all_ref = accelerator.gather_for_metrics(batch[t])
640 | am_dis_stft[f'{t}_rec'].update(self.stft_loss(est=all_recon, ref=all_ref).item())
641 | am_dis_mel[f'{t}_rec'].update(self.mel_loss(est=all_recon, ref=all_ref).item())
642 | am_sisdr[f'{t}_rec'].update(- self.eval_sisdr(est=all_recon, ref=all_ref).item())
643 | ## distributed compute visqol
644 | # list_visqol = self.eval_visqol(est=out_audio, ref=batch[t], sr=self.sr)
645 | # scores_visqol = accelerator.gather_for_metrics(list_visqol)
646 | # am_visqol[f'{t}_rec'].update(np.mean(scores_visqol))
647 |
648 | log_dict = OrderedDict()
649 | for t in self.eval_tracks:
650 | log_dict[f'test_{t}/stft'] = am_dis_stft[t].avg
651 | log_dict[f'test_{t}/mel'] = am_dis_mel[t].avg
652 | log_dict[f'test_{t}/sisdr'] = am_sisdr[t].avg
653 | # log_dict[f'test_{t}/visqol'] = am_visqol[t].avg
654 | for t in self.tracks:
655 | for i, am_ppl in enumerate(am_ppl_rvq[t]):
656 | log_dict[f'test_ppl/{t}_q{i}'] = am_ppl.avg
657 |
658 | self.model.train()
659 |
660 | return log_dict
661 |
662 |
663 | @torch.no_grad()
664 | @accelerator.on_main_process
665 | def run_vis(self, nb_step):
666 | self.model.eval()
667 |
668 | ret_dict = defaultdict(list)
669 |
670 | writer = self.tracker.writer
671 |
672 | vis_idx = self.cfg.training.get('vis_idx', [])
673 | for idx in vis_idx:
674 | # get data
675 | batch = self.val_loader.dataset.__getitem__(idx)
676 | for t in self.tracks:
677 | batch[t] = batch[t].unsqueeze(0).to(accelerator.device) # (1, 1, T)
678 | # data transform and augmentation
679 | batch = self._data_transform(batch, self.cfg.training.transform, self.tm.nb_step, is_eval=True)
680 | # single track recon
681 | for t in self.tracks:
682 | ret_dict[f'{t}_orig'].append(batch[t][0])
683 | ret_dict[f'{t}_recon'].append(self.model_eval_func(batch[t], output_tracks=[t])[0])
684 |
685 | # mix recon and separation
686 | sep_out = self.model_eval_func(batch['mix'], output_tracks= ['mix'] + self.tracks)
687 | ret_dict['mix_orig'].append(batch['mix'][0])
688 | ret_dict['mix_recon'].append(sep_out[:, 0]) # (1, T)
689 | for i, t in enumerate(self.tracks):
690 | ret_dict[f'{t}_sep'].append(sep_out[:, i+1])
691 |
692 | # separation using FFT-mask
693 | mix = batch['mix'].unsqueeze(2)
694 | signal_sep = sep_out[:,1:].unsqueeze(2)
695 | all_sep_mask_norm = self.sep_norm(mix, signal_sep)
696 | for p, t in enumerate(self.tracks):
697 | est = all_sep_mask_norm[0,p]
698 | right_pad = mix.shape[-1] - est.shape[-1]
699 | est = F.pad(est, (0, right_pad))
700 | ret_dict[f'{t}_sep_mask'].append(est)
701 |
702 |
703 | # mix
704 | audio_mix_orig = torch.cat(ret_dict['mix_orig'], dim=-1).detach().cpu()
705 | audio_mix_recon = torch.cat(ret_dict['mix_recon'], dim=-1).detach().cpu()
706 | writer.add_audio('mix_orig', audio_mix_orig, global_step=nb_step, sample_rate=self.sr)
707 | writer.add_audio('mix_recon', audio_mix_recon, global_step=nb_step, sample_rate=self.sr)
708 | audio_signal = torch.cat((audio_mix_orig, audio_mix_recon), dim=0).numpy()
709 | fig = utils.vis_spec(audio_signal, fs=self.sr, fig_width=8*len(vis_idx),
710 | tight_layout=False, save_fig=None)
711 | writer.add_figure('mix', fig, global_step=nb_step)
712 |
713 | # track
714 | for t in self.tracks:
715 | audio_orig = torch.cat(ret_dict[f'{t}_orig'], dim=-1).detach().cpu()
716 | audio_recon = torch.cat(ret_dict[f'{t}_recon'], dim=-1).detach().cpu()
717 | audio_sep = torch.cat(ret_dict[f'{t}_sep'], dim=-1).detach().cpu()
718 | audio_sep_mask = torch.cat(ret_dict[f'{t}_sep_mask'], dim=-1).detach().cpu()
719 | writer.add_audio(f'{t}_orig', audio_orig, global_step=nb_step, sample_rate=self.sr)
720 | writer.add_audio(f'{t}_recon', audio_recon, global_step=nb_step, sample_rate=self.sr)
721 | writer.add_audio(f'{t}_sep', audio_sep, global_step=nb_step, sample_rate=self.sr)
722 | writer.add_audio(f'{t}_sep_mask', audio_sep_mask, global_step=nb_step, sample_rate=self.sr)
723 | audio_signal = torch.cat((audio_orig, audio_recon, audio_sep, audio_sep_mask), dim=0).numpy()
724 | fig = utils.vis_spec(audio_signal, fs=self.sr, fig_width=8*len(vis_idx),
725 | tight_layout=False, save_fig=None)
726 | writer.add_figure(f'{t}', fig, global_step=nb_step)
727 |
728 | self.model.train()
729 |
730 | return
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .audio_process import (
2 | normalize_mean_var_np,
3 | normalize_max_norm_np,
4 | normalize_mean_var,
5 | normalize_max_norm,
6 | db_to_gain,
7 | VolumeNorm,
8 | WavSepMagNorm,
9 | )
10 | from .torch_utils import (
11 | warmup_learning_rate,
12 | get_scheduler,
13 | Configure_AdamW,
14 | )
15 | from .utils import (
16 | warm_up,
17 | AverageMeter,
18 | TrainMonitor,
19 | )
20 |
21 | from .vis import (
22 | vis_spec,
23 | )
--------------------------------------------------------------------------------
/src/utils/audio_process.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch import nn
4 | import torchaudio
5 | from collections import namedtuple
6 | from einops import rearrange
7 |
8 | def normalize_mean_var_np(wav, eps=1e-8, std=None):
9 | mean = np.mean(wav, axis=-1, keepdims=True)
10 | if std is None:
11 | std = np.std(wav, axis=-1, keepdims=True)
12 | return (wav - mean) / (std + eps)
13 |
14 | def normalize_max_norm_np(wav):
15 | return wav /np.max(np.abs(wav), axis=-1, keepdims=True)
16 |
17 |
18 | def normalize_mean_var(wav, eps=1e-8, std=None):
19 | mean = wav.mean(-1, keepdim=True)
20 | if std is None:
21 | std = wav.std(-1, keepdim=True)
22 | return (wav - mean) / (std + eps)
23 |
24 | def normalize_max_norm(wav):
25 | return wav / wav.abs().max(-1, keepdim=True)
26 |
27 | def db_to_gain(db):
28 | return np.power(10.0, db/20.0)
29 |
30 |
31 |
32 | class VolumeNorm:
33 | """
34 | Volume normalization to a specific loudness [LUFS standard]
35 | """
36 | def __init__(self, sample_rate=16000):
37 | self.lufs_meter = torchaudio.transforms.Loudness(sample_rate)
38 |
39 | def __call__(self, signal, target_loudness=-30, var=0, return_gain=False):
40 | """
41 | signal: torch.Tensor [B, channels, L]
42 | """
43 | bs = signal.shape[0]
44 | # LUFS diff
45 | lufs_ref = self.lufs_meter(signal)
46 | lufs_target = (target_loudness + (torch.rand(bs) * 2 - 1) * var).to(lufs_ref.device)
47 | # db to gain
48 | gain = torch.exp((lufs_target - lufs_ref) * np.log(10) / 20)
49 | gain[gain.isnan()] = 0 # zero gain for silent audio
50 | # norm
51 | signal *= gain[:, None, None]
52 |
53 | if return_gain:
54 | return signal, gain
55 | else:
56 | return signal
57 |
58 |
59 | STFTParams = namedtuple(
60 | "STFTParams",
61 | ["window_length", "hop_length", "window_type", "padding_type"],
62 | )
63 | STFT_PARAMS = STFTParams(
64 | window_length=1024,
65 | hop_length=256,
66 | window_type="hann",
67 | padding_type="reflect",
68 | )
69 |
70 |
71 | class WavSepMagNorm:
72 | """
73 | Normalize the separation results using the magnitude
74 | X_i = (|X_i| / sum_k<|X_i|> * |X_mix|) * exp(j arg)
75 | """
76 | def __init__(self):
77 | self.stft_params = STFT_PARAMS
78 |
79 | def __call__(self, mix, signal_sep):
80 | """
81 | Parameters
82 | ----------
83 | mix: torch.Tensor [B, 1, channels, L]
84 | Mixture signal
85 | signal: torch.Tensor [B, K, channels, L]
86 | Separation results without normalization
87 | Returns
88 | -------
89 | ret: torch.Tensor [B, K, channels, L']
90 | Separation results
91 | """
92 | bs, K, channels, _ = signal_sep.shape
93 | mix = rearrange(mix, 'b k c l -> (b k c) l')
94 | signal_sep = rearrange(signal_sep, 'b k c l -> (b k c) l')
95 |
96 | mix_spec = torch.stft(mix, n_fft=self.stft_params.window_length, hop_length=self.stft_params.hop_length,
97 | win_length=self.stft_params.window_length,
98 | window=torch.hann_window(self.stft_params.window_length, device=mix.device),
99 | pad_mode=self.stft_params.padding_type, center=True, onesided=True, return_complex=True)
100 | signal_sep_spec = torch.stft(signal_sep, n_fft=self.stft_params.window_length, hop_length=self.stft_params.hop_length,
101 | win_length=self.stft_params.window_length,
102 | window=torch.hann_window(self.stft_params.window_length, device=signal_sep.device),
103 | pad_mode=self.stft_params.padding_type, center=True, onesided=True, return_complex=True)
104 |
105 | mix_spec = rearrange(mix_spec, '(b k c) n t -> b k c n t', k=1, c=channels)
106 | signal_sep_spec = rearrange(signal_sep_spec, '(b k c) n t -> b k c n t', k=K, c=channels)
107 |
108 | signal_sep_mag = signal_sep_spec.abs() # (B, K, C, N, T)
109 | ratio = signal_sep_mag / signal_sep_mag.sum(dim=1, keepdim=True)
110 | ret_spec = torch.polar(mix_spec.abs() * ratio, mix_spec.angle())
111 |
112 | ret_spec = rearrange(ret_spec, 'b k c n t -> (b k c) n t')
113 | ret = torch.istft(ret_spec, n_fft=self.stft_params.window_length, hop_length=self.stft_params.hop_length,
114 | win_length=self.stft_params.window_length,
115 | window=torch.hann_window(self.stft_params.window_length, device=mix.device),
116 | center=True)
117 | ret = rearrange(ret, '(b k c) l -> b k c l', k=K, c=channels)
118 |
119 | return ret
120 |
121 |
122 | if __name__ == '__main__':
123 |
124 | import torch
125 | sep_norm = WavSepMagNorm()
126 | mix = torch.randn(32, 1, 1, 16000)
127 | signal_sep = torch.randn(32, 3, 1, 16000)
128 | ret = sep_norm(mix, signal_sep)
129 | print(ret.shape)
130 | print(((mix[...,:ret.shape[-1]] - ret.sum(dim=1, keepdim=True)) < 1e-6).all())
131 |
132 |
--------------------------------------------------------------------------------
/src/utils/torch_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim import lr_scheduler
3 |
4 | def warmup_learning_rate(optimizer, nb_iter, warmup_iter, max_lr):
5 | """warmup learning rate"""
6 | lr = max_lr * nb_iter / warmup_iter
7 | for param_group in optimizer.param_groups:
8 | param_group['lr'] = lr
9 | return lr
10 |
11 |
12 | def get_scheduler(optimizer, args):
13 | if args.policy == 'linear':
14 | scheduler = lr_scheduler.LinearLR(optimizer, total_iters=args.total_iter) # factor 0.33-1
15 | elif args.policy == 'cosine':
16 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.total_iter)
17 | elif args.policy == 'step':
18 | scheduler = lr_scheduler.StepLR(
19 | optimizer, step_size=args.decay_step, gamma=0.1)
20 | elif args.policy == 'multistep':
21 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_scheduler, gamma=0.05)
22 | elif args.policy == 'plateau':
23 | scheduler = lr_scheduler.ReduceLROnPlateau(
24 | optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
25 | else:
26 | return NotImplementedError('learning rate args.policy [%s] is not implemented', args.policy)
27 | return scheduler
28 |
29 |
30 | def Configure_AdamW(model, weight_decay, learning_rate):
31 |
32 | all_params = set(model.parameters())
33 | decay = set()
34 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv1d)
35 |
36 | for m in model.modules():
37 | if isinstance(m, (torch.nn.Linear, torch.nn.Conv1d)):
38 | decay.add(m.weight)
39 | no_decay = all_params - decay
40 |
41 | # create the pytorch optimizer object
42 | optim_groups = [
43 | {"params": list(decay), "weight_decay": weight_decay},
44 | {"params": list(no_decay), "weight_decay": 0.0},
45 | ]
46 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate)
47 | return optimizer
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import json
5 | from typing import List, Union
6 |
7 | def warm_up(step, warmup_steps):
8 | if step < warmup_steps:
9 | warmup_ratio = step / warmup_steps
10 | else:
11 | warmup_ratio = 1
12 | return warmup_ratio
13 |
14 |
15 | class AverageMeter(object):
16 | """Computes and stores the average and current value"""
17 |
18 | def __init__(self):
19 | self.hist = []
20 | self.reset()
21 |
22 | def reset(self):
23 | self.val = 0
24 | self.count = 0
25 | self.avg = 0
26 |
27 | def update(self, val, n=1):
28 | self.val = val
29 | self.count += n
30 | self.avg += (self.val - self.avg) * n / self.count
31 |
32 | def save_log(self):
33 | self.hist.append(self.avg)
34 | self.reset()
35 |
36 |
37 | class TrainMonitor(object):
38 | """Record training"""
39 |
40 | def __init__(self, nb_step=1, best_eval=math.inf, best_step=1, early_stop=0):
41 | self.nb_step = nb_step
42 | self.best_eval = best_eval
43 | self.best_step = best_step
44 | self.early_stop = early_stop
45 |
46 |
47 | def state_dict(self):
48 | sd = {'nb_step': self.nb_step,
49 | 'best_eval': self.best_eval,
50 | 'best_step': self.best_step,
51 | 'early_stop': self.early_stop,
52 | }
53 | return sd
54 |
55 |
56 | def load_state_dict(self, state_dict: dict):
57 | self.nb_step = state_dict['nb_step']
58 | self.best_eval = state_dict['best_eval']
59 | self.best_step = state_dict['best_step']
60 | self.early_stop = state_dict['early_stop']
--------------------------------------------------------------------------------
/src/utils/vis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import matplotlib.pylab as pylab
4 | params = {'legend.fontsize': 'xx-large',
5 | 'axes.labelsize': 'xx-large',
6 | 'axes.titlesize':'xx-large',
7 | 'xtick.labelsize':'xx-large',
8 | 'ytick.labelsize':'xx-large'}
9 | pylab.rcParams.update(params)
10 |
11 |
12 |
13 | def vis_spec(signal,
14 | fs=44100, nfft=1024, hop=512,
15 | fig_width=10, fig_height_per_plot=6, cmap='inferno', vmin=-150, vmax=None,
16 | use_colorbar=True, tight_layout=False, save_fig='./vis.png'):
17 |
18 | # plt.clf()
19 |
20 | # pre-plot
21 | if len(signal.shape) == 1:
22 | signal = signal.reshape(1, -1)
23 |
24 | num_sig = signal.shape[0]
25 | fig_size = (fig_width, fig_height_per_plot*num_sig)
26 | fig, axes = plt.subplots(nrows=num_sig, ncols=1, sharex=True, figsize=fig_size)
27 | if num_sig == 1:
28 | axes = [axes]
29 |
30 | # iterative plot
31 | for i, ax in enumerate(axes):
32 | x = signal[i] + 1e-9
33 | Pxx, freqs, bins, im = ax.specgram(x, scale='dB',
34 | Fs=fs, NFFT=nfft, noverlap=hop,
35 | vmin=vmin, vmax=vmax, cmap=cmap)
36 | if i == num_sig//2:
37 | ax.set_ylabel('Frequency (Hz)')
38 | if i == num_sig-1:
39 | ax.set_xlabel('Time (s)')
40 |
41 | # post-plot
42 | if use_colorbar:
43 | plt.colorbar(im, ax=axes, format="%+2.f dB")
44 |
45 | if tight_layout:
46 | plt.tight_layout()
47 |
48 | if save_fig:
49 | plt.savefig(save_fig)
50 | plt.close(fig)
51 | return
52 | else:
53 | return fig
54 |
55 |
56 | if __name__ == '__main__':
57 | import torchaudio
58 |
59 | filepath = '/home/xbie/Data/dnr_v2/tr/1002/mix.wav'
60 | x, fs = torchaudio.load(filepath)
61 | x = x.numpy().reshape(-1)
62 |
63 | import librosa
64 | sr = 44100
65 | clip = sr * 3
66 | _, (trim30dBs,_) = librosa.effects.trim(x[:clip], top_db=30)
67 | print(trim30dBs)
68 | breakpoint()
69 | # plt.clf()
70 | # plt.plot(x)
71 | # plt.savefig('./tmp.png')
72 | # plt.close()
73 | s = 500000
74 | e = 500000 + 51200
75 | x_clip = x[:, s:e]
76 | vis_spec(x_clip)
77 |
78 |
--------------------------------------------------------------------------------