├── README.md ├── data_loader.py ├── enhanced.wav ├── inference └── real_time_inference │ ├── C++ │ ├── bin │ │ └── win.bin │ ├── lib │ │ └── onnxruntime-linux-x64-1.11.0.tgz │ └── src │ │ ├── fft.cpp │ │ ├── fft.h │ │ ├── onnxtest │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ └── run.cpp │ │ ├── test_STFT.cpp │ │ ├── wav.cpp │ │ └── wav.h │ ├── README.md │ ├── dpcrn_stateful_model.tflite │ ├── inference.py │ └── recording.py ├── main.py ├── modules.py ├── pretrain_model └── model_DPCRN_SNR+logMSE_causal_sinw.h5 ├── real_time_processing ├── __pycache__ │ ├── real_time_DPCRN.cpython-37.pyc │ └── stateful_modules.cpython-37.pyc ├── real_time_DPCRN.py └── stateful_modules.py ├── requirements.txt ├── samples ├── enhanced │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_clatternoise_far_shouting_fileid_7.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_clatternoise_near_crying_fileid_8.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_creakingchair_far_crying_fileid_4.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_creakingchair_far_yelling_fileid_3.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_heavybreathing_near_surpised_fileid_11.wav │ ├── ms_realrec_english_headset_A2J9ZMQ5F4APMW_babycrying_fileid_7.wav │ ├── ms_realrec_english_headset_A2J9ZMQ5F4APMW_car_1_fileid_5.wav │ ├── ms_realrec_english_headset_A2J9ZMQ5F4APMW_dishwasher_fileid_6.wav │ ├── ms_realrec_english_laptop_A2PUL3ZDXOW0VZ_ClatteringNoise_near_4.wav │ ├── ms_realrec_english_laptopmicrophone_A3U20M3KJ10B1A_Creakingchair_far_fileid_6.wav │ ├── ms_realrec_musical_desktop_A8D0400E5EK2K_violin_far_fileid_8.wav │ ├── ms_realrec_musical_headset_A2H95JVPEKRUWA_accordion_far_fileid_11.wav │ ├── ms_realrec_musical_headset_A2H95JVPEKRUWA_harp_far_fileid_10.wav │ ├── ms_realrec_musical_laptopmicrophone_A14ZT8D7Z6T9IA_Guitar_far_fileid_6.wav │ └── ms_realrec_musical_microsoftsoundmapper_A4OAU6U3ZSBYY_guitar_far_fileid_9.wav └── noisy │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_clatternoise_far_shouting_fileid_7.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_clatternoise_near_crying_fileid_8.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_creakingchair_far_crying_fileid_4.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_creakingchair_far_yelling_fileid_3.wav │ ├── ms_realrec_emotional_Desktopstandmic_AHLK9SWDJHBBZ_heavybreathing_near_surpised_fileid_11.wav │ ├── ms_realrec_english_headset_A2J9ZMQ5F4APMW_babycrying_fileid_7.wav │ ├── ms_realrec_english_headset_A2J9ZMQ5F4APMW_car_1_fileid_5.wav │ ├── ms_realrec_english_headset_A2J9ZMQ5F4APMW_dishwasher_fileid_6.wav │ ├── ms_realrec_english_laptop_A2PUL3ZDXOW0VZ_ClatteringNoise_near_4.wav │ ├── ms_realrec_english_laptopmicrophone_A3U20M3KJ10B1A_Creakingchair_far_fileid_6.wav │ ├── ms_realrec_musical_desktop_A8D0400E5EK2K_violin_far_fileid_8.wav │ ├── ms_realrec_musical_headset_A2H95JVPEKRUWA_accordion_far_fileid_11.wav │ ├── ms_realrec_musical_headset_A2H95JVPEKRUWA_harp_far_fileid_10.wav │ ├── ms_realrec_musical_laptopmicrophone_A14ZT8D7Z6T9IA_Guitar_far_fileid_6.wav │ └── ms_realrec_musical_microsoftsoundmapper_A4OAU6U3ZSBYY_guitar_far_fileid_9.wav ├── test.wav └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DPCRN_DNS3 2 | *Created on Mon Oct 28 16:05:31 2021*
3 | *@author: xiaohuai.le* 4 | 5 | This repository is the official implementation of paper "DPCRN: Dual-Path Convolution Recurrent Network for Single Channel Speech Enhancement". This work got the third place in Deep Noise Suppression Challenge. 6 | ## Requirements 7 | tensorflow>=1.14,
8 | numpy,
9 | matplotlib,
10 | librosa,
11 | sondfile.
12 | 13 | ## Datasets 14 | We use [Deep Noise Suppression Dataset](https://github.com/microsoft/DNS-Challenge) and [OpenSLR26](http://www.openslr.org/26/), [OpenSLR28](http://www.openslr.org/28/) RIRs dataset in our training and validation stages. The directory structure of the dataset is shown below:
15 | dataset
16 | ├── clean
17 | │ ├── audio1.wav
18 | │ ├── audio2.wav
19 | │ ├── audio3.wav
20 | │ ...
21 | ├── noise
22 | │ ├── audio1.wav
23 | │ ├── audio2.wav
24 | │ ├── audio3.wav
25 | │ ...
26 | 27 | RIR
28 | ├── rirs
29 | │ ├── rir1.wav
30 | │ ├── rir2.wav
31 | │ ├── rir3.wav
32 | │ ...
33 | 34 | ## Training and test 35 | Run the following code to training: 36 | ```shell 37 | python main.py --mode train --cuda 0 --experimentName experiment_1 38 | ``` 39 | Run the following code to test the model on a single file: 40 | ```shell 41 | python main.py --mode test --test_dir the_dir_of_noisy --output_dir the_dir_of_enhancement_results 42 | ``` 43 | ## More samples 44 | 45 | The final results on the blind test set of DNS3 is available on https://github.com/Le-Xiaohuai-speech/DPCRN_DNS3_Results.
46 | 47 | ## Real-time inference 48 | **Note that the real-time inference can only run on the tensorflow=1.x.** 49 | Run real-time inference to calculate the time cost of a frame:
50 | ```shell 51 | python ./real_time_processing/real_time_DPCRN.py 52 | ``` 53 | ## Tensorflow Lite quantization and pruning 54 | The TFLite file of a smaller dpcrn model is uploaded. 55 | Enhance a single wav file: 56 | ```shell 57 | python ./inference/real_time_inference/inference.py 58 | ``` 59 | Streaming recording and enhancement: 60 | ```shell 61 | python ./inference/real_time_inference/recording.py 62 | ``` 63 | ## TensorRT deployment 64 | https://github.com/Xiaobin-Rong/TRT-SE?tab=readme-ov-file 65 | 66 | ## Citations 67 | ```shell 68 | @inproceedings{le21b_interspeech, 69 | author={Xiaohuai Le and Hongsheng Chen and Kai Chen and Jing Lu}, 70 | title={{DPCRN: Dual-Path Convolution Recurrent Network for Single Channel Speech Enhancement}}, 71 | year=2021, 72 | booktitle={Proc. Interspeech 2021}, 73 | pages={2811--2815}, 74 | doi={10.21437/Interspeech.2021-296} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Jan 12 14:57:00 2021 4 | 5 | @author: xiaohuaile 6 | """ 7 | import soundfile as sf 8 | #from wavinfo import WavInfoReader 9 | from random import shuffle, seed 10 | import numpy as np 11 | import librosa 12 | import os 13 | from scipy import signal 14 | ''' 15 | TRAIN_DIR: DNS data 16 | RIR_DIR: Room impulse response 17 | ''' 18 | TRAIN_DIR = '/data/ssd1/xiaohuai.le/DNS_data1/DNS_data' 19 | RIR_DIR = '/data/ssd1/xiaohuai.le/RIR_database/impulse_responses/' 20 | 21 | 22 | #FIR, frequencies below 60Hz will be filtered 23 | fir = signal.firls(1025,[0,40,50,60,70,8000],[0,0,0.1,0.5,1,1],fs = 16000) 24 | 25 | def add_pyreverb(clean_speech, rir): 26 | ''' 27 | convolve RIRs to the clean speech to generate reverbrant speech 28 | ''' 29 | l = len(rir)//2 30 | reverb_speech = signal.fftconvolve(clean_speech, rir, mode="full") 31 | # make reverb_speech same length as clean_speech 32 | reverb_speech = reverb_speech[l : clean_speech.shape[0]+l] 33 | 34 | return reverb_speech 35 | #按照snr混合音频 36 | def mk_mixture(s1,s2,snr,eps = 1e-8): 37 | ''' 38 | make mixture from s1 and s2 with snr 39 | ''' 40 | norm_sig1 = s1 / np.sqrt(np.sum(s1 ** 2) + eps) 41 | norm_sig2 = s2 / np.sqrt(np.sum(s2 ** 2) + eps) 42 | alpha = 10**(snr/20) 43 | mix = norm_sig2 + alpha*norm_sig1 44 | M = max(np.max(abs(mix)),np.max(abs(norm_sig2)),np.max(abs(alpha*norm_sig1))) + eps 45 | mix = mix / M 46 | norm_sig1 = norm_sig1 * alpha/ M 47 | norm_sig2 = norm_sig2 / M 48 | 49 | return norm_sig1,norm_sig2,mix,snr 50 | 51 | 52 | class data_generator(): 53 | 54 | def __init__(self,train_dir = TRAIN_DIR, 55 | RIR_dir = RIR_DIR, 56 | validation_rate=0.1, 57 | length_per_sample = 4, 58 | fs = 16000, 59 | n_fft = 400, 60 | n_hop = 200, 61 | batch_size = 8, 62 | sample_num=-1, 63 | add_reverb = True, 64 | reverb_rate = 0.5 65 | ): 66 | ''' 67 | keras data generator 68 | Para.: 69 | train_dir: folder storing training data, including train_dir/clean, train_dir/noise 70 | RIR_dir: folder storing RIRs, from OpenSLR26 and OpenSLR28 71 | validation_rate: how much data is used for validation 72 | length_per_sample: speech sample length in second 73 | fs: sample rate of the speech 74 | n_fft: FFT length and window length in STFT 75 | n_hop: hop length in STFT 76 | batch_size: batch size 77 | sample_num: how many samples are used for training and validation 78 | add_reverb: adding reverbrantion or not 79 | reverb_rate: how much data is reverbrant 80 | ''' 81 | 82 | self.train_dir = train_dir 83 | self.clean_dir = os.path.join(train_dir,'clean') 84 | self.noise_dir = os.path.join(train_dir,'noise') 85 | 86 | self.fs = fs 87 | self.batch_size = batch_size 88 | self.length_per_sample = length_per_sample 89 | self.L = length_per_sample * self.fs 90 | # calculate the length of each sample after iSTFT 91 | self.points_per_sample = ((self.L - n_fft) // n_hop) * n_hop + n_fft 92 | 93 | self.validation_rate = validation_rate 94 | self.add_reverb = add_reverb 95 | self.reverb_rate = reverb_rate 96 | 97 | if RIR_dir is not None: 98 | self.rir_dir = RIR_dir 99 | self.rir_list = librosa.util.find_files(self.rir_dir,ext = 'wav')[:sample_num] 100 | np.random.shuffle(self.rir_list) 101 | self.rir_list = self.rir_list[:sample_num] 102 | print('there are {} rir clips\n'.format(len(self.rir_list))) 103 | 104 | self.noise_file_list = os.listdir(self.noise_dir) 105 | self.clean_file_list = os.listdir(self.clean_dir)[:sample_num] 106 | self.train_length = int(len(self.clean_file_list)*(1-validation_rate)) 107 | self.train_list, self.validation_list = self.generating_train_validation(self.train_length) 108 | self.valid_length = len(self.validation_list) 109 | 110 | self.train_rir = self.rir_list[:self.train_length] 111 | self.valid_rir = self.rir_list[self.train_length : self.train_length + self.valid_length] 112 | print('have been generated DNS training list...\n') 113 | 114 | print('there are {} samples for training, {} for validation'.format(self.train_length,self.valid_length)) 115 | 116 | def find_files(self,file_name): 117 | ''' 118 | from file_name find parallel noise file and noisy file 119 | e.g. 120 | file_name: clean_fileid_1.wav 121 | noise_file_name: noise_fileid_1.wav 122 | noisy_file_name: noisy_fileid_1.wav 123 | ''' 124 | #noise_file_name = np.random.choice(self.noise_file_list) #randomly selection 125 | _,k1,k2 = file_name.split('_') 126 | noise_file_name = 'noise' + '_' + k1 + '_' + k2 127 | noisy_file_name = 'noisy' + '_' + k1 + '_' + k2 128 | 129 | # random segmentation 130 | Begin_S = int(np.random.uniform(0,30 - self.length_per_sample)) * self.fs 131 | Begin_N = int(np.random.uniform(0,30 - self.length_per_sample)) * self.fs 132 | return noise_file_name,noisy_file_name,Begin_S,Begin_N 133 | 134 | def generating_train_validation(self,training_length): 135 | ''' 136 | get training and validation data 137 | ''' 138 | np.random.shuffle(self.clean_file_list) 139 | self.train_list,self.validation_list = self.clean_file_list[:training_length],self.clean_file_list[training_length:] 140 | 141 | return self.train_list,self.validation_list 142 | 143 | def generator(self, batch_size, validation = False): 144 | ''' 145 | data generator, 146 | validation: if True, get validation data genertor 147 | ''' 148 | if validation: 149 | train_data = self.validation_list 150 | train_rir = self.valid_rir 151 | else: 152 | train_data = self.train_list 153 | train_rir = self.train_rir 154 | N_batch = len(train_data) // batch_size 155 | batch_num = 0 156 | while (True): 157 | 158 | batch_clean = np.zeros([batch_size,self.points_per_sample],dtype = np.float32) 159 | batch_noisy = np.zeros([batch_size,self.points_per_sample],dtype = np.float32) 160 | 161 | for i in range(batch_size): 162 | # random amplitude gain 163 | gain = np.random.normal(loc=-5,scale=10) 164 | gain = 10**(gain/10) 165 | gain = min(gain,3) 166 | gain = max(gain,0.01) 167 | 168 | SNR = np.random.uniform(-5,5) 169 | sample_num = batch_num*batch_size + i 170 | #get the path of clean audio 171 | clean_f = train_data[sample_num] 172 | rir_f = train_rir[sample_num] 173 | reverb_rate = np.random.rand() 174 | 175 | noise_f, noisy_f, Begin_S,Begin_N = self.find_files(clean_f) 176 | clean_s = sf.read(os.path.join(self.clean_dir,clean_f),dtype = 'float32',start= Begin_S,stop = Begin_S + self.points_per_sample)[0] 177 | noise_s = sf.read(os.path.join(self.noise_dir,noise_f),dtype = 'float32',start= Begin_N,stop = Begin_N + self.points_per_sample)[0] 178 | 179 | clean_s = add_pyreverb(clean_s, fir) 180 | 181 | #noise_s = noise_s - np.mean(noise_s) 182 | if self.add_reverb: 183 | if reverb_rate < self.reverb_rate: 184 | rir_s = sf.read(rir_f,dtype = 'float32')[0] 185 | if len(rir_s.shape)>1: 186 | rir_s = rir_s[:,0] 187 | clean_s = add_pyreverb(clean_s, rir_s) 188 | 189 | clean_s,noise_s,noisy_s,_ = mk_mixture(clean_s,noise_s,SNR,eps = 1e-8) 190 | 191 | batch_clean[i,:] = clean_s * gain 192 | batch_noisy[i,:] = noisy_s * gain 193 | 194 | batch_num += 1 195 | 196 | if batch_num == N_batch: 197 | batch_num = 0 198 | 199 | if validation: 200 | train_data = self.validation_list 201 | train_rir = self.valid_rir 202 | else: 203 | train_data = self.train_list 204 | train_rir = self.train_rir 205 | 206 | np.random.shuffle(train_data) 207 | np.random.shuffle(train_rir) 208 | np.random.shuffle(self.noise_file_list) 209 | 210 | N_batch = len(train_data) // batch_size 211 | 212 | yield batch_noisy,batch_clean 213 | 214 | 215 | -------------------------------------------------------------------------------- /enhanced.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/DPCRN_DNS3/c7fe17d02fcc2502f198dd6c2d29bba2c4e1c0ed/enhanced.wav -------------------------------------------------------------------------------- /inference/real_time_inference/C++/bin/win.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/DPCRN_DNS3/c7fe17d02fcc2502f198dd6c2d29bba2c4e1c0ed/inference/real_time_inference/C++/bin/win.bin -------------------------------------------------------------------------------- /inference/real_time_inference/C++/lib/onnxruntime-linux-x64-1.11.0.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/DPCRN_DNS3/c7fe17d02fcc2502f198dd6c2d29bba2c4e1c0ed/inference/real_time_inference/C++/lib/onnxruntime-linux-x64-1.11.0.tgz -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/fft.cpp: -------------------------------------------------------------------------------- 1 | #include "fft.h" 2 | 3 | static void make_sintbl(int n, float* sintbl) 4 | { 5 | int i, n2, n4, n8; 6 | float c, s, dc, ds, t; 7 | 8 | n2 = n / 2; n4 = n / 4; n8 = n / 8; 9 | t = sin(M_PI / n); 10 | dc = 2 * t * t; ds = sqrt(dc * (2 - dc)); 11 | t = 2 * dc; c = sintbl[n4] = 1; s = sintbl[0] = 0; 12 | for (i = 1; i < n8; i++) { 13 | c -= dc; dc += t * c; 14 | s += ds; ds -= t * s; 15 | sintbl[i] = s; sintbl[n4 - i] = c; 16 | } 17 | if (n8 != 0) sintbl[n8] = sqrt(0.5); 18 | for (i = 0; i < n4; i++) 19 | sintbl[n2 - i] = sintbl[i]; 20 | for (i = 0; i < n2 + n4; i++) 21 | sintbl[i + n2] = -sintbl[i]; 22 | } 23 | 24 | static void make_bitrev(int n, int* bitrev) 25 | { 26 | int i, j, k, n2; 27 | 28 | n2 = n / 2; i = j = 0; 29 | for (;;) { 30 | bitrev[i] = j; 31 | if (++i >= n) break; 32 | k = n2; 33 | while (k <= j) { j -= k; k /= 2; } 34 | j += k; 35 | } 36 | } 37 | 38 | int fft(float* x, float* y, int n) 39 | { 40 | static int last_n = 0; /* previous n */ 41 | static int *bitrev = NULL; /* bit reversal table */ 42 | static float *sintbl = NULL; /* trigonometric function table */ 43 | int i, j, k, ik, h, d, k2, n4, inverse; 44 | float t, s, c, dx, dy; 45 | 46 | /* preparation */ 47 | if (n < 0) { 48 | n = -n; inverse = 1; /* inverse transform */ 49 | } 50 | else { 51 | inverse = 0; 52 | } 53 | n4 = n / 4; 54 | if (n != last_n || n == 0) { 55 | last_n = n; 56 | if (sintbl != NULL) free(sintbl); 57 | if (bitrev != NULL) free(bitrev); 58 | if (n == 0) return 0; /* free the memory */ 59 | sintbl = (float*)malloc((n + n4) * sizeof(float)); 60 | bitrev = (int*)malloc(n * sizeof(int)); 61 | if (sintbl == NULL || bitrev == NULL) { 62 | //Error("%s in %f(%d): out of memory\n", __func__, __FILE__, __LINE__); 63 | return 1; 64 | } 65 | make_sintbl(n, sintbl); 66 | make_bitrev(n, bitrev); 67 | } 68 | 69 | /* bit reversal */ 70 | for (i = 0; i < n; i++) { 71 | j = bitrev[i]; 72 | if (i < j) { 73 | t = x[i]; x[i] = x[j]; x[j] = t; 74 | t = y[i]; y[i] = y[j]; y[j] = t; 75 | } 76 | } 77 | 78 | /* transformation */ 79 | for (k = 1; k < n; k = k2) { 80 | h = 0; k2 = k + k; d = n / k2; 81 | for (j = 0; j < k; j++) { 82 | c = sintbl[h + n4]; 83 | if (inverse) 84 | s = -sintbl[h]; 85 | else 86 | s = sintbl[h]; 87 | for (i = j; i < n; i += k2) { 88 | ik = i + k; 89 | dx = s * y[ik] + c * x[ik]; 90 | dy = c * y[ik] - s * x[ik]; 91 | x[ik] = x[i] - dx; x[i] += dx; 92 | y[ik] = y[i] - dy; y[i] += dy; 93 | } 94 | h += d; 95 | } 96 | } 97 | if (inverse) { 98 | /* divide by n in case of the inverse transformation */ 99 | for (i = 0; i < n; i++) 100 | { 101 | x[i] /= n; 102 | y[i] /= n; 103 | } 104 | } 105 | return 0; /* finished successfully */ 106 | } 107 | 108 | void STFT(float *s, float *spec, float *win, int signal_length, int fft_length, int window_length, int hop_length){ 109 | 110 | int N_frame = signal_length / hop_length + 1; 111 | int index = 0; 112 | 113 | for (int i = 0; i < N_frame; i++){ 114 | for (int j = 0; j < window_length; j++ ){ 115 | index = i * hop_length + j; 116 | if (index < signal_length){ 117 | spec[i * fft_length * 2 + j] = win[j] * s[index]; 118 | } 119 | } 120 | fft(spec + i * fft_length * 2, spec + i * fft_length * 2 + fft_length, fft_length); 121 | } 122 | } 123 | 124 | void iSTFT(float *s, float *spec, float *win, int N_frame, int signal_length, int fft_length, int window_length, int hop_length){ 125 | 126 | for (int i = 0; i < N_frame; i++){ 127 | for (int j = 1; j < fft_length / 2; j++ ){ 128 | // real part 129 | spec[i * fft_length * 2 + fft_length - j] = spec[i * fft_length * 2 + j]; 130 | // imaginary part 131 | spec[i * fft_length * 2 + 2 * fft_length - j] = -spec[i * fft_length * 2 + fft_length + j]; 132 | } 133 | fft(spec + i * fft_length * 2, spec + i * fft_length * 2 + fft_length, -fft_length); 134 | // add window // overlap-add 135 | for (int k = 0; k < fft_length ; k++ ){ 136 | //spec[i * fft_length * 2 + k] = win[k] * spec[i * fft_length * 2 + k] 137 | if (i * hop_length + k < signal_length){ 138 | s[i * hop_length + k] = s[i * hop_length + k] + win[k] * spec[i * fft_length * 2 + k]; 139 | } 140 | } 141 | } 142 | } 143 | 144 | 145 | -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/fft.h: -------------------------------------------------------------------------------- 1 | /* @breif Fast Fourier Transform 2 | * @Author: Xiaohuaile 3 | * @Date: 2022-4-9 4 | * @Last Modified by: Xiaohuaile 5 | */ 6 | #ifndef FFT_H_ 7 | #define FFT_H_ 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | // #define M_PI 3.141592653589793238462643383279502 14 | 15 | /* 16 | * @breif Fast Fourier Transform 17 | * @x: real part 18 | * @y: imaginary part 19 | * @n: length of fft, negative for inverse FFT 20 | */ 21 | int fft(float* x, float* y, int n); 22 | 23 | /* 24 | * @brief: Short Time Fourier Transform 25 | * @s: input signal 26 | * @spec: spectrogram, N_frame * N_fft *2 (real and imag) 27 | * @win: window, window_length 28 | * @signal_length 29 | * @fft_length 30 | * @window_length 31 | * @hop_length 32 | */ 33 | void STFT(float *s, float *spec, float *win, int signal_length, int fft_length, int window_length, int hop_length); 34 | 35 | /* 36 | * @brief: inverse Short Time Fourier Transform 37 | * @s: output signal 38 | * @spec: spectrogram, N_frame * N_fft / 2 + 1 (rfft part) + N_fft/2 - 1 (rfft symmetrical part) 39 | * @win: window, window_length 40 | * @N_frame: number of the frame 41 | * @signal_length: length of the input signal 42 | * @fft_length 43 | * @window_length 44 | * @hop_length 45 | */ 46 | void iSTFT(float *s, float *spec, float *win, int N_frame, int signal_length, int fft_length, int window_length, int hop_length); 47 | 48 | #endif 49 | 50 | 51 | -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/onnxtest/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0.0) 2 | project(run_demo VERSION 0.1.0) 3 | # cross compiler for arm/aarch64 4 | #set(CMAKE_CXX_COMPILER "/opt/gcc-linaro-6.5.0-2018.12-x86_64_aarch64-linux-gnu/bin/aarch64-linux-gnu-g++") 5 | 6 | enable_testing() 7 | # onnxruntime dir 8 | set(HOST_PACKAGE_DIR "/mnt/d/codes/c++/my_project/dpcrn_onnx_demo/onnxruntime-linux-x64-1.11.0") 9 | 10 | include_directories( 11 | ${HOST_PACKAGE_DIR}/include/) 12 | 13 | link_directories( 14 | ${HOST_PACKAGE_DIR}/lib/) 15 | 16 | add_executable(run_demo run.cpp) 17 | 18 | target_link_libraries(run_demo ${HOST_PACKAGE_DIR}/lib/libonnxruntime.so.1.11.0) 19 | 20 | 21 | -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/onnxtest/README.md: -------------------------------------------------------------------------------- 1 | 2 | ```shell 3 | tar zxvf ../../lib/onnxruntime-linux-x64-1.11.0.tgz 4 | ``` 5 | Set the path of onnxruntime as the HOST_PACKAGE_DIR in CMakeLists.txt and make. 6 | 7 | You can also compile the code by: 8 | ```shell 9 | g++ -o run run.cpp ../../lib/onnxruntime-linux-x64-1.11.0/lib/libonnxruntime.so.1.11.0 -I../../lib/onnxruntime-linux-x64-1.11.0/include/ 10 | ``` 11 | -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/onnxtest/run.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | int main(int argc, char* argv[]) { 9 | Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); 10 | Ort::SessionOptions session_options; 11 | session_options.SetIntraOpNumThreads(1); 12 | session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED); 13 | 14 | #ifdef _WIN32 15 | const wchar_t* model_path = L"./model.onnx"; 16 | #else 17 | const char* model_path = "dpcrn.onnx"; 18 | #endif 19 | int step=0; 20 | std::cout << "input the step:"<> step; 22 | 23 | Ort::Session session(env, model_path, session_options); 24 | // print model input layer (node names, types, shape etc.) 25 | Ort::AllocatorWithDefaultOptions allocator; 26 | 27 | // print number of model input nodes 28 | size_t num_input_nodes = session.GetInputCount(); 29 | std::vector input_node_names = {"input","h_in"}; 30 | std::vector output_node_names = {"output","h_out"}; 31 | // get input tensor 32 | std::vector input_node_dims = {1, 3, 1, 257}; 33 | size_t input_tensor_size = 3 * 257; 34 | std::vector input_tensor_values(input_tensor_size); 35 | 36 | for (unsigned int i = 0; i < input_tensor_size; i++){ 37 | input_tensor_values[i] = (float)i / input_tensor_size; 38 | } 39 | 40 | auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); 41 | Ort::Value input_tensor = Ort::Value::CreateTensor(memory_info, input_tensor_values.data(), input_tensor_size, input_node_dims.data(), 4); 42 | assert(input_tensor.IsTensor()); 43 | 44 | // get input hidden states 45 | std::vector input_hidden_node_dims = {2, 32, 128}; 46 | size_t input_hidden_tensor_size = 2 * 32 * 128; 47 | std::vector input_hidden_tensor_values(input_hidden_tensor_size); 48 | 49 | for (unsigned int i = 0; i < input_hidden_tensor_size; i++){ 50 | input_hidden_tensor_values[i] = (float)i / (input_hidden_tensor_size + 1); 51 | } 52 | 53 | auto hidden_memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); 54 | Ort::Value input_hidden_tensor = Ort::Value::CreateTensor(hidden_memory_info, input_hidden_tensor_values.data(), input_hidden_tensor_size, input_hidden_node_dims.data(), 3); 55 | assert(input_hidden_tensor.IsTensor()); 56 | 57 | // inference 58 | clock_t start, end; 59 | //std::vector time; 60 | 61 | std::vector ort_inputs; 62 | std::vector output_tensors; 63 | ort_inputs.push_back(std::move(input_tensor)); 64 | ort_inputs.push_back(std::move(input_hidden_tensor)); 65 | 66 | float *output, *output_hidden; 67 | /* 68 | session.Run(run_options, input_names, input_values, input_count, output_names, output_count) 69 | OrtRun(session_, nullptr, input_names, &input_tensor, input_count, output_names, output_count, &output_tensor); 70 | */ 71 | start = clock(); 72 | for(int i =0 ; i < step; i++){ 73 | 74 | fill(input_tensor_values.begin(), input_tensor_values.end(), (float)i); 75 | output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), ort_inputs.data(), ort_inputs.size(), output_node_names.data(), 2); 76 | // get pointer to output tensor float values 77 | output = output_tensors[0].GetTensorMutableData(); 78 | output_hidden = output_tensors[1].GetTensorMutableData(); 79 | //usleep(16000); 80 | } 81 | 82 | end = clock(); 83 | for(int i =0;i<20;i++){ 84 | std::cout << i << " " << output[i] << std::endl; 85 | } 86 | std::cout<< (double)(end - start) / CLOCKS_PER_SEC / step * 1000 << " ms/frame" << std::endl; 87 | 88 | std::cout<< input_node_names.data()[0] << std::endl; 89 | printf("Done!\n"); 90 | } 91 | -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/test_STFT.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "wav.h" 4 | #include "fft.h" 5 | /* 6 | run g++ ./test_STFT.cpp ./wav.cpp ./fft.cpp -o read_wav 7 | */ 8 | 9 | int main(int, char**) { 10 | std::cout << "input the wav file" << std::endl; 11 | std::string wav_file; 12 | // file_name = "./sample.wav"; 13 | std::cin >> wav_file; 14 | FILE* fp = fopen(wav_file.data(), "rb"); 15 | // header information and sample data 16 | wav_info info; 17 | wav_data wdata; 18 | // read the wav file 19 | read_wav_info(&info, &wdata, fp); 20 | fclose(fp); 21 | 22 | float len_in_s = (float)info.num_samples / (float)info.sample_rate; 23 | // print the information 24 | std::cout << "file name: " << wav_file << "\n" 25 | << "channel number: " << info.num_channels <<"\n" 26 | << "bits per sampe: " << info.bits_per_sample << "\n" 27 | << "sample rate: " << info.sample_rate<<"\n" 28 | << "num samples: " << info.num_samples<<"\n" 29 | << "time in second: " << len_in_s << std::endl; 30 | 31 | std::cout << "samples: "<< std::endl; 32 | 33 | for(int i=0;i<100;i++){ 34 | std::cout << wdata.data[i]<<" "; 35 | } 36 | std::cout << std::endl; 37 | 38 | // test STFT 39 | int fft_len = 512; 40 | int hop_len = 256; 41 | int N_frame = wdata.size / hop_len + 1; 42 | float win[fft_len]; 43 | float *s, *spec; 44 | std::string win_f = "../bin/win.bin"; 45 | s = new float[wdata.size]; 46 | spec = new float[N_frame * fft_len * 2]; 47 | std::cout << "frame length: " << fft_len << "\n" 48 | << "hop length: " << hop_len << "\n" 49 | << "frame number: " << N_frame << std::endl; 50 | // read window 51 | read_file_bin_data(win_f.c_str(), win, fft_len * 4); 52 | // STFT 53 | for (int i=0; i < wdata.size; i++){ 54 | s[i] = float(wdata.data[i]) / 32767.0; 55 | } 56 | STFT(s, spec, win, wdata.size, fft_len, fft_len, hop_len); 57 | /* 58 | check the output in python by: 59 | x = np.fromfile('../bin/stft.bin', dtype=np.float32) 60 | */ 61 | write_file_bin_data("../bin/stft.bin", spec, N_frame * fft_len * 2 * 4); 62 | // iSTFT 63 | memset(s, 0, wdata.size*4); 64 | iSTFT(s, spec, win, N_frame, wdata.size, fft_len, fft_len, hop_len); 65 | 66 | for (int i=0; i < wdata.size; i++){ 67 | wdata.data[i] = int16_t(32767.0 * s[i]); 68 | } 69 | // write the signal as a .pcm file 70 | write_file_bin_data("./output_s.pcm", wdata.data, wdata.size * info.bits_per_sample / 8); 71 | // write the signal as a .wav file 72 | fp = fopen("./output_s.wav", "wb"); 73 | write_file_wav(&info, fp, wdata.data); 74 | fclose(fp); 75 | 76 | free_source(&wdata); 77 | delete [] s; 78 | delete [] spec; 79 | } 80 | -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/wav.cpp: -------------------------------------------------------------------------------- 1 | #include "wav.h" 2 | #include 3 | #include 4 | #include 5 | 6 | void read_wav_info(struct wav_info *w, wav_data *wdata, FILE *fp) { 7 | // To be read from *fp 8 | uint32_t data_size; 9 | uint32_t size; 10 | uint32_t byte_rate; 11 | uint16_t block_align; 12 | 13 | uint8_t x[4]; /* buffer for reading from *fp */ 14 | 15 | // Start reading from beginning of *fp 16 | if(fseek(fp,0,SEEK_SET)) { 17 | fprintf(stderr,"Error with fseek in read_wav_info in wav.c\n"); 18 | exit(EXIT_FAILURE); 19 | } 20 | // First four bytes of any RIFF file should be the ASCII codes for "RIFF" 21 | fread(x,1,4,fp); 22 | if((x[0] != 0x52) || (x[1] != 0x49) || (x[2] != 0x46) || (x[3] != 0x46)) { 23 | fprintf(stderr,"Error: First 4 bytes indicate file is not a RIFF/WAVE file!\n"); 24 | exit(EXIT_FAILURE); 25 | } 26 | // Next four bytes give the RIFF chunk size RCS in Little Endian format 27 | fread(x,1,4,fp); 28 | // We're not going to do anything with it, but you could do 29 | // uint32_t RCS = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); 30 | // here if you wanted to... 31 | // Next four bytes should be the ASCII codes for "WAVE" 32 | fread(x,1,4,fp); 33 | if((x[0] != 0x57) || (x[1] != 0x41) || (x[2] != 0x56) || (x[3] != 0x45)) { 34 | fprintf(stderr,"Error: Bytes 9-12 indicate file is not a RIFF/WAVE file!\n"); 35 | exit(EXIT_FAILURE); 36 | } 37 | 38 | // Look for the "fmt " subchunk of this RIFF file... 39 | while(1) { 40 | fread(x,1,4,fp); 41 | // See if the four bytes we just read are the ASCII codes for "fmt " 42 | if((x[0] == 0x66) && (x[1] == 0x6D) && (x[2] == 0x74) && (x[3] == 0x20)) { 43 | // Found the "fmt " subchunk 44 | // The next four bytes should give the size of the fmt subchunk 45 | // in Little Endian. This should be 16 if this is a PCM WAVE file. 46 | fread(x,1,4,fp); 47 | uint32_t y = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); 48 | if(y != 16) { 49 | fprintf(stderr,"Error: File does not appear to be a PCM RIFF/WAVE file.\n"); 50 | fprintf(stderr,"fmt subchunk doesn't have size 16.\n"); 51 | exit(EXIT_FAILURE); 52 | } 53 | // Next two bytes should give the integer 1 (for PCM format) in Little Endian 54 | fread(x,1,4,fp); 55 | uint16_t wFormatTag = x[0] | (x[1] << 8); 56 | if(wFormatTag != 1) { 57 | fprintf(stderr,"Error: File does not appear to be a PCM RIFF/WAVE file.\n"); 58 | fprintf(stderr,"wFormatTag is not equal to 1.\n"); 59 | exit(EXIT_FAILURE); 60 | } 61 | // The rest of the fmt subchunk should give num_channels (as a two-byte 62 | // integer in L.E.), sample_rate (four-byte L.E.), byte_rate (four-byte L.E.), 63 | // block_align (two-byte L.E.), and bits_per_sample (two-byte L.E.) 64 | w->num_channels = x[2] | (x[3] << 8); 65 | fread(x,1,4,fp); 66 | w->sample_rate = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); 67 | fread(x,1,4,fp); 68 | byte_rate = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); 69 | fread(x,1,4,fp); 70 | block_align = x[0] | (x[1] << 8); 71 | w->bits_per_sample = x[2] | (x[3] << 8); 72 | // Now we're done with the fmt subchunk 73 | break; 74 | } 75 | // The four bytes after the four-byte "Chunk ID" in any RIFF file give 76 | // the size of the chunk as a four-byte integer in Little Endian 77 | uint32_t chunk_size; 78 | fread(x,1,4,fp); 79 | chunk_size = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); 80 | printf("chunk size=%d",chunk_size); 81 | // Skip over this subchunk and keep looking for "fmt " subchunk 82 | if(fseek(fp,chunk_size,SEEK_CUR)) { 83 | fprintf(stderr,"Error: Couldn't find fmt subchunk in file.\n"); 84 | exit(EXIT_FAILURE); 85 | } 86 | } 87 | 88 | // Now look for the "data" subchunk of this RIFF file... 89 | while(1) { 90 | fread(x,1,4,fp); 91 | // See if these four bytes are the ASCII codes for "data" 92 | if((x[0] == 0x64) && (x[1] == 0x61) && (x[2] == 0x74) && (x[3] == 0x61)) { 93 | // Found the "data" subchunk 94 | fread(x,1,4,fp); 95 | data_size = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); 96 | // Now we're done reading from *fp... 97 | break; 98 | } 99 | // The four bytes after the four-byte "Chunk ID" in any RIFF file give 100 | // the size of the chunk as a four-byte integer in Little Endian 101 | uint32_t chunk_size; 102 | fread(x,1,4,fp); 103 | chunk_size = x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24); 104 | // Skip over this subchunk and keep looking for "data" subchunk 105 | if(fseek(fp,chunk_size,SEEK_CUR)) { 106 | fprintf(stderr,"Error: Couldn't find data subchunk in file.\n"); 107 | exit(EXIT_FAILURE); 108 | } 109 | } 110 | // Determine num_samples 111 | printf("bit size: %d\n",data_size); 112 | w->num_samples = data_size/((w->num_channels)*(w->bits_per_sample)/8); 113 | wdata->size = w->num_samples; 114 | 115 | wdata->data=(int16_t*)malloc(data_size); 116 | fread(wdata->data, sizeof(int16_t), wdata->size, fp); 117 | 118 | // Do some error checking: 119 | if(block_align != (w->num_channels)*(w->bits_per_sample)/8) { 120 | fprintf(stderr,"Error: block_align, num_channels, bits_per_sample mismatch in WAVE header.\n"); 121 | fprintf(stderr,"block_align=%i\n",block_align); 122 | fprintf(stderr,"num_channels=%i\n",(int) w->num_channels); 123 | fprintf(stderr,"bits_per_sample=%i\n",(int) w->bits_per_sample); 124 | exit(EXIT_FAILURE); 125 | } 126 | if(byte_rate != (w->sample_rate)*(w->num_channels)*(w->bits_per_sample)/8) { 127 | fprintf(stderr,"Error: byte_rate, sample_rate, num_channels mismatch in WAVE header.\n"); 128 | fprintf(stderr,"byte_rate=%i\n",byte_rate); 129 | fprintf(stderr,"sample_rate=%i\n",(int) w->sample_rate); 130 | fprintf(stderr,"num_channels=%i\n",(int) w->num_channels); 131 | exit(EXIT_FAILURE); 132 | } 133 | 134 | } 135 | 136 | void write_wav_hdr(const struct wav_info *w, FILE *fp) { 137 | // We'll need the following: 138 | uint32_t data_size = (w->num_samples)*(w->num_channels)*(w->bits_per_sample)/8; 139 | uint32_t RCS = data_size+36; 140 | uint32_t byte_rate = (w->sample_rate)*(w->num_channels)*(w->bits_per_sample)/8; 141 | uint16_t block_align = (w->num_channels)*(w->bits_per_sample)/8; 142 | 143 | // Prepare a standard 44 byte WAVE header from the info in w 144 | uint8_t h[44]; 145 | // Bytes 1-4 are the ASCII codes for the four characters "RIFF" 146 | h[0]=0x52; 147 | h[1]=0x49; 148 | h[2]=0x46; 149 | h[3]=0x46; 150 | // Bytes 5-8 are RCS (i.e. data_size plus the remaining 36 bytes in the header) 151 | // in Little Endian format 152 | for(int i=0; i<4; i++) h[4+i] = (RCS >> (8*i)) & 0xFF; 153 | // Bytes 9-12 are the ASCII codes for the four characters "WAVE" 154 | h[8]=0x57; 155 | h[9]=0x41; 156 | h[10]=0x56; 157 | h[11]=0x45; 158 | // Bytes 13-16 are the ASCII codes for the four characters "fmt " 159 | h[12]=0x66; 160 | h[13]=0x6D; 161 | h[14]=0x74; 162 | h[15]=0x20; 163 | // Bytes 17-20 are the integer 16 (the size of the "fmt " subchunk 164 | // in the RIFF header we are writing) as a four-byte integer in 165 | // Little Endian format 166 | h[16]=0x10; 167 | h[17]=0x00; 168 | h[18]=0x00; 169 | h[19]=0x00; 170 | // Bytes 21-22 are the integer 1 (to indicate PCM format), 171 | // written as a two-byte Little Endian 172 | h[20]=0x01; 173 | h[21]=0x00; 174 | // Bytes 23-24 are num_channels as a two-byte Little Endian 175 | for(int j=0; j<2; j++) h[22+j] = (w->num_channels >> (8*j)) & 0xFF; 176 | // Bytes 25-26 are sample_rate as a four-byte L.E. 177 | for(int i=0; i<4; i++) h[24+i] = (w->sample_rate >> (8*i)) & 0xFF; 178 | // Bytes 27-30 are byte_rate as a four-byte L.E. 179 | for(int i=0; i<4; i++) h[28+i] = (byte_rate >> (8*i)) & 0xFF; 180 | // Bytes 31-34 are block_align as a two-byte L.E. 181 | for(int j=0; j<2; j++) h[32+j] = (block_align >> (8*j)) & 0xFF; 182 | // Bytes 35-36 are bits_per_sample as a two-byte L.E. 183 | for(int j=0; j<2; j++) h[34+j] = (w->bits_per_sample >> (8*j)) & 0xFF; 184 | // Bytes 37-40 are the ASCII codes for the four characters "data" 185 | h[36]=0x64; 186 | h[37]=0x61; 187 | h[38]=0x74; 188 | h[39]=0x61; 189 | // Bytes 41-44 are data_size as a four-byte L.E. 190 | for(int i=0; i<4; i++) h[40+i] = (data_size >> (8*i)) & 0xFF; 191 | 192 | // Write the header to the beginning of *fp 193 | if(fseek(fp,0,SEEK_SET)) { 194 | fprintf(stderr,"Error with fseek in write_wav_header in wav.c\n"); 195 | exit(EXIT_FAILURE); 196 | } 197 | fwrite(h,1,44,fp); 198 | } 199 | 200 | void print_wav_info(const struct wav_info *w) { 201 | printf("Number of channels: %i\n",(int) w->num_channels); 202 | printf("Bits per sample: %i\n",(int) w->bits_per_sample); 203 | printf("Sample rate: %i Hz\n",(int) w->sample_rate); 204 | printf("Total samples per channel: %i\n", (int) w->num_samples); 205 | int duration = w->num_samples/w->sample_rate; 206 | int r = w->num_samples % w->sample_rate; 207 | if(r==0) printf("Duration: %i s\n", duration); 208 | else printf("Duration: %i + %i/%i s\n",duration,r,(int) w->sample_rate); 209 | } 210 | 211 | void write_file_wav(const struct wav_info* w, FILE* fp, const int16_t* sample) { 212 | write_wav_hdr(w, fp); 213 | int b = w->bits_per_sample/8; 214 | fwrite(sample, b, w->num_samples, fp); 215 | } 216 | 217 | void read_file_bin_data(const char *file, void *data, size_t byte_length) { 218 | std::ifstream in(file, std::ios::in | std::ios::binary); 219 | in.read((char *) data, byte_length); 220 | in.close(); 221 | } 222 | 223 | void write_file_bin_data(const char *file, void *data, size_t byte_length) { 224 | std::ofstream out(file, std::ios::out | std::ios::binary); 225 | out.write((char *) data, byte_length); 226 | out.close(); 227 | } 228 | 229 | void free_source(wav_data* wdata) { 230 | free(wdata->data); 231 | } 232 | 233 | -------------------------------------------------------------------------------- /inference/real_time_inference/C++/src/wav.h: -------------------------------------------------------------------------------- 1 | /* Provides basic handling of PCM format RIFF/WAVE audio files 2 | See 3 | http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Docs/riffmci.pdf 4 | for some information on the relevant specification. 5 | This is not as elaborate as the WAVE file support in the "audiofile" package 6 | but is very simple, self-contained, and probably adequate for most purposes. 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | // wav data 13 | struct wav_data{ 14 | public: 15 | int16_t* data; 16 | int size; 17 | 18 | wav_data(){ 19 | data=NULL; 20 | size=0; 21 | } 22 | }; 23 | // wav header info 24 | struct wav_info { 25 | uint_fast16_t num_channels; /* 1 for mono, 2 for stereo, etc. */ 26 | uint_fast16_t bits_per_sample; /* 16 for CD, 24 for high-res, etc. */ 27 | uint_fast32_t sample_rate; /* 44100 for CD, 88200, 96000, 192000, etc. */ 28 | uint_fast32_t num_samples; /* total number of samples per channel */ 29 | }; 30 | 31 | void read_wav_info(struct wav_info* w, wav_data* wdata, FILE* fp); 32 | /* Read wav_info from *fp, assuming *fp is a PCM format RIFF/WAVE file. 33 | Leaves the seek position of *fp at the beginning of the data section 34 | of *fp, so one could immediately begin reading/writing samples */ 35 | 36 | void write_wav_hdr(const struct wav_info* w, FILE* fp); 37 | /* Write a standard 44-byte PCM format RIFF/WAVE header to the beginning of *fp. 38 | Again, the seek position of *fp will be left at the beginning of 39 | the data section, so one can immediately begin writing samples */ 40 | 41 | void print_wav_info(const struct wav_info* w); 42 | /* Prints information from *w to stdout */ 43 | 44 | void write_file_wav(const struct wav_info* w, FILE* fp, const int16_t* sample); 45 | 46 | void free_source(wav_data* data); 47 | 48 | // read pcm data as .bin or .pcm 49 | void read_file_bin_data(const char *file, void *data, size_t byte_length); 50 | 51 | // write pcm data as .bin or .pcm 52 | void write_file_bin_data(const char *file, void *data, size_t byte_length); 53 | -------------------------------------------------------------------------------- /inference/real_time_inference/README.md: -------------------------------------------------------------------------------- 1 | # Real Time Inference 2 | *Created on Mon Apr 25 17:40:30 2022*
3 | *@author: xiaohuai.le* 4 | 5 | A smaller DPCRN model is used for real time inference with about 0.53 M parameters and 1.1 G MACs. The LSTMs of the baseline model are replaced by GRUs. The frame size and hop size are set to 32 ms and 16 ms respectively. 6 | -------------------------------------------------------------------------------- /inference/real_time_inference/dpcrn_stateful_model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Le-Xiaohuai-speech/DPCRN_DNS3/c7fe17d02fcc2502f198dd6c2d29bba2c4e1c0ed/inference/real_time_inference/dpcrn_stateful_model.tflite -------------------------------------------------------------------------------- /inference/real_time_inference/inference.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 20 22:16:58 2021 4 | 5 | @author: xiaohuai.le 6 | """ 7 | import numpy as np 8 | import tflite_runtime.interpreter as tflite 9 | import matplotlib.pyplot as plt 10 | import time 11 | import copy 12 | import soundfile as sf 13 | import librosa 14 | 15 | def enhancement_stateful(noisy_f, model_stateful = './dpcrn_stateful_model.tflite', output_f = './enhance_s.wav', plot = True, gain =1): 16 | 17 | noisy_s = sf.read(noisy_f,dtype = 'float32')[0] * gain 18 | 19 | length = len(noisy_s) 20 | 21 | N_frame = (length - 512) // 256 + 1 22 | 23 | enh_s = np.zeros([512 + 256 * (N_frame - 1)],dtype = np.float32) 24 | 25 | inp = np.zeros([1,1,257,3], dtype = np.float32) 26 | inp_state_1 = np.zeros([1,32,128], dtype = np.float32) 27 | inp_state_2 = np.zeros([1,32,128], dtype = np.float32) 28 | t = [] 29 | 30 | win = np.sin(np.arange(.5,512-.5+1)/512*np.pi) 31 | 32 | interpreter = tflite.Interpreter(model_path = model_stateful) 33 | interpreter.allocate_tensors() 34 | 35 | input_details = interpreter.get_input_details() 36 | output_details = interpreter.get_output_details() 37 | 38 | [print(i['name'],i['shape']) for i in input_details] 39 | [print(i['name'],i['shape']) for i in output_details] 40 | for i in range(N_frame): 41 | begin = time.perf_counter() 42 | 43 | noisy = noisy_s[i*256 : i*256 + 512] * win 44 | spec = np.fft.rfft(noisy).astype('complex64') 45 | spec1 = copy.copy(spec) 46 | 47 | inp[0,0,:,0] = spec1.real 48 | inp[0,0,:,1] = spec1.imag 49 | inp[0,0,:,2] = 2 * np.log(abs(spec)) 50 | 51 | interpreter.set_tensor(input_details[0]['index'], inp) 52 | interpreter.set_tensor(input_details[1]['index'], inp_state_1) 53 | interpreter.set_tensor(input_details[2]['index'], inp_state_2) 54 | 55 | interpreter.invoke() 56 | 57 | mag_mask = interpreter.get_tensor(output_details[0]['index'])[0,0] 58 | sin = interpreter.get_tensor(output_details[1]['index'])[0,0] 59 | cos = interpreter.get_tensor(output_details[2]['index'])[0,0] 60 | inp_state_1 = interpreter.get_tensor(output_details[3]['index']) 61 | inp_state_2 = interpreter.get_tensor(output_details[4]['index']) 62 | 63 | spec = spec * mag_mask * (cos + 1j*sin) 64 | 65 | enhanced = np.fft.irfft(spec) * win 66 | 67 | end = time.perf_counter() 68 | enh_s[i*256 : i*256 + 512] += enhanced 69 | t.append(end-begin) 70 | 71 | print('Total {} frames, inference time per frame:{}s'.format(N_frame,np.mean(t))) 72 | if plot: 73 | 74 | spec_n = librosa.stft(noisy_s,400,200,center = False) 75 | spec_e = librosa.stft(enh_s, 400,200,center = False) 76 | plt.figure(0) 77 | plt.plot(noisy_s) 78 | plt.plot(enh_s) 79 | plt.figure(1) 80 | plt.subplot(211) 81 | plt.imshow(np.log(abs(spec_n)),cmap= 'jet',origin ='lower') 82 | plt.subplot(212) 83 | plt.imshow(np.log(abs(spec_e)),cmap= 'jet',origin ='lower') 84 | 85 | sf.write(output_f,enh_s,16000) 86 | return noisy_s,enh_s 87 | 88 | if __name__ == '__main__': 89 | 90 | 91 | n,e = enhancement_stateful('D:/codes/test_audio/librispeech/white0/61-70968-0030.wav',model_stateful = './dpcrn_stateful_model.tflite') 92 | -------------------------------------------------------------------------------- /inference/real_time_inference/recording.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Oct 8 16:05:31 2021 4 | 5 | @author: xiaohuai.le 6 | """ 7 | 8 | import pyaudio 9 | import tkinter as tk 10 | import wave 11 | import threading 12 | import queue 13 | import matplotlib.pyplot as plt 14 | import matplotlib.animation as animation 15 | import matplotlib.lines as line 16 | import numpy as np 17 | from soundfile import write 18 | import tflite_runtime.interpreter as tflite 19 | 20 | #%% 21 | interpreter = tflite.Interpreter(model_path = './dpcrn_stateful_model.tflite') 22 | interpreter.allocate_tensors() 23 | 24 | input_details = interpreter.get_input_details() 25 | output_details = interpreter.get_output_details() 26 | 27 | [print('input_{}'.format(index),i['shape']) for index, i in enumerate(output_details)] 28 | [print('output_{}'.format(index),i['shape']) for index,i in enumerate(output_details)] 29 | #%% 30 | CHUNK = 256 31 | N_FFT = 512 32 | hop = 256 33 | FORMAT = pyaudio.paInt16 34 | CHANNELS = 1 35 | RATE = 16000 36 | 37 | data =[] 38 | frames=[] 39 | counter=1 40 | N = 200 41 | window = np.sin(np.arange(.5,N_FFT-.5+1)/N_FFT*np.pi) 42 | gain = 1 43 | MAX = 32767/gain 44 | frame = np.zeros(N_FFT) #256*4 45 | 46 | noisy_s = [] 47 | enh_s = [] 48 | 49 | #GUI 50 | class Application(tk.Frame): 51 | def __init__(self,master=None): 52 | tk.Frame.__init__(self,master) 53 | self.grid() 54 | self.creatWidgets() 55 | 56 | def creatWidgets(self): 57 | self.quitButton=tk.Button(self,text='quit',command=root.destroy) 58 | self.quitButton.grid(column=1,row=3) 59 | 60 | 61 | #make noisy axes and enhance axes 62 | fig = plt.figure() 63 | noisy_ax = plt.subplot(325,xlim=(0,CHUNK*N), ylim=(-MAX,MAX)) 64 | enhance_ax = plt.subplot(326,xlim=(0,CHUNK*N), ylim=(-MAX,MAX)) 65 | noisy_ax.set_title("noisy signal") 66 | enhance_ax.set_title("enhanced signal") 67 | noisy_line = line.Line2D([],[]) 68 | enhance_line = line.Line2D([],[]) 69 | #plot data update after reading buffer 70 | noisy_data = np.zeros(CHUNK*N,dtype=np.int16) 71 | enhance_data = np.zeros(CHUNK*N,dtype=np.int16) 72 | noisy_x_data = np.arange(0,CHUNK*N,1) 73 | enhance_x_data = np.arange(0,CHUNK*N,1) 74 | 75 | n_stft_ax = plt.subplot(311) 76 | n_stft_ax.set_title("noisy spectrogram") 77 | n_stft_ax.set_ylim(0,N_FFT//2 + 1) 78 | n_stft_ax.set_xlim(0,1000) 79 | n_image_stft = n_stft_ax.imshow(np.random.randn(N_FFT//2 + 1,1000),cmap ='jet') 80 | n_stft_data=np.zeros([257,1000],dtype=np.float32) 81 | 82 | stft_ax = plt.subplot(312) 83 | stft_ax.set_title("enhanced spectrogram") 84 | stft_ax.set_ylim(0,N_FFT//2 + 1) 85 | stft_ax.set_xlim(0,1000) 86 | image_stft = stft_ax.imshow(np.random.randn(N_FFT//2 + 1,1000),cmap ='jet') 87 | stft_data=np.zeros([N_FFT//2 + 1,1000],dtype=np.float32) 88 | 89 | def plot_init(): 90 | noisy_ax.add_line(noisy_line) 91 | enhance_ax.add_line(enhance_line) 92 | return enhance_line,noisy_line,image_stft,n_image_stft 93 | 94 | def plot_update(i): 95 | global noisy_data 96 | global enhance_data 97 | global enhance_x_data 98 | global stft_data 99 | global n_stft_data 100 | noisy_line.set_xdata(noisy_x_data) 101 | noisy_line.set_ydata(noisy_data) 102 | 103 | enhance_line.set_xdata(enhance_x_data) 104 | enhance_line.set_ydata(enhance_data) 105 | 106 | image_stft.set_data(stft_data) 107 | n_image_stft.set_data(n_stft_data) 108 | return enhance_line,noisy_line,image_stft,n_image_stft 109 | 110 | 111 | def audio_callback(in_data, frame_count, time_info, status): 112 | global ad_rdy_ev 113 | 114 | q.put(in_data) 115 | ad_rdy_ev.set() 116 | if counter <= 0: 117 | return (None,pyaudio.paComplete) 118 | else: 119 | return (None,pyaudio.paContinue) 120 | 121 | #processing block 122 | 123 | def read_audio_thead(q,stream,frames,ad_rdy_ev): 124 | global frame 125 | inp = np.zeros([1,1,257,3], dtype = np.float32) 126 | inp_state_1 = np.zeros([1,32,128], dtype = np.float32) 127 | inp_state_2 = np.zeros([1,32,128], dtype = np.float32) 128 | while stream.is_active(): 129 | ad_rdy_ev.wait(timeout=1000) 130 | if not q.empty(): 131 | #process audio data here 132 | data=q.get() 133 | while not q.empty(): 134 | q.get() 135 | # CHUNK * N_chunk 136 | noisy_data_0 = np.frombuffer(data,np.dtype('