├── .gitignore ├── LICENSE ├── README.md ├── config ├── inference.toml └── train.toml ├── inference.sh ├── mertrics.sh ├── run.sh └── speech_enhance ├── __init__.py ├── audio_zen ├── __init__.py ├── acoustics │ ├── __init__.py │ ├── beamforming.py │ ├── feature.py │ ├── mask.py │ └── utils.py ├── constant.py ├── dataset │ ├── __init__.py │ └── base_dataset.py ├── inferencer │ ├── __init__.py │ └── base_inferencer.py ├── loss.py ├── metrics.py ├── model │ ├── __init__.py │ ├── base_model.py │ └── module │ │ ├── __init__.py │ │ ├── attention_model.py │ │ ├── causal_conv.py │ │ ├── feature_norm.py │ │ ├── sequence_model.py │ │ └── si_module.py ├── trainer │ ├── __init__.py │ └── base_trainer.py └── utils.py ├── fullsubnet ├── __init__.py ├── dataset │ ├── __init__.py │ ├── dataset_inference.py │ ├── dataset_train.py │ └── dataset_validation.py ├── inferencer │ ├── __init__.py │ └── inferencer.py ├── model │ ├── __init__.py │ └── fullsubnet.py └── trainer │ ├── __init__.py │ └── trainer.py ├── fullsubnet_plus ├── __init__.py ├── dataset │ ├── __init__.py │ ├── dataset_inference.py │ ├── dataset_train.py │ └── dataset_validation.py ├── inferencer │ ├── __init__.py │ └── inferencer.py ├── model │ └── fullsubnet_plus.py └── trainer │ ├── __init__.py │ └── trainer.py ├── inter_subnet ├── __init__.py ├── dataset │ ├── __init__.py │ ├── dataset_inference.py │ ├── dataset_train.py │ └── dataset_validation.py ├── inferencer │ ├── __init__.py │ └── inferencer.py ├── model │ ├── Inter_SubNet.py │ └── __init__.py └── trainer │ ├── __init__.py │ ├── joint_trainer.py │ └── trainer.py ├── subband_model ├── __init__.py ├── dataset │ ├── __init__.py │ ├── dataset_inference.py │ ├── dataset_train.py │ └── dataset_validation.py ├── inferencer │ ├── __init__.py │ └── inferencer.py ├── model │ ├── __init__.py │ └── subband_model.py └── trainer │ ├── __init__.py │ ├── joint_trainer.py │ └── trainer.py ├── tools ├── __init__.py ├── analyse.py ├── calculate_metrics.py ├── collect_lst.py ├── dns_mos.py ├── gen_lst.py ├── inference.py ├── noisyspeech_synthesizer.py ├── resample_dir.py └── train.py └── utils ├── __init__.py ├── logger.py ├── plot.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Inter-SubNet 2 | 3 | The official PyTorch implementation of **"[Inter-SubNet: Speech Enhancement with Subband Interaction](https://arxiv.org/abs/2305.05599)"**, accepted by ICASSP 2023. 4 | 5 | 📜[[Full Paper](https://arxiv.org/abs/2305.05599)] ▶[[Demo](https://rookiejunchen.github.io/Inter-SubNet_demo/)] 💿[[Checkpoint](https://drive.google.com/file/d/1j9jdXRxPhXLE93XlYppCQtcOqMOJNjdt/view?usp=share_link)] 6 | 7 | 8 | 9 | ## Requirements 10 | 11 | - Linux or macOS 12 | 13 | - python>=3.6 14 | 15 | - Anaconda or Miniconda 16 | 17 | - NVIDIA GPU + CUDA CuDNN (CPU can also be supported) 18 | 19 | 20 | 21 | ### Environment && Installation 22 | 23 | Install Anaconda or Miniconda, and then install conda and pip packages: 24 | 25 | ```shell 26 | # Create conda environment 27 | conda create --name speech_enhance python=3.8 28 | conda activate speech_enhance 29 | 30 | # Install conda packages 31 | # Check python=3.8, cudatoolkit=10.2, pytorch=1.7.1, torchaudio=0.7 32 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch 33 | conda install tensorboard joblib matplotlib 34 | 35 | # Install pip packages 36 | # Check librosa=0.8 37 | pip install Cython 38 | pip install librosa pesq pypesq pystoi tqdm toml colorful mir_eval torch_complex 39 | 40 | # (Optional) If you want to load "mp3" format audio in your dataset 41 | conda install -c conda-forge ffmpeg 42 | ``` 43 | 44 | 45 | 46 | ### Quick Usage 47 | 48 | Clone the repository: 49 | 50 | ```shell 51 | git clone https://github.com/RookieJunChen/Inter-SubNet.git 52 | cd Inter-SubNet 53 | ``` 54 | 55 | Download the [pre-trained checkpoint](https://drive.google.com/file/d/1j9jdXRxPhXLE93XlYppCQtcOqMOJNjdt/view?usp=share_link), and input commands: 56 | 57 | ```shell 58 | source activate speech_enhance 59 | python -m speech_enhance.tools.inference \ 60 | -C config/inference.toml \ 61 | -M $MODEL_DIR \ 62 | -I $INPUT_DIR \ 63 | -O $OUTPUT_DIR 64 | ``` 65 | 66 |
67 | 68 | ## Start Up 69 | 70 | ### Clone 71 | 72 | ```shell 73 | git clone https://github.com/RookieJunChen/Inter-SubNet.git 74 | cd Inter-SubNet 75 | ``` 76 | 77 | 78 | 79 | ### Data preparation 80 | 81 | #### Train data 82 | 83 | Please prepare your data in the data dir as like: 84 | 85 | - data/DNS-Challenge/DNS-Challenge-interspeech2020-master/ 86 | - data/DNS-Challenge/DNS-Challenge-master/ 87 | 88 | and set the train dir in the script `run.sh`. 89 | 90 | Then: 91 | 92 | ```shell 93 | source activate speech_enhance 94 | bash run.sh 0 # peprare training list or meta file 95 | ``` 96 | 97 | #### Test data 98 | 99 | Please prepare your test cases dir like: `data/test_cases_`, and set the test dir in the script `run.sh`. 100 | 101 | 102 | 103 | ### Training 104 | 105 | First, you need to modify the various configurations in `config/train.toml` for training. 106 | 107 | Then you can run training: 108 | 109 | ```shell 110 | source activate speech_enhance 111 | bash run.sh 1 112 | ``` 113 | 114 | 115 | 116 | ### Inference 117 | 118 | After training, you can enhance noisy speech. Before inference, you first need to modify the configuration in `config/inference.toml`. 119 | 120 | You can also run inference: 121 | 122 | ```shell 123 | source activate speech_enhance 124 | bash run.sh 2 125 | ``` 126 | 127 | Or you can just use `inference.sh`: 128 | 129 | ```shell 130 | source activate speech_enhance 131 | bash inference.sh 132 | ``` 133 | 134 | 135 | 136 | 137 | 138 | ### Eval 139 | 140 | Calculating objective metrics (SI_SDR, STOI, WB_PESQ, NB_PESQ, etc.) : 141 | 142 | ```shell 143 | bash metrics.sh 144 | ``` 145 | 146 | For test set without reference, you can obtain subjective scores (DNS_MOS and NISQA, etc) through [DNSMOS](https://github.com/RookieJunChen/dns_mos_calculate) and [NISQA](https://github.com/RookieJunChen/my_NISQA). 147 | 148 | 149 | 150 | 151 | ## Citation 152 | If you find our work useful in your research, please consider citing: 153 | ``` 154 | @inproceedings{chen2023inter, 155 | title={Inter-Subnet: Speech Enhancement with Subband Interaction}, 156 | author={Chen, Jun and Rao, Wei and Wang, Zilin and Lin, Jiuxin and Wu, Zhiyong and Wang, Yannan and Shang, Shidong and Meng, Helen}, 157 | booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 158 | pages={1--5}, 159 | year={2023}, 160 | organization={IEEE} 161 | } 162 | ``` 163 | -------------------------------------------------------------------------------- /config/inference.toml: -------------------------------------------------------------------------------- 1 | [acoustics] 2 | n_fft = 512 3 | win_length = 512 4 | sr = 16000 5 | hop_length = 256 6 | 7 | 8 | [inferencer] 9 | path = "inter_subnet.inferencer.inferencer.Inferencer" 10 | type = "full_band_crm_mask" 11 | 12 | [inferencer.args] 13 | n_neighbor = 15 14 | 15 | 16 | [dataset] 17 | path = "inter_subnet.dataset.dataset_inference.Dataset" 18 | [dataset.args] 19 | dataset_dir_list = [ 20 | "/workspace/project-nas-11025-sh/speech_enhance/data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/with_reverb/noisy" 21 | # "/workspace/project-nas-11025-sh/speech_enhance/data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/no_reverb/noisy" 22 | ] 23 | sr = 16000 24 | 25 | 26 | [model] 27 | path = "inter_subnet.model.Inter_SubNet.Inter_SubNet" 28 | [model.args] 29 | sb_num_neighbors = 15 30 | num_freqs = 257 31 | look_ahead = 2 32 | sequence_model = "LSTM" 33 | sb_output_activate_function = false 34 | sb_model_hidden_size = 384 35 | weight_init = false 36 | norm_type = "offline_laplace_norm" 37 | num_groups_in_drop_band = 2 38 | sbinter_middle_hidden_times = 0.8 39 | 40 | -------------------------------------------------------------------------------- /config/train.toml: -------------------------------------------------------------------------------- 1 | [meta] 2 | save_dir = "logs/Inter_SubNet" 3 | description = "This is a description of Inter-SubNet experiment." 4 | seed = 0 # set random seed for random, numpy, pytorch-gpu and pytorch-cpu 5 | port = "4396" 6 | keep_reproducibility = false # see https://pytorch.org/docs/stable/notes/randomness.html 7 | use_amp = false # use automatic mixed precision, it will benefits Tensor Core-enabled GPU (e.g. Volta, Turing, Ampere). 2-3X speedup。 8 | 9 | 10 | [acoustics] 11 | n_fft = 512 12 | win_length = 512 13 | sr = 16000 14 | hop_length = 256 15 | 16 | 17 | [loss_function] 18 | name = "mse_loss" 19 | [loss_function.args] 20 | 21 | 22 | [optimizer] 23 | lr = 0.001 24 | beta1 = 0.9 25 | beta2 = 0.999 26 | 27 | 28 | [train_dataset] 29 | path = "inter_subnet.dataset.dataset_train.Dataset" 30 | [train_dataset.args] 31 | clean_dataset = "train_data_DNS_2021_16k/clean_book.txt" 32 | clean_dataset_limit = false 33 | clean_dataset_offset = 0 34 | noise_dataset = "train_data_DNS_2021_16k/noise.txt" 35 | noise_dataset_limit = false 36 | noise_dataset_offset = 0 37 | num_workers = 36 38 | pre_load_clean_dataset = false 39 | pre_load_noise = false 40 | pre_load_rir = false 41 | reverb_proportion = 0.75 42 | rir_dataset = "train_data_DNS_2021_16k/rir.txt" 43 | rir_dataset_limit = false 44 | rir_dataset_offset = 0 45 | silence_length = 0.2 46 | snr_range = [-5, 20] 47 | sr = 16000 48 | sub_sample_length = 3.072 49 | target_dB_FS = -25 50 | target_dB_FS_floating_value = 10 51 | 52 | 53 | [train_dataset.dataloader] 54 | batch_size = 20 55 | num_workers = 24 56 | drop_last = true 57 | pin_memory = true 58 | 59 | 60 | [validation_dataset] 61 | path = "inter_subnet.dataset.dataset_validation.Dataset" 62 | [validation_dataset.args] 63 | dataset_dir_list = [ 64 | "data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/with_reverb/", 65 | "data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/no_reverb/" 66 | # "/dockerdata/thujunchen/data/DNS_Challenge/test_set/synthetic/with_reverb/", 67 | # "/dockerdata/thujunchen/data/DNS_Challenge/test_set/synthetic/no_reverb/" 68 | ] 69 | sr = 16000 70 | 71 | 72 | [model] 73 | path = "inter_subnet.model.Inter_SubNet.Inter_SubNet" 74 | [model.args] 75 | sb_num_neighbors = 15 76 | num_freqs = 257 77 | look_ahead = 2 78 | sequence_model = "LSTM" 79 | sb_output_activate_function = false 80 | sb_model_hidden_size = 384 81 | weight_init = false 82 | norm_type = "offline_laplace_norm" 83 | num_groups_in_drop_band = 2 84 | sbinter_middle_hidden_times = 0.8 85 | 86 | 87 | [trainer] 88 | path = "inter_subnet.trainer.trainer.New_Judge_Trainer" 89 | [trainer.train] 90 | clip_grad_norm_value = 10 91 | epochs = 9999 92 | alpha = 1 93 | save_checkpoint_interval = 1 94 | [trainer.validation] 95 | save_max_metric_score = true 96 | validation_interval = 1 97 | [trainer.visualization] 98 | metrics = ["WB_PESQ", "NB_PESQ", "STOI", "SI_SDR"] 99 | n_samples = 10 100 | num_workers = 12 101 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # do enhance(denoise) 4 | CUDA_VISIBLE_DEVICES='5' python -m speech_enhance.tools.inference \ 5 | -C config/inference.toml \ 6 | -M /cjcode/ft_local/intersubnet/intersubnet.tar \ 7 | -I /DNS_2021/eval/testclips \ 8 | -O /enhance_data/dns4_testclips/inter_subnet 9 | 10 | 11 | # Normalized to -6dB (optional) 12 | sdir="enhance_data/dns4_testclips/inter_subnet" 13 | fdir="enhance_data/dns4_testclips/inter_subnet_norm" 14 | 15 | softfiles=$(find $sdir -name "*.wav") 16 | for file in ${softfiles} 17 | do 18 | length=${#sdir}+1 19 | file=${file:$length} 20 | f=$sdir/$file 21 | echo $f 22 | dstfile=$fdir/$file 23 | echo $dstfile 24 | sox $f -b16 $dstfile rate -v -b 99.7 16k norm -6 25 | done 26 | 27 | 28 | -------------------------------------------------------------------------------- /mertrics.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python speech_enhance/tools/calculate_metrics.py \ 4 | -R /workspace/project-nas-11025-sh/speech_enhance/data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/with_reverb/clean \ 5 | -E /workspace/project-nas-11025-sh/speech_enhance/case/with_reverb/fullsubnet+/enhanced_0194 \ 6 | -M SI_SDR,STOI,WB_PESQ,NB_PESQ \ 7 | -S DNS_1 \ 8 | -D /workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/with_reverb_fullsubnet+/ -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #set -eux 4 | if test "$#" -eq 1; then 5 | stage=$(($1)) 6 | stop_stage=$(($1)) 7 | elif test "$#" -eq 2; then 8 | stage=$(($1)) 9 | stop_stage=$(($2)) 10 | else 11 | stage=0 12 | stop_stage=10 13 | fi 14 | 15 | # corpus data will be under ${data_dir}/${corpus_name}, should be read only for exp 16 | # train data(features, index) will be under ${data_dir}/${train_data_dir} 17 | cur_dir=$(pwd) 18 | data_dir=../../../data/ 19 | corpus_name=data 20 | train_data_dir=train_data_fsn_dns_master 21 | exp_id=$(pwd | xargs basename) 22 | spkr_id=$(pwd | xargs dirname | xargs basename) 23 | corpus_id=${spkr_id}_${exp_id} 24 | 25 | ############################################# 26 | # prepare files 27 | 28 | if [ ! -d ${corpus_name} ]; then 29 | ln -s ${data_dir} ${corpus_name} 30 | fi 31 | 32 | if [ ! -d ${train_data_dir} ]; then 33 | cpfs_cur_dir=${cur_dir/nas/cpfs} 34 | cpfs_data_dir=${cpfs_cur_dir}/${data_dir} 35 | mkdir -p ${cpfs_data_dir}/${train_data_dir} 36 | ln -s ${cpfs_data_dir}/${train_data_dir} ${train_data_dir} 37 | fi 38 | 39 | ############################################# 40 | # generate list of clean/noise for training, dir of synthesic noisy-clean pair for val 41 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 42 | # gen lst 43 | /usr/bin/python3 -m speech_enhance.tools.gen_lst --dataset_dir ${corpus_name}/DNS-Challenge/DNS-Challenge-master/datasets_16k/clean/ --output_lst ${train_data_dir}/clean.txt 44 | /usr/bin/python3 -m speech_enhance.tools.gen_lst --dataset_dir ${corpus_name}/DNS-Challenge/DNS-Challenge-master/datasets_16k/noise/ --output_lst ${train_data_dir}/noise.txt 45 | cat ${corpus_name}/rir/simulated_rirs_16k/*/rir_list ${corpus_name}/rir/RIRS_NOISES/real_rirs_isotropic_noises/rir_list | awk -F ' ' '{print "data/rir/"$5}' > ${train_data_dir}/rir.txt 46 | perl -pi -e 's#data/rir#data/rir_16k#g' ${train_data_dir}/rir.txt 47 | # just use the book wav as interspeech2020 48 | grep book train_data_fsn_dns_master/clean.txt > train_data_fsn_dns_master/clean_book.txt 49 | # 50 | test_set=${corpus_name}/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set 51 | if [ ! -d ${test_set} ]; then 52 | echo "please prepare the ${test_set} from https://github.com/microsoft/DNS-Challenge/tree/interspeech2020/master/datasets/test_set" 53 | exit 54 | fi 55 | fi 56 | 57 | # train 58 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 59 | CUDA_VISIBLE_DEVICES='0,1' python -m speech_enhance.tools.train -C config/train.toml -N 2 60 | fi 61 | 62 | # test 63 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 64 | # input 24k audio 65 | input_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s13/logs-s13/acoustic/analysis/gt" 66 | output_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s13/logs-s13/acoustic/analysis/gt_24k_enhance_s1_24k" 67 | 68 | input_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s10_tst_tsv_male_global_pitch_only1/logs-s10_tst_tsv_male_global_pitch_only1/eval-992914" 69 | output_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s10_tst_tsv_male_global_pitch_only1/logs-s10_tst_tsv_male_global_pitch_only1/eval-992914-enhance_dns_master_s1_24k" 70 | 71 | input_dir="data/test_cases_didi/eval-992914/" 72 | output_dir="logs/eval/test_cases_didi/eval-992914/" 73 | 74 | # do enhance(denoise) 75 | #CUDA_VISIBLE_DEVICES=0 \ 76 | python -m speech_enhance.tools.inference \ 77 | -C config/inference.toml \ 78 | -M logs/FullSubNet/train/checkpoints/best_model.tar \ 79 | -I ${input_dir} \ 80 | -O ${output_dir} 81 | #-O logs/FullSubNet/inference/ 82 | 83 | # norm volume to -6db 84 | for f in `ls ${output_dir}/*/*.wav | grep -v norm`; do 85 | echo $f 86 | fid=`basename $f .wav` 87 | fo=`dirname $f`/${fid}_norm-6db.wav 88 | echo $fo 89 | sox $f -b16 $fo rate -v -b 99.7 24k norm -6 90 | done 91 | fi 92 | -------------------------------------------------------------------------------- /speech_enhance/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/__init__.py -------------------------------------------------------------------------------- /speech_enhance/audio_zen/acoustics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/acoustics/__init__.py -------------------------------------------------------------------------------- /speech_enhance/audio_zen/acoustics/beamforming.py: -------------------------------------------------------------------------------- 1 | from torch_complex import ComplexTensor 2 | from torch_complex import functional as FC 3 | 4 | 5 | def apply_crf_filter(cRM_filter: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: 6 | """ 7 | Apply complex Ratio Filter 8 | 9 | Args: 10 | cRM_filter: complex Ratio Filter 11 | mix: mixture 12 | 13 | Returns: 14 | [B, C, F, T] 15 | """ 16 | # [B, F, T, Filter_delay] x [B, C, F, Filter_delay,T] => [B, C, F, T] 17 | es = FC.einsum("bftd, bcfdt -> bcft", [cRM_filter.conj(), mix]) 18 | return es 19 | 20 | 21 | def get_power_spectral_density_matrix(complex_tensor: ComplexTensor) -> ComplexTensor: 22 | """ 23 | Cross-channel power spectral density (PSD) matrix 24 | 25 | Args: 26 | complex_tensor: [..., F, C, T] 27 | 28 | Returns 29 | psd: [..., F, C, C] 30 | """ 31 | # outer product: [..., C_1, T] x [..., C_2, T] => [..., T, C_1, C_2] 32 | return FC.einsum("...ct,...et->...tce", [complex_tensor, complex_tensor.conj()]) 33 | 34 | 35 | def apply_beamforming_vector(beamforming_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor: 36 | # [..., C] x [..., C, T] => [..., T] 37 | # There's no relationship between frequencies. 38 | es = FC.einsum("bftc, bfct -> bft", [beamforming_vector.conj(), mix]) 39 | return es 40 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/acoustics/mask.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import librosa 4 | import numpy as np 5 | import torch 6 | 7 | from audio_zen.constant import EPSILON 8 | 9 | 10 | def build_ideal_ratio_mask(noisy_mag, clean_mag) -> torch.Tensor: 11 | """ 12 | 13 | Args: 14 | noisy_mag: [B, F, T], noisy magnitude 15 | clean_mag: [B, F, T], clean magnitude 16 | 17 | Returns: 18 | [B, F, T, 1] 19 | """ 20 | # noisy_mag_finetune = torch.sqrt(torch.square(noisy_mag) + EPSILON) 21 | # ratio_mask = clean_mag / noisy_mag_finetune 22 | ratio_mask = clean_mag / (noisy_mag + EPSILON) 23 | ratio_mask = ratio_mask[..., None] 24 | return compress_cIRM(ratio_mask, K=10, C=0.1) 25 | 26 | 27 | def build_complex_ideal_ratio_mask(noisy: torch.complex64, clean: torch.complex64) -> torch.Tensor: 28 | """ 29 | 30 | Args: 31 | noisy: [B, F, T], noisy complex-valued stft coefficients 32 | clean: [B, F, T], clean complex-valued stft coefficients 33 | 34 | Returns: 35 | [B, F, T, 2] 36 | """ 37 | denominator = torch.square(noisy.real) + torch.square(noisy.imag) + EPSILON 38 | 39 | mask_real = (noisy.real * clean.real + noisy.imag * clean.imag) / denominator 40 | mask_imag = (noisy.real * clean.imag - noisy.imag * clean.real) / denominator 41 | 42 | complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1) 43 | 44 | return compress_cIRM(complex_ratio_mask, K=10, C=0.1) 45 | 46 | 47 | def compress_cIRM(mask, K=10, C=0.1): 48 | """ 49 | Compress from (-inf, +inf) to [-K ~ K] 50 | """ 51 | if torch.is_tensor(mask): 52 | mask = -100 * (mask <= -100) + mask * (mask > -100) 53 | mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask)) 54 | else: 55 | mask = -100 * (mask <= -100) + mask * (mask > -100) 56 | mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask)) 57 | return mask 58 | 59 | 60 | def decompress_cIRM(mask, K=10, limit=9.9): 61 | mask = limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit) 62 | mask = -K * torch.log((K - mask) / (K + mask)) 63 | return mask 64 | 65 | 66 | def complex_mul(noisy_r, noisy_i, mask_r, mask_i): 67 | r = noisy_r * mask_r - noisy_i * mask_i 68 | i = noisy_r * mask_i + noisy_i * mask_r 69 | return r, i 70 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/acoustics/utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | 3 | 4 | def transform_pesq_range(pesq_score): 5 | """ 6 | transform PESQ metric range from [-0.5 ~ 4.5] to [0 ~ 1] 7 | """ 8 | return (pesq_score + 0.5) / 5 9 | 10 | 11 | def load_wav(path, sr=16000): 12 | return librosa.load(path, sr=sr)[0] -------------------------------------------------------------------------------- /speech_enhance/audio_zen/constant.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | 5 | NEG_INF = torch.finfo(torch.float32).min 6 | PI = math.pi 7 | SOUND_SPEED = 343 # m/s 8 | EPSILON = np.finfo(np.float32).eps 9 | MAX_INT16 = np.iinfo(np.int16).max 10 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/dataset/__init__.py -------------------------------------------------------------------------------- /speech_enhance/audio_zen/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data 2 | 3 | 4 | class BaseDataset(data.Dataset): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | @staticmethod 9 | def _offset_and_limit(dataset_list, offset, limit): 10 | dataset_list = dataset_list[offset:] 11 | if limit: 12 | dataset_list = dataset_list[:limit] 13 | return dataset_list 14 | 15 | @staticmethod 16 | def _parse_snr_range(snr_range): 17 | assert len(snr_range) == 2, f"The range of SNR should be [low, high], not {snr_range}." 18 | assert snr_range[0] <= snr_range[-1], f"The low SNR should not larger than high SNR." 19 | 20 | low, high = snr_range 21 | snr_list = [] 22 | for i in range(low, high + 1, 1): 23 | snr_list.append(i) 24 | 25 | return snr_list 26 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/inferencer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/inferencer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/audio_zen/inferencer/base_inferencer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | from pathlib import Path 4 | 5 | import librosa 6 | import numpy as np 7 | import soundfile as sf 8 | import toml 9 | import torch 10 | from torch.nn import functional 11 | from torch.utils.data import DataLoader 12 | from tqdm import tqdm 13 | import time 14 | 15 | from audio_zen.acoustics.feature import stft, istft, mc_stft 16 | from audio_zen.utils import initialize_module, prepare_device, prepare_empty_dir 17 | 18 | # for log 19 | from utils.logger import log 20 | print=log 21 | 22 | class BaseInferencer: 23 | def __init__(self, config, checkpoint_path, output_dir): 24 | checkpoint_path = Path(checkpoint_path).expanduser().absolute() 25 | root_dir = Path(output_dir).expanduser().absolute() 26 | self.device = prepare_device(torch.cuda.device_count()) 27 | 28 | print("Loading inference dataset...") 29 | self.dataloader = self._load_dataloader(config["dataset"]) 30 | print("Loading model...") 31 | 32 | self.model, epoch = self._load_model(config["model"], checkpoint_path, self.device) 33 | # self.model = self._load_model(config["model"], checkpoint_path, self.device) 34 | # epoch = 64 35 | 36 | self.inference_config = config["inferencer"] 37 | 38 | self.enhanced_dir = root_dir / f"enhanced_{str(epoch).zfill(4)}" 39 | prepare_empty_dir([self.enhanced_dir]) 40 | 41 | # Acoustics 42 | self.acoustic_config = config["acoustics"] 43 | 44 | # Supported STFT 45 | self.n_fft = self.acoustic_config["n_fft"] 46 | self.hop_length = self.acoustic_config["hop_length"] 47 | self.win_length = self.acoustic_config["win_length"] 48 | self.sr = self.acoustic_config["sr"] 49 | 50 | # See utils_backup.py 51 | self.torch_stft = partial(stft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) 52 | self.torch_istft = partial(istft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) 53 | self.torch_mc_stft = partial(mc_stft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) 54 | self.librosa_stft = partial(librosa.stft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length) 55 | self.librosa_istft = partial(librosa.istft, hop_length=self.hop_length, win_length=self.win_length) 56 | 57 | print("Configurations are as follows: ") 58 | print(toml.dumps(config)) 59 | with open((root_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle: 60 | toml.dump(config, handle) 61 | 62 | @staticmethod 63 | def _load_dataloader(dataset_config): 64 | dataset = initialize_module(dataset_config["path"], args=dataset_config["args"], initialize=True) 65 | dataloader = DataLoader( 66 | dataset=dataset, 67 | batch_size=1, 68 | num_workers=0, 69 | ) 70 | return dataloader 71 | 72 | @staticmethod 73 | def _unfold(input, pad_mode, n_neighbor): 74 | """ 75 | 沿着频率轴,将语谱图划分为多个 overlap 的子频带 76 | 77 | Args: 78 | input: [B, C, F, T] 79 | 80 | Returns: 81 | [B, N, C, F, T], F 为子频带的频率轴大小, e.g. [2, 161, 1, 19, 200] 82 | """ 83 | assert input.dim() == 4, f"The dim of input is {input.dim()}, which should be 4." 84 | batch_size, n_channels, n_freqs, n_frames = input.size() 85 | output = input.reshape(batch_size * n_channels, 1, n_freqs, n_frames) 86 | sub_band_n_freqs = n_neighbor * 2 + 1 87 | 88 | output = functional.pad(output, [0, 0, n_neighbor, n_neighbor], mode=pad_mode) 89 | output = functional.unfold(output, (sub_band_n_freqs, n_frames)) 90 | assert output.shape[-1] == n_freqs, f"n_freqs != N (sub_band), {n_freqs} != {output.shape[-1]}" 91 | 92 | # 拆分 unfold 中间的维度 93 | output = output.reshape(batch_size, n_channels, sub_band_n_freqs, n_frames, n_freqs) 94 | output = output.permute(0, 4, 1, 2, 3).contiguous() # permute 本质上与 reshape 可是不同的 ...,得到的维度相同,但 entity 不同啊 95 | return output 96 | 97 | @staticmethod 98 | def _load_model(model_config, checkpoint_path, device): 99 | model = initialize_module(model_config["path"], args=model_config["args"], initialize=True) 100 | model_checkpoint = torch.load(checkpoint_path, map_location="cpu") 101 | 102 | # model_static_dict = model_checkpoint 103 | model_static_dict = model_checkpoint["model"] 104 | epoch = model_checkpoint["epoch"] 105 | print(f"当前正在处理 tar 格式的模型断点,其 epoch 为:{epoch}.") 106 | 107 | model.load_state_dict(model_static_dict) 108 | model.to(device) 109 | model.eval() 110 | return model, model_checkpoint["epoch"] 111 | # return model 112 | 113 | @torch.no_grad() 114 | def multi_channel_mag_to_mag(self, noisy, inference_args=None): 115 | """ 116 | 模型的输入为带噪语音的 **幅度谱**,输出同样为 **幅度谱** 117 | """ 118 | mixture_stft_coefficients = self.torch_mc_stft(noisy) 119 | mixture_mag = (mixture_stft_coefficients.real ** 2 + mixture_stft_coefficients.imag ** 2) ** 0.5 120 | 121 | enhanced_mag = self.model(mixture_mag) 122 | 123 | # Phase of the reference channel 124 | reference_channel_stft_coefficients = mixture_stft_coefficients[:, 0, ...] 125 | noisy_phase = torch.atan2(reference_channel_stft_coefficients.imag, reference_channel_stft_coefficients.real) 126 | complex_tensor = torch.stack([(enhanced_mag * torch.cos(noisy_phase)), (enhanced_mag * torch.sin(noisy_phase))], dim=-1) 127 | enhanced = self.torch_istft(complex_tensor, length=noisy.shape[-1]) 128 | 129 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 130 | 131 | return enhanced 132 | 133 | @torch.no_grad() 134 | def __call__(self): 135 | inference_type = self.inference_config["type"] 136 | assert inference_type in dir(self), f"Not implemented Inferencer type: {inference_type}" 137 | 138 | inference_args = self.inference_config["args"] 139 | 140 | for noisy, name in tqdm(self.dataloader, desc="Inference"): 141 | assert len(name) == 1, "The batch size of inference stage must 1." 142 | name = name[0] 143 | 144 | t1 = time.time() 145 | enhanced = getattr(self, inference_type)(noisy.to(self.device), inference_args) 146 | t2 = time.time() 147 | 148 | if (abs(enhanced) > 1).any(): 149 | print(f"Warning: enhanced is not in the range [-1, 1], {name}") 150 | 151 | amp = np.iinfo(np.int16).max 152 | enhanced = np.int16(0.8 * amp * enhanced / np.max(np.abs(enhanced))) 153 | 154 | # cal rtf 155 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"]) 156 | print(f"{name}, rtf: {rtf}") 157 | 158 | # clnsp102_traffic_248091_3_snr0_tl-21_fileid_268 => clean_fileid_0 159 | # name = "clean_" + "_".join(name.split("_")[-2:]) 160 | sf.write(self.enhanced_dir / f"{name}.wav", enhanced, samplerate=self.acoustic_config["sr"]) 161 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | l1_loss = torch.nn.L1Loss 4 | mse_loss = torch.nn.MSELoss 5 | 6 | 7 | def si_snr_loss(): 8 | def si_snr(x, s, eps=1e-8): 9 | """ 10 | 11 | Args: 12 | x: Enhanced fo shape [B, T] 13 | s: Reference of shape [B, T] 14 | eps: 15 | 16 | Returns: 17 | si_snr: [B] 18 | """ 19 | def l2norm(mat, keep_dim=False): 20 | return torch.norm(mat, dim=-1, keepdim=keep_dim) 21 | 22 | if x.shape != s.shape: 23 | raise RuntimeError(f"Dimension mismatch when calculate si_snr, {x.shape} vs {s.shape}") 24 | 25 | x_zm = x - torch.mean(x, dim=-1, keepdim=True) 26 | s_zm = s - torch.mean(s, dim=-1, keepdim=True) 27 | 28 | t = torch.sum(x_zm * s_zm, dim=-1, keepdim=True) * s_zm / (l2norm(s_zm, keep_dim=True) ** 2 + eps) 29 | 30 | return -torch.mean(20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))) 31 | 32 | return si_snr 33 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mir_eval.separation import bss_eval_sources 3 | from pesq import pesq 4 | from pypesq import pesq as nb_pesq 5 | from pystoi.stoi import stoi 6 | import librosa 7 | 8 | def _scale_bss_eval(references, estimate, idx, compute_sir_sar=True): 9 | """ 10 | Helper for scale_bss_eval to avoid infinite recursion loop. 11 | """ 12 | source = references[..., idx] 13 | source_energy = (source ** 2).sum() 14 | 15 | alpha = ( 16 | source @ estimate / source_energy 17 | ) 18 | 19 | e_true = source 20 | e_res = estimate - e_true 21 | 22 | signal = (e_true ** 2).sum() 23 | noise = (e_res ** 2).sum() 24 | 25 | snr = 10 * np.log10(signal / noise) 26 | 27 | e_true = source * alpha 28 | e_res = estimate - e_true 29 | 30 | signal = (e_true ** 2).sum() 31 | noise = (e_res ** 2).sum() 32 | 33 | si_sdr = 10 * np.log10(signal / noise) 34 | 35 | srr = -10 * np.log10((1 - (1 / alpha)) ** 2) 36 | sd_sdr = snr + 10 * np.log10(alpha ** 2) 37 | 38 | si_sir = np.nan 39 | si_sar = np.nan 40 | 41 | if compute_sir_sar: 42 | references_projection = references.T @ references 43 | 44 | references_onto_residual = np.dot(references.transpose(), e_res) 45 | b = np.linalg.solve(references_projection, references_onto_residual) 46 | 47 | e_interf = np.dot(references, b) 48 | e_artif = e_res - e_interf 49 | 50 | si_sir = 10 * np.log10(signal / (e_interf ** 2).sum()) 51 | si_sar = 10 * np.log10(signal / (e_artif ** 2).sum()) 52 | 53 | return si_sdr, si_sir, si_sar, sd_sdr, snr, srr 54 | 55 | 56 | def SDR(reference, estimation, sr=16000): 57 | sdr, _, _, _ = bss_eval_sources(reference[None, :], estimation[None, :]) 58 | return sdr 59 | 60 | 61 | def SI_SDR(reference, estimation, sr=16000): 62 | """ 63 | Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) 64 | 65 | Args: 66 | reference: numpy.ndarray, [..., T] 67 | estimation: numpy.ndarray, [..., T] 68 | 69 | Returns: 70 | SI-SDR 71 | 72 | References 73 | SDR– Half- Baked or Well Done? (http://www.merl.com/publications/docs/TR2019-013.pdf) 74 | """ 75 | estimation, reference = np.broadcast_arrays(estimation, reference) 76 | reference_energy = np.sum(reference ** 2, axis=-1, keepdims=True) 77 | 78 | optimal_scaling = np.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy 79 | 80 | projection = optimal_scaling * reference 81 | 82 | noise = estimation - projection 83 | 84 | ratio = np.sum(projection ** 2, axis=-1) / np.sum(noise ** 2, axis=-1) 85 | return 10 * np.log10(ratio) 86 | 87 | 88 | def STOI(ref, est, sr=16000): 89 | return stoi(ref, est, sr, extended=False) 90 | 91 | 92 | def WB_PESQ(ref, est, sr=16000): 93 | if sr != 16000: 94 | wb_ref = librosa.resample(ref, sr, 16000) 95 | wb_est = librosa.resample(est, sr, 16000) 96 | else: 97 | wb_ref = ref 98 | wb_est = est 99 | # pesq will not downsample internally 100 | return pesq(16000, wb_ref, wb_est, "wb") 101 | 102 | 103 | def NB_PESQ(ref, est, sr=16000): 104 | if sr != 8000: 105 | nb_ref = librosa.resample(ref, sr, 8000) 106 | nb_est = librosa.resample(est, sr, 8000) 107 | else: 108 | nb_ref = ref 109 | nb_est = est 110 | # nb_pesq downsample to 8000 internally. 111 | return nb_pesq(nb_ref, nb_est, 8000) 112 | 113 | mos_metrics = None 114 | def MOSNET(ref, est, sr=16000): 115 | ## 116 | global mos_metrics 117 | if mos_metrics is None: 118 | import speechmetrics 119 | window_length = 10 # seconds 120 | mos_metrics = speechmetrics.load('mosnet', window_length) 121 | ## 122 | scores = mos_metrics(est, rate=sr) 123 | avg_score = np.mean(scores["mosnet"]) 124 | #print(avg_score) 125 | return avg_score 126 | 127 | # Only registered metric can be used. 128 | REGISTERED_METRICS = { 129 | "SI_SDR": SI_SDR, 130 | "STOI": STOI, 131 | "WB_PESQ": WB_PESQ, 132 | "NB_PESQ": NB_PESQ, 133 | "MOSNET": MOSNET 134 | } 135 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/model/__init__.py -------------------------------------------------------------------------------- /speech_enhance/audio_zen/model/module/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/model/module/__init__.py -------------------------------------------------------------------------------- /speech_enhance/audio_zen/model/module/causal_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CausalConvBlock(nn.Module): 6 | def __init__(self, in_channels, out_channels, encoder_activate_function, **kwargs): 7 | super().__init__() 8 | self.conv = nn.Conv2d( 9 | in_channels=in_channels, 10 | out_channels=out_channels, 11 | kernel_size=(3, 2), 12 | stride=(2, 1), 13 | padding=(0, 1), 14 | **kwargs # 这里不是左右 pad,而是上下 pad 为 0,左右分别 pad 1... 15 | ) 16 | self.norm = nn.BatchNorm2d(out_channels) 17 | self.activation = getattr(nn, encoder_activate_function)() 18 | 19 | def forward(self, x): 20 | """ 21 | 2D Causal convolution. 22 | 23 | Args: 24 | x: [B, C, F, T] 25 | Returns: 26 | [B, C, F, T] 27 | """ 28 | x = self.conv(x) 29 | x = x[:, :, :, :-1] # chomp size 30 | x = self.norm(x) 31 | x = self.activation(x) 32 | return x 33 | 34 | 35 | class CausalTransConvBlock(nn.Module): 36 | def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)): 37 | super().__init__() 38 | self.conv = nn.ConvTranspose2d( 39 | in_channels=in_channels, 40 | out_channels=out_channels, 41 | kernel_size=(3, 2), 42 | stride=(2, 1), 43 | output_padding=output_padding 44 | ) 45 | self.norm = nn.BatchNorm2d(out_channels) 46 | if is_last: 47 | self.activation = nn.ReLU() 48 | else: 49 | self.activation = nn.ELU() 50 | 51 | def forward(self, x): 52 | """ 53 | 2D Causal convolution. 54 | 55 | Args: 56 | x: [B, C, F, T] 57 | Returns: 58 | [B, C, F, T] 59 | """ 60 | x = self.conv(x) 61 | x = x[:, :, :, :-1] # chomp size 62 | x = self.norm(x) 63 | x = self.activation(x) 64 | return x 65 | 66 | 67 | class TCNBlock(nn.Module): 68 | def __init__(self, in_channels=257, hidden_channel=512, out_channels=257, kernel_size=3, dilation=1, 69 | use_skip_connection=True, causal=False): 70 | super().__init__() 71 | self.conv1x1 = nn.Conv1d(in_channels, hidden_channel, 1) 72 | self.prelu1 = nn.PReLU() 73 | self.norm1 = nn.GroupNorm(1, hidden_channel, eps=1e-8) 74 | padding = (dilation * (kernel_size - 1)) // 2 if not causal else ( 75 | dilation * (kernel_size - 1)) 76 | self.depthwise_conv = nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1, 77 | groups=hidden_channel, padding=padding, dilation=dilation) 78 | self.prelu2 = nn.PReLU() 79 | self.norm2 = nn.GroupNorm(1, hidden_channel, eps=1e-8) 80 | self.sconv = nn.Conv1d(hidden_channel, out_channels, 1) 81 | # self.tcn_block = nn.Sequential( 82 | # nn.Conv1d(in_channels, hidden_channel, 1), 83 | # nn.PReLU(), 84 | # nn.GroupNorm(1, hidden_channel, eps=1e-8), 85 | # nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1, 86 | # groups=hidden_channel, padding=padding, dilation=dilation, bias=True), 87 | # nn.PReLU(), 88 | # nn.GroupNorm(1, hidden_channel, eps=1e-8), 89 | # nn.Conv1d(hidden_channel, out_channels, 1) 90 | # ) 91 | 92 | self.causal = causal 93 | self.padding = padding 94 | self.use_skip_connection = use_skip_connection 95 | 96 | def forward(self, x): 97 | """ 98 | x: [channels, T] 99 | """ 100 | if self.use_skip_connection: 101 | y = self.conv1x1(x) 102 | y = self.norm1(self.prelu1(y)) 103 | y = self.depthwise_conv(y) 104 | if self.causal: 105 | y = y[:, :, :-self.padding] 106 | y = self.norm2(self.prelu2(y)) 107 | output = self.sconv(y) 108 | return x + output 109 | else: 110 | y = self.conv1x1(x) 111 | y = self.norm1(self.prelu1(y)) 112 | y = self.depthwise_conv(y) 113 | if self.causal: 114 | y = y[:, :, :-self.padding] 115 | y = self.norm2(self.prelu2(y)) 116 | output = self.sconv(y) 117 | return output 118 | 119 | 120 | class STCNBlock(nn.Module): 121 | def __init__(self, in_channels=257, hidden_channel=512, out_channels=257, kernel_size=3, dilation=1, 122 | use_skip_connection=True, causal=False): 123 | super().__init__() 124 | self.conv1x1 = nn.Conv1d(in_channels, hidden_channel, 1) 125 | self.prelu1 = nn.PReLU() 126 | self.norm1 = nn.GroupNorm(1, hidden_channel, eps=1e-8) 127 | padding = (dilation * (kernel_size - 1)) // 2 if not causal else ( 128 | dilation * (kernel_size - 1)) 129 | self.depthwise_conv = nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1, 130 | groups=hidden_channel, padding=padding, dilation=dilation) 131 | self.prelu2 = nn.PReLU() 132 | self.norm2 = nn.GroupNorm(1, hidden_channel, eps=1e-8) 133 | self.sconv = nn.Conv1d(hidden_channel, out_channels, 1) 134 | # self.tcn_block = nn.Sequential( 135 | # nn.Conv1d(in_channels, hidden_channel, 1), 136 | # nn.PReLU(), 137 | # nn.GroupNorm(1, hidden_channel, eps=1e-8), 138 | # nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1, 139 | # groups=hidden_channel, padding=padding, dilation=dilation, bias=True), 140 | # nn.PReLU(), 141 | # nn.GroupNorm(1, hidden_channel, eps=1e-8), 142 | # nn.Conv1d(hidden_channel, out_channels, 1) 143 | # ) 144 | 145 | self.causal = causal 146 | self.padding = padding 147 | self.use_skip_connection = use_skip_connection 148 | 149 | def forward(self, x): 150 | """ 151 | x: [channels, T] 152 | """ 153 | if self.use_skip_connection: 154 | y = self.conv1x1(x) 155 | y = self.norm1(self.prelu1(y)) 156 | y = self.depthwise_conv(y) 157 | if self.causal: 158 | y = y[:, :, :-self.padding] 159 | y = self.norm2(self.prelu2(y)) 160 | output = self.sconv(y) 161 | return x + output 162 | else: 163 | y = self.conv1x1(x) 164 | y = self.norm1(self.prelu1(y)) 165 | y = self.depthwise_conv(y) 166 | if self.causal: 167 | y = y[:, :, :-self.padding] 168 | y = self.norm2(self.prelu2(y)) 169 | output = self.sconv(y) 170 | return output 171 | 172 | 173 | if __name__ == '__main__': 174 | a = torch.rand(2, 1, 19, 200) 175 | l1 = CausalConvBlock(1, 20, kernel_size=(3, 2), stride=(2, 1), padding=(0, 1), ) 176 | l2 = CausalConvBlock(20, 40, kernel_size=(3, 2), stride=(1, 1), padding=1, ) 177 | l3 = CausalConvBlock(40, 40, kernel_size=(3, 2), stride=(2, 1), padding=(0, 1), ) 178 | l4 = CausalConvBlock(40, 40, kernel_size=(3, 2), stride=(1, 1), padding=1, ) 179 | print(l1(a).shape) 180 | print(l4(l3(l2(l1(a)))).shape) 181 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/model/module/feature_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def cumulative_norm(input): 6 | eps = 1e-10 7 | 8 | # [B, C, F, T] 9 | batch_size, n_channels, n_freqs, n_frames = input.size() 10 | device = input.device 11 | data_type = input.dtype 12 | 13 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames) 14 | 15 | step_sum = torch.sum(input, dim=1) # [B, T] 16 | step_pow_sum = torch.sum(torch.square(input), dim=1) 17 | 18 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 19 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T] 20 | 21 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device) 22 | entry_count = entry_count.reshape(1, n_frames) # [1, T] 23 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 24 | 25 | cum_mean = cumulative_sum / entry_count # B, T 26 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T 27 | cum_std = (cum_var + eps).sqrt() # B, T 28 | 29 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames) 30 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames) 31 | 32 | x = (input - cum_mean) / cum_std 33 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames) 34 | 35 | return x 36 | 37 | 38 | class CumulativeMagSpectralNorm(nn.Module): 39 | def __init__(self, cumulative=False, use_mid_freq_mu=False): 40 | """ 41 | 42 | Args: 43 | cumulative: 是否采用累积的方式计算 mu 44 | use_mid_freq_mu: 仅采用中心频率的 mu 来代替全局 mu 45 | 46 | Notes: 47 | 先算均值再累加 等同于 先累加再算均值 48 | 49 | """ 50 | super().__init__() 51 | self.eps = 1e-6 52 | self.cumulative = cumulative 53 | self.use_mid_freq_mu = use_mid_freq_mu 54 | 55 | def forward(self, input): 56 | assert input.ndim == 4, f"{self.__name__} only support 4D input." 57 | batch_size, n_channels, n_freqs, n_frames = input.size() 58 | device = input.device 59 | data_type = input.dtype 60 | 61 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames) 62 | 63 | if self.use_mid_freq_mu: 64 | step_sum = input[:, int(n_freqs // 2 - 1), :] # [B * C, F, T] => [B * C, T] 65 | else: 66 | step_sum = torch.mean(input, dim=1) # [B * C, F, T] => [B * C, T] 67 | 68 | if self.cumulative: 69 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 70 | entry_count = torch.arange(1, n_frames + 1, dtype=data_type, device=device) 71 | entry_count = entry_count.reshape(1, n_frames) # [1, T] 72 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 73 | 74 | mu = cumulative_sum / entry_count # [B * C, T] 75 | mu = mu.reshape(batch_size * n_channels, 1, n_frames) 76 | else: 77 | mu = torch.mean(step_sum, dim=-1) # [B * C] 78 | mu = mu.reshape(batch_size * n_channels, 1, 1) # [B * C, 1, 1] 79 | 80 | input_normed = input / (mu + self.eps) 81 | input_normed = input_normed.reshape(batch_size, n_channels, n_freqs, n_frames) 82 | return input_normed 83 | 84 | 85 | if __name__ == '__main__': 86 | a = torch.rand(2, 1, 160, 200) 87 | ln = CumulativeMagSpectralNorm(cumulative=False, use_mid_freq_mu=False) 88 | ln_1 = CumulativeMagSpectralNorm(cumulative=True, use_mid_freq_mu=False) 89 | ln_2 = CumulativeMagSpectralNorm(cumulative=False, use_mid_freq_mu=False) 90 | ln_3 = CumulativeMagSpectralNorm(cumulative=True, use_mid_freq_mu=False) 91 | print(ln(a).mean()) 92 | print(ln_1(a).mean()) 93 | print(ln_2(a).mean()) 94 | print(ln_3(a).mean()) 95 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/model/module/si_module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class subband_interaction(nn.Module): 10 | def __init__(self, input_size, hidden_size): 11 | super(subband_interaction, self).__init__() 12 | """ 13 | Subband Interaction Module 14 | """ 15 | 16 | self.input_linear = nn.Sequential( 17 | nn.Linear(input_size, hidden_size), 18 | nn.PReLU() 19 | ) 20 | self.mean_linear = nn.Sequential( 21 | nn.Linear(hidden_size, hidden_size), 22 | nn.PReLU() 23 | ) 24 | self.output_linear = nn.Sequential( 25 | nn.Linear(hidden_size * 2, input_size), 26 | nn.PReLU() 27 | ) 28 | self.norm = nn.GroupNorm(1, input_size) 29 | 30 | def forward(self, input): 31 | """ 32 | input: [B, F, F_s, T] 33 | """ 34 | B, G, N, T = input.shape 35 | 36 | # Transform 37 | group_input = input # [B, F, F_s, T] 38 | group_input = group_input.permute(0, 3, 1, 2).contiguous().view(-1, N) # [B * T * F, F_s] 39 | group_output = self.input_linear(group_input).view(B, T, G, -1) # [B, T, F, H] 40 | 41 | # Avg pooling 42 | group_mean = group_output.mean(2).view(B * T, -1) # [B * T, H] 43 | 44 | # Concate and transform 45 | group_output = group_output.view(B * T, G, -1) # [B * T, F, H] 46 | group_mean = self.mean_linear(group_mean).unsqueeze(1).expand_as(group_output).contiguous() # [B * T, F, H] 47 | group_output = torch.cat([group_output, group_mean], 2) # [B * T, F, 2H] 48 | group_output = self.output_linear(group_output.view(-1, group_output.shape[-1])) # [B * T * F, F_s] 49 | group_output = group_output.view(B, T, G, -1).permute(0, 2, 3, 1).contiguous() # [B, F, F_s, T] 50 | group_output = self.norm(group_output.view(B * G, N, T)) # [B * F, F_s, T] 51 | output = input + group_output.view(input.shape) # [B, F, F_s, T] 52 | 53 | return output 54 | -------------------------------------------------------------------------------- /speech_enhance/audio_zen/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/trainer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/audio_zen/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import time 4 | from copy import deepcopy 5 | from functools import reduce 6 | 7 | import torch 8 | 9 | # for log 10 | from utils.logger import log 11 | print=log 12 | 13 | def load_checkpoint(checkpoint_path, device): 14 | _, ext = os.path.splitext(os.path.basename(checkpoint_path)) 15 | assert ext in (".pth", ".tar"), "Only support ext and tar extensions of l1 checkpoint." 16 | model_checkpoint = torch.load(os.path.abspath(os.path.expanduser(checkpoint_path)), map_location=device) 17 | 18 | if ext == ".pth": 19 | print(f"Loading {checkpoint_path}.") 20 | return model_checkpoint 21 | else: # load tar 22 | print(f"Loading {checkpoint_path}, epoch = {model_checkpoint['epoch']}.") 23 | return model_checkpoint["l1"] 24 | 25 | 26 | def prepare_empty_dir(dirs, resume=False): 27 | """ 28 | if resume the experiment, assert the dirs exist. If not the resume experiment, set up new dirs. 29 | 30 | Args: 31 | dirs (list): directors list 32 | resume (bool): whether to resume experiment, default is False 33 | """ 34 | for dir_path in dirs: 35 | if resume: 36 | assert dir_path.exists(), "In resume mode, you must be have an old experiment dir." 37 | else: 38 | dir_path.mkdir(parents=True, exist_ok=True) 39 | 40 | 41 | def check_nan(tensor, key=""): 42 | if torch.sum(torch.isnan(tensor)) > 0: 43 | print(f"Found NaN in {key}") 44 | 45 | 46 | class ExecutionTime: 47 | """ 48 | Count execution time. 49 | 50 | Examples: 51 | timer = ExecutionTime() 52 | ... 53 | print(f"Finished in {timer.duration()} seconds.") 54 | """ 55 | 56 | def __init__(self): 57 | self.start_time = time.time() 58 | 59 | def duration(self): 60 | return int(time.time() - self.start_time) 61 | 62 | 63 | def initialize_module(path: str, args: dict = None, initialize: bool = True): 64 | """ 65 | Load module dynamically with "args". 66 | 67 | Args: 68 | path: module path in this project. 69 | args: parameters that passes to the Class or the Function in the module. 70 | initialize: initialize the Class or the Function with args. 71 | 72 | Examples: 73 | Config items are as follows: 74 | 75 | [model] 76 | path = "model.full_sub_net.FullSubNetModel" 77 | [model.args] 78 | n_frames = 32 79 | ... 80 | 81 | This function will: 82 | 1. Load the "model.full_sub_net" module. 83 | 2. Call "FullSubNetModel" Class (or Function) in "model.full_sub_net" module. 84 | 3. If initialize is True: 85 | instantiate (or call) the Class (or the Function) and pass the parameters (in "[model.args]") to it. 86 | """ 87 | module_path = ".".join(path.split(".")[:-1]) 88 | class_or_function_name = path.split(".")[-1] 89 | 90 | module = importlib.import_module(module_path) 91 | class_or_function = getattr(module, class_or_function_name) 92 | 93 | if initialize: 94 | if args: 95 | return class_or_function(**args) 96 | else: 97 | return class_or_function() 98 | else: 99 | return class_or_function 100 | 101 | 102 | def print_tensor_info(tensor, flag="Tensor"): 103 | def floor_tensor(float_tensor): 104 | return int(float(float_tensor) * 1000) / 1000 105 | 106 | print( 107 | f"{flag}\n" 108 | f"\t" 109 | f"max: {floor_tensor(torch.max(tensor))}, min: {float(torch.min(tensor))}, " 110 | f"mean: {floor_tensor(torch.mean(tensor))}, std: {floor_tensor(torch.std(tensor))}") 111 | 112 | 113 | def set_requires_grad(nets, requires_grad=False): 114 | """ 115 | Args: 116 | nets: list of networks 117 | requires_grad 118 | """ 119 | if not isinstance(nets, list): 120 | nets = [nets] 121 | for net in nets: 122 | if net is not None: 123 | for param in net.parameters(): 124 | param.requires_grad = requires_grad 125 | 126 | 127 | def merge_config(*config_dicts): 128 | """ 129 | Deep merge configuration dicts. 130 | 131 | Args: 132 | *config_dicts: any number of configuration dicts. 133 | 134 | Notes: 135 | 1. The values of item in the later configuration dict(s) will update the ones in the former dict(s). 136 | 2. The key in the later dict must be exist in the former dict. It means that the first dict must consists of all keys. 137 | 138 | Examples: 139 | a = [ 140 | "a": 1, 141 | "b": 2, 142 | "c": { 143 | "d": 1 144 | } 145 | ] 146 | b = [ 147 | "a": 2, 148 | "b": 2, 149 | "c": { 150 | "e": 1 151 | } 152 | ] 153 | c = merge_config(a, b) 154 | c = [ 155 | "a": 2, 156 | "b": 2, 157 | "c": { 158 | "d": 1, 159 | "e": 1 160 | } 161 | ] 162 | 163 | Returns: 164 | New deep-copied configuration dict. 165 | """ 166 | 167 | def merge(older_dict, newer_dict): 168 | for new_key in newer_dict: 169 | if new_key not in older_dict: 170 | # Checks items in custom config must be within common config 171 | raise KeyError(f"Key {new_key} is not exist in the common config.") 172 | 173 | if isinstance(older_dict[new_key], dict): 174 | older_dict[new_key] = merge(older_dict[new_key], newer_dict[new_key]) 175 | else: 176 | older_dict[new_key] = deepcopy(newer_dict[new_key]) 177 | 178 | return older_dict 179 | 180 | return reduce(merge, config_dicts[1:], deepcopy(config_dicts[0])) 181 | 182 | 183 | def prepare_device(n_gpu: int, keep_reproducibility=False): 184 | """ 185 | Choose to use CPU or GPU depend on the value of "n_gpu". 186 | 187 | Args: 188 | n_gpu(int): the number of GPUs used in the experiment. if n_gpu == 0, use CPU; if n_gpu >= 1, use GPU. 189 | keep_reproducibility (bool): if we need to consider the repeatability of experiment, set keep_reproducibility to True. 190 | 191 | See Also 192 | Reproducibility: https://pytorch.org/docs/stable/notes/randomness.html 193 | """ 194 | if n_gpu == 0: 195 | print("Using CPU in the experiment.") 196 | device = torch.device("cpu") 197 | else: 198 | # possibly at the cost of reduced performance 199 | if keep_reproducibility: 200 | print("Using CuDNN deterministic mode in the experiment.") 201 | torch.backends.cudnn.benchmark = False # ensures that CUDA selects the same convolution algorithm each time 202 | torch.set_deterministic(True) # configures PyTorch only to use deterministic implementation 203 | else: 204 | # causes cuDNN to benchmark multiple convolution algorithms and select the fastest 205 | torch.backends.cudnn.benchmark = True 206 | 207 | device = torch.device("cuda:0") 208 | 209 | return device 210 | 211 | 212 | def expand_path(path): 213 | return os.path.abspath(os.path.expanduser(path)) 214 | 215 | 216 | def basename(path): 217 | filename, ext = os.path.splitext(os.path.basename(path)) 218 | return filename, ext -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/dataset/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/dataset/dataset_inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import librosa 4 | import numpy as np 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.utils import basename 8 | 9 | 10 | class Dataset(BaseDataset): 11 | def __init__(self, 12 | dataset_dir_list, 13 | sr, 14 | ): 15 | """ 16 | Args: 17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list 18 | """ 19 | super().__init__() 20 | assert isinstance(dataset_dir_list, list) 21 | self.sr = sr 22 | 23 | noisy_file_path_list = [] 24 | for dataset_dir in dataset_dir_list: 25 | dataset_dir = Path(dataset_dir).expanduser().absolute() 26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted 27 | 28 | self.noisy_file_path_list = noisy_file_path_list 29 | self.length = len(self.noisy_file_path_list) 30 | 31 | def __len__(self): 32 | return self.length 33 | 34 | def __getitem__(self, item): 35 | noisy_file_path = self.noisy_file_path_list[item] 36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0] 37 | noisy_y = noisy_y.astype(np.float32) 38 | 39 | return noisy_y, basename(noisy_file_path)[0] 40 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/dataset/dataset_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample 5 | from audio_zen.dataset.base_dataset import BaseDataset 6 | from audio_zen.utils import expand_path 7 | from joblib import Parallel, delayed 8 | from scipy import signal 9 | from tqdm import tqdm 10 | 11 | 12 | class Dataset(BaseDataset): 13 | def __init__(self, 14 | clean_dataset, 15 | clean_dataset_limit, 16 | clean_dataset_offset, 17 | noise_dataset, 18 | noise_dataset_limit, 19 | noise_dataset_offset, 20 | rir_dataset, 21 | rir_dataset_limit, 22 | rir_dataset_offset, 23 | snr_range, 24 | reverb_proportion, 25 | silence_length, 26 | target_dB_FS, 27 | target_dB_FS_floating_value, 28 | sub_sample_length, 29 | sr, 30 | pre_load_clean_dataset, 31 | pre_load_noise, 32 | pre_load_rir, 33 | num_workers 34 | ): 35 | """ 36 | Dynamic mixing for training 37 | 38 | Args: 39 | clean_dataset_limit: 40 | clean_dataset_offset: 41 | noise_dataset_limit: 42 | noise_dataset_offset: 43 | rir_dataset: 44 | rir_dataset_limit: 45 | rir_dataset_offset: 46 | snr_range: 47 | reverb_proportion: 48 | clean_dataset: scp file 49 | noise_dataset: scp file 50 | sub_sample_length: 51 | sr: 52 | """ 53 | super().__init__() 54 | # acoustics args 55 | self.sr = sr 56 | 57 | # parallel args 58 | self.num_workers = num_workers 59 | 60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")] 61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")] 62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")] 63 | 64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit) 65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit) 66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit) 67 | 68 | if pre_load_clean_dataset: 69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset") 70 | 71 | if pre_load_noise: 72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset") 73 | 74 | if pre_load_rir: 75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset") 76 | 77 | self.clean_dataset_list = clean_dataset_list 78 | self.noise_dataset_list = noise_dataset_list 79 | self.rir_dataset_list = rir_dataset_list 80 | 81 | snr_list = self._parse_snr_range(snr_range) 82 | self.snr_list = snr_list 83 | 84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]" 85 | self.reverb_proportion = reverb_proportion 86 | self.silence_length = silence_length 87 | self.target_dB_FS = target_dB_FS 88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value 89 | self.sub_sample_length = sub_sample_length 90 | 91 | self.length = len(self.clean_dataset_list) 92 | 93 | def __len__(self): 94 | return self.length 95 | 96 | def _preload_dataset(self, file_path_list, remark=""): 97 | waveform_list = Parallel(n_jobs=self.num_workers)( 98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark) 99 | ) 100 | return list(zip(file_path_list, waveform_list)) 101 | 102 | @staticmethod 103 | def _random_select_from(dataset_list): 104 | return random.choice(dataset_list) 105 | 106 | def _select_noise_y(self, target_length): 107 | noise_y = np.zeros(0, dtype=np.float32) 108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32) 109 | remaining_length = target_length 110 | 111 | while remaining_length > 0: 112 | noise_file = self._random_select_from(self.noise_dataset_list) 113 | noise_new_added = load_wav(noise_file, sr=self.sr) 114 | noise_y = np.append(noise_y, noise_new_added) 115 | remaining_length -= len(noise_new_added) 116 | 117 | # 如果还需要添加新的噪声,就插入一个小静音段 118 | if remaining_length > 0: 119 | silence_len = min(remaining_length, len(silence)) 120 | noise_y = np.append(noise_y, silence[:silence_len]) 121 | remaining_length -= silence_len 122 | 123 | if len(noise_y) > target_length: 124 | idx_start = np.random.randint(len(noise_y) - target_length) 125 | noise_y = noise_y[idx_start:idx_start + target_length] 126 | 127 | return noise_y 128 | 129 | @staticmethod 130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6): 131 | """ 132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果 133 | 134 | Args: 135 | clean_y: 纯净语音 136 | noise_y: 噪声 137 | snr (int): 信噪比 138 | target_dB_FS (int): 139 | target_dB_FS_floating_value (int): 140 | rir: room impulse response, None 或 np.array 141 | eps: eps 142 | 143 | Returns: 144 | (noisy_y,clean_y) 145 | """ 146 | if rir is not None: 147 | if rir.ndim > 1: 148 | rir_idx = np.random.randint(0, rir.shape[0]) 149 | rir = rir[rir_idx, :] 150 | 151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)] 152 | 153 | clean_y, _ = norm_amplitude(clean_y) 154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS) 155 | clean_rms = (clean_y ** 2).mean() ** 0.5 156 | 157 | noise_y, _ = norm_amplitude(noise_y) 158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS) 159 | noise_rms = (noise_y ** 2).mean() ** 0.5 160 | 161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps) 162 | noise_y *= snr_scalar 163 | noisy_y = clean_y + noise_y 164 | 165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value 166 | noisy_target_dB_FS = np.random.randint( 167 | target_dB_FS - target_dB_FS_floating_value, 168 | target_dB_FS + target_dB_FS_floating_value 169 | ) 170 | 171 | # 使用 noisy 的 rms 放缩音频 172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS) 173 | clean_y *= noisy_scalar 174 | 175 | # 合成带噪语音的时候可能会 clipping,虽然极少 176 | # 对 noisy, clean_y, noise_y 稍微进行调整 177 | if is_clipped(noisy_y): 178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1 179 | noisy_y = noisy_y / noisy_y_scalar 180 | clean_y = clean_y / noisy_y_scalar 181 | 182 | return noisy_y, clean_y 183 | 184 | def __getitem__(self, item): 185 | clean_file = self.clean_dataset_list[item] 186 | clean_y = load_wav(clean_file, sr=self.sr) 187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr)) 188 | 189 | noise_y = self._select_noise_y(target_length=len(clean_y)) 190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}" 191 | 192 | snr = self._random_select_from(self.snr_list) 193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion) 194 | 195 | noisy_y, clean_y = self.snr_mix( 196 | clean_y=clean_y, 197 | noise_y=noise_y, 198 | snr=snr, 199 | target_dB_FS=self.target_dB_FS, 200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value, 201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None 202 | ) 203 | 204 | noisy_y = noisy_y.astype(np.float32) 205 | clean_y = clean_y.astype(np.float32) 206 | 207 | return noisy_y, clean_y 208 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/dataset/dataset_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import librosa 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.acoustics.utils import load_wav 8 | from audio_zen.utils import basename 9 | 10 | 11 | class Dataset(BaseDataset): 12 | def __init__( 13 | self, 14 | dataset_dir_list, 15 | sr, 16 | ): 17 | """ 18 | Construct DNS validation set 19 | 20 | synthetic/ 21 | with_reverb/ 22 | noisy/ 23 | clean_y/ 24 | no_reverb/ 25 | noisy/ 26 | clean_y/ 27 | """ 28 | super(Dataset, self).__init__() 29 | noisy_files_list = [] 30 | 31 | for dataset_dir in dataset_dir_list: 32 | dataset_dir = Path(dataset_dir).expanduser().absolute() 33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix()) 34 | 35 | self.length = len(noisy_files_list) 36 | self.noisy_files_list = noisy_files_list 37 | self.sr = sr 38 | 39 | def __len__(self): 40 | return self.length 41 | 42 | def __getitem__(self, item): 43 | """ 44 | use the absolute path of the noisy speech to find the corresponding clean speech. 45 | 46 | Notes 47 | with_reverb and no_reverb dirs have same-named files. 48 | If we use `basename`, the problem will be raised (cover) in visualization. 49 | 50 | Returns: 51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name 52 | """ 53 | noisy_file_path = self.noisy_files_list[item] 54 | parent_dir = Path(noisy_file_path).parents[1].name 55 | noisy_filename, _ = basename(noisy_file_path) 56 | 57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename 58 | 59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化 60 | if parent_dir == "with_reverb": 61 | speech_type = "With_reverb" 62 | elif parent_dir == "no_reverb": 63 | speech_type = "No_reverb" 64 | elif parent_dir == "dns_2_non_english": 65 | speech_type = "Non_english" 66 | elif parent_dir == "dns_2_emotion": 67 | speech_type = "Emotion" 68 | elif parent_dir == "dns_2_singing": 69 | speech_type = "Singing" 70 | else: 71 | raise NotImplementedError(f"Not supported dir: {parent_dir}") 72 | 73 | # Find the corresponding clean speech using "parent_dir" and "file_id" 74 | file_id = noisy_filename.split("_")[-1] 75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"): 76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15 77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}" 78 | elif parent_dir == "dns_2_non_english": 79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121 80 | clean_filename = f"synthetic_clean_fileid_{file_id}" 81 | else: 82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300 83 | if parent_dir == "with_reverb": 84 | reverb_remark = "with_reverb" 85 | clean_filename = f"clean_fileid_{file_id}" 86 | 87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}") 88 | 89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr) 90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr) 91 | 92 | return noisy, clean, reverb_remark + noisy_filename, speech_type 93 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/inferencer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/inferencer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/inferencer/inferencer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from audio_zen.acoustics.feature import mag_phase 5 | from audio_zen.acoustics.mask import decompress_cIRM 6 | from audio_zen.inferencer.base_inferencer import BaseInferencer 7 | 8 | # for log 9 | from utils.logger import log 10 | print=log 11 | 12 | def cumulative_norm(input): 13 | eps = 1e-10 14 | device = input.device 15 | data_type = input.dtype 16 | n_dim = input.ndim 17 | 18 | assert n_dim in (3, 4) 19 | 20 | if n_dim == 3: 21 | n_channels = 1 22 | batch_size, n_freqs, n_frames = input.size() 23 | else: 24 | batch_size, n_channels, n_freqs, n_frames = input.size() 25 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames) 26 | 27 | step_sum = torch.sum(input, dim=1) # [B, T] 28 | step_pow_sum = torch.sum(torch.square(input), dim=1) 29 | 30 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 31 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T] 32 | 33 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device) 34 | entry_count = entry_count.reshape(1, n_frames) # [1, T] 35 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 36 | 37 | cum_mean = cumulative_sum / entry_count # B, T 38 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T 39 | cum_std = (cum_var + eps).sqrt() # B, T 40 | 41 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames) 42 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames) 43 | 44 | x = (input - cum_mean) / cum_std 45 | 46 | if n_dim == 4: 47 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames) 48 | 49 | return x 50 | 51 | 52 | class Inferencer(BaseInferencer): 53 | def __init__(self, config, checkpoint_path, output_dir): 54 | super().__init__(config, checkpoint_path, output_dir) 55 | 56 | @torch.no_grad() 57 | def mag(self, noisy, inference_args): 58 | noisy_complex = self.torch_stft(noisy) 59 | noisy_mag, noisy_phase = mag_phase(noisy_complex) # [B, F, T] => [B, 1, F, T] 60 | 61 | enhanced_mag = self.model(noisy_mag.unsqueeze(1)).squeeze(1) 62 | 63 | enhanced = self.torch_istft((enhanced_mag, noisy_phase), length=noisy.size(-1), use_mag_phase=True) 64 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 65 | 66 | return enhanced 67 | 68 | @torch.no_grad() 69 | def scaled_mask(self, noisy, inference_args): 70 | noisy_complex = self.torch_stft(noisy) 71 | noisy_mag, noisy_phase = mag_phase(noisy_complex) 72 | 73 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2] 74 | noisy_mag = noisy_mag.unsqueeze(1) 75 | scaled_mask = self.model(noisy_mag) 76 | scaled_mask = scaled_mask.permute(0, 2, 3, 1) 77 | 78 | enhanced_complex = noisy_complex * scaled_mask 79 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1), use_mag_phase=False) 80 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 81 | 82 | return enhanced 83 | 84 | @torch.no_grad() 85 | def sub_band_crm_mask(self, noisy, inference_args): 86 | pad_mode = inference_args["pad_mode"] 87 | n_neighbor = inference_args["n_neighbor"] 88 | 89 | noisy = noisy.cpu().numpy().reshape(-1) 90 | noisy_D = self.librosa_stft(noisy) 91 | 92 | noisy_real = torch.tensor(noisy_D.real, device=self.device) 93 | noisy_imag = torch.tensor(noisy_D.imag, device=self.device) 94 | noisy_mag = torch.sqrt(torch.square(noisy_real) + torch.square(noisy_imag)) # [F, T] 95 | n_freqs, n_frames = noisy_mag.size() 96 | 97 | noisy_mag = noisy_mag.reshape(1, 1, n_freqs, n_frames) 98 | noisy_mag_padded = self._unfold(noisy_mag, pad_mode, n_neighbor) # [B, N, C, F_s, T] <=> [1, 257, 1, 31, T] 99 | noisy_mag_padded = noisy_mag_padded.squeeze(0).squeeze(1) # [257, 31, 200] <=> [B, F_s, T] 100 | 101 | pred_crm = self.model(noisy_mag_padded).detach() # [B, 2, T] <=> [F, 2, T] 102 | pred_crm = pred_crm.permute(0, 2, 1).contiguous() # [B, T, 2] 103 | 104 | lim = 9.99 105 | pred_crm = lim * (pred_crm >= lim) - lim * (pred_crm <= -lim) + pred_crm * (torch.abs(pred_crm) < lim) 106 | pred_crm = -10 * torch.log((10 - pred_crm) / (10 + pred_crm)) 107 | 108 | enhanced_real = pred_crm[:, :, 0] * noisy_real - pred_crm[:, :, 1] * noisy_imag 109 | enhanced_imag = pred_crm[:, :, 1] * noisy_real + pred_crm[:, :, 0] * noisy_imag 110 | 111 | enhanced_real = enhanced_real.cpu().numpy() 112 | enhanced_imag = enhanced_imag.cpu().numpy() 113 | enhanced = self.librosa_istft(enhanced_real + 1j * enhanced_imag, length=len(noisy)) 114 | return enhanced 115 | 116 | @torch.no_grad() 117 | def full_band_crm_mask(self, noisy, inference_args): 118 | noisy_complex = self.torch_stft(noisy) 119 | noisy_mag, _ = mag_phase(noisy_complex) 120 | 121 | noisy_mag = noisy_mag.unsqueeze(1) 122 | t1 = time.time() 123 | pred_crm = self.model(noisy_mag) 124 | t2 = time.time() 125 | pred_crm = pred_crm.permute(0, 2, 3, 1) 126 | 127 | pred_crm = decompress_cIRM(pred_crm) 128 | enhanced_real = pred_crm[..., 0] * noisy_complex.real - pred_crm[..., 1] * noisy_complex.imag 129 | enhanced_imag = pred_crm[..., 1] * noisy_complex.real + pred_crm[..., 0] * noisy_complex.imag 130 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1) 131 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1)) 132 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 133 | 134 | # 135 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"]) 136 | print(f"model rtf: {rtf}") 137 | 138 | return enhanced 139 | 140 | @torch.no_grad() 141 | def overlapped_chunk(self, noisy, inference_args): 142 | sr = self.acoustic_config["sr"] 143 | 144 | noisy = noisy.squeeze(0) 145 | 146 | num_mics = 8 147 | chunk_length = sr * inference_args["chunk_length"] 148 | chunk_hop_length = chunk_length // 2 149 | num_chunks = int(noisy.shape[-1] / chunk_hop_length) + 1 150 | 151 | win = torch.hann_window(chunk_length, device=noisy.device) 152 | 153 | prev = None 154 | enhanced = None 155 | # 模拟语音的静音段,防止一上来就给语音,处理的不好 156 | for chunk_idx in range(num_chunks): 157 | if chunk_idx == 0: 158 | pad = torch.zeros((num_mics, 256), device=noisy.device) 159 | 160 | chunk_start_position = chunk_idx * chunk_hop_length 161 | chunk_end_position = chunk_start_position + chunk_length 162 | 163 | # concat([(8, 256), (..., ... + chunk_length)]) 164 | noisy_chunk = torch.cat((pad, noisy[:, chunk_start_position:chunk_end_position]), dim=1) 165 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0)) 166 | enhanced_chunk = torch.squeeze(enhanced_chunk) 167 | enhanced_chunk = enhanced_chunk[256:] 168 | 169 | # Save the prior half chunk, 170 | cur = enhanced_chunk[:chunk_length // 2] 171 | 172 | # only for the 1st chunk,no overlap for the very 1st chunk prior half 173 | prev = enhanced_chunk[chunk_length // 2:] * win[chunk_length // 2:] 174 | else: 175 | # use the previous noisy data as the pad 176 | pad = noisy[:, (chunk_idx * chunk_hop_length - 256):(chunk_idx * chunk_hop_length)] 177 | 178 | chunk_start_position = chunk_idx * chunk_hop_length 179 | chunk_end_position = chunk_start_position + chunk_length 180 | 181 | noisy_chunk = torch.cat((pad, noisy[:8, chunk_start_position:chunk_end_position]), dim=1) 182 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0)) 183 | enhanced_chunk = torch.squeeze(enhanced_chunk) 184 | enhanced_chunk = enhanced_chunk[256:] 185 | 186 | # 使用这个窗函数来对拼接的位置进行平滑? 187 | enhanced_chunk = enhanced_chunk * win[:len(enhanced_chunk)] 188 | 189 | tmp = enhanced_chunk[:chunk_length // 2] 190 | cur = tmp[:min(len(tmp), len(prev))] + prev[:min(len(tmp), len(prev))] 191 | prev = enhanced_chunk[chunk_length // 2:] 192 | 193 | if enhanced is None: 194 | enhanced = cur 195 | else: 196 | enhanced = torch.cat((enhanced, cur), dim=0) 197 | 198 | enhanced = enhanced[:noisy.shape[1]] 199 | return enhanced.detach().squeeze(0).cpu().numpy() 200 | 201 | @torch.no_grad() 202 | def time_domain(self, noisy, inference_args): 203 | noisy = noisy.to(self.device) 204 | enhanced = self.model(noisy) 205 | return enhanced.detach().squeeze().cpu().numpy() 206 | 207 | 208 | if __name__ == '__main__': 209 | a = torch.rand(10, 2, 161, 200) 210 | print(cumulative_norm(a).shape) 211 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/model/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/model/fullsubnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional 3 | 4 | from audio_zen.acoustics.feature import drop_band 5 | from audio_zen.model.base_model import BaseModel 6 | from audio_zen.model.module.sequence_model import SequenceModel 7 | 8 | # for log 9 | from utils.logger import log 10 | print=log 11 | 12 | class Model(BaseModel): 13 | def __init__(self, 14 | num_freqs, 15 | look_ahead, 16 | sequence_model, 17 | fb_num_neighbors, 18 | sb_num_neighbors, 19 | fb_output_activate_function, 20 | sb_output_activate_function, 21 | fb_model_hidden_size, 22 | sb_model_hidden_size, 23 | norm_type="offline_laplace_norm", 24 | num_groups_in_drop_band=2, 25 | weight_init=True, 26 | ): 27 | """ 28 | FullSubNet model (cIRM mask) 29 | 30 | Args: 31 | num_freqs: Frequency dim of the input 32 | sb_num_neighbors: Number of the neighbor frequencies in each side 33 | look_ahead: Number of use of the future frames 34 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM) 35 | """ 36 | super().__init__() 37 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM." 38 | 39 | self.fb_model = SequenceModel( 40 | input_size=num_freqs, 41 | output_size=num_freqs, 42 | hidden_size=fb_model_hidden_size, 43 | num_layers=2, 44 | bidirectional=False, 45 | sequence_model=sequence_model, 46 | output_activate_function=fb_output_activate_function 47 | ) 48 | 49 | self.sb_model = SequenceModel( 50 | input_size=(sb_num_neighbors * 2 + 1) + (fb_num_neighbors * 2 + 1), 51 | output_size=2, 52 | hidden_size=sb_model_hidden_size, 53 | num_layers=2, 54 | bidirectional=False, 55 | sequence_model=sequence_model, 56 | output_activate_function=sb_output_activate_function 57 | ) 58 | 59 | self.sb_num_neighbors = sb_num_neighbors 60 | self.fb_num_neighbors = fb_num_neighbors 61 | self.look_ahead = look_ahead 62 | self.norm = self.norm_wrapper(norm_type) 63 | self.num_groups_in_drop_band = num_groups_in_drop_band 64 | 65 | if weight_init: 66 | self.apply(self.weight_init) 67 | 68 | def forward(self, noisy_mag): 69 | """ 70 | Args: 71 | noisy_mag: noisy magnitude spectrogram 72 | 73 | Returns: 74 | The real part and imag part of the enhanced spectrogram 75 | 76 | Shapes: 77 | noisy_mag: [B, 1, F, T] 78 | return: [B, 2, F, T] 79 | """ 80 | assert noisy_mag.dim() == 4 81 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead 82 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size() 83 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs." 84 | 85 | # Fullband model 86 | fb_input = self.norm(noisy_mag).reshape(batch_size, num_channels * num_freqs, num_frames) 87 | fb_output = self.fb_model(fb_input).reshape(batch_size, 1, num_freqs, num_frames) 88 | 89 | # Unfold the output of the fullband model, [B, N=F, C, F_f, T] 90 | fb_output_unfolded = self.unfold(fb_output, num_neighbor=self.fb_num_neighbors) 91 | fb_output_unfolded = fb_output_unfolded.reshape(batch_size, num_freqs, self.fb_num_neighbors * 2 + 1, num_frames) 92 | 93 | # Unfold noisy input, [B, N=F, C, F_s, T] 94 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors) 95 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, num_frames) 96 | 97 | # Concatenation, [B, F, (F_s + F_f), T] 98 | sb_input = torch.cat([noisy_mag_unfolded, fb_output_unfolded], dim=2) 99 | sb_input = self.norm(sb_input) 100 | 101 | # Speeding up training without significant performance degradation. These will be updated to the paper later. 102 | if batch_size > 1: 103 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3), num_groups=self.num_groups_in_drop_band) # [B, (F_s + F_f), F//num_groups, T] 104 | num_freqs = sb_input.shape[2] 105 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, (F_s + F_f), T] 106 | 107 | sb_input = sb_input.reshape( 108 | batch_size * num_freqs, 109 | (self.sb_num_neighbors * 2 + 1) + (self.fb_num_neighbors * 2 + 1), 110 | num_frames 111 | ) 112 | 113 | # [B * F, (F_s + F_f), T] => [B * F, 2, T] => [B, F, 2, T] 114 | sb_mask = self.sb_model(sb_input) 115 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous() 116 | 117 | output = sb_mask[:, :, :, self.look_ahead:] 118 | return output 119 | 120 | 121 | if __name__ == "__main__": 122 | import datetime 123 | 124 | with torch.no_grad(): 125 | model = Model( 126 | sb_num_neighbors=15, 127 | fb_num_neighbors=0, 128 | num_freqs=257, 129 | look_ahead=2, 130 | sequence_model="LSTM", 131 | fb_output_activate_function="ReLU", 132 | sb_output_activate_function=None, 133 | fb_model_hidden_size=512, 134 | sb_model_hidden_size=384, 135 | weight_init=False, 136 | norm_type="offline_laplace_norm", 137 | num_groups_in_drop_band=2, 138 | ) 139 | # ipt = torch.rand(3, 800) # 1.6s 140 | # ipt_len = ipt.shape[-1] 141 | # # 1000 frames (16s) - 5.65s (35.31%,纯模型) - 5.78s 142 | # # 500 frames (8s) - 3.05s (38.12%,纯模型) - 3.04s 143 | # # 200 frames (3.2s) - 1.19s (37.19%,纯模型) - 1.20s 144 | # # 100 frames (1.6s) - 0.62s (38.75%,纯模型) - 0.65s 145 | # start = datetime.datetime.now() 146 | # 147 | # complex_tensor = torch.stft(ipt, n_fft=512, hop_length=256) 148 | # mag = (complex_tensor.pow(2.).sum(-1) + 1e-8).pow(0.5 * 1.0).unsqueeze(1) 149 | # print(f"STFT: {datetime.datetime.now() - start}, {mag.shape}") 150 | # 151 | # enhanced_complex_tensor = model(mag).detach().permute(0, 2, 3, 1) 152 | # print(enhanced_complex_tensor.shape) 153 | # print(f"Model Inference: {datetime.datetime.now() - start}") 154 | # 155 | # enhanced = torch.istft(enhanced_complex_tensor, 512, 256, length=ipt_len) 156 | # print(f"iSTFT: {datetime.datetime.now() - start}") 157 | # 158 | # print(f"{datetime.datetime.now() - start}") 159 | ipt = torch.rand(3, 1, 257, 200) 160 | print(model(ipt).shape) 161 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/trainer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | from torch.cuda.amp import autocast 4 | from tqdm import tqdm 5 | 6 | from audio_zen.acoustics.feature import mag_phase, drop_band 7 | from audio_zen.acoustics.mask import build_complex_ideal_ratio_mask, decompress_cIRM 8 | from audio_zen.trainer.base_trainer import BaseTrainer 9 | from utils.logger import log 10 | 11 | plt.switch_backend('agg') 12 | 13 | 14 | class Trainer(BaseTrainer): 15 | def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer, train_dataloader, validation_dataloader): 16 | super().__init__(dist, rank, config, resume, only_validation, model, loss_function, optimizer) 17 | self.train_dataloader = train_dataloader 18 | self.valid_dataloader = validation_dataloader 19 | 20 | def _train_epoch(self, epoch): 21 | 22 | loss_total = 0.0 23 | progress_bar = None 24 | 25 | if self.rank == 0: 26 | progress_bar = tqdm(total=len(self.train_dataloader), desc=f"Training") 27 | 28 | for noisy, clean in self.train_dataloader: 29 | self.optimizer.zero_grad() 30 | 31 | noisy = noisy.to(self.rank) 32 | clean = clean.to(self.rank) 33 | 34 | noisy_complex = self.torch_stft(noisy) 35 | clean_complex = self.torch_stft(clean) 36 | 37 | noisy_mag, _ = mag_phase(noisy_complex) 38 | ground_truth_cIRM = build_complex_ideal_ratio_mask(noisy_complex, clean_complex) # [B, F, T, 2] 39 | ground_truth_cIRM = drop_band( 40 | ground_truth_cIRM.permute(0, 3, 1, 2), # [B, 2, F ,T] 41 | self.model.module.num_groups_in_drop_band 42 | ).permute(0, 2, 3, 1) 43 | 44 | with autocast(enabled=self.use_amp): 45 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2] 46 | noisy_mag = noisy_mag.unsqueeze(1) 47 | cRM = self.model(noisy_mag) 48 | cRM = cRM.permute(0, 2, 3, 1) 49 | loss = self.loss_function(ground_truth_cIRM, cRM) 50 | 51 | self.scaler.scale(loss).backward() 52 | self.scaler.unscale_(self.optimizer) 53 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value) 54 | self.scaler.step(self.optimizer) 55 | self.scaler.update() 56 | 57 | loss_total += loss.item() 58 | 59 | if self.rank == 0: 60 | progress_bar.update(1) 61 | progress_bar.refresh() 62 | 63 | if self.rank == 0: 64 | log(f"[Train] Epoch {epoch}, Loss {loss_total / len(self.train_dataloader)}") 65 | self.writer.add_scalar(f"Loss/Train", loss_total / len(self.train_dataloader), epoch) 66 | 67 | @torch.no_grad() 68 | def _validation_epoch(self, epoch): 69 | progress_bar = None 70 | if self.rank == 0: 71 | progress_bar = tqdm(total=len(self.valid_dataloader), desc=f"Validation") 72 | 73 | visualization_n_samples = self.visualization_config["n_samples"] 74 | visualization_num_workers = self.visualization_config["num_workers"] 75 | visualization_metrics = self.visualization_config["metrics"] 76 | 77 | loss_total = 0.0 78 | loss_list = {"With_reverb": 0.0, "No_reverb": 0.0, } 79 | item_idx_list = {"With_reverb": 0, "No_reverb": 0, } 80 | noisy_y_list = {"With_reverb": [], "No_reverb": [], } 81 | clean_y_list = {"With_reverb": [], "No_reverb": [], } 82 | enhanced_y_list = {"With_reverb": [], "No_reverb": [], } 83 | validation_score_list = {"With_reverb": 0.0, "No_reverb": 0.0} 84 | 85 | # speech_type in ("with_reverb", "no_reverb") 86 | for i, (noisy, clean, name, speech_type) in enumerate(self.valid_dataloader): 87 | assert len(name) == 1, "The batch size for the validation stage must be one." 88 | name = name[0] 89 | speech_type = speech_type[0] 90 | 91 | noisy = noisy.to(self.rank) 92 | clean = clean.to(self.rank) 93 | 94 | noisy_complex = self.torch_stft(noisy) 95 | clean_complex = self.torch_stft(clean) 96 | 97 | noisy_mag, _ = mag_phase(noisy_complex) 98 | cIRM = build_complex_ideal_ratio_mask(noisy_complex, clean_complex) # [B, F, T, 2] 99 | 100 | noisy_mag = noisy_mag.unsqueeze(1) 101 | cRM = self.model(noisy_mag) 102 | cRM = cRM.permute(0, 2, 3, 1) 103 | 104 | loss = self.loss_function(cIRM, cRM) 105 | 106 | cRM = decompress_cIRM(cRM) 107 | 108 | enhanced_real = cRM[..., 0] * noisy_complex.real - cRM[..., 1] * noisy_complex.imag 109 | enhanced_imag = cRM[..., 1] * noisy_complex.real + cRM[..., 0] * noisy_complex.imag 110 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1) 111 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1)) 112 | 113 | noisy = noisy.detach().squeeze(0).cpu().numpy() 114 | clean = clean.detach().squeeze(0).cpu().numpy() 115 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 116 | 117 | assert len(noisy) == len(clean) == len(enhanced) 118 | loss_total += loss 119 | 120 | # Separated loss 121 | loss_list[speech_type] += loss 122 | item_idx_list[speech_type] += 1 123 | 124 | if item_idx_list[speech_type] <= visualization_n_samples: 125 | self.spec_audio_visualization(noisy, enhanced, clean, name, epoch, mark=speech_type) 126 | 127 | noisy_y_list[speech_type].append(noisy) 128 | clean_y_list[speech_type].append(clean) 129 | enhanced_y_list[speech_type].append(enhanced) 130 | 131 | if self.rank == 0: 132 | progress_bar.update(1) 133 | 134 | log(f"[Test] Epoch {epoch}, Loss {loss_total / len(self.valid_dataloader)}") 135 | self.writer.add_scalar(f"Loss/Validation_Total", loss_total / len(self.valid_dataloader), epoch) 136 | 137 | for speech_type in ("With_reverb", "No_reverb"): 138 | log(f"[Test] Epoch {epoch}, {speech_type}, Loss {loss_list[speech_type] / len(self.valid_dataloader)}") 139 | self.writer.add_scalar(f"Loss/{speech_type}", loss_list[speech_type] / len(self.valid_dataloader), epoch) 140 | 141 | validation_score_list[speech_type] = self.metrics_visualization( 142 | noisy_y_list[speech_type], clean_y_list[speech_type], enhanced_y_list[speech_type], 143 | visualization_metrics, epoch, visualization_num_workers, mark=speech_type 144 | ) 145 | 146 | return validation_score_list["No_reverb"] 147 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet_plus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet_plus/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/dataset/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet_plus/dataset/dataset_inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import librosa 4 | import numpy as np 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.utils import basename 8 | 9 | 10 | class Dataset(BaseDataset): 11 | def __init__(self, 12 | dataset_dir_list, 13 | sr, 14 | ): 15 | """ 16 | Args: 17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list 18 | """ 19 | super().__init__() 20 | assert isinstance(dataset_dir_list, list) 21 | self.sr = sr 22 | 23 | noisy_file_path_list = [] 24 | for dataset_dir in dataset_dir_list: 25 | dataset_dir = Path(dataset_dir).expanduser().absolute() 26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted 27 | 28 | self.noisy_file_path_list = noisy_file_path_list 29 | self.length = len(self.noisy_file_path_list) 30 | 31 | def __len__(self): 32 | return self.length 33 | 34 | def __getitem__(self, item): 35 | noisy_file_path = self.noisy_file_path_list[item] 36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0] 37 | noisy_y = noisy_y.astype(np.float32) 38 | 39 | return noisy_y, basename(noisy_file_path)[0] 40 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet_plus/dataset/dataset_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample 5 | from audio_zen.dataset.base_dataset import BaseDataset 6 | from audio_zen.utils import expand_path 7 | from joblib import Parallel, delayed 8 | from scipy import signal 9 | from tqdm import tqdm 10 | 11 | 12 | class Dataset(BaseDataset): 13 | def __init__(self, 14 | clean_dataset, 15 | clean_dataset_limit, 16 | clean_dataset_offset, 17 | noise_dataset, 18 | noise_dataset_limit, 19 | noise_dataset_offset, 20 | rir_dataset, 21 | rir_dataset_limit, 22 | rir_dataset_offset, 23 | snr_range, 24 | reverb_proportion, 25 | silence_length, 26 | target_dB_FS, 27 | target_dB_FS_floating_value, 28 | sub_sample_length, 29 | sr, 30 | pre_load_clean_dataset, 31 | pre_load_noise, 32 | pre_load_rir, 33 | num_workers 34 | ): 35 | """ 36 | Dynamic mixing for training 37 | 38 | Args: 39 | clean_dataset_limit: 40 | clean_dataset_offset: 41 | noise_dataset_limit: 42 | noise_dataset_offset: 43 | rir_dataset: 44 | rir_dataset_limit: 45 | rir_dataset_offset: 46 | snr_range: 47 | reverb_proportion: 48 | clean_dataset: scp file 49 | noise_dataset: scp file 50 | sub_sample_length: 51 | sr: 52 | """ 53 | super().__init__() 54 | # acoustics args 55 | self.sr = sr 56 | 57 | # parallel args 58 | self.num_workers = num_workers 59 | 60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")] 61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")] 62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")] 63 | 64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit) 65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit) 66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit) 67 | 68 | if pre_load_clean_dataset: 69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset") 70 | 71 | if pre_load_noise: 72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset") 73 | 74 | if pre_load_rir: 75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset") 76 | 77 | self.clean_dataset_list = clean_dataset_list 78 | self.noise_dataset_list = noise_dataset_list 79 | self.rir_dataset_list = rir_dataset_list 80 | 81 | snr_list = self._parse_snr_range(snr_range) 82 | self.snr_list = snr_list 83 | 84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]" 85 | self.reverb_proportion = reverb_proportion 86 | self.silence_length = silence_length 87 | self.target_dB_FS = target_dB_FS 88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value 89 | self.sub_sample_length = sub_sample_length 90 | 91 | self.length = len(self.clean_dataset_list) 92 | 93 | def __len__(self): 94 | return self.length 95 | 96 | def _preload_dataset(self, file_path_list, remark=""): 97 | waveform_list = Parallel(n_jobs=self.num_workers)( 98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark) 99 | ) 100 | return list(zip(file_path_list, waveform_list)) 101 | 102 | @staticmethod 103 | def _random_select_from(dataset_list): 104 | return random.choice(dataset_list) 105 | 106 | def _select_noise_y(self, target_length): 107 | noise_y = np.zeros(0, dtype=np.float32) 108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32) 109 | remaining_length = target_length 110 | 111 | while remaining_length > 0: 112 | noise_file = self._random_select_from(self.noise_dataset_list) 113 | noise_new_added = load_wav(noise_file, sr=self.sr) 114 | noise_y = np.append(noise_y, noise_new_added) 115 | remaining_length -= len(noise_new_added) 116 | 117 | # 如果还需要添加新的噪声,就插入一个小静音段 118 | if remaining_length > 0: 119 | silence_len = min(remaining_length, len(silence)) 120 | noise_y = np.append(noise_y, silence[:silence_len]) 121 | remaining_length -= silence_len 122 | 123 | if len(noise_y) > target_length: 124 | idx_start = np.random.randint(len(noise_y) - target_length) 125 | noise_y = noise_y[idx_start:idx_start + target_length] 126 | 127 | return noise_y 128 | 129 | @staticmethod 130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6): 131 | """ 132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果 133 | 134 | Args: 135 | clean_y: 纯净语音 136 | noise_y: 噪声 137 | snr (int): 信噪比 138 | target_dB_FS (int): 139 | target_dB_FS_floating_value (int): 140 | rir: room impulse response, None 或 np.array 141 | eps: eps 142 | 143 | Returns: 144 | (noisy_y,clean_y) 145 | """ 146 | if rir is not None: 147 | if rir.ndim > 1: 148 | rir_idx = np.random.randint(0, rir.shape[0]) 149 | rir = rir[rir_idx, :] 150 | 151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)] 152 | 153 | clean_y, _ = norm_amplitude(clean_y) 154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS) 155 | clean_rms = (clean_y ** 2).mean() ** 0.5 156 | 157 | noise_y, _ = norm_amplitude(noise_y) 158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS) 159 | noise_rms = (noise_y ** 2).mean() ** 0.5 160 | 161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps) 162 | noise_y *= snr_scalar 163 | noisy_y = clean_y + noise_y 164 | 165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value 166 | noisy_target_dB_FS = np.random.randint( 167 | target_dB_FS - target_dB_FS_floating_value, 168 | target_dB_FS + target_dB_FS_floating_value 169 | ) 170 | 171 | # 使用 noisy 的 rms 放缩音频 172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS) 173 | clean_y *= noisy_scalar 174 | 175 | # 合成带噪语音的时候可能会 clipping,虽然极少 176 | # 对 noisy, clean_y, noise_y 稍微进行调整 177 | if is_clipped(noisy_y): 178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1 179 | noisy_y = noisy_y / noisy_y_scalar 180 | clean_y = clean_y / noisy_y_scalar 181 | 182 | return noisy_y, clean_y 183 | 184 | def __getitem__(self, item): 185 | clean_file = self.clean_dataset_list[item] 186 | clean_y = load_wav(clean_file, sr=self.sr) 187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr)) 188 | 189 | noise_y = self._select_noise_y(target_length=len(clean_y)) 190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}" 191 | 192 | snr = self._random_select_from(self.snr_list) 193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion) 194 | 195 | noisy_y, clean_y = self.snr_mix( 196 | clean_y=clean_y, 197 | noise_y=noise_y, 198 | snr=snr, 199 | target_dB_FS=self.target_dB_FS, 200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value, 201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None 202 | ) 203 | 204 | noisy_y = noisy_y.astype(np.float32) 205 | clean_y = clean_y.astype(np.float32) 206 | 207 | return noisy_y, clean_y 208 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet_plus/dataset/dataset_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import librosa 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.acoustics.utils import load_wav 8 | from audio_zen.utils import basename 9 | 10 | 11 | class Dataset(BaseDataset): 12 | def __init__( 13 | self, 14 | dataset_dir_list, 15 | sr, 16 | ): 17 | """ 18 | Construct DNS validation set 19 | 20 | synthetic/ 21 | with_reverb/ 22 | noisy/ 23 | clean_y/ 24 | no_reverb/ 25 | noisy/ 26 | clean_y/ 27 | """ 28 | super(Dataset, self).__init__() 29 | noisy_files_list = [] 30 | 31 | for dataset_dir in dataset_dir_list: 32 | dataset_dir = Path(dataset_dir).expanduser().absolute() 33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix()) 34 | 35 | self.length = len(noisy_files_list) 36 | self.noisy_files_list = noisy_files_list 37 | self.sr = sr 38 | 39 | def __len__(self): 40 | return self.length 41 | 42 | def __getitem__(self, item): 43 | """ 44 | use the absolute path of the noisy speech to find the corresponding clean speech. 45 | 46 | Notes 47 | with_reverb and no_reverb dirs have same-named files. 48 | If we use `basename`, the problem will be raised (cover) in visualization. 49 | 50 | Returns: 51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name 52 | """ 53 | noisy_file_path = self.noisy_files_list[item] 54 | parent_dir = Path(noisy_file_path).parents[1].name 55 | noisy_filename, _ = basename(noisy_file_path) 56 | 57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename 58 | 59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化 60 | if parent_dir == "with_reverb": 61 | speech_type = "With_reverb" 62 | elif parent_dir == "no_reverb": 63 | speech_type = "No_reverb" 64 | elif parent_dir == "dns_2_non_english": 65 | speech_type = "Non_english" 66 | elif parent_dir == "dns_2_emotion": 67 | speech_type = "Emotion" 68 | elif parent_dir == "dns_2_singing": 69 | speech_type = "Singing" 70 | else: 71 | raise NotImplementedError(f"Not supported dir: {parent_dir}") 72 | 73 | # Find the corresponding clean speech using "parent_dir" and "file_id" 74 | file_id = noisy_filename.split("_")[-1] 75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"): 76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15 77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}" 78 | elif parent_dir == "dns_2_non_english": 79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121 80 | clean_filename = f"synthetic_clean_fileid_{file_id}" 81 | else: 82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300 83 | if parent_dir == "with_reverb": 84 | reverb_remark = "with_reverb" 85 | clean_filename = f"clean_fileid_{file_id}" 86 | 87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}") 88 | 89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr) 90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr) 91 | 92 | return noisy, clean, reverb_remark + noisy_filename, speech_type 93 | -------------------------------------------------------------------------------- /speech_enhance/fullsubnet_plus/inferencer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/inferencer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/fullsubnet_plus/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/trainer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/__init__.py -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/dataset/__init__.py -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/dataset/dataset_inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import librosa 4 | import numpy as np 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.utils import basename 8 | 9 | 10 | class Dataset(BaseDataset): 11 | def __init__(self, 12 | dataset_dir_list, 13 | sr, 14 | ): 15 | """ 16 | Args: 17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list 18 | """ 19 | super().__init__() 20 | assert isinstance(dataset_dir_list, list) 21 | self.sr = sr 22 | 23 | noisy_file_path_list = [] 24 | for dataset_dir in dataset_dir_list: 25 | dataset_dir = Path(dataset_dir).expanduser().absolute() 26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted 27 | 28 | self.noisy_file_path_list = noisy_file_path_list 29 | self.length = len(self.noisy_file_path_list) 30 | 31 | def __len__(self): 32 | return self.length 33 | 34 | def __getitem__(self, item): 35 | noisy_file_path = self.noisy_file_path_list[item] 36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0] 37 | noisy_y = noisy_y.astype(np.float32) 38 | 39 | return noisy_y, basename(noisy_file_path)[0] 40 | -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/dataset/dataset_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample 5 | from audio_zen.dataset.base_dataset import BaseDataset 6 | from audio_zen.utils import expand_path 7 | from joblib import Parallel, delayed 8 | from scipy import signal 9 | from tqdm import tqdm 10 | 11 | 12 | class Dataset(BaseDataset): 13 | def __init__(self, 14 | clean_dataset, 15 | clean_dataset_limit, 16 | clean_dataset_offset, 17 | noise_dataset, 18 | noise_dataset_limit, 19 | noise_dataset_offset, 20 | rir_dataset, 21 | rir_dataset_limit, 22 | rir_dataset_offset, 23 | snr_range, 24 | reverb_proportion, 25 | silence_length, 26 | target_dB_FS, 27 | target_dB_FS_floating_value, 28 | sub_sample_length, 29 | sr, 30 | pre_load_clean_dataset, 31 | pre_load_noise, 32 | pre_load_rir, 33 | num_workers 34 | ): 35 | """ 36 | Dynamic mixing for training 37 | 38 | Args: 39 | clean_dataset_limit: 40 | clean_dataset_offset: 41 | noise_dataset_limit: 42 | noise_dataset_offset: 43 | rir_dataset: 44 | rir_dataset_limit: 45 | rir_dataset_offset: 46 | snr_range: 47 | reverb_proportion: 48 | clean_dataset: scp file 49 | noise_dataset: scp file 50 | sub_sample_length: 51 | sr: 52 | """ 53 | super().__init__() 54 | # acoustics args 55 | self.sr = sr 56 | 57 | # parallel args 58 | self.num_workers = num_workers 59 | 60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")] 61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")] 62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")] 63 | 64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit) 65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit) 66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit) 67 | 68 | if pre_load_clean_dataset: 69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset") 70 | 71 | if pre_load_noise: 72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset") 73 | 74 | if pre_load_rir: 75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset") 76 | 77 | self.clean_dataset_list = clean_dataset_list 78 | self.noise_dataset_list = noise_dataset_list 79 | self.rir_dataset_list = rir_dataset_list 80 | 81 | snr_list = self._parse_snr_range(snr_range) 82 | self.snr_list = snr_list 83 | 84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]" 85 | self.reverb_proportion = reverb_proportion 86 | self.silence_length = silence_length 87 | self.target_dB_FS = target_dB_FS 88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value 89 | self.sub_sample_length = sub_sample_length 90 | 91 | self.length = len(self.clean_dataset_list) 92 | 93 | def __len__(self): 94 | return self.length 95 | 96 | def _preload_dataset(self, file_path_list, remark=""): 97 | waveform_list = Parallel(n_jobs=self.num_workers)( 98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark) 99 | ) 100 | return list(zip(file_path_list, waveform_list)) 101 | 102 | @staticmethod 103 | def _random_select_from(dataset_list): 104 | return random.choice(dataset_list) 105 | 106 | def _select_noise_y(self, target_length): 107 | noise_y = np.zeros(0, dtype=np.float32) 108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32) 109 | remaining_length = target_length 110 | 111 | while remaining_length > 0: 112 | noise_file = self._random_select_from(self.noise_dataset_list) 113 | noise_new_added = load_wav(noise_file, sr=self.sr) 114 | noise_y = np.append(noise_y, noise_new_added) 115 | remaining_length -= len(noise_new_added) 116 | 117 | # 如果还需要添加新的噪声,就插入一个小静音段 118 | if remaining_length > 0: 119 | silence_len = min(remaining_length, len(silence)) 120 | noise_y = np.append(noise_y, silence[:silence_len]) 121 | remaining_length -= silence_len 122 | 123 | if len(noise_y) > target_length: 124 | idx_start = np.random.randint(len(noise_y) - target_length) 125 | noise_y = noise_y[idx_start:idx_start + target_length] 126 | 127 | return noise_y 128 | 129 | @staticmethod 130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6): 131 | """ 132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果 133 | 134 | Args: 135 | clean_y: 纯净语音 136 | noise_y: 噪声 137 | snr (int): 信噪比 138 | target_dB_FS (int): 139 | target_dB_FS_floating_value (int): 140 | rir: room impulse response, None 或 np.array 141 | eps: eps 142 | 143 | Returns: 144 | (noisy_y,clean_y) 145 | """ 146 | if rir is not None: 147 | if rir.ndim > 1: 148 | rir_idx = np.random.randint(0, rir.shape[0]) 149 | rir = rir[rir_idx, :] 150 | 151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)] 152 | 153 | clean_y, _ = norm_amplitude(clean_y) 154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS) 155 | clean_rms = (clean_y ** 2).mean() ** 0.5 156 | 157 | noise_y, _ = norm_amplitude(noise_y) 158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS) 159 | noise_rms = (noise_y ** 2).mean() ** 0.5 160 | 161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps) 162 | noise_y *= snr_scalar 163 | noisy_y = clean_y + noise_y 164 | 165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value 166 | noisy_target_dB_FS = np.random.randint( 167 | target_dB_FS - target_dB_FS_floating_value, 168 | target_dB_FS + target_dB_FS_floating_value 169 | ) 170 | 171 | # 使用 noisy 的 rms 放缩音频 172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS) 173 | clean_y *= noisy_scalar 174 | 175 | # 合成带噪语音的时候可能会 clipping,虽然极少 176 | # 对 noisy, clean_y, noise_y 稍微进行调整 177 | if is_clipped(noisy_y): 178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1 179 | noisy_y = noisy_y / noisy_y_scalar 180 | clean_y = clean_y / noisy_y_scalar 181 | 182 | return noisy_y, clean_y 183 | 184 | def __getitem__(self, item): 185 | clean_file = self.clean_dataset_list[item] 186 | clean_y = load_wav(clean_file, sr=self.sr) 187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr)) 188 | 189 | noise_y = self._select_noise_y(target_length=len(clean_y)) 190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}" 191 | 192 | snr = self._random_select_from(self.snr_list) 193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion) 194 | 195 | noisy_y, clean_y = self.snr_mix( 196 | clean_y=clean_y, 197 | noise_y=noise_y, 198 | snr=snr, 199 | target_dB_FS=self.target_dB_FS, 200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value, 201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None 202 | ) 203 | 204 | noisy_y = noisy_y.astype(np.float32) 205 | clean_y = clean_y.astype(np.float32) 206 | 207 | return noisy_y, clean_y 208 | -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/dataset/dataset_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import librosa 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.acoustics.utils import load_wav 8 | from audio_zen.utils import basename 9 | 10 | 11 | class Dataset(BaseDataset): 12 | def __init__( 13 | self, 14 | dataset_dir_list, 15 | sr, 16 | ): 17 | """ 18 | Construct DNS validation set 19 | 20 | synthetic/ 21 | with_reverb/ 22 | noisy/ 23 | clean_y/ 24 | no_reverb/ 25 | noisy/ 26 | clean_y/ 27 | """ 28 | super(Dataset, self).__init__() 29 | noisy_files_list = [] 30 | 31 | for dataset_dir in dataset_dir_list: 32 | dataset_dir = Path(dataset_dir).expanduser().absolute() 33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix()) 34 | 35 | self.length = len(noisy_files_list) 36 | self.noisy_files_list = noisy_files_list 37 | self.sr = sr 38 | 39 | def __len__(self): 40 | return self.length 41 | 42 | def __getitem__(self, item): 43 | """ 44 | use the absolute path of the noisy speech to find the corresponding clean speech. 45 | 46 | Notes 47 | with_reverb and no_reverb dirs have same-named files. 48 | If we use `basename`, the problem will be raised (cover) in visualization. 49 | 50 | Returns: 51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name 52 | """ 53 | noisy_file_path = self.noisy_files_list[item] 54 | parent_dir = Path(noisy_file_path).parents[1].name 55 | noisy_filename, _ = basename(noisy_file_path) 56 | 57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename 58 | 59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化 60 | if parent_dir == "with_reverb": 61 | speech_type = "With_reverb" 62 | elif parent_dir == "no_reverb": 63 | speech_type = "No_reverb" 64 | elif parent_dir == "dns_2_non_english": 65 | speech_type = "Non_english" 66 | elif parent_dir == "dns_2_emotion": 67 | speech_type = "Emotion" 68 | elif parent_dir == "dns_2_singing": 69 | speech_type = "Singing" 70 | else: 71 | raise NotImplementedError(f"Not supported dir: {parent_dir}") 72 | 73 | # Find the corresponding clean speech using "parent_dir" and "file_id" 74 | file_id = noisy_filename.split("_")[-1] 75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"): 76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15 77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}" 78 | elif parent_dir == "dns_2_non_english": 79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121 80 | clean_filename = f"synthetic_clean_fileid_{file_id}" 81 | else: 82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300 83 | if parent_dir == "with_reverb": 84 | reverb_remark = "with_reverb" 85 | clean_filename = f"clean_fileid_{file_id}" 86 | 87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}") 88 | 89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr) 90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr) 91 | 92 | return noisy, clean, reverb_remark + noisy_filename, speech_type 93 | -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/inferencer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/inferencer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/inferencer/inferencer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from audio_zen.acoustics.feature import mag_phase 5 | from audio_zen.acoustics.mask import decompress_cIRM 6 | from audio_zen.inferencer.base_inferencer import BaseInferencer 7 | import soundfile as sf 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | # for log 12 | from utils.logger import log 13 | print=log 14 | 15 | def cumulative_norm(input): 16 | eps = 1e-10 17 | device = input.device 18 | data_type = input.dtype 19 | n_dim = input.ndim 20 | 21 | assert n_dim in (3, 4) 22 | 23 | if n_dim == 3: 24 | n_channels = 1 25 | batch_size, n_freqs, n_frames = input.size() 26 | else: 27 | batch_size, n_channels, n_freqs, n_frames = input.size() 28 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames) 29 | 30 | step_sum = torch.sum(input, dim=1) # [B, T] 31 | step_pow_sum = torch.sum(torch.square(input), dim=1) 32 | 33 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 34 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T] 35 | 36 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device) 37 | entry_count = entry_count.reshape(1, n_frames) # [1, T] 38 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 39 | 40 | cum_mean = cumulative_sum / entry_count # B, T 41 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T 42 | cum_std = (cum_var + eps).sqrt() # B, T 43 | 44 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames) 45 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames) 46 | 47 | x = (input - cum_mean) / cum_std 48 | 49 | if n_dim == 4: 50 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames) 51 | 52 | return x 53 | 54 | 55 | class Inferencer(BaseInferencer): 56 | def __init__(self, config, checkpoint_path, output_dir): 57 | super().__init__(config, checkpoint_path, output_dir) 58 | 59 | @torch.no_grad() 60 | def mag(self, noisy, inference_args): 61 | noisy_complex = self.torch_stft(noisy) 62 | noisy_mag, noisy_phase = mag_phase(noisy_complex) # [B, F, T] => [B, 1, F, T] 63 | 64 | enhanced_mag = self.model(noisy_mag.unsqueeze(1)).squeeze(1) 65 | 66 | enhanced = self.torch_istft((enhanced_mag, noisy_phase), length=noisy.size(-1), use_mag_phase=True) 67 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 68 | 69 | return enhanced 70 | 71 | @torch.no_grad() 72 | def scaled_mask(self, noisy, inference_args): 73 | noisy_complex = self.torch_stft(noisy) 74 | noisy_mag, noisy_phase = mag_phase(noisy_complex) 75 | 76 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2] 77 | noisy_mag = noisy_mag.unsqueeze(1) 78 | scaled_mask = self.model(noisy_mag) 79 | scaled_mask = scaled_mask.permute(0, 2, 3, 1) 80 | 81 | enhanced_complex = noisy_complex * scaled_mask 82 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1), use_mag_phase=False) 83 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 84 | 85 | return enhanced 86 | 87 | @torch.no_grad() 88 | def sub_band_crm_mask(self, noisy, inference_args): 89 | pad_mode = inference_args["pad_mode"] 90 | n_neighbor = inference_args["n_neighbor"] 91 | 92 | noisy = noisy.cpu().numpy().reshape(-1) 93 | noisy_D = self.librosa_stft(noisy) 94 | 95 | noisy_real = torch.tensor(noisy_D.real, device=self.device) 96 | noisy_imag = torch.tensor(noisy_D.imag, device=self.device) 97 | noisy_mag = torch.sqrt(torch.square(noisy_real) + torch.square(noisy_imag)) # [F, T] 98 | n_freqs, n_frames = noisy_mag.size() 99 | 100 | noisy_mag = noisy_mag.reshape(1, 1, n_freqs, n_frames) 101 | noisy_mag_padded = self._unfold(noisy_mag, pad_mode, n_neighbor) # [B, N, C, F_s, T] <=> [1, 257, 1, 31, T] 102 | noisy_mag_padded = noisy_mag_padded.squeeze(0).squeeze(1) # [257, 31, 200] <=> [B, F_s, T] 103 | 104 | pred_crm = self.model(noisy_mag_padded).detach() # [B, 2, T] <=> [F, 2, T] 105 | pred_crm = pred_crm.permute(0, 2, 1).contiguous() # [B, T, 2] 106 | 107 | lim = 9.99 108 | pred_crm = lim * (pred_crm >= lim) - lim * (pred_crm <= -lim) + pred_crm * (torch.abs(pred_crm) < lim) 109 | pred_crm = -10 * torch.log((10 - pred_crm) / (10 + pred_crm)) 110 | 111 | enhanced_real = pred_crm[:, :, 0] * noisy_real - pred_crm[:, :, 1] * noisy_imag 112 | enhanced_imag = pred_crm[:, :, 1] * noisy_real + pred_crm[:, :, 0] * noisy_imag 113 | 114 | enhanced_real = enhanced_real.cpu().numpy() 115 | enhanced_imag = enhanced_imag.cpu().numpy() 116 | enhanced = self.librosa_istft(enhanced_real + 1j * enhanced_imag, length=len(noisy)) 117 | return enhanced 118 | 119 | @torch.no_grad() 120 | def full_band_crm_mask(self, noisy, inference_args): 121 | noisy_complex = self.torch_stft(noisy) 122 | noisy_mag, _ = mag_phase(noisy_complex) 123 | 124 | noisy_mag = noisy_mag.unsqueeze(1) 125 | t1 = time.time() 126 | pred_crm = self.model(noisy_mag) 127 | t2 = time.time() 128 | pred_crm = pred_crm.permute(0, 2, 3, 1) 129 | 130 | pred_crm = decompress_cIRM(pred_crm) 131 | enhanced_real = pred_crm[..., 0] * noisy_complex.real - pred_crm[..., 1] * noisy_complex.imag 132 | enhanced_imag = pred_crm[..., 1] * noisy_complex.real + pred_crm[..., 0] * noisy_complex.imag 133 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1) 134 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1)) 135 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 136 | 137 | # rtf计算 138 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"]) 139 | print(f"model rtf: {rtf}") 140 | 141 | return enhanced 142 | 143 | 144 | @torch.no_grad() 145 | def overlapped_chunk(self, noisy, inference_args): 146 | sr = self.acoustic_config["sr"] 147 | 148 | noisy = noisy.squeeze(0) 149 | 150 | num_mics = 8 151 | chunk_length = sr * inference_args["chunk_length"] 152 | chunk_hop_length = chunk_length // 2 153 | num_chunks = int(noisy.shape[-1] / chunk_hop_length) + 1 154 | 155 | win = torch.hann_window(chunk_length, device=noisy.device) 156 | 157 | prev = None 158 | enhanced = None 159 | # 模拟语音的静音段,防止一上来就给语音,处理的不好 160 | for chunk_idx in range(num_chunks): 161 | if chunk_idx == 0: 162 | pad = torch.zeros((num_mics, 256), device=noisy.device) 163 | 164 | chunk_start_position = chunk_idx * chunk_hop_length 165 | chunk_end_position = chunk_start_position + chunk_length 166 | 167 | # concat([(8, 256), (..., ... + chunk_length)]) 168 | noisy_chunk = torch.cat((pad, noisy[:, chunk_start_position:chunk_end_position]), dim=1) 169 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0)) 170 | enhanced_chunk = torch.squeeze(enhanced_chunk) 171 | enhanced_chunk = enhanced_chunk[256:] 172 | 173 | # Save the prior half chunk, 174 | cur = enhanced_chunk[:chunk_length // 2] 175 | 176 | # only for the 1st chunk,no overlap for the very 1st chunk prior half 177 | prev = enhanced_chunk[chunk_length // 2:] * win[chunk_length // 2:] 178 | else: 179 | # use the previous noisy data as the pad 180 | pad = noisy[:, (chunk_idx * chunk_hop_length - 256):(chunk_idx * chunk_hop_length)] 181 | 182 | chunk_start_position = chunk_idx * chunk_hop_length 183 | chunk_end_position = chunk_start_position + chunk_length 184 | 185 | noisy_chunk = torch.cat((pad, noisy[:8, chunk_start_position:chunk_end_position]), dim=1) 186 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0)) 187 | enhanced_chunk = torch.squeeze(enhanced_chunk) 188 | enhanced_chunk = enhanced_chunk[256:] 189 | 190 | # 使用这个窗函数来对拼接的位置进行平滑? 191 | enhanced_chunk = enhanced_chunk * win[:len(enhanced_chunk)] 192 | 193 | tmp = enhanced_chunk[:chunk_length // 2] 194 | cur = tmp[:min(len(tmp), len(prev))] + prev[:min(len(tmp), len(prev))] 195 | prev = enhanced_chunk[chunk_length // 2:] 196 | 197 | if enhanced is None: 198 | enhanced = cur 199 | else: 200 | enhanced = torch.cat((enhanced, cur), dim=0) 201 | 202 | enhanced = enhanced[:noisy.shape[1]] 203 | return enhanced.detach().squeeze(0).cpu().numpy() 204 | 205 | @torch.no_grad() 206 | def time_domain(self, noisy, inference_args): 207 | noisy = noisy.to(self.device) 208 | enhanced = self.model(noisy) 209 | return enhanced.detach().squeeze().cpu().numpy() 210 | 211 | 212 | if __name__ == '__main__': 213 | a = torch.rand(10, 2, 161, 200) 214 | print(cumulative_norm(a).shape) 215 | -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/model/Inter_SubNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional 3 | 4 | from audio_zen.acoustics.feature import drop_band 5 | from audio_zen.model.base_model import BaseModel 6 | from audio_zen.model.module.sequence_model import stacked_SIL_blocks_SequenceModel 7 | 8 | # for log 9 | from utils.logger import log 10 | 11 | print = log 12 | 13 | 14 | class Inter_SubNet(BaseModel): 15 | def __init__(self, 16 | num_freqs, 17 | look_ahead, 18 | sequence_model, 19 | sb_num_neighbors, 20 | sb_output_activate_function, 21 | sb_model_hidden_size, 22 | norm_type="offline_laplace_norm", 23 | num_groups_in_drop_band=2, 24 | weight_init=True, 25 | sbinter_middle_hidden_times=0.66, 26 | ): 27 | """ 28 | Inter-SubNet model (cIRM mask) 29 | 30 | Args: 31 | num_freqs: Frequency dim of the input 32 | sb_num_neighbors: Number of the neighbor frequencies in each side 33 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM) 34 | """ 35 | super().__init__() 36 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM." 37 | 38 | subband_input_size = (sb_num_neighbors * 2 + 1) 39 | self.sb_model = stacked_SIL_blocks_SequenceModel( 40 | input_size=subband_input_size, 41 | output_size=2, 42 | hidden_size=sb_model_hidden_size, 43 | num_layers=2, 44 | bidirectional=False, 45 | sequence_model=sequence_model, 46 | output_activate_function=sb_output_activate_function, 47 | middle_tac_hidden_times=sbinter_middle_hidden_times 48 | ) 49 | 50 | self.sb_num_neighbors = sb_num_neighbors 51 | # self.fb_num_neighbors = fb_num_neighbors 52 | self.look_ahead = look_ahead 53 | self.norm = self.norm_wrapper(norm_type) 54 | self.num_groups_in_drop_band = num_groups_in_drop_band 55 | 56 | if weight_init: 57 | self.apply(self.weight_init) 58 | 59 | def forward(self, noisy_mag): 60 | """ 61 | Args: 62 | noisy_mag: noisy magnitude spectrogram 63 | 64 | Returns: 65 | The real part and imag part of the enhanced spectrogram 66 | 67 | Shapes: 68 | noisy_mag: [B, 1, F, T] 69 | return: [B, 2, F, T] 70 | """ 71 | assert noisy_mag.dim() == 4 72 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead 73 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size() 74 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs." 75 | 76 | # Unfold noisy input, [B, N=F, C, F_s, T] 77 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors) 78 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, 79 | num_frames) 80 | 81 | sb_input = self.norm(noisy_mag_unfolded) 82 | 83 | if batch_size > 1: 84 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3), 85 | num_groups=self.num_groups_in_drop_band) # [B, F_s, F//num_groups, T] 86 | num_freqs = sb_input.shape[2] 87 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, F_s, T] 88 | 89 | # [B, F//num_groups, F_s, T] => [B * F, 2, T] => [B, F, 2, T] 90 | sb_mask = self.sb_model(sb_input) 91 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous() 92 | 93 | output = sb_mask[:, :, :, self.look_ahead:] 94 | return output 95 | 96 | 97 | if __name__ == "__main__": 98 | model = Inter_SubNet( 99 | num_freqs=257, 100 | look_ahead=2, 101 | sequence_model="LSTM", 102 | sb_num_neighbors=15, 103 | sb_output_activate_function=False, 104 | sb_model_hidden_size=384, 105 | weight_init=False, 106 | norm_type="offline_laplace_norm", 107 | num_groups_in_drop_band=2, 108 | sbinter_middle_hidden_times=0.8 109 | ) 110 | -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/model/__init__.py -------------------------------------------------------------------------------- /speech_enhance/inter_subnet/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/trainer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/subband_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/__init__.py -------------------------------------------------------------------------------- /speech_enhance/subband_model/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/dataset/__init__.py -------------------------------------------------------------------------------- /speech_enhance/subband_model/dataset/dataset_inference.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import librosa 4 | import numpy as np 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.utils import basename 8 | 9 | 10 | class Dataset(BaseDataset): 11 | def __init__(self, 12 | dataset_dir_list, 13 | sr, 14 | ): 15 | """ 16 | Args: 17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list 18 | """ 19 | super().__init__() 20 | assert isinstance(dataset_dir_list, list) 21 | self.sr = sr 22 | 23 | noisy_file_path_list = [] 24 | for dataset_dir in dataset_dir_list: 25 | dataset_dir = Path(dataset_dir).expanduser().absolute() 26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted 27 | 28 | self.noisy_file_path_list = noisy_file_path_list 29 | self.length = len(self.noisy_file_path_list) 30 | 31 | def __len__(self): 32 | return self.length 33 | 34 | def __getitem__(self, item): 35 | noisy_file_path = self.noisy_file_path_list[item] 36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0] 37 | noisy_y = noisy_y.astype(np.float32) 38 | 39 | return noisy_y, basename(noisy_file_path)[0] 40 | -------------------------------------------------------------------------------- /speech_enhance/subband_model/dataset/dataset_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample 5 | from audio_zen.dataset.base_dataset import BaseDataset 6 | from audio_zen.utils import expand_path 7 | from joblib import Parallel, delayed 8 | from scipy import signal 9 | from tqdm import tqdm 10 | 11 | 12 | class Dataset(BaseDataset): 13 | def __init__(self, 14 | clean_dataset, 15 | clean_dataset_limit, 16 | clean_dataset_offset, 17 | noise_dataset, 18 | noise_dataset_limit, 19 | noise_dataset_offset, 20 | rir_dataset, 21 | rir_dataset_limit, 22 | rir_dataset_offset, 23 | snr_range, 24 | reverb_proportion, 25 | silence_length, 26 | target_dB_FS, 27 | target_dB_FS_floating_value, 28 | sub_sample_length, 29 | sr, 30 | pre_load_clean_dataset, 31 | pre_load_noise, 32 | pre_load_rir, 33 | num_workers 34 | ): 35 | """ 36 | Dynamic mixing for training 37 | 38 | Args: 39 | clean_dataset_limit: 40 | clean_dataset_offset: 41 | noise_dataset_limit: 42 | noise_dataset_offset: 43 | rir_dataset: 44 | rir_dataset_limit: 45 | rir_dataset_offset: 46 | snr_range: 47 | reverb_proportion: 48 | clean_dataset: scp file 49 | noise_dataset: scp file 50 | sub_sample_length: 51 | sr: 52 | """ 53 | super().__init__() 54 | # acoustics args 55 | self.sr = sr 56 | 57 | # parallel args 58 | self.num_workers = num_workers 59 | 60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")] 61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")] 62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")] 63 | 64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit) 65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit) 66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit) 67 | 68 | if pre_load_clean_dataset: 69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset") 70 | 71 | if pre_load_noise: 72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset") 73 | 74 | if pre_load_rir: 75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset") 76 | 77 | self.clean_dataset_list = clean_dataset_list 78 | self.noise_dataset_list = noise_dataset_list 79 | self.rir_dataset_list = rir_dataset_list 80 | 81 | snr_list = self._parse_snr_range(snr_range) 82 | self.snr_list = snr_list 83 | 84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]" 85 | self.reverb_proportion = reverb_proportion 86 | self.silence_length = silence_length 87 | self.target_dB_FS = target_dB_FS 88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value 89 | self.sub_sample_length = sub_sample_length 90 | 91 | self.length = len(self.clean_dataset_list) 92 | 93 | def __len__(self): 94 | return self.length 95 | 96 | def _preload_dataset(self, file_path_list, remark=""): 97 | waveform_list = Parallel(n_jobs=self.num_workers)( 98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark) 99 | ) 100 | return list(zip(file_path_list, waveform_list)) 101 | 102 | @staticmethod 103 | def _random_select_from(dataset_list): 104 | return random.choice(dataset_list) 105 | 106 | def _select_noise_y(self, target_length): 107 | noise_y = np.zeros(0, dtype=np.float32) 108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32) 109 | remaining_length = target_length 110 | 111 | while remaining_length > 0: 112 | noise_file = self._random_select_from(self.noise_dataset_list) 113 | noise_new_added = load_wav(noise_file, sr=self.sr) 114 | noise_y = np.append(noise_y, noise_new_added) 115 | remaining_length -= len(noise_new_added) 116 | 117 | # 如果还需要添加新的噪声,就插入一个小静音段 118 | if remaining_length > 0: 119 | silence_len = min(remaining_length, len(silence)) 120 | noise_y = np.append(noise_y, silence[:silence_len]) 121 | remaining_length -= silence_len 122 | 123 | if len(noise_y) > target_length: 124 | idx_start = np.random.randint(len(noise_y) - target_length) 125 | noise_y = noise_y[idx_start:idx_start + target_length] 126 | 127 | return noise_y 128 | 129 | @staticmethod 130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6): 131 | """ 132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果 133 | 134 | Args: 135 | clean_y: 纯净语音 136 | noise_y: 噪声 137 | snr (int): 信噪比 138 | target_dB_FS (int): 139 | target_dB_FS_floating_value (int): 140 | rir: room impulse response, None 或 np.array 141 | eps: eps 142 | 143 | Returns: 144 | (noisy_y,clean_y) 145 | """ 146 | if rir is not None: 147 | if rir.ndim > 1: 148 | rir_idx = np.random.randint(0, rir.shape[0]) 149 | rir = rir[rir_idx, :] 150 | 151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)] 152 | 153 | clean_y, _ = norm_amplitude(clean_y) 154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS) 155 | clean_rms = (clean_y ** 2).mean() ** 0.5 156 | 157 | noise_y, _ = norm_amplitude(noise_y) 158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS) 159 | noise_rms = (noise_y ** 2).mean() ** 0.5 160 | 161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps) 162 | noise_y *= snr_scalar 163 | noisy_y = clean_y + noise_y 164 | 165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value 166 | noisy_target_dB_FS = np.random.randint( 167 | target_dB_FS - target_dB_FS_floating_value, 168 | target_dB_FS + target_dB_FS_floating_value 169 | ) 170 | 171 | # 使用 noisy 的 rms 放缩音频 172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS) 173 | clean_y *= noisy_scalar 174 | 175 | # 合成带噪语音的时候可能会 clipping,虽然极少 176 | # 对 noisy, clean_y, noise_y 稍微进行调整 177 | if is_clipped(noisy_y): 178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1 179 | noisy_y = noisy_y / noisy_y_scalar 180 | clean_y = clean_y / noisy_y_scalar 181 | 182 | return noisy_y, clean_y 183 | 184 | def __getitem__(self, item): 185 | clean_file = self.clean_dataset_list[item] 186 | clean_y = load_wav(clean_file, sr=self.sr) 187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr)) 188 | 189 | noise_y = self._select_noise_y(target_length=len(clean_y)) 190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}" 191 | 192 | snr = self._random_select_from(self.snr_list) 193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion) 194 | 195 | noisy_y, clean_y = self.snr_mix( 196 | clean_y=clean_y, 197 | noise_y=noise_y, 198 | snr=snr, 199 | target_dB_FS=self.target_dB_FS, 200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value, 201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None 202 | ) 203 | 204 | noisy_y = noisy_y.astype(np.float32) 205 | clean_y = clean_y.astype(np.float32) 206 | 207 | return noisy_y, clean_y 208 | -------------------------------------------------------------------------------- /speech_enhance/subband_model/dataset/dataset_validation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import librosa 5 | 6 | from audio_zen.dataset.base_dataset import BaseDataset 7 | from audio_zen.acoustics.utils import load_wav 8 | from audio_zen.utils import basename 9 | 10 | 11 | class Dataset(BaseDataset): 12 | def __init__( 13 | self, 14 | dataset_dir_list, 15 | sr, 16 | ): 17 | """ 18 | Construct DNS validation set 19 | 20 | synthetic/ 21 | with_reverb/ 22 | noisy/ 23 | clean_y/ 24 | no_reverb/ 25 | noisy/ 26 | clean_y/ 27 | """ 28 | super(Dataset, self).__init__() 29 | noisy_files_list = [] 30 | 31 | for dataset_dir in dataset_dir_list: 32 | dataset_dir = Path(dataset_dir).expanduser().absolute() 33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix()) 34 | 35 | self.length = len(noisy_files_list) 36 | self.noisy_files_list = noisy_files_list 37 | self.sr = sr 38 | 39 | def __len__(self): 40 | return self.length 41 | 42 | def __getitem__(self, item): 43 | """ 44 | use the absolute path of the noisy speech to find the corresponding clean speech. 45 | 46 | Notes 47 | with_reverb and no_reverb dirs have same-named files. 48 | If we use `basename`, the problem will be raised (cover) in visualization. 49 | 50 | Returns: 51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name 52 | """ 53 | noisy_file_path = self.noisy_files_list[item] 54 | parent_dir = Path(noisy_file_path).parents[1].name 55 | noisy_filename, _ = basename(noisy_file_path) 56 | 57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename 58 | 59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化 60 | if parent_dir == "with_reverb": 61 | speech_type = "With_reverb" 62 | elif parent_dir == "no_reverb": 63 | speech_type = "No_reverb" 64 | elif parent_dir == "dns_2_non_english": 65 | speech_type = "Non_english" 66 | elif parent_dir == "dns_2_emotion": 67 | speech_type = "Emotion" 68 | elif parent_dir == "dns_2_singing": 69 | speech_type = "Singing" 70 | else: 71 | raise NotImplementedError(f"Not supported dir: {parent_dir}") 72 | 73 | # Find the corresponding clean speech using "parent_dir" and "file_id" 74 | file_id = noisy_filename.split("_")[-1] 75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"): 76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15 77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}" 78 | elif parent_dir == "dns_2_non_english": 79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121 80 | clean_filename = f"synthetic_clean_fileid_{file_id}" 81 | else: 82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300 83 | if parent_dir == "with_reverb": 84 | reverb_remark = "with_reverb" 85 | clean_filename = f"clean_fileid_{file_id}" 86 | 87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}") 88 | 89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr) 90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr) 91 | 92 | return noisy, clean, reverb_remark + noisy_filename, speech_type 93 | -------------------------------------------------------------------------------- /speech_enhance/subband_model/inferencer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/inferencer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/subband_model/inferencer/inferencer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | from audio_zen.acoustics.feature import mag_phase 5 | from audio_zen.acoustics.mask import decompress_cIRM 6 | from audio_zen.inferencer.base_inferencer import BaseInferencer 7 | import soundfile as sf 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | # for log 12 | from utils.logger import log 13 | print=log 14 | 15 | def cumulative_norm(input): 16 | eps = 1e-10 17 | device = input.device 18 | data_type = input.dtype 19 | n_dim = input.ndim 20 | 21 | assert n_dim in (3, 4) 22 | 23 | if n_dim == 3: 24 | n_channels = 1 25 | batch_size, n_freqs, n_frames = input.size() 26 | else: 27 | batch_size, n_channels, n_freqs, n_frames = input.size() 28 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames) 29 | 30 | step_sum = torch.sum(input, dim=1) # [B, T] 31 | step_pow_sum = torch.sum(torch.square(input), dim=1) 32 | 33 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T] 34 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T] 35 | 36 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device) 37 | entry_count = entry_count.reshape(1, n_frames) # [1, T] 38 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T] 39 | 40 | cum_mean = cumulative_sum / entry_count # B, T 41 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T 42 | cum_std = (cum_var + eps).sqrt() # B, T 43 | 44 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames) 45 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames) 46 | 47 | x = (input - cum_mean) / cum_std 48 | 49 | if n_dim == 4: 50 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames) 51 | 52 | return x 53 | 54 | 55 | class Inferencer(BaseInferencer): 56 | def __init__(self, config, checkpoint_path, output_dir): 57 | super().__init__(config, checkpoint_path, output_dir) 58 | 59 | @torch.no_grad() 60 | def mag(self, noisy, inference_args): 61 | noisy_complex = self.torch_stft(noisy) 62 | noisy_mag, noisy_phase = mag_phase(noisy_complex) # [B, F, T] => [B, 1, F, T] 63 | 64 | enhanced_mag = self.model(noisy_mag.unsqueeze(1)).squeeze(1) 65 | 66 | enhanced = self.torch_istft((enhanced_mag, noisy_phase), length=noisy.size(-1), use_mag_phase=True) 67 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 68 | 69 | return enhanced 70 | 71 | @torch.no_grad() 72 | def scaled_mask(self, noisy, inference_args): 73 | noisy_complex = self.torch_stft(noisy) 74 | noisy_mag, noisy_phase = mag_phase(noisy_complex) 75 | 76 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2] 77 | noisy_mag = noisy_mag.unsqueeze(1) 78 | scaled_mask = self.model(noisy_mag) 79 | scaled_mask = scaled_mask.permute(0, 2, 3, 1) 80 | 81 | enhanced_complex = noisy_complex * scaled_mask 82 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1), use_mag_phase=False) 83 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 84 | 85 | return enhanced 86 | 87 | @torch.no_grad() 88 | def sub_band_crm_mask(self, noisy, inference_args): 89 | pad_mode = inference_args["pad_mode"] 90 | n_neighbor = inference_args["n_neighbor"] 91 | 92 | noisy = noisy.cpu().numpy().reshape(-1) 93 | noisy_D = self.librosa_stft(noisy) 94 | 95 | noisy_real = torch.tensor(noisy_D.real, device=self.device) 96 | noisy_imag = torch.tensor(noisy_D.imag, device=self.device) 97 | noisy_mag = torch.sqrt(torch.square(noisy_real) + torch.square(noisy_imag)) # [F, T] 98 | n_freqs, n_frames = noisy_mag.size() 99 | 100 | noisy_mag = noisy_mag.reshape(1, 1, n_freqs, n_frames) 101 | noisy_mag_padded = self._unfold(noisy_mag, pad_mode, n_neighbor) # [B, N, C, F_s, T] <=> [1, 257, 1, 31, T] 102 | noisy_mag_padded = noisy_mag_padded.squeeze(0).squeeze(1) # [257, 31, 200] <=> [B, F_s, T] 103 | 104 | pred_crm = self.model(noisy_mag_padded).detach() # [B, 2, T] <=> [F, 2, T] 105 | pred_crm = pred_crm.permute(0, 2, 1).contiguous() # [B, T, 2] 106 | 107 | lim = 9.99 108 | pred_crm = lim * (pred_crm >= lim) - lim * (pred_crm <= -lim) + pred_crm * (torch.abs(pred_crm) < lim) 109 | pred_crm = -10 * torch.log((10 - pred_crm) / (10 + pred_crm)) 110 | 111 | enhanced_real = pred_crm[:, :, 0] * noisy_real - pred_crm[:, :, 1] * noisy_imag 112 | enhanced_imag = pred_crm[:, :, 1] * noisy_real + pred_crm[:, :, 0] * noisy_imag 113 | 114 | enhanced_real = enhanced_real.cpu().numpy() 115 | enhanced_imag = enhanced_imag.cpu().numpy() 116 | enhanced = self.librosa_istft(enhanced_real + 1j * enhanced_imag, length=len(noisy)) 117 | return enhanced 118 | 119 | @torch.no_grad() 120 | def full_band_crm_mask(self, noisy, inference_args): 121 | noisy_complex = self.torch_stft(noisy) 122 | noisy_mag, _ = mag_phase(noisy_complex) 123 | 124 | noisy_mag = noisy_mag.unsqueeze(1) 125 | t1 = time.time() 126 | pred_crm = self.model(noisy_mag) 127 | t2 = time.time() 128 | pred_crm = pred_crm.permute(0, 2, 3, 1) 129 | 130 | pred_crm = decompress_cIRM(pred_crm) 131 | enhanced_real = pred_crm[..., 0] * noisy_complex.real - pred_crm[..., 1] * noisy_complex.imag 132 | enhanced_imag = pred_crm[..., 1] * noisy_complex.real + pred_crm[..., 0] * noisy_complex.imag 133 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1) 134 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1)) 135 | enhanced = enhanced.detach().squeeze(0).cpu().numpy() 136 | 137 | # rtf计算 138 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"]) 139 | print(f"model rtf: {rtf}") 140 | 141 | return enhanced 142 | 143 | 144 | @torch.no_grad() 145 | def overlapped_chunk(self, noisy, inference_args): 146 | sr = self.acoustic_config["sr"] 147 | 148 | noisy = noisy.squeeze(0) 149 | 150 | num_mics = 8 151 | chunk_length = sr * inference_args["chunk_length"] 152 | chunk_hop_length = chunk_length // 2 153 | num_chunks = int(noisy.shape[-1] / chunk_hop_length) + 1 154 | 155 | win = torch.hann_window(chunk_length, device=noisy.device) 156 | 157 | prev = None 158 | enhanced = None 159 | # 模拟语音的静音段,防止一上来就给语音,处理的不好 160 | for chunk_idx in range(num_chunks): 161 | if chunk_idx == 0: 162 | pad = torch.zeros((num_mics, 256), device=noisy.device) 163 | 164 | chunk_start_position = chunk_idx * chunk_hop_length 165 | chunk_end_position = chunk_start_position + chunk_length 166 | 167 | # concat([(8, 256), (..., ... + chunk_length)]) 168 | noisy_chunk = torch.cat((pad, noisy[:, chunk_start_position:chunk_end_position]), dim=1) 169 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0)) 170 | enhanced_chunk = torch.squeeze(enhanced_chunk) 171 | enhanced_chunk = enhanced_chunk[256:] 172 | 173 | # Save the prior half chunk, 174 | cur = enhanced_chunk[:chunk_length // 2] 175 | 176 | # only for the 1st chunk,no overlap for the very 1st chunk prior half 177 | prev = enhanced_chunk[chunk_length // 2:] * win[chunk_length // 2:] 178 | else: 179 | # use the previous noisy data as the pad 180 | pad = noisy[:, (chunk_idx * chunk_hop_length - 256):(chunk_idx * chunk_hop_length)] 181 | 182 | chunk_start_position = chunk_idx * chunk_hop_length 183 | chunk_end_position = chunk_start_position + chunk_length 184 | 185 | noisy_chunk = torch.cat((pad, noisy[:8, chunk_start_position:chunk_end_position]), dim=1) 186 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0)) 187 | enhanced_chunk = torch.squeeze(enhanced_chunk) 188 | enhanced_chunk = enhanced_chunk[256:] 189 | 190 | # 使用这个窗函数来对拼接的位置进行平滑? 191 | enhanced_chunk = enhanced_chunk * win[:len(enhanced_chunk)] 192 | 193 | tmp = enhanced_chunk[:chunk_length // 2] 194 | cur = tmp[:min(len(tmp), len(prev))] + prev[:min(len(tmp), len(prev))] 195 | prev = enhanced_chunk[chunk_length // 2:] 196 | 197 | if enhanced is None: 198 | enhanced = cur 199 | else: 200 | enhanced = torch.cat((enhanced, cur), dim=0) 201 | 202 | enhanced = enhanced[:noisy.shape[1]] 203 | return enhanced.detach().squeeze(0).cpu().numpy() 204 | 205 | @torch.no_grad() 206 | def time_domain(self, noisy, inference_args): 207 | noisy = noisy.to(self.device) 208 | enhanced = self.model(noisy) 209 | return enhanced.detach().squeeze().cpu().numpy() 210 | 211 | 212 | if __name__ == '__main__': 213 | a = torch.rand(10, 2, 161, 200) 214 | print(cumulative_norm(a).shape) 215 | -------------------------------------------------------------------------------- /speech_enhance/subband_model/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/model/__init__.py -------------------------------------------------------------------------------- /speech_enhance/subband_model/model/subband_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional 3 | 4 | from audio_zen.acoustics.feature import drop_band 5 | from audio_zen.model.base_model import BaseModel 6 | from audio_zen.model.module.sequence_model import SequenceModel 7 | 8 | # for log 9 | from utils.logger import log 10 | 11 | print = log 12 | 13 | 14 | class Subband_model(BaseModel): 15 | def __init__(self, 16 | look_ahead, 17 | sequence_model, 18 | sb_num_neighbors, 19 | sb_output_activate_function, 20 | sb_model_hidden_size, 21 | norm_type="offline_laplace_norm", 22 | num_groups_in_drop_band=2, 23 | weight_init=True, 24 | ): 25 | """ 26 | FullSubNet model (cIRM mask) 27 | 28 | Args: 29 | num_freqs: Frequency dim of the input 30 | sb_num_neighbors: Number of the neighbor frequencies in each side 31 | look_ahead: Number of use of the future frames 32 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM) 33 | """ 34 | super().__init__() 35 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM." 36 | 37 | self.sb_model = SequenceModel( 38 | input_size=(sb_num_neighbors * 2 + 1), 39 | output_size=2, 40 | hidden_size=sb_model_hidden_size, 41 | num_layers=2, 42 | bidirectional=False, 43 | sequence_model=sequence_model, 44 | output_activate_function=sb_output_activate_function 45 | ) 46 | 47 | self.sb_num_neighbors = sb_num_neighbors 48 | self.look_ahead = look_ahead 49 | self.norm = self.norm_wrapper(norm_type) 50 | self.num_groups_in_drop_band = num_groups_in_drop_band 51 | 52 | if weight_init: 53 | self.apply(self.weight_init) 54 | 55 | def forward(self, noisy_mag): 56 | """ 57 | Args: 58 | noisy_mag: noisy magnitude spectrogram 59 | 60 | Returns: 61 | The real part and imag part of the enhanced spectrogram 62 | 63 | Shapes: 64 | noisy_mag: [B, 1, F, T] 65 | return: [B, 2, F, T] 66 | """ 67 | assert noisy_mag.dim() == 4 68 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead 69 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size() 70 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs." 71 | 72 | # Unfold noisy input, [B, N=F, C, F_s, T] 73 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors) 74 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, 75 | num_frames) 76 | 77 | sb_input = self.norm(noisy_mag_unfolded) 78 | 79 | # Speeding up training without significant performance degradation. These will be updated to the paper later. 80 | if batch_size > 1: 81 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3), 82 | num_groups=self.num_groups_in_drop_band) # [B, F_s, F//num_groups, T] 83 | num_freqs = sb_input.shape[2] 84 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, F_s, T] 85 | 86 | sb_input = sb_input.reshape( 87 | batch_size * num_freqs, 88 | (self.sb_num_neighbors * 2 + 1), 89 | num_frames 90 | ) 91 | 92 | # [B * F, F_s, T] => [B * F, 2, T] => [B, F, 2, T] 93 | sb_mask = self.sb_model(sb_input) 94 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous() 95 | 96 | output = sb_mask[:, :, :, self.look_ahead:] 97 | return output 98 | 99 | 100 | class Subband_model_Large(BaseModel): 101 | def __init__(self, 102 | num_freqs, 103 | look_ahead, 104 | sequence_model, 105 | fb_num_neighbors, 106 | sb_num_neighbors, 107 | fb_output_activate_function, 108 | sb_output_activate_function, 109 | fb_model_hidden_size, 110 | sb_model_hidden_size, 111 | norm_type="offline_laplace_norm", 112 | num_groups_in_drop_band=2, 113 | weight_init=True, 114 | ): 115 | """ 116 | Subband model (cIRM mask) 117 | 118 | Args: 119 | num_freqs: Frequency dim of the input 120 | sb_num_neighbors: Number of the neighbor frequencies in each side 121 | look_ahead: Number of use of the future frames 122 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM) 123 | """ 124 | super().__init__() 125 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM." 126 | 127 | self.sb_model = SequenceModel( 128 | input_size=(sb_num_neighbors * 2 + 1), 129 | output_size=2, 130 | hidden_size=sb_model_hidden_size, 131 | num_layers=3, 132 | bidirectional=False, 133 | sequence_model=sequence_model, 134 | output_activate_function=sb_output_activate_function 135 | ) 136 | 137 | self.sb_num_neighbors = sb_num_neighbors 138 | self.look_ahead = look_ahead 139 | self.norm = self.norm_wrapper(norm_type) 140 | self.num_groups_in_drop_band = num_groups_in_drop_band 141 | 142 | if weight_init: 143 | self.apply(self.weight_init) 144 | 145 | def forward(self, noisy_mag): 146 | """ 147 | Args: 148 | noisy_mag: noisy magnitude spectrogram 149 | 150 | Returns: 151 | The real part and imag part of the enhanced spectrogram 152 | 153 | Shapes: 154 | noisy_mag: [B, 1, F, T] 155 | return: [B, 2, F, T] 156 | """ 157 | assert noisy_mag.dim() == 4 158 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead 159 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size() 160 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs." 161 | 162 | 163 | # Unfold noisy input, [B, N=F, C, F_s, T] 164 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors) 165 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, 166 | num_frames) 167 | 168 | 169 | sb_input = self.norm(noisy_mag_unfolded) 170 | 171 | 172 | if batch_size > 1: 173 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3), 174 | num_groups=self.num_groups_in_drop_band) # [B, F_s, F//num_groups, T] 175 | num_freqs = sb_input.shape[2] 176 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, F_s, T] 177 | 178 | sb_input = sb_input.reshape( 179 | batch_size * num_freqs, 180 | (self.sb_num_neighbors * 2 + 1), 181 | num_frames 182 | ) 183 | 184 | # [B * F, (F_s), T] => [B * F, 2, T] => [B, F, 2, T] 185 | sb_mask = self.sb_model(sb_input) 186 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous() 187 | 188 | output = sb_mask[:, :, :, self.look_ahead:] 189 | return output 190 | 191 | 192 | if __name__ == "__main__": 193 | model = Subband_model( 194 | look_ahead=2, 195 | sequence_model="LSTM", 196 | sb_num_neighbors=15, 197 | sb_output_activate_function=False, 198 | sb_model_hidden_size=384, 199 | weight_init=False, 200 | norm_type="offline_laplace_norm", 201 | num_groups_in_drop_band=2, 202 | ) 203 | 204 | -------------------------------------------------------------------------------- /speech_enhance/subband_model/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/trainer/__init__.py -------------------------------------------------------------------------------- /speech_enhance/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/tools/__init__.py -------------------------------------------------------------------------------- /speech_enhance/tools/analyse.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import matplotlib 4 | 5 | def read_from_txt(filename): 6 | ans_dict = {} 7 | with open(filename, "r") as f: # 打开文件 8 | data = f.readlines() # 读取文件 9 | for line in data: 10 | line = line.strip('\n') 11 | lines = line.split(" ") 12 | ans_dict[lines[0]] = float(lines[1]) 13 | return ans_dict 14 | 15 | def write_to_txt(filename, total_list): 16 | with open(filename, 'w+') as temp_file: 17 | for i1, sisdr in total_list: 18 | string = i1 + ": " + str(sisdr) + '\n' 19 | temp_file.write(string) 20 | 21 | def takeSecond(elem): 22 | return elem[1] 23 | 24 | def make_rank(sidir_list): 25 | sidir_list.sort(key=takeSecond, reverse=True) 26 | 27 | def compare_two_data(data1, data2): 28 | ans_list = [] 29 | for wav in data1.keys(): 30 | num1 = data1[wav] 31 | num2 = data2[wav] 32 | ans_list.append((wav, num1 - num2)) 33 | make_rank(ans_list) 34 | return ans_list 35 | 36 | 37 | 38 | def draw_hist(data, filename): 39 | plt.hist(data, facecolor="blue", edgecolor="black", alpha=0.7) 40 | plt.xlabel("Interval") 41 | plt.ylabel("Frequency") 42 | plt.title("Frequency Distribution Histogram") 43 | plt.show() 44 | plt.savefig(filename) 45 | 46 | def draw_two_hist(data1, data1_name, data2, data2_name, filename): 47 | # bins = np.linspace(5, 30, 5) 48 | # plt.hist(data1, bins, edgecolor="black", alpha=0.7, label=data1_name) 49 | # plt.hist(data2, bins, edgecolor="black", alpha=0.7, label=data2_name) 50 | plt.hist(data1, edgecolor="black", alpha=0.7, label=data1_name) 51 | plt.hist(data2, edgecolor="black", alpha=0.7, label=data2_name) 52 | plt.legend(loc='upper right') 53 | plt.xlabel("Interval(Score)") 54 | plt.ylabel("Frequency") 55 | plt.title("Frequency Distribution Histogram") 56 | plt.show() 57 | plt.savefig(filename) 58 | 59 | 60 | data1 = read_from_txt("/workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/fullsubnet_plus/NB_PESQ.txt") 61 | data2 = read_from_txt("/workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/our_fullsubnet/NB_PESQ.txt") 62 | 63 | # print(compare_two_data(data1, data2)) 64 | write_to_txt("/workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/compare/compare_NB_PESQ.txt", compare_two_data(data1, data2)) -------------------------------------------------------------------------------- /speech_enhance/tools/calculate_metrics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from inspect import getmembers, isfunction 6 | from pathlib import Path 7 | 8 | import librosa 9 | import numpy as np 10 | from joblib import Parallel, delayed 11 | from tqdm import tqdm 12 | 13 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 14 | import audio_zen.metrics as metrics 15 | from audio_zen.utils import prepare_empty_dir 16 | 17 | 18 | def load_wav_paths_from_scp(scp_path, to_abs=True): 19 | wav_paths = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(scp_path)), "r")] 20 | if to_abs: 21 | tmp = [] 22 | for path in wav_paths: 23 | tmp.append(os.path.abspath(os.path.expanduser(path))) 24 | wav_paths = tmp 25 | return wav_paths 26 | 27 | 28 | def shrink_multi_channel_path( 29 | full_dataset_list: list, 30 | num_channels: int 31 | ) -> list: 32 | """ 33 | 34 | Args: 35 | full_dataset_list: [ 36 | 028000010_room1_rev_RT600.06_mic1_micpos1.5p0.5p1.93_srcpos0.46077p1.1p1.68_langle180_angle150_ds1.2_mic1.wav 37 | ... 38 | 028000010_room1_rev_RT600.06_mic1_micpos1.5p0.5p1.93_srcpos0.46077p1.1p1.68_langle180_angle150_ds1.2_mic2.wav 39 | ] 40 | num_channels: 41 | 42 | Returns: 43 | 44 | """ 45 | assert len(full_dataset_list) % num_channels == 0, "Num error" 46 | 47 | shrunk_dataset_list = [] 48 | for index in range(0, len(full_dataset_list), num_channels): 49 | full_path = full_dataset_list[index] 50 | shrunk_path = f"{'_'.join(full_path.split('_')[:-1])}.wav" 51 | shrunk_dataset_list.append(shrunk_path) 52 | 53 | assert len(shrunk_dataset_list) == len(full_dataset_list) // num_channels 54 | return shrunk_dataset_list 55 | 56 | 57 | def get_basename(path): 58 | return os.path.splitext(os.path.basename(path))[0] 59 | 60 | 61 | def pre_processing(est, ref, specific_dataset=None): 62 | ref = Path(ref).expanduser().absolute() 63 | est = Path(est).expanduser().absolute() 64 | 65 | if ref.is_dir(): 66 | reference_wav_paths = librosa.util.find_files(ref.as_posix(), ext="wav") 67 | else: 68 | reference_wav_paths = load_wav_paths_from_scp(ref.as_posix()) 69 | 70 | if est.is_dir(): 71 | estimated_wav_paths = librosa.util.find_files(est.as_posix(), ext="wav") 72 | else: 73 | estimated_wav_paths = load_wav_paths_from_scp(est.as_posix()) 74 | 75 | if not specific_dataset: 76 | # 默认情况下,两个列表应该是一一对应的 77 | check_two_aligned_list(reference_wav_paths, estimated_wav_paths) 78 | else: 79 | # 针对不同的数据集,进行手工对齐,保证两个列表一一对应 80 | reordered_estimated_wav_paths = [] 81 | if specific_dataset == "dns_1": 82 | # 按照 reference_wav_paths 中文件的后缀名重排 estimated_wav_paths 83 | # 提取后缀 84 | for ref_path in reference_wav_paths: 85 | for est_path in estimated_wav_paths: 86 | est_basename = get_basename(est_path) 87 | if "clean_" + "_".join(est_basename.split("_")[-2:]) == get_basename(ref_path): 88 | reordered_estimated_wav_paths.append(est_path) 89 | elif specific_dataset == "dns_2": 90 | for ref_path in reference_wav_paths: 91 | for est_path in estimated_wav_paths: 92 | # synthetic_french_acejour_orleans_sb_64kb-01_jbq2HJt9QXw_snr14_tl-26_fileid_47 93 | # synthetic_clean_fileid_47 94 | est_basename = get_basename(est_path) 95 | file_id = est_basename.split('_')[-1] 96 | if f"synthetic_clean_fileid_{file_id}" == get_basename(ref_path): 97 | reordered_estimated_wav_paths.append(est_path) 98 | elif specific_dataset == "maxhub_noisy": 99 | # Reference_channel = 0 100 | # 寻找对应的干净语音 101 | reference_channel = 0 102 | print(f"Found #files: {len(reference_wav_paths)}") 103 | for est_path in estimated_wav_paths: 104 | # MC0604W0154_room4_rev_RT600.1_mic1_micpos1.5p0.5p1.84_srcpos4.507p1.5945p1.3_langle180_angle20_ds3.2_kesou_kesou_mic1.wav 105 | est_basename = get_basename(est_path) # 带噪的 106 | for ref_path in reference_wav_paths: 107 | ref_basename = get_basename(ref_path) 108 | 109 | else: 110 | raise NotImplementedError(f"Not supported specific dataset {specific_dataset}.") 111 | estimated_wav_paths = reordered_estimated_wav_paths 112 | 113 | return reference_wav_paths, estimated_wav_paths 114 | 115 | 116 | def check_two_aligned_list(a, b): 117 | assert len(a) == len(b), "两个列表中的长度不等." 118 | for z, (i, j) in enumerate(zip(a, b), start=1): 119 | assert get_basename(i) == get_basename(j), f"两个列表中存在不相同的文件名,行数为: {z}" \ 120 | f"\n\t {i}" \ 121 | f"\n\t{j}" 122 | 123 | 124 | def compute_metric(reference_wav_paths, estimated_wav_paths, sr, metric_type="SI_SDR"): 125 | metrics_dict = {o[0]: o[1] for o in getmembers(metrics) if isfunction(o[1])} 126 | assert metric_type in metrics_dict, f"不支持的评价指标: {metric_type}" 127 | metric_function = metrics_dict[metric_type] 128 | if metric_type == "MOSNET": 129 | n_jobs = 1 130 | else: 131 | n_jobs = 40 132 | 133 | def calculate_metric(ref_wav_path, est_wav_path): 134 | ref_wav, _ = librosa.load(ref_wav_path, sr=sr) 135 | est_wav, _ = librosa.load(est_wav_path, sr=sr, mono=False) 136 | if est_wav.ndim > 1: 137 | est_wav = est_wav[0] 138 | 139 | basename = get_basename(ref_wav_path) 140 | 141 | ref_wav_len = len(ref_wav) 142 | est_wav_len = len(est_wav) 143 | 144 | if ref_wav_len != est_wav_len: 145 | print(f"[Warning] ref {ref_wav_len} and est {est_wav_len} are not in the same length") 146 | pass 147 | 148 | return basename, metric_function(ref_wav[:len(est_wav)], est_wav) 149 | 150 | metrics_result_store = Parallel(n_jobs=n_jobs)( 151 | delayed(calculate_metric)(ref, est) for ref, est in tqdm(zip(reference_wav_paths, estimated_wav_paths)) 152 | ) 153 | return metrics_result_store 154 | 155 | def takeSecond(elem): 156 | return elem[1] 157 | 158 | def make_rank(the_list): 159 | the_list.sort(key=takeSecond, reverse=True) 160 | 161 | def write_to_txt(filename, total_list): 162 | with open(filename, 'w+') as temp_file: 163 | for i1, sisdr in total_list: 164 | string = i1 + ": " + str(sisdr) + '\n' 165 | temp_file.write(string) 166 | 167 | def main(args): 168 | sr = args.sr 169 | metric_types = args.metric_types 170 | export_dir = args.export_dir 171 | specific_dataset = args.specific_dataset.lower() 172 | 173 | # 通过指定的 scp 文件或目录获取全部的 wav 样本 174 | reference_wav_paths, estimated_wav_paths = pre_processing(args.estimated, args.reference, specific_dataset) 175 | 176 | # if export_dir: 177 | # export_dir = Path(export_dir).expanduser().absolute() 178 | # prepare_empty_dir([export_dir]) 179 | 180 | print(f"=== {args.estimated} === {args.reference} ===") 181 | for metric_type in metric_types.split(","): 182 | # print(reference_wav_paths) 183 | # print(estimated_wav_paths) 184 | metrics_result_store = compute_metric(reference_wav_paths, estimated_wav_paths, sr, metric_type=metric_type) 185 | 186 | # Print result 187 | print(f"{metric_type}: {metrics_result_store}") 188 | metric_value = np.mean(list(zip(*metrics_result_store))[1]) 189 | print(f"{metric_type}: {metric_value}") 190 | 191 | # Export result 192 | if export_dir: 193 | make_rank(metrics_result_store) 194 | write_to_txt(export_dir + str(metric_type) + ".txt", metrics_result_store) 195 | # import tablib 196 | 197 | # export_path = export_dir / f"{metric_type}.xlsx" 198 | # print(f"Export result to {export_path}") 199 | 200 | # headers = ("Speech", f"{metric_type}") 201 | # metric_seq = [[basename, metric_value] for basename, metric_value in metrics_result_store] 202 | # data = tablib.Dataset(*metric_seq, headers=headers) 203 | # with open(export_path.as_posix(), "wb") as f: 204 | # f.write(data.export("xlsx")) 205 | 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser( 209 | description="输入两个目录或列表,计算各种评价指标的均值", 210 | epilog="python calculate_metrics.py -E 'est_dir' -R 'ref_dir' -M SI_SDR,STOI,WB_PESQ,NB_PESQ,SSNR,LSD,SRMR" 211 | ) 212 | parser.add_argument("-R", "--reference", required=True, type=str, help="") 213 | parser.add_argument("-E", "--estimated", required=True, type=str, help="") 214 | parser.add_argument("-M", "--metric_types", required=True, type=str, help="哪个评价指标,要与 util.metrics 中的内容一致.") 215 | parser.add_argument("--sr", type=int, default=16000, help="采样率") 216 | parser.add_argument("-D", "--export_dir", type=str, default="", help="") 217 | parser.add_argument("--limit", type=int, default=None, help="[正在开发]从列表中读取文件的上限数量.") 218 | parser.add_argument("--offset", type=int, default=0, help="[正在开发]从列表中指定位置开始读取文件.") 219 | parser.add_argument("-S", "--specific_dataset", type=str, default="", help="指定数据集类型,e.g. DNS_1, DNS_2, 大小写均可") 220 | args = parser.parse_args() 221 | main(args) 222 | 223 | """ 224 | TODO 225 | 1. 语音为多通道时如何算 226 | 2. 支持 register, 默认情况下应该计算所有 register 中语音 227 | """ 228 | -------------------------------------------------------------------------------- /speech_enhance/tools/collect_lst.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import sys 4 | from pathlib import Path 5 | import librosa 6 | from tqdm import tqdm 7 | 8 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 9 | from audio_zen.acoustics.mask import is_clipped, load_wav, activity_detector 10 | 11 | 12 | def offset_and_limit(data_list, offset, limit): 13 | data_list = data_list[offset:] 14 | if limit: 15 | data_list = data_list[:limit] 16 | return data_list 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description="FullSubNet") 21 | parser.add_argument('-candidate_datasets', '--candidate_datasets', help='delimited list input', 22 | type=lambda s: [item for item in s.split(',')]) 23 | parser.add_argument("-dist_file", "--dist_file", required=True, type=str, help="output lst") 24 | parser.add_argument("-sr", "--sr", type=int, default=16000, help="sample rate") 25 | parser.add_argument("-wav_min_second", "--wav_min_second", type=int, default=3, help="the min length of a wav") 26 | parser.add_argument("-activity_threshold", "--activity_threshold", type=float, default=0.6, 27 | help="the activity threshold of speech/sil") 28 | parser.add_argument("-total_hrs", "--total_hrs", type=int, default=30, help="the length in time of wav(s)") 29 | 30 | args = parser.parse_args() 31 | 32 | candidate_datasets = args.candidate_datasets 33 | dataset_limit = None 34 | dataset_offset = 0 35 | dist_file = args.dist_file 36 | 37 | # 声学参数 38 | sr = args.sr 39 | wav_min_second = args.wav_min_second 40 | activity_threshold = args.activity_threshold 41 | total_hrs = args.total_hrs # 计划收集语音的总时长 42 | 43 | all_wav_path_list = [] 44 | output_wav_path_list = [] 45 | accumulated_time = 0.0 46 | 47 | is_clipped_wav_list = [] 48 | is_low_activity_list = [] 49 | is_too_short_list = [] 50 | 51 | for dataset_path in candidate_datasets: 52 | dataset_path = Path(dataset_path).expanduser().absolute() 53 | all_wav_path_list += librosa.util.find_files(dataset_path.as_posix(), ext=["wav"]) 54 | 55 | all_wav_path_list = offset_and_limit(all_wav_path_list, dataset_offset, dataset_limit) 56 | random.shuffle(all_wav_path_list) 57 | 58 | for wav_file_path in tqdm(all_wav_path_list, desc="Checking"): 59 | y = load_wav(wav_file_path, sr=sr) 60 | wav_duration = len(y) / sr 61 | wav_file_user_path = wav_file_path.replace(Path(wav_file_path).home().as_posix(), "~") 62 | 63 | is_clipped_wav = is_clipped(y) 64 | is_low_activity = activity_detector(y) < activity_threshold 65 | is_too_short = wav_duration < wav_min_second 66 | 67 | if is_too_short: 68 | is_too_short_list.append(wav_file_user_path) 69 | continue 70 | 71 | if is_clipped_wav: 72 | is_clipped_wav_list.append(wav_file_user_path) 73 | continue 74 | 75 | if is_low_activity: 76 | is_low_activity_list.append(wav_file_user_path) 77 | continue 78 | 79 | if (not is_clipped_wav) and (not is_low_activity) and (not is_too_short): 80 | accumulated_time += wav_duration 81 | output_wav_path_list.append(wav_file_user_path) 82 | 83 | if accumulated_time >= (total_hrs * 3600): 84 | break 85 | 86 | os.makedirs(os.path.dirname(dist_file.as_posix()), exist_ok=True) 87 | with open(dist_file.as_posix(), 'w') as f: 88 | f.writelines(f"{file_path}\n" for file_path in output_wav_path_list) 89 | 90 | print("=" * 70) 91 | print("Speech Preprocessing") 92 | print(f"\t Original files: {len(all_wav_path_list)}") 93 | print(f"\t Selected files: {accumulated_time / 3600} hrs, {len(output_wav_path_list)} files.") 94 | print(f"\t is_clipped_wav: {len(is_clipped_wav_list)}") 95 | print(f"\t is_low_activity: {len(is_low_activity_list)}") 96 | print(f"\t is_too_short: {len(is_too_short_list)}") 97 | print(f"\t dist file:") 98 | print(f"\t {dist_file.as_posix()}") 99 | print("=" * 70) 100 | -------------------------------------------------------------------------------- /speech_enhance/tools/dns_mos.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import json 4 | import os 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import requests 9 | import soundfile as sf 10 | import librosa 11 | 12 | from urllib.parse import urlparse, urljoin 13 | 14 | # URL for the web service 15 | SCORING_URI_DNSMOS = 'https://dnsmos.azurewebsites.net/score' 16 | SCORING_URI_DNSMOS_P835 = 'https://dnsmos.azurewebsites.net/v1/dnsmosp835/score' 17 | # If the service is authenticated, set the key or token 18 | AUTH_KEY = 'd3VoYW4tdW5pdjpkbnNtb3M=' 19 | 20 | # Set the content type 21 | headers = {'Content-Type': 'application/json'} 22 | # If authentication is enabled, set the authorization header 23 | headers['Authorization'] = f'Basic {AUTH_KEY}' 24 | 25 | 26 | def main(args): 27 | print(args.testset_dir) 28 | audio_clips_list = glob.glob(os.path.join(args.testset_dir, "*.wav")) # glob:搜索列表中符合的文件,返回列表 29 | print(audio_clips_list) 30 | scores = [] 31 | dir_path = args.score_file.split('score.csv')[0] 32 | if not os.path.exists(dir_path): 33 | os.makedirs(dir_path) 34 | if not os.path.exists(os.path.join(dir_path, 'file_mos.txt')): 35 | f = open(os.path.join(dir_path, 'file_mos.txt'), 'w') 36 | dict = {} 37 | else: 38 | f = open(os.path.join(dir_path, 'file_mos.txt'), 'r') 39 | dict = {} 40 | lines = f.readlines() 41 | for line in lines: 42 | utt_id = line.split('.wav')[0] 43 | # print('utt_id store', utt_id) 44 | dict[utt_id] = 1 45 | flag = 0 46 | for fpath in audio_clips_list: 47 | utt_id = fpath.split('\\')[-1].split('.wav')[0] 48 | # print('utt_id', utt_id) 49 | if utt_id in dict: 50 | print('find uttid', utt_id) 51 | continue 52 | flag = 1 53 | f = open(os.path.join(dir_path, 'file_mos.txt'), 'a+') 54 | audio, fs = sf.read(fpath) 55 | if fs != 16000: 56 | print('Resample to 16k') 57 | audio = librosa.resample(audio, orig_sr=fs, target_sr=16000) 58 | data = {"data": audio.tolist(), "filename": os.path.basename(fpath)} 59 | input_data = json.dumps(data) 60 | # Make the request and display the response 61 | if args.method == 'p808': 62 | u = SCORING_URI_DNSMOS 63 | else: 64 | u = SCORING_URI_DNSMOS_P835 65 | try_flag = 1 66 | while try_flag: 67 | try: 68 | resp = requests.post(u, data=input_data, headers=headers, timeout=50) 69 | try_flag = 0 70 | score_dict = resp.json() 71 | except: 72 | try_flag = 1 73 | print('retry_1') 74 | continue 75 | try: 76 | score_dict['file_name'] = os.path.basename(fpath) 77 | if args.method == 'p808': 78 | f.write(score_dict['file_name'] + ' ' + str(score_dict['mos']) + '\n') 79 | print(score_dict['mos'], ' ', score_dict['file_name']) 80 | else: 81 | f.write(score_dict['file_name'] + ' SIG[{}], BAK[{}], OVR[{}]'.format(score_dict['mos_sig'], 82 | score_dict['mos_bak'], 83 | score_dict['mos_ovr']) + '\n') 84 | print(score_dict['file_name'] + ' SIG[{}], BAK[{}], OVR[{}]'.format(score_dict['mos_sig'], 85 | score_dict['mos_bak'], 86 | score_dict['mos_ovr'])) 87 | try_flag = 0 88 | except: 89 | try_flag = 1 90 | print('retry_2') 91 | continue 92 | f.close() 93 | scores.append(score_dict) 94 | if flag: 95 | df = pd.DataFrame(scores) 96 | if args.method == 'p808': 97 | print('Mean MOS Score for the files is ', np.mean(df['mos'])) 98 | else: 99 | print('Mean scores for the files: SIG[{}], BAK[{}], OVR[{}]'.format(np.mean(df['mos_sig']), 100 | np.mean(df['mos_bak']), 101 | np.mean(df['mos_ovr']))) 102 | 103 | if args.score_file: 104 | df.to_csv(args.score_file) 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser() 109 | parser.add_argument("--testset_dir", 110 | default=r'C:\Users\cyrillv\Desktop\谱修复样例及DNSMOS指标\test_data\谱修复测试集_yuanjun\noisy', 111 | help='Path to the dir containing audio clips to be evaluated') 112 | parser.add_argument('--score_file', default=r'./谱修复测试集_yuanjun/noisy/score.csv', 113 | help='If you want the scores in a CSV file provide the full path') 114 | parser.add_argument('--method', default='p835', const='p808', nargs='?', choices=['p808', 'p835'], 115 | help='Choose which method to compute P.808 or P.835. Default is P.808') 116 | args = parser.parse_args() 117 | main(args) 118 | -------------------------------------------------------------------------------- /speech_enhance/tools/gen_lst.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import argparse 4 | 5 | def gen_lst(args): 6 | wav_lst = glob(os.path.join(args.dataset_dir, "**/*.wav"), recursive=True) 7 | os.makedirs(os.path.dirname(args.output_lst), exist_ok=True) 8 | fc = open(args.output_lst, "w") 9 | for one_wav in wav_lst: 10 | fc.write(f"{one_wav}\n") 11 | fc.close() 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--dataset_dir", type=str, default="") 16 | parser.add_argument("--output_lst", type=str, default="") 17 | args = parser.parse_args() 18 | 19 | gen_lst(args) 20 | -------------------------------------------------------------------------------- /speech_enhance/tools/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import toml 6 | 7 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 8 | from audio_zen.utils import initialize_module 9 | 10 | 11 | def main(config, checkpoint_path, output_dir): 12 | inferencer_class = initialize_module(config["inferencer"]["path"], initialize=False) 13 | inferencer = inferencer_class( 14 | config, 15 | checkpoint_path, 16 | output_dir 17 | ) 18 | inferencer() 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser("Inference") 23 | parser.add_argument("-C", "--configuration", type=str, required=True, help="Config file.") 24 | parser.add_argument("-M", "--model_checkpoint_path", type=str, required=True, help="The path of the model's checkpoint.") 25 | parser.add_argument('-I', '--dataset_dir_list', help='delimited list input', 26 | type=lambda s: [item.strip() for item in s.split(',')]) 27 | parser.add_argument("-O", "--output_dir", type=str, required=True, help="The path for saving enhanced speeches.") 28 | args = parser.parse_args() 29 | 30 | configuration = toml.load(args.configuration) 31 | checkpoint_path = args.model_checkpoint_path 32 | output_dir = args.output_dir 33 | if len(args.dataset_dir_list) > 0: 34 | print(f"use specified dataset_dir_list: {args.dataset_dir_list}, instead of in config") 35 | configuration["dataset"]["args"]["dataset_dir_list"] = args.dataset_dir_list 36 | 37 | main(configuration, checkpoint_path, output_dir) 38 | -------------------------------------------------------------------------------- /speech_enhance/tools/noisyspeech_synthesizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: chkarada 3 | """ 4 | import glob 5 | import numpy as np 6 | import soundfile as sf 7 | import os 8 | import argparse 9 | import yaml 10 | import configparser as CP 11 | from ..audio.audiolib import audioread, audiowrite, snr_mixer 12 | 13 | def main(cfg): 14 | snr_lower = float(cfg["snr_lower"]) 15 | snr_upper = float(cfg["snr_upper"]) 16 | total_snrlevels = int(cfg["total_snrlevels"]) 17 | 18 | clean_dir = os.path.join(os.path.dirname(__file__), 'clean_train') 19 | if cfg["speech_dir"]!='None': 20 | clean_dir = cfg["speech_dir"] 21 | if not os.path.exists(clean_dir): 22 | assert False, ("Clean speech data is required") 23 | 24 | noise_dir = os.path.join(os.path.dirname(__file__), 'noise_train') 25 | if cfg["noise_dir"]!='None': 26 | noise_dir = cfg["noise_dir"] 27 | if not os.path.exists(noise_dir): 28 | assert False, ("Noise data is required") 29 | 30 | output_dir = os.path.join(os.path.dirname(__file__), 'train_data') 31 | if cfg["noisyspeech_dir"]!='None': 32 | output_dir = cfg["noisyspeech_dir"] 33 | os.makedirs(output_dir, exist_ok=True) 34 | 35 | fs = float(cfg["sampling_rate"]) 36 | audioformat = cfg["audioformat"] 37 | total_hours = float(cfg["total_hours"]) 38 | audio_length = float(cfg["audio_length"]) 39 | silence_length = float(cfg["silence_length"]) 40 | noisyspeech_dir = os.path.join(output_dir, 'NoisySpeech_training') 41 | if not os.path.exists(noisyspeech_dir): 42 | os.makedirs(noisyspeech_dir) 43 | clean_proc_dir = os.path.join(output_dir, 'CleanSpeech_training') 44 | if not os.path.exists(clean_proc_dir): 45 | os.makedirs(clean_proc_dir) 46 | noise_proc_dir = os.path.join(output_dir, 'Noise_training') 47 | if not os.path.exists(noise_proc_dir): 48 | os.makedirs(noise_proc_dir) 49 | 50 | total_secs = total_hours*60*60 51 | total_samples = int(total_secs * fs) 52 | audio_length = int(audio_length*fs) 53 | SNR = np.linspace(snr_lower, snr_upper, total_snrlevels) 54 | cleanfilenames = glob.glob(os.path.join(clean_dir, audioformat)) 55 | if cfg["noise_types_excluded"]=='None': 56 | noisefilenames = glob.glob(os.path.join(noise_dir, audioformat)) 57 | else: 58 | filestoexclude = cfg["noise_types_excluded"].split(',') 59 | noisefilenames = glob.glob(os.path.join(noise_dir, audioformat)) 60 | for i in range(len(filestoexclude)): 61 | noisefilenames = [fn for fn in noisefilenames if not os.path.basename(fn).startswith(filestoexclude[i])] 62 | 63 | filecounter = 0 64 | num_samples = 0 65 | 66 | while num_samples < total_samples: 67 | idx_s = np.random.randint(0, np.size(cleanfilenames)) 68 | clean, fs = audioread(cleanfilenames[idx_s]) 69 | 70 | if len(clean)>audio_length: 71 | clean = clean 72 | 73 | else: 74 | 75 | while len(clean)<=audio_length: 76 | idx_s = idx_s + 1 77 | if idx_s >= np.size(cleanfilenames)-1: 78 | idx_s = np.random.randint(0, np.size(cleanfilenames)) 79 | newclean, fs = audioread(cleanfilenames[idx_s]) 80 | cleanconcat = np.append(clean, np.zeros(int(fs*silence_length))) 81 | clean = np.append(cleanconcat, newclean) 82 | 83 | idx_n = np.random.randint(0, np.size(noisefilenames)) 84 | noise, fs = audioread(noisefilenames[idx_n]) 85 | 86 | if len(noise)>=len(clean): 87 | noise = noise[0:len(clean)] 88 | 89 | else: 90 | 91 | while len(noise)<=len(clean): 92 | idx_n = idx_n + 1 93 | if idx_n >= np.size(noisefilenames)-1: 94 | idx_n = np.random.randint(0, np.size(noisefilenames)) 95 | newnoise, fs = audioread(noisefilenames[idx_n]) 96 | noiseconcat = np.append(noise, np.zeros(int(fs*silence_length))) 97 | noise = np.append(noiseconcat, newnoise) 98 | noise = noise[0:len(clean)] 99 | filecounter = filecounter + 1 100 | 101 | for i in range(np.size(SNR)): 102 | clean_snr, noise_snr, noisy_snr = snr_mixer(clean=clean, noise=noise, snr=SNR[i]) 103 | noisyfilename = 'noisy'+str(filecounter)+'_SNRdb_'+str(SNR[i])+'_clnsp'+str(filecounter)+'.wav' 104 | cleanfilename = 'clnsp'+str(filecounter)+'.wav' 105 | noisefilename = 'noisy'+str(filecounter)+'_SNRdb_'+str(SNR[i])+'.wav' 106 | noisypath = os.path.join(noisyspeech_dir, noisyfilename) 107 | cleanpath = os.path.join(clean_proc_dir, cleanfilename) 108 | noisepath = os.path.join(noise_proc_dir, noisefilename) 109 | audiowrite(noisy_snr, fs, noisypath, norm=False) 110 | audiowrite(clean_snr, fs, cleanpath, norm=False) 111 | audiowrite(noise_snr, fs, noisepath, norm=False) 112 | num_samples = num_samples + len(noisy_snr) 113 | 114 | 115 | if __name__=="__main__": 116 | parser = argparse.ArgumentParser() 117 | # Configurations: read noisyspeech_synthesizer.cfg 118 | parser.add_argument("-c", "--config", type=str, help="Read ./config/config.yaml for all the details", default="./config/config.yaml") 119 | args = parser.parse_args() 120 | 121 | cur_filename = os.path.basename((os.path.relpath(__file__))).split(".")[0] 122 | # config 123 | config = yaml.unsafe_load(open(args.config, "r")) 124 | 125 | main(config) 126 | -------------------------------------------------------------------------------- /speech_enhance/tools/resample_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import argparse 4 | from tqdm import tqdm 5 | from joblib import Parallel, delayed 6 | 7 | def resample_one_wav(wav_file, output_wav_file): 8 | if not os.path.exists(wav_file): 9 | print(f"not found {wav_file}, return") 10 | return 11 | #### 12 | #print(wav_file) 13 | #print(output_wav_file) 14 | os.makedirs(os.path.dirname(output_wav_file), exist_ok=True) 15 | cmd = f"sox {wav_file} -b16 {output_wav_file} rate -v -b 99.7 16k" 16 | os.system(cmd) 17 | 18 | def resample_dir(args): 19 | ### get all wavs 20 | wav_lst = glob(os.path.join(args.dataset_dir, "**/*.wav"), recursive=True) 21 | os.makedirs(args.output_dir, exist_ok=True) 22 | ### 23 | num_workers = 40 24 | Parallel(n_jobs=num_workers)( 25 | delayed(resample_one_wav)(wav_file, wav_file.replace(args.dataset_dir, args.output_dir)) for wav_file in wav_lst) 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--dataset_dir", type=str, default="") 30 | parser.add_argument("--output_dir", type=str, default="") 31 | args = parser.parse_args() 32 | 33 | resample_dir(args) 34 | -------------------------------------------------------------------------------- /speech_enhance/tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import sys 5 | from socket import socket 6 | 7 | import numpy as np 8 | import toml 9 | import torch 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | from torch.utils.data import DataLoader, DistributedSampler 13 | 14 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", ".."))) 15 | import audio_zen.loss as loss 16 | from audio_zen.utils import initialize_module 17 | from utils.logger import init 18 | 19 | # get free gpu automatically 20 | import GPUtil 21 | 22 | def entry(rank, world_size, config, resume, only_validation): 23 | torch.manual_seed(config["meta"]["seed"]) # For both CPU and GPU 24 | np.random.seed(config["meta"]["seed"]) 25 | random.seed(config["meta"]["seed"]) 26 | 27 | os.environ["MASTER_ADDR"] = "localhost" 28 | s = socket() 29 | s.bind(("", 0)) 30 | os.environ["MASTER_PORT"] = "1111" # A random local port 31 | 32 | # Initialize the process group 33 | dist.init_process_group("gloo", rank=rank, world_size=world_size) 34 | 35 | # init log file 36 | if rank==0: 37 | os.makedirs(os.path.join(config["meta"]["save_dir"]), exist_ok=True) 38 | init(os.path.join(config["meta"]["save_dir"], "train.log"), "train", slack_url=None) 39 | 40 | # The DistributedSampler will split the dataset into the several cross-process parts. 41 | # On the contrary, "Sampler=None, shuffle=True", each GPU will get all data in the whole dataset. 42 | 43 | train_dataset = initialize_module(config["train_dataset"]["path"], args=config["train_dataset"]["args"]) 44 | sampler = DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=rank, shuffle=True) 45 | train_dataloader = DataLoader( 46 | dataset=train_dataset, 47 | sampler=sampler, 48 | shuffle=False, 49 | **config["train_dataset"]["dataloader"], 50 | ) 51 | 52 | valid_dataloader = DataLoader( 53 | dataset=initialize_module(config["validation_dataset"]["path"], args=config["validation_dataset"]["args"]), 54 | num_workers=0, 55 | batch_size=1 56 | ) 57 | 58 | model = initialize_module(config["model"]["path"], args=config["model"]["args"]) 59 | 60 | optimizer = torch.optim.Adam( 61 | params=model.parameters(), 62 | lr=config["optimizer"]["lr"], 63 | betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"]) 64 | ) 65 | 66 | loss_function = getattr(loss, config["loss_function"]["name"])(**config["loss_function"]["args"]) 67 | trainer_class = initialize_module(config["trainer"]["path"], initialize=False) 68 | 69 | trainer = trainer_class( 70 | dist=dist, 71 | rank=rank, 72 | config=config, 73 | resume=resume, 74 | only_validation=only_validation, 75 | model=model, 76 | loss_function=loss_function, 77 | optimizer=optimizer, 78 | train_dataloader=train_dataloader, 79 | validation_dataloader=valid_dataloader 80 | ) 81 | 82 | trainer.train() 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser(description="FullSubNet") 87 | parser.add_argument("-C", "--configuration", required=True, type=str, help="Configuration (*.toml).") 88 | parser.add_argument("-R", "--resume", action="store_true", help="Resume the experiment from latest checkpoint.") 89 | parser.add_argument("-V", "--only_validation", action="store_true", help="Only run validation. It is used for debugging validation.") 90 | parser.add_argument("-N", "--num_gpus", type=int, default=0, help="The number of GPUs you are using for training.") 91 | parser.add_argument("-P", "--preloaded_model_path", type=str, help="Path of the *.pth file of a model.") 92 | args = parser.parse_args() 93 | 94 | # set the gpu auto 95 | if args.num_gpus == 0: 96 | device_ids = GPUtil.getAvailable(order = 'first', limit = 8, maxLoad = 0.5, maxMemory = 0.5, includeNan=False, excludeID=[], excludeUUID=[]) 97 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([ str(device_id) for device_id in device_ids ]) 98 | args.num_gpus = len(device_ids) 99 | print(f"gpus: {os.environ['CUDA_VISIBLE_DEVICES']}") 100 | 101 | if args.preloaded_model_path: 102 | assert not args.resume, "The 'resume' conflicts with the 'preloaded_model_path'." 103 | 104 | configuration = toml.load(args.configuration) 105 | 106 | configuration["meta"]["experiment_name"], _ = os.path.splitext(os.path.basename(args.configuration)) 107 | configuration["meta"]["config_path"] = args.configuration 108 | configuration["meta"]["preloaded_model_path"] = args.preloaded_model_path 109 | 110 | # Expand python search path to "recipes" 111 | # sys.path.append(os.path.join(os.getcwd(), "..")) 112 | 113 | # One training job is corresponding to one group (world). 114 | # The world size is the number of processes for training, which is usually the number of GPUs you are using for distributed training. 115 | # the rank is the unique ID given to a process. 116 | # Find more information about DistributedDataParallel (DDP) in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html. 117 | mp.spawn(entry, 118 | args=(args.num_gpus, configuration, args.resume, args.only_validation), 119 | nprocs=args.num_gpus, 120 | join=True) 121 | 122 | -------------------------------------------------------------------------------- /speech_enhance/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/utils/__init__.py -------------------------------------------------------------------------------- /speech_enhance/utils/logger.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | from datetime import datetime 3 | import json 4 | from threading import Thread 5 | from urllib.request import Request, urlopen 6 | import os 7 | 8 | _format = '%Y-%m-%d %H:%M:%S.%f' 9 | _file = None 10 | _run_name = None 11 | _slack_url = None 12 | 13 | def init(filename, run_name, slack_url=None): 14 | os.makedirs(os.path.dirname(filename), exist_ok=True) 15 | 16 | global _file, _run_name, _slack_url 17 | _close_logfile() 18 | _file = open(filename, 'a') 19 | _file.write('\n-----------------------------------------------------------------\n') 20 | _file.write('Starting new training run\n') 21 | _file.write('-----------------------------------------------------------------\n') 22 | _file.flush() 23 | _run_name = run_name 24 | _slack_url = slack_url 25 | 26 | 27 | def log(msg, slack=False): 28 | cur_time = datetime.now().strftime(_format)[:-3] 29 | print('[%s] %s' % (cur_time, msg), end='\n', flush=True) 30 | if _file is not None: 31 | _file.write('[%s] %s\n' % (cur_time, msg)) 32 | _file.flush() 33 | if slack and _slack_url is not None: 34 | Thread(target=_send_slack, args=(msg,)).start() 35 | 36 | def _close_logfile(): 37 | global _file 38 | if _file is not None: 39 | _file.close() 40 | _file = None 41 | 42 | def _send_slack(msg): 43 | req = Request(_slack_url) 44 | req.add_header('Content-Type', 'application/json') 45 | urlopen(req, json.dumps({ 46 | 'username': 'tacotron', 47 | 'icon_emoji': ':taco:', 48 | 'text': '*%s*: %s' % (_run_name, msg) 49 | }).encode()) 50 | 51 | 52 | atexit.register(_close_logfile) 53 | -------------------------------------------------------------------------------- /speech_enhance/utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | matplotlib.use("Agg") 4 | 5 | import numpy as np 6 | 7 | from .logger import log 8 | 9 | 10 | def plot_alignment(alignment, path): 11 | alignment = np.where(alignment < 1, alignment, 1) 12 | # log("min and max:", np.min(alignment), np.max(alignment)) 13 | 14 | fig = plt.figure(figsize=(8, 6)) 15 | ax = fig.add_subplot(111) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | 20 | plt.tight_layout() 21 | plt.savefig(path, format='png') 22 | plt.close() 23 | return 24 | 25 | 26 | def plot_spectrogram(pred_spectrogram, plot_path, title="mel-spec", show=False): 27 | fig = plt.figure(figsize=(20, 10)) 28 | fig.text(0.5, 0.18, title, horizontalalignment='center', fontsize=16) 29 | vmin = np.min(pred_spectrogram) 30 | vmax = np.max(pred_spectrogram) 31 | ax2 = fig.add_subplot(111) 32 | im = ax2.imshow(np.rot90(pred_spectrogram), interpolation='none', 33 | vmin=vmin, vmax=vmax) 34 | char = fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', 35 | ax=ax2) 36 | 37 | # char.set_ticks(np.arange(vmin, vmax)) 38 | char.set_ticks(np.arange(0, 1)) 39 | 40 | plt.tight_layout() 41 | plt.savefig(plot_path, format='png') 42 | if show: 43 | plt.show() 44 | plt.close() 45 | log("save spec png to {}".format(plot_path)) 46 | return 47 | 48 | 49 | def plot_two_spec(pred_spec, target_spec, pic_path, title=None, 50 | vmin=None, vmax=None): 51 | # assert np.shape(pred_spec)[1] == 80 and np.shape(target_spec)[1] == 80 52 | fig = plt.figure(figsize=(12, 8)) 53 | fig.text(0.5, 0.18, title, horizontalalignment='center', fontsize=16) 54 | if vmin is None or vmax is None: 55 | vmin = min(np.min(pred_spec), np.min(target_spec)) 56 | vmax = max(np.max(pred_spec), np.max(target_spec)) 57 | ax1 = fig.add_subplot(211) 58 | ax1.set_title('Predicted Mel-Spectrogram') 59 | im = ax1.imshow(np.rot90(pred_spec), interpolation='none', 60 | vmin=vmin, vmax=vmax) 61 | fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax1) 62 | 63 | ax2 = fig.add_subplot(212) 64 | ax2.set_title('Target Mel-Spectrogram') 65 | im = ax2.imshow(np.rot90(target_spec), interpolation='none', 66 | vmin=vmin, vmax=vmax) 67 | fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax2) 68 | 69 | plt.tight_layout() 70 | plt.savefig(pic_path, format='png') 71 | plt.close() 72 | log("save spec png to {}".format(pic_path)) 73 | return 74 | 75 | 76 | def plot_line(path, x_list, y_list, label_list): 77 | assert len(x_list) == len(y_list) == len(label_list) 78 | plt.title('Result Analysis') 79 | for x_data, y_data, label in zip(x_list, y_list, label_list): 80 | plt.plot(x_data, y_data, label=label) 81 | # plt.plot(x2, y2, color='red', label='predict') 82 | plt.legend() # 显示图例 83 | plt.xlabel('frame-index') 84 | plt.ylabel('value') 85 | plt.savefig(path, format='png') 86 | plt.close() 87 | return 88 | 89 | 90 | def plot_line_phone_time(path, time_pitch_index, pitch_seq, seq_label): 91 | """ show phoneme and pitch in time""" 92 | seq_str = [] 93 | seq_index = [] 94 | pre = seq_label[0] 95 | counter = 0 96 | for i in seq_label: 97 | if i != pre: 98 | seq_str.append(pre) 99 | seq_index.append(counter) 100 | pre = i 101 | else: 102 | counter += 1 103 | fig = plt.figure() 104 | ax1 = fig.add_subplot(111) 105 | ax1.plot(time_pitch_index, pitch_seq, 'r') 106 | ax1.set_ylabel('pitch') 107 | for i in range(len(seq_index)): 108 | plt.vlines(seq_index[i], ymin=0, ymax=700) 109 | plt.text(seq_index[i] - 1, 800, seq_str[i], rotation=39) 110 | 111 | plt.savefig(path, format='png') 112 | return 113 | 114 | 115 | def plot_mel(mel, path, info=None): 116 | mel = mel.T 117 | fig, ax = plt.subplots() 118 | im = ax.imshow( 119 | mel, 120 | aspect='auto', 121 | origin='lower', 122 | interpolation='none') 123 | fig.colorbar(im, ax=ax) 124 | xlabel = 'Decoder timestep' 125 | if info is not None: 126 | xlabel += '\n\n' + info 127 | plt.show() 128 | plt.savefig(path, format='png') 129 | plt.close() 130 | 131 | return fig 132 | 133 | 134 | def plot_one_mel_pitch_energy(mel, pitch, energy, stat_json_file, title, path): 135 | """plot mel/pitch/energy 136 | 137 | Args: 138 | mel: [dim, T] 139 | pitch: [T] 140 | energy: [T] 141 | stat_json_file: stat file 142 | title: titile 143 | path: path for png 144 | 145 | """ 146 | with open(stat_json_file) as f: 147 | stats = json.load(f) 148 | stats = stats["phn_pitch"] + stats["phn_energy"][:2] 149 | # stats = [min(pitch), max(pitch), 0, 1.0, min(energy), max(energy)] 150 | 151 | fig = plot_multi_mel_pitch_energy( 152 | [ 153 | (mel, pitch, energy), 154 | ], 155 | stats, 156 | [title], 157 | ) 158 | plt.savefig(path, format='png') 159 | plt.close() 160 | 161 | 162 | def expand(values, durations): 163 | out = list() 164 | for value, d in zip(values, durations): 165 | out += [value] * max(0, int(d)) 166 | return np.array(out) 167 | 168 | 169 | def plot_multi_mel_pitch_energy(data, stats, titles): 170 | fig, axes = plt.subplots(len(data), 1, squeeze=False, figsize=(6.4, 3.0 * len(data))) 171 | if titles is None: 172 | titles = [None for i in range(len(data))] 173 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats 174 | pitch_min = pitch_min * pitch_std + pitch_mean 175 | pitch_max = pitch_max * pitch_std + pitch_mean 176 | 177 | def add_axis(fig, old_ax): 178 | ax = fig.add_axes(old_ax.get_position(), anchor="W") 179 | ax.set_facecolor("None") 180 | return ax 181 | 182 | for i in range(len(data)): 183 | mel, pitch, energy = data[i] 184 | # log(titles, mel.shape, pitch.shape, energy.shape) 185 | pitch = pitch * pitch_std + pitch_mean 186 | axes[i][0].imshow(mel, origin="lower") 187 | axes[i][0].set_aspect(2.5, adjustable="box") 188 | axes[i][0].set_ylim(0, mel.shape[0]) 189 | axes[i][0].set_title(titles[i], fontsize="medium") 190 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) 191 | axes[i][0].set_anchor("W") 192 | 193 | ax1 = add_axis(fig, axes[i][0]) 194 | # log(pitch) 195 | ax1.plot(pitch, color="tomato") 196 | ax1.set_xlim(0, max(mel.shape[1], len(pitch))) 197 | ax1.set_ylim(0, pitch_max) 198 | ax1.set_ylabel("F0", color="tomato") 199 | ax1.tick_params( 200 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False 201 | ) 202 | 203 | ax2 = add_axis(fig, axes[i][0]) 204 | ax2.plot(energy, color="darkviolet") 205 | ax2.set_xlim(0, max(mel.shape[1], len(energy))) 206 | ax2.set_ylim(energy_min, energy_max) 207 | ax2.set_ylabel("Energy", color="darkviolet") 208 | ax2.yaxis.set_label_position("right") 209 | ax2.tick_params( 210 | labelsize="x-small", 211 | colors="darkviolet", 212 | bottom=False, 213 | labelbottom=False, 214 | left=False, 215 | labelleft=False, 216 | right=True, 217 | labelright=True, 218 | ) 219 | 220 | return fig 221 | 222 | 223 | if __name__ == '__main__': 224 | pass 225 | -------------------------------------------------------------------------------- /speech_enhance/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import yaml 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from .logger import log 9 | 10 | 11 | def touch_dir(d): 12 | os.makedirs(d, exist_ok=True) 13 | 14 | 15 | def is_file_exists(f): 16 | return os.path.exists(f) 17 | 18 | 19 | def check_file_exists(f): 20 | if not os.path.exists(f): 21 | log(f"not found file: {f}") 22 | assert False, f"not found file: {f}" 23 | 24 | 25 | def read_lines(data_path): 26 | lines = [] 27 | with open(data_path, encoding="utf-8") as fr: 28 | for line in fr.readlines(): 29 | if len(line.strip().replace(" ", "")): 30 | lines.append(line.strip()) 31 | # log("read {} lines from {}".format(len(lines), data_path)) 32 | # log("example(last) {}\n".format(lines[-1])) 33 | return lines 34 | 35 | 36 | def write_lines(data_path, lines): 37 | with open(data_path, "w", encoding="utf-8") as fw: 38 | for line in lines: 39 | fw.write("{}\n".format(line)) 40 | # log("write {} lines to {}".format(len(lines), data_path)) 41 | # log("example(last line): {}\n".format(lines[-1])) 42 | return 43 | 44 | 45 | def get_name_from_path(abs_path): 46 | return ".".join(os.path.basename(abs_path).split(".")[:-1]) 47 | 48 | 49 | class AttrDict(dict): 50 | def __init__(self, *args, **kwargs): 51 | super(AttrDict, self).__init__(*args, **kwargs) 52 | self.__dict__ = self 53 | return 54 | 55 | 56 | def load_hparams(yaml_path): 57 | with open(yaml_path, encoding="utf-8") as yaml_file: 58 | hparams = yaml.safe_load(yaml_file) 59 | return AttrDict(hparams) 60 | 61 | 62 | def dump_hparams(yaml_path, hparams): 63 | touch_dir(os.path.dirname(yaml_path)) 64 | with open(yaml_path, "w") as fw: 65 | yaml.dump(hparams, fw) 66 | log("save hparams to {}".format(yaml_path)) 67 | return 68 | 69 | 70 | def get_all_wav_path(file_dir): 71 | wav_list = [] 72 | for path, dir_list, file_list in os.walk(file_dir): 73 | for file_name in file_list: 74 | if file_name.endswith(".wav") or file_name.endswith(".WAV"): 75 | wav_path = os.path.join(path, file_name) 76 | wav_list.append(wav_path) 77 | return sorted(wav_list) 78 | 79 | 80 | def clean_and_new_dir(data_dir): 81 | if os.path.exists(data_dir): 82 | shutil.rmtree(data_dir) 83 | os.makedirs(data_dir) 84 | return 85 | 86 | 87 | def generate_dir_tree(synth_dir, dir_name_list, del_old=False): 88 | os.makedirs(synth_dir, exist_ok=True) 89 | dir_path_list = [] 90 | if del_old: 91 | shutil.rmtree(synth_dir, ignore_errors=True) 92 | for name in dir_name_list: 93 | dir_path = os.path.join(synth_dir, name) 94 | dir_path_list.append(dir_path) 95 | os.makedirs(dir_path, exist_ok=True) 96 | return dir_path_list 97 | 98 | 99 | def str2bool(v): 100 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 101 | return True 102 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 103 | return False 104 | else: 105 | raise argparse.ArgumentTypeError('Boolean value expected.') 106 | 107 | 108 | def pad(input_ele, mel_max_length=None): 109 | if mel_max_length: 110 | max_len = mel_max_length 111 | else: 112 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) 113 | 114 | out_list = list() 115 | for i, batch in enumerate(input_ele): 116 | if len(batch.shape) == 1: 117 | one_batch_padded = F.pad( 118 | batch, (0, max_len - batch.size(0)), "constant", 0.0 119 | ) 120 | elif len(batch.shape) == 2: 121 | one_batch_padded = F.pad( 122 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 123 | ) 124 | out_list.append(one_batch_padded) 125 | out_padded = torch.stack(out_list) 126 | return out_padded 127 | 128 | 129 | def pad_1D(inputs, PAD=0): 130 | def pad_data(x, length, PAD): 131 | x_padded = np.pad( 132 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD 133 | ) 134 | return x_padded 135 | 136 | max_len = max((len(x) for x in inputs)) 137 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) 138 | 139 | return padded 140 | 141 | 142 | def pad_2D(inputs, maxlen=None): 143 | def pad(x, max_len): 144 | PAD = 0 145 | if np.shape(x)[0] > max_len: 146 | raise ValueError("not max_len") 147 | 148 | s = np.shape(x)[1] 149 | x_padded = np.pad( 150 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD 151 | ) 152 | return x_padded[:, :s] 153 | 154 | if maxlen: 155 | output = np.stack([pad(x, maxlen) for x in inputs]) 156 | else: 157 | max_len = max(np.shape(x)[0] for x in inputs) 158 | output = np.stack([pad(x, max_len) for x in inputs]) 159 | 160 | return output 161 | 162 | 163 | def get_mask_from_lengths(lengths, max_len=None): 164 | batch_size = lengths.shape[0] 165 | if max_len is None: 166 | max_len = torch.max(lengths).item() 167 | 168 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) 169 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 170 | 171 | return mask 172 | 173 | 174 | if __name__ == '__main__': 175 | pass 176 | --------------------------------------------------------------------------------