├── .gitignore
├── LICENSE
├── README.md
├── config
├── inference.toml
└── train.toml
├── inference.sh
├── mertrics.sh
├── run.sh
└── speech_enhance
├── __init__.py
├── audio_zen
├── __init__.py
├── acoustics
│ ├── __init__.py
│ ├── beamforming.py
│ ├── feature.py
│ ├── mask.py
│ └── utils.py
├── constant.py
├── dataset
│ ├── __init__.py
│ └── base_dataset.py
├── inferencer
│ ├── __init__.py
│ └── base_inferencer.py
├── loss.py
├── metrics.py
├── model
│ ├── __init__.py
│ ├── base_model.py
│ └── module
│ │ ├── __init__.py
│ │ ├── attention_model.py
│ │ ├── causal_conv.py
│ │ ├── feature_norm.py
│ │ ├── sequence_model.py
│ │ └── si_module.py
├── trainer
│ ├── __init__.py
│ └── base_trainer.py
└── utils.py
├── fullsubnet
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── dataset_inference.py
│ ├── dataset_train.py
│ └── dataset_validation.py
├── inferencer
│ ├── __init__.py
│ └── inferencer.py
├── model
│ ├── __init__.py
│ └── fullsubnet.py
└── trainer
│ ├── __init__.py
│ └── trainer.py
├── fullsubnet_plus
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── dataset_inference.py
│ ├── dataset_train.py
│ └── dataset_validation.py
├── inferencer
│ ├── __init__.py
│ └── inferencer.py
├── model
│ └── fullsubnet_plus.py
└── trainer
│ ├── __init__.py
│ └── trainer.py
├── inter_subnet
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── dataset_inference.py
│ ├── dataset_train.py
│ └── dataset_validation.py
├── inferencer
│ ├── __init__.py
│ └── inferencer.py
├── model
│ ├── Inter_SubNet.py
│ └── __init__.py
└── trainer
│ ├── __init__.py
│ ├── joint_trainer.py
│ └── trainer.py
├── subband_model
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── dataset_inference.py
│ ├── dataset_train.py
│ └── dataset_validation.py
├── inferencer
│ ├── __init__.py
│ └── inferencer.py
├── model
│ ├── __init__.py
│ └── subband_model.py
└── trainer
│ ├── __init__.py
│ ├── joint_trainer.py
│ └── trainer.py
├── tools
├── __init__.py
├── analyse.py
├── calculate_metrics.py
├── collect_lst.py
├── dns_mos.py
├── gen_lst.py
├── inference.py
├── noisyspeech_synthesizer.py
├── resample_dir.py
└── train.py
└── utils
├── __init__.py
├── logger.py
├── plot.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Inter-SubNet
2 |
3 | The official PyTorch implementation of **"[Inter-SubNet: Speech Enhancement with Subband Interaction](https://arxiv.org/abs/2305.05599)"**, accepted by ICASSP 2023.
4 |
5 | 📜[[Full Paper](https://arxiv.org/abs/2305.05599)] ▶[[Demo](https://rookiejunchen.github.io/Inter-SubNet_demo/)] 💿[[Checkpoint](https://drive.google.com/file/d/1j9jdXRxPhXLE93XlYppCQtcOqMOJNjdt/view?usp=share_link)]
6 |
7 |
8 |
9 | ## Requirements
10 |
11 | - Linux or macOS
12 |
13 | - python>=3.6
14 |
15 | - Anaconda or Miniconda
16 |
17 | - NVIDIA GPU + CUDA CuDNN (CPU can also be supported)
18 |
19 |
20 |
21 | ### Environment && Installation
22 |
23 | Install Anaconda or Miniconda, and then install conda and pip packages:
24 |
25 | ```shell
26 | # Create conda environment
27 | conda create --name speech_enhance python=3.8
28 | conda activate speech_enhance
29 |
30 | # Install conda packages
31 | # Check python=3.8, cudatoolkit=10.2, pytorch=1.7.1, torchaudio=0.7
32 | conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
33 | conda install tensorboard joblib matplotlib
34 |
35 | # Install pip packages
36 | # Check librosa=0.8
37 | pip install Cython
38 | pip install librosa pesq pypesq pystoi tqdm toml colorful mir_eval torch_complex
39 |
40 | # (Optional) If you want to load "mp3" format audio in your dataset
41 | conda install -c conda-forge ffmpeg
42 | ```
43 |
44 |
45 |
46 | ### Quick Usage
47 |
48 | Clone the repository:
49 |
50 | ```shell
51 | git clone https://github.com/RookieJunChen/Inter-SubNet.git
52 | cd Inter-SubNet
53 | ```
54 |
55 | Download the [pre-trained checkpoint](https://drive.google.com/file/d/1j9jdXRxPhXLE93XlYppCQtcOqMOJNjdt/view?usp=share_link), and input commands:
56 |
57 | ```shell
58 | source activate speech_enhance
59 | python -m speech_enhance.tools.inference \
60 | -C config/inference.toml \
61 | -M $MODEL_DIR \
62 | -I $INPUT_DIR \
63 | -O $OUTPUT_DIR
64 | ```
65 |
66 |
67 |
68 | ## Start Up
69 |
70 | ### Clone
71 |
72 | ```shell
73 | git clone https://github.com/RookieJunChen/Inter-SubNet.git
74 | cd Inter-SubNet
75 | ```
76 |
77 |
78 |
79 | ### Data preparation
80 |
81 | #### Train data
82 |
83 | Please prepare your data in the data dir as like:
84 |
85 | - data/DNS-Challenge/DNS-Challenge-interspeech2020-master/
86 | - data/DNS-Challenge/DNS-Challenge-master/
87 |
88 | and set the train dir in the script `run.sh`.
89 |
90 | Then:
91 |
92 | ```shell
93 | source activate speech_enhance
94 | bash run.sh 0 # peprare training list or meta file
95 | ```
96 |
97 | #### Test data
98 |
99 | Please prepare your test cases dir like: `data/test_cases_`, and set the test dir in the script `run.sh`.
100 |
101 |
102 |
103 | ### Training
104 |
105 | First, you need to modify the various configurations in `config/train.toml` for training.
106 |
107 | Then you can run training:
108 |
109 | ```shell
110 | source activate speech_enhance
111 | bash run.sh 1
112 | ```
113 |
114 |
115 |
116 | ### Inference
117 |
118 | After training, you can enhance noisy speech. Before inference, you first need to modify the configuration in `config/inference.toml`.
119 |
120 | You can also run inference:
121 |
122 | ```shell
123 | source activate speech_enhance
124 | bash run.sh 2
125 | ```
126 |
127 | Or you can just use `inference.sh`:
128 |
129 | ```shell
130 | source activate speech_enhance
131 | bash inference.sh
132 | ```
133 |
134 |
135 |
136 |
137 |
138 | ### Eval
139 |
140 | Calculating objective metrics (SI_SDR, STOI, WB_PESQ, NB_PESQ, etc.) :
141 |
142 | ```shell
143 | bash metrics.sh
144 | ```
145 |
146 | For test set without reference, you can obtain subjective scores (DNS_MOS and NISQA, etc) through [DNSMOS](https://github.com/RookieJunChen/dns_mos_calculate) and [NISQA](https://github.com/RookieJunChen/my_NISQA).
147 |
148 |
149 |
150 |
151 | ## Citation
152 | If you find our work useful in your research, please consider citing:
153 | ```
154 | @inproceedings{chen2023inter,
155 | title={Inter-Subnet: Speech Enhancement with Subband Interaction},
156 | author={Chen, Jun and Rao, Wei and Wang, Zilin and Lin, Jiuxin and Wu, Zhiyong and Wang, Yannan and Shang, Shidong and Meng, Helen},
157 | booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
158 | pages={1--5},
159 | year={2023},
160 | organization={IEEE}
161 | }
162 | ```
163 |
--------------------------------------------------------------------------------
/config/inference.toml:
--------------------------------------------------------------------------------
1 | [acoustics]
2 | n_fft = 512
3 | win_length = 512
4 | sr = 16000
5 | hop_length = 256
6 |
7 |
8 | [inferencer]
9 | path = "inter_subnet.inferencer.inferencer.Inferencer"
10 | type = "full_band_crm_mask"
11 |
12 | [inferencer.args]
13 | n_neighbor = 15
14 |
15 |
16 | [dataset]
17 | path = "inter_subnet.dataset.dataset_inference.Dataset"
18 | [dataset.args]
19 | dataset_dir_list = [
20 | "/workspace/project-nas-11025-sh/speech_enhance/data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/with_reverb/noisy"
21 | # "/workspace/project-nas-11025-sh/speech_enhance/data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/no_reverb/noisy"
22 | ]
23 | sr = 16000
24 |
25 |
26 | [model]
27 | path = "inter_subnet.model.Inter_SubNet.Inter_SubNet"
28 | [model.args]
29 | sb_num_neighbors = 15
30 | num_freqs = 257
31 | look_ahead = 2
32 | sequence_model = "LSTM"
33 | sb_output_activate_function = false
34 | sb_model_hidden_size = 384
35 | weight_init = false
36 | norm_type = "offline_laplace_norm"
37 | num_groups_in_drop_band = 2
38 | sbinter_middle_hidden_times = 0.8
39 |
40 |
--------------------------------------------------------------------------------
/config/train.toml:
--------------------------------------------------------------------------------
1 | [meta]
2 | save_dir = "logs/Inter_SubNet"
3 | description = "This is a description of Inter-SubNet experiment."
4 | seed = 0 # set random seed for random, numpy, pytorch-gpu and pytorch-cpu
5 | port = "4396"
6 | keep_reproducibility = false # see https://pytorch.org/docs/stable/notes/randomness.html
7 | use_amp = false # use automatic mixed precision, it will benefits Tensor Core-enabled GPU (e.g. Volta, Turing, Ampere). 2-3X speedup。
8 |
9 |
10 | [acoustics]
11 | n_fft = 512
12 | win_length = 512
13 | sr = 16000
14 | hop_length = 256
15 |
16 |
17 | [loss_function]
18 | name = "mse_loss"
19 | [loss_function.args]
20 |
21 |
22 | [optimizer]
23 | lr = 0.001
24 | beta1 = 0.9
25 | beta2 = 0.999
26 |
27 |
28 | [train_dataset]
29 | path = "inter_subnet.dataset.dataset_train.Dataset"
30 | [train_dataset.args]
31 | clean_dataset = "train_data_DNS_2021_16k/clean_book.txt"
32 | clean_dataset_limit = false
33 | clean_dataset_offset = 0
34 | noise_dataset = "train_data_DNS_2021_16k/noise.txt"
35 | noise_dataset_limit = false
36 | noise_dataset_offset = 0
37 | num_workers = 36
38 | pre_load_clean_dataset = false
39 | pre_load_noise = false
40 | pre_load_rir = false
41 | reverb_proportion = 0.75
42 | rir_dataset = "train_data_DNS_2021_16k/rir.txt"
43 | rir_dataset_limit = false
44 | rir_dataset_offset = 0
45 | silence_length = 0.2
46 | snr_range = [-5, 20]
47 | sr = 16000
48 | sub_sample_length = 3.072
49 | target_dB_FS = -25
50 | target_dB_FS_floating_value = 10
51 |
52 |
53 | [train_dataset.dataloader]
54 | batch_size = 20
55 | num_workers = 24
56 | drop_last = true
57 | pin_memory = true
58 |
59 |
60 | [validation_dataset]
61 | path = "inter_subnet.dataset.dataset_validation.Dataset"
62 | [validation_dataset.args]
63 | dataset_dir_list = [
64 | "data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/with_reverb/",
65 | "data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/no_reverb/"
66 | # "/dockerdata/thujunchen/data/DNS_Challenge/test_set/synthetic/with_reverb/",
67 | # "/dockerdata/thujunchen/data/DNS_Challenge/test_set/synthetic/no_reverb/"
68 | ]
69 | sr = 16000
70 |
71 |
72 | [model]
73 | path = "inter_subnet.model.Inter_SubNet.Inter_SubNet"
74 | [model.args]
75 | sb_num_neighbors = 15
76 | num_freqs = 257
77 | look_ahead = 2
78 | sequence_model = "LSTM"
79 | sb_output_activate_function = false
80 | sb_model_hidden_size = 384
81 | weight_init = false
82 | norm_type = "offline_laplace_norm"
83 | num_groups_in_drop_band = 2
84 | sbinter_middle_hidden_times = 0.8
85 |
86 |
87 | [trainer]
88 | path = "inter_subnet.trainer.trainer.New_Judge_Trainer"
89 | [trainer.train]
90 | clip_grad_norm_value = 10
91 | epochs = 9999
92 | alpha = 1
93 | save_checkpoint_interval = 1
94 | [trainer.validation]
95 | save_max_metric_score = true
96 | validation_interval = 1
97 | [trainer.visualization]
98 | metrics = ["WB_PESQ", "NB_PESQ", "STOI", "SI_SDR"]
99 | n_samples = 10
100 | num_workers = 12
101 |
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # do enhance(denoise)
4 | CUDA_VISIBLE_DEVICES='5' python -m speech_enhance.tools.inference \
5 | -C config/inference.toml \
6 | -M /cjcode/ft_local/intersubnet/intersubnet.tar \
7 | -I /DNS_2021/eval/testclips \
8 | -O /enhance_data/dns4_testclips/inter_subnet
9 |
10 |
11 | # Normalized to -6dB (optional)
12 | sdir="enhance_data/dns4_testclips/inter_subnet"
13 | fdir="enhance_data/dns4_testclips/inter_subnet_norm"
14 |
15 | softfiles=$(find $sdir -name "*.wav")
16 | for file in ${softfiles}
17 | do
18 | length=${#sdir}+1
19 | file=${file:$length}
20 | f=$sdir/$file
21 | echo $f
22 | dstfile=$fdir/$file
23 | echo $dstfile
24 | sox $f -b16 $dstfile rate -v -b 99.7 16k norm -6
25 | done
26 |
27 |
28 |
--------------------------------------------------------------------------------
/mertrics.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python speech_enhance/tools/calculate_metrics.py \
4 | -R /workspace/project-nas-11025-sh/speech_enhance/data/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set/synthetic/with_reverb/clean \
5 | -E /workspace/project-nas-11025-sh/speech_enhance/case/with_reverb/fullsubnet+/enhanced_0194 \
6 | -M SI_SDR,STOI,WB_PESQ,NB_PESQ \
7 | -S DNS_1 \
8 | -D /workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/with_reverb_fullsubnet+/
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #set -eux
4 | if test "$#" -eq 1; then
5 | stage=$(($1))
6 | stop_stage=$(($1))
7 | elif test "$#" -eq 2; then
8 | stage=$(($1))
9 | stop_stage=$(($2))
10 | else
11 | stage=0
12 | stop_stage=10
13 | fi
14 |
15 | # corpus data will be under ${data_dir}/${corpus_name}, should be read only for exp
16 | # train data(features, index) will be under ${data_dir}/${train_data_dir}
17 | cur_dir=$(pwd)
18 | data_dir=../../../data/
19 | corpus_name=data
20 | train_data_dir=train_data_fsn_dns_master
21 | exp_id=$(pwd | xargs basename)
22 | spkr_id=$(pwd | xargs dirname | xargs basename)
23 | corpus_id=${spkr_id}_${exp_id}
24 |
25 | #############################################
26 | # prepare files
27 |
28 | if [ ! -d ${corpus_name} ]; then
29 | ln -s ${data_dir} ${corpus_name}
30 | fi
31 |
32 | if [ ! -d ${train_data_dir} ]; then
33 | cpfs_cur_dir=${cur_dir/nas/cpfs}
34 | cpfs_data_dir=${cpfs_cur_dir}/${data_dir}
35 | mkdir -p ${cpfs_data_dir}/${train_data_dir}
36 | ln -s ${cpfs_data_dir}/${train_data_dir} ${train_data_dir}
37 | fi
38 |
39 | #############################################
40 | # generate list of clean/noise for training, dir of synthesic noisy-clean pair for val
41 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
42 | # gen lst
43 | /usr/bin/python3 -m speech_enhance.tools.gen_lst --dataset_dir ${corpus_name}/DNS-Challenge/DNS-Challenge-master/datasets_16k/clean/ --output_lst ${train_data_dir}/clean.txt
44 | /usr/bin/python3 -m speech_enhance.tools.gen_lst --dataset_dir ${corpus_name}/DNS-Challenge/DNS-Challenge-master/datasets_16k/noise/ --output_lst ${train_data_dir}/noise.txt
45 | cat ${corpus_name}/rir/simulated_rirs_16k/*/rir_list ${corpus_name}/rir/RIRS_NOISES/real_rirs_isotropic_noises/rir_list | awk -F ' ' '{print "data/rir/"$5}' > ${train_data_dir}/rir.txt
46 | perl -pi -e 's#data/rir#data/rir_16k#g' ${train_data_dir}/rir.txt
47 | # just use the book wav as interspeech2020
48 | grep book train_data_fsn_dns_master/clean.txt > train_data_fsn_dns_master/clean_book.txt
49 | #
50 | test_set=${corpus_name}/DNS-Challenge/DNS-Challenge-interspeech2020-master/datasets/test_set
51 | if [ ! -d ${test_set} ]; then
52 | echo "please prepare the ${test_set} from https://github.com/microsoft/DNS-Challenge/tree/interspeech2020/master/datasets/test_set"
53 | exit
54 | fi
55 | fi
56 |
57 | # train
58 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
59 | CUDA_VISIBLE_DEVICES='0,1' python -m speech_enhance.tools.train -C config/train.toml -N 2
60 | fi
61 |
62 | # test
63 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
64 | # input 24k audio
65 | input_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s13/logs-s13/acoustic/analysis/gt"
66 | output_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s13/logs-s13/acoustic/analysis/gt_24k_enhance_s1_24k"
67 |
68 | input_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s10_tst_tsv_male_global_pitch_only1/logs-s10_tst_tsv_male_global_pitch_only1/eval-992914"
69 | output_dir="/workspace/project-nas-10691-sh/durian/egs/m2voc/s10_tst_tsv_male_global_pitch_only1/logs-s10_tst_tsv_male_global_pitch_only1/eval-992914-enhance_dns_master_s1_24k"
70 |
71 | input_dir="data/test_cases_didi/eval-992914/"
72 | output_dir="logs/eval/test_cases_didi/eval-992914/"
73 |
74 | # do enhance(denoise)
75 | #CUDA_VISIBLE_DEVICES=0 \
76 | python -m speech_enhance.tools.inference \
77 | -C config/inference.toml \
78 | -M logs/FullSubNet/train/checkpoints/best_model.tar \
79 | -I ${input_dir} \
80 | -O ${output_dir}
81 | #-O logs/FullSubNet/inference/
82 |
83 | # norm volume to -6db
84 | for f in `ls ${output_dir}/*/*.wav | grep -v norm`; do
85 | echo $f
86 | fid=`basename $f .wav`
87 | fo=`dirname $f`/${fid}_norm-6db.wav
88 | echo $fo
89 | sox $f -b16 $fo rate -v -b 99.7 24k norm -6
90 | done
91 | fi
92 |
--------------------------------------------------------------------------------
/speech_enhance/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/acoustics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/acoustics/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/acoustics/beamforming.py:
--------------------------------------------------------------------------------
1 | from torch_complex import ComplexTensor
2 | from torch_complex import functional as FC
3 |
4 |
5 | def apply_crf_filter(cRM_filter: ComplexTensor, mix: ComplexTensor) -> ComplexTensor:
6 | """
7 | Apply complex Ratio Filter
8 |
9 | Args:
10 | cRM_filter: complex Ratio Filter
11 | mix: mixture
12 |
13 | Returns:
14 | [B, C, F, T]
15 | """
16 | # [B, F, T, Filter_delay] x [B, C, F, Filter_delay,T] => [B, C, F, T]
17 | es = FC.einsum("bftd, bcfdt -> bcft", [cRM_filter.conj(), mix])
18 | return es
19 |
20 |
21 | def get_power_spectral_density_matrix(complex_tensor: ComplexTensor) -> ComplexTensor:
22 | """
23 | Cross-channel power spectral density (PSD) matrix
24 |
25 | Args:
26 | complex_tensor: [..., F, C, T]
27 |
28 | Returns
29 | psd: [..., F, C, C]
30 | """
31 | # outer product: [..., C_1, T] x [..., C_2, T] => [..., T, C_1, C_2]
32 | return FC.einsum("...ct,...et->...tce", [complex_tensor, complex_tensor.conj()])
33 |
34 |
35 | def apply_beamforming_vector(beamforming_vector: ComplexTensor, mix: ComplexTensor) -> ComplexTensor:
36 | # [..., C] x [..., C, T] => [..., T]
37 | # There's no relationship between frequencies.
38 | es = FC.einsum("bftc, bfct -> bft", [beamforming_vector.conj(), mix])
39 | return es
40 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/acoustics/mask.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import librosa
4 | import numpy as np
5 | import torch
6 |
7 | from audio_zen.constant import EPSILON
8 |
9 |
10 | def build_ideal_ratio_mask(noisy_mag, clean_mag) -> torch.Tensor:
11 | """
12 |
13 | Args:
14 | noisy_mag: [B, F, T], noisy magnitude
15 | clean_mag: [B, F, T], clean magnitude
16 |
17 | Returns:
18 | [B, F, T, 1]
19 | """
20 | # noisy_mag_finetune = torch.sqrt(torch.square(noisy_mag) + EPSILON)
21 | # ratio_mask = clean_mag / noisy_mag_finetune
22 | ratio_mask = clean_mag / (noisy_mag + EPSILON)
23 | ratio_mask = ratio_mask[..., None]
24 | return compress_cIRM(ratio_mask, K=10, C=0.1)
25 |
26 |
27 | def build_complex_ideal_ratio_mask(noisy: torch.complex64, clean: torch.complex64) -> torch.Tensor:
28 | """
29 |
30 | Args:
31 | noisy: [B, F, T], noisy complex-valued stft coefficients
32 | clean: [B, F, T], clean complex-valued stft coefficients
33 |
34 | Returns:
35 | [B, F, T, 2]
36 | """
37 | denominator = torch.square(noisy.real) + torch.square(noisy.imag) + EPSILON
38 |
39 | mask_real = (noisy.real * clean.real + noisy.imag * clean.imag) / denominator
40 | mask_imag = (noisy.real * clean.imag - noisy.imag * clean.real) / denominator
41 |
42 | complex_ratio_mask = torch.stack((mask_real, mask_imag), dim=-1)
43 |
44 | return compress_cIRM(complex_ratio_mask, K=10, C=0.1)
45 |
46 |
47 | def compress_cIRM(mask, K=10, C=0.1):
48 | """
49 | Compress from (-inf, +inf) to [-K ~ K]
50 | """
51 | if torch.is_tensor(mask):
52 | mask = -100 * (mask <= -100) + mask * (mask > -100)
53 | mask = K * (1 - torch.exp(-C * mask)) / (1 + torch.exp(-C * mask))
54 | else:
55 | mask = -100 * (mask <= -100) + mask * (mask > -100)
56 | mask = K * (1 - np.exp(-C * mask)) / (1 + np.exp(-C * mask))
57 | return mask
58 |
59 |
60 | def decompress_cIRM(mask, K=10, limit=9.9):
61 | mask = limit * (mask >= limit) - limit * (mask <= -limit) + mask * (torch.abs(mask) < limit)
62 | mask = -K * torch.log((K - mask) / (K + mask))
63 | return mask
64 |
65 |
66 | def complex_mul(noisy_r, noisy_i, mask_r, mask_i):
67 | r = noisy_r * mask_r - noisy_i * mask_i
68 | i = noisy_r * mask_i + noisy_i * mask_r
69 | return r, i
70 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/acoustics/utils.py:
--------------------------------------------------------------------------------
1 | import librosa
2 |
3 |
4 | def transform_pesq_range(pesq_score):
5 | """
6 | transform PESQ metric range from [-0.5 ~ 4.5] to [0 ~ 1]
7 | """
8 | return (pesq_score + 0.5) / 5
9 |
10 |
11 | def load_wav(path, sr=16000):
12 | return librosa.load(path, sr=sr)[0]
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/constant.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 |
5 | NEG_INF = torch.finfo(torch.float32).min
6 | PI = math.pi
7 | SOUND_SPEED = 343 # m/s
8 | EPSILON = np.finfo(np.float32).eps
9 | MAX_INT16 = np.iinfo(np.int16).max
10 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/dataset/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/dataset/base_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data
2 |
3 |
4 | class BaseDataset(data.Dataset):
5 | def __init__(self):
6 | super().__init__()
7 |
8 | @staticmethod
9 | def _offset_and_limit(dataset_list, offset, limit):
10 | dataset_list = dataset_list[offset:]
11 | if limit:
12 | dataset_list = dataset_list[:limit]
13 | return dataset_list
14 |
15 | @staticmethod
16 | def _parse_snr_range(snr_range):
17 | assert len(snr_range) == 2, f"The range of SNR should be [low, high], not {snr_range}."
18 | assert snr_range[0] <= snr_range[-1], f"The low SNR should not larger than high SNR."
19 |
20 | low, high = snr_range
21 | snr_list = []
22 | for i in range(low, high + 1, 1):
23 | snr_list.append(i)
24 |
25 | return snr_list
26 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/inferencer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/inferencer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/inferencer/base_inferencer.py:
--------------------------------------------------------------------------------
1 | import time
2 | from functools import partial
3 | from pathlib import Path
4 |
5 | import librosa
6 | import numpy as np
7 | import soundfile as sf
8 | import toml
9 | import torch
10 | from torch.nn import functional
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 | import time
14 |
15 | from audio_zen.acoustics.feature import stft, istft, mc_stft
16 | from audio_zen.utils import initialize_module, prepare_device, prepare_empty_dir
17 |
18 | # for log
19 | from utils.logger import log
20 | print=log
21 |
22 | class BaseInferencer:
23 | def __init__(self, config, checkpoint_path, output_dir):
24 | checkpoint_path = Path(checkpoint_path).expanduser().absolute()
25 | root_dir = Path(output_dir).expanduser().absolute()
26 | self.device = prepare_device(torch.cuda.device_count())
27 |
28 | print("Loading inference dataset...")
29 | self.dataloader = self._load_dataloader(config["dataset"])
30 | print("Loading model...")
31 |
32 | self.model, epoch = self._load_model(config["model"], checkpoint_path, self.device)
33 | # self.model = self._load_model(config["model"], checkpoint_path, self.device)
34 | # epoch = 64
35 |
36 | self.inference_config = config["inferencer"]
37 |
38 | self.enhanced_dir = root_dir / f"enhanced_{str(epoch).zfill(4)}"
39 | prepare_empty_dir([self.enhanced_dir])
40 |
41 | # Acoustics
42 | self.acoustic_config = config["acoustics"]
43 |
44 | # Supported STFT
45 | self.n_fft = self.acoustic_config["n_fft"]
46 | self.hop_length = self.acoustic_config["hop_length"]
47 | self.win_length = self.acoustic_config["win_length"]
48 | self.sr = self.acoustic_config["sr"]
49 |
50 | # See utils_backup.py
51 | self.torch_stft = partial(stft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
52 | self.torch_istft = partial(istft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
53 | self.torch_mc_stft = partial(mc_stft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
54 | self.librosa_stft = partial(librosa.stft, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length)
55 | self.librosa_istft = partial(librosa.istft, hop_length=self.hop_length, win_length=self.win_length)
56 |
57 | print("Configurations are as follows: ")
58 | print(toml.dumps(config))
59 | with open((root_dir / f"{time.strftime('%Y-%m-%d %H:%M:%S')}.toml").as_posix(), "w") as handle:
60 | toml.dump(config, handle)
61 |
62 | @staticmethod
63 | def _load_dataloader(dataset_config):
64 | dataset = initialize_module(dataset_config["path"], args=dataset_config["args"], initialize=True)
65 | dataloader = DataLoader(
66 | dataset=dataset,
67 | batch_size=1,
68 | num_workers=0,
69 | )
70 | return dataloader
71 |
72 | @staticmethod
73 | def _unfold(input, pad_mode, n_neighbor):
74 | """
75 | 沿着频率轴,将语谱图划分为多个 overlap 的子频带
76 |
77 | Args:
78 | input: [B, C, F, T]
79 |
80 | Returns:
81 | [B, N, C, F, T], F 为子频带的频率轴大小, e.g. [2, 161, 1, 19, 200]
82 | """
83 | assert input.dim() == 4, f"The dim of input is {input.dim()}, which should be 4."
84 | batch_size, n_channels, n_freqs, n_frames = input.size()
85 | output = input.reshape(batch_size * n_channels, 1, n_freqs, n_frames)
86 | sub_band_n_freqs = n_neighbor * 2 + 1
87 |
88 | output = functional.pad(output, [0, 0, n_neighbor, n_neighbor], mode=pad_mode)
89 | output = functional.unfold(output, (sub_band_n_freqs, n_frames))
90 | assert output.shape[-1] == n_freqs, f"n_freqs != N (sub_band), {n_freqs} != {output.shape[-1]}"
91 |
92 | # 拆分 unfold 中间的维度
93 | output = output.reshape(batch_size, n_channels, sub_band_n_freqs, n_frames, n_freqs)
94 | output = output.permute(0, 4, 1, 2, 3).contiguous() # permute 本质上与 reshape 可是不同的 ...,得到的维度相同,但 entity 不同啊
95 | return output
96 |
97 | @staticmethod
98 | def _load_model(model_config, checkpoint_path, device):
99 | model = initialize_module(model_config["path"], args=model_config["args"], initialize=True)
100 | model_checkpoint = torch.load(checkpoint_path, map_location="cpu")
101 |
102 | # model_static_dict = model_checkpoint
103 | model_static_dict = model_checkpoint["model"]
104 | epoch = model_checkpoint["epoch"]
105 | print(f"当前正在处理 tar 格式的模型断点,其 epoch 为:{epoch}.")
106 |
107 | model.load_state_dict(model_static_dict)
108 | model.to(device)
109 | model.eval()
110 | return model, model_checkpoint["epoch"]
111 | # return model
112 |
113 | @torch.no_grad()
114 | def multi_channel_mag_to_mag(self, noisy, inference_args=None):
115 | """
116 | 模型的输入为带噪语音的 **幅度谱**,输出同样为 **幅度谱**
117 | """
118 | mixture_stft_coefficients = self.torch_mc_stft(noisy)
119 | mixture_mag = (mixture_stft_coefficients.real ** 2 + mixture_stft_coefficients.imag ** 2) ** 0.5
120 |
121 | enhanced_mag = self.model(mixture_mag)
122 |
123 | # Phase of the reference channel
124 | reference_channel_stft_coefficients = mixture_stft_coefficients[:, 0, ...]
125 | noisy_phase = torch.atan2(reference_channel_stft_coefficients.imag, reference_channel_stft_coefficients.real)
126 | complex_tensor = torch.stack([(enhanced_mag * torch.cos(noisy_phase)), (enhanced_mag * torch.sin(noisy_phase))], dim=-1)
127 | enhanced = self.torch_istft(complex_tensor, length=noisy.shape[-1])
128 |
129 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
130 |
131 | return enhanced
132 |
133 | @torch.no_grad()
134 | def __call__(self):
135 | inference_type = self.inference_config["type"]
136 | assert inference_type in dir(self), f"Not implemented Inferencer type: {inference_type}"
137 |
138 | inference_args = self.inference_config["args"]
139 |
140 | for noisy, name in tqdm(self.dataloader, desc="Inference"):
141 | assert len(name) == 1, "The batch size of inference stage must 1."
142 | name = name[0]
143 |
144 | t1 = time.time()
145 | enhanced = getattr(self, inference_type)(noisy.to(self.device), inference_args)
146 | t2 = time.time()
147 |
148 | if (abs(enhanced) > 1).any():
149 | print(f"Warning: enhanced is not in the range [-1, 1], {name}")
150 |
151 | amp = np.iinfo(np.int16).max
152 | enhanced = np.int16(0.8 * amp * enhanced / np.max(np.abs(enhanced)))
153 |
154 | # cal rtf
155 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"])
156 | print(f"{name}, rtf: {rtf}")
157 |
158 | # clnsp102_traffic_248091_3_snr0_tl-21_fileid_268 => clean_fileid_0
159 | # name = "clean_" + "_".join(name.split("_")[-2:])
160 | sf.write(self.enhanced_dir / f"{name}.wav", enhanced, samplerate=self.acoustic_config["sr"])
161 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | l1_loss = torch.nn.L1Loss
4 | mse_loss = torch.nn.MSELoss
5 |
6 |
7 | def si_snr_loss():
8 | def si_snr(x, s, eps=1e-8):
9 | """
10 |
11 | Args:
12 | x: Enhanced fo shape [B, T]
13 | s: Reference of shape [B, T]
14 | eps:
15 |
16 | Returns:
17 | si_snr: [B]
18 | """
19 | def l2norm(mat, keep_dim=False):
20 | return torch.norm(mat, dim=-1, keepdim=keep_dim)
21 |
22 | if x.shape != s.shape:
23 | raise RuntimeError(f"Dimension mismatch when calculate si_snr, {x.shape} vs {s.shape}")
24 |
25 | x_zm = x - torch.mean(x, dim=-1, keepdim=True)
26 | s_zm = s - torch.mean(s, dim=-1, keepdim=True)
27 |
28 | t = torch.sum(x_zm * s_zm, dim=-1, keepdim=True) * s_zm / (l2norm(s_zm, keep_dim=True) ** 2 + eps)
29 |
30 | return -torch.mean(20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)))
31 |
32 | return si_snr
33 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mir_eval.separation import bss_eval_sources
3 | from pesq import pesq
4 | from pypesq import pesq as nb_pesq
5 | from pystoi.stoi import stoi
6 | import librosa
7 |
8 | def _scale_bss_eval(references, estimate, idx, compute_sir_sar=True):
9 | """
10 | Helper for scale_bss_eval to avoid infinite recursion loop.
11 | """
12 | source = references[..., idx]
13 | source_energy = (source ** 2).sum()
14 |
15 | alpha = (
16 | source @ estimate / source_energy
17 | )
18 |
19 | e_true = source
20 | e_res = estimate - e_true
21 |
22 | signal = (e_true ** 2).sum()
23 | noise = (e_res ** 2).sum()
24 |
25 | snr = 10 * np.log10(signal / noise)
26 |
27 | e_true = source * alpha
28 | e_res = estimate - e_true
29 |
30 | signal = (e_true ** 2).sum()
31 | noise = (e_res ** 2).sum()
32 |
33 | si_sdr = 10 * np.log10(signal / noise)
34 |
35 | srr = -10 * np.log10((1 - (1 / alpha)) ** 2)
36 | sd_sdr = snr + 10 * np.log10(alpha ** 2)
37 |
38 | si_sir = np.nan
39 | si_sar = np.nan
40 |
41 | if compute_sir_sar:
42 | references_projection = references.T @ references
43 |
44 | references_onto_residual = np.dot(references.transpose(), e_res)
45 | b = np.linalg.solve(references_projection, references_onto_residual)
46 |
47 | e_interf = np.dot(references, b)
48 | e_artif = e_res - e_interf
49 |
50 | si_sir = 10 * np.log10(signal / (e_interf ** 2).sum())
51 | si_sar = 10 * np.log10(signal / (e_artif ** 2).sum())
52 |
53 | return si_sdr, si_sir, si_sar, sd_sdr, snr, srr
54 |
55 |
56 | def SDR(reference, estimation, sr=16000):
57 | sdr, _, _, _ = bss_eval_sources(reference[None, :], estimation[None, :])
58 | return sdr
59 |
60 |
61 | def SI_SDR(reference, estimation, sr=16000):
62 | """
63 | Scale-Invariant Signal-to-Distortion Ratio (SI-SDR)
64 |
65 | Args:
66 | reference: numpy.ndarray, [..., T]
67 | estimation: numpy.ndarray, [..., T]
68 |
69 | Returns:
70 | SI-SDR
71 |
72 | References
73 | SDR– Half- Baked or Well Done? (http://www.merl.com/publications/docs/TR2019-013.pdf)
74 | """
75 | estimation, reference = np.broadcast_arrays(estimation, reference)
76 | reference_energy = np.sum(reference ** 2, axis=-1, keepdims=True)
77 |
78 | optimal_scaling = np.sum(reference * estimation, axis=-1, keepdims=True) / reference_energy
79 |
80 | projection = optimal_scaling * reference
81 |
82 | noise = estimation - projection
83 |
84 | ratio = np.sum(projection ** 2, axis=-1) / np.sum(noise ** 2, axis=-1)
85 | return 10 * np.log10(ratio)
86 |
87 |
88 | def STOI(ref, est, sr=16000):
89 | return stoi(ref, est, sr, extended=False)
90 |
91 |
92 | def WB_PESQ(ref, est, sr=16000):
93 | if sr != 16000:
94 | wb_ref = librosa.resample(ref, sr, 16000)
95 | wb_est = librosa.resample(est, sr, 16000)
96 | else:
97 | wb_ref = ref
98 | wb_est = est
99 | # pesq will not downsample internally
100 | return pesq(16000, wb_ref, wb_est, "wb")
101 |
102 |
103 | def NB_PESQ(ref, est, sr=16000):
104 | if sr != 8000:
105 | nb_ref = librosa.resample(ref, sr, 8000)
106 | nb_est = librosa.resample(est, sr, 8000)
107 | else:
108 | nb_ref = ref
109 | nb_est = est
110 | # nb_pesq downsample to 8000 internally.
111 | return nb_pesq(nb_ref, nb_est, 8000)
112 |
113 | mos_metrics = None
114 | def MOSNET(ref, est, sr=16000):
115 | ##
116 | global mos_metrics
117 | if mos_metrics is None:
118 | import speechmetrics
119 | window_length = 10 # seconds
120 | mos_metrics = speechmetrics.load('mosnet', window_length)
121 | ##
122 | scores = mos_metrics(est, rate=sr)
123 | avg_score = np.mean(scores["mosnet"])
124 | #print(avg_score)
125 | return avg_score
126 |
127 | # Only registered metric can be used.
128 | REGISTERED_METRICS = {
129 | "SI_SDR": SI_SDR,
130 | "STOI": STOI,
131 | "WB_PESQ": WB_PESQ,
132 | "NB_PESQ": NB_PESQ,
133 | "MOSNET": MOSNET
134 | }
135 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/model/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/model/module/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/model/module/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/model/module/causal_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CausalConvBlock(nn.Module):
6 | def __init__(self, in_channels, out_channels, encoder_activate_function, **kwargs):
7 | super().__init__()
8 | self.conv = nn.Conv2d(
9 | in_channels=in_channels,
10 | out_channels=out_channels,
11 | kernel_size=(3, 2),
12 | stride=(2, 1),
13 | padding=(0, 1),
14 | **kwargs # 这里不是左右 pad,而是上下 pad 为 0,左右分别 pad 1...
15 | )
16 | self.norm = nn.BatchNorm2d(out_channels)
17 | self.activation = getattr(nn, encoder_activate_function)()
18 |
19 | def forward(self, x):
20 | """
21 | 2D Causal convolution.
22 |
23 | Args:
24 | x: [B, C, F, T]
25 | Returns:
26 | [B, C, F, T]
27 | """
28 | x = self.conv(x)
29 | x = x[:, :, :, :-1] # chomp size
30 | x = self.norm(x)
31 | x = self.activation(x)
32 | return x
33 |
34 |
35 | class CausalTransConvBlock(nn.Module):
36 | def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)):
37 | super().__init__()
38 | self.conv = nn.ConvTranspose2d(
39 | in_channels=in_channels,
40 | out_channels=out_channels,
41 | kernel_size=(3, 2),
42 | stride=(2, 1),
43 | output_padding=output_padding
44 | )
45 | self.norm = nn.BatchNorm2d(out_channels)
46 | if is_last:
47 | self.activation = nn.ReLU()
48 | else:
49 | self.activation = nn.ELU()
50 |
51 | def forward(self, x):
52 | """
53 | 2D Causal convolution.
54 |
55 | Args:
56 | x: [B, C, F, T]
57 | Returns:
58 | [B, C, F, T]
59 | """
60 | x = self.conv(x)
61 | x = x[:, :, :, :-1] # chomp size
62 | x = self.norm(x)
63 | x = self.activation(x)
64 | return x
65 |
66 |
67 | class TCNBlock(nn.Module):
68 | def __init__(self, in_channels=257, hidden_channel=512, out_channels=257, kernel_size=3, dilation=1,
69 | use_skip_connection=True, causal=False):
70 | super().__init__()
71 | self.conv1x1 = nn.Conv1d(in_channels, hidden_channel, 1)
72 | self.prelu1 = nn.PReLU()
73 | self.norm1 = nn.GroupNorm(1, hidden_channel, eps=1e-8)
74 | padding = (dilation * (kernel_size - 1)) // 2 if not causal else (
75 | dilation * (kernel_size - 1))
76 | self.depthwise_conv = nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1,
77 | groups=hidden_channel, padding=padding, dilation=dilation)
78 | self.prelu2 = nn.PReLU()
79 | self.norm2 = nn.GroupNorm(1, hidden_channel, eps=1e-8)
80 | self.sconv = nn.Conv1d(hidden_channel, out_channels, 1)
81 | # self.tcn_block = nn.Sequential(
82 | # nn.Conv1d(in_channels, hidden_channel, 1),
83 | # nn.PReLU(),
84 | # nn.GroupNorm(1, hidden_channel, eps=1e-8),
85 | # nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1,
86 | # groups=hidden_channel, padding=padding, dilation=dilation, bias=True),
87 | # nn.PReLU(),
88 | # nn.GroupNorm(1, hidden_channel, eps=1e-8),
89 | # nn.Conv1d(hidden_channel, out_channels, 1)
90 | # )
91 |
92 | self.causal = causal
93 | self.padding = padding
94 | self.use_skip_connection = use_skip_connection
95 |
96 | def forward(self, x):
97 | """
98 | x: [channels, T]
99 | """
100 | if self.use_skip_connection:
101 | y = self.conv1x1(x)
102 | y = self.norm1(self.prelu1(y))
103 | y = self.depthwise_conv(y)
104 | if self.causal:
105 | y = y[:, :, :-self.padding]
106 | y = self.norm2(self.prelu2(y))
107 | output = self.sconv(y)
108 | return x + output
109 | else:
110 | y = self.conv1x1(x)
111 | y = self.norm1(self.prelu1(y))
112 | y = self.depthwise_conv(y)
113 | if self.causal:
114 | y = y[:, :, :-self.padding]
115 | y = self.norm2(self.prelu2(y))
116 | output = self.sconv(y)
117 | return output
118 |
119 |
120 | class STCNBlock(nn.Module):
121 | def __init__(self, in_channels=257, hidden_channel=512, out_channels=257, kernel_size=3, dilation=1,
122 | use_skip_connection=True, causal=False):
123 | super().__init__()
124 | self.conv1x1 = nn.Conv1d(in_channels, hidden_channel, 1)
125 | self.prelu1 = nn.PReLU()
126 | self.norm1 = nn.GroupNorm(1, hidden_channel, eps=1e-8)
127 | padding = (dilation * (kernel_size - 1)) // 2 if not causal else (
128 | dilation * (kernel_size - 1))
129 | self.depthwise_conv = nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1,
130 | groups=hidden_channel, padding=padding, dilation=dilation)
131 | self.prelu2 = nn.PReLU()
132 | self.norm2 = nn.GroupNorm(1, hidden_channel, eps=1e-8)
133 | self.sconv = nn.Conv1d(hidden_channel, out_channels, 1)
134 | # self.tcn_block = nn.Sequential(
135 | # nn.Conv1d(in_channels, hidden_channel, 1),
136 | # nn.PReLU(),
137 | # nn.GroupNorm(1, hidden_channel, eps=1e-8),
138 | # nn.Conv1d(hidden_channel, hidden_channel, kernel_size=kernel_size, stride=1,
139 | # groups=hidden_channel, padding=padding, dilation=dilation, bias=True),
140 | # nn.PReLU(),
141 | # nn.GroupNorm(1, hidden_channel, eps=1e-8),
142 | # nn.Conv1d(hidden_channel, out_channels, 1)
143 | # )
144 |
145 | self.causal = causal
146 | self.padding = padding
147 | self.use_skip_connection = use_skip_connection
148 |
149 | def forward(self, x):
150 | """
151 | x: [channels, T]
152 | """
153 | if self.use_skip_connection:
154 | y = self.conv1x1(x)
155 | y = self.norm1(self.prelu1(y))
156 | y = self.depthwise_conv(y)
157 | if self.causal:
158 | y = y[:, :, :-self.padding]
159 | y = self.norm2(self.prelu2(y))
160 | output = self.sconv(y)
161 | return x + output
162 | else:
163 | y = self.conv1x1(x)
164 | y = self.norm1(self.prelu1(y))
165 | y = self.depthwise_conv(y)
166 | if self.causal:
167 | y = y[:, :, :-self.padding]
168 | y = self.norm2(self.prelu2(y))
169 | output = self.sconv(y)
170 | return output
171 |
172 |
173 | if __name__ == '__main__':
174 | a = torch.rand(2, 1, 19, 200)
175 | l1 = CausalConvBlock(1, 20, kernel_size=(3, 2), stride=(2, 1), padding=(0, 1), )
176 | l2 = CausalConvBlock(20, 40, kernel_size=(3, 2), stride=(1, 1), padding=1, )
177 | l3 = CausalConvBlock(40, 40, kernel_size=(3, 2), stride=(2, 1), padding=(0, 1), )
178 | l4 = CausalConvBlock(40, 40, kernel_size=(3, 2), stride=(1, 1), padding=1, )
179 | print(l1(a).shape)
180 | print(l4(l3(l2(l1(a)))).shape)
181 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/model/module/feature_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def cumulative_norm(input):
6 | eps = 1e-10
7 |
8 | # [B, C, F, T]
9 | batch_size, n_channels, n_freqs, n_frames = input.size()
10 | device = input.device
11 | data_type = input.dtype
12 |
13 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames)
14 |
15 | step_sum = torch.sum(input, dim=1) # [B, T]
16 | step_pow_sum = torch.sum(torch.square(input), dim=1)
17 |
18 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
19 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T]
20 |
21 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device)
22 | entry_count = entry_count.reshape(1, n_frames) # [1, T]
23 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
24 |
25 | cum_mean = cumulative_sum / entry_count # B, T
26 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T
27 | cum_std = (cum_var + eps).sqrt() # B, T
28 |
29 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames)
30 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames)
31 |
32 | x = (input - cum_mean) / cum_std
33 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames)
34 |
35 | return x
36 |
37 |
38 | class CumulativeMagSpectralNorm(nn.Module):
39 | def __init__(self, cumulative=False, use_mid_freq_mu=False):
40 | """
41 |
42 | Args:
43 | cumulative: 是否采用累积的方式计算 mu
44 | use_mid_freq_mu: 仅采用中心频率的 mu 来代替全局 mu
45 |
46 | Notes:
47 | 先算均值再累加 等同于 先累加再算均值
48 |
49 | """
50 | super().__init__()
51 | self.eps = 1e-6
52 | self.cumulative = cumulative
53 | self.use_mid_freq_mu = use_mid_freq_mu
54 |
55 | def forward(self, input):
56 | assert input.ndim == 4, f"{self.__name__} only support 4D input."
57 | batch_size, n_channels, n_freqs, n_frames = input.size()
58 | device = input.device
59 | data_type = input.dtype
60 |
61 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames)
62 |
63 | if self.use_mid_freq_mu:
64 | step_sum = input[:, int(n_freqs // 2 - 1), :] # [B * C, F, T] => [B * C, T]
65 | else:
66 | step_sum = torch.mean(input, dim=1) # [B * C, F, T] => [B * C, T]
67 |
68 | if self.cumulative:
69 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
70 | entry_count = torch.arange(1, n_frames + 1, dtype=data_type, device=device)
71 | entry_count = entry_count.reshape(1, n_frames) # [1, T]
72 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
73 |
74 | mu = cumulative_sum / entry_count # [B * C, T]
75 | mu = mu.reshape(batch_size * n_channels, 1, n_frames)
76 | else:
77 | mu = torch.mean(step_sum, dim=-1) # [B * C]
78 | mu = mu.reshape(batch_size * n_channels, 1, 1) # [B * C, 1, 1]
79 |
80 | input_normed = input / (mu + self.eps)
81 | input_normed = input_normed.reshape(batch_size, n_channels, n_freqs, n_frames)
82 | return input_normed
83 |
84 |
85 | if __name__ == '__main__':
86 | a = torch.rand(2, 1, 160, 200)
87 | ln = CumulativeMagSpectralNorm(cumulative=False, use_mid_freq_mu=False)
88 | ln_1 = CumulativeMagSpectralNorm(cumulative=True, use_mid_freq_mu=False)
89 | ln_2 = CumulativeMagSpectralNorm(cumulative=False, use_mid_freq_mu=False)
90 | ln_3 = CumulativeMagSpectralNorm(cumulative=True, use_mid_freq_mu=False)
91 | print(ln(a).mean())
92 | print(ln_1(a).mean())
93 | print(ln_2(a).mean())
94 | print(ln_3(a).mean())
95 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/model/module/si_module.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 |
9 | class subband_interaction(nn.Module):
10 | def __init__(self, input_size, hidden_size):
11 | super(subband_interaction, self).__init__()
12 | """
13 | Subband Interaction Module
14 | """
15 |
16 | self.input_linear = nn.Sequential(
17 | nn.Linear(input_size, hidden_size),
18 | nn.PReLU()
19 | )
20 | self.mean_linear = nn.Sequential(
21 | nn.Linear(hidden_size, hidden_size),
22 | nn.PReLU()
23 | )
24 | self.output_linear = nn.Sequential(
25 | nn.Linear(hidden_size * 2, input_size),
26 | nn.PReLU()
27 | )
28 | self.norm = nn.GroupNorm(1, input_size)
29 |
30 | def forward(self, input):
31 | """
32 | input: [B, F, F_s, T]
33 | """
34 | B, G, N, T = input.shape
35 |
36 | # Transform
37 | group_input = input # [B, F, F_s, T]
38 | group_input = group_input.permute(0, 3, 1, 2).contiguous().view(-1, N) # [B * T * F, F_s]
39 | group_output = self.input_linear(group_input).view(B, T, G, -1) # [B, T, F, H]
40 |
41 | # Avg pooling
42 | group_mean = group_output.mean(2).view(B * T, -1) # [B * T, H]
43 |
44 | # Concate and transform
45 | group_output = group_output.view(B * T, G, -1) # [B * T, F, H]
46 | group_mean = self.mean_linear(group_mean).unsqueeze(1).expand_as(group_output).contiguous() # [B * T, F, H]
47 | group_output = torch.cat([group_output, group_mean], 2) # [B * T, F, 2H]
48 | group_output = self.output_linear(group_output.view(-1, group_output.shape[-1])) # [B * T * F, F_s]
49 | group_output = group_output.view(B, T, G, -1).permute(0, 2, 3, 1).contiguous() # [B, F, F_s, T]
50 | group_output = self.norm(group_output.view(B * G, N, T)) # [B * F, F_s, T]
51 | output = input + group_output.view(input.shape) # [B, F, F_s, T]
52 |
53 | return output
54 |
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/audio_zen/trainer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/audio_zen/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import time
4 | from copy import deepcopy
5 | from functools import reduce
6 |
7 | import torch
8 |
9 | # for log
10 | from utils.logger import log
11 | print=log
12 |
13 | def load_checkpoint(checkpoint_path, device):
14 | _, ext = os.path.splitext(os.path.basename(checkpoint_path))
15 | assert ext in (".pth", ".tar"), "Only support ext and tar extensions of l1 checkpoint."
16 | model_checkpoint = torch.load(os.path.abspath(os.path.expanduser(checkpoint_path)), map_location=device)
17 |
18 | if ext == ".pth":
19 | print(f"Loading {checkpoint_path}.")
20 | return model_checkpoint
21 | else: # load tar
22 | print(f"Loading {checkpoint_path}, epoch = {model_checkpoint['epoch']}.")
23 | return model_checkpoint["l1"]
24 |
25 |
26 | def prepare_empty_dir(dirs, resume=False):
27 | """
28 | if resume the experiment, assert the dirs exist. If not the resume experiment, set up new dirs.
29 |
30 | Args:
31 | dirs (list): directors list
32 | resume (bool): whether to resume experiment, default is False
33 | """
34 | for dir_path in dirs:
35 | if resume:
36 | assert dir_path.exists(), "In resume mode, you must be have an old experiment dir."
37 | else:
38 | dir_path.mkdir(parents=True, exist_ok=True)
39 |
40 |
41 | def check_nan(tensor, key=""):
42 | if torch.sum(torch.isnan(tensor)) > 0:
43 | print(f"Found NaN in {key}")
44 |
45 |
46 | class ExecutionTime:
47 | """
48 | Count execution time.
49 |
50 | Examples:
51 | timer = ExecutionTime()
52 | ...
53 | print(f"Finished in {timer.duration()} seconds.")
54 | """
55 |
56 | def __init__(self):
57 | self.start_time = time.time()
58 |
59 | def duration(self):
60 | return int(time.time() - self.start_time)
61 |
62 |
63 | def initialize_module(path: str, args: dict = None, initialize: bool = True):
64 | """
65 | Load module dynamically with "args".
66 |
67 | Args:
68 | path: module path in this project.
69 | args: parameters that passes to the Class or the Function in the module.
70 | initialize: initialize the Class or the Function with args.
71 |
72 | Examples:
73 | Config items are as follows:
74 |
75 | [model]
76 | path = "model.full_sub_net.FullSubNetModel"
77 | [model.args]
78 | n_frames = 32
79 | ...
80 |
81 | This function will:
82 | 1. Load the "model.full_sub_net" module.
83 | 2. Call "FullSubNetModel" Class (or Function) in "model.full_sub_net" module.
84 | 3. If initialize is True:
85 | instantiate (or call) the Class (or the Function) and pass the parameters (in "[model.args]") to it.
86 | """
87 | module_path = ".".join(path.split(".")[:-1])
88 | class_or_function_name = path.split(".")[-1]
89 |
90 | module = importlib.import_module(module_path)
91 | class_or_function = getattr(module, class_or_function_name)
92 |
93 | if initialize:
94 | if args:
95 | return class_or_function(**args)
96 | else:
97 | return class_or_function()
98 | else:
99 | return class_or_function
100 |
101 |
102 | def print_tensor_info(tensor, flag="Tensor"):
103 | def floor_tensor(float_tensor):
104 | return int(float(float_tensor) * 1000) / 1000
105 |
106 | print(
107 | f"{flag}\n"
108 | f"\t"
109 | f"max: {floor_tensor(torch.max(tensor))}, min: {float(torch.min(tensor))}, "
110 | f"mean: {floor_tensor(torch.mean(tensor))}, std: {floor_tensor(torch.std(tensor))}")
111 |
112 |
113 | def set_requires_grad(nets, requires_grad=False):
114 | """
115 | Args:
116 | nets: list of networks
117 | requires_grad
118 | """
119 | if not isinstance(nets, list):
120 | nets = [nets]
121 | for net in nets:
122 | if net is not None:
123 | for param in net.parameters():
124 | param.requires_grad = requires_grad
125 |
126 |
127 | def merge_config(*config_dicts):
128 | """
129 | Deep merge configuration dicts.
130 |
131 | Args:
132 | *config_dicts: any number of configuration dicts.
133 |
134 | Notes:
135 | 1. The values of item in the later configuration dict(s) will update the ones in the former dict(s).
136 | 2. The key in the later dict must be exist in the former dict. It means that the first dict must consists of all keys.
137 |
138 | Examples:
139 | a = [
140 | "a": 1,
141 | "b": 2,
142 | "c": {
143 | "d": 1
144 | }
145 | ]
146 | b = [
147 | "a": 2,
148 | "b": 2,
149 | "c": {
150 | "e": 1
151 | }
152 | ]
153 | c = merge_config(a, b)
154 | c = [
155 | "a": 2,
156 | "b": 2,
157 | "c": {
158 | "d": 1,
159 | "e": 1
160 | }
161 | ]
162 |
163 | Returns:
164 | New deep-copied configuration dict.
165 | """
166 |
167 | def merge(older_dict, newer_dict):
168 | for new_key in newer_dict:
169 | if new_key not in older_dict:
170 | # Checks items in custom config must be within common config
171 | raise KeyError(f"Key {new_key} is not exist in the common config.")
172 |
173 | if isinstance(older_dict[new_key], dict):
174 | older_dict[new_key] = merge(older_dict[new_key], newer_dict[new_key])
175 | else:
176 | older_dict[new_key] = deepcopy(newer_dict[new_key])
177 |
178 | return older_dict
179 |
180 | return reduce(merge, config_dicts[1:], deepcopy(config_dicts[0]))
181 |
182 |
183 | def prepare_device(n_gpu: int, keep_reproducibility=False):
184 | """
185 | Choose to use CPU or GPU depend on the value of "n_gpu".
186 |
187 | Args:
188 | n_gpu(int): the number of GPUs used in the experiment. if n_gpu == 0, use CPU; if n_gpu >= 1, use GPU.
189 | keep_reproducibility (bool): if we need to consider the repeatability of experiment, set keep_reproducibility to True.
190 |
191 | See Also
192 | Reproducibility: https://pytorch.org/docs/stable/notes/randomness.html
193 | """
194 | if n_gpu == 0:
195 | print("Using CPU in the experiment.")
196 | device = torch.device("cpu")
197 | else:
198 | # possibly at the cost of reduced performance
199 | if keep_reproducibility:
200 | print("Using CuDNN deterministic mode in the experiment.")
201 | torch.backends.cudnn.benchmark = False # ensures that CUDA selects the same convolution algorithm each time
202 | torch.set_deterministic(True) # configures PyTorch only to use deterministic implementation
203 | else:
204 | # causes cuDNN to benchmark multiple convolution algorithms and select the fastest
205 | torch.backends.cudnn.benchmark = True
206 |
207 | device = torch.device("cuda:0")
208 |
209 | return device
210 |
211 |
212 | def expand_path(path):
213 | return os.path.abspath(os.path.expanduser(path))
214 |
215 |
216 | def basename(path):
217 | filename, ext = os.path.splitext(os.path.basename(path))
218 | return filename, ext
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/dataset/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/dataset/dataset_inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import librosa
4 | import numpy as np
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.utils import basename
8 |
9 |
10 | class Dataset(BaseDataset):
11 | def __init__(self,
12 | dataset_dir_list,
13 | sr,
14 | ):
15 | """
16 | Args:
17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list
18 | """
19 | super().__init__()
20 | assert isinstance(dataset_dir_list, list)
21 | self.sr = sr
22 |
23 | noisy_file_path_list = []
24 | for dataset_dir in dataset_dir_list:
25 | dataset_dir = Path(dataset_dir).expanduser().absolute()
26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted
27 |
28 | self.noisy_file_path_list = noisy_file_path_list
29 | self.length = len(self.noisy_file_path_list)
30 |
31 | def __len__(self):
32 | return self.length
33 |
34 | def __getitem__(self, item):
35 | noisy_file_path = self.noisy_file_path_list[item]
36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0]
37 | noisy_y = noisy_y.astype(np.float32)
38 |
39 | return noisy_y, basename(noisy_file_path)[0]
40 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/dataset/dataset_train.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample
5 | from audio_zen.dataset.base_dataset import BaseDataset
6 | from audio_zen.utils import expand_path
7 | from joblib import Parallel, delayed
8 | from scipy import signal
9 | from tqdm import tqdm
10 |
11 |
12 | class Dataset(BaseDataset):
13 | def __init__(self,
14 | clean_dataset,
15 | clean_dataset_limit,
16 | clean_dataset_offset,
17 | noise_dataset,
18 | noise_dataset_limit,
19 | noise_dataset_offset,
20 | rir_dataset,
21 | rir_dataset_limit,
22 | rir_dataset_offset,
23 | snr_range,
24 | reverb_proportion,
25 | silence_length,
26 | target_dB_FS,
27 | target_dB_FS_floating_value,
28 | sub_sample_length,
29 | sr,
30 | pre_load_clean_dataset,
31 | pre_load_noise,
32 | pre_load_rir,
33 | num_workers
34 | ):
35 | """
36 | Dynamic mixing for training
37 |
38 | Args:
39 | clean_dataset_limit:
40 | clean_dataset_offset:
41 | noise_dataset_limit:
42 | noise_dataset_offset:
43 | rir_dataset:
44 | rir_dataset_limit:
45 | rir_dataset_offset:
46 | snr_range:
47 | reverb_proportion:
48 | clean_dataset: scp file
49 | noise_dataset: scp file
50 | sub_sample_length:
51 | sr:
52 | """
53 | super().__init__()
54 | # acoustics args
55 | self.sr = sr
56 |
57 | # parallel args
58 | self.num_workers = num_workers
59 |
60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")]
61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")]
62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")]
63 |
64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit)
65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit)
66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit)
67 |
68 | if pre_load_clean_dataset:
69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset")
70 |
71 | if pre_load_noise:
72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset")
73 |
74 | if pre_load_rir:
75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset")
76 |
77 | self.clean_dataset_list = clean_dataset_list
78 | self.noise_dataset_list = noise_dataset_list
79 | self.rir_dataset_list = rir_dataset_list
80 |
81 | snr_list = self._parse_snr_range(snr_range)
82 | self.snr_list = snr_list
83 |
84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]"
85 | self.reverb_proportion = reverb_proportion
86 | self.silence_length = silence_length
87 | self.target_dB_FS = target_dB_FS
88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value
89 | self.sub_sample_length = sub_sample_length
90 |
91 | self.length = len(self.clean_dataset_list)
92 |
93 | def __len__(self):
94 | return self.length
95 |
96 | def _preload_dataset(self, file_path_list, remark=""):
97 | waveform_list = Parallel(n_jobs=self.num_workers)(
98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark)
99 | )
100 | return list(zip(file_path_list, waveform_list))
101 |
102 | @staticmethod
103 | def _random_select_from(dataset_list):
104 | return random.choice(dataset_list)
105 |
106 | def _select_noise_y(self, target_length):
107 | noise_y = np.zeros(0, dtype=np.float32)
108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32)
109 | remaining_length = target_length
110 |
111 | while remaining_length > 0:
112 | noise_file = self._random_select_from(self.noise_dataset_list)
113 | noise_new_added = load_wav(noise_file, sr=self.sr)
114 | noise_y = np.append(noise_y, noise_new_added)
115 | remaining_length -= len(noise_new_added)
116 |
117 | # 如果还需要添加新的噪声,就插入一个小静音段
118 | if remaining_length > 0:
119 | silence_len = min(remaining_length, len(silence))
120 | noise_y = np.append(noise_y, silence[:silence_len])
121 | remaining_length -= silence_len
122 |
123 | if len(noise_y) > target_length:
124 | idx_start = np.random.randint(len(noise_y) - target_length)
125 | noise_y = noise_y[idx_start:idx_start + target_length]
126 |
127 | return noise_y
128 |
129 | @staticmethod
130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6):
131 | """
132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果
133 |
134 | Args:
135 | clean_y: 纯净语音
136 | noise_y: 噪声
137 | snr (int): 信噪比
138 | target_dB_FS (int):
139 | target_dB_FS_floating_value (int):
140 | rir: room impulse response, None 或 np.array
141 | eps: eps
142 |
143 | Returns:
144 | (noisy_y,clean_y)
145 | """
146 | if rir is not None:
147 | if rir.ndim > 1:
148 | rir_idx = np.random.randint(0, rir.shape[0])
149 | rir = rir[rir_idx, :]
150 |
151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)]
152 |
153 | clean_y, _ = norm_amplitude(clean_y)
154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS)
155 | clean_rms = (clean_y ** 2).mean() ** 0.5
156 |
157 | noise_y, _ = norm_amplitude(noise_y)
158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS)
159 | noise_rms = (noise_y ** 2).mean() ** 0.5
160 |
161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps)
162 | noise_y *= snr_scalar
163 | noisy_y = clean_y + noise_y
164 |
165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value
166 | noisy_target_dB_FS = np.random.randint(
167 | target_dB_FS - target_dB_FS_floating_value,
168 | target_dB_FS + target_dB_FS_floating_value
169 | )
170 |
171 | # 使用 noisy 的 rms 放缩音频
172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS)
173 | clean_y *= noisy_scalar
174 |
175 | # 合成带噪语音的时候可能会 clipping,虽然极少
176 | # 对 noisy, clean_y, noise_y 稍微进行调整
177 | if is_clipped(noisy_y):
178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1
179 | noisy_y = noisy_y / noisy_y_scalar
180 | clean_y = clean_y / noisy_y_scalar
181 |
182 | return noisy_y, clean_y
183 |
184 | def __getitem__(self, item):
185 | clean_file = self.clean_dataset_list[item]
186 | clean_y = load_wav(clean_file, sr=self.sr)
187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr))
188 |
189 | noise_y = self._select_noise_y(target_length=len(clean_y))
190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}"
191 |
192 | snr = self._random_select_from(self.snr_list)
193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion)
194 |
195 | noisy_y, clean_y = self.snr_mix(
196 | clean_y=clean_y,
197 | noise_y=noise_y,
198 | snr=snr,
199 | target_dB_FS=self.target_dB_FS,
200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value,
201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None
202 | )
203 |
204 | noisy_y = noisy_y.astype(np.float32)
205 | clean_y = clean_y.astype(np.float32)
206 |
207 | return noisy_y, clean_y
208 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/dataset/dataset_validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import librosa
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.acoustics.utils import load_wav
8 | from audio_zen.utils import basename
9 |
10 |
11 | class Dataset(BaseDataset):
12 | def __init__(
13 | self,
14 | dataset_dir_list,
15 | sr,
16 | ):
17 | """
18 | Construct DNS validation set
19 |
20 | synthetic/
21 | with_reverb/
22 | noisy/
23 | clean_y/
24 | no_reverb/
25 | noisy/
26 | clean_y/
27 | """
28 | super(Dataset, self).__init__()
29 | noisy_files_list = []
30 |
31 | for dataset_dir in dataset_dir_list:
32 | dataset_dir = Path(dataset_dir).expanduser().absolute()
33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix())
34 |
35 | self.length = len(noisy_files_list)
36 | self.noisy_files_list = noisy_files_list
37 | self.sr = sr
38 |
39 | def __len__(self):
40 | return self.length
41 |
42 | def __getitem__(self, item):
43 | """
44 | use the absolute path of the noisy speech to find the corresponding clean speech.
45 |
46 | Notes
47 | with_reverb and no_reverb dirs have same-named files.
48 | If we use `basename`, the problem will be raised (cover) in visualization.
49 |
50 | Returns:
51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name
52 | """
53 | noisy_file_path = self.noisy_files_list[item]
54 | parent_dir = Path(noisy_file_path).parents[1].name
55 | noisy_filename, _ = basename(noisy_file_path)
56 |
57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename
58 |
59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化
60 | if parent_dir == "with_reverb":
61 | speech_type = "With_reverb"
62 | elif parent_dir == "no_reverb":
63 | speech_type = "No_reverb"
64 | elif parent_dir == "dns_2_non_english":
65 | speech_type = "Non_english"
66 | elif parent_dir == "dns_2_emotion":
67 | speech_type = "Emotion"
68 | elif parent_dir == "dns_2_singing":
69 | speech_type = "Singing"
70 | else:
71 | raise NotImplementedError(f"Not supported dir: {parent_dir}")
72 |
73 | # Find the corresponding clean speech using "parent_dir" and "file_id"
74 | file_id = noisy_filename.split("_")[-1]
75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"):
76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15
77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}"
78 | elif parent_dir == "dns_2_non_english":
79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121
80 | clean_filename = f"synthetic_clean_fileid_{file_id}"
81 | else:
82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300
83 | if parent_dir == "with_reverb":
84 | reverb_remark = "with_reverb"
85 | clean_filename = f"clean_fileid_{file_id}"
86 |
87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}")
88 |
89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr)
90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr)
91 |
92 | return noisy, clean, reverb_remark + noisy_filename, speech_type
93 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/inferencer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/inferencer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/inferencer/inferencer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 | from audio_zen.acoustics.feature import mag_phase
5 | from audio_zen.acoustics.mask import decompress_cIRM
6 | from audio_zen.inferencer.base_inferencer import BaseInferencer
7 |
8 | # for log
9 | from utils.logger import log
10 | print=log
11 |
12 | def cumulative_norm(input):
13 | eps = 1e-10
14 | device = input.device
15 | data_type = input.dtype
16 | n_dim = input.ndim
17 |
18 | assert n_dim in (3, 4)
19 |
20 | if n_dim == 3:
21 | n_channels = 1
22 | batch_size, n_freqs, n_frames = input.size()
23 | else:
24 | batch_size, n_channels, n_freqs, n_frames = input.size()
25 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames)
26 |
27 | step_sum = torch.sum(input, dim=1) # [B, T]
28 | step_pow_sum = torch.sum(torch.square(input), dim=1)
29 |
30 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
31 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T]
32 |
33 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device)
34 | entry_count = entry_count.reshape(1, n_frames) # [1, T]
35 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
36 |
37 | cum_mean = cumulative_sum / entry_count # B, T
38 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T
39 | cum_std = (cum_var + eps).sqrt() # B, T
40 |
41 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames)
42 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames)
43 |
44 | x = (input - cum_mean) / cum_std
45 |
46 | if n_dim == 4:
47 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames)
48 |
49 | return x
50 |
51 |
52 | class Inferencer(BaseInferencer):
53 | def __init__(self, config, checkpoint_path, output_dir):
54 | super().__init__(config, checkpoint_path, output_dir)
55 |
56 | @torch.no_grad()
57 | def mag(self, noisy, inference_args):
58 | noisy_complex = self.torch_stft(noisy)
59 | noisy_mag, noisy_phase = mag_phase(noisy_complex) # [B, F, T] => [B, 1, F, T]
60 |
61 | enhanced_mag = self.model(noisy_mag.unsqueeze(1)).squeeze(1)
62 |
63 | enhanced = self.torch_istft((enhanced_mag, noisy_phase), length=noisy.size(-1), use_mag_phase=True)
64 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
65 |
66 | return enhanced
67 |
68 | @torch.no_grad()
69 | def scaled_mask(self, noisy, inference_args):
70 | noisy_complex = self.torch_stft(noisy)
71 | noisy_mag, noisy_phase = mag_phase(noisy_complex)
72 |
73 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2]
74 | noisy_mag = noisy_mag.unsqueeze(1)
75 | scaled_mask = self.model(noisy_mag)
76 | scaled_mask = scaled_mask.permute(0, 2, 3, 1)
77 |
78 | enhanced_complex = noisy_complex * scaled_mask
79 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1), use_mag_phase=False)
80 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
81 |
82 | return enhanced
83 |
84 | @torch.no_grad()
85 | def sub_band_crm_mask(self, noisy, inference_args):
86 | pad_mode = inference_args["pad_mode"]
87 | n_neighbor = inference_args["n_neighbor"]
88 |
89 | noisy = noisy.cpu().numpy().reshape(-1)
90 | noisy_D = self.librosa_stft(noisy)
91 |
92 | noisy_real = torch.tensor(noisy_D.real, device=self.device)
93 | noisy_imag = torch.tensor(noisy_D.imag, device=self.device)
94 | noisy_mag = torch.sqrt(torch.square(noisy_real) + torch.square(noisy_imag)) # [F, T]
95 | n_freqs, n_frames = noisy_mag.size()
96 |
97 | noisy_mag = noisy_mag.reshape(1, 1, n_freqs, n_frames)
98 | noisy_mag_padded = self._unfold(noisy_mag, pad_mode, n_neighbor) # [B, N, C, F_s, T] <=> [1, 257, 1, 31, T]
99 | noisy_mag_padded = noisy_mag_padded.squeeze(0).squeeze(1) # [257, 31, 200] <=> [B, F_s, T]
100 |
101 | pred_crm = self.model(noisy_mag_padded).detach() # [B, 2, T] <=> [F, 2, T]
102 | pred_crm = pred_crm.permute(0, 2, 1).contiguous() # [B, T, 2]
103 |
104 | lim = 9.99
105 | pred_crm = lim * (pred_crm >= lim) - lim * (pred_crm <= -lim) + pred_crm * (torch.abs(pred_crm) < lim)
106 | pred_crm = -10 * torch.log((10 - pred_crm) / (10 + pred_crm))
107 |
108 | enhanced_real = pred_crm[:, :, 0] * noisy_real - pred_crm[:, :, 1] * noisy_imag
109 | enhanced_imag = pred_crm[:, :, 1] * noisy_real + pred_crm[:, :, 0] * noisy_imag
110 |
111 | enhanced_real = enhanced_real.cpu().numpy()
112 | enhanced_imag = enhanced_imag.cpu().numpy()
113 | enhanced = self.librosa_istft(enhanced_real + 1j * enhanced_imag, length=len(noisy))
114 | return enhanced
115 |
116 | @torch.no_grad()
117 | def full_band_crm_mask(self, noisy, inference_args):
118 | noisy_complex = self.torch_stft(noisy)
119 | noisy_mag, _ = mag_phase(noisy_complex)
120 |
121 | noisy_mag = noisy_mag.unsqueeze(1)
122 | t1 = time.time()
123 | pred_crm = self.model(noisy_mag)
124 | t2 = time.time()
125 | pred_crm = pred_crm.permute(0, 2, 3, 1)
126 |
127 | pred_crm = decompress_cIRM(pred_crm)
128 | enhanced_real = pred_crm[..., 0] * noisy_complex.real - pred_crm[..., 1] * noisy_complex.imag
129 | enhanced_imag = pred_crm[..., 1] * noisy_complex.real + pred_crm[..., 0] * noisy_complex.imag
130 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1)
131 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1))
132 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
133 |
134 | #
135 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"])
136 | print(f"model rtf: {rtf}")
137 |
138 | return enhanced
139 |
140 | @torch.no_grad()
141 | def overlapped_chunk(self, noisy, inference_args):
142 | sr = self.acoustic_config["sr"]
143 |
144 | noisy = noisy.squeeze(0)
145 |
146 | num_mics = 8
147 | chunk_length = sr * inference_args["chunk_length"]
148 | chunk_hop_length = chunk_length // 2
149 | num_chunks = int(noisy.shape[-1] / chunk_hop_length) + 1
150 |
151 | win = torch.hann_window(chunk_length, device=noisy.device)
152 |
153 | prev = None
154 | enhanced = None
155 | # 模拟语音的静音段,防止一上来就给语音,处理的不好
156 | for chunk_idx in range(num_chunks):
157 | if chunk_idx == 0:
158 | pad = torch.zeros((num_mics, 256), device=noisy.device)
159 |
160 | chunk_start_position = chunk_idx * chunk_hop_length
161 | chunk_end_position = chunk_start_position + chunk_length
162 |
163 | # concat([(8, 256), (..., ... + chunk_length)])
164 | noisy_chunk = torch.cat((pad, noisy[:, chunk_start_position:chunk_end_position]), dim=1)
165 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0))
166 | enhanced_chunk = torch.squeeze(enhanced_chunk)
167 | enhanced_chunk = enhanced_chunk[256:]
168 |
169 | # Save the prior half chunk,
170 | cur = enhanced_chunk[:chunk_length // 2]
171 |
172 | # only for the 1st chunk,no overlap for the very 1st chunk prior half
173 | prev = enhanced_chunk[chunk_length // 2:] * win[chunk_length // 2:]
174 | else:
175 | # use the previous noisy data as the pad
176 | pad = noisy[:, (chunk_idx * chunk_hop_length - 256):(chunk_idx * chunk_hop_length)]
177 |
178 | chunk_start_position = chunk_idx * chunk_hop_length
179 | chunk_end_position = chunk_start_position + chunk_length
180 |
181 | noisy_chunk = torch.cat((pad, noisy[:8, chunk_start_position:chunk_end_position]), dim=1)
182 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0))
183 | enhanced_chunk = torch.squeeze(enhanced_chunk)
184 | enhanced_chunk = enhanced_chunk[256:]
185 |
186 | # 使用这个窗函数来对拼接的位置进行平滑?
187 | enhanced_chunk = enhanced_chunk * win[:len(enhanced_chunk)]
188 |
189 | tmp = enhanced_chunk[:chunk_length // 2]
190 | cur = tmp[:min(len(tmp), len(prev))] + prev[:min(len(tmp), len(prev))]
191 | prev = enhanced_chunk[chunk_length // 2:]
192 |
193 | if enhanced is None:
194 | enhanced = cur
195 | else:
196 | enhanced = torch.cat((enhanced, cur), dim=0)
197 |
198 | enhanced = enhanced[:noisy.shape[1]]
199 | return enhanced.detach().squeeze(0).cpu().numpy()
200 |
201 | @torch.no_grad()
202 | def time_domain(self, noisy, inference_args):
203 | noisy = noisy.to(self.device)
204 | enhanced = self.model(noisy)
205 | return enhanced.detach().squeeze().cpu().numpy()
206 |
207 |
208 | if __name__ == '__main__':
209 | a = torch.rand(10, 2, 161, 200)
210 | print(cumulative_norm(a).shape)
211 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/model/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/model/fullsubnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional
3 |
4 | from audio_zen.acoustics.feature import drop_band
5 | from audio_zen.model.base_model import BaseModel
6 | from audio_zen.model.module.sequence_model import SequenceModel
7 |
8 | # for log
9 | from utils.logger import log
10 | print=log
11 |
12 | class Model(BaseModel):
13 | def __init__(self,
14 | num_freqs,
15 | look_ahead,
16 | sequence_model,
17 | fb_num_neighbors,
18 | sb_num_neighbors,
19 | fb_output_activate_function,
20 | sb_output_activate_function,
21 | fb_model_hidden_size,
22 | sb_model_hidden_size,
23 | norm_type="offline_laplace_norm",
24 | num_groups_in_drop_band=2,
25 | weight_init=True,
26 | ):
27 | """
28 | FullSubNet model (cIRM mask)
29 |
30 | Args:
31 | num_freqs: Frequency dim of the input
32 | sb_num_neighbors: Number of the neighbor frequencies in each side
33 | look_ahead: Number of use of the future frames
34 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM)
35 | """
36 | super().__init__()
37 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM."
38 |
39 | self.fb_model = SequenceModel(
40 | input_size=num_freqs,
41 | output_size=num_freqs,
42 | hidden_size=fb_model_hidden_size,
43 | num_layers=2,
44 | bidirectional=False,
45 | sequence_model=sequence_model,
46 | output_activate_function=fb_output_activate_function
47 | )
48 |
49 | self.sb_model = SequenceModel(
50 | input_size=(sb_num_neighbors * 2 + 1) + (fb_num_neighbors * 2 + 1),
51 | output_size=2,
52 | hidden_size=sb_model_hidden_size,
53 | num_layers=2,
54 | bidirectional=False,
55 | sequence_model=sequence_model,
56 | output_activate_function=sb_output_activate_function
57 | )
58 |
59 | self.sb_num_neighbors = sb_num_neighbors
60 | self.fb_num_neighbors = fb_num_neighbors
61 | self.look_ahead = look_ahead
62 | self.norm = self.norm_wrapper(norm_type)
63 | self.num_groups_in_drop_band = num_groups_in_drop_band
64 |
65 | if weight_init:
66 | self.apply(self.weight_init)
67 |
68 | def forward(self, noisy_mag):
69 | """
70 | Args:
71 | noisy_mag: noisy magnitude spectrogram
72 |
73 | Returns:
74 | The real part and imag part of the enhanced spectrogram
75 |
76 | Shapes:
77 | noisy_mag: [B, 1, F, T]
78 | return: [B, 2, F, T]
79 | """
80 | assert noisy_mag.dim() == 4
81 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead
82 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size()
83 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs."
84 |
85 | # Fullband model
86 | fb_input = self.norm(noisy_mag).reshape(batch_size, num_channels * num_freqs, num_frames)
87 | fb_output = self.fb_model(fb_input).reshape(batch_size, 1, num_freqs, num_frames)
88 |
89 | # Unfold the output of the fullband model, [B, N=F, C, F_f, T]
90 | fb_output_unfolded = self.unfold(fb_output, num_neighbor=self.fb_num_neighbors)
91 | fb_output_unfolded = fb_output_unfolded.reshape(batch_size, num_freqs, self.fb_num_neighbors * 2 + 1, num_frames)
92 |
93 | # Unfold noisy input, [B, N=F, C, F_s, T]
94 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors)
95 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1, num_frames)
96 |
97 | # Concatenation, [B, F, (F_s + F_f), T]
98 | sb_input = torch.cat([noisy_mag_unfolded, fb_output_unfolded], dim=2)
99 | sb_input = self.norm(sb_input)
100 |
101 | # Speeding up training without significant performance degradation. These will be updated to the paper later.
102 | if batch_size > 1:
103 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3), num_groups=self.num_groups_in_drop_band) # [B, (F_s + F_f), F//num_groups, T]
104 | num_freqs = sb_input.shape[2]
105 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, (F_s + F_f), T]
106 |
107 | sb_input = sb_input.reshape(
108 | batch_size * num_freqs,
109 | (self.sb_num_neighbors * 2 + 1) + (self.fb_num_neighbors * 2 + 1),
110 | num_frames
111 | )
112 |
113 | # [B * F, (F_s + F_f), T] => [B * F, 2, T] => [B, F, 2, T]
114 | sb_mask = self.sb_model(sb_input)
115 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous()
116 |
117 | output = sb_mask[:, :, :, self.look_ahead:]
118 | return output
119 |
120 |
121 | if __name__ == "__main__":
122 | import datetime
123 |
124 | with torch.no_grad():
125 | model = Model(
126 | sb_num_neighbors=15,
127 | fb_num_neighbors=0,
128 | num_freqs=257,
129 | look_ahead=2,
130 | sequence_model="LSTM",
131 | fb_output_activate_function="ReLU",
132 | sb_output_activate_function=None,
133 | fb_model_hidden_size=512,
134 | sb_model_hidden_size=384,
135 | weight_init=False,
136 | norm_type="offline_laplace_norm",
137 | num_groups_in_drop_band=2,
138 | )
139 | # ipt = torch.rand(3, 800) # 1.6s
140 | # ipt_len = ipt.shape[-1]
141 | # # 1000 frames (16s) - 5.65s (35.31%,纯模型) - 5.78s
142 | # # 500 frames (8s) - 3.05s (38.12%,纯模型) - 3.04s
143 | # # 200 frames (3.2s) - 1.19s (37.19%,纯模型) - 1.20s
144 | # # 100 frames (1.6s) - 0.62s (38.75%,纯模型) - 0.65s
145 | # start = datetime.datetime.now()
146 | #
147 | # complex_tensor = torch.stft(ipt, n_fft=512, hop_length=256)
148 | # mag = (complex_tensor.pow(2.).sum(-1) + 1e-8).pow(0.5 * 1.0).unsqueeze(1)
149 | # print(f"STFT: {datetime.datetime.now() - start}, {mag.shape}")
150 | #
151 | # enhanced_complex_tensor = model(mag).detach().permute(0, 2, 3, 1)
152 | # print(enhanced_complex_tensor.shape)
153 | # print(f"Model Inference: {datetime.datetime.now() - start}")
154 | #
155 | # enhanced = torch.istft(enhanced_complex_tensor, 512, 256, length=ipt_len)
156 | # print(f"iSTFT: {datetime.datetime.now() - start}")
157 | #
158 | # print(f"{datetime.datetime.now() - start}")
159 | ipt = torch.rand(3, 1, 257, 200)
160 | print(model(ipt).shape)
161 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet/trainer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | from torch.cuda.amp import autocast
4 | from tqdm import tqdm
5 |
6 | from audio_zen.acoustics.feature import mag_phase, drop_band
7 | from audio_zen.acoustics.mask import build_complex_ideal_ratio_mask, decompress_cIRM
8 | from audio_zen.trainer.base_trainer import BaseTrainer
9 | from utils.logger import log
10 |
11 | plt.switch_backend('agg')
12 |
13 |
14 | class Trainer(BaseTrainer):
15 | def __init__(self, dist, rank, config, resume, only_validation, model, loss_function, optimizer, train_dataloader, validation_dataloader):
16 | super().__init__(dist, rank, config, resume, only_validation, model, loss_function, optimizer)
17 | self.train_dataloader = train_dataloader
18 | self.valid_dataloader = validation_dataloader
19 |
20 | def _train_epoch(self, epoch):
21 |
22 | loss_total = 0.0
23 | progress_bar = None
24 |
25 | if self.rank == 0:
26 | progress_bar = tqdm(total=len(self.train_dataloader), desc=f"Training")
27 |
28 | for noisy, clean in self.train_dataloader:
29 | self.optimizer.zero_grad()
30 |
31 | noisy = noisy.to(self.rank)
32 | clean = clean.to(self.rank)
33 |
34 | noisy_complex = self.torch_stft(noisy)
35 | clean_complex = self.torch_stft(clean)
36 |
37 | noisy_mag, _ = mag_phase(noisy_complex)
38 | ground_truth_cIRM = build_complex_ideal_ratio_mask(noisy_complex, clean_complex) # [B, F, T, 2]
39 | ground_truth_cIRM = drop_band(
40 | ground_truth_cIRM.permute(0, 3, 1, 2), # [B, 2, F ,T]
41 | self.model.module.num_groups_in_drop_band
42 | ).permute(0, 2, 3, 1)
43 |
44 | with autocast(enabled=self.use_amp):
45 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2]
46 | noisy_mag = noisy_mag.unsqueeze(1)
47 | cRM = self.model(noisy_mag)
48 | cRM = cRM.permute(0, 2, 3, 1)
49 | loss = self.loss_function(ground_truth_cIRM, cRM)
50 |
51 | self.scaler.scale(loss).backward()
52 | self.scaler.unscale_(self.optimizer)
53 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value)
54 | self.scaler.step(self.optimizer)
55 | self.scaler.update()
56 |
57 | loss_total += loss.item()
58 |
59 | if self.rank == 0:
60 | progress_bar.update(1)
61 | progress_bar.refresh()
62 |
63 | if self.rank == 0:
64 | log(f"[Train] Epoch {epoch}, Loss {loss_total / len(self.train_dataloader)}")
65 | self.writer.add_scalar(f"Loss/Train", loss_total / len(self.train_dataloader), epoch)
66 |
67 | @torch.no_grad()
68 | def _validation_epoch(self, epoch):
69 | progress_bar = None
70 | if self.rank == 0:
71 | progress_bar = tqdm(total=len(self.valid_dataloader), desc=f"Validation")
72 |
73 | visualization_n_samples = self.visualization_config["n_samples"]
74 | visualization_num_workers = self.visualization_config["num_workers"]
75 | visualization_metrics = self.visualization_config["metrics"]
76 |
77 | loss_total = 0.0
78 | loss_list = {"With_reverb": 0.0, "No_reverb": 0.0, }
79 | item_idx_list = {"With_reverb": 0, "No_reverb": 0, }
80 | noisy_y_list = {"With_reverb": [], "No_reverb": [], }
81 | clean_y_list = {"With_reverb": [], "No_reverb": [], }
82 | enhanced_y_list = {"With_reverb": [], "No_reverb": [], }
83 | validation_score_list = {"With_reverb": 0.0, "No_reverb": 0.0}
84 |
85 | # speech_type in ("with_reverb", "no_reverb")
86 | for i, (noisy, clean, name, speech_type) in enumerate(self.valid_dataloader):
87 | assert len(name) == 1, "The batch size for the validation stage must be one."
88 | name = name[0]
89 | speech_type = speech_type[0]
90 |
91 | noisy = noisy.to(self.rank)
92 | clean = clean.to(self.rank)
93 |
94 | noisy_complex = self.torch_stft(noisy)
95 | clean_complex = self.torch_stft(clean)
96 |
97 | noisy_mag, _ = mag_phase(noisy_complex)
98 | cIRM = build_complex_ideal_ratio_mask(noisy_complex, clean_complex) # [B, F, T, 2]
99 |
100 | noisy_mag = noisy_mag.unsqueeze(1)
101 | cRM = self.model(noisy_mag)
102 | cRM = cRM.permute(0, 2, 3, 1)
103 |
104 | loss = self.loss_function(cIRM, cRM)
105 |
106 | cRM = decompress_cIRM(cRM)
107 |
108 | enhanced_real = cRM[..., 0] * noisy_complex.real - cRM[..., 1] * noisy_complex.imag
109 | enhanced_imag = cRM[..., 1] * noisy_complex.real + cRM[..., 0] * noisy_complex.imag
110 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1)
111 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1))
112 |
113 | noisy = noisy.detach().squeeze(0).cpu().numpy()
114 | clean = clean.detach().squeeze(0).cpu().numpy()
115 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
116 |
117 | assert len(noisy) == len(clean) == len(enhanced)
118 | loss_total += loss
119 |
120 | # Separated loss
121 | loss_list[speech_type] += loss
122 | item_idx_list[speech_type] += 1
123 |
124 | if item_idx_list[speech_type] <= visualization_n_samples:
125 | self.spec_audio_visualization(noisy, enhanced, clean, name, epoch, mark=speech_type)
126 |
127 | noisy_y_list[speech_type].append(noisy)
128 | clean_y_list[speech_type].append(clean)
129 | enhanced_y_list[speech_type].append(enhanced)
130 |
131 | if self.rank == 0:
132 | progress_bar.update(1)
133 |
134 | log(f"[Test] Epoch {epoch}, Loss {loss_total / len(self.valid_dataloader)}")
135 | self.writer.add_scalar(f"Loss/Validation_Total", loss_total / len(self.valid_dataloader), epoch)
136 |
137 | for speech_type in ("With_reverb", "No_reverb"):
138 | log(f"[Test] Epoch {epoch}, {speech_type}, Loss {loss_list[speech_type] / len(self.valid_dataloader)}")
139 | self.writer.add_scalar(f"Loss/{speech_type}", loss_list[speech_type] / len(self.valid_dataloader), epoch)
140 |
141 | validation_score_list[speech_type] = self.metrics_visualization(
142 | noisy_y_list[speech_type], clean_y_list[speech_type], enhanced_y_list[speech_type],
143 | visualization_metrics, epoch, visualization_num_workers, mark=speech_type
144 | )
145 |
146 | return validation_score_list["No_reverb"]
147 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet_plus/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet_plus/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/dataset/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet_plus/dataset/dataset_inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import librosa
4 | import numpy as np
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.utils import basename
8 |
9 |
10 | class Dataset(BaseDataset):
11 | def __init__(self,
12 | dataset_dir_list,
13 | sr,
14 | ):
15 | """
16 | Args:
17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list
18 | """
19 | super().__init__()
20 | assert isinstance(dataset_dir_list, list)
21 | self.sr = sr
22 |
23 | noisy_file_path_list = []
24 | for dataset_dir in dataset_dir_list:
25 | dataset_dir = Path(dataset_dir).expanduser().absolute()
26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted
27 |
28 | self.noisy_file_path_list = noisy_file_path_list
29 | self.length = len(self.noisy_file_path_list)
30 |
31 | def __len__(self):
32 | return self.length
33 |
34 | def __getitem__(self, item):
35 | noisy_file_path = self.noisy_file_path_list[item]
36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0]
37 | noisy_y = noisy_y.astype(np.float32)
38 |
39 | return noisy_y, basename(noisy_file_path)[0]
40 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet_plus/dataset/dataset_train.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample
5 | from audio_zen.dataset.base_dataset import BaseDataset
6 | from audio_zen.utils import expand_path
7 | from joblib import Parallel, delayed
8 | from scipy import signal
9 | from tqdm import tqdm
10 |
11 |
12 | class Dataset(BaseDataset):
13 | def __init__(self,
14 | clean_dataset,
15 | clean_dataset_limit,
16 | clean_dataset_offset,
17 | noise_dataset,
18 | noise_dataset_limit,
19 | noise_dataset_offset,
20 | rir_dataset,
21 | rir_dataset_limit,
22 | rir_dataset_offset,
23 | snr_range,
24 | reverb_proportion,
25 | silence_length,
26 | target_dB_FS,
27 | target_dB_FS_floating_value,
28 | sub_sample_length,
29 | sr,
30 | pre_load_clean_dataset,
31 | pre_load_noise,
32 | pre_load_rir,
33 | num_workers
34 | ):
35 | """
36 | Dynamic mixing for training
37 |
38 | Args:
39 | clean_dataset_limit:
40 | clean_dataset_offset:
41 | noise_dataset_limit:
42 | noise_dataset_offset:
43 | rir_dataset:
44 | rir_dataset_limit:
45 | rir_dataset_offset:
46 | snr_range:
47 | reverb_proportion:
48 | clean_dataset: scp file
49 | noise_dataset: scp file
50 | sub_sample_length:
51 | sr:
52 | """
53 | super().__init__()
54 | # acoustics args
55 | self.sr = sr
56 |
57 | # parallel args
58 | self.num_workers = num_workers
59 |
60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")]
61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")]
62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")]
63 |
64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit)
65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit)
66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit)
67 |
68 | if pre_load_clean_dataset:
69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset")
70 |
71 | if pre_load_noise:
72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset")
73 |
74 | if pre_load_rir:
75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset")
76 |
77 | self.clean_dataset_list = clean_dataset_list
78 | self.noise_dataset_list = noise_dataset_list
79 | self.rir_dataset_list = rir_dataset_list
80 |
81 | snr_list = self._parse_snr_range(snr_range)
82 | self.snr_list = snr_list
83 |
84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]"
85 | self.reverb_proportion = reverb_proportion
86 | self.silence_length = silence_length
87 | self.target_dB_FS = target_dB_FS
88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value
89 | self.sub_sample_length = sub_sample_length
90 |
91 | self.length = len(self.clean_dataset_list)
92 |
93 | def __len__(self):
94 | return self.length
95 |
96 | def _preload_dataset(self, file_path_list, remark=""):
97 | waveform_list = Parallel(n_jobs=self.num_workers)(
98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark)
99 | )
100 | return list(zip(file_path_list, waveform_list))
101 |
102 | @staticmethod
103 | def _random_select_from(dataset_list):
104 | return random.choice(dataset_list)
105 |
106 | def _select_noise_y(self, target_length):
107 | noise_y = np.zeros(0, dtype=np.float32)
108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32)
109 | remaining_length = target_length
110 |
111 | while remaining_length > 0:
112 | noise_file = self._random_select_from(self.noise_dataset_list)
113 | noise_new_added = load_wav(noise_file, sr=self.sr)
114 | noise_y = np.append(noise_y, noise_new_added)
115 | remaining_length -= len(noise_new_added)
116 |
117 | # 如果还需要添加新的噪声,就插入一个小静音段
118 | if remaining_length > 0:
119 | silence_len = min(remaining_length, len(silence))
120 | noise_y = np.append(noise_y, silence[:silence_len])
121 | remaining_length -= silence_len
122 |
123 | if len(noise_y) > target_length:
124 | idx_start = np.random.randint(len(noise_y) - target_length)
125 | noise_y = noise_y[idx_start:idx_start + target_length]
126 |
127 | return noise_y
128 |
129 | @staticmethod
130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6):
131 | """
132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果
133 |
134 | Args:
135 | clean_y: 纯净语音
136 | noise_y: 噪声
137 | snr (int): 信噪比
138 | target_dB_FS (int):
139 | target_dB_FS_floating_value (int):
140 | rir: room impulse response, None 或 np.array
141 | eps: eps
142 |
143 | Returns:
144 | (noisy_y,clean_y)
145 | """
146 | if rir is not None:
147 | if rir.ndim > 1:
148 | rir_idx = np.random.randint(0, rir.shape[0])
149 | rir = rir[rir_idx, :]
150 |
151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)]
152 |
153 | clean_y, _ = norm_amplitude(clean_y)
154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS)
155 | clean_rms = (clean_y ** 2).mean() ** 0.5
156 |
157 | noise_y, _ = norm_amplitude(noise_y)
158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS)
159 | noise_rms = (noise_y ** 2).mean() ** 0.5
160 |
161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps)
162 | noise_y *= snr_scalar
163 | noisy_y = clean_y + noise_y
164 |
165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value
166 | noisy_target_dB_FS = np.random.randint(
167 | target_dB_FS - target_dB_FS_floating_value,
168 | target_dB_FS + target_dB_FS_floating_value
169 | )
170 |
171 | # 使用 noisy 的 rms 放缩音频
172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS)
173 | clean_y *= noisy_scalar
174 |
175 | # 合成带噪语音的时候可能会 clipping,虽然极少
176 | # 对 noisy, clean_y, noise_y 稍微进行调整
177 | if is_clipped(noisy_y):
178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1
179 | noisy_y = noisy_y / noisy_y_scalar
180 | clean_y = clean_y / noisy_y_scalar
181 |
182 | return noisy_y, clean_y
183 |
184 | def __getitem__(self, item):
185 | clean_file = self.clean_dataset_list[item]
186 | clean_y = load_wav(clean_file, sr=self.sr)
187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr))
188 |
189 | noise_y = self._select_noise_y(target_length=len(clean_y))
190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}"
191 |
192 | snr = self._random_select_from(self.snr_list)
193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion)
194 |
195 | noisy_y, clean_y = self.snr_mix(
196 | clean_y=clean_y,
197 | noise_y=noise_y,
198 | snr=snr,
199 | target_dB_FS=self.target_dB_FS,
200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value,
201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None
202 | )
203 |
204 | noisy_y = noisy_y.astype(np.float32)
205 | clean_y = clean_y.astype(np.float32)
206 |
207 | return noisy_y, clean_y
208 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet_plus/dataset/dataset_validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import librosa
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.acoustics.utils import load_wav
8 | from audio_zen.utils import basename
9 |
10 |
11 | class Dataset(BaseDataset):
12 | def __init__(
13 | self,
14 | dataset_dir_list,
15 | sr,
16 | ):
17 | """
18 | Construct DNS validation set
19 |
20 | synthetic/
21 | with_reverb/
22 | noisy/
23 | clean_y/
24 | no_reverb/
25 | noisy/
26 | clean_y/
27 | """
28 | super(Dataset, self).__init__()
29 | noisy_files_list = []
30 |
31 | for dataset_dir in dataset_dir_list:
32 | dataset_dir = Path(dataset_dir).expanduser().absolute()
33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix())
34 |
35 | self.length = len(noisy_files_list)
36 | self.noisy_files_list = noisy_files_list
37 | self.sr = sr
38 |
39 | def __len__(self):
40 | return self.length
41 |
42 | def __getitem__(self, item):
43 | """
44 | use the absolute path of the noisy speech to find the corresponding clean speech.
45 |
46 | Notes
47 | with_reverb and no_reverb dirs have same-named files.
48 | If we use `basename`, the problem will be raised (cover) in visualization.
49 |
50 | Returns:
51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name
52 | """
53 | noisy_file_path = self.noisy_files_list[item]
54 | parent_dir = Path(noisy_file_path).parents[1].name
55 | noisy_filename, _ = basename(noisy_file_path)
56 |
57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename
58 |
59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化
60 | if parent_dir == "with_reverb":
61 | speech_type = "With_reverb"
62 | elif parent_dir == "no_reverb":
63 | speech_type = "No_reverb"
64 | elif parent_dir == "dns_2_non_english":
65 | speech_type = "Non_english"
66 | elif parent_dir == "dns_2_emotion":
67 | speech_type = "Emotion"
68 | elif parent_dir == "dns_2_singing":
69 | speech_type = "Singing"
70 | else:
71 | raise NotImplementedError(f"Not supported dir: {parent_dir}")
72 |
73 | # Find the corresponding clean speech using "parent_dir" and "file_id"
74 | file_id = noisy_filename.split("_")[-1]
75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"):
76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15
77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}"
78 | elif parent_dir == "dns_2_non_english":
79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121
80 | clean_filename = f"synthetic_clean_fileid_{file_id}"
81 | else:
82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300
83 | if parent_dir == "with_reverb":
84 | reverb_remark = "with_reverb"
85 | clean_filename = f"clean_fileid_{file_id}"
86 |
87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}")
88 |
89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr)
90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr)
91 |
92 | return noisy, clean, reverb_remark + noisy_filename, speech_type
93 |
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet_plus/inferencer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/inferencer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/fullsubnet_plus/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/fullsubnet_plus/trainer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/dataset/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/dataset/dataset_inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import librosa
4 | import numpy as np
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.utils import basename
8 |
9 |
10 | class Dataset(BaseDataset):
11 | def __init__(self,
12 | dataset_dir_list,
13 | sr,
14 | ):
15 | """
16 | Args:
17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list
18 | """
19 | super().__init__()
20 | assert isinstance(dataset_dir_list, list)
21 | self.sr = sr
22 |
23 | noisy_file_path_list = []
24 | for dataset_dir in dataset_dir_list:
25 | dataset_dir = Path(dataset_dir).expanduser().absolute()
26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted
27 |
28 | self.noisy_file_path_list = noisy_file_path_list
29 | self.length = len(self.noisy_file_path_list)
30 |
31 | def __len__(self):
32 | return self.length
33 |
34 | def __getitem__(self, item):
35 | noisy_file_path = self.noisy_file_path_list[item]
36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0]
37 | noisy_y = noisy_y.astype(np.float32)
38 |
39 | return noisy_y, basename(noisy_file_path)[0]
40 |
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/dataset/dataset_train.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample
5 | from audio_zen.dataset.base_dataset import BaseDataset
6 | from audio_zen.utils import expand_path
7 | from joblib import Parallel, delayed
8 | from scipy import signal
9 | from tqdm import tqdm
10 |
11 |
12 | class Dataset(BaseDataset):
13 | def __init__(self,
14 | clean_dataset,
15 | clean_dataset_limit,
16 | clean_dataset_offset,
17 | noise_dataset,
18 | noise_dataset_limit,
19 | noise_dataset_offset,
20 | rir_dataset,
21 | rir_dataset_limit,
22 | rir_dataset_offset,
23 | snr_range,
24 | reverb_proportion,
25 | silence_length,
26 | target_dB_FS,
27 | target_dB_FS_floating_value,
28 | sub_sample_length,
29 | sr,
30 | pre_load_clean_dataset,
31 | pre_load_noise,
32 | pre_load_rir,
33 | num_workers
34 | ):
35 | """
36 | Dynamic mixing for training
37 |
38 | Args:
39 | clean_dataset_limit:
40 | clean_dataset_offset:
41 | noise_dataset_limit:
42 | noise_dataset_offset:
43 | rir_dataset:
44 | rir_dataset_limit:
45 | rir_dataset_offset:
46 | snr_range:
47 | reverb_proportion:
48 | clean_dataset: scp file
49 | noise_dataset: scp file
50 | sub_sample_length:
51 | sr:
52 | """
53 | super().__init__()
54 | # acoustics args
55 | self.sr = sr
56 |
57 | # parallel args
58 | self.num_workers = num_workers
59 |
60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")]
61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")]
62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")]
63 |
64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit)
65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit)
66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit)
67 |
68 | if pre_load_clean_dataset:
69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset")
70 |
71 | if pre_load_noise:
72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset")
73 |
74 | if pre_load_rir:
75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset")
76 |
77 | self.clean_dataset_list = clean_dataset_list
78 | self.noise_dataset_list = noise_dataset_list
79 | self.rir_dataset_list = rir_dataset_list
80 |
81 | snr_list = self._parse_snr_range(snr_range)
82 | self.snr_list = snr_list
83 |
84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]"
85 | self.reverb_proportion = reverb_proportion
86 | self.silence_length = silence_length
87 | self.target_dB_FS = target_dB_FS
88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value
89 | self.sub_sample_length = sub_sample_length
90 |
91 | self.length = len(self.clean_dataset_list)
92 |
93 | def __len__(self):
94 | return self.length
95 |
96 | def _preload_dataset(self, file_path_list, remark=""):
97 | waveform_list = Parallel(n_jobs=self.num_workers)(
98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark)
99 | )
100 | return list(zip(file_path_list, waveform_list))
101 |
102 | @staticmethod
103 | def _random_select_from(dataset_list):
104 | return random.choice(dataset_list)
105 |
106 | def _select_noise_y(self, target_length):
107 | noise_y = np.zeros(0, dtype=np.float32)
108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32)
109 | remaining_length = target_length
110 |
111 | while remaining_length > 0:
112 | noise_file = self._random_select_from(self.noise_dataset_list)
113 | noise_new_added = load_wav(noise_file, sr=self.sr)
114 | noise_y = np.append(noise_y, noise_new_added)
115 | remaining_length -= len(noise_new_added)
116 |
117 | # 如果还需要添加新的噪声,就插入一个小静音段
118 | if remaining_length > 0:
119 | silence_len = min(remaining_length, len(silence))
120 | noise_y = np.append(noise_y, silence[:silence_len])
121 | remaining_length -= silence_len
122 |
123 | if len(noise_y) > target_length:
124 | idx_start = np.random.randint(len(noise_y) - target_length)
125 | noise_y = noise_y[idx_start:idx_start + target_length]
126 |
127 | return noise_y
128 |
129 | @staticmethod
130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6):
131 | """
132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果
133 |
134 | Args:
135 | clean_y: 纯净语音
136 | noise_y: 噪声
137 | snr (int): 信噪比
138 | target_dB_FS (int):
139 | target_dB_FS_floating_value (int):
140 | rir: room impulse response, None 或 np.array
141 | eps: eps
142 |
143 | Returns:
144 | (noisy_y,clean_y)
145 | """
146 | if rir is not None:
147 | if rir.ndim > 1:
148 | rir_idx = np.random.randint(0, rir.shape[0])
149 | rir = rir[rir_idx, :]
150 |
151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)]
152 |
153 | clean_y, _ = norm_amplitude(clean_y)
154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS)
155 | clean_rms = (clean_y ** 2).mean() ** 0.5
156 |
157 | noise_y, _ = norm_amplitude(noise_y)
158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS)
159 | noise_rms = (noise_y ** 2).mean() ** 0.5
160 |
161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps)
162 | noise_y *= snr_scalar
163 | noisy_y = clean_y + noise_y
164 |
165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value
166 | noisy_target_dB_FS = np.random.randint(
167 | target_dB_FS - target_dB_FS_floating_value,
168 | target_dB_FS + target_dB_FS_floating_value
169 | )
170 |
171 | # 使用 noisy 的 rms 放缩音频
172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS)
173 | clean_y *= noisy_scalar
174 |
175 | # 合成带噪语音的时候可能会 clipping,虽然极少
176 | # 对 noisy, clean_y, noise_y 稍微进行调整
177 | if is_clipped(noisy_y):
178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1
179 | noisy_y = noisy_y / noisy_y_scalar
180 | clean_y = clean_y / noisy_y_scalar
181 |
182 | return noisy_y, clean_y
183 |
184 | def __getitem__(self, item):
185 | clean_file = self.clean_dataset_list[item]
186 | clean_y = load_wav(clean_file, sr=self.sr)
187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr))
188 |
189 | noise_y = self._select_noise_y(target_length=len(clean_y))
190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}"
191 |
192 | snr = self._random_select_from(self.snr_list)
193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion)
194 |
195 | noisy_y, clean_y = self.snr_mix(
196 | clean_y=clean_y,
197 | noise_y=noise_y,
198 | snr=snr,
199 | target_dB_FS=self.target_dB_FS,
200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value,
201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None
202 | )
203 |
204 | noisy_y = noisy_y.astype(np.float32)
205 | clean_y = clean_y.astype(np.float32)
206 |
207 | return noisy_y, clean_y
208 |
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/dataset/dataset_validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import librosa
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.acoustics.utils import load_wav
8 | from audio_zen.utils import basename
9 |
10 |
11 | class Dataset(BaseDataset):
12 | def __init__(
13 | self,
14 | dataset_dir_list,
15 | sr,
16 | ):
17 | """
18 | Construct DNS validation set
19 |
20 | synthetic/
21 | with_reverb/
22 | noisy/
23 | clean_y/
24 | no_reverb/
25 | noisy/
26 | clean_y/
27 | """
28 | super(Dataset, self).__init__()
29 | noisy_files_list = []
30 |
31 | for dataset_dir in dataset_dir_list:
32 | dataset_dir = Path(dataset_dir).expanduser().absolute()
33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix())
34 |
35 | self.length = len(noisy_files_list)
36 | self.noisy_files_list = noisy_files_list
37 | self.sr = sr
38 |
39 | def __len__(self):
40 | return self.length
41 |
42 | def __getitem__(self, item):
43 | """
44 | use the absolute path of the noisy speech to find the corresponding clean speech.
45 |
46 | Notes
47 | with_reverb and no_reverb dirs have same-named files.
48 | If we use `basename`, the problem will be raised (cover) in visualization.
49 |
50 | Returns:
51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name
52 | """
53 | noisy_file_path = self.noisy_files_list[item]
54 | parent_dir = Path(noisy_file_path).parents[1].name
55 | noisy_filename, _ = basename(noisy_file_path)
56 |
57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename
58 |
59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化
60 | if parent_dir == "with_reverb":
61 | speech_type = "With_reverb"
62 | elif parent_dir == "no_reverb":
63 | speech_type = "No_reverb"
64 | elif parent_dir == "dns_2_non_english":
65 | speech_type = "Non_english"
66 | elif parent_dir == "dns_2_emotion":
67 | speech_type = "Emotion"
68 | elif parent_dir == "dns_2_singing":
69 | speech_type = "Singing"
70 | else:
71 | raise NotImplementedError(f"Not supported dir: {parent_dir}")
72 |
73 | # Find the corresponding clean speech using "parent_dir" and "file_id"
74 | file_id = noisy_filename.split("_")[-1]
75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"):
76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15
77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}"
78 | elif parent_dir == "dns_2_non_english":
79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121
80 | clean_filename = f"synthetic_clean_fileid_{file_id}"
81 | else:
82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300
83 | if parent_dir == "with_reverb":
84 | reverb_remark = "with_reverb"
85 | clean_filename = f"clean_fileid_{file_id}"
86 |
87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}")
88 |
89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr)
90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr)
91 |
92 | return noisy, clean, reverb_remark + noisy_filename, speech_type
93 |
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/inferencer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/inferencer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/inferencer/inferencer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 | from audio_zen.acoustics.feature import mag_phase
5 | from audio_zen.acoustics.mask import decompress_cIRM
6 | from audio_zen.inferencer.base_inferencer import BaseInferencer
7 | import soundfile as sf
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | # for log
12 | from utils.logger import log
13 | print=log
14 |
15 | def cumulative_norm(input):
16 | eps = 1e-10
17 | device = input.device
18 | data_type = input.dtype
19 | n_dim = input.ndim
20 |
21 | assert n_dim in (3, 4)
22 |
23 | if n_dim == 3:
24 | n_channels = 1
25 | batch_size, n_freqs, n_frames = input.size()
26 | else:
27 | batch_size, n_channels, n_freqs, n_frames = input.size()
28 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames)
29 |
30 | step_sum = torch.sum(input, dim=1) # [B, T]
31 | step_pow_sum = torch.sum(torch.square(input), dim=1)
32 |
33 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
34 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T]
35 |
36 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device)
37 | entry_count = entry_count.reshape(1, n_frames) # [1, T]
38 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
39 |
40 | cum_mean = cumulative_sum / entry_count # B, T
41 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T
42 | cum_std = (cum_var + eps).sqrt() # B, T
43 |
44 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames)
45 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames)
46 |
47 | x = (input - cum_mean) / cum_std
48 |
49 | if n_dim == 4:
50 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames)
51 |
52 | return x
53 |
54 |
55 | class Inferencer(BaseInferencer):
56 | def __init__(self, config, checkpoint_path, output_dir):
57 | super().__init__(config, checkpoint_path, output_dir)
58 |
59 | @torch.no_grad()
60 | def mag(self, noisy, inference_args):
61 | noisy_complex = self.torch_stft(noisy)
62 | noisy_mag, noisy_phase = mag_phase(noisy_complex) # [B, F, T] => [B, 1, F, T]
63 |
64 | enhanced_mag = self.model(noisy_mag.unsqueeze(1)).squeeze(1)
65 |
66 | enhanced = self.torch_istft((enhanced_mag, noisy_phase), length=noisy.size(-1), use_mag_phase=True)
67 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
68 |
69 | return enhanced
70 |
71 | @torch.no_grad()
72 | def scaled_mask(self, noisy, inference_args):
73 | noisy_complex = self.torch_stft(noisy)
74 | noisy_mag, noisy_phase = mag_phase(noisy_complex)
75 |
76 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2]
77 | noisy_mag = noisy_mag.unsqueeze(1)
78 | scaled_mask = self.model(noisy_mag)
79 | scaled_mask = scaled_mask.permute(0, 2, 3, 1)
80 |
81 | enhanced_complex = noisy_complex * scaled_mask
82 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1), use_mag_phase=False)
83 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
84 |
85 | return enhanced
86 |
87 | @torch.no_grad()
88 | def sub_band_crm_mask(self, noisy, inference_args):
89 | pad_mode = inference_args["pad_mode"]
90 | n_neighbor = inference_args["n_neighbor"]
91 |
92 | noisy = noisy.cpu().numpy().reshape(-1)
93 | noisy_D = self.librosa_stft(noisy)
94 |
95 | noisy_real = torch.tensor(noisy_D.real, device=self.device)
96 | noisy_imag = torch.tensor(noisy_D.imag, device=self.device)
97 | noisy_mag = torch.sqrt(torch.square(noisy_real) + torch.square(noisy_imag)) # [F, T]
98 | n_freqs, n_frames = noisy_mag.size()
99 |
100 | noisy_mag = noisy_mag.reshape(1, 1, n_freqs, n_frames)
101 | noisy_mag_padded = self._unfold(noisy_mag, pad_mode, n_neighbor) # [B, N, C, F_s, T] <=> [1, 257, 1, 31, T]
102 | noisy_mag_padded = noisy_mag_padded.squeeze(0).squeeze(1) # [257, 31, 200] <=> [B, F_s, T]
103 |
104 | pred_crm = self.model(noisy_mag_padded).detach() # [B, 2, T] <=> [F, 2, T]
105 | pred_crm = pred_crm.permute(0, 2, 1).contiguous() # [B, T, 2]
106 |
107 | lim = 9.99
108 | pred_crm = lim * (pred_crm >= lim) - lim * (pred_crm <= -lim) + pred_crm * (torch.abs(pred_crm) < lim)
109 | pred_crm = -10 * torch.log((10 - pred_crm) / (10 + pred_crm))
110 |
111 | enhanced_real = pred_crm[:, :, 0] * noisy_real - pred_crm[:, :, 1] * noisy_imag
112 | enhanced_imag = pred_crm[:, :, 1] * noisy_real + pred_crm[:, :, 0] * noisy_imag
113 |
114 | enhanced_real = enhanced_real.cpu().numpy()
115 | enhanced_imag = enhanced_imag.cpu().numpy()
116 | enhanced = self.librosa_istft(enhanced_real + 1j * enhanced_imag, length=len(noisy))
117 | return enhanced
118 |
119 | @torch.no_grad()
120 | def full_band_crm_mask(self, noisy, inference_args):
121 | noisy_complex = self.torch_stft(noisy)
122 | noisy_mag, _ = mag_phase(noisy_complex)
123 |
124 | noisy_mag = noisy_mag.unsqueeze(1)
125 | t1 = time.time()
126 | pred_crm = self.model(noisy_mag)
127 | t2 = time.time()
128 | pred_crm = pred_crm.permute(0, 2, 3, 1)
129 |
130 | pred_crm = decompress_cIRM(pred_crm)
131 | enhanced_real = pred_crm[..., 0] * noisy_complex.real - pred_crm[..., 1] * noisy_complex.imag
132 | enhanced_imag = pred_crm[..., 1] * noisy_complex.real + pred_crm[..., 0] * noisy_complex.imag
133 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1)
134 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1))
135 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
136 |
137 | # rtf计算
138 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"])
139 | print(f"model rtf: {rtf}")
140 |
141 | return enhanced
142 |
143 |
144 | @torch.no_grad()
145 | def overlapped_chunk(self, noisy, inference_args):
146 | sr = self.acoustic_config["sr"]
147 |
148 | noisy = noisy.squeeze(0)
149 |
150 | num_mics = 8
151 | chunk_length = sr * inference_args["chunk_length"]
152 | chunk_hop_length = chunk_length // 2
153 | num_chunks = int(noisy.shape[-1] / chunk_hop_length) + 1
154 |
155 | win = torch.hann_window(chunk_length, device=noisy.device)
156 |
157 | prev = None
158 | enhanced = None
159 | # 模拟语音的静音段,防止一上来就给语音,处理的不好
160 | for chunk_idx in range(num_chunks):
161 | if chunk_idx == 0:
162 | pad = torch.zeros((num_mics, 256), device=noisy.device)
163 |
164 | chunk_start_position = chunk_idx * chunk_hop_length
165 | chunk_end_position = chunk_start_position + chunk_length
166 |
167 | # concat([(8, 256), (..., ... + chunk_length)])
168 | noisy_chunk = torch.cat((pad, noisy[:, chunk_start_position:chunk_end_position]), dim=1)
169 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0))
170 | enhanced_chunk = torch.squeeze(enhanced_chunk)
171 | enhanced_chunk = enhanced_chunk[256:]
172 |
173 | # Save the prior half chunk,
174 | cur = enhanced_chunk[:chunk_length // 2]
175 |
176 | # only for the 1st chunk,no overlap for the very 1st chunk prior half
177 | prev = enhanced_chunk[chunk_length // 2:] * win[chunk_length // 2:]
178 | else:
179 | # use the previous noisy data as the pad
180 | pad = noisy[:, (chunk_idx * chunk_hop_length - 256):(chunk_idx * chunk_hop_length)]
181 |
182 | chunk_start_position = chunk_idx * chunk_hop_length
183 | chunk_end_position = chunk_start_position + chunk_length
184 |
185 | noisy_chunk = torch.cat((pad, noisy[:8, chunk_start_position:chunk_end_position]), dim=1)
186 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0))
187 | enhanced_chunk = torch.squeeze(enhanced_chunk)
188 | enhanced_chunk = enhanced_chunk[256:]
189 |
190 | # 使用这个窗函数来对拼接的位置进行平滑?
191 | enhanced_chunk = enhanced_chunk * win[:len(enhanced_chunk)]
192 |
193 | tmp = enhanced_chunk[:chunk_length // 2]
194 | cur = tmp[:min(len(tmp), len(prev))] + prev[:min(len(tmp), len(prev))]
195 | prev = enhanced_chunk[chunk_length // 2:]
196 |
197 | if enhanced is None:
198 | enhanced = cur
199 | else:
200 | enhanced = torch.cat((enhanced, cur), dim=0)
201 |
202 | enhanced = enhanced[:noisy.shape[1]]
203 | return enhanced.detach().squeeze(0).cpu().numpy()
204 |
205 | @torch.no_grad()
206 | def time_domain(self, noisy, inference_args):
207 | noisy = noisy.to(self.device)
208 | enhanced = self.model(noisy)
209 | return enhanced.detach().squeeze().cpu().numpy()
210 |
211 |
212 | if __name__ == '__main__':
213 | a = torch.rand(10, 2, 161, 200)
214 | print(cumulative_norm(a).shape)
215 |
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/model/Inter_SubNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional
3 |
4 | from audio_zen.acoustics.feature import drop_band
5 | from audio_zen.model.base_model import BaseModel
6 | from audio_zen.model.module.sequence_model import stacked_SIL_blocks_SequenceModel
7 |
8 | # for log
9 | from utils.logger import log
10 |
11 | print = log
12 |
13 |
14 | class Inter_SubNet(BaseModel):
15 | def __init__(self,
16 | num_freqs,
17 | look_ahead,
18 | sequence_model,
19 | sb_num_neighbors,
20 | sb_output_activate_function,
21 | sb_model_hidden_size,
22 | norm_type="offline_laplace_norm",
23 | num_groups_in_drop_band=2,
24 | weight_init=True,
25 | sbinter_middle_hidden_times=0.66,
26 | ):
27 | """
28 | Inter-SubNet model (cIRM mask)
29 |
30 | Args:
31 | num_freqs: Frequency dim of the input
32 | sb_num_neighbors: Number of the neighbor frequencies in each side
33 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM)
34 | """
35 | super().__init__()
36 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM."
37 |
38 | subband_input_size = (sb_num_neighbors * 2 + 1)
39 | self.sb_model = stacked_SIL_blocks_SequenceModel(
40 | input_size=subband_input_size,
41 | output_size=2,
42 | hidden_size=sb_model_hidden_size,
43 | num_layers=2,
44 | bidirectional=False,
45 | sequence_model=sequence_model,
46 | output_activate_function=sb_output_activate_function,
47 | middle_tac_hidden_times=sbinter_middle_hidden_times
48 | )
49 |
50 | self.sb_num_neighbors = sb_num_neighbors
51 | # self.fb_num_neighbors = fb_num_neighbors
52 | self.look_ahead = look_ahead
53 | self.norm = self.norm_wrapper(norm_type)
54 | self.num_groups_in_drop_band = num_groups_in_drop_band
55 |
56 | if weight_init:
57 | self.apply(self.weight_init)
58 |
59 | def forward(self, noisy_mag):
60 | """
61 | Args:
62 | noisy_mag: noisy magnitude spectrogram
63 |
64 | Returns:
65 | The real part and imag part of the enhanced spectrogram
66 |
67 | Shapes:
68 | noisy_mag: [B, 1, F, T]
69 | return: [B, 2, F, T]
70 | """
71 | assert noisy_mag.dim() == 4
72 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead
73 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size()
74 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs."
75 |
76 | # Unfold noisy input, [B, N=F, C, F_s, T]
77 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors)
78 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1,
79 | num_frames)
80 |
81 | sb_input = self.norm(noisy_mag_unfolded)
82 |
83 | if batch_size > 1:
84 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3),
85 | num_groups=self.num_groups_in_drop_band) # [B, F_s, F//num_groups, T]
86 | num_freqs = sb_input.shape[2]
87 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, F_s, T]
88 |
89 | # [B, F//num_groups, F_s, T] => [B * F, 2, T] => [B, F, 2, T]
90 | sb_mask = self.sb_model(sb_input)
91 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous()
92 |
93 | output = sb_mask[:, :, :, self.look_ahead:]
94 | return output
95 |
96 |
97 | if __name__ == "__main__":
98 | model = Inter_SubNet(
99 | num_freqs=257,
100 | look_ahead=2,
101 | sequence_model="LSTM",
102 | sb_num_neighbors=15,
103 | sb_output_activate_function=False,
104 | sb_model_hidden_size=384,
105 | weight_init=False,
106 | norm_type="offline_laplace_norm",
107 | num_groups_in_drop_band=2,
108 | sbinter_middle_hidden_times=0.8
109 | )
110 |
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/model/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/inter_subnet/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/inter_subnet/trainer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/subband_model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/subband_model/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/dataset/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/subband_model/dataset/dataset_inference.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import librosa
4 | import numpy as np
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.utils import basename
8 |
9 |
10 | class Dataset(BaseDataset):
11 | def __init__(self,
12 | dataset_dir_list,
13 | sr,
14 | ):
15 | """
16 | Args:
17 | noisy_dataset_dir_list (str or list): noisy dir or noisy dir list
18 | """
19 | super().__init__()
20 | assert isinstance(dataset_dir_list, list)
21 | self.sr = sr
22 |
23 | noisy_file_path_list = []
24 | for dataset_dir in dataset_dir_list:
25 | dataset_dir = Path(dataset_dir).expanduser().absolute()
26 | noisy_file_path_list += librosa.util.find_files(dataset_dir.as_posix()) # Sorted
27 |
28 | self.noisy_file_path_list = noisy_file_path_list
29 | self.length = len(self.noisy_file_path_list)
30 |
31 | def __len__(self):
32 | return self.length
33 |
34 | def __getitem__(self, item):
35 | noisy_file_path = self.noisy_file_path_list[item]
36 | noisy_y = librosa.load(noisy_file_path, sr=self.sr)[0]
37 | noisy_y = noisy_y.astype(np.float32)
38 |
39 | return noisy_y, basename(noisy_file_path)[0]
40 |
--------------------------------------------------------------------------------
/speech_enhance/subband_model/dataset/dataset_train.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from audio_zen.acoustics.feature import norm_amplitude, tailor_dB_FS, is_clipped, load_wav, subsample
5 | from audio_zen.dataset.base_dataset import BaseDataset
6 | from audio_zen.utils import expand_path
7 | from joblib import Parallel, delayed
8 | from scipy import signal
9 | from tqdm import tqdm
10 |
11 |
12 | class Dataset(BaseDataset):
13 | def __init__(self,
14 | clean_dataset,
15 | clean_dataset_limit,
16 | clean_dataset_offset,
17 | noise_dataset,
18 | noise_dataset_limit,
19 | noise_dataset_offset,
20 | rir_dataset,
21 | rir_dataset_limit,
22 | rir_dataset_offset,
23 | snr_range,
24 | reverb_proportion,
25 | silence_length,
26 | target_dB_FS,
27 | target_dB_FS_floating_value,
28 | sub_sample_length,
29 | sr,
30 | pre_load_clean_dataset,
31 | pre_load_noise,
32 | pre_load_rir,
33 | num_workers
34 | ):
35 | """
36 | Dynamic mixing for training
37 |
38 | Args:
39 | clean_dataset_limit:
40 | clean_dataset_offset:
41 | noise_dataset_limit:
42 | noise_dataset_offset:
43 | rir_dataset:
44 | rir_dataset_limit:
45 | rir_dataset_offset:
46 | snr_range:
47 | reverb_proportion:
48 | clean_dataset: scp file
49 | noise_dataset: scp file
50 | sub_sample_length:
51 | sr:
52 | """
53 | super().__init__()
54 | # acoustics args
55 | self.sr = sr
56 |
57 | # parallel args
58 | self.num_workers = num_workers
59 |
60 | clean_dataset_list = [line.rstrip('\n') for line in open(expand_path(clean_dataset), "r")]
61 | noise_dataset_list = [line.rstrip('\n') for line in open(expand_path(noise_dataset), "r")]
62 | rir_dataset_list = [line.rstrip('\n') for line in open(expand_path(rir_dataset), "r")]
63 |
64 | clean_dataset_list = self._offset_and_limit(clean_dataset_list, clean_dataset_offset, clean_dataset_limit)
65 | noise_dataset_list = self._offset_and_limit(noise_dataset_list, noise_dataset_offset, noise_dataset_limit)
66 | rir_dataset_list = self._offset_and_limit(rir_dataset_list, rir_dataset_offset, rir_dataset_limit)
67 |
68 | if pre_load_clean_dataset:
69 | clean_dataset_list = self._preload_dataset(clean_dataset_list, remark="Clean Dataset")
70 |
71 | if pre_load_noise:
72 | noise_dataset_list = self._preload_dataset(noise_dataset_list, remark="Noise Dataset")
73 |
74 | if pre_load_rir:
75 | rir_dataset_list = self._preload_dataset(rir_dataset_list, remark="RIR Dataset")
76 |
77 | self.clean_dataset_list = clean_dataset_list
78 | self.noise_dataset_list = noise_dataset_list
79 | self.rir_dataset_list = rir_dataset_list
80 |
81 | snr_list = self._parse_snr_range(snr_range)
82 | self.snr_list = snr_list
83 |
84 | assert 0 <= reverb_proportion <= 1, "reverberation proportion should be in [0, 1]"
85 | self.reverb_proportion = reverb_proportion
86 | self.silence_length = silence_length
87 | self.target_dB_FS = target_dB_FS
88 | self.target_dB_FS_floating_value = target_dB_FS_floating_value
89 | self.sub_sample_length = sub_sample_length
90 |
91 | self.length = len(self.clean_dataset_list)
92 |
93 | def __len__(self):
94 | return self.length
95 |
96 | def _preload_dataset(self, file_path_list, remark=""):
97 | waveform_list = Parallel(n_jobs=self.num_workers)(
98 | delayed(load_wav)(f_path) for f_path in tqdm(file_path_list, desc=remark)
99 | )
100 | return list(zip(file_path_list, waveform_list))
101 |
102 | @staticmethod
103 | def _random_select_from(dataset_list):
104 | return random.choice(dataset_list)
105 |
106 | def _select_noise_y(self, target_length):
107 | noise_y = np.zeros(0, dtype=np.float32)
108 | silence = np.zeros(int(self.sr * self.silence_length), dtype=np.float32)
109 | remaining_length = target_length
110 |
111 | while remaining_length > 0:
112 | noise_file = self._random_select_from(self.noise_dataset_list)
113 | noise_new_added = load_wav(noise_file, sr=self.sr)
114 | noise_y = np.append(noise_y, noise_new_added)
115 | remaining_length -= len(noise_new_added)
116 |
117 | # 如果还需要添加新的噪声,就插入一个小静音段
118 | if remaining_length > 0:
119 | silence_len = min(remaining_length, len(silence))
120 | noise_y = np.append(noise_y, silence[:silence_len])
121 | remaining_length -= silence_len
122 |
123 | if len(noise_y) > target_length:
124 | idx_start = np.random.randint(len(noise_y) - target_length)
125 | noise_y = noise_y[idx_start:idx_start + target_length]
126 |
127 | return noise_y
128 |
129 | @staticmethod
130 | def snr_mix(clean_y, noise_y, snr, target_dB_FS, target_dB_FS_floating_value, rir=None, eps=1e-6):
131 | """
132 | 混合噪声与纯净语音,当 rir 参数不为空时,对纯净语音施加混响效果
133 |
134 | Args:
135 | clean_y: 纯净语音
136 | noise_y: 噪声
137 | snr (int): 信噪比
138 | target_dB_FS (int):
139 | target_dB_FS_floating_value (int):
140 | rir: room impulse response, None 或 np.array
141 | eps: eps
142 |
143 | Returns:
144 | (noisy_y,clean_y)
145 | """
146 | if rir is not None:
147 | if rir.ndim > 1:
148 | rir_idx = np.random.randint(0, rir.shape[0])
149 | rir = rir[rir_idx, :]
150 |
151 | clean_y = signal.fftconvolve(clean_y, rir)[:len(clean_y)]
152 |
153 | clean_y, _ = norm_amplitude(clean_y)
154 | clean_y, _, _ = tailor_dB_FS(clean_y, target_dB_FS)
155 | clean_rms = (clean_y ** 2).mean() ** 0.5
156 |
157 | noise_y, _ = norm_amplitude(noise_y)
158 | noise_y, _, _ = tailor_dB_FS(noise_y, target_dB_FS)
159 | noise_rms = (noise_y ** 2).mean() ** 0.5
160 |
161 | snr_scalar = clean_rms / (10 ** (snr / 20)) / (noise_rms + eps)
162 | noise_y *= snr_scalar
163 | noisy_y = clean_y + noise_y
164 |
165 | # Randomly select RMS value of dBFS between -15 dBFS and -35 dBFS and normalize noisy speech with that value
166 | noisy_target_dB_FS = np.random.randint(
167 | target_dB_FS - target_dB_FS_floating_value,
168 | target_dB_FS + target_dB_FS_floating_value
169 | )
170 |
171 | # 使用 noisy 的 rms 放缩音频
172 | noisy_y, _, noisy_scalar = tailor_dB_FS(noisy_y, noisy_target_dB_FS)
173 | clean_y *= noisy_scalar
174 |
175 | # 合成带噪语音的时候可能会 clipping,虽然极少
176 | # 对 noisy, clean_y, noise_y 稍微进行调整
177 | if is_clipped(noisy_y):
178 | noisy_y_scalar = np.max(np.abs(noisy_y)) / (0.99 - eps) # 相当于除以 1
179 | noisy_y = noisy_y / noisy_y_scalar
180 | clean_y = clean_y / noisy_y_scalar
181 |
182 | return noisy_y, clean_y
183 |
184 | def __getitem__(self, item):
185 | clean_file = self.clean_dataset_list[item]
186 | clean_y = load_wav(clean_file, sr=self.sr)
187 | clean_y = subsample(clean_y, sub_sample_length=int(self.sub_sample_length * self.sr))
188 |
189 | noise_y = self._select_noise_y(target_length=len(clean_y))
190 | assert len(clean_y) == len(noise_y), f"Inequality: {len(clean_y)} {len(noise_y)}"
191 |
192 | snr = self._random_select_from(self.snr_list)
193 | use_reverb = bool(np.random.random(1) < self.reverb_proportion)
194 |
195 | noisy_y, clean_y = self.snr_mix(
196 | clean_y=clean_y,
197 | noise_y=noise_y,
198 | snr=snr,
199 | target_dB_FS=self.target_dB_FS,
200 | target_dB_FS_floating_value=self.target_dB_FS_floating_value,
201 | rir=load_wav(self._random_select_from(self.rir_dataset_list), sr=self.sr) if use_reverb else None
202 | )
203 |
204 | noisy_y = noisy_y.astype(np.float32)
205 | clean_y = clean_y.astype(np.float32)
206 |
207 | return noisy_y, clean_y
208 |
--------------------------------------------------------------------------------
/speech_enhance/subband_model/dataset/dataset_validation.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import librosa
5 |
6 | from audio_zen.dataset.base_dataset import BaseDataset
7 | from audio_zen.acoustics.utils import load_wav
8 | from audio_zen.utils import basename
9 |
10 |
11 | class Dataset(BaseDataset):
12 | def __init__(
13 | self,
14 | dataset_dir_list,
15 | sr,
16 | ):
17 | """
18 | Construct DNS validation set
19 |
20 | synthetic/
21 | with_reverb/
22 | noisy/
23 | clean_y/
24 | no_reverb/
25 | noisy/
26 | clean_y/
27 | """
28 | super(Dataset, self).__init__()
29 | noisy_files_list = []
30 |
31 | for dataset_dir in dataset_dir_list:
32 | dataset_dir = Path(dataset_dir).expanduser().absolute()
33 | noisy_files_list += librosa.util.find_files((dataset_dir / "noisy").as_posix())
34 |
35 | self.length = len(noisy_files_list)
36 | self.noisy_files_list = noisy_files_list
37 | self.sr = sr
38 |
39 | def __len__(self):
40 | return self.length
41 |
42 | def __getitem__(self, item):
43 | """
44 | use the absolute path of the noisy speech to find the corresponding clean speech.
45 |
46 | Notes
47 | with_reverb and no_reverb dirs have same-named files.
48 | If we use `basename`, the problem will be raised (cover) in visualization.
49 |
50 | Returns:
51 | noisy: [waveform...], clean: [waveform...], type: [reverb|no_reverb] + name
52 | """
53 | noisy_file_path = self.noisy_files_list[item]
54 | parent_dir = Path(noisy_file_path).parents[1].name
55 | noisy_filename, _ = basename(noisy_file_path)
56 |
57 | reverb_remark = "" # When the speech comes from reverb_dir, insert "with_reverb" before the filename
58 |
59 | # speech_type 与 validation 部分要一致,用于区分后续的可视化
60 | if parent_dir == "with_reverb":
61 | speech_type = "With_reverb"
62 | elif parent_dir == "no_reverb":
63 | speech_type = "No_reverb"
64 | elif parent_dir == "dns_2_non_english":
65 | speech_type = "Non_english"
66 | elif parent_dir == "dns_2_emotion":
67 | speech_type = "Emotion"
68 | elif parent_dir == "dns_2_singing":
69 | speech_type = "Singing"
70 | else:
71 | raise NotImplementedError(f"Not supported dir: {parent_dir}")
72 |
73 | # Find the corresponding clean speech using "parent_dir" and "file_id"
74 | file_id = noisy_filename.split("_")[-1]
75 | if parent_dir in ("dns_2_emotion", "dns_2_singing"):
76 | # e.g., synthetic_emotion_1792_snr19_tl-35_fileid_19 => synthetic_emotion_clean_fileid_15
77 | clean_filename = f"synthetic_{speech_type.lower()}_clean_fileid_{file_id}"
78 | elif parent_dir == "dns_2_non_english":
79 | # e.g., synthetic_german_collection044_14_-04_CFQQgBvv2xQ_snr8_tl-21_fileid_121 => synthetic_clean_fileid_121
80 | clean_filename = f"synthetic_clean_fileid_{file_id}"
81 | else:
82 | # e.g., clnsp587_Unt_WsHPhfA_snr8_tl-30_fileid_300 => clean_fileid_300
83 | if parent_dir == "with_reverb":
84 | reverb_remark = "with_reverb"
85 | clean_filename = f"clean_fileid_{file_id}"
86 |
87 | clean_file_path = noisy_file_path.replace(f"noisy/{noisy_filename}", f"clean/{clean_filename}")
88 |
89 | noisy = load_wav(os.path.abspath(os.path.expanduser(noisy_file_path)), sr=self.sr)
90 | clean = load_wav(os.path.abspath(os.path.expanduser(clean_file_path)), sr=self.sr)
91 |
92 | return noisy, clean, reverb_remark + noisy_filename, speech_type
93 |
--------------------------------------------------------------------------------
/speech_enhance/subband_model/inferencer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/inferencer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/subband_model/inferencer/inferencer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 | from audio_zen.acoustics.feature import mag_phase
5 | from audio_zen.acoustics.mask import decompress_cIRM
6 | from audio_zen.inferencer.base_inferencer import BaseInferencer
7 | import soundfile as sf
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | # for log
12 | from utils.logger import log
13 | print=log
14 |
15 | def cumulative_norm(input):
16 | eps = 1e-10
17 | device = input.device
18 | data_type = input.dtype
19 | n_dim = input.ndim
20 |
21 | assert n_dim in (3, 4)
22 |
23 | if n_dim == 3:
24 | n_channels = 1
25 | batch_size, n_freqs, n_frames = input.size()
26 | else:
27 | batch_size, n_channels, n_freqs, n_frames = input.size()
28 | input = input.reshape(batch_size * n_channels, n_freqs, n_frames)
29 |
30 | step_sum = torch.sum(input, dim=1) # [B, T]
31 | step_pow_sum = torch.sum(torch.square(input), dim=1)
32 |
33 | cumulative_sum = torch.cumsum(step_sum, dim=-1) # [B, T]
34 | cumulative_pow_sum = torch.cumsum(step_pow_sum, dim=-1) # [B, T]
35 |
36 | entry_count = torch.arange(n_freqs, n_freqs * n_frames + 1, n_freqs, dtype=data_type, device=device)
37 | entry_count = entry_count.reshape(1, n_frames) # [1, T]
38 | entry_count = entry_count.expand_as(cumulative_sum) # [1, T] => [B, T]
39 |
40 | cum_mean = cumulative_sum / entry_count # B, T
41 | cum_var = (cumulative_pow_sum - 2 * cum_mean * cumulative_sum) / entry_count + cum_mean.pow(2) # B, T
42 | cum_std = (cum_var + eps).sqrt() # B, T
43 |
44 | cum_mean = cum_mean.reshape(batch_size * n_channels, 1, n_frames)
45 | cum_std = cum_std.reshape(batch_size * n_channels, 1, n_frames)
46 |
47 | x = (input - cum_mean) / cum_std
48 |
49 | if n_dim == 4:
50 | x = x.reshape(batch_size, n_channels, n_freqs, n_frames)
51 |
52 | return x
53 |
54 |
55 | class Inferencer(BaseInferencer):
56 | def __init__(self, config, checkpoint_path, output_dir):
57 | super().__init__(config, checkpoint_path, output_dir)
58 |
59 | @torch.no_grad()
60 | def mag(self, noisy, inference_args):
61 | noisy_complex = self.torch_stft(noisy)
62 | noisy_mag, noisy_phase = mag_phase(noisy_complex) # [B, F, T] => [B, 1, F, T]
63 |
64 | enhanced_mag = self.model(noisy_mag.unsqueeze(1)).squeeze(1)
65 |
66 | enhanced = self.torch_istft((enhanced_mag, noisy_phase), length=noisy.size(-1), use_mag_phase=True)
67 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
68 |
69 | return enhanced
70 |
71 | @torch.no_grad()
72 | def scaled_mask(self, noisy, inference_args):
73 | noisy_complex = self.torch_stft(noisy)
74 | noisy_mag, noisy_phase = mag_phase(noisy_complex)
75 |
76 | # [B, F, T] => [B, 1, F, T] => model => [B, 2, F, T] => [B, F, T, 2]
77 | noisy_mag = noisy_mag.unsqueeze(1)
78 | scaled_mask = self.model(noisy_mag)
79 | scaled_mask = scaled_mask.permute(0, 2, 3, 1)
80 |
81 | enhanced_complex = noisy_complex * scaled_mask
82 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1), use_mag_phase=False)
83 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
84 |
85 | return enhanced
86 |
87 | @torch.no_grad()
88 | def sub_band_crm_mask(self, noisy, inference_args):
89 | pad_mode = inference_args["pad_mode"]
90 | n_neighbor = inference_args["n_neighbor"]
91 |
92 | noisy = noisy.cpu().numpy().reshape(-1)
93 | noisy_D = self.librosa_stft(noisy)
94 |
95 | noisy_real = torch.tensor(noisy_D.real, device=self.device)
96 | noisy_imag = torch.tensor(noisy_D.imag, device=self.device)
97 | noisy_mag = torch.sqrt(torch.square(noisy_real) + torch.square(noisy_imag)) # [F, T]
98 | n_freqs, n_frames = noisy_mag.size()
99 |
100 | noisy_mag = noisy_mag.reshape(1, 1, n_freqs, n_frames)
101 | noisy_mag_padded = self._unfold(noisy_mag, pad_mode, n_neighbor) # [B, N, C, F_s, T] <=> [1, 257, 1, 31, T]
102 | noisy_mag_padded = noisy_mag_padded.squeeze(0).squeeze(1) # [257, 31, 200] <=> [B, F_s, T]
103 |
104 | pred_crm = self.model(noisy_mag_padded).detach() # [B, 2, T] <=> [F, 2, T]
105 | pred_crm = pred_crm.permute(0, 2, 1).contiguous() # [B, T, 2]
106 |
107 | lim = 9.99
108 | pred_crm = lim * (pred_crm >= lim) - lim * (pred_crm <= -lim) + pred_crm * (torch.abs(pred_crm) < lim)
109 | pred_crm = -10 * torch.log((10 - pred_crm) / (10 + pred_crm))
110 |
111 | enhanced_real = pred_crm[:, :, 0] * noisy_real - pred_crm[:, :, 1] * noisy_imag
112 | enhanced_imag = pred_crm[:, :, 1] * noisy_real + pred_crm[:, :, 0] * noisy_imag
113 |
114 | enhanced_real = enhanced_real.cpu().numpy()
115 | enhanced_imag = enhanced_imag.cpu().numpy()
116 | enhanced = self.librosa_istft(enhanced_real + 1j * enhanced_imag, length=len(noisy))
117 | return enhanced
118 |
119 | @torch.no_grad()
120 | def full_band_crm_mask(self, noisy, inference_args):
121 | noisy_complex = self.torch_stft(noisy)
122 | noisy_mag, _ = mag_phase(noisy_complex)
123 |
124 | noisy_mag = noisy_mag.unsqueeze(1)
125 | t1 = time.time()
126 | pred_crm = self.model(noisy_mag)
127 | t2 = time.time()
128 | pred_crm = pred_crm.permute(0, 2, 3, 1)
129 |
130 | pred_crm = decompress_cIRM(pred_crm)
131 | enhanced_real = pred_crm[..., 0] * noisy_complex.real - pred_crm[..., 1] * noisy_complex.imag
132 | enhanced_imag = pred_crm[..., 1] * noisy_complex.real + pred_crm[..., 0] * noisy_complex.imag
133 | enhanced_complex = torch.stack((enhanced_real, enhanced_imag), dim=-1)
134 | enhanced = self.torch_istft(enhanced_complex, length=noisy.size(-1))
135 | enhanced = enhanced.detach().squeeze(0).cpu().numpy()
136 |
137 | # rtf计算
138 | rtf = (t2 - t1) / (len(enhanced) * 1.0 / self.acoustic_config["sr"])
139 | print(f"model rtf: {rtf}")
140 |
141 | return enhanced
142 |
143 |
144 | @torch.no_grad()
145 | def overlapped_chunk(self, noisy, inference_args):
146 | sr = self.acoustic_config["sr"]
147 |
148 | noisy = noisy.squeeze(0)
149 |
150 | num_mics = 8
151 | chunk_length = sr * inference_args["chunk_length"]
152 | chunk_hop_length = chunk_length // 2
153 | num_chunks = int(noisy.shape[-1] / chunk_hop_length) + 1
154 |
155 | win = torch.hann_window(chunk_length, device=noisy.device)
156 |
157 | prev = None
158 | enhanced = None
159 | # 模拟语音的静音段,防止一上来就给语音,处理的不好
160 | for chunk_idx in range(num_chunks):
161 | if chunk_idx == 0:
162 | pad = torch.zeros((num_mics, 256), device=noisy.device)
163 |
164 | chunk_start_position = chunk_idx * chunk_hop_length
165 | chunk_end_position = chunk_start_position + chunk_length
166 |
167 | # concat([(8, 256), (..., ... + chunk_length)])
168 | noisy_chunk = torch.cat((pad, noisy[:, chunk_start_position:chunk_end_position]), dim=1)
169 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0))
170 | enhanced_chunk = torch.squeeze(enhanced_chunk)
171 | enhanced_chunk = enhanced_chunk[256:]
172 |
173 | # Save the prior half chunk,
174 | cur = enhanced_chunk[:chunk_length // 2]
175 |
176 | # only for the 1st chunk,no overlap for the very 1st chunk prior half
177 | prev = enhanced_chunk[chunk_length // 2:] * win[chunk_length // 2:]
178 | else:
179 | # use the previous noisy data as the pad
180 | pad = noisy[:, (chunk_idx * chunk_hop_length - 256):(chunk_idx * chunk_hop_length)]
181 |
182 | chunk_start_position = chunk_idx * chunk_hop_length
183 | chunk_end_position = chunk_start_position + chunk_length
184 |
185 | noisy_chunk = torch.cat((pad, noisy[:8, chunk_start_position:chunk_end_position]), dim=1)
186 | enhanced_chunk = self.model(noisy_chunk.unsqueeze(0))
187 | enhanced_chunk = torch.squeeze(enhanced_chunk)
188 | enhanced_chunk = enhanced_chunk[256:]
189 |
190 | # 使用这个窗函数来对拼接的位置进行平滑?
191 | enhanced_chunk = enhanced_chunk * win[:len(enhanced_chunk)]
192 |
193 | tmp = enhanced_chunk[:chunk_length // 2]
194 | cur = tmp[:min(len(tmp), len(prev))] + prev[:min(len(tmp), len(prev))]
195 | prev = enhanced_chunk[chunk_length // 2:]
196 |
197 | if enhanced is None:
198 | enhanced = cur
199 | else:
200 | enhanced = torch.cat((enhanced, cur), dim=0)
201 |
202 | enhanced = enhanced[:noisy.shape[1]]
203 | return enhanced.detach().squeeze(0).cpu().numpy()
204 |
205 | @torch.no_grad()
206 | def time_domain(self, noisy, inference_args):
207 | noisy = noisy.to(self.device)
208 | enhanced = self.model(noisy)
209 | return enhanced.detach().squeeze().cpu().numpy()
210 |
211 |
212 | if __name__ == '__main__':
213 | a = torch.rand(10, 2, 161, 200)
214 | print(cumulative_norm(a).shape)
215 |
--------------------------------------------------------------------------------
/speech_enhance/subband_model/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/model/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/subband_model/model/subband_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional
3 |
4 | from audio_zen.acoustics.feature import drop_band
5 | from audio_zen.model.base_model import BaseModel
6 | from audio_zen.model.module.sequence_model import SequenceModel
7 |
8 | # for log
9 | from utils.logger import log
10 |
11 | print = log
12 |
13 |
14 | class Subband_model(BaseModel):
15 | def __init__(self,
16 | look_ahead,
17 | sequence_model,
18 | sb_num_neighbors,
19 | sb_output_activate_function,
20 | sb_model_hidden_size,
21 | norm_type="offline_laplace_norm",
22 | num_groups_in_drop_band=2,
23 | weight_init=True,
24 | ):
25 | """
26 | FullSubNet model (cIRM mask)
27 |
28 | Args:
29 | num_freqs: Frequency dim of the input
30 | sb_num_neighbors: Number of the neighbor frequencies in each side
31 | look_ahead: Number of use of the future frames
32 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM)
33 | """
34 | super().__init__()
35 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM."
36 |
37 | self.sb_model = SequenceModel(
38 | input_size=(sb_num_neighbors * 2 + 1),
39 | output_size=2,
40 | hidden_size=sb_model_hidden_size,
41 | num_layers=2,
42 | bidirectional=False,
43 | sequence_model=sequence_model,
44 | output_activate_function=sb_output_activate_function
45 | )
46 |
47 | self.sb_num_neighbors = sb_num_neighbors
48 | self.look_ahead = look_ahead
49 | self.norm = self.norm_wrapper(norm_type)
50 | self.num_groups_in_drop_band = num_groups_in_drop_band
51 |
52 | if weight_init:
53 | self.apply(self.weight_init)
54 |
55 | def forward(self, noisy_mag):
56 | """
57 | Args:
58 | noisy_mag: noisy magnitude spectrogram
59 |
60 | Returns:
61 | The real part and imag part of the enhanced spectrogram
62 |
63 | Shapes:
64 | noisy_mag: [B, 1, F, T]
65 | return: [B, 2, F, T]
66 | """
67 | assert noisy_mag.dim() == 4
68 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead
69 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size()
70 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs."
71 |
72 | # Unfold noisy input, [B, N=F, C, F_s, T]
73 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors)
74 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1,
75 | num_frames)
76 |
77 | sb_input = self.norm(noisy_mag_unfolded)
78 |
79 | # Speeding up training without significant performance degradation. These will be updated to the paper later.
80 | if batch_size > 1:
81 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3),
82 | num_groups=self.num_groups_in_drop_band) # [B, F_s, F//num_groups, T]
83 | num_freqs = sb_input.shape[2]
84 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, F_s, T]
85 |
86 | sb_input = sb_input.reshape(
87 | batch_size * num_freqs,
88 | (self.sb_num_neighbors * 2 + 1),
89 | num_frames
90 | )
91 |
92 | # [B * F, F_s, T] => [B * F, 2, T] => [B, F, 2, T]
93 | sb_mask = self.sb_model(sb_input)
94 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous()
95 |
96 | output = sb_mask[:, :, :, self.look_ahead:]
97 | return output
98 |
99 |
100 | class Subband_model_Large(BaseModel):
101 | def __init__(self,
102 | num_freqs,
103 | look_ahead,
104 | sequence_model,
105 | fb_num_neighbors,
106 | sb_num_neighbors,
107 | fb_output_activate_function,
108 | sb_output_activate_function,
109 | fb_model_hidden_size,
110 | sb_model_hidden_size,
111 | norm_type="offline_laplace_norm",
112 | num_groups_in_drop_band=2,
113 | weight_init=True,
114 | ):
115 | """
116 | Subband model (cIRM mask)
117 |
118 | Args:
119 | num_freqs: Frequency dim of the input
120 | sb_num_neighbors: Number of the neighbor frequencies in each side
121 | look_ahead: Number of use of the future frames
122 | sequence_model: Chose one sequence model as the basic model (GRU, LSTM)
123 | """
124 | super().__init__()
125 | assert sequence_model in ("GRU", "LSTM"), f"{self.__class__.__name__} only support GRU and LSTM."
126 |
127 | self.sb_model = SequenceModel(
128 | input_size=(sb_num_neighbors * 2 + 1),
129 | output_size=2,
130 | hidden_size=sb_model_hidden_size,
131 | num_layers=3,
132 | bidirectional=False,
133 | sequence_model=sequence_model,
134 | output_activate_function=sb_output_activate_function
135 | )
136 |
137 | self.sb_num_neighbors = sb_num_neighbors
138 | self.look_ahead = look_ahead
139 | self.norm = self.norm_wrapper(norm_type)
140 | self.num_groups_in_drop_band = num_groups_in_drop_band
141 |
142 | if weight_init:
143 | self.apply(self.weight_init)
144 |
145 | def forward(self, noisy_mag):
146 | """
147 | Args:
148 | noisy_mag: noisy magnitude spectrogram
149 |
150 | Returns:
151 | The real part and imag part of the enhanced spectrogram
152 |
153 | Shapes:
154 | noisy_mag: [B, 1, F, T]
155 | return: [B, 2, F, T]
156 | """
157 | assert noisy_mag.dim() == 4
158 | noisy_mag = functional.pad(noisy_mag, [0, self.look_ahead]) # Pad the look ahead
159 | batch_size, num_channels, num_freqs, num_frames = noisy_mag.size()
160 | assert num_channels == 1, f"{self.__class__.__name__} takes the mag feature as inputs."
161 |
162 |
163 | # Unfold noisy input, [B, N=F, C, F_s, T]
164 | noisy_mag_unfolded = self.unfold(noisy_mag, num_neighbor=self.sb_num_neighbors)
165 | noisy_mag_unfolded = noisy_mag_unfolded.reshape(batch_size, num_freqs, self.sb_num_neighbors * 2 + 1,
166 | num_frames)
167 |
168 |
169 | sb_input = self.norm(noisy_mag_unfolded)
170 |
171 |
172 | if batch_size > 1:
173 | sb_input = drop_band(sb_input.permute(0, 2, 1, 3),
174 | num_groups=self.num_groups_in_drop_band) # [B, F_s, F//num_groups, T]
175 | num_freqs = sb_input.shape[2]
176 | sb_input = sb_input.permute(0, 2, 1, 3) # [B, F//num_groups, F_s, T]
177 |
178 | sb_input = sb_input.reshape(
179 | batch_size * num_freqs,
180 | (self.sb_num_neighbors * 2 + 1),
181 | num_frames
182 | )
183 |
184 | # [B * F, (F_s), T] => [B * F, 2, T] => [B, F, 2, T]
185 | sb_mask = self.sb_model(sb_input)
186 | sb_mask = sb_mask.reshape(batch_size, num_freqs, 2, num_frames).permute(0, 2, 1, 3).contiguous()
187 |
188 | output = sb_mask[:, :, :, self.look_ahead:]
189 | return output
190 |
191 |
192 | if __name__ == "__main__":
193 | model = Subband_model(
194 | look_ahead=2,
195 | sequence_model="LSTM",
196 | sb_num_neighbors=15,
197 | sb_output_activate_function=False,
198 | sb_model_hidden_size=384,
199 | weight_init=False,
200 | norm_type="offline_laplace_norm",
201 | num_groups_in_drop_band=2,
202 | )
203 |
204 |
--------------------------------------------------------------------------------
/speech_enhance/subband_model/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/subband_model/trainer/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/tools/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/tools/analyse.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import matplotlib
4 |
5 | def read_from_txt(filename):
6 | ans_dict = {}
7 | with open(filename, "r") as f: # 打开文件
8 | data = f.readlines() # 读取文件
9 | for line in data:
10 | line = line.strip('\n')
11 | lines = line.split(" ")
12 | ans_dict[lines[0]] = float(lines[1])
13 | return ans_dict
14 |
15 | def write_to_txt(filename, total_list):
16 | with open(filename, 'w+') as temp_file:
17 | for i1, sisdr in total_list:
18 | string = i1 + ": " + str(sisdr) + '\n'
19 | temp_file.write(string)
20 |
21 | def takeSecond(elem):
22 | return elem[1]
23 |
24 | def make_rank(sidir_list):
25 | sidir_list.sort(key=takeSecond, reverse=True)
26 |
27 | def compare_two_data(data1, data2):
28 | ans_list = []
29 | for wav in data1.keys():
30 | num1 = data1[wav]
31 | num2 = data2[wav]
32 | ans_list.append((wav, num1 - num2))
33 | make_rank(ans_list)
34 | return ans_list
35 |
36 |
37 |
38 | def draw_hist(data, filename):
39 | plt.hist(data, facecolor="blue", edgecolor="black", alpha=0.7)
40 | plt.xlabel("Interval")
41 | plt.ylabel("Frequency")
42 | plt.title("Frequency Distribution Histogram")
43 | plt.show()
44 | plt.savefig(filename)
45 |
46 | def draw_two_hist(data1, data1_name, data2, data2_name, filename):
47 | # bins = np.linspace(5, 30, 5)
48 | # plt.hist(data1, bins, edgecolor="black", alpha=0.7, label=data1_name)
49 | # plt.hist(data2, bins, edgecolor="black", alpha=0.7, label=data2_name)
50 | plt.hist(data1, edgecolor="black", alpha=0.7, label=data1_name)
51 | plt.hist(data2, edgecolor="black", alpha=0.7, label=data2_name)
52 | plt.legend(loc='upper right')
53 | plt.xlabel("Interval(Score)")
54 | plt.ylabel("Frequency")
55 | plt.title("Frequency Distribution Histogram")
56 | plt.show()
57 | plt.savefig(filename)
58 |
59 |
60 | data1 = read_from_txt("/workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/fullsubnet_plus/NB_PESQ.txt")
61 | data2 = read_from_txt("/workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/our_fullsubnet/NB_PESQ.txt")
62 |
63 | # print(compare_two_data(data1, data2))
64 | write_to_txt("/workspace/project-nas-11025-sh/speech_enhance/egs/DNS-master/s1_16k/mertrics/compare/compare_NB_PESQ.txt", compare_two_data(data1, data2))
--------------------------------------------------------------------------------
/speech_enhance/tools/calculate_metrics.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | from inspect import getmembers, isfunction
6 | from pathlib import Path
7 |
8 | import librosa
9 | import numpy as np
10 | from joblib import Parallel, delayed
11 | from tqdm import tqdm
12 |
13 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
14 | import audio_zen.metrics as metrics
15 | from audio_zen.utils import prepare_empty_dir
16 |
17 |
18 | def load_wav_paths_from_scp(scp_path, to_abs=True):
19 | wav_paths = [line.rstrip('\n') for line in open(os.path.abspath(os.path.expanduser(scp_path)), "r")]
20 | if to_abs:
21 | tmp = []
22 | for path in wav_paths:
23 | tmp.append(os.path.abspath(os.path.expanduser(path)))
24 | wav_paths = tmp
25 | return wav_paths
26 |
27 |
28 | def shrink_multi_channel_path(
29 | full_dataset_list: list,
30 | num_channels: int
31 | ) -> list:
32 | """
33 |
34 | Args:
35 | full_dataset_list: [
36 | 028000010_room1_rev_RT600.06_mic1_micpos1.5p0.5p1.93_srcpos0.46077p1.1p1.68_langle180_angle150_ds1.2_mic1.wav
37 | ...
38 | 028000010_room1_rev_RT600.06_mic1_micpos1.5p0.5p1.93_srcpos0.46077p1.1p1.68_langle180_angle150_ds1.2_mic2.wav
39 | ]
40 | num_channels:
41 |
42 | Returns:
43 |
44 | """
45 | assert len(full_dataset_list) % num_channels == 0, "Num error"
46 |
47 | shrunk_dataset_list = []
48 | for index in range(0, len(full_dataset_list), num_channels):
49 | full_path = full_dataset_list[index]
50 | shrunk_path = f"{'_'.join(full_path.split('_')[:-1])}.wav"
51 | shrunk_dataset_list.append(shrunk_path)
52 |
53 | assert len(shrunk_dataset_list) == len(full_dataset_list) // num_channels
54 | return shrunk_dataset_list
55 |
56 |
57 | def get_basename(path):
58 | return os.path.splitext(os.path.basename(path))[0]
59 |
60 |
61 | def pre_processing(est, ref, specific_dataset=None):
62 | ref = Path(ref).expanduser().absolute()
63 | est = Path(est).expanduser().absolute()
64 |
65 | if ref.is_dir():
66 | reference_wav_paths = librosa.util.find_files(ref.as_posix(), ext="wav")
67 | else:
68 | reference_wav_paths = load_wav_paths_from_scp(ref.as_posix())
69 |
70 | if est.is_dir():
71 | estimated_wav_paths = librosa.util.find_files(est.as_posix(), ext="wav")
72 | else:
73 | estimated_wav_paths = load_wav_paths_from_scp(est.as_posix())
74 |
75 | if not specific_dataset:
76 | # 默认情况下,两个列表应该是一一对应的
77 | check_two_aligned_list(reference_wav_paths, estimated_wav_paths)
78 | else:
79 | # 针对不同的数据集,进行手工对齐,保证两个列表一一对应
80 | reordered_estimated_wav_paths = []
81 | if specific_dataset == "dns_1":
82 | # 按照 reference_wav_paths 中文件的后缀名重排 estimated_wav_paths
83 | # 提取后缀
84 | for ref_path in reference_wav_paths:
85 | for est_path in estimated_wav_paths:
86 | est_basename = get_basename(est_path)
87 | if "clean_" + "_".join(est_basename.split("_")[-2:]) == get_basename(ref_path):
88 | reordered_estimated_wav_paths.append(est_path)
89 | elif specific_dataset == "dns_2":
90 | for ref_path in reference_wav_paths:
91 | for est_path in estimated_wav_paths:
92 | # synthetic_french_acejour_orleans_sb_64kb-01_jbq2HJt9QXw_snr14_tl-26_fileid_47
93 | # synthetic_clean_fileid_47
94 | est_basename = get_basename(est_path)
95 | file_id = est_basename.split('_')[-1]
96 | if f"synthetic_clean_fileid_{file_id}" == get_basename(ref_path):
97 | reordered_estimated_wav_paths.append(est_path)
98 | elif specific_dataset == "maxhub_noisy":
99 | # Reference_channel = 0
100 | # 寻找对应的干净语音
101 | reference_channel = 0
102 | print(f"Found #files: {len(reference_wav_paths)}")
103 | for est_path in estimated_wav_paths:
104 | # MC0604W0154_room4_rev_RT600.1_mic1_micpos1.5p0.5p1.84_srcpos4.507p1.5945p1.3_langle180_angle20_ds3.2_kesou_kesou_mic1.wav
105 | est_basename = get_basename(est_path) # 带噪的
106 | for ref_path in reference_wav_paths:
107 | ref_basename = get_basename(ref_path)
108 |
109 | else:
110 | raise NotImplementedError(f"Not supported specific dataset {specific_dataset}.")
111 | estimated_wav_paths = reordered_estimated_wav_paths
112 |
113 | return reference_wav_paths, estimated_wav_paths
114 |
115 |
116 | def check_two_aligned_list(a, b):
117 | assert len(a) == len(b), "两个列表中的长度不等."
118 | for z, (i, j) in enumerate(zip(a, b), start=1):
119 | assert get_basename(i) == get_basename(j), f"两个列表中存在不相同的文件名,行数为: {z}" \
120 | f"\n\t {i}" \
121 | f"\n\t{j}"
122 |
123 |
124 | def compute_metric(reference_wav_paths, estimated_wav_paths, sr, metric_type="SI_SDR"):
125 | metrics_dict = {o[0]: o[1] for o in getmembers(metrics) if isfunction(o[1])}
126 | assert metric_type in metrics_dict, f"不支持的评价指标: {metric_type}"
127 | metric_function = metrics_dict[metric_type]
128 | if metric_type == "MOSNET":
129 | n_jobs = 1
130 | else:
131 | n_jobs = 40
132 |
133 | def calculate_metric(ref_wav_path, est_wav_path):
134 | ref_wav, _ = librosa.load(ref_wav_path, sr=sr)
135 | est_wav, _ = librosa.load(est_wav_path, sr=sr, mono=False)
136 | if est_wav.ndim > 1:
137 | est_wav = est_wav[0]
138 |
139 | basename = get_basename(ref_wav_path)
140 |
141 | ref_wav_len = len(ref_wav)
142 | est_wav_len = len(est_wav)
143 |
144 | if ref_wav_len != est_wav_len:
145 | print(f"[Warning] ref {ref_wav_len} and est {est_wav_len} are not in the same length")
146 | pass
147 |
148 | return basename, metric_function(ref_wav[:len(est_wav)], est_wav)
149 |
150 | metrics_result_store = Parallel(n_jobs=n_jobs)(
151 | delayed(calculate_metric)(ref, est) for ref, est in tqdm(zip(reference_wav_paths, estimated_wav_paths))
152 | )
153 | return metrics_result_store
154 |
155 | def takeSecond(elem):
156 | return elem[1]
157 |
158 | def make_rank(the_list):
159 | the_list.sort(key=takeSecond, reverse=True)
160 |
161 | def write_to_txt(filename, total_list):
162 | with open(filename, 'w+') as temp_file:
163 | for i1, sisdr in total_list:
164 | string = i1 + ": " + str(sisdr) + '\n'
165 | temp_file.write(string)
166 |
167 | def main(args):
168 | sr = args.sr
169 | metric_types = args.metric_types
170 | export_dir = args.export_dir
171 | specific_dataset = args.specific_dataset.lower()
172 |
173 | # 通过指定的 scp 文件或目录获取全部的 wav 样本
174 | reference_wav_paths, estimated_wav_paths = pre_processing(args.estimated, args.reference, specific_dataset)
175 |
176 | # if export_dir:
177 | # export_dir = Path(export_dir).expanduser().absolute()
178 | # prepare_empty_dir([export_dir])
179 |
180 | print(f"=== {args.estimated} === {args.reference} ===")
181 | for metric_type in metric_types.split(","):
182 | # print(reference_wav_paths)
183 | # print(estimated_wav_paths)
184 | metrics_result_store = compute_metric(reference_wav_paths, estimated_wav_paths, sr, metric_type=metric_type)
185 |
186 | # Print result
187 | print(f"{metric_type}: {metrics_result_store}")
188 | metric_value = np.mean(list(zip(*metrics_result_store))[1])
189 | print(f"{metric_type}: {metric_value}")
190 |
191 | # Export result
192 | if export_dir:
193 | make_rank(metrics_result_store)
194 | write_to_txt(export_dir + str(metric_type) + ".txt", metrics_result_store)
195 | # import tablib
196 |
197 | # export_path = export_dir / f"{metric_type}.xlsx"
198 | # print(f"Export result to {export_path}")
199 |
200 | # headers = ("Speech", f"{metric_type}")
201 | # metric_seq = [[basename, metric_value] for basename, metric_value in metrics_result_store]
202 | # data = tablib.Dataset(*metric_seq, headers=headers)
203 | # with open(export_path.as_posix(), "wb") as f:
204 | # f.write(data.export("xlsx"))
205 |
206 |
207 | if __name__ == '__main__':
208 | parser = argparse.ArgumentParser(
209 | description="输入两个目录或列表,计算各种评价指标的均值",
210 | epilog="python calculate_metrics.py -E 'est_dir' -R 'ref_dir' -M SI_SDR,STOI,WB_PESQ,NB_PESQ,SSNR,LSD,SRMR"
211 | )
212 | parser.add_argument("-R", "--reference", required=True, type=str, help="")
213 | parser.add_argument("-E", "--estimated", required=True, type=str, help="")
214 | parser.add_argument("-M", "--metric_types", required=True, type=str, help="哪个评价指标,要与 util.metrics 中的内容一致.")
215 | parser.add_argument("--sr", type=int, default=16000, help="采样率")
216 | parser.add_argument("-D", "--export_dir", type=str, default="", help="")
217 | parser.add_argument("--limit", type=int, default=None, help="[正在开发]从列表中读取文件的上限数量.")
218 | parser.add_argument("--offset", type=int, default=0, help="[正在开发]从列表中指定位置开始读取文件.")
219 | parser.add_argument("-S", "--specific_dataset", type=str, default="", help="指定数据集类型,e.g. DNS_1, DNS_2, 大小写均可")
220 | args = parser.parse_args()
221 | main(args)
222 |
223 | """
224 | TODO
225 | 1. 语音为多通道时如何算
226 | 2. 支持 register, 默认情况下应该计算所有 register 中语音
227 | """
228 |
--------------------------------------------------------------------------------
/speech_enhance/tools/collect_lst.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import sys
4 | from pathlib import Path
5 | import librosa
6 | from tqdm import tqdm
7 |
8 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
9 | from audio_zen.acoustics.mask import is_clipped, load_wav, activity_detector
10 |
11 |
12 | def offset_and_limit(data_list, offset, limit):
13 | data_list = data_list[offset:]
14 | if limit:
15 | data_list = data_list[:limit]
16 | return data_list
17 |
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser(description="FullSubNet")
21 | parser.add_argument('-candidate_datasets', '--candidate_datasets', help='delimited list input',
22 | type=lambda s: [item for item in s.split(',')])
23 | parser.add_argument("-dist_file", "--dist_file", required=True, type=str, help="output lst")
24 | parser.add_argument("-sr", "--sr", type=int, default=16000, help="sample rate")
25 | parser.add_argument("-wav_min_second", "--wav_min_second", type=int, default=3, help="the min length of a wav")
26 | parser.add_argument("-activity_threshold", "--activity_threshold", type=float, default=0.6,
27 | help="the activity threshold of speech/sil")
28 | parser.add_argument("-total_hrs", "--total_hrs", type=int, default=30, help="the length in time of wav(s)")
29 |
30 | args = parser.parse_args()
31 |
32 | candidate_datasets = args.candidate_datasets
33 | dataset_limit = None
34 | dataset_offset = 0
35 | dist_file = args.dist_file
36 |
37 | # 声学参数
38 | sr = args.sr
39 | wav_min_second = args.wav_min_second
40 | activity_threshold = args.activity_threshold
41 | total_hrs = args.total_hrs # 计划收集语音的总时长
42 |
43 | all_wav_path_list = []
44 | output_wav_path_list = []
45 | accumulated_time = 0.0
46 |
47 | is_clipped_wav_list = []
48 | is_low_activity_list = []
49 | is_too_short_list = []
50 |
51 | for dataset_path in candidate_datasets:
52 | dataset_path = Path(dataset_path).expanduser().absolute()
53 | all_wav_path_list += librosa.util.find_files(dataset_path.as_posix(), ext=["wav"])
54 |
55 | all_wav_path_list = offset_and_limit(all_wav_path_list, dataset_offset, dataset_limit)
56 | random.shuffle(all_wav_path_list)
57 |
58 | for wav_file_path in tqdm(all_wav_path_list, desc="Checking"):
59 | y = load_wav(wav_file_path, sr=sr)
60 | wav_duration = len(y) / sr
61 | wav_file_user_path = wav_file_path.replace(Path(wav_file_path).home().as_posix(), "~")
62 |
63 | is_clipped_wav = is_clipped(y)
64 | is_low_activity = activity_detector(y) < activity_threshold
65 | is_too_short = wav_duration < wav_min_second
66 |
67 | if is_too_short:
68 | is_too_short_list.append(wav_file_user_path)
69 | continue
70 |
71 | if is_clipped_wav:
72 | is_clipped_wav_list.append(wav_file_user_path)
73 | continue
74 |
75 | if is_low_activity:
76 | is_low_activity_list.append(wav_file_user_path)
77 | continue
78 |
79 | if (not is_clipped_wav) and (not is_low_activity) and (not is_too_short):
80 | accumulated_time += wav_duration
81 | output_wav_path_list.append(wav_file_user_path)
82 |
83 | if accumulated_time >= (total_hrs * 3600):
84 | break
85 |
86 | os.makedirs(os.path.dirname(dist_file.as_posix()), exist_ok=True)
87 | with open(dist_file.as_posix(), 'w') as f:
88 | f.writelines(f"{file_path}\n" for file_path in output_wav_path_list)
89 |
90 | print("=" * 70)
91 | print("Speech Preprocessing")
92 | print(f"\t Original files: {len(all_wav_path_list)}")
93 | print(f"\t Selected files: {accumulated_time / 3600} hrs, {len(output_wav_path_list)} files.")
94 | print(f"\t is_clipped_wav: {len(is_clipped_wav_list)}")
95 | print(f"\t is_low_activity: {len(is_low_activity_list)}")
96 | print(f"\t is_too_short: {len(is_too_short_list)}")
97 | print(f"\t dist file:")
98 | print(f"\t {dist_file.as_posix()}")
99 | print("=" * 70)
100 |
--------------------------------------------------------------------------------
/speech_enhance/tools/dns_mos.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import os
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import requests
9 | import soundfile as sf
10 | import librosa
11 |
12 | from urllib.parse import urlparse, urljoin
13 |
14 | # URL for the web service
15 | SCORING_URI_DNSMOS = 'https://dnsmos.azurewebsites.net/score'
16 | SCORING_URI_DNSMOS_P835 = 'https://dnsmos.azurewebsites.net/v1/dnsmosp835/score'
17 | # If the service is authenticated, set the key or token
18 | AUTH_KEY = 'd3VoYW4tdW5pdjpkbnNtb3M='
19 |
20 | # Set the content type
21 | headers = {'Content-Type': 'application/json'}
22 | # If authentication is enabled, set the authorization header
23 | headers['Authorization'] = f'Basic {AUTH_KEY}'
24 |
25 |
26 | def main(args):
27 | print(args.testset_dir)
28 | audio_clips_list = glob.glob(os.path.join(args.testset_dir, "*.wav")) # glob:搜索列表中符合的文件,返回列表
29 | print(audio_clips_list)
30 | scores = []
31 | dir_path = args.score_file.split('score.csv')[0]
32 | if not os.path.exists(dir_path):
33 | os.makedirs(dir_path)
34 | if not os.path.exists(os.path.join(dir_path, 'file_mos.txt')):
35 | f = open(os.path.join(dir_path, 'file_mos.txt'), 'w')
36 | dict = {}
37 | else:
38 | f = open(os.path.join(dir_path, 'file_mos.txt'), 'r')
39 | dict = {}
40 | lines = f.readlines()
41 | for line in lines:
42 | utt_id = line.split('.wav')[0]
43 | # print('utt_id store', utt_id)
44 | dict[utt_id] = 1
45 | flag = 0
46 | for fpath in audio_clips_list:
47 | utt_id = fpath.split('\\')[-1].split('.wav')[0]
48 | # print('utt_id', utt_id)
49 | if utt_id in dict:
50 | print('find uttid', utt_id)
51 | continue
52 | flag = 1
53 | f = open(os.path.join(dir_path, 'file_mos.txt'), 'a+')
54 | audio, fs = sf.read(fpath)
55 | if fs != 16000:
56 | print('Resample to 16k')
57 | audio = librosa.resample(audio, orig_sr=fs, target_sr=16000)
58 | data = {"data": audio.tolist(), "filename": os.path.basename(fpath)}
59 | input_data = json.dumps(data)
60 | # Make the request and display the response
61 | if args.method == 'p808':
62 | u = SCORING_URI_DNSMOS
63 | else:
64 | u = SCORING_URI_DNSMOS_P835
65 | try_flag = 1
66 | while try_flag:
67 | try:
68 | resp = requests.post(u, data=input_data, headers=headers, timeout=50)
69 | try_flag = 0
70 | score_dict = resp.json()
71 | except:
72 | try_flag = 1
73 | print('retry_1')
74 | continue
75 | try:
76 | score_dict['file_name'] = os.path.basename(fpath)
77 | if args.method == 'p808':
78 | f.write(score_dict['file_name'] + ' ' + str(score_dict['mos']) + '\n')
79 | print(score_dict['mos'], ' ', score_dict['file_name'])
80 | else:
81 | f.write(score_dict['file_name'] + ' SIG[{}], BAK[{}], OVR[{}]'.format(score_dict['mos_sig'],
82 | score_dict['mos_bak'],
83 | score_dict['mos_ovr']) + '\n')
84 | print(score_dict['file_name'] + ' SIG[{}], BAK[{}], OVR[{}]'.format(score_dict['mos_sig'],
85 | score_dict['mos_bak'],
86 | score_dict['mos_ovr']))
87 | try_flag = 0
88 | except:
89 | try_flag = 1
90 | print('retry_2')
91 | continue
92 | f.close()
93 | scores.append(score_dict)
94 | if flag:
95 | df = pd.DataFrame(scores)
96 | if args.method == 'p808':
97 | print('Mean MOS Score for the files is ', np.mean(df['mos']))
98 | else:
99 | print('Mean scores for the files: SIG[{}], BAK[{}], OVR[{}]'.format(np.mean(df['mos_sig']),
100 | np.mean(df['mos_bak']),
101 | np.mean(df['mos_ovr'])))
102 |
103 | if args.score_file:
104 | df.to_csv(args.score_file)
105 |
106 |
107 | if __name__ == "__main__":
108 | parser = argparse.ArgumentParser()
109 | parser.add_argument("--testset_dir",
110 | default=r'C:\Users\cyrillv\Desktop\谱修复样例及DNSMOS指标\test_data\谱修复测试集_yuanjun\noisy',
111 | help='Path to the dir containing audio clips to be evaluated')
112 | parser.add_argument('--score_file', default=r'./谱修复测试集_yuanjun/noisy/score.csv',
113 | help='If you want the scores in a CSV file provide the full path')
114 | parser.add_argument('--method', default='p835', const='p808', nargs='?', choices=['p808', 'p835'],
115 | help='Choose which method to compute P.808 or P.835. Default is P.808')
116 | args = parser.parse_args()
117 | main(args)
118 |
--------------------------------------------------------------------------------
/speech_enhance/tools/gen_lst.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import argparse
4 |
5 | def gen_lst(args):
6 | wav_lst = glob(os.path.join(args.dataset_dir, "**/*.wav"), recursive=True)
7 | os.makedirs(os.path.dirname(args.output_lst), exist_ok=True)
8 | fc = open(args.output_lst, "w")
9 | for one_wav in wav_lst:
10 | fc.write(f"{one_wav}\n")
11 | fc.close()
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--dataset_dir", type=str, default="")
16 | parser.add_argument("--output_lst", type=str, default="")
17 | args = parser.parse_args()
18 |
19 | gen_lst(args)
20 |
--------------------------------------------------------------------------------
/speech_enhance/tools/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 |
5 | import toml
6 |
7 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
8 | from audio_zen.utils import initialize_module
9 |
10 |
11 | def main(config, checkpoint_path, output_dir):
12 | inferencer_class = initialize_module(config["inferencer"]["path"], initialize=False)
13 | inferencer = inferencer_class(
14 | config,
15 | checkpoint_path,
16 | output_dir
17 | )
18 | inferencer()
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser("Inference")
23 | parser.add_argument("-C", "--configuration", type=str, required=True, help="Config file.")
24 | parser.add_argument("-M", "--model_checkpoint_path", type=str, required=True, help="The path of the model's checkpoint.")
25 | parser.add_argument('-I', '--dataset_dir_list', help='delimited list input',
26 | type=lambda s: [item.strip() for item in s.split(',')])
27 | parser.add_argument("-O", "--output_dir", type=str, required=True, help="The path for saving enhanced speeches.")
28 | args = parser.parse_args()
29 |
30 | configuration = toml.load(args.configuration)
31 | checkpoint_path = args.model_checkpoint_path
32 | output_dir = args.output_dir
33 | if len(args.dataset_dir_list) > 0:
34 | print(f"use specified dataset_dir_list: {args.dataset_dir_list}, instead of in config")
35 | configuration["dataset"]["args"]["dataset_dir_list"] = args.dataset_dir_list
36 |
37 | main(configuration, checkpoint_path, output_dir)
38 |
--------------------------------------------------------------------------------
/speech_enhance/tools/noisyspeech_synthesizer.py:
--------------------------------------------------------------------------------
1 | """
2 | @author: chkarada
3 | """
4 | import glob
5 | import numpy as np
6 | import soundfile as sf
7 | import os
8 | import argparse
9 | import yaml
10 | import configparser as CP
11 | from ..audio.audiolib import audioread, audiowrite, snr_mixer
12 |
13 | def main(cfg):
14 | snr_lower = float(cfg["snr_lower"])
15 | snr_upper = float(cfg["snr_upper"])
16 | total_snrlevels = int(cfg["total_snrlevels"])
17 |
18 | clean_dir = os.path.join(os.path.dirname(__file__), 'clean_train')
19 | if cfg["speech_dir"]!='None':
20 | clean_dir = cfg["speech_dir"]
21 | if not os.path.exists(clean_dir):
22 | assert False, ("Clean speech data is required")
23 |
24 | noise_dir = os.path.join(os.path.dirname(__file__), 'noise_train')
25 | if cfg["noise_dir"]!='None':
26 | noise_dir = cfg["noise_dir"]
27 | if not os.path.exists(noise_dir):
28 | assert False, ("Noise data is required")
29 |
30 | output_dir = os.path.join(os.path.dirname(__file__), 'train_data')
31 | if cfg["noisyspeech_dir"]!='None':
32 | output_dir = cfg["noisyspeech_dir"]
33 | os.makedirs(output_dir, exist_ok=True)
34 |
35 | fs = float(cfg["sampling_rate"])
36 | audioformat = cfg["audioformat"]
37 | total_hours = float(cfg["total_hours"])
38 | audio_length = float(cfg["audio_length"])
39 | silence_length = float(cfg["silence_length"])
40 | noisyspeech_dir = os.path.join(output_dir, 'NoisySpeech_training')
41 | if not os.path.exists(noisyspeech_dir):
42 | os.makedirs(noisyspeech_dir)
43 | clean_proc_dir = os.path.join(output_dir, 'CleanSpeech_training')
44 | if not os.path.exists(clean_proc_dir):
45 | os.makedirs(clean_proc_dir)
46 | noise_proc_dir = os.path.join(output_dir, 'Noise_training')
47 | if not os.path.exists(noise_proc_dir):
48 | os.makedirs(noise_proc_dir)
49 |
50 | total_secs = total_hours*60*60
51 | total_samples = int(total_secs * fs)
52 | audio_length = int(audio_length*fs)
53 | SNR = np.linspace(snr_lower, snr_upper, total_snrlevels)
54 | cleanfilenames = glob.glob(os.path.join(clean_dir, audioformat))
55 | if cfg["noise_types_excluded"]=='None':
56 | noisefilenames = glob.glob(os.path.join(noise_dir, audioformat))
57 | else:
58 | filestoexclude = cfg["noise_types_excluded"].split(',')
59 | noisefilenames = glob.glob(os.path.join(noise_dir, audioformat))
60 | for i in range(len(filestoexclude)):
61 | noisefilenames = [fn for fn in noisefilenames if not os.path.basename(fn).startswith(filestoexclude[i])]
62 |
63 | filecounter = 0
64 | num_samples = 0
65 |
66 | while num_samples < total_samples:
67 | idx_s = np.random.randint(0, np.size(cleanfilenames))
68 | clean, fs = audioread(cleanfilenames[idx_s])
69 |
70 | if len(clean)>audio_length:
71 | clean = clean
72 |
73 | else:
74 |
75 | while len(clean)<=audio_length:
76 | idx_s = idx_s + 1
77 | if idx_s >= np.size(cleanfilenames)-1:
78 | idx_s = np.random.randint(0, np.size(cleanfilenames))
79 | newclean, fs = audioread(cleanfilenames[idx_s])
80 | cleanconcat = np.append(clean, np.zeros(int(fs*silence_length)))
81 | clean = np.append(cleanconcat, newclean)
82 |
83 | idx_n = np.random.randint(0, np.size(noisefilenames))
84 | noise, fs = audioread(noisefilenames[idx_n])
85 |
86 | if len(noise)>=len(clean):
87 | noise = noise[0:len(clean)]
88 |
89 | else:
90 |
91 | while len(noise)<=len(clean):
92 | idx_n = idx_n + 1
93 | if idx_n >= np.size(noisefilenames)-1:
94 | idx_n = np.random.randint(0, np.size(noisefilenames))
95 | newnoise, fs = audioread(noisefilenames[idx_n])
96 | noiseconcat = np.append(noise, np.zeros(int(fs*silence_length)))
97 | noise = np.append(noiseconcat, newnoise)
98 | noise = noise[0:len(clean)]
99 | filecounter = filecounter + 1
100 |
101 | for i in range(np.size(SNR)):
102 | clean_snr, noise_snr, noisy_snr = snr_mixer(clean=clean, noise=noise, snr=SNR[i])
103 | noisyfilename = 'noisy'+str(filecounter)+'_SNRdb_'+str(SNR[i])+'_clnsp'+str(filecounter)+'.wav'
104 | cleanfilename = 'clnsp'+str(filecounter)+'.wav'
105 | noisefilename = 'noisy'+str(filecounter)+'_SNRdb_'+str(SNR[i])+'.wav'
106 | noisypath = os.path.join(noisyspeech_dir, noisyfilename)
107 | cleanpath = os.path.join(clean_proc_dir, cleanfilename)
108 | noisepath = os.path.join(noise_proc_dir, noisefilename)
109 | audiowrite(noisy_snr, fs, noisypath, norm=False)
110 | audiowrite(clean_snr, fs, cleanpath, norm=False)
111 | audiowrite(noise_snr, fs, noisepath, norm=False)
112 | num_samples = num_samples + len(noisy_snr)
113 |
114 |
115 | if __name__=="__main__":
116 | parser = argparse.ArgumentParser()
117 | # Configurations: read noisyspeech_synthesizer.cfg
118 | parser.add_argument("-c", "--config", type=str, help="Read ./config/config.yaml for all the details", default="./config/config.yaml")
119 | args = parser.parse_args()
120 |
121 | cur_filename = os.path.basename((os.path.relpath(__file__))).split(".")[0]
122 | # config
123 | config = yaml.unsafe_load(open(args.config, "r"))
124 |
125 | main(config)
126 |
--------------------------------------------------------------------------------
/speech_enhance/tools/resample_dir.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import argparse
4 | from tqdm import tqdm
5 | from joblib import Parallel, delayed
6 |
7 | def resample_one_wav(wav_file, output_wav_file):
8 | if not os.path.exists(wav_file):
9 | print(f"not found {wav_file}, return")
10 | return
11 | ####
12 | #print(wav_file)
13 | #print(output_wav_file)
14 | os.makedirs(os.path.dirname(output_wav_file), exist_ok=True)
15 | cmd = f"sox {wav_file} -b16 {output_wav_file} rate -v -b 99.7 16k"
16 | os.system(cmd)
17 |
18 | def resample_dir(args):
19 | ### get all wavs
20 | wav_lst = glob(os.path.join(args.dataset_dir, "**/*.wav"), recursive=True)
21 | os.makedirs(args.output_dir, exist_ok=True)
22 | ###
23 | num_workers = 40
24 | Parallel(n_jobs=num_workers)(
25 | delayed(resample_one_wav)(wav_file, wav_file.replace(args.dataset_dir, args.output_dir)) for wav_file in wav_lst)
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--dataset_dir", type=str, default="")
30 | parser.add_argument("--output_dir", type=str, default="")
31 | args = parser.parse_args()
32 |
33 | resample_dir(args)
34 |
--------------------------------------------------------------------------------
/speech_enhance/tools/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import sys
5 | from socket import socket
6 |
7 | import numpy as np
8 | import toml
9 | import torch
10 | import torch.distributed as dist
11 | import torch.multiprocessing as mp
12 | from torch.utils.data import DataLoader, DistributedSampler
13 |
14 | sys.path.append(os.path.abspath(os.path.join(__file__, "..", "..")))
15 | import audio_zen.loss as loss
16 | from audio_zen.utils import initialize_module
17 | from utils.logger import init
18 |
19 | # get free gpu automatically
20 | import GPUtil
21 |
22 | def entry(rank, world_size, config, resume, only_validation):
23 | torch.manual_seed(config["meta"]["seed"]) # For both CPU and GPU
24 | np.random.seed(config["meta"]["seed"])
25 | random.seed(config["meta"]["seed"])
26 |
27 | os.environ["MASTER_ADDR"] = "localhost"
28 | s = socket()
29 | s.bind(("", 0))
30 | os.environ["MASTER_PORT"] = "1111" # A random local port
31 |
32 | # Initialize the process group
33 | dist.init_process_group("gloo", rank=rank, world_size=world_size)
34 |
35 | # init log file
36 | if rank==0:
37 | os.makedirs(os.path.join(config["meta"]["save_dir"]), exist_ok=True)
38 | init(os.path.join(config["meta"]["save_dir"], "train.log"), "train", slack_url=None)
39 |
40 | # The DistributedSampler will split the dataset into the several cross-process parts.
41 | # On the contrary, "Sampler=None, shuffle=True", each GPU will get all data in the whole dataset.
42 |
43 | train_dataset = initialize_module(config["train_dataset"]["path"], args=config["train_dataset"]["args"])
44 | sampler = DistributedSampler(dataset=train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
45 | train_dataloader = DataLoader(
46 | dataset=train_dataset,
47 | sampler=sampler,
48 | shuffle=False,
49 | **config["train_dataset"]["dataloader"],
50 | )
51 |
52 | valid_dataloader = DataLoader(
53 | dataset=initialize_module(config["validation_dataset"]["path"], args=config["validation_dataset"]["args"]),
54 | num_workers=0,
55 | batch_size=1
56 | )
57 |
58 | model = initialize_module(config["model"]["path"], args=config["model"]["args"])
59 |
60 | optimizer = torch.optim.Adam(
61 | params=model.parameters(),
62 | lr=config["optimizer"]["lr"],
63 | betas=(config["optimizer"]["beta1"], config["optimizer"]["beta2"])
64 | )
65 |
66 | loss_function = getattr(loss, config["loss_function"]["name"])(**config["loss_function"]["args"])
67 | trainer_class = initialize_module(config["trainer"]["path"], initialize=False)
68 |
69 | trainer = trainer_class(
70 | dist=dist,
71 | rank=rank,
72 | config=config,
73 | resume=resume,
74 | only_validation=only_validation,
75 | model=model,
76 | loss_function=loss_function,
77 | optimizer=optimizer,
78 | train_dataloader=train_dataloader,
79 | validation_dataloader=valid_dataloader
80 | )
81 |
82 | trainer.train()
83 |
84 |
85 | if __name__ == '__main__':
86 | parser = argparse.ArgumentParser(description="FullSubNet")
87 | parser.add_argument("-C", "--configuration", required=True, type=str, help="Configuration (*.toml).")
88 | parser.add_argument("-R", "--resume", action="store_true", help="Resume the experiment from latest checkpoint.")
89 | parser.add_argument("-V", "--only_validation", action="store_true", help="Only run validation. It is used for debugging validation.")
90 | parser.add_argument("-N", "--num_gpus", type=int, default=0, help="The number of GPUs you are using for training.")
91 | parser.add_argument("-P", "--preloaded_model_path", type=str, help="Path of the *.pth file of a model.")
92 | args = parser.parse_args()
93 |
94 | # set the gpu auto
95 | if args.num_gpus == 0:
96 | device_ids = GPUtil.getAvailable(order = 'first', limit = 8, maxLoad = 0.5, maxMemory = 0.5, includeNan=False, excludeID=[], excludeUUID=[])
97 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([ str(device_id) for device_id in device_ids ])
98 | args.num_gpus = len(device_ids)
99 | print(f"gpus: {os.environ['CUDA_VISIBLE_DEVICES']}")
100 |
101 | if args.preloaded_model_path:
102 | assert not args.resume, "The 'resume' conflicts with the 'preloaded_model_path'."
103 |
104 | configuration = toml.load(args.configuration)
105 |
106 | configuration["meta"]["experiment_name"], _ = os.path.splitext(os.path.basename(args.configuration))
107 | configuration["meta"]["config_path"] = args.configuration
108 | configuration["meta"]["preloaded_model_path"] = args.preloaded_model_path
109 |
110 | # Expand python search path to "recipes"
111 | # sys.path.append(os.path.join(os.getcwd(), ".."))
112 |
113 | # One training job is corresponding to one group (world).
114 | # The world size is the number of processes for training, which is usually the number of GPUs you are using for distributed training.
115 | # the rank is the unique ID given to a process.
116 | # Find more information about DistributedDataParallel (DDP) in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html.
117 | mp.spawn(entry,
118 | args=(args.num_gpus, configuration, args.resume, args.only_validation),
119 | nprocs=args.num_gpus,
120 | join=True)
121 |
122 |
--------------------------------------------------------------------------------
/speech_enhance/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RookieJunChen/Inter-SubNet/ae1af09b8ca5364c50c95e8de94fbf9d67018d90/speech_enhance/utils/__init__.py
--------------------------------------------------------------------------------
/speech_enhance/utils/logger.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | from datetime import datetime
3 | import json
4 | from threading import Thread
5 | from urllib.request import Request, urlopen
6 | import os
7 |
8 | _format = '%Y-%m-%d %H:%M:%S.%f'
9 | _file = None
10 | _run_name = None
11 | _slack_url = None
12 |
13 | def init(filename, run_name, slack_url=None):
14 | os.makedirs(os.path.dirname(filename), exist_ok=True)
15 |
16 | global _file, _run_name, _slack_url
17 | _close_logfile()
18 | _file = open(filename, 'a')
19 | _file.write('\n-----------------------------------------------------------------\n')
20 | _file.write('Starting new training run\n')
21 | _file.write('-----------------------------------------------------------------\n')
22 | _file.flush()
23 | _run_name = run_name
24 | _slack_url = slack_url
25 |
26 |
27 | def log(msg, slack=False):
28 | cur_time = datetime.now().strftime(_format)[:-3]
29 | print('[%s] %s' % (cur_time, msg), end='\n', flush=True)
30 | if _file is not None:
31 | _file.write('[%s] %s\n' % (cur_time, msg))
32 | _file.flush()
33 | if slack and _slack_url is not None:
34 | Thread(target=_send_slack, args=(msg,)).start()
35 |
36 | def _close_logfile():
37 | global _file
38 | if _file is not None:
39 | _file.close()
40 | _file = None
41 |
42 | def _send_slack(msg):
43 | req = Request(_slack_url)
44 | req.add_header('Content-Type', 'application/json')
45 | urlopen(req, json.dumps({
46 | 'username': 'tacotron',
47 | 'icon_emoji': ':taco:',
48 | 'text': '*%s*: %s' % (_run_name, msg)
49 | }).encode())
50 |
51 |
52 | atexit.register(_close_logfile)
53 |
--------------------------------------------------------------------------------
/speech_enhance/utils/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import matplotlib.pyplot as plt
3 | matplotlib.use("Agg")
4 |
5 | import numpy as np
6 |
7 | from .logger import log
8 |
9 |
10 | def plot_alignment(alignment, path):
11 | alignment = np.where(alignment < 1, alignment, 1)
12 | # log("min and max:", np.min(alignment), np.max(alignment))
13 |
14 | fig = plt.figure(figsize=(8, 6))
15 | ax = fig.add_subplot(111)
16 | im = ax.imshow(alignment, aspect='auto', origin='lower',
17 | interpolation='none')
18 | fig.colorbar(im, ax=ax)
19 |
20 | plt.tight_layout()
21 | plt.savefig(path, format='png')
22 | plt.close()
23 | return
24 |
25 |
26 | def plot_spectrogram(pred_spectrogram, plot_path, title="mel-spec", show=False):
27 | fig = plt.figure(figsize=(20, 10))
28 | fig.text(0.5, 0.18, title, horizontalalignment='center', fontsize=16)
29 | vmin = np.min(pred_spectrogram)
30 | vmax = np.max(pred_spectrogram)
31 | ax2 = fig.add_subplot(111)
32 | im = ax2.imshow(np.rot90(pred_spectrogram), interpolation='none',
33 | vmin=vmin, vmax=vmax)
34 | char = fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal',
35 | ax=ax2)
36 |
37 | # char.set_ticks(np.arange(vmin, vmax))
38 | char.set_ticks(np.arange(0, 1))
39 |
40 | plt.tight_layout()
41 | plt.savefig(plot_path, format='png')
42 | if show:
43 | plt.show()
44 | plt.close()
45 | log("save spec png to {}".format(plot_path))
46 | return
47 |
48 |
49 | def plot_two_spec(pred_spec, target_spec, pic_path, title=None,
50 | vmin=None, vmax=None):
51 | # assert np.shape(pred_spec)[1] == 80 and np.shape(target_spec)[1] == 80
52 | fig = plt.figure(figsize=(12, 8))
53 | fig.text(0.5, 0.18, title, horizontalalignment='center', fontsize=16)
54 | if vmin is None or vmax is None:
55 | vmin = min(np.min(pred_spec), np.min(target_spec))
56 | vmax = max(np.max(pred_spec), np.max(target_spec))
57 | ax1 = fig.add_subplot(211)
58 | ax1.set_title('Predicted Mel-Spectrogram')
59 | im = ax1.imshow(np.rot90(pred_spec), interpolation='none',
60 | vmin=vmin, vmax=vmax)
61 | fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax1)
62 |
63 | ax2 = fig.add_subplot(212)
64 | ax2.set_title('Target Mel-Spectrogram')
65 | im = ax2.imshow(np.rot90(target_spec), interpolation='none',
66 | vmin=vmin, vmax=vmax)
67 | fig.colorbar(mappable=im, shrink=0.65, orientation='horizontal', ax=ax2)
68 |
69 | plt.tight_layout()
70 | plt.savefig(pic_path, format='png')
71 | plt.close()
72 | log("save spec png to {}".format(pic_path))
73 | return
74 |
75 |
76 | def plot_line(path, x_list, y_list, label_list):
77 | assert len(x_list) == len(y_list) == len(label_list)
78 | plt.title('Result Analysis')
79 | for x_data, y_data, label in zip(x_list, y_list, label_list):
80 | plt.plot(x_data, y_data, label=label)
81 | # plt.plot(x2, y2, color='red', label='predict')
82 | plt.legend() # 显示图例
83 | plt.xlabel('frame-index')
84 | plt.ylabel('value')
85 | plt.savefig(path, format='png')
86 | plt.close()
87 | return
88 |
89 |
90 | def plot_line_phone_time(path, time_pitch_index, pitch_seq, seq_label):
91 | """ show phoneme and pitch in time"""
92 | seq_str = []
93 | seq_index = []
94 | pre = seq_label[0]
95 | counter = 0
96 | for i in seq_label:
97 | if i != pre:
98 | seq_str.append(pre)
99 | seq_index.append(counter)
100 | pre = i
101 | else:
102 | counter += 1
103 | fig = plt.figure()
104 | ax1 = fig.add_subplot(111)
105 | ax1.plot(time_pitch_index, pitch_seq, 'r')
106 | ax1.set_ylabel('pitch')
107 | for i in range(len(seq_index)):
108 | plt.vlines(seq_index[i], ymin=0, ymax=700)
109 | plt.text(seq_index[i] - 1, 800, seq_str[i], rotation=39)
110 |
111 | plt.savefig(path, format='png')
112 | return
113 |
114 |
115 | def plot_mel(mel, path, info=None):
116 | mel = mel.T
117 | fig, ax = plt.subplots()
118 | im = ax.imshow(
119 | mel,
120 | aspect='auto',
121 | origin='lower',
122 | interpolation='none')
123 | fig.colorbar(im, ax=ax)
124 | xlabel = 'Decoder timestep'
125 | if info is not None:
126 | xlabel += '\n\n' + info
127 | plt.show()
128 | plt.savefig(path, format='png')
129 | plt.close()
130 |
131 | return fig
132 |
133 |
134 | def plot_one_mel_pitch_energy(mel, pitch, energy, stat_json_file, title, path):
135 | """plot mel/pitch/energy
136 |
137 | Args:
138 | mel: [dim, T]
139 | pitch: [T]
140 | energy: [T]
141 | stat_json_file: stat file
142 | title: titile
143 | path: path for png
144 |
145 | """
146 | with open(stat_json_file) as f:
147 | stats = json.load(f)
148 | stats = stats["phn_pitch"] + stats["phn_energy"][:2]
149 | # stats = [min(pitch), max(pitch), 0, 1.0, min(energy), max(energy)]
150 |
151 | fig = plot_multi_mel_pitch_energy(
152 | [
153 | (mel, pitch, energy),
154 | ],
155 | stats,
156 | [title],
157 | )
158 | plt.savefig(path, format='png')
159 | plt.close()
160 |
161 |
162 | def expand(values, durations):
163 | out = list()
164 | for value, d in zip(values, durations):
165 | out += [value] * max(0, int(d))
166 | return np.array(out)
167 |
168 |
169 | def plot_multi_mel_pitch_energy(data, stats, titles):
170 | fig, axes = plt.subplots(len(data), 1, squeeze=False, figsize=(6.4, 3.0 * len(data)))
171 | if titles is None:
172 | titles = [None for i in range(len(data))]
173 | pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats
174 | pitch_min = pitch_min * pitch_std + pitch_mean
175 | pitch_max = pitch_max * pitch_std + pitch_mean
176 |
177 | def add_axis(fig, old_ax):
178 | ax = fig.add_axes(old_ax.get_position(), anchor="W")
179 | ax.set_facecolor("None")
180 | return ax
181 |
182 | for i in range(len(data)):
183 | mel, pitch, energy = data[i]
184 | # log(titles, mel.shape, pitch.shape, energy.shape)
185 | pitch = pitch * pitch_std + pitch_mean
186 | axes[i][0].imshow(mel, origin="lower")
187 | axes[i][0].set_aspect(2.5, adjustable="box")
188 | axes[i][0].set_ylim(0, mel.shape[0])
189 | axes[i][0].set_title(titles[i], fontsize="medium")
190 | axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
191 | axes[i][0].set_anchor("W")
192 |
193 | ax1 = add_axis(fig, axes[i][0])
194 | # log(pitch)
195 | ax1.plot(pitch, color="tomato")
196 | ax1.set_xlim(0, max(mel.shape[1], len(pitch)))
197 | ax1.set_ylim(0, pitch_max)
198 | ax1.set_ylabel("F0", color="tomato")
199 | ax1.tick_params(
200 | labelsize="x-small", colors="tomato", bottom=False, labelbottom=False
201 | )
202 |
203 | ax2 = add_axis(fig, axes[i][0])
204 | ax2.plot(energy, color="darkviolet")
205 | ax2.set_xlim(0, max(mel.shape[1], len(energy)))
206 | ax2.set_ylim(energy_min, energy_max)
207 | ax2.set_ylabel("Energy", color="darkviolet")
208 | ax2.yaxis.set_label_position("right")
209 | ax2.tick_params(
210 | labelsize="x-small",
211 | colors="darkviolet",
212 | bottom=False,
213 | labelbottom=False,
214 | left=False,
215 | labelleft=False,
216 | right=True,
217 | labelright=True,
218 | )
219 |
220 | return fig
221 |
222 |
223 | if __name__ == '__main__':
224 | pass
225 |
--------------------------------------------------------------------------------
/speech_enhance/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import yaml
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from .logger import log
9 |
10 |
11 | def touch_dir(d):
12 | os.makedirs(d, exist_ok=True)
13 |
14 |
15 | def is_file_exists(f):
16 | return os.path.exists(f)
17 |
18 |
19 | def check_file_exists(f):
20 | if not os.path.exists(f):
21 | log(f"not found file: {f}")
22 | assert False, f"not found file: {f}"
23 |
24 |
25 | def read_lines(data_path):
26 | lines = []
27 | with open(data_path, encoding="utf-8") as fr:
28 | for line in fr.readlines():
29 | if len(line.strip().replace(" ", "")):
30 | lines.append(line.strip())
31 | # log("read {} lines from {}".format(len(lines), data_path))
32 | # log("example(last) {}\n".format(lines[-1]))
33 | return lines
34 |
35 |
36 | def write_lines(data_path, lines):
37 | with open(data_path, "w", encoding="utf-8") as fw:
38 | for line in lines:
39 | fw.write("{}\n".format(line))
40 | # log("write {} lines to {}".format(len(lines), data_path))
41 | # log("example(last line): {}\n".format(lines[-1]))
42 | return
43 |
44 |
45 | def get_name_from_path(abs_path):
46 | return ".".join(os.path.basename(abs_path).split(".")[:-1])
47 |
48 |
49 | class AttrDict(dict):
50 | def __init__(self, *args, **kwargs):
51 | super(AttrDict, self).__init__(*args, **kwargs)
52 | self.__dict__ = self
53 | return
54 |
55 |
56 | def load_hparams(yaml_path):
57 | with open(yaml_path, encoding="utf-8") as yaml_file:
58 | hparams = yaml.safe_load(yaml_file)
59 | return AttrDict(hparams)
60 |
61 |
62 | def dump_hparams(yaml_path, hparams):
63 | touch_dir(os.path.dirname(yaml_path))
64 | with open(yaml_path, "w") as fw:
65 | yaml.dump(hparams, fw)
66 | log("save hparams to {}".format(yaml_path))
67 | return
68 |
69 |
70 | def get_all_wav_path(file_dir):
71 | wav_list = []
72 | for path, dir_list, file_list in os.walk(file_dir):
73 | for file_name in file_list:
74 | if file_name.endswith(".wav") or file_name.endswith(".WAV"):
75 | wav_path = os.path.join(path, file_name)
76 | wav_list.append(wav_path)
77 | return sorted(wav_list)
78 |
79 |
80 | def clean_and_new_dir(data_dir):
81 | if os.path.exists(data_dir):
82 | shutil.rmtree(data_dir)
83 | os.makedirs(data_dir)
84 | return
85 |
86 |
87 | def generate_dir_tree(synth_dir, dir_name_list, del_old=False):
88 | os.makedirs(synth_dir, exist_ok=True)
89 | dir_path_list = []
90 | if del_old:
91 | shutil.rmtree(synth_dir, ignore_errors=True)
92 | for name in dir_name_list:
93 | dir_path = os.path.join(synth_dir, name)
94 | dir_path_list.append(dir_path)
95 | os.makedirs(dir_path, exist_ok=True)
96 | return dir_path_list
97 |
98 |
99 | def str2bool(v):
100 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
101 | return True
102 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
103 | return False
104 | else:
105 | raise argparse.ArgumentTypeError('Boolean value expected.')
106 |
107 |
108 | def pad(input_ele, mel_max_length=None):
109 | if mel_max_length:
110 | max_len = mel_max_length
111 | else:
112 | max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])
113 |
114 | out_list = list()
115 | for i, batch in enumerate(input_ele):
116 | if len(batch.shape) == 1:
117 | one_batch_padded = F.pad(
118 | batch, (0, max_len - batch.size(0)), "constant", 0.0
119 | )
120 | elif len(batch.shape) == 2:
121 | one_batch_padded = F.pad(
122 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0
123 | )
124 | out_list.append(one_batch_padded)
125 | out_padded = torch.stack(out_list)
126 | return out_padded
127 |
128 |
129 | def pad_1D(inputs, PAD=0):
130 | def pad_data(x, length, PAD):
131 | x_padded = np.pad(
132 | x, (0, length - x.shape[0]), mode="constant", constant_values=PAD
133 | )
134 | return x_padded
135 |
136 | max_len = max((len(x) for x in inputs))
137 | padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
138 |
139 | return padded
140 |
141 |
142 | def pad_2D(inputs, maxlen=None):
143 | def pad(x, max_len):
144 | PAD = 0
145 | if np.shape(x)[0] > max_len:
146 | raise ValueError("not max_len")
147 |
148 | s = np.shape(x)[1]
149 | x_padded = np.pad(
150 | x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD
151 | )
152 | return x_padded[:, :s]
153 |
154 | if maxlen:
155 | output = np.stack([pad(x, maxlen) for x in inputs])
156 | else:
157 | max_len = max(np.shape(x)[0] for x in inputs)
158 | output = np.stack([pad(x, max_len) for x in inputs])
159 |
160 | return output
161 |
162 |
163 | def get_mask_from_lengths(lengths, max_len=None):
164 | batch_size = lengths.shape[0]
165 | if max_len is None:
166 | max_len = torch.max(lengths).item()
167 |
168 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device)
169 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
170 |
171 | return mask
172 |
173 |
174 | if __name__ == '__main__':
175 | pass
176 |
--------------------------------------------------------------------------------