├── LICENSE ├── asa.py ├── erb.py ├── f_sampling.py ├── mtfaa.py ├── phase_encoder.py ├── readme.md ├── stft.py └── tfcm.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Shimin Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /asa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Axial Soft Attention (ASA). 3 | NOTE I recommend that you remove the t-attention and only keep 4 | the f-attention when using it, because there is already TFCMs 5 | to time-modeling, and doing so can greatly increase the batch size. 6 | 7 | shmzhang@aslp-npu.org, 2022 8 | """ 9 | 10 | import einops 11 | import torch as th 12 | import torch.nn as nn 13 | 14 | 15 | def max_neg_value(t): 16 | return -th.finfo(t.dtype).max 17 | 18 | 19 | class ASA(nn.Module): 20 | def __init__(self, c=64, causal=True): 21 | super(ASA, self).__init__() 22 | self.d_c = c//4 23 | self.f_qkv = nn.Sequential( 24 | nn.Conv2d(c, self.d_c*3, kernel_size=(1, 1), bias=False), 25 | nn.BatchNorm2d(self.d_c*3), 26 | nn.PReLU(self.d_c*3), 27 | ) 28 | self.t_qk = nn.Sequential( 29 | nn.Conv2d(c, self.d_c*2, kernel_size=(1, 1), bias=False), 30 | nn.BatchNorm2d(self.d_c*2), 31 | nn.PReLU(self.d_c*2), 32 | ) 33 | self.proj = nn.Sequential( 34 | nn.Conv2d(self.d_c, c, kernel_size=(1, 1), bias=False), 35 | nn.BatchNorm2d(c), 36 | nn.PReLU(c), 37 | ) 38 | self.causal = causal 39 | 40 | def forward(self, inp): 41 | """ 42 | inp: B C F T 43 | """ 44 | # f-attention 45 | f_qkv = self.f_qkv(inp) 46 | qf, kf, v = tuple(einops.rearrange( 47 | f_qkv, "b (c k) f t->k b c f t", k=3)) 48 | f_score = th.einsum("bcft,bcyt->btfy", qf, kf) / (self.d_c**0.5) 49 | f_score = f_score.softmax(dim=-1) 50 | f_out = th.einsum('btfy,bcyt->bcft', [f_score, v]) 51 | # t-attention 52 | t_qk = self.t_qk(inp) 53 | qt, kt = tuple(einops.rearrange(t_qk, "b (c k) f t->k b c f t", k=2)) 54 | t_score = th.einsum('bcft,bcfy->bfty', [qt, kt]) / (self.d_c**0.5) 55 | mask_value = max_neg_value(t_score) 56 | if self.causal: 57 | i, j = t_score.shape[-2:] 58 | mask = th.ones(i, j, device=t_score.device).triu_(j - i + 1).bool() 59 | t_score.masked_fill_(mask, mask_value) 60 | t_score = t_score.softmax(dim=-1) 61 | t_out = th.einsum('bfty,bcfy->bcft', [t_score, f_out]) 62 | out = self.proj(t_out) 63 | return out + inp 64 | 65 | 66 | def test_asa(): 67 | nnet = ASA(c=64) 68 | inp = th.randn(2, 64, 256, 100) 69 | out = nnet(inp) 70 | print('out: ', out.shape) 71 | 72 | 73 | if __name__ == "__main__": 74 | test_asa() 75 | -------------------------------------------------------------------------------- /erb.py: -------------------------------------------------------------------------------- 1 | """ 2 | linear FBank instead of ERB scale. 3 | NOTE To to reduce the reconstruction error, the linear fbank is used. 4 | shmzhang@aslp-npu.org, 2022 5 | """ 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | from spafe.fbanks import linear_fbanks 10 | 11 | 12 | class Banks(nn.Module): 13 | def __init__(self, nfilters, nfft, fs, low_freq=None, high_freq=None, learnable=False): 14 | super(Banks, self).__init__() 15 | self.nfilters, self.nfft, self.fs = nfilters, nfft, fs 16 | filter, _ = linear_fbanks.linear_filter_banks( 17 | nfilts=self.nfilters, 18 | nfft=self.nfft, 19 | low_freq=low_freq, 20 | high_freq=high_freq, 21 | fs=self.fs, 22 | ) 23 | filter = th.from_numpy(filter).float() 24 | if not learnable: 25 | # 30% energy compensation. 26 | self.register_buffer('filter', filter*1.3) 27 | self.register_buffer('filter_inv', th.pinverse(filter)) 28 | else: 29 | self.filter = nn.Parameter(filter) 30 | self.filter_inv = nn.Parameter(th.pinverse(filter)) 31 | 32 | def amp2bank(self, amp): 33 | amp_feature = th.einsum("bcft,kf->bckt", amp, self.filter) 34 | return amp_feature 35 | 36 | def bank2amp(self, inputs): 37 | return th.einsum("bckt,fk->bcft", inputs, self.filter_inv) 38 | 39 | 40 | def test_bank(): 41 | import soundfile as sf 42 | import numpy as np 43 | from stft import STFT 44 | stft = STFT(32*48, 8*48, 32*48, "hann") 45 | net = Banks(256, 32*48, 48000) 46 | sig_raw, sr = sf.read("path/to/48k.wav") 47 | sig = th.from_numpy(sig_raw)[None, :].float() 48 | cspec = stft.transform(sig) 49 | mag = th.norm(cspec, dim=1) 50 | phase = th.atan2(cspec[:,1,:,:], cspec[:,0,:,:]) 51 | mag = mag.unsqueeze(dim=1) 52 | outs = net.amp2bank(mag) 53 | outs = net.bank2amp(outs) 54 | print(th.nn.functional.mse_loss(outs, mag)) 55 | outs = outs.squeeze(dim=1) 56 | real = outs * th.cos(phase) 57 | imag = outs * th.sin(phase) 58 | sig_rec = stft.inverse(real, imag) 59 | sig_rec = sig_rec.cpu().data.numpy()[0] 60 | min_len = min(len(sig_rec), len(sig_raw)) 61 | sf.write("res.wav", np.stack( 62 | [sig_rec[:min_len], sig_raw[:min_len]], axis=1), sr) 63 | print(np.mean(np.square(sig_rec[:min_len] - sig_raw[:min_len]))) 64 | 65 | 66 | if __name__ == '__main__': 67 | test_bank() 68 | -------------------------------------------------------------------------------- /f_sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Frequency Down/Up Sampling. 3 | 4 | shmzhang@aslp-npu.org, 2022 5 | """ 6 | 7 | 8 | import torch as th 9 | import torch.nn as nn 10 | 11 | 12 | class FD(nn.Module): 13 | def __init__(self, cin, cout, K=(7, 1), S=(4, 1), P=(2, 0)): 14 | super(FD, self).__init__() 15 | self.fd = nn.Sequential( 16 | nn.Conv2d(cin, cout, K, S, P, groups=2), 17 | nn.BatchNorm2d(cout), 18 | nn.PReLU(cout) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.fd(x) 23 | 24 | 25 | class FU(nn.Module): 26 | def __init__(self, cin, cout, K=(7, 1), S=(4, 1), P=(2, 0), O=(1, 0)): 27 | super(FU, self).__init__() 28 | self.pconv1 = nn.Sequential( 29 | nn.Conv2d(cin*2, cin, (1, 1)), 30 | nn.BatchNorm2d(cin), 31 | nn.Tanh(), 32 | ) 33 | self.pconv2 = nn.Sequential( 34 | nn.Conv2d(cin, cout, (1, 1)), 35 | nn.BatchNorm2d(cout), 36 | nn.PReLU(cout), 37 | ) 38 | # 22/06/13 update, add groups = 2 39 | self.conv3 = nn.Sequential( 40 | nn.ConvTranspose2d(cout, cout, K, S, P, O, groups=2), 41 | nn.BatchNorm2d(cout), 42 | nn.PReLU(cout) 43 | ) 44 | 45 | def forward(self, fu, fd): 46 | """ 47 | fu, fd: B C F T 48 | """ 49 | outs = self.pconv1(th.cat([fu, fd], dim=1))*fd 50 | outs = self.pconv2(outs) 51 | outs = self.conv3(outs) 52 | return outs 53 | 54 | 55 | def test_fd(): 56 | net = FD(4, 8) 57 | inps = th.randn(3, 4, 256, 101) 58 | print(net(inps).shape) 59 | 60 | 61 | if __name__ == "__main__": 62 | test_fd() 63 | -------------------------------------------------------------------------------- /mtfaa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-scale temporal frequency axial attention neural network (MTFAA). 3 | 4 | shmzhang@aslp-npu.org, 2022 5 | """ 6 | 7 | import torch as th 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as tf 11 | from typing import List 12 | 13 | from tfcm import TFCM 14 | from asa import ASA 15 | from phase_encoder import PhaseEncoder 16 | from f_sampling import FD, FU 17 | from erb import Banks 18 | from stft import STFT 19 | 20 | 21 | def parse_1dstr(sstr: str) -> List[int]: 22 | return list(map(int, sstr.split(","))) 23 | 24 | 25 | def parse_2dstr(sstr: str) -> List[List[int]]: 26 | return [parse_1dstr(tok) for tok in sstr.split(";")] 27 | 28 | 29 | eps = 1e-10 30 | 31 | 32 | class MTFAANet(nn.Module): 33 | 34 | def __init__(self, 35 | n_sig=1, 36 | PEc=4, 37 | Co="48,96,192", 38 | O="1,1,1", 39 | causal=True, 40 | bottleneck_layer=2, 41 | tfcm_layer=6, 42 | mag_f_dim=3, 43 | win_len=32*48, 44 | win_hop=8*48, 45 | nerb=256, 46 | sr=48000, 47 | win_type="hann", 48 | ): 49 | super(MTFAANet, self).__init__() 50 | self.PE = PhaseEncoder(PEc, n_sig) 51 | # 32ms @ 48kHz 52 | self.stft = STFT(win_len, win_hop, win_len, win_type) 53 | self.ERB = Banks(nerb, win_len, sr) 54 | self.encoder_fd = nn.ModuleList() 55 | self.encoder_bn = nn.ModuleList() 56 | self.bottleneck = nn.ModuleList() 57 | self.decoder_fu = nn.ModuleList() 58 | self.decoder_bn = nn.ModuleList() 59 | C_en = [PEc//2*n_sig] + parse_1dstr(Co) 60 | C_de = [4] + parse_1dstr(Co) 61 | O = parse_1dstr(O) 62 | for idx in range(len(C_en)-1): 63 | self.encoder_fd.append( 64 | FD(C_en[idx], C_en[idx+1]), 65 | ) 66 | self.encoder_bn.append( 67 | nn.Sequential( 68 | TFCM(C_en[idx+1], (3, 3), 69 | tfcm_layer=tfcm_layer, causal=causal), 70 | ASA(C_en[idx+1], causal=causal), 71 | ) 72 | ) 73 | 74 | for idx in range(bottleneck_layer): 75 | self.bottleneck.append( 76 | nn.Sequential( 77 | TFCM(C_en[-1], (3, 3), 78 | tfcm_layer=tfcm_layer, causal=causal), 79 | ASA(C_en[-1], causal=causal), 80 | ) 81 | ) 82 | 83 | for idx in range(len(C_de)-1, 0, -1): 84 | self.decoder_fu.append( 85 | FU(C_de[idx], C_de[idx-1], O=(O[idx-1], 0)), 86 | ) 87 | self.decoder_bn.append( 88 | nn.Sequential( 89 | TFCM(C_de[idx-1], (3, 3), 90 | tfcm_layer=tfcm_layer, causal=causal), 91 | ASA(C_de[idx-1], causal=causal), 92 | ) 93 | ) 94 | # MEA is causal, so mag_t_dim = 1. 95 | self.mag_mask = nn.Conv2d( 96 | 4, mag_f_dim, kernel_size=(3, 1), padding=(1, 0)) 97 | self.real_mask = nn.Conv2d(4, 1, kernel_size=(3, 1), padding=(1, 0)) 98 | self.imag_mask = nn.Conv2d(4, 1, kernel_size=(3, 1), padding=(1, 0)) 99 | kernel = th.eye(mag_f_dim) 100 | kernel = kernel.reshape(mag_f_dim, 1, mag_f_dim, 1) 101 | self.register_buffer('kernel', kernel) 102 | self.mag_f_dim = mag_f_dim 103 | 104 | def forward(self, sigs): 105 | """ 106 | sigs: list [B N] of len(sigs) 107 | """ 108 | cspecs = [] 109 | for sig in sigs: 110 | cspecs.append(self.stft.transform(sig)) 111 | # D / E ? 112 | D_cspec = cspecs[0] 113 | mag = th.norm(D_cspec, dim=1) 114 | pha = torch.atan2(D_cspec[:, -1, ...], D_cspec[:, 0, ...]) 115 | out = self.ERB.amp2bank(self.PE(cspecs)) 116 | encoder_out = [] 117 | for idx in range(len(self.encoder_fd)): 118 | out = self.encoder_fd[idx](out) 119 | encoder_out.append(out) 120 | out = self.encoder_bn[idx](out) 121 | 122 | for idx in range(len(self.bottleneck)): 123 | out = self.bottleneck[idx](out) 124 | 125 | for idx in range(len(self.decoder_fu)): 126 | out = self.decoder_fu[idx](out, encoder_out[-1-idx]) 127 | out = self.decoder_bn[idx](out) 128 | out = self.ERB.bank2amp(out) 129 | # stage 1 130 | mag_mask = self.mag_mask(out) 131 | mag_pad = tf.pad( 132 | mag[:, None], [0, 0, (self.mag_f_dim-1)//2, (self.mag_f_dim-1)//2]) 133 | mag = tf.conv2d(mag_pad, self.kernel) 134 | mag = mag * mag_mask.sigmoid() 135 | mag = mag.sum(dim=1) 136 | # stage 2 137 | real_mask = self.real_mask(out).squeeze(1) 138 | imag_mask = self.imag_mask(out).squeeze(1) 139 | 140 | mag_mask = th.sqrt(th.clamp(real_mask**2+imag_mask**2, eps)) 141 | pha_mask = th.atan2(imag_mask+eps, real_mask+eps) 142 | real = mag * mag_mask.tanh() * th.cos(pha+pha_mask) 143 | imag = mag * mag_mask.tanh() * th.sin(pha+pha_mask) 144 | return mag, th.stack([real, imag], dim=1), self.stft.inverse(real, imag) 145 | 146 | 147 | def test_nnet(): 148 | # noise supression (microphone, ) 149 | nnet = MTFAANet(n_sig=1) 150 | inp = th.randn(3, 48000) 151 | mag, cspec, wav = nnet([inp]) 152 | print(mag.shape, cspec.shape, wav.shape) 153 | # echo cancellation (microphone, error, reference,) 154 | nnet = MTFAANet(n_sig=3) 155 | mag, cspec, wav = nnet([inp, inp, inp]) 156 | print(mag.shape, cspec.shape, wav.shape) 157 | 158 | 159 | def test_mac(): 160 | from thop import profile, clever_format 161 | import torch as th 162 | nnet = MTFAANet(n_sig=3) 163 | # hop=8ms, win=32ms@48KHz, process 1s. 164 | inp = th.randn(1, 48000) 165 | # inp = th.randn(1, 2, 769, 126) 166 | macs, params = profile(nnet, inputs=([inp, inp, inp],), verbose=False) 167 | macs, params = clever_format([macs, params], "%.3f") 168 | print('macs: ', macs) 169 | print('params: ', params) 170 | 171 | 172 | if __name__ == "__main__": 173 | # test_nnet() 174 | test_mac() 175 | -------------------------------------------------------------------------------- /phase_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Phase Encoder (PE). 3 | 4 | shmzhang@aslp-npu.org, 2022 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class ComplexConv2d(nn.Module): 12 | def __init__( 13 | self, 14 | in_channels, 15 | out_channels, 16 | kernel_size=(1, 1), 17 | stride=(1, 1), 18 | padding=(0, 0), 19 | dilation=1, 20 | groups=1, 21 | causal=True, 22 | complex_axis=1, 23 | ): 24 | super(ComplexConv2d, self).__init__() 25 | self.in_channels = in_channels//2 26 | self.out_channels = out_channels//2 27 | self.kernel_size = kernel_size 28 | self.stride = stride 29 | self.padding = padding 30 | self.causal = causal 31 | self.groups = groups 32 | self.dilation = dilation 33 | self.complex_axis = complex_axis 34 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, padding=[ 35 | self.padding[0], 0], dilation=self.dilation, groups=self.groups) 36 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride, padding=[ 37 | self.padding[0], 0], dilation=self.dilation, groups=self.groups) 38 | 39 | nn.init.normal_(self.real_conv.weight.data, std=0.05) 40 | nn.init.normal_(self.imag_conv.weight.data, std=0.05) 41 | nn.init.constant_(self.real_conv.bias, 0.) 42 | nn.init.constant_(self.imag_conv.bias, 0.) 43 | 44 | def forward(self, inputs): 45 | if self.padding[1] != 0 and self.causal: 46 | inputs = F.pad(inputs, [self.padding[1], 0, 0, 0]) 47 | else: 48 | inputs = F.pad(inputs, [self.padding[1], self.padding[1], 0, 0]) 49 | 50 | if self.complex_axis == 0: 51 | real = self.real_conv(inputs) 52 | imag = self.imag_conv(inputs) 53 | real2real, imag2real = torch.chunk(real, 2, self.complex_axis) 54 | real2imag, imag2imag = torch.chunk(imag, 2, self.complex_axis) 55 | 56 | else: 57 | if isinstance(inputs, torch.Tensor): 58 | real, imag = torch.chunk(inputs, 2, self.complex_axis) 59 | 60 | real2real = self.real_conv(real,) 61 | imag2imag = self.imag_conv(imag,) 62 | 63 | real2imag = self.imag_conv(real) 64 | imag2real = self.real_conv(imag) 65 | 66 | real = real2real - imag2imag 67 | imag = real2imag + imag2real 68 | out = torch.cat([real, imag], self.complex_axis) 69 | return out 70 | 71 | 72 | def complex_cat(inps, dim=1): 73 | reals, imags = [], [] 74 | for inp in inps: 75 | real, imag = inp.chunk(2, dim) 76 | reals.append(real) 77 | imags.append(imag) 78 | reals = torch.cat(reals, dim) 79 | imags = torch.cat(imags, dim) 80 | return reals, imags 81 | 82 | 83 | class ComplexLinearProjection(nn.Module): 84 | def __init__(self, cin): 85 | super(ComplexLinearProjection, self).__init__() 86 | self.clp = ComplexConv2d(cin, cin) 87 | 88 | def forward(self, real, imag): 89 | """ 90 | real, imag: B C F T 91 | """ 92 | inputs = torch.cat([real, imag], 1) 93 | outputs = self.clp(inputs) 94 | real, imag = outputs.chunk(2, dim=1) 95 | outputs = torch.sqrt(real**2+imag**2+1e-8) 96 | return outputs 97 | 98 | 99 | class PhaseEncoder(nn.Module): 100 | def __init__(self, cout, n_sig, cin=2, alpha=0.5): 101 | super(PhaseEncoder, self).__init__() 102 | self.complexnn = nn.ModuleList() 103 | for _ in range(n_sig): 104 | self.complexnn.append( 105 | nn.Sequential( 106 | nn.ConstantPad2d((2, 0, 0, 0), 0.0), 107 | ComplexConv2d(cin, cout, (1, 3)) 108 | ) 109 | ) 110 | self.clp = ComplexLinearProjection(cout*n_sig) 111 | self.alpha = alpha 112 | 113 | def forward(self, cspecs): 114 | """ 115 | cspec: B C F T 116 | """ 117 | outs = [] 118 | for idx, layer in enumerate(self.complexnn): 119 | outs.append(layer(cspecs[idx])) 120 | real, imag = complex_cat(outs, dim=1) 121 | amp = self.clp(real, imag) 122 | return amp**self.alpha 123 | 124 | 125 | if __name__ == "__main__": 126 | net = PhaseEncoder(cout=4, n_sig=1) 127 | # 32ms@48kHz, concatenation of [real, imag], dim=1 128 | inps = torch.randn(3, 2, 769, 126) 129 | outs = net([inps]) 130 | print(outs.shape) 131 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # MTFAA-Net 2 | 3 | Unofficial PyTorch implementation of Baidu's MTFAA-Net: "[Multi-Scale Temporal Frequency Convolutional Network With Axial Attention for Speech Enhancement](https://ieeexplore.ieee.org/document/9746610)". 4 | 5 | ## some whls 6 | ```shell 7 | # under your python env. 8 | pip install einops 9 | pip install spafe 10 | ``` 11 | 12 | ## some bugs? 13 | 14 | Implementation details may not be consistent with the paper, and any comments are welcome (shmzhang@npu-aslp.org) -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple wrapper for torch built-in STFT. 3 | shmzhang@aslp-npu.org, 2022 4 | """ 5 | 6 | import torch as th 7 | import torch.nn as nn 8 | import einops 9 | 10 | 11 | class STFT(nn.Module): 12 | def __init__(self, win_len, hop_len, fft_len, win_type): 13 | super(STFT, self).__init__() 14 | self.win, self.hop = win_len, hop_len 15 | self.nfft = fft_len 16 | window = { 17 | "hann": th.hann_window(win_len), 18 | "hamm": th.hamming_window(win_len), 19 | } 20 | assert win_type in window.keys() 21 | self.window = window[win_type] 22 | 23 | def transform(self, inp): 24 | """ 25 | inp: B N 26 | """ 27 | cspec = th.stft(inp, self.nfft, self.hop, self.win, 28 | self.window.to(inp.device), return_complex=False) 29 | cspec = einops.rearrange(cspec, "b f t c -> b c f t") 30 | return cspec 31 | 32 | def inverse(self, real, imag): 33 | """ 34 | real, imag: B F T 35 | """ 36 | inp = th.stack([real, imag], dim=-1) 37 | return th.istft(inp, self.nfft, self.hop, self.win, self.window.to(real.device)) 38 | -------------------------------------------------------------------------------- /tfcm.py: -------------------------------------------------------------------------------- 1 | """ 2 | TCN modules (TCM) -> TFCN modules (TFCM). 3 | 4 | shmzhang@aslp-npu.org, 2022 5 | """ 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | class TFCM_Block(nn.Module): 12 | def __init__(self, 13 | cin=24, 14 | K=(3, 3), 15 | dila=1, 16 | causal=True, 17 | ): 18 | super(TFCM_Block, self).__init__() 19 | self.pconv1 = nn.Sequential( 20 | nn.Conv2d(cin, cin, kernel_size=(1, 1)), 21 | nn.BatchNorm2d(cin), 22 | nn.PReLU(cin), 23 | ) 24 | dila_pad = dila * (K[1] - 1) 25 | if causal: 26 | self.dila_conv = nn.Sequential( 27 | nn.ConstantPad2d((dila_pad, 0, 1, 1), 0.0), 28 | nn.Conv2d(cin, cin, K, 1, dilation=(1, dila), groups=cin), 29 | nn.BatchNorm2d(cin), 30 | nn.PReLU(cin) 31 | ) 32 | else: 33 | # update 22/06/21, add groups for non-casual 34 | self.dila_conv = nn.Sequential( 35 | nn.ConstantPad2d((dila_pad//2, dila_pad//2, 1, 1), 0.0), 36 | nn.Conv2d(cin, cin, K, 1, dilation=(1, dila), groups=cin), 37 | nn.BatchNorm2d(cin), 38 | nn.PReLU(cin) 39 | ) 40 | self.pconv2 = nn.Conv2d(cin, cin, kernel_size=(1, 1)) 41 | self.causal = causal 42 | self.dila_pad = dila_pad 43 | 44 | def forward(self, inps): 45 | """ 46 | inp: B x C x F x T 47 | """ 48 | outs = self.pconv1(inps) 49 | outs = self.dila_conv(outs) 50 | outs = self.pconv2(outs) 51 | return outs + inps 52 | 53 | 54 | class TFCM(nn.Module): 55 | def __init__(self, 56 | cin=24, 57 | K=(3, 3), 58 | tfcm_layer=6, 59 | causal=True, 60 | ): 61 | super(TFCM, self).__init__() 62 | self.tfcm = nn.ModuleList() 63 | for idx in range(tfcm_layer): 64 | self.tfcm.append( 65 | TFCM_Block(cin, K, 2**idx, causal=causal) 66 | ) 67 | 68 | def forward(self, inp): 69 | out = inp 70 | for idx in range(len(self.tfcm)): 71 | out = self.tfcm[idx](out) 72 | return out 73 | 74 | 75 | def test_tfcm(): 76 | nnet = TFCM(24) 77 | inp = th.randn(2, 24, 256, 101) 78 | out = nnet(inp) 79 | print(out.shape) 80 | 81 | 82 | if __name__ == "__main__": 83 | test_tfcm() 84 | --------------------------------------------------------------------------------