├── .gitignore ├── .idea ├── .gitignore ├── AEC_DeepModel.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── data_preparation ├── __pycache__ │ └── data_preparation.cpython-36.pyc └── data_preparation.py ├── model ├── Baseline.py ├── TCN_model.py ├── __pycache__ │ ├── Baseline.cpython-36.pyc │ └── ops.cpython-36.pyc └── ops.py ├── test ├── echo_signal │ ├── echo_fileid_9992.wav │ ├── echo_fileid_9993.wav │ └── echo_fileid_9994.wav ├── farend_speech │ ├── farend_speech_fileid_9992.wav │ ├── farend_speech_fileid_9993.wav │ └── farend_speech_fileid_9994.wav ├── model_test.py ├── nearend_mic_signal │ ├── nearend_mic_fileid_9992.wav │ ├── nearend_mic_fileid_9993.wav │ └── nearend_mic_fileid_9994.wav ├── nearend_speech │ ├── nearend_speech_fileid_9992.wav │ ├── nearend_speech_fileid_9993.wav │ └── nearend_speech_fileid_9994.wav └── predict │ └── 深度学习生成的nearend_speech_fileid_9992.wav └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/AEC_baseline/10.pth 2 | Q_and_A/* 3 | data_preparation/tensor_complex.py 4 | 5 | data_preparation/Synthetic/TEST/echo_signal/*.wav 6 | data_preparation/Synthetic/TEST/farend_speech/*.wav 7 | data_preparation/Synthetic/TEST/nearend_mic_signal/*.wav 8 | data_preparation/Synthetic/TEST/nearend_speech/*.wav 9 | 10 | data_preparation/Synthetic/TRAIN/echo_signal/*.wav 11 | data_preparation/Synthetic/TRAIN/farend_speech/*.wav 12 | data_preparation/Synthetic/TRAIN/nearend_mic_signal/*.wav 13 | data_preparation/Synthetic/TRAIN/nearend_speech/*.wav 14 | 15 | data_preparation/Synthetic/VAL/echo_signal/*.wav 16 | data_preparation/Synthetic/VAL/farend_speech/*.wav 17 | data_preparation/Synthetic/VAL/nearend_mic_signal/*.wav 18 | data_preparation/Synthetic/VAL/nearend_speech/*.wav 19 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../:\声学回声消除\Code\AEC_DeepModel\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # 基于编辑器的 HTTP 客户端请求 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/AEC_DeepModel.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 17 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 292 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AEC_DeepModel 2 | 基于深度学习的声学回声消除基线代码 3 | 4 | 具体解析见我的博客:https://www.cnblogs.com/LXP-Never/p/14779360.html 5 | 6 | 7 | # 数据准备 8 | 9 | 按照以下文件结构,放好语音,我直接使用的是AEC-Challenge 数据集中的合成数据集 10 | 11 | ```angular2html 12 | └─Synthetic 13 | ├─TEST 14 | │ ├─echo_signal 15 | │ ├─farend_speech 16 | │ ├─nearend_mic_signal 17 | │ └─nearend_speech 18 | ├─TRAIN 19 | │ ├─echo_signal 20 | │ ├─farend_speech 21 | │ ├─nearend_mic_signal 22 | │ └─nearend_speech 23 | └─VAL 24 | ├─echo_signal 25 | ├─farend_speech 26 | ├─nearend_mic_signal 27 | └─nearend_speech 28 | ``` 29 | 30 | 数据处理脚本为 `data_preparation.py` 31 | 32 | 如果想要自己生成回声的话建议使用 [RIR-Generator](https://github.com/ehabets/RIR-Generator) 方法,毕竟很多论文中使用的也是这个方法 33 | 34 | 35 | 36 | # 运行 37 | 38 | ``` 39 | python train.py 40 | ``` 41 | 具体的命令行解析参数见`train.py`脚本 42 | 43 | # 估计近端语音 44 | 45 | ``` 46 | python test/model_test.py 47 | ``` 48 | 49 | 50 | 点赞,关注,不迷路 51 | 52 | 我以后还会开源更有价值的内容 53 | 54 | # 参考 55 | 56 | - [语音数据增强及python实现](https://www.cnblogs.com/LXP-Never/p/13404523.html) 57 | 58 | -------------------------------------------------------------------------------- /data_preparation/__pycache__/data_preparation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/data_preparation/__pycache__/data_preparation.cpython-36.pyc -------------------------------------------------------------------------------- /data_preparation/data_preparation.py: -------------------------------------------------------------------------------- 1 | # Author:凌逆战 2 | # -*- coding:utf-8 -*- 3 | """ 4 | 作用: 5 | """ 6 | import glob 7 | import os 8 | import torch.nn.functional as F 9 | import torch 10 | import torchaudio 11 | from torch.utils.data import Dataset 12 | from torch.utils.data import DataLoader 13 | 14 | 15 | class FileDateset(Dataset): 16 | def __init__(self, dataset_path="./Synthetic/TRAIN", fs=16000, win_length=320, mode="train"): 17 | self.fs = fs 18 | self.win_length = win_length 19 | self.mode = mode 20 | 21 | farend_speech_path = os.path.join(dataset_path, "farend_speech") # "./Synthetic/TRAIN/farend_speech" 22 | nearend_mic_signal_path = os.path.join(dataset_path, "nearend_mic_signal") # "./Synthetic/TRAIN/nearend_mic_signal" 23 | nearend_speech_path = os.path.join(dataset_path, "nearend_speech") # "./Synthetic/TRAIN/nearend_speech" 24 | 25 | self.farend_speech_list = sorted(glob.glob(farend_speech_path+"/*.wav")) # 远端语音路径,list 26 | self.nearend_mic_signal_list = sorted(glob.glob(nearend_mic_signal_path+"/*.wav")) # 近端麦克风语音路径,list 27 | self.nearend_speech_list = sorted(glob.glob(nearend_speech_path+"/*.wav")) # 近端语音路径,list 28 | 29 | def spectrogram(self, wav_path): 30 | """ 31 | :param wav_path: 音频路径 32 | :return: 返回该音频的振幅和相位 33 | """ 34 | wav, _ = torchaudio.load(wav_path) 35 | wav = wav.squeeze() 36 | 37 | if len(wav) < 160000: 38 | wav = F.pad(wav, (0,160000-len(wav)), mode="constant",value=0) 39 | 40 | S = torch.stft(wav, n_fft=self.win_length, hop_length=self.win_length//2, 41 | win_length=self.win_length, window=torch.hann_window(window_length=self.win_length), 42 | center=False, return_complex=True) # (*, F,T) 43 | magnitude = torch.abs(S) # 振幅 44 | phase = torch.exp(1j * torch.angle(S)) # 相位 45 | return magnitude, phase 46 | 47 | 48 | def __getitem__(self, item): 49 | """__getitem__是类的专有方法,使类可以像list一样按照索引来获取元素 50 | :param item: 索引 51 | :return: 按 索引取出来的 元素 52 | """ 53 | # 远端语音 振幅,相位 (F, T),F为频点数,T为帧数 54 | farend_speech_magnitude, farend_speech_phase = self.spectrogram(self.farend_speech_list[item]) # torch.Size([161, 999]) 55 | # 近端麦克风 振幅,相位 56 | nearend_mic_magnitude, nearend_mic_phase = self.spectrogram(self.nearend_mic_signal_list[item]) 57 | # 近端语音 振幅,相位 58 | nearend_speech_magnitude, nearend_speech_phase = self.spectrogram(self.nearend_speech_list[item]) 59 | 60 | X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0) # 在频点维度上进行拼接(161*2, 999),模型输入 61 | 62 | _eps = torch.finfo(torch.float).eps # 防止分母出现0 63 | mask_IRM = torch.sqrt(nearend_speech_magnitude ** 2/(nearend_mic_magnitude ** 2+_eps)) # IRM,模型输出 64 | 65 | 66 | return X, mask_IRM, nearend_mic_magnitude, nearend_speech_magnitude 67 | 68 | def __len__(self): 69 | """__len__是类的专有方法,获取整个数据的长度""" 70 | return len(self.farend_speech_list) 71 | 72 | 73 | if __name__ == "__main__": 74 | train_set = FileDateset() 75 | train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True) 76 | 77 | for x, y, nearend_mic_magnitude,nearend_speech_magnitude in train_loader: 78 | print(x.shape) # torch.Size([64, 322, 999]) 79 | print(y.shape) # torch.Size([64, 161, 999]) 80 | print(nearend_mic_magnitude.shape) 81 | -------------------------------------------------------------------------------- /model/Baseline.py: -------------------------------------------------------------------------------- 1 | # Author:凌逆战 2 | # -*- coding:utf-8 -*- 3 | """ 4 | 作用:随便搭建的模型,只要符合输入大小是[64, 322, 999],输出大小是[64, 161, 999],就能跑通 5 | """ 6 | import torch.nn as nn 7 | import torch 8 | 9 | 10 | class Base_model(nn.Module): 11 | def __init__(self): 12 | super(Base_model, self).__init__() 13 | # [batch, channel, input_size] (B, F, T) 14 | # [64, 322, 999] ---> [64, 161, 999] 15 | self.model = nn.Sequential( 16 | nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1), 17 | nn.LeakyReLU(0.2), 18 | nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1), 19 | nn.LeakyReLU(0.2), 20 | nn.Conv1d(in_channels=322, out_channels=161, kernel_size=3, stride=1, padding=1), 21 | nn.LeakyReLU(0.2), 22 | nn.Conv1d(in_channels=161, out_channels=161, kernel_size=3, stride=1, padding=1), 23 | nn.Sigmoid() 24 | ) 25 | 26 | def forward(self, x): 27 | """ 28 | :param x: 麦克风信号和远端信号的特征串联在一起作为输入特征 (322, 206) 29 | :return: IRM_mask * input = 近端语音对数谱 30 | """ 31 | Estimated_IRM = self.model(x) 32 | 33 | return Estimated_IRM 34 | 35 | 36 | if __name__ == "__main__": 37 | model = Base_model().cuda() 38 | x = torch.randn(8, 322, 999).to("cuda") # 输入 [8, 322, 999] 39 | y = model(x) # 输出 [8, 161, 999] 40 | print(y.shape) 41 | -------------------------------------------------------------------------------- /model/TCN_model.py: -------------------------------------------------------------------------------- 1 | # Author:凌逆战 2 | # -*- coding:utf-8 -*- 3 | """ 4 | 作用: 5 | """ 6 | import torch.nn as nn 7 | import torch 8 | 9 | # Author:凌逆战 10 | # -*- coding:utf-8 -*- 11 | import torch 12 | from torch import nn 13 | from torch.nn import init 14 | from torch.nn.utils import weight_norm 15 | from torch.nn import functional as F 16 | 17 | 18 | def weights_init(m): 19 | classname = m.__class__.__name__ 20 | if classname.find("Conv") != -1: 21 | m.weight.data.normal_(0.0, 0.02) 22 | elif classname.find("BatchNorm2d") != -1: 23 | m.weight.data.normal_(1.0, 0.02) 24 | m.bias.data.fill_(0) 25 | 26 | 27 | class PixelShuffle1D(nn.Module): 28 | """ 29 | 1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf 30 | Upscales sample length, downscales channel length 31 | "short" is input, "long" is output 32 | """ 33 | 34 | def __init__(self, upscale_factor): 35 | super(PixelShuffle1D, self).__init__() 36 | self.upscale_factor = upscale_factor 37 | 38 | def forward(self, x): 39 | batch_size, short_channel_len, short_width = x.size() 40 | 41 | long_channel_len = short_channel_len // self.upscale_factor 42 | long_width = self.upscale_factor * short_width 43 | 44 | x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width]) 45 | x = x.permute(0, 2, 3, 1).contiguous() 46 | x = x.view(batch_size, long_channel_len, long_width) 47 | 48 | return x 49 | 50 | 51 | class Chomp1d(nn.Module): 52 | def __init__(self, chomp_size): 53 | super(Chomp1d, self).__init__() 54 | self.chomp_size = chomp_size 55 | 56 | def forward(self, x): 57 | return x[:, :, :-self.chomp_size].contiguous() 58 | 59 | 60 | class TCN_block(nn.Module): 61 | """这个padding和我的有些不同""" 62 | 63 | def __init__(self, in_channel, out_channel, kernel_size, dilation): 64 | super(TCN_block, self).__init__() 65 | padding = (kernel_size - 1) * dilation 66 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size, padding=padding, dilation=dilation) 67 | self.chomp1 = Chomp1d(padding) 68 | self.bn1 = nn.BatchNorm1d(num_features=out_channel) # BN有bias的作用 69 | self.LeakyReLU1 = nn.LeakyReLU(negative_slope=0.2) 70 | # self.dropout1 = nn.Dropout(dropout) 71 | # --------------------------------------------------------------- 72 | self.conv2 = nn.Conv1d(out_channel, out_channel * 2, kernel_size, padding=kernel_size // 2) 73 | self.bn2 = nn.BatchNorm1d(num_features=out_channel * 2) 74 | # --------------------------------------------------------------- 75 | if in_channel == 2 * out_channel: 76 | self.downsample = None 77 | else: 78 | self.downsample = nn.Conv1d(in_channel, out_channel * 2, kernel_size=1) 79 | self.LeakyReLU = nn.LeakyReLU(negative_slope=0.2) 80 | self.init_weights() 81 | 82 | def init_weights(self): 83 | init.orthogonal_(self.conv1.weight) 84 | init.zeros_(self.conv1.bias) 85 | init.orthogonal_(self.conv2.weight) 86 | init.zeros_(self.conv2.bias) 87 | # BN层 88 | init.normal_(self.bn1.weight, mean=1.0, std=0.02) 89 | init.constant_(self.bn1.bias, 0) 90 | init.normal_(self.bn2.weight, mean=1.0, std=0.02) 91 | init.constant_(self.bn2.bias, 0) 92 | if self.downsample is not None: 93 | init.orthogonal_(self.downsample.weight) 94 | 95 | def forward(self, input): 96 | x = self.conv1(input) 97 | x = self.chomp1(x) 98 | x = self.bn1(x) 99 | x = self.LeakyReLU1(x) 100 | # -------------------------- 101 | x = self.conv2(x) 102 | out = self.bn2(x) 103 | 104 | if self.downsample is None: 105 | res = input 106 | else: 107 | res = self.downsample(input) 108 | return self.LeakyReLU(out + res) 109 | 110 | 111 | class TCN_block_k3(nn.Module): 112 | # 如果用这个kernel必须设置为3 113 | def __init__(self, in_channel, out_channel, kernel_size, dilation): 114 | super(TCN_block_k3, self).__init__() 115 | self.padding = nn.ReflectionPad1d(dilation) 116 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=3, stride=1, dilation=dilation) 117 | self.bn1 = nn.BatchNorm1d(num_features=out_channel) # BN有bias的作用 118 | self.leakyrelu_1 = nn.LeakyReLU(negative_slope=0.2) 119 | # self.dropout1 = nn.Dropout(dropout) 120 | # --------------------------------------------------------------- 121 | self.conv2 = nn.Conv1d(out_channel, out_channel * 2, kernel_size=1, stride=1) 122 | self.bn2 = nn.BatchNorm1d(num_features=out_channel * 2) 123 | # --------------------------------------------------------------- 124 | if in_channel == 2 * out_channel: 125 | self.downsample = None 126 | else: 127 | self.downsample = nn.Conv1d(in_channel, out_channel * 2, kernel_size=1) 128 | self.LeakyReLU = nn.LeakyReLU(negative_slope=0.2) 129 | self.init_weights() 130 | 131 | def init_weights(self): 132 | init.orthogonal_(self.conv1.weight) # 第一层卷积权重初始化 133 | init.orthogonal_(self.conv2.weight) # 第二层卷积权重初始化 134 | # BN层 135 | init.normal_(self.bn1.weight, mean=1.0, std=0.02) 136 | init.constant_(self.bn1.bias, 0) 137 | init.normal_(self.bn2.weight, mean=1.0, std=0.02) 138 | init.constant_(self.bn2.bias, 0) 139 | if self.downsample is not None: 140 | init.orthogonal_(self.downsample.weight) 141 | 142 | def forward(self, input): 143 | x = self.padding(input) 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.leakyrelu_1(x) 147 | # -------------------------- 148 | x = self.conv2(x) 149 | out = self.bn2(x) 150 | 151 | if self.downsample is None: 152 | # print("1") 153 | res = input 154 | else: 155 | # print("2") 156 | res = self.downsample(input) 157 | return self.LeakyReLU(out + res) 158 | 159 | 160 | class TCN_model(nn.Module): 161 | def __init__(self): 162 | super(TCN_model, self).__init__() 163 | # (64, 322, 998) 164 | self.first = nn.Conv1d(in_channels=322, out_channels=161, kernel_size=9, stride=2, padding=9 // 2) # (64, 161, 499) 165 | self.TCN_conv0 = TCN_block(in_channel=161, out_channel=161, kernel_size=9, dilation=2 ** 0) 166 | self.TCN_conv1 = TCN_block(in_channel=322, out_channel=161, kernel_size=9, dilation=2 ** 1) 167 | self.TCN_conv2 = TCN_block(in_channel=322, out_channel=161, kernel_size=9, dilation=2 ** 2) 168 | 169 | self.lastconv = nn.Conv1d(in_channels=322, out_channels=2, kernel_size=9, stride=1, padding=9 // 2) 170 | self.subpix = PixelShuffle1D(upscale_factor=2) 171 | self.init_weights() 172 | 173 | def init_weights(self): 174 | init.orthogonal_(self.first.weight) 175 | init.constant_(self.first.bias, 0) 176 | init.normal_(self.lastconv.weight, mean=0, std=1e-3) 177 | init.constant_(self.lastconv.bias, 0) 178 | 179 | def forward(self, input): 180 | # inputs (64, 322, 999) 181 | x = self.first(input) # torch.Size([64, 161, 500]) 182 | print(x.shape) 183 | x = self.TCN_conv0(x) # torch.Size([64, 322, 500]) 184 | print(x.shape) 185 | x = self.TCN_conv1(x) 186 | print(x.shape) 187 | x = self.TCN_conv2(x) 188 | print(x.shape) 189 | x = self.lastconv(x) 190 | print(x.shape) 191 | x = self.subpix(x) 192 | print(x.shape) 193 | return x 194 | 195 | 196 | x = torch.randn(64, 322, 998) 197 | model = TCN_model() 198 | output = model(x) 199 | print(output.shape) 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /model/__pycache__/Baseline.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/model/__pycache__/Baseline.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/ops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/model/__pycache__/ops.cpython-36.pyc -------------------------------------------------------------------------------- /model/ops.py: -------------------------------------------------------------------------------- 1 | # Author:凌逆战 2 | # -*- coding:utf-8 -*- 3 | """ 4 | 作用: 5 | """ 6 | import torch 7 | 8 | 9 | 10 | def frequency_MSE_loss(logits, labels): 11 | """ 均方误差,频域损失 12 | labels:batch_labels 13 | logits:batch_logits 14 | """ 15 | loss = torch.mean((logits - labels) ** 2) 16 | return loss 17 | 18 | 19 | def frequency_RMSE_loss(logits, labels): 20 | """ 均方根误差,频域损失 21 | labels:batch_labels 22 | logits:batch_logits 23 | """ 24 | loss = torch.sqrt(torch.mean((logits - labels) ** 2, dim=[1, 2])) 25 | loss = torch.mean(loss, dim=0) 26 | return loss 27 | 28 | 29 | def frequency_MAE_loss(logits, labels): 30 | """ 平均绝对值误差,频域损失 31 | labels:batch_labels 32 | logits:batch_logits 33 | """ 34 | loss = torch.mean(torch.abs(logits - labels)) 35 | return loss 36 | 37 | # ###################### 计算LSD ###################### 38 | def pytorch_LSD(logits, labels): 39 | # (…, freq, time) 40 | 41 | logits_log = torch.log10(logits ** 2 + 3e-9) 42 | labels_log = torch.log10(labels ** 2 + 3e-9) 43 | original_target_squared = (labels_log - logits_log) ** 2 44 | 45 | # lsd = torch.mean(torch.sqrt(torch.mean(original_target_squared, dim=0))) 46 | lsd = torch.mean(torch.sqrt(torch.mean(original_target_squared, dim=1)), dim=1) 47 | lsd = torch.mean(lsd, dim=0) 48 | 49 | return lsd -------------------------------------------------------------------------------- /test/echo_signal/echo_fileid_9992.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/echo_signal/echo_fileid_9992.wav -------------------------------------------------------------------------------- /test/echo_signal/echo_fileid_9993.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/echo_signal/echo_fileid_9993.wav -------------------------------------------------------------------------------- /test/echo_signal/echo_fileid_9994.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/echo_signal/echo_fileid_9994.wav -------------------------------------------------------------------------------- /test/farend_speech/farend_speech_fileid_9992.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/farend_speech/farend_speech_fileid_9992.wav -------------------------------------------------------------------------------- /test/farend_speech/farend_speech_fileid_9993.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/farend_speech/farend_speech_fileid_9993.wav -------------------------------------------------------------------------------- /test/farend_speech/farend_speech_fileid_9994.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/farend_speech/farend_speech_fileid_9994.wav -------------------------------------------------------------------------------- /test/model_test.py: -------------------------------------------------------------------------------- 1 | # Author:凌逆战 2 | # -*- coding:utf-8 -*- 3 | """ 4 | 作用:通过模型生成近端语音 5 | """ 6 | import librosa 7 | import matplotlib 8 | import torchaudio 9 | import torch.nn.functional as F 10 | import torch 11 | import matplotlib.pyplot as plt 12 | from model.Baseline import Base_model 13 | from matplotlib.ticker import FuncFormatter 14 | import numpy as np 15 | 16 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 17 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示符号 18 | 19 | 20 | def spectrogram(wav_path, win_length=320): 21 | wav, _ = torchaudio.load(wav_path) 22 | wav = wav.squeeze() 23 | 24 | if len(wav) < 160000: 25 | wav = F.pad(wav, (0, 160000 - len(wav)), mode="constant", value=0) 26 | # if len(wav) != 160000: 27 | # print(wav_path) 28 | # print(len(wav)) 29 | 30 | S = torch.stft(wav, n_fft=win_length, hop_length=win_length // 2, 31 | win_length=win_length, window=torch.hann_window(window_length=win_length), 32 | center=False, return_complex=True) 33 | magnitude = torch.abs(S) 34 | phase = torch.exp(1j * torch.angle(S)) 35 | return magnitude, phase 36 | 37 | 38 | fs = 16000 39 | farend_speech = "./farend_speech/farend_speech_fileid_9992.wav" 40 | nearend_mic_signal = "./nearend_mic_signal/nearend_mic_fileid_9992.wav" 41 | nearend_speech = "./nearend_speech/nearend_speech_fileid_9992.wav" 42 | echo_signal = "./echo_signal/echo_fileid_9992.wav" 43 | 44 | print("GPU是否可用:", torch.cuda.is_available()) # True 45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 46 | 47 | farend_speech_magnitude, farend_speech_phase = spectrogram(farend_speech) # 远端语音 振幅,相位 48 | nearend_mic_magnitude, nearend_mic_phase = spectrogram(nearend_mic_signal) # 近端麦克风语音 振幅,相位 49 | nearend_speech_magnitude, nearend_speech_phase = spectrogram(nearend_speech) # 近端语音振 幅,相位 50 | 51 | farend_speech_magnitude = farend_speech_magnitude.to(device) 52 | nearend_mic_phase = nearend_mic_phase.to(device) 53 | nearend_mic_magnitude = nearend_mic_magnitude.to(device) 54 | 55 | nearend_speech_magnitude = nearend_speech_magnitude.to(device) 56 | nearend_speech_phase = nearend_speech_phase.to(device) 57 | 58 | model = Base_model().to(device) # 实例化模型 59 | checkpoint = torch.load("../checkpoints/AEC_baseline/10.pth") 60 | model.load_state_dict(checkpoint["model"]) 61 | 62 | X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0) 63 | X = X.unsqueeze(0) 64 | per_mask = model(X) # [1, 322, 999]-->[1, 161, 999] 65 | 66 | per_nearend_magnitude = per_mask * nearend_mic_magnitude # 预测的近端语音 振幅 67 | 68 | complex_stft = per_nearend_magnitude * nearend_mic_phase # 振幅*相位=语音复数表示 69 | print("complex_stft", complex_stft.shape) # [1, 161, 999] 70 | 71 | per_nearend = torch.istft(complex_stft, n_fft=320, hop_length=160, win_length=320, 72 | window=torch.hann_window(window_length=320).to("cuda")) 73 | 74 | torchaudio.save("./predict/nearend_speech_fileid_9992.wav", src=per_nearend.cpu().detach(), sample_rate=fs) 75 | # print("近端语音", per_nearend.shape) # [1, 159680] 76 | 77 | y, _ = librosa.load(nearend_speech, sr=fs) 78 | time_y = np.arange(0, len(y)) * (1.0 / fs) 79 | recover_wav, _ = librosa.load("./predict/nearend_speech_fileid_9992.wav", sr=16000) 80 | time_recover = np.arange(0, len(recover_wav)) * (1.0 / fs) 81 | 82 | plt.figure(figsize=(8,6)) 83 | ax_1 = plt.subplot(3, 1, 1) 84 | plt.title("近端语音和预测近端波形图", fontsize=14) 85 | plt.plot(time_y, y, label="近端语音") 86 | plt.plot(time_recover, recover_wav, label="深度学习生成的近端语音波形") 87 | plt.xlabel('时间/s', fontsize=14) 88 | plt.ylabel('幅值', fontsize=14) 89 | plt.xticks(fontsize=14) 90 | plt.yticks(fontsize=14) 91 | plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998) 92 | plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距 93 | plt.legend() 94 | 95 | norm = matplotlib.colors.Normalize(vmin=-200, vmax=-40) 96 | ax_2 = plt.subplot(3, 1, 2) 97 | plt.title("近端语音频谱", fontsize=14) 98 | plt.specgram(y, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm) 99 | plt.xlabel('时间/s', fontsize=14) 100 | plt.ylabel('频率/kHz', fontsize=14) 101 | plt.xticks(fontsize=14) 102 | plt.yticks(fontsize=14) 103 | plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998) 104 | plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距 105 | 106 | ax_3 = plt.subplot(3, 1, 3) 107 | plt.title("深度学习生成的近端语音频谱", fontsize=14) 108 | plt.specgram(recover_wav, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm) 109 | plt.xlabel('时间/s', fontsize=14) 110 | plt.ylabel('频率/kHz', fontsize=14) 111 | plt.xticks(fontsize=14) 112 | plt.yticks(fontsize=14) 113 | plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998) 114 | plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距 115 | 116 | def formatnum(x, pos): 117 | return '$%d$' % (x / 1000) 118 | 119 | 120 | formatter = FuncFormatter(formatnum) 121 | ax_2.yaxis.set_major_formatter(formatter) 122 | ax_3.yaxis.set_major_formatter(formatter) 123 | 124 | 125 | plt.show() 126 | -------------------------------------------------------------------------------- /test/nearend_mic_signal/nearend_mic_fileid_9992.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_mic_signal/nearend_mic_fileid_9992.wav -------------------------------------------------------------------------------- /test/nearend_mic_signal/nearend_mic_fileid_9993.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_mic_signal/nearend_mic_fileid_9993.wav -------------------------------------------------------------------------------- /test/nearend_mic_signal/nearend_mic_fileid_9994.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_mic_signal/nearend_mic_fileid_9994.wav -------------------------------------------------------------------------------- /test/nearend_speech/nearend_speech_fileid_9992.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_speech/nearend_speech_fileid_9992.wav -------------------------------------------------------------------------------- /test/nearend_speech/nearend_speech_fileid_9993.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_speech/nearend_speech_fileid_9993.wav -------------------------------------------------------------------------------- /test/nearend_speech/nearend_speech_fileid_9994.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_speech/nearend_speech_fileid_9994.wav -------------------------------------------------------------------------------- /test/predict/深度学习生成的nearend_speech_fileid_9992.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/predict/深度学习生成的nearend_speech_fileid_9992.wav -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import DataLoader 4 | from torch import nn 5 | import argparse 6 | from tensorboardX import SummaryWriter 7 | 8 | from data_preparation.data_preparation import FileDateset 9 | from model.Baseline import Base_model 10 | from model.ops import pytorch_LSD 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | # 重头开始训练 defaule=None, 继续训练defaule设置为'/**.pth' 16 | parser.add_argument("--model_name", type=str, default=None, help="是否加载模型继续训练 '/50.pth' None") 17 | parser.add_argument("--batch-size", type=int, default=16, help="") 18 | parser.add_argument("--epochs", type=int, default=20) 19 | parser.add_argument('--lr', type=float, default=3e-4, help='学习率 (default: 0.01)') 20 | parser.add_argument('--train_data', default="./data_preparation/Synthetic/TRAIN", help='数据集的path') 21 | parser.add_argument('--val_data', default="./data_preparation/Synthetic/VAL", help='验证样本的path') 22 | parser.add_argument('--checkpoints_dir', default="./checkpoints/AEC_baseline", help='模型检查点文件的路径(以继续培训)') 23 | parser.add_argument('--event_dir', default="./event_file/AEC_baseline", help='tensorboard事件文件的地址') 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | def main(): 29 | args = parse_args() 30 | print("GPU是否可用:", torch.cuda.is_available()) # True 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | 33 | # 实例化 Dataset 34 | train_set = FileDateset(dataset_path=args.train_data) # 实例化训练数据集 35 | val_set = FileDateset(dataset_path=args.val_data) # 实例化验证数据集 36 | 37 | # 数据加载器 38 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True) 39 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True) 40 | 41 | # ########### 保存检查点的地址(如果检查点不存在,则创建) ############ 42 | if not os.path.exists(args.checkpoints_dir): 43 | os.makedirs(args.checkpoints_dir) 44 | 45 | ################################ 46 | # 实例化模型 # 47 | ################################ 48 | model = Base_model().to(device) # 实例化模型 49 | # summary(model, input_size=(322, 999)) # 模型输出 torch.Size([64, 322, 999]) 50 | # ########### 损失函数 ############ 51 | criterion = nn.MSELoss(reduce=True, size_average=True, reduction='mean') 52 | 53 | ############################### 54 | # 创建优化器 Create optimizers # 55 | ############################### 56 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, ) 57 | # lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1) 58 | 59 | # ########### TensorBoard可视化 summary ############ 60 | writer = SummaryWriter(args.event_dir) # 创建事件文件 61 | 62 | # ########### 加载模型检查点 ############ 63 | start_epoch = 0 64 | if args.model_name: 65 | print("加载模型:", args.checkpoints_dir + args.model_name) 66 | checkpoint = torch.load(args.checkpoints_dir + args.model_name) 67 | model.load_state_dict(checkpoint["model"]) 68 | optimizer.load_state_dict(checkpoint["optimizer"]) 69 | start_epoch = checkpoint['epoch'] 70 | # lr_schedule.load_state_dict(checkpoint['lr_schedule']) # 加载lr_scheduler 71 | 72 | for epoch in range(start_epoch, args.epochs): 73 | model.train() # 训练模型 74 | for batch_idx, (train_X, train_mask, train_nearend_mic_magnitude, train_nearend_magnitude) in enumerate( 75 | train_loader): 76 | train_X = train_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T) 77 | train_mask = train_mask.to(device) # IRM [batch_size 161, 999] 78 | train_nearend_mic_magnitude = train_nearend_mic_magnitude.to(device) 79 | train_nearend_magnitude = train_nearend_magnitude.to(device) 80 | 81 | # 前向传播 82 | pred_mask = model(train_X) # [batch_size, 322, 999]--> [batch_size, 161, 999] 83 | train_loss = criterion(pred_mask, train_mask) 84 | 85 | # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999] 86 | pred_near_spectrum = pred_mask * train_nearend_mic_magnitude 87 | train_lsd = pytorch_LSD(train_nearend_magnitude, pred_near_spectrum) 88 | 89 | # 反向传播 90 | optimizer.zero_grad() # 将梯度清零 91 | train_loss.backward() # 反向传播 92 | optimizer.step() # 更新参数 93 | 94 | # ########### 可视化打印 ############ 95 | print('Train Epoch: {} Loss: {:.6f} LSD: {:.6f}'.format(epoch + 1, train_loss.item(), train_lsd.item())) 96 | 97 | # ########### TensorBoard可视化 summary ############ 98 | # lr_schedule.step() # 学习率衰减 99 | # writer.add_scalar(tag="lr", scalar_value=model.state_dict()['param_groups'][0]['lr'], global_step=epoch + 1) 100 | writer.add_scalar(tag="train_loss", scalar_value=train_loss.item(), global_step=epoch + 1) 101 | writer.add_scalar(tag="train_lsd", scalar_value=train_lsd.item(), global_step=epoch + 1) 102 | writer.flush() 103 | 104 | # 神经网络在验证数据集上的表现 105 | model.eval() # 测试模型 106 | # 测试的时候不需要梯度 107 | with torch.no_grad(): 108 | for val_batch_idx, (val_X, val_mask, val_nearend_mic_magnitude, val_nearend_magnitude) in enumerate( 109 | val_loader): 110 | val_X = val_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T) 111 | val_mask = val_mask.to(device) # IRM [batch_size 161, 999] 112 | val_nearend_mic_magnitude = val_nearend_mic_magnitude.to(device) 113 | val_nearend_magnitude = val_nearend_magnitude.to(device) 114 | 115 | # 前向传播 116 | val_pred_mask = model(val_X) 117 | val_loss = criterion(val_pred_mask, val_mask) 118 | 119 | # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999] 120 | val_pred_near_spectrum = val_pred_mask * val_nearend_mic_magnitude 121 | val_lsd = pytorch_LSD(val_nearend_magnitude, val_pred_near_spectrum) 122 | 123 | # ########### 可视化打印 ############ 124 | print(' val Epoch: {} \tLoss: {:.6f}\tlsd: {:.6f}'.format(epoch + 1, val_loss.item(), val_lsd.item())) 125 | ###################### 126 | # 更新tensorboard # 127 | ###################### 128 | writer.add_scalar(tag="val_loss", scalar_value=val_loss.item(), global_step=epoch + 1) 129 | writer.add_scalar(tag="val_lsd", scalar_value=val_lsd.item(), global_step=epoch + 1) 130 | writer.flush() 131 | 132 | # # ########### 保存模型 ############ 133 | if (epoch + 1) % 10 == 0: 134 | checkpoint = { 135 | "model": model.state_dict(), 136 | "optimizer": optimizer.state_dict(), 137 | "epoch": epoch + 1, 138 | # 'lr_schedule': lr_schedule.state_dict() 139 | } 140 | torch.save(checkpoint, '%s/%d.pth' % (args.checkpoints_dir, epoch + 1)) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | --------------------------------------------------------------------------------