├── linear_real.py ├── linear_cplx.py ├── fusion.py ├── README.md ├── ff_real.py ├── ff_cplx.py ├── show.py ├── script.py ├── conv2d_real.py ├── dsconv2d_cplx.py ├── dsconv2d_real.py ├── f_att_real.py ├── conv2d_cplx.py ├── t_att_real.py ├── dilated_dualpath_conformer.py ├── misc.py ├── f_att_cplx.py ├── t_att_cplx.py ├── loss.py ├── uformer.py └── trans.py /linear_real.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | EPSILON = torch.finfo(torch.float32).eps 9 | 10 | 11 | class Real_Linear(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim 17 | ): 18 | super(Real_Linear, self).__init__() 19 | self.linear = nn.Linear(in_dim, out_dim) 20 | 21 | def forward(self, inputs): 22 | # N, *, F 23 | out = self.linear(inputs) 24 | return out 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /linear_cplx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | class Complex_Linear(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim 17 | ): 18 | super(Complex_Linear, self).__init__() 19 | self.real_linear = nn.Linear(in_dim, out_dim) 20 | self.imag_linear = nn.Linear(in_dim, out_dim) 21 | 22 | def forward(self,inputs): 23 | # N, *, F, 2 24 | inputs_real, inputs_imag = inputs[...,0], inputs[...,1] 25 | out_real = self.real_linear(inputs_real) - self.imag_linear(inputs_imag) 26 | out_imag = self.real_linear(inputs_imag) + self.imag_linear(inputs_real) 27 | return torch.stack([out_real, out_imag], -1) 28 | -------------------------------------------------------------------------------- /fusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | import sys 8 | 9 | 10 | EPSILON = torch.finfo(torch.float32).eps 11 | 12 | 13 | def fusion(cplx, mag): 14 | cplx_mag = torch.sqrt(torch.clamp(cplx[...,0]**2+cplx[...,1]**2, EPSILON)) 15 | mag_out = mag + torch.sigmoid(cplx_mag) 16 | cplx_real = cplx[...,0] + torch.sigmoid(mag) 17 | cplx_imag = cplx[...,1] + torch.sigmoid(mag) 18 | cplx_out = torch.stack([cplx_real, cplx_imag], -1) 19 | return cplx_out, mag_out 20 | 21 | def fusion_magpha(cplx, mag): 22 | cplx_mag = torch.sqrt(torch.clamp(cplx[...,0]**2+cplx[...,1]**2, EPSILON)) 23 | cplx_pha = torch.atan2(cplx[...,1]+EPSILON, cplx[...,0]) 24 | mag_out = mag + cplx_mag 25 | cplx_real = cplx[...,0] + mag * torch.cos(cplx_pha) 26 | cplx_imag = cplx[...,1] + mag * torch.sin(cplx_pha) 27 | cplx_out = torch.stack([cplx_real, cplx_imag], -1) 28 | return cplx_out, mag_out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Uformer 2 | The implementation of Uformer: A Unet based dilated complex & real dual-path conformer network for simultaneous speech enhancement and dereverberation 3 | 4 | The paper is available at: https://arxiv.org/abs/2111.06015 5 | 6 | Please cite the paper if you want to follow the idea or results: 7 | 8 | @article{fu2021uformer, 9 | title={Uformer: A Unet based dilated complex \& real dual-path conformer network for simultaneous speech enhancement and dereverberation}, 10 | author={Fu, Yihui and Liu, Yun and Li, Jingdong and Luo, Dawei and Lv, Shubo and Jv, Yukai and Xie, Lei}, 11 | journal={arXiv preprint arXiv:2111.06015}, 12 | year={2021} 13 | } 14 | 15 | 2022.6.18 16 | I make some modifications of the model including: 17 | 1. Replace local temporal attention with global temporal attention. 18 | 2. Cancel Encoder Decoder Attention to reduce the amount of parameter. 19 | 3. Change some activation function and norm function. 20 | -------------------------------------------------------------------------------- /ff_real.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from linear_real import Real_Linear 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | class FF_Real(nn.Module): 12 | 13 | def __init__(self, 14 | in_dim, 15 | hidden_dim): 16 | super(FF_Real, self).__init__() 17 | self.layernorm_linear = nn.LayerNorm(in_dim) 18 | self.linear1 = Real_Linear(in_dim, hidden_dim) 19 | self.linear2 = Real_Linear(hidden_dim, in_dim) 20 | self.prelu = nn.PReLU() 21 | self.dropout = nn.Dropout(p=0.1) 22 | 23 | def forward(self, x): 24 | # N C F T 25 | y = self.layernorm_linear(x.transpose(1, 3)) 26 | y = self.linear1(y) 27 | y = self.prelu(y) 28 | y = self.dropout(y) 29 | y = self.linear2(y) 30 | y = self.dropout(y) 31 | y = y.transpose(1, 3) 32 | y = y*0.5 + x 33 | return y 34 | 35 | if __name__ == '__main__': 36 | net = FF_Real(128, 64) 37 | inputs = torch.ones([10, 128, 4, 397]) 38 | y = net(inputs) 39 | print(y.shape) 40 | -------------------------------------------------------------------------------- /ff_cplx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from linear_cplx import Complex_Linear 8 | EPSILON = torch.finfo(torch.float32).eps 9 | 10 | class FF_Cplx(nn.Module): 11 | 12 | def __init__(self, 13 | in_dim, 14 | hidden_dim): 15 | super(FF_Cplx, self).__init__() 16 | self.layernorm_linear = nn.LayerNorm(in_dim) 17 | self.linear1 = Complex_Linear(in_dim, hidden_dim) 18 | self.linear2 = Complex_Linear(hidden_dim, in_dim) 19 | self.prelu = nn.PReLU() 20 | self.dropout = nn.Dropout(p=0.1) 21 | 22 | def forward(self, x): 23 | # N C F T 2 24 | y = self.layernorm_linear(x.transpose(1, 4)).transpose(1, 4) 25 | y = y.transpose(1, 3) 26 | y = self.linear1(y) 27 | y = self.prelu(y) 28 | y = self.dropout(y) 29 | y = self.linear2(y) 30 | y = self.dropout(y) 31 | y = y.transpose(1, 3) 32 | y = y*0.5 + x 33 | return y 34 | 35 | if __name__ == '__main__': 36 | net = FF_Cplx(128, 64) 37 | inputs = torch.ones([10, 128, 4, 397, 2]) 38 | y = net(inputs) 39 | print(y.shape) 40 | -------------------------------------------------------------------------------- /show.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python -u 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2018 Northwestern Polytechnical University (author: Ke Wang) 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | 11 | def show_params(nnet, fid): 12 | 13 | print("=" * 40, "Model Parameters", "=" * 40) 14 | if fid is not None: 15 | fid.write("=" * 40+ "Model Parameters"+ "=" * 40 +"\n") 16 | num_params = 0 17 | for module_name, m in nnet.named_modules(): 18 | if module_name == '': 19 | for name, params in m.named_parameters(): 20 | print(name, params.size()) 21 | if fid is not None: 22 | fid.write(str(name)+ str(params.size())+'\n') 23 | i = 1 24 | for j in params.size(): 25 | i = i * j 26 | num_params += i 27 | print('[*] Parameter Size: {}'.format(num_params)) 28 | print("=" * 98) 29 | if fid is not None: 30 | fid.write('[*] Parameter Size: {}'.format(num_params)+'\n') 31 | fid.write("=" * 98+'\n') 32 | fid.flush() 33 | 34 | 35 | def show_model(nnet, fid): 36 | print("=" * 40, "Model Structures", "=" * 40) 37 | if fid is not None: 38 | fid.write("=" * 40+ "Model Structures"+"=" * 40+'\n') 39 | for module_name, m in nnet.named_modules(): 40 | if module_name == '': 41 | print(m) 42 | if fid is not None: 43 | fid.write(str(m)) 44 | print("=" * 98) 45 | if fid is not None: 46 | fid.write("=" * 98+'\n') 47 | fid.flush() 48 | -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | from trans import STFT, iSTFT, MelTransform 2 | import soundfile as sf 3 | import torch 4 | from scipy import signal 5 | import numpy as np 6 | 7 | stft = STFT(400, 160) 8 | istft = iSTFT(400, 160) 9 | 10 | y1, fs = sf.read('../fuck/1089-134686-0000.flac') 11 | y2, fs = sf.read('../fuck/121-121726-0000.flac') 12 | rir, fs = sf.read('../fuck/5.46_4.40_3.58_3.04_1.97_1.36_218.3015_254.7778_26.3611_0.2320.wav') 13 | 14 | if y1.shape[0] > y2.shape[0]: 15 | length = y2.shape[0] 16 | else: 17 | length = y1.shape[0] 18 | 19 | y1 = y1[:length] 20 | y2 = y2[:length] 21 | a = [] 22 | b = [] 23 | for i in range(8): 24 | a.append(signal.oaconvolve(y1, rir[:, i])) 25 | b.append(signal.oaconvolve(y2, rir[:, i+8])) 26 | y1 = np.stack(a, -1) 27 | y2 = np.stack(b, -1) 28 | y1 = y1[:-7999] 29 | y2 = y2[:-7999] 30 | a=y1 31 | mixwav = y1 + y2 32 | y1 = torch.from_numpy(y1).float() 33 | y2 = torch.from_numpy(y2).float() 34 | y1 = y1.unsqueeze(0).transpose(2, 1) 35 | y2 = y2.unsqueeze(0).transpose(2, 1) 36 | mix = y1 + y2 37 | mix_real, mix_imag = stft(mix) 38 | y1_real, y1_imag = stft(y1) 39 | y2_real, y2_imag = stft(y2) 40 | 41 | mix_mag = torch.sqrt(torch.clamp(mix_real**2 + mix_imag**2, 1e-7)) 42 | y1_mag = torch.sqrt(torch.clamp(y1_real**2 + y1_imag**2, 1e-7)) 43 | y2_mag = torch.sqrt(torch.clamp(y2_real**2 + y2_imag**2, 1e-7)) 44 | mix_pha = torch.atan2(mix_imag, mix_real) 45 | mask = y1_mag / torch.clamp(mix_mag, 1e-7) 46 | mask = mask[:,0] 47 | # max_abs = torch.norm(mask, float("inf"), dim=1, keepdim=True) 48 | # mask = mask / torch.clamp(max_abs, 1e-7) 49 | # mask = torch.transpose(mask, 1, 2) 50 | mask = mask.squeeze() 51 | 52 | mask = np.array(mask) 53 | enh = mix_mag[0,0]*mask#.transpose(1,0) 54 | enh = enh.unsqueeze(0) 55 | 56 | 57 | enh = istft((enh*torch.cos(mix_pha[:,0]), enh*torch.sin(mix_pha[:,0]))) 58 | enh = enh.squeeze() 59 | enh = np.array(enh) 60 | sf.write('../fuck/masked.wav',enh,16000) 61 | print(enh.shape) 62 | 63 | 64 | 65 | np.save('../fuck/testmask.npy', mask.transpose(1,0)) 66 | sf.write('../fuck/mix.wav',mixwav, 16000) 67 | sf.write('../fuck/spk1.wav',a, 16000) -------------------------------------------------------------------------------- /conv2d_real.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | EPSILON = torch.finfo(torch.float32).eps 9 | 10 | 11 | class RealConv2d_Encoder(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | in_channels, 16 | out_channels, 17 | kernel_size=(1,1), 18 | stride=(1,1), 19 | padding=(0,0), 20 | dilation=(1,1), 21 | groups=1, 22 | ): 23 | ''' 24 | in_channels: real+imag 25 | out_channels: real+imag 26 | ''' 27 | super(RealConv2d_Encoder, self).__init__() 28 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, groups = groups) 29 | 30 | def forward(self, inputs): 31 | # inputs : N C F T 32 | out = self.conv(inputs) 33 | out = out[...,:inputs.shape[-1]] 34 | return out 35 | 36 | class RealConv2d_Decoder(nn.Module): 37 | 38 | def __init__( 39 | self, 40 | in_channels, 41 | out_channels, 42 | kernel_size=(1,1), 43 | stride=(1,1), 44 | padding=(0,0), 45 | output_padding=(0,0), 46 | dilation=(1,1), 47 | groups=1, 48 | ): 49 | ''' 50 | in_channels: real+imag 51 | out_channels: real+imag 52 | ''' 53 | super(RealConv2d_Decoder, self).__init__() 54 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding, output_padding = output_padding, dilation = dilation, groups = groups) 55 | 56 | def forward(self,inputs): 57 | # inputs : N C F T 2 58 | 59 | out = self.conv(inputs) 60 | out = out[...,:inputs.shape[-1]] 61 | return out 62 | 63 | 64 | -------------------------------------------------------------------------------- /dsconv2d_cplx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from conv2d_cplx import ComplexConv2d_Encoder 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | class DSConv2d(nn.Module): 12 | """ 13 | 1D convolutional block: 14 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv 15 | """ 16 | 17 | def __init__(self, 18 | in_channels, 19 | conv_channels, 20 | dilation1, 21 | dilation2, 22 | kernel_size=3, 23 | causal=False): 24 | super(DSConv2d, self).__init__() 25 | # 1x1 conv 26 | self.conv1x1 = ComplexConv2d_Encoder(in_channels, conv_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0,0)) 27 | self.prelu = nn.PReLU() 28 | self.layernorm_conv1 = nn.LayerNorm(in_channels) 29 | dconv_pad1 = (dilation1 * (kernel_size - 1)) // 2 if not causal else ( 30 | dilation1 * (kernel_size - 1)) 31 | dconv_pad2 = (dilation2 * (kernel_size - 1)) // 2 if not causal else ( 32 | dilation2 * (kernel_size - 1)) 33 | # depthwise conv 34 | self.dconv1 = ComplexConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad1), dilation = (1,dilation1)) 35 | self.dconv2 = ComplexConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad2), dilation = (1,dilation2)) 36 | self.layernorm_conv2 = nn.LayerNorm(conv_channels) 37 | # 1x1 conv cross channel 38 | self.sconv = ComplexConv2d_Encoder(conv_channels, in_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0,0)) 39 | # different padding way 40 | self.causal = causal 41 | self.dropout = nn.Dropout(p=0.1) 42 | 43 | 44 | def forward(self, x): 45 | # N C F T 2 46 | y = self.layernorm_conv1(x.transpose(1,4)).transpose(1,4) 47 | 48 | y = self.conv1x1(y) 49 | y = self.prelu(y) 50 | 51 | y1 = self.dconv1(y) 52 | y2 = self.dconv2(y) 53 | 54 | y = y1 * torch.sigmoid(y2) 55 | y = self.layernorm_conv2(y.transpose(1,4)).transpose(1,4) 56 | y = y * torch.sigmoid(y) 57 | y = self.sconv(y) 58 | y = self.dropout(y) 59 | y = x + y 60 | return y 61 | 62 | if __name__ == '__main__': 63 | net = DSConv2d(128, 64, 2, 4) 64 | inputs = torch.ones([10, 128, 4, 397, 2]) 65 | y = net(inputs) 66 | print(y.shape) 67 | -------------------------------------------------------------------------------- /dsconv2d_real.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from conv2d_real import RealConv2d_Encoder 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | class DSConv2d_Real(nn.Module): 12 | """ 13 | 1D convolutional block: 14 | Conv1x1 - PReLU - Norm - DConv - PReLU - Norm - SConv 15 | """ 16 | 17 | def __init__(self, 18 | in_channels, 19 | conv_channels, 20 | dilation1, 21 | dilation2, 22 | kernel_size=3, 23 | causal=False): 24 | super(DSConv2d_Real, self).__init__() 25 | # 1x1 conv 26 | self.conv1x1 = RealConv2d_Encoder(in_channels, conv_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0,0)) 27 | self.prelu = nn.PReLU() 28 | self.layernorm_conv1 = nn.LayerNorm(in_channels) 29 | dconv_pad1 = (dilation1 * (kernel_size - 1)) // 2 if not causal else ( 30 | dilation1 * (kernel_size - 1)) 31 | dconv_pad2 = (dilation2 * (kernel_size - 1)) // 2 if not causal else ( 32 | dilation2 * (kernel_size - 1)) 33 | # depthwise conv 34 | self.dconv1 = RealConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad1), dilation = (1,dilation1)) 35 | self.dconv2 = RealConv2d_Encoder(conv_channels, conv_channels, kernel_size=(3, kernel_size), stride=(1, 1), padding=(1,dconv_pad2), dilation = (1,dilation2)) 36 | self.layernorm_conv2 = nn.LayerNorm(conv_channels) 37 | # 1x1 conv cross channel 38 | self.sconv = RealConv2d_Encoder(conv_channels, in_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0,0)) 39 | # different padding way 40 | self.causal = causal 41 | self.dropout = nn.Dropout(p=0.1) 42 | 43 | # self.se = SELayer(in_channels) 44 | 45 | def forward(self, x): 46 | # N C F T 47 | y = self.layernorm_conv1(x.transpose(1,3)).transpose(1,3) 48 | y = self.conv1x1(y) 49 | y = self.prelu(y) 50 | 51 | y1 = self.dconv1(y) 52 | y2 = self.dconv2(y) 53 | 54 | y = y1 * torch.sigmoid(y2) 55 | y = self.layernorm_conv2(y.transpose(1,3)).transpose(1,3) 56 | y = y * torch.sigmoid(y) 57 | y = self.sconv(y) 58 | y = self.dropout(y) 59 | y = x + y 60 | return y 61 | 62 | if __name__ == '__main__': 63 | net = DSConv2d_Real(128, 64, 2, 4) 64 | inputs = torch.ones([10, 128, 4, 397]) 65 | y = net(inputs) 66 | print(y.shape) 67 | -------------------------------------------------------------------------------- /f_att_real.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from linear_real import Real_Linear 7 | 8 | 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | 12 | 13 | class F_att_real(nn.Module): 14 | def __init__(self, in_channel, hidden_channel): 15 | super(F_att_real, self).__init__() 16 | self.query = Real_Linear(in_channel, hidden_channel) 17 | self.key = Real_Linear(in_channel, hidden_channel) 18 | self.value = Real_Linear(in_channel, hidden_channel) 19 | self.softmax = nn.Softmax(dim = -1) 20 | self.hidden_channel = hidden_channel 21 | def forward(self, q, k, v): 22 | # NT * F * C 23 | query = self.query(q) 24 | key = self.key(k) 25 | value = self.value(v) 26 | # output = [] 27 | energy = torch.einsum("...tf,...fy->...ty", [query, key.transpose(1, 2)]) / self.hidden_channel**0.5 28 | energy = self.softmax(energy) # NT * F * F 29 | weighted_value = torch.einsum("...tf,...fy->...ty", [energy, value]) 30 | 31 | return weighted_value 32 | 33 | class Self_Attention_F_real(nn.Module): 34 | def __init__(self, in_channel, hidden_channel): 35 | super(Self_Attention_F_real, self).__init__() 36 | self.F_att = F_att_real(in_channel, hidden_channel) 37 | self.layernorm1 = nn.LayerNorm(in_channel) 38 | self.layernorm2 = nn.LayerNorm(hidden_channel) 39 | 40 | def forward(self, x): 41 | # N*T, F, C 42 | out = self.layernorm1(x) 43 | out = self.F_att(out, out, out) 44 | out = self.layernorm2(out) 45 | return out 46 | 47 | class Multihead_Attention_F_Branch_real(nn.Module): 48 | def __init__(self, in_channel, hidden_channel, n_heads=1): 49 | super(Multihead_Attention_F_Branch_real, self).__init__() 50 | self.attn_heads = nn.ModuleList([Self_Attention_F_real(in_channel, hidden_channel) for _ in range(n_heads)] ) 51 | self.transform_linear = Real_Linear(hidden_channel, in_channel) 52 | self.layernorm3 = nn.LayerNorm(in_channel) 53 | self.dropout = nn.Dropout(p=0.1) 54 | self.prelu = nn.PReLU() 55 | 56 | def forward(self, inputs): 57 | # N * C * F * T 58 | 59 | N, C, F, T = inputs.shape 60 | x = inputs.permute(0, 3, 2, 1) # N T F C 61 | x = x.contiguous().view([N*T, F, C]) 62 | x = [attn(x) for i, attn in enumerate(self.attn_heads)] 63 | x = torch.stack(x, -1) 64 | x = x.squeeze(-1) 65 | 66 | out = self.transform_linear(x) 67 | 68 | out = out.contiguous().view([N, T, F, C]) 69 | out = out.permute(0, 3, 2, 1) 70 | out = self.prelu(self.layernorm3(out.transpose(1, 3)).transpose(1, 3)) 71 | out = self.dropout(out) 72 | out = out + inputs 73 | return out 74 | 75 | if __name__ == '__main__': 76 | net = Multihead_Attention_F_Branch_real(128, 64) 77 | inputs = torch.ones([10, 128, 4, 397]) 78 | y = net(inputs) 79 | print(y.shape) 80 | -------------------------------------------------------------------------------- /conv2d_cplx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | class ComplexConv2d_Encoder(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | in_channels, 16 | out_channels, 17 | kernel_size=(1,1), 18 | stride=(1,1), 19 | padding=(0,0), 20 | dilation=(1,1), 21 | groups=1, 22 | ): 23 | ''' 24 | in_channels: real+imag 25 | out_channels: real+imag 26 | ''' 27 | super(ComplexConv2d_Encoder, self).__init__() 28 | self.real_conv = nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, groups = groups) 29 | self.imag_conv = nn.Conv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, groups = groups) 30 | 31 | def forward(self,inputs): 32 | # inputs : N C F T 2 33 | inputs_real, inputs_imag = inputs[...,0], inputs[...,1] 34 | out_real = self.real_conv(inputs_real) - self.imag_conv(inputs_imag) 35 | out_imag = self.real_conv(inputs_imag) + self.imag_conv(inputs_real) 36 | out_real = out_real[...,:inputs_real.shape[-1]] 37 | out_imag = out_imag[...,:inputs_imag.shape[-1]] 38 | return torch.stack([out_real, out_imag], -1) 39 | 40 | class ComplexConv2d_Decoder(nn.Module): 41 | 42 | def __init__( 43 | self, 44 | in_channels, 45 | out_channels, 46 | kernel_size=(1,1), 47 | stride=(1,1), 48 | padding=(0,0), 49 | output_padding=(0,0), 50 | dilation=(1,1), 51 | groups=1, 52 | ): 53 | ''' 54 | in_channels: real+imag 55 | out_channels: real+imag 56 | ''' 57 | super(ComplexConv2d_Decoder, self).__init__() 58 | self.real_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding, output_padding = output_padding, dilation = dilation, groups = groups) 59 | self.imag_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = padding, output_padding = output_padding, dilation = dilation, groups = groups) 60 | 61 | def forward(self,inputs): 62 | # inputs : N C F T 2 63 | inputs_real, inputs_imag = inputs[...,0], inputs[...,1] 64 | out_real = self.real_conv(inputs_real) - self.imag_conv(inputs_imag) 65 | out_imag = self.real_conv(inputs_imag) + self.imag_conv(inputs_real) 66 | out_real = out_real[...,:inputs_real.shape[-1]] 67 | out_imag = out_imag[...,:inputs_imag.shape[-1]] 68 | return torch.stack([out_real, out_imag], -1) 69 | 70 | 71 | -------------------------------------------------------------------------------- /t_att_real.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from linear_real import Real_Linear 7 | 8 | EPSILON = torch.finfo(torch.float32).eps 9 | 10 | 11 | 12 | class T_att_real(nn.Module): 13 | def __init__(self, in_channel, hidden_channel): 14 | super(T_att_real, self).__init__() 15 | self.query = Real_Linear(in_channel, hidden_channel) 16 | self.key = Real_Linear(in_channel, hidden_channel) 17 | self.value = Real_Linear(in_channel, hidden_channel) 18 | self.softmax = nn.Softmax(dim = -1) 19 | self.hidden_channel = hidden_channel 20 | 21 | def forward(self, q, k, v): 22 | causal = False 23 | # NF * T * C 24 | query = self.query(q) 25 | key = self.key(k) 26 | value = self.value(v) 27 | energy = torch.einsum("...tf,...fy->...ty", [query, key.transpose(1, 2)]) / 16**0.5 28 | if causal: 29 | mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2]), diagonal=0) 30 | mask = mask.to(energy.device) 31 | energy = energy * mask 32 | energy = self.softmax(energy) # NF * T * T 33 | weighted_value = torch.einsum("...tf,...fy->...ty", [energy, value]) 34 | 35 | return weighted_value 36 | 37 | class Self_Attention_T_real(nn.Module): 38 | def __init__(self, in_channel, hidden_channel): 39 | super(Self_Attention_T_real, self).__init__() 40 | self.T_att = T_att_real(in_channel, hidden_channel) 41 | 42 | self.layernorm1 = nn.LayerNorm(in_channel) 43 | self.layernorm2 = nn.LayerNorm(hidden_channel) 44 | 45 | def forward(self, x): 46 | # N*F, T, C 47 | out = self.layernorm1(x) 48 | out = self.T_att(out, out, out) 49 | out = self.layernorm2(out) 50 | return out 51 | 52 | class Multihead_Attention_T_Branch_real(nn.Module): 53 | def __init__(self, in_channel, hidden_channel, n_heads=1): 54 | super(Multihead_Attention_T_Branch_real, self).__init__() 55 | self.attn_heads = nn.ModuleList([Self_Attention_T_real(in_channel, hidden_channel) for _ in range(n_heads)] ) 56 | self.transform_linear = Real_Linear(hidden_channel, in_channel) 57 | self.layernorm3 = nn.LayerNorm(in_channel) 58 | self.dropout = nn.Dropout(p=0.1) 59 | self.prelu = nn.PReLU() 60 | 61 | def forward(self, inputs): 62 | # N * C * F * T 63 | 64 | N, C, F, T = inputs.shape 65 | x = inputs.permute(0, 2, 3, 1) # N F T C 66 | x = x.contiguous().view([N*F, T, C]) 67 | x = [attn(x) for i, attn in enumerate(self.attn_heads)] 68 | x = torch.stack(x, -1) 69 | x = x.squeeze(-1) 70 | outs = self.transform_linear(x) 71 | outs = outs.contiguous().view([N, F, T, C]) 72 | outs = self.prelu(self.layernorm3(outs)) 73 | outs = self.dropout(outs) 74 | outs = outs.permute(0, 3, 1, 2) 75 | outs = outs + inputs 76 | return outs 77 | 78 | 79 | 80 | if __name__ == '__main__': 81 | net = Multihead_Attention_T_Branch_real(128, 64) 82 | inputs = torch.ones([10, 128, 4, 397]) 83 | y = net(inputs) 84 | print(y.shape) 85 | -------------------------------------------------------------------------------- /dilated_dualpath_conformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import warnings 8 | 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | from f_att_cplx import Multihead_Attention_F_Branch 12 | from t_att_cplx import Multihead_Attention_T_Branch 13 | from f_att_real import Multihead_Attention_F_Branch_real 14 | from t_att_real import Multihead_Attention_T_Branch_real 15 | from dsconv2d_cplx import DSConv2d 16 | from dsconv2d_real import DSConv2d_Real 17 | from ff_real import FF_Real 18 | from ff_cplx import FF_Cplx 19 | from fusion import fusion as fusion 20 | from show import show_model, show_params 21 | 22 | 23 | class Dilated_Dualpath_Conformer(nn.Module): 24 | 25 | def __init__(self, inchannel=128, hiddenchannel=64): 26 | super(Dilated_Dualpath_Conformer, self).__init__() 27 | 28 | self.ff1_cplx = FF_Cplx(inchannel, hiddenchannel) 29 | self.ff1_mag = FF_Real(inchannel, hiddenchannel) 30 | 31 | 32 | self.cplx_tatt = Multihead_Attention_T_Branch(inchannel, 16) 33 | self.cplx_fatt = Multihead_Attention_F_Branch(inchannel, 16) 34 | self.mag_tatt = Multihead_Attention_T_Branch_real(inchannel, 16) 35 | self.mag_fatt = Multihead_Attention_F_Branch_real(inchannel, 16) 36 | 37 | 38 | dilation = [1, 2, 4, 8, 16, 32, 64, 128] 39 | self.dsconv_cplx = nn.ModuleList() 40 | for idx in range(len(dilation)): 41 | self.dsconv_cplx.append(DSConv2d(inchannel, 32, dilation1=dilation[idx], dilation2=dilation[len(dilation)-idx-1])) 42 | self.dsconv_real = nn.ModuleList() 43 | for idx in range(len(dilation)): 44 | self.dsconv_real.append(DSConv2d_Real(inchannel, 32, dilation1=dilation[idx], dilation2=dilation[len(dilation)-idx-1])) 45 | 46 | 47 | self.ff2_cplx = FF_Cplx(inchannel, hiddenchannel) 48 | self.ff2_mag = FF_Real(inchannel, hiddenchannel) 49 | 50 | self.ln_conformer_cplx = nn.LayerNorm(inchannel) 51 | self.ln_conformer_mag = nn.LayerNorm(inchannel) 52 | 53 | def forward(self, cplx, mag): 54 | # N C F T 2 55 | # N C F T 56 | 57 | cplx = self.ff1_cplx(cplx) 58 | mag= self.ff1_mag(mag) 59 | cplx, mag = fusion(cplx, mag) 60 | 61 | cplx = self.cplx_tatt(cplx) 62 | mag = self.mag_tatt(mag) 63 | cplx, mag = fusion(cplx, mag) 64 | 65 | cplx = self.cplx_fatt(cplx) 66 | mag = self.mag_fatt(mag) 67 | cplx, mag = fusion(cplx, mag) 68 | 69 | for idx in range(len(self.dsconv_cplx)): 70 | cplx = self.dsconv_cplx[idx](cplx) 71 | mag = self.dsconv_real[idx](mag) 72 | cplx, mag = fusion(cplx, mag) 73 | 74 | cplx = self.ff2_cplx(cplx) 75 | mag= self.ff2_mag(mag) 76 | cplx, mag = fusion(cplx, mag) 77 | cplx, mag = self.ln_conformer_cplx(cplx.transpose(1,4)).transpose(1,4), self.ln_conformer_mag(mag.transpose(1,3)).transpose(1,3) 78 | return cplx, mag 79 | 80 | 81 | if __name__ == '__main__': 82 | net = Dilated_Dualpath_Conformer(128, 64) 83 | inputs = torch.ones([10, 128, 4, 397, 2]) 84 | y = net(inputs, inputs[...,0]) 85 | print(y[0].shape, y[1].shape) 86 | -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python -u 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2018 Northwestern Polytechnical University (author: Ke Wang) 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import os 11 | 12 | import torch 13 | 14 | 15 | # def save_checkpoint(model1, model2, model3, optimizer1, optimizer2, optimizer3, epoch, step, checkpoint_dir): 16 | def save_checkpoint(model1, optimizer1, epoch, step, checkpoint_dir): 17 | checkpoint_path = os.path.join( 18 | checkpoint_dir, 'model.ckpt-{}-{}.pt'.format(epoch,step)) 19 | torch.save({'model1': model1.state_dict(), 20 | # 'model2': model2.state_dict(), 21 | # 'model3': model3.state_dict(), 22 | 'optimizer1': optimizer1.state_dict(), 23 | # 'optimizer2': optimizer2.state_dict(), 24 | # 'optimizer3': optimizer3.state_dict(), 25 | 'epoch': epoch, 26 | 'step': step}, checkpoint_path) 27 | with open(os.path.join(checkpoint_dir, 'checkpoint'), 'w') as f: 28 | f.write('model.ckpt-{}-{}.pt'.format(epoch,step)) 29 | print("=> Save checkpoint:", checkpoint_path) 30 | 31 | 32 | # def reload_model(model1, model2, model3, optimizer1, optimizer2, optimizer3, checkpoint_dir, use_cuda=True): 33 | def reload_model(model1, optimizer1, checkpoint_dir, use_cuda=True): 34 | ckpt_name = os.path.join(checkpoint_dir, 'checkpoint') 35 | if os.path.isfile(ckpt_name): 36 | with open(ckpt_name, 'r') as f: 37 | model_name = f.readline().strip() 38 | checkpoint_path = os.path.join(checkpoint_dir, model_name) 39 | checkpoint = load_checkpoint(checkpoint_path, use_cuda) 40 | model1.load_state_dict(checkpoint['model1'])#, strict=False 41 | optimizer1.load_state_dict(checkpoint['optimizer1']) 42 | epoch = checkpoint['epoch'] 43 | step = checkpoint['step'] 44 | print('=> Reload previous model and optimizer.',model_name) 45 | else: 46 | print('[!] checkpoint directory is empty. Train a new model ...') 47 | epoch = 0 48 | step = 0 49 | return epoch, step 50 | 51 | 52 | # def reload_for_eval(model1, model2, model3, checkpoint_dir, use_cuda): 53 | def reload_for_eval(model1, checkpoint_dir, use_cuda): 54 | ckpt_name = os.path.join(checkpoint_dir, 'checkpoint') 55 | if os.path.isfile(ckpt_name): 56 | with open(ckpt_name, 'r') as f: 57 | model_name = f.readline().strip() 58 | checkpoint_path = os.path.join(checkpoint_dir, model_name) 59 | checkpoint = load_checkpoint(checkpoint_path, use_cuda) 60 | model1.load_state_dict(checkpoint['model']) # model1 61 | # model2.load_state_dict(checkpoint['model2']) 62 | # model3.load_state_dict(checkpoint['model3']) 63 | print('=> Reload well-trained model {} for decoding.'.format( 64 | model_name)) 65 | 66 | 67 | def load_checkpoint(checkpoint_path, use_cuda): 68 | if use_cuda: 69 | checkpoint = torch.load(checkpoint_path, map_location='cuda:0') 70 | else: 71 | checkpoint = torch.load( 72 | checkpoint_path, map_location=lambda storage, loc: storage) 73 | return checkpoint 74 | 75 | 76 | def learning_rate_decaying(optimizer, rate): 77 | """decaying the learning rate""" 78 | lr = get_learning_rate(optimizer) * rate 79 | for param_group in optimizer.param_groups: 80 | param_group["lr"] = lr 81 | 82 | 83 | def get_learning_rate(optimizer): 84 | """Get learning rate""" 85 | return optimizer.param_groups[0]["lr"] 86 | -------------------------------------------------------------------------------- /f_att_cplx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from linear_real import Real_Linear 7 | from linear_cplx import Complex_Linear 8 | 9 | EPSILON = torch.finfo(torch.float32).eps 10 | 11 | 12 | 13 | class F_att(nn.Module): 14 | def __init__(self, in_channel, hidden_channel): 15 | super(F_att, self).__init__() 16 | self.query = Real_Linear(in_channel, hidden_channel) 17 | self.key = Real_Linear(in_channel, hidden_channel) 18 | self.value = Real_Linear(in_channel, hidden_channel) 19 | self.softmax = nn.Softmax(dim = -1) 20 | self.hidden_channel = hidden_channel 21 | 22 | def forward(self, q, k, v): 23 | # NT * F * C 24 | query = self.query(q) 25 | key = self.key(k) 26 | value = self.value(v) 27 | energy = torch.einsum("...tf,...fy->...ty", [query, key.transpose(1, 2)]) / self.hidden_channel**0.5 28 | energy = self.softmax(energy) # NT * F * F 29 | weighted_value = torch.einsum("...tf,...fy->...ty", [energy, value]) 30 | 31 | return weighted_value 32 | 33 | class Self_Attention_F(nn.Module): 34 | def __init__(self, in_channel, hidden_channel): 35 | super(Self_Attention_F, self).__init__() 36 | self.F_att1 = F_att(in_channel, hidden_channel) 37 | self.F_att2 = F_att(in_channel, hidden_channel) 38 | self.F_att3 = F_att(in_channel, hidden_channel) 39 | self.F_att4 = F_att(in_channel, hidden_channel) 40 | self.F_att5 = F_att(in_channel, hidden_channel) 41 | self.F_att6 = F_att(in_channel, hidden_channel) 42 | self.F_att7 = F_att(in_channel, hidden_channel) 43 | self.F_att8 = F_att(in_channel, hidden_channel) 44 | self.layernorm1 = nn.LayerNorm(in_channel) 45 | self.layernorm2 = nn.LayerNorm(hidden_channel) 46 | 47 | def forward(self, x): 48 | # N*T, F, C, 2 49 | x = self.layernorm1(x.transpose(2, 3)).transpose(2, 3) 50 | real, imag = x[...,0], x[...,1] 51 | A = self.F_att1(real, real, real) 52 | B = self.F_att2(real, imag, imag) 53 | C = self.F_att3(imag, real, imag) 54 | D = self.F_att4(imag, imag, real) 55 | E = self.F_att5(real, real, imag) 56 | F = self.F_att6(real, imag, real) 57 | G = self.F_att7(imag, real, real) 58 | H = self.F_att8(imag, imag, imag) 59 | real_att = A-B-C-D 60 | imag_att = E+F+G-H 61 | out = torch.stack([real_att, imag_att], -1) 62 | out = self.layernorm2(out.transpose(2, 3)).transpose(2, 3) 63 | return out 64 | 65 | class Multihead_Attention_F_Branch(nn.Module): 66 | def __init__(self, in_channel, hidden_channel, n_heads=1): 67 | super(Multihead_Attention_F_Branch, self).__init__() 68 | self.attn_heads = nn.ModuleList([Self_Attention_F(in_channel, hidden_channel) for _ in range(n_heads)] ) 69 | self.transform_linear = Complex_Linear(hidden_channel, in_channel) 70 | self.layernorm3 = nn.LayerNorm(in_channel) 71 | self.dropout = nn.Dropout(p=0.1) 72 | self.prelu = nn.PReLU() 73 | 74 | def forward(self, inputs): 75 | # N * C * F * T * 2 76 | N, C, F, T, ri = inputs.shape 77 | x = inputs.permute(0, 3, 2, 1, 4) # N T F C 2 78 | x = x.contiguous().view([N*T, F, C, ri]) 79 | x = [attn(x) for i, attn in enumerate(self.attn_heads)] 80 | x = torch.stack(x, -1) 81 | x = x.squeeze(-1) 82 | outs = self.transform_linear(x) 83 | outs = outs.contiguous().view([N, T, F, C, ri]) 84 | outs = outs.permute(0, 3, 2, 1, 4) 85 | outs = self.prelu(self.layernorm3(outs.transpose(1, 4)).transpose(1, 4)) 86 | outs = self.dropout(outs) 87 | outs = outs + inputs 88 | return outs 89 | 90 | 91 | if __name__ == '__main__': 92 | net = Multihead_Attention_F_Branch(128, 64) 93 | inputs = torch.ones([10, 128, 4, 397, 2]) 94 | y = net(inputs) 95 | print(y.shape) 96 | -------------------------------------------------------------------------------- /t_att_cplx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from linear_real import Real_Linear 7 | from linear_cplx import Complex_Linear 8 | 9 | 10 | 11 | EPSILON = torch.finfo(torch.float32).eps 12 | 13 | 14 | 15 | class T_att(nn.Module): 16 | def __init__(self, in_channel, hidden_channel): 17 | super(T_att, self).__init__() 18 | self.query = Real_Linear(in_channel, hidden_channel) 19 | self.key = Real_Linear(in_channel, hidden_channel) 20 | self.value = Real_Linear(in_channel, hidden_channel) 21 | self.softmax = nn.Softmax(dim = -1) 22 | self.hidden_channel = hidden_channel 23 | 24 | def forward(self, q, k, v): 25 | causal = False 26 | # NF * T * C 27 | query = self.query(q) 28 | key = self.key(k) 29 | value = self.value(v) 30 | energy = torch.einsum("...tf,...fy->...ty", [query, key.transpose(1, 2)]) / self.hidden_channel**0.5 31 | if causal: 32 | mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2]), diagonal=0) 33 | mask = mask.to(energy.device) 34 | energy = energy * mask 35 | energy = self.softmax(energy) # NF * T * T 36 | weighted_value = torch.einsum("...tf,...fy->...ty", [energy, value]) 37 | 38 | return weighted_value 39 | 40 | class Self_Attention_T(nn.Module): 41 | def __init__(self, in_channel, hidden_channel): 42 | super(Self_Attention_T, self).__init__() 43 | self.T_att1 = T_att(in_channel, hidden_channel) 44 | self.T_att2 = T_att(in_channel, hidden_channel) 45 | self.T_att3 = T_att(in_channel, hidden_channel) 46 | self.T_att4 = T_att(in_channel, hidden_channel) 47 | self.T_att5 = T_att(in_channel, hidden_channel) 48 | self.T_att6 = T_att(in_channel, hidden_channel) 49 | self.T_att7 = T_att(in_channel, hidden_channel) 50 | self.T_att8 = T_att(in_channel, hidden_channel) 51 | self.layernorm1 = nn.LayerNorm(in_channel) 52 | self.layernorm2 = nn.LayerNorm(hidden_channel) 53 | 54 | def forward(self, x): 55 | # N*F, T, C, 2 56 | x = self.layernorm1(x.transpose(2, 3)).transpose(2, 3) 57 | real, imag = x[...,0], x[...,1] 58 | A = self.T_att1(real, real, real) 59 | B = self.T_att2(real, imag, imag) 60 | C = self.T_att3(imag, real, imag) 61 | D = self.T_att4(imag, imag, real) 62 | E = self.T_att5(real, real, imag) 63 | F = self.T_att6(real, imag, real) 64 | G = self.T_att7(imag, real, real) 65 | H = self.T_att8(imag, imag, imag) 66 | real_att = A-B-C-D 67 | imag_att = E+F+G-H 68 | out = torch.stack([real_att, imag_att], -1) 69 | out = self.layernorm2(out.transpose(2, 3)).transpose(2, 3) 70 | return out 71 | 72 | class Multihead_Attention_T_Branch(nn.Module): 73 | def __init__(self, in_channel, hidden_channel, n_heads=1): 74 | super(Multihead_Attention_T_Branch, self).__init__() 75 | self.attn_heads = nn.ModuleList([Self_Attention_T(in_channel, hidden_channel) for _ in range(n_heads)] ) 76 | self.transform_linear = Complex_Linear(hidden_channel, in_channel) 77 | self.layernorm3 = nn.LayerNorm(in_channel) 78 | self.dropout = nn.Dropout(p=0.1) 79 | self.prelu = nn.PReLU() 80 | 81 | def forward(self, inputs): 82 | # N * C * F * T * 2 83 | 84 | N, C, F, T, ri = inputs.shape 85 | x = inputs.permute(0, 2, 3, 1, 4) # N F T C 2 86 | x = x.contiguous().view([N*F, T, C, ri]) 87 | x = [attn(x) for i, attn in enumerate(self.attn_heads)] 88 | x = torch.stack(x, -1) 89 | x = x.squeeze(-1) 90 | outs = self.transform_linear(x) 91 | outs = outs.contiguous().view([N, F, T, C, ri]) 92 | outs = outs.permute(0, 3, 1, 2, 4) 93 | outs = self.prelu(self.layernorm3(outs.transpose(1, 4)).transpose(1, 4)) 94 | outs = self.dropout(outs) 95 | outs = outs + inputs 96 | return outs 97 | 98 | if __name__ == '__main__': 99 | net = Multihead_Attention_T_Branch(128, 64) 100 | inputs = torch.ones([10, 128, 4, 397, 2]) 101 | y = net(inputs) 102 | print(y.shape) 103 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import os 8 | import sys 9 | 10 | from trans import MelTransform 11 | EPSILON = torch.finfo(torch.float32).eps 12 | 13 | def sisnr(x, s, eps=EPSILON): 14 | """ 15 | Arguments 16 | x: separated signal, N x S tensor 17 | s: reference signal, N x S tensor 18 | Return 19 | sisnr: N tensor 20 | """ 21 | def l2norm(mat, keepdim=False): 22 | return torch.norm(mat, dim=-1, keepdim=keepdim) 23 | 24 | x_zm = x - torch.mean(x) 25 | s_zm = s - torch.mean(s) 26 | t = torch.sum(x_zm * s_zm) * s_zm / (l2norm(s_zm)**2 + eps) 27 | return 0.0-20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps)) 28 | 29 | def calloss(output, source): 30 | loss_total = 0.0 31 | zerocount = 0 32 | for i in range(output.shape[0]): 33 | loss = sisnr(output[i], source[i]) 34 | if torch.mean(source[i]**2) < 1.2e-8: 35 | loss = 0.0 36 | zerocount = zerocount + 1 37 | loss_total = loss_total + loss 38 | loss_total = loss_total / (output.shape[0] - zerocount) 39 | return loss_total, loss_total, loss_total 40 | 41 | 42 | def calloss_cplxmse(output, source): 43 | # B 2 F T 44 | loss = 0 45 | output_real, output_imag = output[:,0], output[:,1] 46 | source_real, source_imag = source[:,0], source[:,1] 47 | for i in range(output.shape[0]): 48 | loss_real = F.mse_loss(output_real[i], source_real[i], reduction='sum') 49 | loss_real = loss_real / output_real.shape[-2] 50 | loss_imag = F.mse_loss(output_imag[i], source_imag[i], reduction='sum') 51 | loss_imag = loss_imag / output_imag.shape[-2] 52 | loss = loss + loss_real + loss_imag 53 | return loss / output.shape[0] / 2, loss / output.shape[0] / 2, loss / output.shape[0] / 2 54 | 55 | def calloss_magmse(output, source): 56 | output_mag = torch.sqrt(torch.clamp(output[:,0]**2 + output[:,1]**2, EPSILON)) 57 | source_mag = torch.sqrt(torch.clamp(source[:,0]**2 + source[:,1]**2, EPSILON)) 58 | loss = 0 59 | for i in range(output.shape[0]): 60 | loss_mag = F.mse_loss(output_mag[i], source_mag[i], reduction='sum') 61 | loss_mag = loss_mag / output_mag.shape[-2] 62 | loss = loss + loss_mag 63 | return loss / output.shape[0], loss / output.shape[0], loss / output.shape[0] 64 | 65 | def calloss_cplxmse_subband(output, source): 66 | loss = 0 67 | output = output[:, :, 1:] 68 | source = source[:, :, 1:] 69 | output_real, output_imag = output[:,0], output[:,1] 70 | source_real, source_imag = source[:,0], source[:,1] # N F T 71 | output_real = output_real.chunk(4, -2) 72 | output_imag = output_imag.chunk(4, -2) 73 | source_real = source_real.chunk(4, -2) 74 | source_imag = source_imag.chunk(4, -2) 75 | output_real = torch.stack(output_real, -1) #N F' T 4 76 | output_imag = torch.stack(output_imag, -1) 77 | source_real = torch.stack(source_real, -1) 78 | source_imag = torch.stack(source_imag, -1) 79 | # weight = [0.4, 0.2, 0.15, 0.1, 0.06, 0.04, 0.03, 0.02] 80 | weight = [1.5, 1.2, 0.8, 0.5] 81 | for i in range(output_real.shape[0]): 82 | for j in range(output_real.shape[-1]): 83 | loss_real = F.mse_loss(output_real[i, :, :, j], source_real[i, :, :, j], reduction='sum') 84 | loss_real = weight[j] * loss_real 85 | loss_imag = F.mse_loss(output_imag[i, :, :, j], source_imag[i, :, :, j], reduction='sum') 86 | loss_imag = weight[j] * loss_imag 87 | loss = loss + loss_real + loss_imag 88 | return loss / output.shape[0] / output.shape[2] / 2, loss / output.shape[0] / output.shape[2] / 2, loss / output.shape[0] / output.shape[2] / 2 89 | 90 | def calloss_magmse_subband(output, source): 91 | output_mag = torch.sqrt(torch.clamp(output[:,0]**2 + output[:,1]**2, EPSILON)) 92 | source_mag = torch.sqrt(torch.clamp(source[:,0]**2 + source[:,1]**2, EPSILON)) 93 | # output_mag = output 94 | # source_mag = source 95 | loss = 0 96 | output_mag = output_mag[:, 1:] # N F T 97 | source_mag = source_mag[:, 1:] 98 | # weight = [0.4, 0.2, 0.15, 0.1, 0.06, 0.04, 0.03, 0.02] 99 | weight = [1.5, 1.2, 0.8, 0.5] 100 | output_mag = output_mag.chunk(4, -2) 101 | output_mag = torch.stack(output_mag, -1) 102 | source_mag = source_mag.chunk(4, -2) 103 | source_mag = torch.stack(source_mag, -1) #N F' T 4 104 | for i in range(output_mag.shape[0]): 105 | for j in range(output_mag.shape[-1]): 106 | loss_mag = F.mse_loss(output_mag[i, :, :, j], source_mag[i, :, :, j], reduction='sum') 107 | loss_mag = weight[j] * loss_mag 108 | loss = loss + loss_mag 109 | return loss / output.shape[0] / output_mag.shape[2], loss / output.shape[0] / output_mag.shape[2], loss / output.shape[0] / output_mag.shape[2] 110 | 111 | def calloss_fbankmse_subband(output, source): 112 | mel = MelTransform(960, sr=48000, num_mels=128) 113 | output_mag = torch.sqrt(torch.clamp(output[:,0]**2 + output[:,1]**2, EPSILON)) 114 | source_mag = torch.sqrt(torch.clamp(source[:,0]**2 + source[:,1]**2, EPSILON)) 115 | output_mag = output_mag.transpose(1, 2) # N T F 116 | source_mag = source_mag.transpose(1, 2) 117 | output_mag = mel(output_mag) 118 | source_mag = mel(source_mag) 119 | output_mag = output_mag.chunk(8, -1) 120 | output_mag = torch.stack(output_mag, -1) 121 | source_mag = source_mag.chunk(8, -1) 122 | source_mag = torch.stack(source_mag, -1) #N T F' 8 123 | weight = [0.4, 0.2, 0.15, 0.1, 0.06, 0.04, 0.03, 0.02] 124 | loss = 0 125 | for i in range(output_mag.shape[0]): 126 | for j in range(output_mag.shape[-1]): 127 | loss_mag = F.mse_loss(output_mag[i, :, :, j], source_mag[i, :, :, j], reduction='sum') 128 | loss_mag = weight[j] * loss_mag 129 | loss = loss + loss_mag 130 | return loss / output.shape[0] / output_mag.shape[2], loss / output.shape[0] / output_mag.shape[2], loss / output.shape[0] / output_mag.shape[2] 131 | 132 | def calloss_timemae(output, source): 133 | loss = 0 134 | for i in range(output.shape[0]): 135 | loss_time = F.l1_loss(output[i], source[i], reduction='sum') 136 | loss = loss + loss_time 137 | return loss / output.shape[0], loss / output.shape[0], loss / output.shape[0] 138 | 139 | def calloss_bce(output, source): 140 | loss = 0 141 | # for i in range(output.shape[0]): 142 | # loss_bce = F.binary_cross_entropy(F.sigmoid(output), source, reduction='sum') 143 | # loss = loss + loss_bce 144 | loss = F.binary_cross_entropy(output, source, reduction='sum') 145 | return loss/output.shape[0]/output.shape[1], loss/output.shape[0]/output.shape[1], loss/output.shape[0]/output.shape[1] 146 | 147 | def calacc(output, source): 148 | a = torch.zeros(output.shape, device = output.device) 149 | b = torch.ones(output.shape, device = output.device) 150 | output = torch.where(output<=0.5, a, b) 151 | error = int(torch.sum(torch.abs(output-source)).item()) 152 | total = output.shape[0] * output.shape[1] * output.shape[2] 153 | acc = (total-error) / total 154 | return acc, acc, acc 155 | 156 | def calloss_asrenc(output, source): 157 | # n t f 158 | loss = F.mse_loss(output, source, reduction='sum') 159 | loss = loss / output.shape[0]/output.shape[-2] 160 | return loss, loss, loss -------------------------------------------------------------------------------- /uformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import os 8 | import sys 9 | 10 | import torch_complex 11 | from torch_complex import ComplexTensor 12 | import warnings 13 | from time import time 14 | 15 | EPSILON = torch.finfo(torch.float32).eps 16 | sys.path.append(os.path.dirname(sys.path[0]) + '/model') 17 | 18 | from trans import STFT, iSTFT, MelTransform, inv_MelTransform 19 | from conv2d_cplx import ComplexConv2d_Encoder, ComplexConv2d_Decoder 20 | from conv2d_real import RealConv2d_Encoder, RealConv2d_Decoder 21 | from dilated_dualpath_conformer import Dilated_Dualpath_Conformer 22 | from fusion import fusion as fusion 23 | from show import show_model, show_params 24 | 25 | 26 | def tanhextern(input): 27 | out = 10 * (1-torch.exp(-0.1*input)) / (1+torch.exp(-0.1*input)) 28 | 29 | class Uformer(nn.Module): 30 | 31 | def __init__(self, 32 | win_len=400, 33 | win_inc=160, 34 | fft_len=512, 35 | win_type='hanning', 36 | fid=None): 37 | super(Uformer, self).__init__() 38 | input_dim = win_len 39 | output_dim = win_len 40 | self.kernel_num = [1,8,16,32,64,128,128] 41 | self.kernel_num_real = [1,8,16,32,64,128] 42 | 43 | self.encoder = nn.ModuleList() 44 | self.decoder = nn.ModuleList() 45 | self.encoder_real = nn.ModuleList() 46 | self.decoder_real = nn.ModuleList() 47 | for idx in range(len(self.kernel_num)-1): 48 | self.encoder.append( 49 | nn.Sequential( 50 | ComplexConv2d_Encoder( 51 | self.kernel_num[idx], 52 | self.kernel_num[idx+1], 53 | kernel_size=(5, 2), 54 | stride=(2, 1), 55 | padding=(2, 1), 56 | dilation=(1, 1), 57 | groups = 1 58 | ), 59 | nn.BatchNorm3d(self.kernel_num[idx+1]), 60 | nn.PReLU() 61 | ) 62 | ) 63 | 64 | for idx in range(len(self.kernel_num)-1): 65 | self.encoder_real.append( 66 | nn.Sequential( 67 | RealConv2d_Encoder( 68 | self.kernel_num[idx], 69 | self.kernel_num[idx+1], 70 | kernel_size=(5, 2), 71 | stride=(2, 1), 72 | padding=(2,1), 73 | dilation=(1, 1), 74 | groups = 1 75 | ), 76 | nn.BatchNorm2d(self.kernel_num[idx+1]), 77 | nn.PReLU() 78 | ) 79 | ) 80 | 81 | 82 | self.conformer = Dilated_Dualpath_Conformer() 83 | 84 | 85 | for idx in range(len(self.kernel_num)-1, 0, -1): 86 | if idx >= 2: 87 | self.decoder.append( 88 | nn.Sequential( 89 | ComplexConv2d_Decoder( 90 | self.kernel_num[idx]*2, 91 | self.kernel_num[idx-1], 92 | kernel_size =(5, 2), 93 | stride=(2,1), 94 | padding=(2, 0), 95 | output_padding = (1, 0), 96 | dilation=(1, 1), 97 | groups = 1 98 | ), 99 | nn.BatchNorm3d(self.kernel_num[idx-1]), 100 | #nn.ELU() 101 | nn.PReLU() 102 | ) 103 | ) 104 | 105 | else: 106 | self.decoder.append( 107 | nn.Sequential( 108 | ComplexConv2d_Decoder( 109 | self.kernel_num[idx]*2, 110 | self.kernel_num[idx-1], 111 | kernel_size =(5, 2), 112 | stride=(2,1), 113 | padding=(2, 0), 114 | output_padding = (1, 0), 115 | dilation=(1, 1), 116 | groups = 1 117 | ), 118 | ) 119 | ) 120 | 121 | for idx in range(len(self.kernel_num)-1, 0, -1): 122 | if idx >= 2: 123 | self.decoder_real.append( 124 | nn.Sequential( 125 | RealConv2d_Decoder( 126 | self.kernel_num[idx]*2, 127 | self.kernel_num[idx-1], 128 | kernel_size =(5, 2), 129 | stride=(2,1), 130 | padding=(2,0), 131 | output_padding = (1, 0), 132 | dilation=(1, 1), 133 | groups = 1 134 | ), 135 | nn.BatchNorm2d(self.kernel_num[idx-1]), 136 | #nn.ELU() 137 | nn.PReLU() 138 | ) 139 | ) 140 | 141 | else: 142 | self.decoder_real.append( 143 | nn.Sequential( 144 | RealConv2d_Decoder( 145 | self.kernel_num[idx]*2, 146 | self.kernel_num[idx-1], 147 | kernel_size =(5, 2), 148 | stride=(2,1), 149 | padding=(2,0), 150 | output_padding = (1, 0), 151 | dilation=(1, 1), 152 | groups = 1 153 | ), 154 | ) 155 | ) 156 | 157 | 158 | 159 | self.stft = STFT(frame_len=win_len, frame_hop=win_inc) 160 | self.istft = iSTFT(frame_len=win_len, frame_hop=win_inc) 161 | 162 | show_model(self, fid) 163 | show_params(self, fid) 164 | 165 | def flatten_parameters(self): 166 | self.enhance.flatten_parameters() 167 | 168 | def forward(self, inputs, src): 169 | warnings.filterwarnings('ignore') 170 | 171 | 172 | inputs_real, inputs_imag = self.stft(inputs[:,0].unsqueeze(1)) 173 | src_real, src_imag = self.stft(src[:,0]) 174 | src = self.istft((src_real, src_imag)) 175 | src_mag, src_pha = torch.sqrt(torch.clamp(src_real ** 2 + src_imag ** 2, EPSILON)), torch.atan2(src_imag+EPSILON, src_real) 176 | 177 | src_mag = src_mag ** 0.5 178 | src_real, src_imag = src_mag * torch.cos(src_pha), src_mag * torch.sin(src_pha) 179 | src_cplx = torch.stack([src_real, src_imag], 1) 180 | 181 | 182 | mag, phase = torch.sqrt(torch.clamp(inputs_real ** 2 + inputs_imag ** 2, EPSILON)), torch.atan2(inputs_imag+EPSILON, inputs_real) 183 | mag = mag ** 0.5 184 | mag_input = [] 185 | 186 | mag_input.append(mag) 187 | 188 | 189 | 190 | 191 | inputs_real, inputs_imag = mag * torch.cos(phase), mag * torch.sin(phase) 192 | 193 | 194 | 195 | 196 | 197 | out = torch.stack([inputs_real, inputs_imag], -1) # B C F T 2 198 | out = out[:, :, 1:] 199 | mag = mag[:, :, 1:] 200 | encoder_out = [] 201 | mag_out = [] 202 | 203 | for idx in range(len(self.encoder)): 204 | out = self.encoder[idx](out) 205 | mag = self.encoder_real[idx](mag) 206 | out, mag = fusion(out, mag) 207 | mag_out.append(mag) 208 | encoder_out.append(out) 209 | 210 | 211 | 212 | out, mag = self.conformer(out, mag) 213 | 214 | for idx in range(len(self.decoder)): 215 | out_cat = torch.cat([encoder_out[-1 - idx],out],1) 216 | out = self.decoder[idx](out_cat) 217 | 218 | mag_cat = torch.cat([mag_out[-1 - idx],mag],1) 219 | mag = self.decoder_real[idx](mag_cat) 220 | 221 | out, mag = fusion(out, mag) 222 | 223 | 224 | 225 | mag = torch.sigmoid(mag) 226 | mag = F.pad(mag, [0,0,1,0]) 227 | 228 | mag = mag[:,0] * mag_input[0][:,0] 229 | 230 | mask_real = out[...,0] 231 | mask_imag = out[...,1] 232 | 233 | mask_mags = torch.sqrt(torch.clamp(mask_real**2 + mask_imag**2, EPSILON)) 234 | real_phase = mask_real/(mask_mags+EPSILON) 235 | imag_phase = mask_imag/(mask_mags+EPSILON) 236 | mask_mags = torch.tanh(mask_mags+EPSILON) 237 | mask_phase = torch.atan2(imag_phase+EPSILON, real_phase) 238 | mask_mags = F.pad(mask_mags, [0,0,1,0]) 239 | mask_phase = F.pad(mask_phase, [0,0,1,0]) 240 | 241 | 242 | 243 | est_mags = mask_mags[:, 0]*mag_input[0][:,0] 244 | 245 | 246 | est_phase = phase[:, 0] + mask_phase[:, 0] 247 | 248 | 249 | 250 | mag_compress, pha_compress = est_mags, est_phase 251 | mag_compress = (mag_compress + mag)*0.5 252 | 253 | real, imag = mag_compress * torch.cos(pha_compress), mag_compress * torch.sin(pha_compress) 254 | 255 | output_real = [] 256 | output_imag = [] 257 | output = [] 258 | output_real.append(real) 259 | output_imag.append(imag) 260 | 261 | 262 | mag_compress = mag_compress ** 2 263 | real, imag = mag_compress * torch.cos(pha_compress), mag_compress * torch.sin(pha_compress) 264 | 265 | 266 | spk1 = self.istft((real, imag)) 267 | output.append(spk1) 268 | 269 | output = torch.stack(output, 1) 270 | output = output.squeeze(1) 271 | output_real = torch.stack(output_real, 1) 272 | output_imag = torch.stack(output_imag, 1) 273 | output_real = output_real.squeeze(1) # N x C x F x T 274 | output_imag = output_imag.squeeze(1) 275 | output_cplx = torch.stack([output_real, output_imag], 1) # N x 2 x C x F x T 276 | return output, src, output_cplx, src_cplx 277 | 278 | def get_params(self, weight_decay=0.0): 279 | 280 | weights, biases = [], [] 281 | for name, param in self.named_parameters(): 282 | if 'bias' in name: 283 | biases += [param] 284 | else: 285 | weights += [param] 286 | params = [{ 287 | 'params': weights, 288 | 'weight_decay': weight_decay, 289 | }, { 290 | 'params': biases, 291 | 'weight_decay': 0.0, 292 | }] 293 | return params 294 | 295 | 296 | if __name__ == '__main__': 297 | torch.manual_seed(10) 298 | torch.set_num_threads(4) 299 | 300 | import soundfile as sf 301 | import numpy as np 302 | 303 | net = Uformer() 304 | inputs = torch.randn([10,1,64000]) 305 | print(inputs.shape) 306 | outputs = net(inputs,inputs) 307 | 308 | 309 | -------------------------------------------------------------------------------- /trans.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Jian Wu 2 | # License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | 4 | import math 5 | 6 | import numpy as np 7 | import torch as th 8 | import torch.nn as nn 9 | import torch.nn.functional as tf 10 | import librosa.filters as filters 11 | import soundfile as sf 12 | import librosa 13 | import sys 14 | import os 15 | sys.path.append(os.path.dirname(sys.path[0]) + '/model') 16 | 17 | EPSILON = 1e-5 18 | # from typing import Optional, Union, Tuple 19 | 20 | 21 | def init_window(wnd: str, frame_len: int) -> th.Tensor: 22 | """ 23 | Return window coefficient 24 | Args: 25 | wnd: window name 26 | frame_len: length of the frame 27 | """ 28 | 29 | def sqrthann(frame_len, periodic=True): 30 | return th.hann_window(frame_len, periodic=periodic)**0.5 31 | 32 | if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]: 33 | raise RuntimeError(f"Unknown window type: {wnd}") 34 | 35 | wnd_tpl = { 36 | "sqrthann": sqrthann, 37 | "hann": th.hann_window, 38 | "hamm": th.hamming_window, 39 | "blackman": th.blackman_window, 40 | "bartlett": th.bartlett_window, 41 | "rect": th.ones 42 | } 43 | if wnd != "rect": 44 | # match with librosa 45 | c = wnd_tpl[wnd](frame_len, periodic=True) 46 | else: 47 | c = wnd_tpl[wnd](frame_len) 48 | return c 49 | 50 | 51 | def init_kernel(frame_len: int, 52 | frame_hop: int, 53 | window: str, 54 | round_pow_of_two: bool = True, 55 | normalized: bool = False, 56 | inverse: bool = False, 57 | mode: str = "librosa") -> th.Tensor: 58 | """ 59 | Return STFT kernels 60 | Args: 61 | frame_len: length of the frame 62 | frame_hop: hop size between frames 63 | window: window name 64 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 65 | normalized: return normalized DFT matrix 66 | inverse: return iDFT matrix 67 | mode: framing mode (librosa or kaldi) 68 | """ 69 | if mode not in ["librosa", "kaldi"]: 70 | raise ValueError(f"Unsupported mode: {mode}") 71 | # FFT points 72 | B = 2**math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len 73 | # center padding window if needed 74 | if mode == "librosa" and B != frame_len: 75 | lpad = (B - frame_len) // 2 76 | window = tf.pad(window, (lpad, B - frame_len - lpad)) 77 | if normalized: 78 | # make K^H * K = I 79 | S = B**0.5 80 | else: 81 | S = 1 82 | I = th.stack([th.eye(B), th.zeros(B, B)], dim=-1) 83 | # W x B x 2 84 | K = th.fft.fft(I / S, 1) 85 | K = th.cat([K.real, K.imag], -1) 86 | if mode == "kaldi": 87 | K = K[:frame_len] 88 | if inverse and not normalized: 89 | # to make K^H * K = I 90 | K = K / B 91 | # 2 x B x W 92 | K = th.transpose(K, 0, 2) * window 93 | # 2B x 1 x W 94 | K = th.reshape(K, (B * 2, 1, K.shape[-1])) 95 | return K, window 96 | 97 | 98 | def mel_filter(frame_len: int, 99 | round_pow_of_two: bool = True, 100 | num_bins: int = None, 101 | sr: int = 16000, 102 | num_mels: int = 80, 103 | fmin: float = 0.0, 104 | fmax: float = None, 105 | norm: bool = False) -> th.Tensor: 106 | """ 107 | Return mel filter coefficients 108 | Args: 109 | frame_len: length of the frame 110 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 111 | num_bins: number of the frequency bins produced by STFT 112 | num_mels: number of the mel bands 113 | fmin: lowest frequency (in Hz) 114 | fmax: highest frequency (in Hz) 115 | norm: normalize the mel filter coefficients 116 | """ 117 | # FFT points 118 | if num_bins is None: 119 | N = 2**math.ceil( 120 | math.log2(frame_len)) if round_pow_of_two else frame_len 121 | else: 122 | N = (num_bins - 1) * 2 123 | # fmin & fmax 124 | freq_upper = sr // 2 125 | if fmax is None: 126 | fmax = freq_upper 127 | else: 128 | fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper) 129 | fmin = max(0, fmin) 130 | # mel filter coefficients 131 | mel = filters.mel(sr, 132 | N, 133 | n_mels=num_mels, 134 | fmax=fmax, 135 | fmin=fmin, 136 | htk=True, 137 | norm="slaney" if norm else None) 138 | # num_mels x (N // 2 + 1) 139 | return th.tensor(mel, dtype=th.float32) 140 | 141 | def inv_mel_filter(frame_len: int, 142 | round_pow_of_two: bool = True, 143 | num_bins: int = None, 144 | sr: int = 16000, 145 | num_mels: int = 80, 146 | fmin: float = 0.0, 147 | fmax: float = None, 148 | norm: bool = False) -> th.Tensor: 149 | """ 150 | Return mel filter coefficients 151 | Args: 152 | frame_len: length of the frame 153 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 154 | num_bins: number of the frequency bins produced by STFT 155 | num_mels: number of the mel bands 156 | fmin: lowest frequency (in Hz) 157 | fmax: highest frequency (in Hz) 158 | norm: normalize the mel filter coefficients 159 | """ 160 | # FFT points 161 | if num_bins is None: 162 | N = 2**math.ceil( 163 | math.log2(frame_len)) if round_pow_of_two else frame_len 164 | else: 165 | N = (num_bins - 1) * 2 166 | # fmin & fmax 167 | freq_upper = sr // 2 168 | if fmax is None: 169 | fmax = freq_upper 170 | else: 171 | fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper) 172 | fmin = max(0, fmin) 173 | # mel filter coefficients 174 | mel = filters.mel(sr, 175 | N, 176 | n_mels=num_mels, 177 | fmax=fmax, 178 | fmin=fmin, 179 | htk=True, 180 | norm="slaney" if norm else None) 181 | mel = np.linalg.pinv(mel) 182 | # num_mels x (N // 2 + 1) 183 | return th.tensor(mel, dtype=th.float32) 184 | 185 | 186 | def speed_perturb_filter(src_sr: int, 187 | dst_sr: int, 188 | cutoff_ratio: float = 0.95, 189 | num_zeros: int = 64) -> th.Tensor: 190 | """ 191 | Return speed perturb filters, reference: 192 | https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py 193 | Args: 194 | src_sr: sample rate of the source signal 195 | dst_sr: sample rate of the target signal 196 | Return: 197 | weight (Tensor): coefficients of the filter 198 | """ 199 | if src_sr == dst_sr: 200 | raise ValueError( 201 | f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}") 202 | gcd = math.gcd(src_sr, dst_sr) 203 | src_sr = src_sr // gcd 204 | dst_sr = dst_sr // gcd 205 | if src_sr == 1 or dst_sr == 1: 206 | raise ValueError("do not support integer downsample/upsample") 207 | zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio 208 | padding = 1 + int(num_zeros / zeros_per_block) 209 | # dst_sr x src_sr x K 210 | times = (np.arange(dst_sr)[:, None, None] / float(dst_sr) - 211 | np.arange(src_sr)[None, :, None] / float(src_sr) - 212 | np.arange(2 * padding + 1)[None, None, :] + padding) 213 | window = np.heaviside(1 - np.abs(times / padding), 214 | 0.0) * (0.5 + 0.5 * np.cos(times / padding * math.pi)) 215 | weight = np.sinc( 216 | times * zeros_per_block) * window * zeros_per_block / float(src_sr) 217 | return th.tensor(weight, dtype=th.float32) 218 | 219 | 220 | def splice_feature(feats: th.Tensor, 221 | lctx: int = 1, 222 | rctx: int = 1, 223 | subsampling_factor: int = 1, 224 | op: str = "cat") -> th.Tensor: 225 | """ 226 | Splice feature 227 | Args: 228 | feats (Tensor): N x ... x T x F, original feature 229 | lctx: left context 230 | rctx: right context 231 | subsampling_factor: subsampling factor 232 | op: operator on feature context 233 | Return: 234 | splice (Tensor): feature with context padded 235 | """ 236 | if lctx + rctx == 0: 237 | return feats 238 | if op not in ["cat", "stack"]: 239 | raise ValueError(f"Unknown op for feature splicing: {op}") 240 | # [N x ... x T x F, ...] 241 | ctx = [] 242 | T = feats.shape[-2] 243 | T = T - T % subsampling_factor 244 | for c in range(-lctx, rctx + 1): 245 | idx = th.arange(c, c + T, device=feats.device, dtype=th.int64) 246 | idx = th.clamp(idx, min=0, max=T - 1) 247 | ctx.append(th.index_select(feats, -2, idx)) 248 | if op == "cat": 249 | # N x ... x T x FD 250 | splice = th.cat(ctx, -1) 251 | else: 252 | # N x ... x T x F x D 253 | splice = th.stack(ctx, -1) 254 | return splice 255 | 256 | 257 | def _forward_stft( 258 | wav: th.Tensor, 259 | kernel: th.Tensor, 260 | output: str = "polar", 261 | pre_emphasis: float = 0, 262 | frame_hop: int = 256, 263 | onesided: bool = False, 264 | center: bool = False): 265 | """ 266 | STFT inner function 267 | Args: 268 | wav (Tensor), N x (C) x S 269 | kernel (Tensor), STFT transform kernels, from init_kernel(...) 270 | output (str), output format: 271 | polar: return (magnitude, phase) pair 272 | complex: return (real, imag) pair 273 | real: return [real; imag] Tensor 274 | frame_hop: frame hop size in number samples 275 | pre_emphasis: factor of preemphasis 276 | onesided: return half FFT bins 277 | center: if true, we assumed to have centered frames 278 | Return: 279 | transform (Tensor or [Tensor, Tensor]), STFT transform results 280 | """ 281 | wav_dim = wav.dim() 282 | if output not in ["polar", "complex", "real"]: 283 | raise ValueError(f"Unknown output format: {output}") 284 | if wav_dim not in [2, 3]: 285 | raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D") 286 | # if N x S, reshape N x 1 x S 287 | # else: reshape NC x 1 x S 288 | N, S = wav.shape[0], wav.shape[-1] 289 | wav = wav.contiguous().view(-1, 1, S) 290 | # NC x 1 x S+2P 291 | if center: 292 | pad = kernel.shape[-1] // 2 293 | # NOTE: match with librosa 294 | wav = tf.pad(wav, (pad, pad), mode="reflect") 295 | # STFT 296 | if pre_emphasis > 0: 297 | # NC x W x T 298 | frames = tf.unfold(wav[:, None], (1, kernel.shape[-1]), 299 | stride=frame_hop, 300 | padding=0) 301 | frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1] 302 | # 1 x 2B x W, NC x W x T, NC x 2B x T 303 | packed = th.matmul(kernel[:, 0][None, ...], frames) 304 | else: 305 | packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0) 306 | # NC x 2B x T => N x C x 2B x T 307 | if wav_dim == 3: 308 | packed = packed.contiguous().view(N, -1, packed.shape[-2], packed.shape[-1]) 309 | # N x (C) x B x T 310 | real, imag = th.chunk(packed, 2, dim=-2) 311 | # N x (C) x B/2+1 x T 312 | if onesided: 313 | num_bins = kernel.shape[0] // 4 + 1 314 | real = real[..., :num_bins, :] 315 | imag = imag[..., :num_bins, :] 316 | if output == "complex": 317 | return (real, imag) 318 | elif output == "real": 319 | return th.stack([real, imag], dim=-1) 320 | else: 321 | mag = (real**2 + imag**2 + EPSILON)**0.5 322 | pha = th.atan2(imag, real) 323 | return (mag, pha) 324 | 325 | 326 | def _inverse_stft(transform, 327 | kernel: th.Tensor, 328 | window: th.Tensor, 329 | input: str = "polar", 330 | frame_hop: int = 256, 331 | onesided: bool = False, 332 | center: bool = False) -> th.Tensor: 333 | """ 334 | iSTFT inner function 335 | Args: 336 | transform (Tensor or [Tensor, Tensor]), STFT transform results 337 | kernel (Tensor), STFT transform kernels, from init_kernel(...) 338 | input (str), input format: 339 | polar: return (magnitude, phase) pair 340 | complex: return (real, imag) pair 341 | real: return [real; imag] Tensor 342 | frame_hop: frame hop size in number samples 343 | onesided: return half FFT bins 344 | center: used in _forward_stft 345 | Return: 346 | wav (Tensor), N x S 347 | """ 348 | if input not in ["polar", "complex", "real"]: 349 | raise ValueError(f"Unknown output format: {input}") 350 | 351 | if input == "real": 352 | real, imag = transform[..., 0], transform[..., 1] 353 | elif input == "polar": 354 | real = transform[0] * th.cos(transform[1]) 355 | imag = transform[0] * th.sin(transform[1]) 356 | else: 357 | real, imag = transform 358 | 359 | # (N) x F x T 360 | imag_dim = imag.dim() 361 | if imag_dim not in [2, 3]: 362 | raise RuntimeError(f"Expect 2D/3D tensor, but got {imag_dim}D") 363 | 364 | # if F x T, reshape 1 x F x T 365 | if imag_dim == 2: 366 | real = th.unsqueeze(real, 0) 367 | imag = th.unsqueeze(imag, 0) 368 | 369 | if onesided: 370 | # [self.num_bins - 2, ..., 1] 371 | reverse = range(kernel.shape[0] // 4 - 1, 0, -1) 372 | # extend matrix: N x B x T 373 | real = th.cat([real, real[:, reverse]], 1) 374 | imag = th.cat([imag, -imag[:, reverse]], 1) 375 | # pack: N x 2B x T 376 | packed = th.cat([real, imag], dim=1) 377 | # N x 1 x T 378 | s = tf.conv_transpose1d(packed, kernel, stride=frame_hop, padding=0) 379 | # normalized audio samples 380 | # refer: https://github.com/pytorch/audio/blob/2ebbbf511fb1e6c47b59fd32ad7e66023fa0dff1/torchaudio/functional.py#L171 381 | # 1 x W x T 382 | win = th.repeat_interleave(window[None, ..., None], 383 | packed.shape[-1], 384 | dim=-1) 385 | # W x 1 x W 386 | I = th.eye(window.shape[0], device=win.device)[:, None] 387 | # 1 x 1 x T 388 | norm = tf.conv_transpose1d(win**2, I, stride=frame_hop, padding=0) 389 | if center: 390 | pad = kernel.shape[-1] // 2 391 | s = s[..., pad:-pad] 392 | norm = norm[..., pad:-pad] 393 | s = s / (norm + EPSILON) 394 | # N x S 395 | s = s.squeeze(1) 396 | return s 397 | 398 | 399 | def forward_stft( 400 | wav: th.Tensor, 401 | frame_len: int, 402 | frame_hop: int, 403 | output: str = "complex", 404 | window: str = "sqrthann", 405 | round_pow_of_two: bool = True, 406 | pre_emphasis: float = 0, 407 | normalized: bool = False, 408 | onesided: bool = True, 409 | center: bool = False, 410 | mode: str = "librosa"): 411 | """ 412 | STFT function implementation, equals to STFT layer 413 | Args: 414 | wav: source audio signal 415 | frame_len: length of the frame 416 | frame_hop: hop size between frames 417 | output: output type (complex, real, polar) 418 | window: window name 419 | center: center flag (similar with that in librosa.stft) 420 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 421 | pre_emphasis: factor of preemphasis 422 | normalized: use normalized DFT kernel 423 | onesided: output onesided STFT 424 | inverse: using iDFT kernel (for iSTFT) 425 | mode: "kaldi"|"librosa", slight difference on applying window function 426 | """ 427 | K, _ = init_kernel(frame_len, 428 | frame_hop, 429 | init_window(window, frame_len), 430 | round_pow_of_two=round_pow_of_two, 431 | normalized=normalized, 432 | inverse=False, 433 | mode=mode) 434 | return _forward_stft(wav, 435 | K.to(wav.device), 436 | output=output, 437 | frame_hop=frame_hop, 438 | pre_emphasis=pre_emphasis, 439 | onesided=onesided, 440 | center=center) 441 | 442 | 443 | def inverse_stft(transform, 444 | frame_len: int, 445 | frame_hop: int, 446 | input: str = "complex", 447 | window: str = "sqrthann", 448 | round_pow_of_two: bool = True, 449 | normalized: bool = False, 450 | onesided: bool = True, 451 | center: bool = False, 452 | mode: str = "librosa") -> th.Tensor: 453 | """ 454 | iSTFT function implementation, equals to iSTFT layer 455 | Args: 456 | transform: results of STFT 457 | frame_len: length of the frame 458 | frame_hop: hop size between frames 459 | input: input format (complex, real, polar) 460 | window: window name 461 | center: center flag (similar with that in librosa.stft) 462 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 463 | normalized: use normalized DFT kernel 464 | onesided: output onesided STFT 465 | mode: "kaldi"|"librosa", slight difference on applying window function 466 | """ 467 | if isinstance(transform, th.Tensor): 468 | device = transform.device 469 | else: 470 | device = transform[0].device 471 | K, w = init_kernel(frame_len, 472 | frame_hop, 473 | init_window(window, frame_len), 474 | round_pow_of_two=round_pow_of_two, 475 | normalized=normalized, 476 | inverse=True, 477 | mode=mode) 478 | return _inverse_stft(transform, 479 | K.to(device), 480 | w.to(device), 481 | input=input, 482 | frame_hop=frame_hop, 483 | onesided=onesided, 484 | center=center) 485 | 486 | 487 | class STFTBase(nn.Module): 488 | """ 489 | Base layer for (i)STFT 490 | 491 | Args: 492 | frame_len: length of the frame 493 | frame_hop: hop size between frames 494 | window: window name 495 | center: center flag (similar with that in librosa.stft) 496 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 497 | normalized: use normalized DFT kernel 498 | pre_emphasis: factor of preemphasis 499 | mode: "kaldi"|"librosa", slight difference on applying window function 500 | onesided: output onesided STFT 501 | inverse: using iDFT kernel (for iSTFT) 502 | """ 503 | 504 | def __init__(self, 505 | frame_len: int, 506 | frame_hop: int, 507 | window: str = "sqrthann", 508 | round_pow_of_two: bool = True, 509 | normalized: bool = False, 510 | pre_emphasis: float = 0, 511 | onesided: bool = True, 512 | inverse: bool = False, 513 | center: bool = False, 514 | mode="librosa") -> None: 515 | super(STFTBase, self).__init__() 516 | K, w = init_kernel(frame_len, 517 | frame_hop, 518 | init_window(window, frame_len), 519 | round_pow_of_two=round_pow_of_two, 520 | normalized=normalized, 521 | inverse=inverse, 522 | mode=mode) 523 | self.K = nn.Parameter(K, requires_grad=False) 524 | self.w = nn.Parameter(w, requires_grad=False) 525 | self.frame_len = frame_len 526 | self.frame_hop = frame_hop 527 | self.onesided = onesided 528 | self.pre_emphasis = pre_emphasis 529 | self.center = center 530 | self.mode = mode 531 | self.num_bins = self.K.shape[0] // 4 + 1 532 | self.expr = ( 533 | f"window={window}, stride={frame_hop}, onesided={onesided}, " + 534 | f"pre_emphasis={self.pre_emphasis}, normalized={normalized}, " + 535 | f"center={self.center}, mode={self.mode}, " + 536 | f"kernel_size={self.num_bins}x{self.K.shape[2]}") 537 | 538 | def num_frames(self, num_samples: th.Tensor) -> th.Tensor: 539 | """ 540 | Compute number of the frames 541 | """ 542 | if th.sum(num_samples <= self.frame_len): 543 | raise RuntimeError( 544 | f"Audio samples less than frame_len ({self.frame_len})") 545 | num_ffts = self.K.shape[-1] 546 | if self.center: 547 | num_samples += num_ffts 548 | return (num_samples - num_ffts) // self.frame_hop + 1 549 | 550 | def extra_repr(self) -> str: 551 | return self.expr 552 | 553 | 554 | class STFT(STFTBase): 555 | """ 556 | Short-time Fourier Transform as a Layer 557 | """ 558 | 559 | def __init__(self, *args, **kwargs): 560 | super(STFT, self).__init__(*args, inverse=False, **kwargs) 561 | 562 | def forward( 563 | self, 564 | wav: th.Tensor, 565 | output: str = "complex" 566 | ): 567 | """ 568 | Accept (single or multiple channel) raw waveform and output magnitude and phase 569 | Args 570 | wav (Tensor) input signal, N x (C) x S 571 | Return 572 | transform (Tensor or [Tensor, Tensor]), N x (C) x F x T 573 | """ 574 | return _forward_stft(wav, 575 | self.K, 576 | output=output, 577 | frame_hop=self.frame_hop, 578 | pre_emphasis=self.pre_emphasis, 579 | onesided=self.onesided, 580 | center=self.center) 581 | 582 | 583 | class iSTFT(STFTBase): 584 | """ 585 | Inverse Short-time Fourier Transform as a Layer 586 | """ 587 | 588 | def __init__(self, *args, **kwargs): 589 | super(iSTFT, self).__init__(*args, inverse=True, **kwargs) 590 | 591 | def forward(self, 592 | transform, 593 | input: str = "complex") -> th.Tensor: 594 | """ 595 | Accept phase & magnitude and output raw waveform 596 | Args 597 | transform (Tensor or [Tensor, Tensor]), STFT output 598 | Return 599 | s (Tensor), N x S 600 | """ 601 | return _inverse_stft(transform, 602 | self.K, 603 | self.w, 604 | input=input, 605 | frame_hop=self.frame_hop, 606 | onesided=self.onesided, 607 | center=self.center) 608 | 609 | 610 | 611 | class MelTransform(nn.Module): 612 | """ 613 | Perform mel tranform (multiply mel filters) 614 | Args: 615 | frame_len: length of the frame 616 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 617 | sr: sample rate of souce signal 618 | num_mels: number of the mel bands 619 | fmin: lowest frequency (in Hz) 620 | fmax: highest frequency (in Hz) 621 | mel_filter: if not "", load mel filter from this 622 | requires_grad: make it trainable or not 623 | """ 624 | 625 | def __init__(self, 626 | frame_len: int, 627 | round_pow_of_two: bool = True, 628 | sr: int = 16000, 629 | num_mels: int = 40, 630 | fmin: float = 0.0, 631 | fmax: float = None, 632 | mel_matrix: str = "", 633 | coeff_norm: bool = False, 634 | requires_grad: bool = False) -> None: 635 | super(MelTransform, self).__init__() 636 | if mel_matrix: 637 | # pass existed tensor for initialization 638 | filters = th.load(mel_matrix) 639 | else: 640 | # NOTE: the following mel matrix is similiar (not equal to) with 641 | # the kaldi results 642 | filters = mel_filter(frame_len, 643 | round_pow_of_two=round_pow_of_two, 644 | sr=sr, 645 | num_mels=num_mels, 646 | fmax=fmax, 647 | fmin=fmin, 648 | norm=coeff_norm) 649 | self.num_mels, self.num_bins = filters.shape 650 | # num_mels x (N // 2 + 1) 651 | self.filters = nn.Parameter(filters, requires_grad=requires_grad) 652 | self.fmin = fmin 653 | self.fmax = sr // 2 if fmax is None else fmax 654 | self.init = mel_matrix if mel_matrix else "librosa" 655 | 656 | def dim(self) -> int: 657 | return self.num_mels 658 | 659 | def extra_repr(self) -> str: 660 | shape = self.filters.shape 661 | return (f"fmin={self.fmin}, fmax={self.fmax}, " + 662 | f"mel_filter={shape[0]}x{shape[1]}, init={self.init}") 663 | 664 | def forward(self, linear: th.Tensor) -> th.Tensor: 665 | """ 666 | Args: 667 | linear (Tensor): linear spectrogram, N x (C) x T x F 668 | Return: 669 | fbank (Tensor): mel-fbank feature, N x (C) x T x B 670 | """ 671 | if linear.dim() not in [3, 4]: 672 | raise RuntimeError("MelTransform expect 3/4D tensor, " + 673 | f"but got {linear.dim()} instead") 674 | # N x T x F => N x T x M 675 | fbank = tf.linear(linear, self.filters.cuda(), bias=None) 676 | return fbank 677 | 678 | class inv_MelTransform(nn.Module): 679 | """ 680 | Perform mel tranform (multiply mel filters) 681 | Args: 682 | frame_len: length of the frame 683 | round_pow_of_two: if true, choose round(#power_of_two) as the FFT size 684 | sr: sample rate of souce signal 685 | num_mels: number of the mel bands 686 | fmin: lowest frequency (in Hz) 687 | fmax: highest frequency (in Hz) 688 | mel_filter: if not "", load mel filter from this 689 | requires_grad: make it trainable or not 690 | """ 691 | 692 | def __init__(self, 693 | frame_len: int, 694 | round_pow_of_two: bool = True, 695 | sr: int = 16000, 696 | num_mels: int = 40, 697 | fmin: float = 0.0, 698 | fmax: float = None, 699 | mel_matrix: str = "", 700 | coeff_norm: bool = False, 701 | requires_grad: bool = False) -> None: 702 | super(inv_MelTransform, self).__init__() 703 | if mel_matrix: 704 | # pass existed tensor for initialization 705 | filters = th.load(mel_matrix) 706 | else: 707 | # NOTE: the following mel matrix is similiar (not equal to) with 708 | # the kaldi results 709 | filters = inv_mel_filter(frame_len, 710 | round_pow_of_two=round_pow_of_two, 711 | sr=sr, 712 | num_mels=num_mels, 713 | fmax=fmax, 714 | fmin=fmin, 715 | norm=coeff_norm) 716 | self.num_mels, self.num_bins = filters.shape 717 | # num_mels x (N // 2 + 1) 718 | self.filters = nn.Parameter(filters, requires_grad=requires_grad) 719 | self.fmin = fmin 720 | self.fmax = sr // 2 if fmax is None else fmax 721 | self.init = mel_matrix if mel_matrix else "librosa" 722 | 723 | def dim(self) -> int: 724 | return self.num_mels 725 | 726 | def extra_repr(self) -> str: 727 | shape = self.filters.shape 728 | return (f"fmin={self.fmin}, fmax={self.fmax}, " + 729 | f"mel_filter={shape[0]}x{shape[1]}, init={self.init}") 730 | 731 | def forward(self, linear: th.Tensor) -> th.Tensor: 732 | """ 733 | Args: 734 | linear (Tensor): linear spectrogram, N x (C) x T x F 735 | Return: 736 | fbank (Tensor): mel-fbank feature, N x (C) x T x B 737 | """ 738 | if linear.dim() not in [3, 4]: 739 | raise RuntimeError("MelTransform expect 3/4D tensor, " + 740 | f"but got {linear.dim()} instead") 741 | # N x T x F => N x T x M 742 | fbank = tf.linear(linear, self.filters.cuda(), bias=None) 743 | return fbank 744 | 745 | # if __name__ == '__main__': 746 | # transform = STFT(400, 160) 747 | # itransform = iSTFT(400, 160) 748 | 749 | # y, fs = sf.read('in.wav') 750 | 751 | # y = th.from_numpy(y) 752 | # y = y.float() 753 | # y = y.unsqueeze(0).unsqueeze(0) 754 | # r,i = transform(y) 755 | # r,i = r.transpose(2,3),i.transpose(2,3) 756 | # mag, pha = (r**2+i**2)**0.5, th.atan2(i,r) 757 | 758 | 759 | 760 | # r, i = mag * th.cos(pha), mag * th.sin(pha) 761 | # r,i = r.transpose(2,3),i.transpose(2,3) 762 | # y = itransform((r.squeeze(0),i.squeeze(0))) 763 | # y = y.squeeze().cpu() 764 | # y = np.array(y) 765 | # sf.write('out.wav',y,fs) 766 | --------------------------------------------------------------------------------