├── DeFTAN_figure.png ├── DeFT_AN.py ├── Exp_result.png ├── Sub-block_figure.png ├── audio clips ├── clip 1 │ ├── enhanced.wav │ ├── noisy.wav │ └── target.wav ├── clip 2 │ ├── enhanced.wav │ ├── noisy.wav │ └── target.wav ├── clip 3 │ ├── enhanced.wav │ ├── noisy.wav │ └── target.wav └── clip 4 │ ├── enhanced.wav │ ├── noisy.wav │ └── target.wav ├── extra ├── DeFTAN_revised_version.py └── stft_loss.py ├── generate_rir ├── donghen_pyroom_rirs.py └── pyroom_rir.cfg ├── readme.md └── speech_enhancement.png /DeFTAN_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/DeFTAN_figure.png -------------------------------------------------------------------------------- /DeFT_AN.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | from torch import Tensor 9 | from torch.autograd import Variable 10 | 11 | 12 | class DenseBlock(nn.Module): 13 | def __init__(self, in_channels, depth): 14 | super(DenseBlock, self).__init__() 15 | self.depth = depth 16 | self.in_channels = in_channels 17 | self.block = nn.ModuleList([]) 18 | for i in range(self.depth): 19 | self.block.append(nn.Sequential( 20 | nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=(3, 3), padding=(1, 1), dilation=(1, 1)), 21 | nn.LayerNorm(257), 22 | nn.PReLU(in_channels) 23 | )) 24 | 25 | def forward(self, input): 26 | skip = input 27 | for i in range(self.depth): 28 | output = self.block[i](skip) 29 | skip = torch.cat([output, skip], dim=1) 30 | 31 | return output 32 | 33 | 34 | class Freq_Conformer(nn.Module): 35 | def __init__(self, dim_model, num_head, dropout): 36 | super(Freq_Conformer, self).__init__() 37 | self.dim_model = dim_model 38 | self.num_head = num_head 39 | self.dim_ffn = 4 * self.dim_model 40 | self.dropout = dropout 41 | 42 | self.self_attn = nn.MultiheadAttention(embed_dim=dim_model, num_heads=num_head, dropout=dropout, batch_first=True) 43 | self.dropout1 = nn.Dropout(self.dropout) 44 | self.norm1 = nn.LayerNorm(self.dim_model, eps=1e-5) 45 | 46 | self.linear1 = nn.Linear(self.dim_model, self.dim_ffn) 47 | self.activation = nn.GELU() 48 | self.linear2 = nn.Linear(self.dim_ffn, self.dim_model) 49 | 50 | self.dropout3 = nn.Dropout(self.dropout) 51 | self.norm2 = nn.LayerNorm(self.dim_model, eps=1e-5) 52 | 53 | def forward(self, input): 54 | att_out, attn = self.self_attn(input, input, input) 55 | norm_out = self.norm1(input + self.dropout1(att_out)) 56 | 57 | ffw_out = self.linear2(self.activation(self.linear1(norm_out))) 58 | output = self.norm2(norm_out + self.dropout3(ffw_out)) 59 | 60 | return output 61 | 62 | 63 | class Time_Conformer(nn.Module): 64 | def __init__(self, dim_model, num_head, dropout, num_layers): 65 | super(Time_Conformer, self).__init__() 66 | self.dim_model = dim_model 67 | self.num_head = num_head 68 | self.dim_ffn = 4 * self.dim_model 69 | self.dropout = dropout 70 | self.num_layers = num_layers 71 | 72 | self.self_attn = nn.MultiheadAttention(embed_dim=dim_model, num_heads=num_head, dropout=dropout, batch_first=True) 73 | self.dropout1 = nn.Dropout(self.dropout) 74 | self.norm1 = nn.LayerNorm(self.dim_model, eps=1e-5) 75 | 76 | self.linear1 = nn.Linear(self.dim_model, self.dim_ffn) 77 | self.activation = nn.GELU() 78 | convs = [] 79 | for i in range(self.num_layers): 80 | convs.append(nn.Sequential( 81 | nn.Conv1d(self.dim_ffn, self.dim_ffn, kernel_size=3, dilation = 2 ** i, padding='same', groups=self.dim_ffn), 82 | nn.GroupNorm(1, self.dim_ffn), 83 | nn.PReLU(self.dim_ffn) 84 | )) 85 | self.seq_conv = nn.Sequential(*convs) 86 | self.linear2 = nn.Linear(self.dim_ffn, self.dim_model) 87 | 88 | self.dropout3 = nn.Dropout(self.dropout) 89 | self.norm2 = nn.LayerNorm(self.dim_model, eps=1e-5) 90 | 91 | def forward(self, input): 92 | att_out, attn = self.self_attn(input, input, input) 93 | norm_out = self.norm1(input + self.dropout1(att_out)) 94 | 95 | ffw_mid = self.activation(self.linear1(norm_out)).transpose(1, 2).contiguous() 96 | ffw_out = self.linear2(self.seq_conv(ffw_mid).transpose(1, 2).contiguous()) 97 | output = self.norm2(norm_out + self.dropout3(ffw_out)) 98 | 99 | return output 100 | 101 | 102 | class proposed_block(nn.Module): 103 | def __init__(self, ch_dim, num_head, dropout, depth): 104 | super(proposed_block, self).__init__() 105 | self.dense_block = DenseBlock(in_channels=ch_dim, depth=depth) 106 | self.freq_conformer = Freq_Conformer(dim_model=ch_dim, num_head=num_head, dropout=dropout) 107 | self.temp_conformer = Time_Conformer(dim_model=ch_dim, num_head=num_head, dropout=dropout, num_layers=3) 108 | 109 | def forward(self, input): 110 | B, C, L, F = input.size() 111 | 112 | output_c = self.dense_block(input) # B, C, L, F 113 | 114 | input_f = output_c.permute(0, 2, 3, 1).contiguous() # B, L, F, C 115 | input_f = input_f.view(B*L, F, C) # B*L, F, C 116 | output_f = self.freq_conformer(input_f) # B*L, F, C 117 | output_f = output_f.view(B, L, F, C) # B, L, F, C 118 | 119 | input_t = output_f.permute(0, 2, 1, 3).contiguous() # B, F, L, C 120 | input_t = input_t.view(B*F, L, C) # B*F, L, C 121 | output_t = self.temp_conformer(input_t) # B*F, L, C 122 | output_t = output_t.view(B, F, L, C) # B, F, L, C 123 | 124 | output = output_t.permute(0, 3, 2, 1) # B, C, L, F 125 | 126 | return output 127 | 128 | 129 | class Network(nn.Module): 130 | def __init__(self, mic_num=4, ch_dim=64, win=512, num_head=4, dropout=0.1, num_layer=4, depth=5): 131 | super(Network, self).__init__() 132 | 133 | self.mic_num = mic_num 134 | self.ch_dim = ch_dim 135 | 136 | self.win = win 137 | self.hop = self.win // 2 138 | 139 | self.prelu = nn.PReLU() 140 | self.dim_model = ch_dim 141 | self.num_head = num_head 142 | self.dim_ffn = self.dim_model * 2 143 | 144 | self.dropout = dropout 145 | self.n_conv_groups = self.dim_ffn 146 | 147 | self.num_layer = num_layer 148 | 149 | self.inp_conv = nn.Sequential( 150 | nn.Conv2d(in_channels=2 * self.mic_num, out_channels=self.ch_dim, kernel_size=(5, 5), padding=(2, 2)), 151 | nn.LayerNorm(257), 152 | nn.PReLU(self.ch_dim) 153 | ) 154 | 155 | self.proposed_block = nn.ModuleList([]) 156 | for ii in range(num_layer): 157 | self.proposed_block.append( 158 | proposed_block(ch_dim=ch_dim, num_head=num_head, dropout=dropout, depth=depth) 159 | ) 160 | 161 | self.out_conv = nn.Conv2d(in_channels=self.ch_dim, out_channels=2, kernel_size=(5, 5), padding=(2, 2)) 162 | 163 | def pad_signal(self, input): 164 | # input is the waveforms: (B, T) or (B, 1, T) 165 | # reshape and padding 166 | if input.dim() not in [2, 3]: 167 | raise RuntimeError("Input can only be 2 or 3 dimensional.") 168 | 169 | if input.dim() == 2: 170 | input = input.unsqueeze(1) 171 | batch_size = input.size(0) 172 | nchannel = input.size(1) 173 | nsample = input.size(2) 174 | rest = self.win - (self.hop + nsample % self.win) % self.win 175 | if rest > 0: 176 | pad = Variable(torch.zeros(batch_size, nchannel, rest)).type(input.type()) 177 | input = torch.cat([input, pad], 2) 178 | 179 | pad_aux = Variable(torch.zeros(batch_size, nchannel, self.hop)).type(input.type()) 180 | input = torch.cat([pad_aux, input, pad_aux], 2) 181 | 182 | return input, rest 183 | 184 | def forward(self, input): 185 | input, rest = self.pad_signal(input) 186 | B, M, T = input.size() # batch B, mic M, time samples T 187 | 188 | stft_input = torch.stft(input.view([-1, T]), n_fft=self.win, hop_length=self.hop, return_complex=False) 189 | _, F, L, _ = stft_input.size() # B*M , F= num freqs, L= num frame, 2= real imag 190 | xi = stft_input.view([B, M, F, L, 2]) # B*M, F, L, 2 -> B, M, F, L, 2 191 | xi = xi.permute(0, 1, 4, 3, 2).contiguous() # B, M, 2, L, F 192 | xi = xi.view([B, M*2, L, F]) # B, 2*M, L, F 193 | 194 | xo = self.inp_conv(xi) # B, C, L, F 195 | for idx in range(self.num_layer): 196 | xo = self.proposed_block[idx](xo) # B, C, L, F 197 | mask = self.out_conv(xo) 198 | ref_enc_out = xi[:, 0:2] 199 | masked_enc_out = ref_enc_out * mask 200 | yo = masked_enc_out.permute(0, 3, 2, 1).contiguous() # B, 2, L, F -> B, F, L, 2 201 | 202 | # yo = yi.permute(0, 3, 2, 1).contiguous() # B, 2, L, F -> B, F, L, 2 203 | istft_input = torch.complex(yo[:, :, :, 0], yo[:, :, :, 1]) 204 | istft_output = torch.istft(istft_input, n_fft=self.win, hop_length=self.hop, return_complex=False) 205 | 206 | output = istft_output[:, self.hop:-(rest + self.hop)].unsqueeze(1) # B, 1, T 207 | 208 | return output 209 | -------------------------------------------------------------------------------- /Exp_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/Exp_result.png -------------------------------------------------------------------------------- /Sub-block_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/Sub-block_figure.png -------------------------------------------------------------------------------- /audio clips/clip 1/enhanced.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 1/enhanced.wav -------------------------------------------------------------------------------- /audio clips/clip 1/noisy.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 1/noisy.wav -------------------------------------------------------------------------------- /audio clips/clip 1/target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 1/target.wav -------------------------------------------------------------------------------- /audio clips/clip 2/enhanced.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 2/enhanced.wav -------------------------------------------------------------------------------- /audio clips/clip 2/noisy.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 2/noisy.wav -------------------------------------------------------------------------------- /audio clips/clip 2/target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 2/target.wav -------------------------------------------------------------------------------- /audio clips/clip 3/enhanced.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 3/enhanced.wav -------------------------------------------------------------------------------- /audio clips/clip 3/noisy.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 3/noisy.wav -------------------------------------------------------------------------------- /audio clips/clip 3/target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 3/target.wav -------------------------------------------------------------------------------- /audio clips/clip 4/enhanced.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 4/enhanced.wav -------------------------------------------------------------------------------- /audio clips/clip 4/noisy.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 4/noisy.wav -------------------------------------------------------------------------------- /audio clips/clip 4/target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/audio clips/clip 4/target.wav -------------------------------------------------------------------------------- /extra/DeFTAN_revised_version.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable, Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | from torch import Tensor 9 | from torch.autograd import Variable 10 | from flash_attn import flash_attn_qkvpacked_func, flash_attn_func 11 | 12 | 13 | class DenseBlock(nn.Module): 14 | def __init__(self, in_channels, depth): 15 | super(DenseBlock, self).__init__() 16 | self.depth = depth 17 | self.in_channels = in_channels 18 | self.block = nn.ModuleList([]) 19 | for i in range(self.depth): 20 | self.block.append(nn.Sequential( 21 | nn.Conv2d(self.in_channels * (i + 1), self.in_channels, kernel_size=(3, 3), padding=(1, 1), dilation=(1, 1)), 22 | nn.LayerNorm(257), 23 | nn.PReLU(in_channels) 24 | )) 25 | 26 | def forward(self, input): 27 | skip = input 28 | for i in range(self.depth): 29 | output = self.block[i](skip) 30 | skip = torch.cat([output, skip], dim=1) 31 | 32 | return output 33 | 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, heads, dropout): 37 | dim_head = dim 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 46 | self.to_k = nn.Linear(dim, inner_dim, bias = False) 47 | self.to_v = nn.Linear(dim, inner_dim, bias = False) 48 | 49 | self.dropout = dropout 50 | 51 | self.to_out = nn.Sequential( 52 | nn.Linear(inner_dim, dim), 53 | nn.Dropout(dropout) 54 | ) if project_out else nn.Identity() 55 | 56 | def forward(self, x): 57 | q = rearrange(self.to_q(x), 'b n (h d) -> b h n d', h = self.heads) 58 | k = rearrange(self.to_k(x), 'b n (h d) -> b h n d', h = self.heads) 59 | v = rearrange(self.to_v(x), 'b n (h d) -> b h n d', h = self.heads) 60 | # FlashAttention 61 | q, k, v = q.to(torch.bfloat16), k.to(torch.bfloat16), v.to(torch.bfloat16) 62 | out = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=self.dropout) 63 | out = rearrange(out.transpose(1, 2), 'b h n d -> b n (h d)') 64 | return self.to_out(out) 65 | 66 | 67 | class Freq_Conformer(nn.Module): 68 | def __init__(self, dim_model, num_head, dropout): 69 | super(Freq_Conformer, self).__init__() 70 | self.dim_model = dim_model 71 | self.num_head = num_head 72 | self.dim_ffn = 4 * self.dim_model 73 | self.dropout = dropout 74 | 75 | self.self_attn = Attention(dim=dim_model, heads=num_head, dropout=dropout) 76 | self.dropout1 = nn.Dropout(self.dropout) 77 | self.norm1 = nn.LayerNorm(self.dim_model, eps=1e-5) 78 | 79 | self.linear1 = nn.Linear(self.dim_model, self.dim_ffn) 80 | self.activation = nn.GELU() 81 | self.linear2 = nn.Linear(self.dim_ffn, self.dim_model) 82 | 83 | self.dropout3 = nn.Dropout(self.dropout) 84 | self.norm2 = nn.LayerNorm(self.dim_model, eps=1e-5) 85 | 86 | def forward(self, input): 87 | att_out = self.self_attn(input) 88 | norm_out = self.norm1(input + self.dropout1(att_out)) 89 | 90 | ffw_out = self.linear2(self.activation(self.linear1(norm_out))) 91 | output = self.norm2(norm_out + self.dropout3(ffw_out)) 92 | 93 | return output 94 | 95 | 96 | class Time_Conformer(nn.Module): 97 | def __init__(self, dim_model, num_head, dropout, num_layers): 98 | super(Time_Conformer, self).__init__() 99 | self.dim_model = dim_model 100 | self.num_head = num_head 101 | self.dim_ffn = 4 * self.dim_model 102 | self.dropout = dropout 103 | self.num_layers = num_layers 104 | 105 | self.self_attn = Attention(dim=dim_model, heads=num_head, dropout=dropout) 106 | self.dropout1 = nn.Dropout(self.dropout) 107 | self.norm1 = nn.LayerNorm(self.dim_model, eps=1e-5) 108 | 109 | self.linear1 = nn.Linear(self.dim_model, self.dim_ffn) 110 | self.activation = nn.GELU() 111 | convs = [] 112 | for i in range(self.num_layers): 113 | convs.append(nn.Sequential( 114 | nn.Conv1d(self.dim_ffn, self.dim_ffn, kernel_size=3, dilation = 2 ** i, padding='same', groups=self.dim_ffn), 115 | nn.GroupNorm(1, self.dim_ffn), 116 | nn.PReLU(self.dim_ffn) 117 | )) 118 | self.seq_conv = nn.Sequential(*convs) 119 | self.linear2 = nn.Linear(self.dim_ffn, self.dim_model) 120 | 121 | self.dropout3 = nn.Dropout(self.dropout) 122 | self.norm2 = nn.LayerNorm(self.dim_model, eps=1e-5) 123 | 124 | def forward(self, input): 125 | att_out = self.self_attn(input) 126 | norm_out = self.norm1(input + self.dropout1(att_out)) 127 | 128 | ffw_mid = self.activation(self.linear1(norm_out)).transpose(1, 2).contiguous() 129 | ffw_out = self.linear2(self.seq_conv(ffw_mid).transpose(1, 2).contiguous()) 130 | output = self.norm2(norm_out + self.dropout3(ffw_out)) 131 | 132 | return output 133 | 134 | 135 | class proposed_block(nn.Module): 136 | def __init__(self, ch_dim, num_head, dropout, depth): 137 | super(proposed_block, self).__init__() 138 | self.dense_block = DenseBlock(in_channels=ch_dim, depth=depth) 139 | self.freq_conformer = Freq_Conformer(dim_model=ch_dim, num_head=num_head, dropout=dropout) 140 | self.temp_conformer = Time_Conformer(dim_model=ch_dim, num_head=num_head, dropout=dropout, num_layers=3) 141 | 142 | def forward(self, input): 143 | B, C, L, F = input.size() 144 | 145 | output_c = self.dense_block(input) # B, C, L, F 146 | 147 | input_f = output_c.permute(0, 2, 3, 1).contiguous() # B, L, F, C 148 | input_f = input_f.view(B*L, F, C) # B*L, F, C 149 | output_f = self.freq_conformer(input_f) # B*L, F, C 150 | output_f = output_f.view(B, L, F, C) # B, L, F, C 151 | 152 | input_t = output_f.permute(0, 2, 1, 3).contiguous() # B, F, L, C 153 | input_t = input_t.view(B*F, L, C) # B*F, L, C 154 | output_t = self.temp_conformer(input_t) # B*F, L, C 155 | output_t = output_t.view(B, F, L, C) # B, F, L, C 156 | 157 | output = output_t.permute(0, 3, 2, 1) # B, C, L, F 158 | 159 | return output 160 | 161 | 162 | class Network(nn.Module): 163 | def __init__(self, mic_num=4, ch_dim=64, win=512, num_head=4, dropout=0.1, num_layer=4, depth=5): 164 | super(Network, self).__init__() 165 | 166 | self.mic_num = mic_num 167 | self.out_ch_num = out_ch_num 168 | self.ch_dim = ch_dim 169 | 170 | self.win = win 171 | self.hop = self.win // 2 172 | 173 | self.prelu = nn.PReLU() 174 | self.dim_model = ch_dim 175 | self.num_head = num_head 176 | self.dim_ffn = self.dim_model * 2 177 | 178 | self.dropout = dropout 179 | self.n_conv_groups = self.dim_ffn 180 | 181 | self.num_layer = num_layer 182 | 183 | self.inp_conv = nn.Sequential( 184 | nn.Conv2d(in_channels=2 * self.mic_num, out_channels=self.ch_dim, kernel_size=(5, 5), padding=(2, 2)), 185 | nn.LayerNorm(257), 186 | nn.PReLU(self.ch_dim) 187 | ) 188 | 189 | self.proposed_block = nn.ModuleList([]) 190 | for ii in range(num_layer): 191 | self.proposed_block.append( 192 | proposed_block(ch_dim=ch_dim, num_head=num_head, dropout=dropout, depth=depth) 193 | ) 194 | 195 | self.out_conv = nn.Conv2d(in_channels=self.ch_dim, out_channels=2 * self.mic_num, kernel_size=(5, 5), padding=(2, 2)) 196 | 197 | def pad_signal(self, input): 198 | # input is the waveforms: (B, T) or (B, 1, T) 199 | # reshape and padding 200 | if input.dim() not in [2, 3]: 201 | raise RuntimeError("Input can only be 2 or 3 dimensional.") 202 | 203 | if input.dim() == 2: 204 | input = input.unsqueeze(1) 205 | batch_size = input.size(0) 206 | nchannel = input.size(1) 207 | nsample = input.size(2) 208 | rest = self.win - (self.hop + nsample % self.win) % self.win 209 | if rest > 0: 210 | pad = Variable(torch.zeros(batch_size, nchannel, rest)).type(input.type()) 211 | input = torch.cat([input, pad], 2) 212 | 213 | pad_aux = Variable(torch.zeros(batch_size, nchannel, self.hop)).type(input.type()) 214 | input = torch.cat([pad_aux, input, pad_aux], 2) 215 | 216 | return input, rest 217 | 218 | def forward(self, input): 219 | input, rest = self.pad_signal(input) 220 | B, M, T = input.size() # batch B, mic M, time samples T 221 | 222 | stft_input = torch.stft(input.view([-1, T]), n_fft=self.win, hop_length=self.hop, return_complex=False) 223 | _, F, L, _ = stft_input.size() # B*M , F= num freqs, L= num frame, 2= real imag 224 | xi = stft_input.view([B, M, F, L, 2]) # B*M, F, L, 2 -> B, M, F, L, 2 225 | xi = xi.permute(0, 1, 4, 3, 2).contiguous() # B, M, 2, L, F 226 | xi = xi.view([B, M*2, L, F]) # B, 2*M, L, F 227 | 228 | xo = self.inp_conv(xi) # B, C, L, F 229 | for idx in range(self.num_layer): 230 | xo = self.proposed_block[idx](xo) # B, C, L, F 231 | mask = self.out_conv(xo) # B, 2*M, L, F 232 | 233 | masked_enc_out = xi * mask 234 | yo = masked_enc_out.permute(0, 3, 2, 1).contiguous() # B, 2M, L, F -> B, F, L, 2M 235 | yo = yo.reshape([B, F, L, 2, M]).permute(0, 4, 1, 2, 3).reshape([B*M, F, L, 2]) # BM, F, L, 2 236 | 237 | output = torch.istft(torch.complex(yo[:, :, :, 0], yo[:, :, :, 1]), n_fft=self.win, hop_length=self.hop, return_complex=False) 238 | output = output[:, self.hop:-(rest + self.hop)].reshape([B, M, -1]) 239 | 240 | return output -------------------------------------------------------------------------------- /extra/stft_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import permutations 3 | from torch.autograd import Variable 4 | import torch 5 | import torch.nn as nn 6 | from scipy import linalg 7 | import scipy 8 | 9 | class stft_loss(nn.Module): 10 | def __init__(self, win, hop, loss_type='mae'): 11 | super(stft_loss, self).__init__() 12 | self.win = win 13 | self.hop = hop 14 | self.loss_type = loss_type 15 | 16 | def forward(self, est, org): 17 | stft_org = torch.stft(org.view([-1, org.size(2)]), n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(org.type()), return_complex=False) 18 | stft_est = torch.stft(est.view([-1, est.size(2)]), n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(est.type()), return_complex=False) 19 | if self.loss_type == 'mse': 20 | stft_loss = torch.mean((stft_org - stft_est) ** 2) 21 | elif self.loss_type == 'mae': 22 | stft_loss = torch.mean(torch.abs(stft_org - stft_est)) 23 | 24 | return stft_loss 25 | 26 | 27 | class PCM_Loss(nn.Module): 28 | def __init__(self): 29 | super(PCM_Loss, self).__init__() 30 | self.loss = stft_loss(win=512, hop=256, loss_type='mae') 31 | 32 | def forward(self, mixed, estimation, origin): 33 | speech_estimation = estimation 34 | speech_origin = origin 35 | 36 | noise_estimation = mixed - estimation 37 | noise_origin = mixed - origin 38 | 39 | loss_speech = self.loss(speech_estimation, speech_origin) 40 | loss_noise = self.loss(noise_estimation, noise_origin) 41 | 42 | tot_loss = 0.5*loss_speech + 0.5*loss_noise 43 | 44 | return tot_loss -------------------------------------------------------------------------------- /generate_rir/donghen_pyroom_rirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import configparser as CP 5 | import multiprocessing 6 | from itertools import repeat 7 | 8 | import math 9 | import random 10 | import numpy as np 11 | 12 | import pyroomacoustics as pra 13 | import utils 14 | import scipy.io.wavfile 15 | 16 | 17 | PROCESSES = multiprocessing.cpu_count() 18 | 19 | 20 | def gen_rirs(params, filenum): 21 | while True: 22 | np.random.seed(random.randint(0, 65536)) 23 | max_tries_room = 1000 24 | 25 | # Set RT60 26 | rt60 = params['room_rt60_min'] + np.random.rand() * (params['room_rt60_max'] - params['room_rt60_min']) 27 | print(f"Generating RIR #{filenum}, RT60: {rt60}s...") 28 | 29 | # Set room geometry 30 | width = params['room_width_min'] + np.random.rand() * (params['room_width_max'] - params['room_width_min']) 31 | length = params['room_length_min'] + np.random.rand() * (params['room_length_max'] - params['room_length_min']) 32 | height = params['room_height_min'] + np.random.rand() * (params['room_height_max'] - params['room_height_min']) 33 | room_geometry = np.array([width, length, height]) 34 | 35 | # Set microphone array location 36 | axx = width / 2 + (-0.5) + np.random.rand() * 1 37 | ayy = length / 2 + (-0.5) + np.random.rand() * 1 38 | azz = 1 + np.random.rand() * 1 39 | 40 | array_center_location = np.array([axx, ayy, azz]) 41 | mic_angle = np.random.rand() * math.pi / 4 42 | 43 | # Set a speech source location 44 | cnt = 0 45 | while True: 46 | source_location = params['room_offset_inside'] + np.random.rand(3) * (room_geometry - params['room_offset_inside'] * 2) 47 | src_dist = np.sqrt(np.sum(np.power(source_location - array_center_location, 2))) 48 | 49 | noise_location = params['room_offset_inside'] + np.random.rand(3) * (room_geometry - params['room_offset_inside'] * 2) 50 | nse_dist = np.sqrt(np.sum(np.power(noise_location - array_center_location, 2))) 51 | if params['array_source_distance_min'] < src_dist < params['array_source_distance_max']: 52 | break 53 | cnt += 1 54 | if cnt > max_tries_room: 55 | assert 0, f"Speech source locating failed." 56 | 57 | # Build shoebox room 58 | # print("Building a shoebox room...") 59 | # Set absorption coefficients 60 | e_absorption, _ = pra.inverse_sabine(rt60=rt60, room_dim=room_geometry) 61 | # reverberant speech 62 | room1 = pra.ShoeBox(room_geometry, fs=params['fs'], max_order=6, ray_tracing=True, materials=pra.Material(e_absorption)) 63 | room1.set_ray_tracing() 64 | # direct speech 65 | room2 = pra.ShoeBox(room_geometry, fs=params['fs'], max_order=0) 66 | # reverberant noise 67 | room3 = pra.ShoeBox(room_geometry, fs=params['fs'], max_order=6, ray_tracing=True, materials=pra.Material(e_absorption)) 68 | room3.set_ray_tracing() 69 | 70 | # Add microphone array 71 | # print("Adding microphone array...") 72 | # circular array 73 | mics = pra.beamforming.circular_microphone_array_xyplane(center=array_center_location, M=params['microphone_num'], 74 | phi0=mic_angle, radius=params['microphone_radius'], 75 | fs=params['fs'], directivity=None) 76 | # you can set the height of microphone array position and source position same 77 | # reverberant speech 78 | room1.add(mics) 79 | # direct speech 80 | room2.add(mics) 81 | # reverberant noise 82 | room3.add(mics) 83 | 84 | # Add a speech source location 85 | # print("Adding a speech source location...") 86 | # reverberant speech 87 | room1.add_source(position=source_location) 88 | # direct speech 89 | room2.add_source(position=source_location) 90 | # reverberant noise 91 | room3.add_source(position=noise_location) 92 | 93 | # Compute RIRs 94 | # print("Computing RIRs...") 95 | # reverberant speech 96 | room1.compute_rir() 97 | # direct speech 98 | room2.compute_rir() 99 | # reverberant noise 100 | room3.compute_rir() 101 | 102 | # filename 103 | filename1 = f'RIR_{width:.2f}-{length:.2f}-{height:.2f}' \ 104 | f'_rt60-{room1.rt60_theory():.2f}_index_{filenum}_reverb.wav' 105 | filename2 = f'RIR_{width:.2f}-{length:.2f}-{height:.2f}' \ 106 | f'_rt60-{room1.rt60_theory():.2f}_index_{filenum}_direct.wav' 107 | filename3 = f'RIR_{width:.2f}-{length:.2f}-{height:.2f}' \ 108 | f'_rt60-{room1.rt60_theory():.2f}_index_{filenum}_noise.wav' 109 | # pyroomacoustic bugs! 110 | save_dir1 = os.path.join(params['rir_proc_dir'], filename1) 111 | save_dir2 = os.path.join(params['rir_proc_dir'], filename2) 112 | save_dir3 = os.path.join(params['rir_proc_dir'], filename3) 113 | 114 | rirs1 = room1.rir 115 | rirs2 = room2.rir 116 | rirs3 = room3.rir 117 | 118 | # To make np.array & zero padding 119 | rir_len_ls1 = [] 120 | for mic_idx in range(len(rirs1)): 121 | for src_idx in range(len(rirs1[mic_idx])): 122 | rir_len_ls1.append(len(rirs1[mic_idx][src_idx])) 123 | rir_len_max1 = max(rir_len_ls1) 124 | for mic_idx in range(len(rirs1)): 125 | for src_idx in range(len(rirs1[mic_idx])): 126 | rirs1[mic_idx][src_idx] = np.pad(rirs1[mic_idx][src_idx], (0, rir_len_max1 - len(rirs1[mic_idx][src_idx])), 'constant') 127 | rirs1 = np.array(rirs1) 128 | if not np.all(np.any(rirs1, axis=-1)): 129 | continue 130 | 131 | # To make np.array & zero padding 132 | rir_len_ls2 = [] 133 | for mic_idx in range(len(rirs1)): 134 | for src_idx in range(len(rirs1[mic_idx])): 135 | rir_len_ls2.append(len(rirs1[mic_idx][src_idx])) 136 | rir_len_max2 = max(rir_len_ls2) 137 | for mic_idx in range(len(rirs2)): 138 | for src_idx in range(len(rirs2[mic_idx])): 139 | rirs2[mic_idx][src_idx] = np.pad(rirs2[mic_idx][src_idx], (0, rir_len_max2 - len(rirs2[mic_idx][src_idx])), 'constant') 140 | rirs2 = np.array(rirs2) 141 | if not np.all(np.any(rirs2, axis=-1)): 142 | continue 143 | 144 | # To make np.array & zero padding 145 | rir_len_ls3 = [] 146 | for mic_idx in range(len(rirs1)): 147 | for src_idx in range(len(rirs1[mic_idx])): 148 | rir_len_ls3.append(len(rirs1[mic_idx][src_idx])) 149 | rir_len_max3 = max(rir_len_ls3) 150 | for mic_idx in range(len(rirs3)): 151 | for src_idx in range(len(rirs3[mic_idx])): 152 | rirs3[mic_idx][src_idx] = np.pad(rirs3[mic_idx][src_idx], (0, rir_len_max3 - len(rirs3[mic_idx][src_idx])), 'constant') 153 | rirs3 = np.array(rirs3) 154 | if not np.all(np.any(rirs3, axis=-1)): 155 | continue 156 | 157 | scipy.io.wavfile.write(save_dir1, params['fs'], np.transpose(np.squeeze(rirs1, axis=1))) 158 | scipy.io.wavfile.write(save_dir2, params['fs'], np.transpose(np.squeeze(rirs2, axis=1))) 159 | scipy.io.wavfile.write(save_dir3, params['fs'], np.transpose(np.squeeze(rirs3, axis=1))) 160 | 161 | # try: 162 | # np.save(save_dir1, rirs1) 163 | # np.save(save_dir2, rirs2) 164 | # np.save(save_dir3, rirs3) 165 | # 166 | # except Exception as e: 167 | # print(str(e)) 168 | # pass 169 | 170 | return room_geometry 171 | 172 | 173 | def main_body(): 174 | '''Main body of this file''' 175 | parser = argparse.ArgumentParser() 176 | 177 | # Configurations: read noisyspeech_synthesizer.cfg and gather inputs 178 | parser.add_argument('--cfg', default='pyroom_rir.cfg', 179 | help='Read pyroom_rir.cfg for all the details') 180 | parser.add_argument('--cfg_str', type=str, default='pyroom_rir') 181 | args = parser.parse_args() 182 | 183 | params = dict() 184 | params['args'] = args 185 | cfgpath = os.path.join(os.path.dirname(__file__), args.cfg) 186 | assert os.path.exists(cfgpath), f'No configuration file as [{cfgpath}]' 187 | 188 | cfg = CP.ConfigParser() 189 | cfg._interpolation = CP.ExtendedInterpolation() 190 | cfg.read(cfgpath) 191 | params['cfg'] = cfg._sections[args.cfg_str] 192 | cfg = params['cfg'] 193 | 194 | params['fs'] = int(cfg['sampling_rate']) 195 | 196 | params['room_rt60_min'] = float(cfg['room_rt60_min']) 197 | params['room_rt60_max'] = float(cfg['room_rt60_max']) 198 | 199 | params['room_width_min'] = float(cfg['room_width_min']) 200 | params['room_width_max'] = float(cfg['room_width_max']) 201 | params['room_length_min'] = float(cfg['room_length_min']) 202 | params['room_length_max'] = float(cfg['room_length_max']) 203 | params['room_height_min'] = float(cfg['room_height_min']) 204 | params['room_height_max'] = float(cfg['room_height_max']) 205 | 206 | params['room_offset_inside'] = float(cfg['room_offset_inside']) 207 | 208 | params['microphone_num'] = int(cfg['microphone_num']) 209 | params['microphone_radius'] = float(cfg['microphone_radius']) 210 | 211 | params['array_source_distance_min'] = float(cfg['array_source_distance_min']) 212 | params['array_source_distance_max'] = float(cfg['array_source_distance_max']) 213 | 214 | if cfg['fileindex_start'] != 'None' and cfg['fileindex_start'] != 'None': 215 | params['fileindex_start'] = int(cfg['fileindex_start']) 216 | params['fileindex_end'] = int(cfg['fileindex_end']) 217 | params['num_files'] = int(params['fileindex_end'])-int(params['fileindex_start']) 218 | else: 219 | params['num_files'] = int((params['total_hours']*60*60)/params['audio_length']) 220 | 221 | print('Number of files to be synthesized:', params['num_files']) 222 | params['is_test_set'] = utils.str2bool(cfg['is_test_set']) 223 | params['rir_proc_dir'] = utils.get_dir(cfg, 'rir_destination', 'pyroom_RIRs') 224 | 225 | multi_pool = multiprocessing.Pool(processes=PROCESSES) 226 | fileindices = range(params['num_files']) 227 | output_lists = multi_pool.starmap(gen_rirs, zip(repeat(params), fileindices)) 228 | 229 | 230 | if __name__ == '__main__': 231 | main_body() 232 | -------------------------------------------------------------------------------- /generate_rir/pyroom_rir.cfg: -------------------------------------------------------------------------------- 1 | # Configuration for pyroom_rirs.py 2 | 3 | [pyroom_rirs] 4 | 5 | ;1) WSJCAM0 6 | ;sampling_rate: 16000 7 | ; 8 | ;room_rt60_min: 0.2 9 | ;room_rt60_max: 1.2 10 | ; 11 | ;room_width_min: 5 12 | ;room_width_max: 10 13 | ;room_length_min: 5 14 | ;room_length_max: 10 15 | ;room_height_min: 3 16 | ;room_height_max: 4 17 | ; 18 | ;room_offset_inside: 0.5 19 | ; 20 | ;microphone_num: 8 21 | ;microphone_radius: 0.1 22 | ; 23 | ;array_source_distance_min: 0.75 24 | ;array_source_distance_max: 2.5 25 | ; 26 | ; 27 | ;fileindex_start: 0 28 | ;fileindex_end: 3264 29 | ;is_test_set: False 30 | ;rir_destination: ./pyroomacoustics_wsjcam0_rir 31 | 32 | ;2) DNS_challenge 33 | sampling_rate: 16000 34 | 35 | room_rt60_min: 0.2 36 | room_rt60_max: 1.2 37 | 38 | room_width_min: 5 39 | room_width_max: 10 40 | room_length_min: 5 41 | room_length_max: 10 42 | room_height_min: 3 43 | room_height_max: 4 44 | 45 | room_offset_inside: 0.5 46 | 47 | microphone_num: 4 48 | microphone_radius: 0.1 49 | 50 | array_source_distance_min: 0.75 51 | array_source_distance_max: 2 52 | 53 | fileindex_start: 0 54 | fileindex_end: 3304 55 | is_test_set: False 56 | rir_destination: ./pyroomacoustics_dns_challenge_rir -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # DeFT-AN: Dense Frequency-Time Attentive Network for multichannel speech enhancement 2 | [D. Lee and J-W. Choi, "DeFT-AN: Dense Frequency-Time Attentive Network for Multichannel Speech Enhancement," IEEE Signal Processing Letters vol.30, pp.155-159, 2023](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10042963) 3 | 4 | ![Speech enhancement diagram](speech_enhancement.png) 5 | 6 | 7 | In this study, we propose a dense frequency-time attentive network (DeFT-AN) for multichannel speech enhancement. DeFT-AN is a mask estimation network that predicts a complex spectral masking pattern for suppressing the noise and reverberation embedded in the short-time Fourier transform (STFT) of an input signal. The proposed mask estimation network incorporates three different types of blocks for aggregating information in the spatial, spectral, and temporal dimensions. It utilizes a spectral transformer with a modified feed-forward network and a temporal conformer with sequential dilated convolutions. The use of dense blocks and transformers dedicated to the three different characteristics of audio signals enables more comprehensive enhancement in noisy and reverberant environments. The remarkable performance of DeFT-AN over state-of-the-art multichannel models is demonstrated based on two popular noisy and reverberant datasets in terms of various metrics for speech quality and intelligibility. 8 | 9 | ![DeFTAN diagram](DeFTAN_figure.png) 10 | 11 | The DeFT-AN model has a series of sub-blocks consisting of a dense block for aggregating spatial information and two transformer blocks for handling spectral and temporal information, respectively. To enable more comprehensive analysis and synthesis of the spectral information, we 12 | introduce an F-transformer for focusing on the spectral information, followed by a T-conformer designed to realize the parallelizable architecture without losing the local information or receptive field in time. This is possible by combining a temporal convolutional network (TCN) [23] structure with the attention module. Finally, we demonstrate the performance of DeFT-AN relative to other state-of-the-art approaches 13 | based on training and testing over two noisy reverberant datasets. 14 | 15 | ![DeFT-A_block](Sub-block_figure.png) 16 | 17 | Both experiments with the spatialized WSJCAM0 dataset and the spatialized DNS challenge dataset show that the DeFTAN outperforms the state-of-the-art models by a large margin in terms of most evaluation measures. Notably, the proposed method exhibits highly improved SI-SDR and PESQ values relative to those of the baseline models. The proposed method was also compared with the two-stage approach using ADCN for the first stage and TPARN for the second stage (ADCN-TPARN). The single-stage performance of the proposed model is similar to that of the two-stage approach, but the 2-stage setting of the proposed model again outperforms ADCN-TPARN. 18 | 19 | ![Experimental_result](Exp_result.png) 20 | 21 | 22 | http://www.sound.kaist.ac.kr 23 | -------------------------------------------------------------------------------- /speech_enhancement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/donghoney0416/DeFT-AN/afdd6c790c07a657d9f6b3e9d8aa36491b027f65/speech_enhancement.png --------------------------------------------------------------------------------