├── .gitignore
├── .idea
├── .gitignore
├── AEC_DeepModel.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── data_preparation
├── __pycache__
│ └── data_preparation.cpython-36.pyc
└── data_preparation.py
├── model
├── Baseline.py
├── TCN_model.py
├── __pycache__
│ ├── Baseline.cpython-36.pyc
│ └── ops.cpython-36.pyc
└── ops.py
├── test
├── echo_signal
│ ├── echo_fileid_9992.wav
│ ├── echo_fileid_9993.wav
│ └── echo_fileid_9994.wav
├── farend_speech
│ ├── farend_speech_fileid_9992.wav
│ ├── farend_speech_fileid_9993.wav
│ └── farend_speech_fileid_9994.wav
├── model_test.py
├── nearend_mic_signal
│ ├── nearend_mic_fileid_9992.wav
│ ├── nearend_mic_fileid_9993.wav
│ └── nearend_mic_fileid_9994.wav
├── nearend_speech
│ ├── nearend_speech_fileid_9992.wav
│ ├── nearend_speech_fileid_9993.wav
│ └── nearend_speech_fileid_9994.wav
└── predict
│ └── 深度学习生成的nearend_speech_fileid_9992.wav
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints/AEC_baseline/10.pth
2 | Q_and_A/*
3 | data_preparation/tensor_complex.py
4 |
5 | data_preparation/Synthetic/TEST/echo_signal/*.wav
6 | data_preparation/Synthetic/TEST/farend_speech/*.wav
7 | data_preparation/Synthetic/TEST/nearend_mic_signal/*.wav
8 | data_preparation/Synthetic/TEST/nearend_speech/*.wav
9 |
10 | data_preparation/Synthetic/TRAIN/echo_signal/*.wav
11 | data_preparation/Synthetic/TRAIN/farend_speech/*.wav
12 | data_preparation/Synthetic/TRAIN/nearend_mic_signal/*.wav
13 | data_preparation/Synthetic/TRAIN/nearend_speech/*.wav
14 |
15 | data_preparation/Synthetic/VAL/echo_signal/*.wav
16 | data_preparation/Synthetic/VAL/farend_speech/*.wav
17 | data_preparation/Synthetic/VAL/nearend_mic_signal/*.wav
18 | data_preparation/Synthetic/VAL/nearend_speech/*.wav
19 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../:\声学回声消除\Code\AEC_DeepModel\.idea/dataSources/
6 | /dataSources.local.xml
7 | # 基于编辑器的 HTTP 客户端请求
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/AEC_DeepModel.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
290 |
291 |
292 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AEC_DeepModel
2 | 基于深度学习的声学回声消除基线代码
3 |
4 | 具体解析见我的博客:https://www.cnblogs.com/LXP-Never/p/14779360.html
5 |
6 |
7 | # 数据准备
8 |
9 | 按照以下文件结构,放好语音,我直接使用的是AEC-Challenge 数据集中的合成数据集
10 |
11 | ```angular2html
12 | └─Synthetic
13 | ├─TEST
14 | │ ├─echo_signal
15 | │ ├─farend_speech
16 | │ ├─nearend_mic_signal
17 | │ └─nearend_speech
18 | ├─TRAIN
19 | │ ├─echo_signal
20 | │ ├─farend_speech
21 | │ ├─nearend_mic_signal
22 | │ └─nearend_speech
23 | └─VAL
24 | ├─echo_signal
25 | ├─farend_speech
26 | ├─nearend_mic_signal
27 | └─nearend_speech
28 | ```
29 |
30 | 数据处理脚本为 `data_preparation.py`
31 |
32 | 如果想要自己生成回声的话建议使用 [RIR-Generator](https://github.com/ehabets/RIR-Generator) 方法,毕竟很多论文中使用的也是这个方法
33 |
34 |
35 |
36 | # 运行
37 |
38 | ```
39 | python train.py
40 | ```
41 | 具体的命令行解析参数见`train.py`脚本
42 |
43 | # 估计近端语音
44 |
45 | ```
46 | python test/model_test.py
47 | ```
48 |
49 |
50 | 点赞,关注,不迷路
51 |
52 | 我以后还会开源更有价值的内容
53 |
54 | # 参考
55 |
56 | - [语音数据增强及python实现](https://www.cnblogs.com/LXP-Never/p/13404523.html)
57 |
58 |
--------------------------------------------------------------------------------
/data_preparation/__pycache__/data_preparation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/data_preparation/__pycache__/data_preparation.cpython-36.pyc
--------------------------------------------------------------------------------
/data_preparation/data_preparation.py:
--------------------------------------------------------------------------------
1 | # Author:凌逆战
2 | # -*- coding:utf-8 -*-
3 | """
4 | 作用:
5 | """
6 | import glob
7 | import os
8 | import torch.nn.functional as F
9 | import torch
10 | import torchaudio
11 | from torch.utils.data import Dataset
12 | from torch.utils.data import DataLoader
13 |
14 |
15 | class FileDateset(Dataset):
16 | def __init__(self, dataset_path="./Synthetic/TRAIN", fs=16000, win_length=320, mode="train"):
17 | self.fs = fs
18 | self.win_length = win_length
19 | self.mode = mode
20 |
21 | farend_speech_path = os.path.join(dataset_path, "farend_speech") # "./Synthetic/TRAIN/farend_speech"
22 | nearend_mic_signal_path = os.path.join(dataset_path, "nearend_mic_signal") # "./Synthetic/TRAIN/nearend_mic_signal"
23 | nearend_speech_path = os.path.join(dataset_path, "nearend_speech") # "./Synthetic/TRAIN/nearend_speech"
24 |
25 | self.farend_speech_list = sorted(glob.glob(farend_speech_path+"/*.wav")) # 远端语音路径,list
26 | self.nearend_mic_signal_list = sorted(glob.glob(nearend_mic_signal_path+"/*.wav")) # 近端麦克风语音路径,list
27 | self.nearend_speech_list = sorted(glob.glob(nearend_speech_path+"/*.wav")) # 近端语音路径,list
28 |
29 | def spectrogram(self, wav_path):
30 | """
31 | :param wav_path: 音频路径
32 | :return: 返回该音频的振幅和相位
33 | """
34 | wav, _ = torchaudio.load(wav_path)
35 | wav = wav.squeeze()
36 |
37 | if len(wav) < 160000:
38 | wav = F.pad(wav, (0,160000-len(wav)), mode="constant",value=0)
39 |
40 | S = torch.stft(wav, n_fft=self.win_length, hop_length=self.win_length//2,
41 | win_length=self.win_length, window=torch.hann_window(window_length=self.win_length),
42 | center=False, return_complex=True) # (*, F,T)
43 | magnitude = torch.abs(S) # 振幅
44 | phase = torch.exp(1j * torch.angle(S)) # 相位
45 | return magnitude, phase
46 |
47 |
48 | def __getitem__(self, item):
49 | """__getitem__是类的专有方法,使类可以像list一样按照索引来获取元素
50 | :param item: 索引
51 | :return: 按 索引取出来的 元素
52 | """
53 | # 远端语音 振幅,相位 (F, T),F为频点数,T为帧数
54 | farend_speech_magnitude, farend_speech_phase = self.spectrogram(self.farend_speech_list[item]) # torch.Size([161, 999])
55 | # 近端麦克风 振幅,相位
56 | nearend_mic_magnitude, nearend_mic_phase = self.spectrogram(self.nearend_mic_signal_list[item])
57 | # 近端语音 振幅,相位
58 | nearend_speech_magnitude, nearend_speech_phase = self.spectrogram(self.nearend_speech_list[item])
59 |
60 | X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0) # 在频点维度上进行拼接(161*2, 999),模型输入
61 |
62 | _eps = torch.finfo(torch.float).eps # 防止分母出现0
63 | mask_IRM = torch.sqrt(nearend_speech_magnitude ** 2/(nearend_mic_magnitude ** 2+_eps)) # IRM,模型输出
64 |
65 |
66 | return X, mask_IRM, nearend_mic_magnitude, nearend_speech_magnitude
67 |
68 | def __len__(self):
69 | """__len__是类的专有方法,获取整个数据的长度"""
70 | return len(self.farend_speech_list)
71 |
72 |
73 | if __name__ == "__main__":
74 | train_set = FileDateset()
75 | train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)
76 |
77 | for x, y, nearend_mic_magnitude,nearend_speech_magnitude in train_loader:
78 | print(x.shape) # torch.Size([64, 322, 999])
79 | print(y.shape) # torch.Size([64, 161, 999])
80 | print(nearend_mic_magnitude.shape)
81 |
--------------------------------------------------------------------------------
/model/Baseline.py:
--------------------------------------------------------------------------------
1 | # Author:凌逆战
2 | # -*- coding:utf-8 -*-
3 | """
4 | 作用:随便搭建的模型,只要符合输入大小是[64, 322, 999],输出大小是[64, 161, 999],就能跑通
5 | """
6 | import torch.nn as nn
7 | import torch
8 |
9 |
10 | class Base_model(nn.Module):
11 | def __init__(self):
12 | super(Base_model, self).__init__()
13 | # [batch, channel, input_size] (B, F, T)
14 | # [64, 322, 999] ---> [64, 161, 999]
15 | self.model = nn.Sequential(
16 | nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1),
17 | nn.LeakyReLU(0.2),
18 | nn.Conv1d(in_channels=322, out_channels=322, kernel_size=3, stride=1, padding=1),
19 | nn.LeakyReLU(0.2),
20 | nn.Conv1d(in_channels=322, out_channels=161, kernel_size=3, stride=1, padding=1),
21 | nn.LeakyReLU(0.2),
22 | nn.Conv1d(in_channels=161, out_channels=161, kernel_size=3, stride=1, padding=1),
23 | nn.Sigmoid()
24 | )
25 |
26 | def forward(self, x):
27 | """
28 | :param x: 麦克风信号和远端信号的特征串联在一起作为输入特征 (322, 206)
29 | :return: IRM_mask * input = 近端语音对数谱
30 | """
31 | Estimated_IRM = self.model(x)
32 |
33 | return Estimated_IRM
34 |
35 |
36 | if __name__ == "__main__":
37 | model = Base_model().cuda()
38 | x = torch.randn(8, 322, 999).to("cuda") # 输入 [8, 322, 999]
39 | y = model(x) # 输出 [8, 161, 999]
40 | print(y.shape)
41 |
--------------------------------------------------------------------------------
/model/TCN_model.py:
--------------------------------------------------------------------------------
1 | # Author:凌逆战
2 | # -*- coding:utf-8 -*-
3 | """
4 | 作用:
5 | """
6 | import torch.nn as nn
7 | import torch
8 |
9 | # Author:凌逆战
10 | # -*- coding:utf-8 -*-
11 | import torch
12 | from torch import nn
13 | from torch.nn import init
14 | from torch.nn.utils import weight_norm
15 | from torch.nn import functional as F
16 |
17 |
18 | def weights_init(m):
19 | classname = m.__class__.__name__
20 | if classname.find("Conv") != -1:
21 | m.weight.data.normal_(0.0, 0.02)
22 | elif classname.find("BatchNorm2d") != -1:
23 | m.weight.data.normal_(1.0, 0.02)
24 | m.bias.data.fill_(0)
25 |
26 |
27 | class PixelShuffle1D(nn.Module):
28 | """
29 | 1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
30 | Upscales sample length, downscales channel length
31 | "short" is input, "long" is output
32 | """
33 |
34 | def __init__(self, upscale_factor):
35 | super(PixelShuffle1D, self).__init__()
36 | self.upscale_factor = upscale_factor
37 |
38 | def forward(self, x):
39 | batch_size, short_channel_len, short_width = x.size()
40 |
41 | long_channel_len = short_channel_len // self.upscale_factor
42 | long_width = self.upscale_factor * short_width
43 |
44 | x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width])
45 | x = x.permute(0, 2, 3, 1).contiguous()
46 | x = x.view(batch_size, long_channel_len, long_width)
47 |
48 | return x
49 |
50 |
51 | class Chomp1d(nn.Module):
52 | def __init__(self, chomp_size):
53 | super(Chomp1d, self).__init__()
54 | self.chomp_size = chomp_size
55 |
56 | def forward(self, x):
57 | return x[:, :, :-self.chomp_size].contiguous()
58 |
59 |
60 | class TCN_block(nn.Module):
61 | """这个padding和我的有些不同"""
62 |
63 | def __init__(self, in_channel, out_channel, kernel_size, dilation):
64 | super(TCN_block, self).__init__()
65 | padding = (kernel_size - 1) * dilation
66 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size, padding=padding, dilation=dilation)
67 | self.chomp1 = Chomp1d(padding)
68 | self.bn1 = nn.BatchNorm1d(num_features=out_channel) # BN有bias的作用
69 | self.LeakyReLU1 = nn.LeakyReLU(negative_slope=0.2)
70 | # self.dropout1 = nn.Dropout(dropout)
71 | # ---------------------------------------------------------------
72 | self.conv2 = nn.Conv1d(out_channel, out_channel * 2, kernel_size, padding=kernel_size // 2)
73 | self.bn2 = nn.BatchNorm1d(num_features=out_channel * 2)
74 | # ---------------------------------------------------------------
75 | if in_channel == 2 * out_channel:
76 | self.downsample = None
77 | else:
78 | self.downsample = nn.Conv1d(in_channel, out_channel * 2, kernel_size=1)
79 | self.LeakyReLU = nn.LeakyReLU(negative_slope=0.2)
80 | self.init_weights()
81 |
82 | def init_weights(self):
83 | init.orthogonal_(self.conv1.weight)
84 | init.zeros_(self.conv1.bias)
85 | init.orthogonal_(self.conv2.weight)
86 | init.zeros_(self.conv2.bias)
87 | # BN层
88 | init.normal_(self.bn1.weight, mean=1.0, std=0.02)
89 | init.constant_(self.bn1.bias, 0)
90 | init.normal_(self.bn2.weight, mean=1.0, std=0.02)
91 | init.constant_(self.bn2.bias, 0)
92 | if self.downsample is not None:
93 | init.orthogonal_(self.downsample.weight)
94 |
95 | def forward(self, input):
96 | x = self.conv1(input)
97 | x = self.chomp1(x)
98 | x = self.bn1(x)
99 | x = self.LeakyReLU1(x)
100 | # --------------------------
101 | x = self.conv2(x)
102 | out = self.bn2(x)
103 |
104 | if self.downsample is None:
105 | res = input
106 | else:
107 | res = self.downsample(input)
108 | return self.LeakyReLU(out + res)
109 |
110 |
111 | class TCN_block_k3(nn.Module):
112 | # 如果用这个kernel必须设置为3
113 | def __init__(self, in_channel, out_channel, kernel_size, dilation):
114 | super(TCN_block_k3, self).__init__()
115 | self.padding = nn.ReflectionPad1d(dilation)
116 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=3, stride=1, dilation=dilation)
117 | self.bn1 = nn.BatchNorm1d(num_features=out_channel) # BN有bias的作用
118 | self.leakyrelu_1 = nn.LeakyReLU(negative_slope=0.2)
119 | # self.dropout1 = nn.Dropout(dropout)
120 | # ---------------------------------------------------------------
121 | self.conv2 = nn.Conv1d(out_channel, out_channel * 2, kernel_size=1, stride=1)
122 | self.bn2 = nn.BatchNorm1d(num_features=out_channel * 2)
123 | # ---------------------------------------------------------------
124 | if in_channel == 2 * out_channel:
125 | self.downsample = None
126 | else:
127 | self.downsample = nn.Conv1d(in_channel, out_channel * 2, kernel_size=1)
128 | self.LeakyReLU = nn.LeakyReLU(negative_slope=0.2)
129 | self.init_weights()
130 |
131 | def init_weights(self):
132 | init.orthogonal_(self.conv1.weight) # 第一层卷积权重初始化
133 | init.orthogonal_(self.conv2.weight) # 第二层卷积权重初始化
134 | # BN层
135 | init.normal_(self.bn1.weight, mean=1.0, std=0.02)
136 | init.constant_(self.bn1.bias, 0)
137 | init.normal_(self.bn2.weight, mean=1.0, std=0.02)
138 | init.constant_(self.bn2.bias, 0)
139 | if self.downsample is not None:
140 | init.orthogonal_(self.downsample.weight)
141 |
142 | def forward(self, input):
143 | x = self.padding(input)
144 | x = self.conv1(x)
145 | x = self.bn1(x)
146 | x = self.leakyrelu_1(x)
147 | # --------------------------
148 | x = self.conv2(x)
149 | out = self.bn2(x)
150 |
151 | if self.downsample is None:
152 | # print("1")
153 | res = input
154 | else:
155 | # print("2")
156 | res = self.downsample(input)
157 | return self.LeakyReLU(out + res)
158 |
159 |
160 | class TCN_model(nn.Module):
161 | def __init__(self):
162 | super(TCN_model, self).__init__()
163 | # (64, 322, 998)
164 | self.first = nn.Conv1d(in_channels=322, out_channels=161, kernel_size=9, stride=2, padding=9 // 2) # (64, 161, 499)
165 | self.TCN_conv0 = TCN_block(in_channel=161, out_channel=161, kernel_size=9, dilation=2 ** 0)
166 | self.TCN_conv1 = TCN_block(in_channel=322, out_channel=161, kernel_size=9, dilation=2 ** 1)
167 | self.TCN_conv2 = TCN_block(in_channel=322, out_channel=161, kernel_size=9, dilation=2 ** 2)
168 |
169 | self.lastconv = nn.Conv1d(in_channels=322, out_channels=2, kernel_size=9, stride=1, padding=9 // 2)
170 | self.subpix = PixelShuffle1D(upscale_factor=2)
171 | self.init_weights()
172 |
173 | def init_weights(self):
174 | init.orthogonal_(self.first.weight)
175 | init.constant_(self.first.bias, 0)
176 | init.normal_(self.lastconv.weight, mean=0, std=1e-3)
177 | init.constant_(self.lastconv.bias, 0)
178 |
179 | def forward(self, input):
180 | # inputs (64, 322, 999)
181 | x = self.first(input) # torch.Size([64, 161, 500])
182 | print(x.shape)
183 | x = self.TCN_conv0(x) # torch.Size([64, 322, 500])
184 | print(x.shape)
185 | x = self.TCN_conv1(x)
186 | print(x.shape)
187 | x = self.TCN_conv2(x)
188 | print(x.shape)
189 | x = self.lastconv(x)
190 | print(x.shape)
191 | x = self.subpix(x)
192 | print(x.shape)
193 | return x
194 |
195 |
196 | x = torch.randn(64, 322, 998)
197 | model = TCN_model()
198 | output = model(x)
199 | print(output.shape)
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/model/__pycache__/Baseline.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/model/__pycache__/Baseline.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/ops.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/model/__pycache__/ops.cpython-36.pyc
--------------------------------------------------------------------------------
/model/ops.py:
--------------------------------------------------------------------------------
1 | # Author:凌逆战
2 | # -*- coding:utf-8 -*-
3 | """
4 | 作用:
5 | """
6 | import torch
7 |
8 |
9 |
10 | def frequency_MSE_loss(logits, labels):
11 | """ 均方误差,频域损失
12 | labels:batch_labels
13 | logits:batch_logits
14 | """
15 | loss = torch.mean((logits - labels) ** 2)
16 | return loss
17 |
18 |
19 | def frequency_RMSE_loss(logits, labels):
20 | """ 均方根误差,频域损失
21 | labels:batch_labels
22 | logits:batch_logits
23 | """
24 | loss = torch.sqrt(torch.mean((logits - labels) ** 2, dim=[1, 2]))
25 | loss = torch.mean(loss, dim=0)
26 | return loss
27 |
28 |
29 | def frequency_MAE_loss(logits, labels):
30 | """ 平均绝对值误差,频域损失
31 | labels:batch_labels
32 | logits:batch_logits
33 | """
34 | loss = torch.mean(torch.abs(logits - labels))
35 | return loss
36 |
37 | # ###################### 计算LSD ######################
38 | def pytorch_LSD(logits, labels):
39 | # (…, freq, time)
40 |
41 | logits_log = torch.log10(logits ** 2 + 3e-9)
42 | labels_log = torch.log10(labels ** 2 + 3e-9)
43 | original_target_squared = (labels_log - logits_log) ** 2
44 |
45 | # lsd = torch.mean(torch.sqrt(torch.mean(original_target_squared, dim=0)))
46 | lsd = torch.mean(torch.sqrt(torch.mean(original_target_squared, dim=1)), dim=1)
47 | lsd = torch.mean(lsd, dim=0)
48 |
49 | return lsd
--------------------------------------------------------------------------------
/test/echo_signal/echo_fileid_9992.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/echo_signal/echo_fileid_9992.wav
--------------------------------------------------------------------------------
/test/echo_signal/echo_fileid_9993.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/echo_signal/echo_fileid_9993.wav
--------------------------------------------------------------------------------
/test/echo_signal/echo_fileid_9994.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/echo_signal/echo_fileid_9994.wav
--------------------------------------------------------------------------------
/test/farend_speech/farend_speech_fileid_9992.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/farend_speech/farend_speech_fileid_9992.wav
--------------------------------------------------------------------------------
/test/farend_speech/farend_speech_fileid_9993.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/farend_speech/farend_speech_fileid_9993.wav
--------------------------------------------------------------------------------
/test/farend_speech/farend_speech_fileid_9994.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/farend_speech/farend_speech_fileid_9994.wav
--------------------------------------------------------------------------------
/test/model_test.py:
--------------------------------------------------------------------------------
1 | # Author:凌逆战
2 | # -*- coding:utf-8 -*-
3 | """
4 | 作用:通过模型生成近端语音
5 | """
6 | import librosa
7 | import matplotlib
8 | import torchaudio
9 | import torch.nn.functional as F
10 | import torch
11 | import matplotlib.pyplot as plt
12 | from model.Baseline import Base_model
13 | from matplotlib.ticker import FuncFormatter
14 | import numpy as np
15 |
16 | plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
17 | plt.rcParams['axes.unicode_minus'] = False # 用来正常显示符号
18 |
19 |
20 | def spectrogram(wav_path, win_length=320):
21 | wav, _ = torchaudio.load(wav_path)
22 | wav = wav.squeeze()
23 |
24 | if len(wav) < 160000:
25 | wav = F.pad(wav, (0, 160000 - len(wav)), mode="constant", value=0)
26 | # if len(wav) != 160000:
27 | # print(wav_path)
28 | # print(len(wav))
29 |
30 | S = torch.stft(wav, n_fft=win_length, hop_length=win_length // 2,
31 | win_length=win_length, window=torch.hann_window(window_length=win_length),
32 | center=False, return_complex=True)
33 | magnitude = torch.abs(S)
34 | phase = torch.exp(1j * torch.angle(S))
35 | return magnitude, phase
36 |
37 |
38 | fs = 16000
39 | farend_speech = "./farend_speech/farend_speech_fileid_9992.wav"
40 | nearend_mic_signal = "./nearend_mic_signal/nearend_mic_fileid_9992.wav"
41 | nearend_speech = "./nearend_speech/nearend_speech_fileid_9992.wav"
42 | echo_signal = "./echo_signal/echo_fileid_9992.wav"
43 |
44 | print("GPU是否可用:", torch.cuda.is_available()) # True
45 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46 |
47 | farend_speech_magnitude, farend_speech_phase = spectrogram(farend_speech) # 远端语音 振幅,相位
48 | nearend_mic_magnitude, nearend_mic_phase = spectrogram(nearend_mic_signal) # 近端麦克风语音 振幅,相位
49 | nearend_speech_magnitude, nearend_speech_phase = spectrogram(nearend_speech) # 近端语音振 幅,相位
50 |
51 | farend_speech_magnitude = farend_speech_magnitude.to(device)
52 | nearend_mic_phase = nearend_mic_phase.to(device)
53 | nearend_mic_magnitude = nearend_mic_magnitude.to(device)
54 |
55 | nearend_speech_magnitude = nearend_speech_magnitude.to(device)
56 | nearend_speech_phase = nearend_speech_phase.to(device)
57 |
58 | model = Base_model().to(device) # 实例化模型
59 | checkpoint = torch.load("../checkpoints/AEC_baseline/10.pth")
60 | model.load_state_dict(checkpoint["model"])
61 |
62 | X = torch.cat((farend_speech_magnitude, nearend_mic_magnitude), dim=0)
63 | X = X.unsqueeze(0)
64 | per_mask = model(X) # [1, 322, 999]-->[1, 161, 999]
65 |
66 | per_nearend_magnitude = per_mask * nearend_mic_magnitude # 预测的近端语音 振幅
67 |
68 | complex_stft = per_nearend_magnitude * nearend_mic_phase # 振幅*相位=语音复数表示
69 | print("complex_stft", complex_stft.shape) # [1, 161, 999]
70 |
71 | per_nearend = torch.istft(complex_stft, n_fft=320, hop_length=160, win_length=320,
72 | window=torch.hann_window(window_length=320).to("cuda"))
73 |
74 | torchaudio.save("./predict/nearend_speech_fileid_9992.wav", src=per_nearend.cpu().detach(), sample_rate=fs)
75 | # print("近端语音", per_nearend.shape) # [1, 159680]
76 |
77 | y, _ = librosa.load(nearend_speech, sr=fs)
78 | time_y = np.arange(0, len(y)) * (1.0 / fs)
79 | recover_wav, _ = librosa.load("./predict/nearend_speech_fileid_9992.wav", sr=16000)
80 | time_recover = np.arange(0, len(recover_wav)) * (1.0 / fs)
81 |
82 | plt.figure(figsize=(8,6))
83 | ax_1 = plt.subplot(3, 1, 1)
84 | plt.title("近端语音和预测近端波形图", fontsize=14)
85 | plt.plot(time_y, y, label="近端语音")
86 | plt.plot(time_recover, recover_wav, label="深度学习生成的近端语音波形")
87 | plt.xlabel('时间/s', fontsize=14)
88 | plt.ylabel('幅值', fontsize=14)
89 | plt.xticks(fontsize=14)
90 | plt.yticks(fontsize=14)
91 | plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
92 | plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距
93 | plt.legend()
94 |
95 | norm = matplotlib.colors.Normalize(vmin=-200, vmax=-40)
96 | ax_2 = plt.subplot(3, 1, 2)
97 | plt.title("近端语音频谱", fontsize=14)
98 | plt.specgram(y, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm)
99 | plt.xlabel('时间/s', fontsize=14)
100 | plt.ylabel('频率/kHz', fontsize=14)
101 | plt.xticks(fontsize=14)
102 | plt.yticks(fontsize=14)
103 | plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
104 | plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距
105 |
106 | ax_3 = plt.subplot(3, 1, 3)
107 | plt.title("深度学习生成的近端语音频谱", fontsize=14)
108 | plt.specgram(recover_wav, Fs=fs, scale_by_freq=True, sides='default', cmap="jet", norm=norm)
109 | plt.xlabel('时间/s', fontsize=14)
110 | plt.ylabel('频率/kHz', fontsize=14)
111 | plt.xticks(fontsize=14)
112 | plt.yticks(fontsize=14)
113 | plt.subplots_adjust(top=0.932, bottom=0.085, left=0.110, right=0.998)
114 | plt.subplots_adjust(hspace=0.809, wspace=0.365) # 调整子图间距
115 |
116 | def formatnum(x, pos):
117 | return '$%d$' % (x / 1000)
118 |
119 |
120 | formatter = FuncFormatter(formatnum)
121 | ax_2.yaxis.set_major_formatter(formatter)
122 | ax_3.yaxis.set_major_formatter(formatter)
123 |
124 |
125 | plt.show()
126 |
--------------------------------------------------------------------------------
/test/nearend_mic_signal/nearend_mic_fileid_9992.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_mic_signal/nearend_mic_fileid_9992.wav
--------------------------------------------------------------------------------
/test/nearend_mic_signal/nearend_mic_fileid_9993.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_mic_signal/nearend_mic_fileid_9993.wav
--------------------------------------------------------------------------------
/test/nearend_mic_signal/nearend_mic_fileid_9994.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_mic_signal/nearend_mic_fileid_9994.wav
--------------------------------------------------------------------------------
/test/nearend_speech/nearend_speech_fileid_9992.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_speech/nearend_speech_fileid_9992.wav
--------------------------------------------------------------------------------
/test/nearend_speech/nearend_speech_fileid_9993.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_speech/nearend_speech_fileid_9993.wav
--------------------------------------------------------------------------------
/test/nearend_speech/nearend_speech_fileid_9994.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/nearend_speech/nearend_speech_fileid_9994.wav
--------------------------------------------------------------------------------
/test/predict/深度学习生成的nearend_speech_fileid_9992.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LXP-Never/AEC_DeepModel/1236fd703487f9a09122fbee316f3af1b88c143a/test/predict/深度学习生成的nearend_speech_fileid_9992.wav
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import DataLoader
4 | from torch import nn
5 | import argparse
6 | from tensorboardX import SummaryWriter
7 |
8 | from data_preparation.data_preparation import FileDateset
9 | from model.Baseline import Base_model
10 | from model.ops import pytorch_LSD
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 | # 重头开始训练 defaule=None, 继续训练defaule设置为'/**.pth'
16 | parser.add_argument("--model_name", type=str, default=None, help="是否加载模型继续训练 '/50.pth' None")
17 | parser.add_argument("--batch-size", type=int, default=16, help="")
18 | parser.add_argument("--epochs", type=int, default=20)
19 | parser.add_argument('--lr', type=float, default=3e-4, help='学习率 (default: 0.01)')
20 | parser.add_argument('--train_data', default="./data_preparation/Synthetic/TRAIN", help='数据集的path')
21 | parser.add_argument('--val_data', default="./data_preparation/Synthetic/VAL", help='验证样本的path')
22 | parser.add_argument('--checkpoints_dir', default="./checkpoints/AEC_baseline", help='模型检查点文件的路径(以继续培训)')
23 | parser.add_argument('--event_dir', default="./event_file/AEC_baseline", help='tensorboard事件文件的地址')
24 | args = parser.parse_args()
25 | return args
26 |
27 |
28 | def main():
29 | args = parse_args()
30 | print("GPU是否可用:", torch.cuda.is_available()) # True
31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 |
33 | # 实例化 Dataset
34 | train_set = FileDateset(dataset_path=args.train_data) # 实例化训练数据集
35 | val_set = FileDateset(dataset_path=args.val_data) # 实例化验证数据集
36 |
37 | # 数据加载器
38 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
39 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True)
40 |
41 | # ########### 保存检查点的地址(如果检查点不存在,则创建) ############
42 | if not os.path.exists(args.checkpoints_dir):
43 | os.makedirs(args.checkpoints_dir)
44 |
45 | ################################
46 | # 实例化模型 #
47 | ################################
48 | model = Base_model().to(device) # 实例化模型
49 | # summary(model, input_size=(322, 999)) # 模型输出 torch.Size([64, 322, 999])
50 | # ########### 损失函数 ############
51 | criterion = nn.MSELoss(reduce=True, size_average=True, reduction='mean')
52 |
53 | ###############################
54 | # 创建优化器 Create optimizers #
55 | ###############################
56 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, )
57 | # lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.1)
58 |
59 | # ########### TensorBoard可视化 summary ############
60 | writer = SummaryWriter(args.event_dir) # 创建事件文件
61 |
62 | # ########### 加载模型检查点 ############
63 | start_epoch = 0
64 | if args.model_name:
65 | print("加载模型:", args.checkpoints_dir + args.model_name)
66 | checkpoint = torch.load(args.checkpoints_dir + args.model_name)
67 | model.load_state_dict(checkpoint["model"])
68 | optimizer.load_state_dict(checkpoint["optimizer"])
69 | start_epoch = checkpoint['epoch']
70 | # lr_schedule.load_state_dict(checkpoint['lr_schedule']) # 加载lr_scheduler
71 |
72 | for epoch in range(start_epoch, args.epochs):
73 | model.train() # 训练模型
74 | for batch_idx, (train_X, train_mask, train_nearend_mic_magnitude, train_nearend_magnitude) in enumerate(
75 | train_loader):
76 | train_X = train_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
77 | train_mask = train_mask.to(device) # IRM [batch_size 161, 999]
78 | train_nearend_mic_magnitude = train_nearend_mic_magnitude.to(device)
79 | train_nearend_magnitude = train_nearend_magnitude.to(device)
80 |
81 | # 前向传播
82 | pred_mask = model(train_X) # [batch_size, 322, 999]--> [batch_size, 161, 999]
83 | train_loss = criterion(pred_mask, train_mask)
84 |
85 | # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
86 | pred_near_spectrum = pred_mask * train_nearend_mic_magnitude
87 | train_lsd = pytorch_LSD(train_nearend_magnitude, pred_near_spectrum)
88 |
89 | # 反向传播
90 | optimizer.zero_grad() # 将梯度清零
91 | train_loss.backward() # 反向传播
92 | optimizer.step() # 更新参数
93 |
94 | # ########### 可视化打印 ############
95 | print('Train Epoch: {} Loss: {:.6f} LSD: {:.6f}'.format(epoch + 1, train_loss.item(), train_lsd.item()))
96 |
97 | # ########### TensorBoard可视化 summary ############
98 | # lr_schedule.step() # 学习率衰减
99 | # writer.add_scalar(tag="lr", scalar_value=model.state_dict()['param_groups'][0]['lr'], global_step=epoch + 1)
100 | writer.add_scalar(tag="train_loss", scalar_value=train_loss.item(), global_step=epoch + 1)
101 | writer.add_scalar(tag="train_lsd", scalar_value=train_lsd.item(), global_step=epoch + 1)
102 | writer.flush()
103 |
104 | # 神经网络在验证数据集上的表现
105 | model.eval() # 测试模型
106 | # 测试的时候不需要梯度
107 | with torch.no_grad():
108 | for val_batch_idx, (val_X, val_mask, val_nearend_mic_magnitude, val_nearend_magnitude) in enumerate(
109 | val_loader):
110 | val_X = val_X.to(device) # 远端语音cat麦克风语音 [batch_size, 322, 999] (, F, T)
111 | val_mask = val_mask.to(device) # IRM [batch_size 161, 999]
112 | val_nearend_mic_magnitude = val_nearend_mic_magnitude.to(device)
113 | val_nearend_magnitude = val_nearend_magnitude.to(device)
114 |
115 | # 前向传播
116 | val_pred_mask = model(val_X)
117 | val_loss = criterion(val_pred_mask, val_mask)
118 |
119 | # 近端语音信号频谱 = mask * 麦克风信号频谱 [batch_size, 161, 999]
120 | val_pred_near_spectrum = val_pred_mask * val_nearend_mic_magnitude
121 | val_lsd = pytorch_LSD(val_nearend_magnitude, val_pred_near_spectrum)
122 |
123 | # ########### 可视化打印 ############
124 | print(' val Epoch: {} \tLoss: {:.6f}\tlsd: {:.6f}'.format(epoch + 1, val_loss.item(), val_lsd.item()))
125 | ######################
126 | # 更新tensorboard #
127 | ######################
128 | writer.add_scalar(tag="val_loss", scalar_value=val_loss.item(), global_step=epoch + 1)
129 | writer.add_scalar(tag="val_lsd", scalar_value=val_lsd.item(), global_step=epoch + 1)
130 | writer.flush()
131 |
132 | # # ########### 保存模型 ############
133 | if (epoch + 1) % 10 == 0:
134 | checkpoint = {
135 | "model": model.state_dict(),
136 | "optimizer": optimizer.state_dict(),
137 | "epoch": epoch + 1,
138 | # 'lr_schedule': lr_schedule.state_dict()
139 | }
140 | torch.save(checkpoint, '%s/%d.pth' % (args.checkpoints_dir, epoch + 1))
141 |
142 |
143 | if __name__ == "__main__":
144 | main()
145 |
--------------------------------------------------------------------------------