├── .gitignore ├── LICENSE ├── README.md ├── flowchart.png ├── model.py ├── requirements.txt └── utility └── separator.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 JonathanDZ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Time-Frequency Domain Filter-and-Sum Network for Multi-channel Speech Separation 2 | 3 | This repository contains the model implementation for the paper titled "Time-Frequency Domain Filter-and-Sum Network for Multi-channel Speech Separation." Our paper proposes a new approach to multi-channel speech separation, building upon the implicit Filter-and-Sum Network (iFaSNet). We achieve this by converting each module of the iFaSNet architecture to perform separation in the time-frequency domain. Our experimental results indicate that our method is superior under the considered conditions. 4 | 5 | # Model 6 | 7 | We implement the Time-Frequency Domain Filter-and-Sum Network (TF-FaSNet) based on iFaSNet's overall structure. The network performs multi-channel speech separation in the time-frequency domain. Refer to the original paper for more information. 8 | 9 | We propose the following improvements to enhance the performance of the iFaSNet model for separating mixtures: 10 | 11 | - Use a multi-path separation module for spectral mapping in the T-F domain 12 | - Add a 2D positional encoding to facilitate attention module learning spectro-temporal information 13 | - Use narrow-band feature extraction to exploit inter-channel cues of different speakers 14 | - Add a convolution module at the end of the separation module to capture local interactions and features. 15 | 16 | The following flowchart depicts the TF-FaSNet model. 17 | 18 |

19 | 20 |

21 | 22 | # Usage 23 | 24 | A minimum implementation of the TF-FaSNet model can be found in `model.py`. 25 | 26 | ## Requirements 27 | 28 | - torch==1.13.1 29 | - torchaudio==0.13.1 30 | - positional-encodings==6.0.1 31 | 32 | ## Dataset 33 | 34 | The model is evaluated on a simulated 6-mic circular array dataset. The data generation script is available at [here](https://github.com/yluo42/TAC/tree/master/data). 35 | 36 | ## Model configurations 37 | 38 | To use our model: 39 | ``` python 40 | mix_audio = torch.randn(3,6,64000) 41 | test_model = make_TF_FaSNet( 42 | nmic=6, nspk=2, n_fft=256, embed_dim=16, 43 | dim_nb=32, dim_ffn=64, n_conv_layers=2, 44 | B=4, I=8, J=1, H=128, L=4 45 | ) 46 | separated_audio = test_model(mix_audio) 47 | ``` 48 | Each variable stands for: 49 | 50 | - General config 51 | - `nmic`: Number of microphones 52 | - `nspk`: Number of speakers 53 | - `n_fft`: Number of FFT points 54 | - `embed_dim`: Embedding dimension for each T-F unit 55 | - Encoder-decoder: 56 | - `dim_nb`: Number of hidden units in the Narrow-band feature extraction module 57 | - `dim_fft`: Number of hidden units between two linear layers in context decoding module 58 | - `n_conv_layers`: Number of convolution blocks in the context decoding module 59 | - Multi-path separation module: 60 | - `B`: Number of multi-path blocks 61 | - `I`: Kernel size for Unfold and Deconv 62 | - `J`: Stride size for Unfold and Deconv 63 | - `H`: Number of hidden units in BLSTM 64 | - `L`: Number of heads in self-attention 65 | 66 | With these configurations, we achieve an average 15.5 dB SI-SNR improvement on the simulated 6-mic circular-array dataset with a model size of 2.5M. 67 | 68 | # Miscellaneous 69 | 70 | Given a $D \times T \times F$ tensor, we apply 2D positional encoding as follows: 71 | ```math 72 | \begin{align*}PE(t,f,2i) = sin(t/10000^{4i/D})\\PE(t,f,2i+1) = cos(t/10000^{4i/D})\\PE(t,f,2j+D/2) = sin(f/10000^{4j/D})\\PE(t,f,2j+1+D/2) = cos(f/10000^{4j/D})\end{align*} 73 | ``` 74 | where $t$ indexes $T$ frames, $f$ indexes $F$ frequencies, and $i,j \in [0, D/4)$ specify the dimension. -------------------------------------------------------------------------------- /flowchart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JonathanDZ/TF-FaSNet/843a456aaabadfeeec52610c058fd22f8fa90ef2/flowchart.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | 5 | from utility.separator import Separator, SeparatorLayer, RNNModule, PositionalMultiHeadAttention 6 | 7 | 8 | class EncoderSeparatorDecoder(nn.Module): 9 | """ 10 | A standard Speech Separation architecture 11 | """ 12 | 13 | def __init__(self, encoder, decoder, separator): 14 | super(EncoderSeparatorDecoder, self).__init__() 15 | self.encoder = encoder 16 | self.decoder = decoder 17 | self.separator = separator 18 | 19 | def forward(self, mix_audio): 20 | """ 21 | Input: 22 | mix_audio: Batch, nmic, T 23 | Output: 24 | separated_audio: Batch, nspk, T 25 | """ 26 | enc_output, XrMM, ref_enc = self.encode(mix_audio) 27 | sep_output = self.separate(enc_output) 28 | dec_output = self.decode(sep_output, XrMM, ref_enc) 29 | return dec_output 30 | 31 | def encode(self, mix_audio): 32 | """ 33 | Beside returning enc_output and ref_enc, encoder also returns a XrMM 34 | (Xr_magnitude_mean: mean of the magnitude of the ref channel of X) 35 | to ensure decoder properly perform inverse normalization. 36 | """ 37 | return self.encoder(mix_audio) 38 | 39 | def separate(self, enc_output): 40 | return self.separator(enc_output) 41 | 42 | def decode(self, sep_output, XrMM, ref_enc): 43 | """ 44 | Decoder receives XrMM to perform inverse normalization. 45 | """ 46 | return self.decoder(sep_output, XrMM, ref_enc) 47 | 48 | 49 | # Encoder 50 | class Encoder(nn.Module): 51 | """ 52 | STFT -> Normalization |-> view as Narrow-band -> BLSTM |-> concat -> conv2d -> gLN 53 | |-> take ref channel -> conv2d | 54 | """ 55 | 56 | def __init__(self, n_fft=256, embed_dim=32, nmic=6, dim_nb=64): 57 | """ 58 | The sampling rate is 16 kHz. The STFT window size and hop size are 16 ms and 4 ms, respectively. 59 | n_fft = 16*16 = 256, hop_length = 4*16 = 64 60 | """ 61 | super(Encoder, self).__init__() 62 | self.n_fft = n_fft 63 | self.window = torch.hann_window(n_fft) 64 | F = n_fft//2 + 1 65 | 66 | # For ref encode 67 | self.conv2d_1 = nn.Conv2d(in_channels=2, out_channels=embed_dim, kernel_size=3, padding=1) 68 | 69 | # For narrow-band encode 70 | # dim_nb = 2*embed_dim 71 | self.embed_dim = embed_dim 72 | self.conv1d = nn.Conv1d(nmic*2, dim_nb, 4) 73 | self.layernorm = nn.LayerNorm(dim_nb) 74 | self.rnn1 = nn.LSTM(input_size=dim_nb, hidden_size=dim_nb, batch_first=True, bidirectional=True) 75 | self.linear1 = nn.Linear(2*dim_nb, dim_nb) 76 | self.deconv1d = nn.ConvTranspose1d(dim_nb, embed_dim, 4) 77 | 78 | # Final gLN 79 | self.conv2d_2 = nn.Conv2d(in_channels=2*embed_dim, out_channels=embed_dim, kernel_size=3, padding=1) 80 | self.gLN = nn.GroupNorm(1, embed_dim, eps=1e-8) 81 | 82 | def forward(self, mix_audio): 83 | """ 84 | mix_audio: batch, nmic, T 85 | """ 86 | batch, nmic, T = mix_audio.shape 87 | output = mix_audio.view(-1, T) # batch*nmic, T 88 | output = torch.stft(output, self.n_fft, 89 | window=self.window, 90 | return_complex=True) # batch*nmic, n_fft/2 + 1, T' 91 | output = output.view(batch, nmic, output.shape[-2], output.shape[-1]) # batch, nmic, n_fft/2 + 1, T' 92 | output = output.permute(0,2,3,1).contiguous() # batch, n_fft/2 + 1, T', nmic 93 | 94 | # Normalization by using reference channel 95 | F, TF = output.shape[1], output.shape[2] 96 | ref_channel = 0 97 | Xr = output[... , ref_channel].clone() # Take a ref channel copy 98 | XrMM = torch.abs(Xr).mean(dim=2) # Xr_magnitude_mean: mean of the magnitude of the ref channel of X 99 | output[:, :, :, :] /= (XrMM.reshape(batch, F, 1, 1) + 1e-8) 100 | 101 | # View as real 102 | output = torch.view_as_real(output) # [B, F, TF, C, 2] 103 | 104 | # 1) Get ref channel encoding 105 | ref_enc = output[:,:,:,ref_channel,:].clone().permute(0,3,2,1).contiguous() # batch, 2, TF, n_fft/2 + 1 106 | ref_enc = self.conv2d_1(ref_enc) # batch, embed_dim, TF, n_fft/2 + 1 107 | 108 | # 2)Get all channel narrow-band encoding 109 | nb_enc_input = output.view(batch*F, TF, nmic*2).transpose(1,2).contiguous() # batch*n_fft/2+1, nmic*2, TF 110 | nb_enc_input = self.conv1d(nb_enc_input) # batch*nfft/2+1, embed_dim, TF 111 | nb_enc = nb_enc_input.transpose(1,2).contiguous() # batch*nfft/2+1, TF, embed_dim 112 | nb_enc = self.layernorm(nb_enc) 113 | 114 | nb_enc, _ = self.rnn1(nb_enc) # batch*nfft/2+1, TF, embed_dim*2 115 | nb_enc = self.linear1(nb_enc) # batch*nfft/2+1, TF, embed_dim 116 | 117 | nb_enc = nb_enc.transpose(1,2).contiguous() # batch*nfft/2+1, embed_dim, TF 118 | nb_enc = nb_enc + nb_enc_input 119 | nb_enc = self.deconv1d(nb_enc) # batch*nfft/2+1, embed_dim, TF 120 | nb_enc = nb_enc.view(batch, F, self.embed_dim, TF).permute(0,2,3,1).contiguous() # batch, embed_dim, TF, n_fft/2+1 121 | 122 | # 3) Concat two encodings to get a ifasnet-like encoding 123 | all_enc = torch.cat([ref_enc, nb_enc], 1) # batch, 2*embed_dim, TF, n_fft/2+1 124 | all_enc = self.conv2d_2(all_enc) 125 | all_enc = self.gLN(all_enc) 126 | 127 | return all_enc, XrMM, ref_enc 128 | 129 | 130 | # Decoder 131 | class Decoder(nn.Module): 132 | "Deconv2d -> linear -> view as full-band -> inverse normalization -> iSTFT" 133 | 134 | def __init__(self, n_fft=256, embed_dim=32, nspk=2, nmic=6, n_conv_layers=2, dropout=0.1, dim_ffn=128): 135 | super(Decoder, self).__init__() 136 | # For context decoding 137 | # dim_ffn = 4*embed_dim 138 | self.conv2d_in = nn.Conv2d(in_channels=2*embed_dim, out_channels=2*embed_dim, kernel_size=3, padding=1) 139 | self.gLN = nn.GroupNorm(1, 2*embed_dim, eps=1e-8) 140 | self.linear1 = nn.Linear(2*embed_dim, dim_ffn) 141 | self.activation = nn.functional.silu 142 | 143 | convs = [] 144 | for l in range(n_conv_layers): 145 | convs.append(nn.Conv2d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=3, padding='same', groups=dim_ffn, bias=True)) 146 | convs.append(nn.GroupNorm(4, dim_ffn, eps=1e-8)) 147 | convs.append(nn.SiLU()) 148 | self.conv = nn.Sequential(*convs) 149 | 150 | # self.dropout1 = nn.Dropout(dropout) 151 | self.linear2 = nn.Linear(dim_ffn, 2*embed_dim) 152 | self.dropout2 = nn.Dropout(dropout) 153 | 154 | # Decode 155 | self.nspk = nspk 156 | self.deconv2d = nn.ConvTranspose2d(in_channels=2*embed_dim, out_channels=2*nspk, kernel_size=3, padding=1) 157 | self.linear = nn.Linear(in_features=2*nspk, out_features=2*nspk) 158 | self.n_fft = n_fft 159 | self.window = torch.hann_window(n_fft) 160 | 161 | def forward(self, x, XrMM, ref_enc): 162 | """ 163 | x: batch, D, T, F 164 | ref_enc: batch, D, T, F 165 | """ 166 | batch, _, T, F = x.shape 167 | 168 | # Add a decode process, which utilizes ref_enc info to potentially enhance its performance 169 | embedding_input = torch.cat([ref_enc, x], 1) # batch, 2*D, T, F 170 | embedding_input = self.conv2d_in(embedding_input) # batch, 2*D, T, F 171 | embedding = self.gLN(embedding_input) # batch, 2*D, T, F 172 | embedding = self._ff_block(embedding) # batch, 2*D, T, F 173 | embedding = embedding + embedding_input 174 | 175 | output = self.deconv2d(embedding) # batch, 2*nspk, T, F 176 | output = output.permute(0,3,2,1).contiguous() # batch, F, T, 2*nspk 177 | output = self.linear(output) 178 | 179 | # To complex 180 | output = output.view(batch, F, T, self.nspk, 2) # batch, F, T, nspk, 2 181 | output = torch.view_as_complex(output) # batch, F, T, nspk 182 | 183 | # Inverse normalization 184 | Ys_hat = torch.empty(size=(batch, self.nspk, F, T), dtype=torch.complex64, device=output.device) 185 | XrMM = torch.unsqueeze(XrMM, dim=2).expand(-1, -1, T) 186 | for spk in range(self.nspk): 187 | Ys_hat[:, spk, :, :] = output[:, :, :, spk] * XrMM[:, :, :] 188 | 189 | # iSTFT with frequency binding 190 | ys_hat = torch.istft(Ys_hat.view(batch * self.nspk, F, T), n_fft=self.n_fft, window=self.window, win_length=self.n_fft) 191 | ys_hat = ys_hat.view(batch, self.nspk, -1) 192 | return ys_hat 193 | 194 | # Feed forward block 195 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 196 | "x: B, 2*D, T, F" 197 | x = x.transpose(1,3).contiguous() # B, F, T, 2*D 198 | x = self.linear1(x) # B, F, T, 2*D 199 | x = self.activation(x) 200 | x = x.transpose(1,3).contiguous() # B, 2*D, T, F 201 | x = self.conv(x) 202 | x = x.transpose(1,3).contiguous() # B, F, T, 2*D 203 | # x = self.dropout1(x) 204 | x = self.linear2(x) 205 | x = x.transpose(1,3).contiguous() # B, 2*D, T, F 206 | return self.dropout2(x) 207 | 208 | 209 | def make_TF_FaSNet(nmic=6, nspk=2, n_fft=256, embed_dim=16, dim_nb=32, dim_ffn=64, n_conv_layers=2, B=4, I=8, J=1, H=128, L=4): 210 | "Helper: Construct TF-FaSNet model from hyperparameters" 211 | F = n_fft//2 + 1 212 | E = embed_dim//L 213 | 214 | c = copy.deepcopy 215 | RNN_module = RNNModule(hidden_size=H, kernel_size=I, stride=J, embed_dim=embed_dim) 216 | self_attn = PositionalMultiHeadAttention(h=L, d_model=embed_dim, d_q=E, F=F) 217 | model = EncoderSeparatorDecoder( 218 | encoder=Encoder(n_fft=n_fft, embed_dim=embed_dim, nmic=nmic, dim_nb=dim_nb), 219 | decoder=Decoder(n_fft=n_fft, embed_dim=embed_dim, nspk=nspk, n_conv_layers=n_conv_layers, dim_ffn=dim_ffn), 220 | separator=Separator(SeparatorLayer(c(RNN_module), c(RNN_module), c(self_attn)), N=B) 221 | ) 222 | 223 | # Initialize parameters with Glorot / fan_avg 224 | for p in model.parameters(): 225 | if p.dim() > 1: 226 | nn.init.xavier_uniform_(p) 227 | return model 228 | 229 | def check_parameters(net): 230 | ''' 231 | Returns module parameters. Mb 232 | ''' 233 | parameters = sum(param.numel() for param in net.parameters()) 234 | return parameters / 10**6 235 | 236 | if __name__ == "__main__": 237 | 238 | test_model = make_TF_FaSNet() 239 | test_model.eval() 240 | 241 | # Check model size 242 | print(check_parameters(test_model)) 243 | 244 | # Test full model 245 | mix_audio = torch.randn(3,6,64000) 246 | separated_audio = test_model(mix_audio) 247 | print(separated_audio.shape) 248 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | positional-encodings==6.0.1 2 | torch==1.13.1 3 | torchaudio==0.13.1 -------------------------------------------------------------------------------- /utility/separator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import copy 5 | from positional_encodings.torch_encodings import PositionalEncoding2D, Summer 6 | 7 | 8 | # Separator 9 | def clones(module, N): 10 | "Produce N identical layers" 11 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 12 | 13 | class Separator(nn.Module): 14 | "Separator is a stack of N layers" 15 | 16 | def __init__(self, layer, N=6): 17 | super(Separator, self).__init__() 18 | self.layers = clones(layer, N) 19 | self.N = N 20 | 21 | def forward(self, x): 22 | "Pass the input through each layer in turn" 23 | for layer in self.layers: 24 | x = layer(x) 25 | return x 26 | 27 | 28 | # Multi-path architecture 29 | class SeparatorLayer(nn.Module): 30 | "Separator is made up of spectral module, temporal module and self-attention module(defined below)" 31 | 32 | def __init__(self, RNN_module_s, RNN_module_t, self_attn): 33 | super(SeparatorLayer, self).__init__() 34 | self.spectral = RNN_module_s 35 | self.temporal = RNN_module_t 36 | self.self_attn = self_attn 37 | 38 | def forward(self, x): 39 | """ 40 | x: batch, D, T, F 41 | """ 42 | output = self.spectral(x) # batch, D, T, F 43 | 44 | output = output.transpose(2,3).contiguous() 45 | output = self.temporal(output) # batch, D, F, T 46 | 47 | output = output.transpose(2,3).contiguous() 48 | output = self.self_attn(output) # batch, D, T, F 49 | 50 | return output 51 | 52 | 53 | # Spectral & temporal module 54 | class RNNModule(nn.Module): 55 | "Unfold -> LN -> BLSTM -> Deconv1D -> residual" 56 | 57 | def __init__(self, hidden_size=256, kernel_size=8, stride=1, embed_dim=32, dropout=0, bidirectional=True): 58 | super(RNNModule, self).__init__() 59 | self.stride = stride 60 | self.kernel_size = kernel_size 61 | 62 | self.layernorm = nn.LayerNorm(embed_dim*kernel_size) 63 | self.rnn = nn.LSTM(embed_dim*kernel_size, hidden_size, 1, dropout=dropout, batch_first=True, bidirectional=bidirectional) # N,L,F -> N,L,2H 64 | self.deconv1d = nn.ConvTranspose1d(2*hidden_size, embed_dim, kernel_size, stride=stride) 65 | 66 | 67 | def forward(self, x): 68 | "x: batch, D, dim1, dim2" 69 | batch, embed_dim, dim1, dim2 = x.shape 70 | 71 | output = x.unfold(-1, self.kernel_size, self.stride) # batch, D, dim1, dim2/stride, kernel_size 72 | 73 | output = output.permute(0,2,3,4,1).contiguous().view(batch, dim1, -1, self.kernel_size*embed_dim) 74 | output = self.layernorm(output) # batch, dim1, dim2/stride, D*kernel_size 75 | 76 | output = output.view(batch*dim1, -1, self.kernel_size*embed_dim) 77 | output, _ = self.rnn(output) # batch*dim1, dim2/stride, 2*hidden_size 78 | 79 | output = output.contiguous().transpose(1,2).contiguous() # batch*dim1, 2*hidden_size, dim2/stride 80 | output = self.deconv1d(output) # batch*dim1, D, dim2 81 | 82 | output = output.view(batch, dim1, embed_dim, dim2).permute(0,2,1,3).contiguous() 83 | output = output + x 84 | 85 | return output 86 | 87 | 88 | # Pre-processing method to generate final result in Attention module 89 | class Generator(nn.Module): 90 | "1X1Conv2d -> PReLU -> cfLN" 91 | 92 | def __init__(self, input_dim=32, output_dim=4, F=129): 93 | super(Generator, self).__init__() 94 | self.conv2d = nn.Conv2d(input_dim, output_dim, 1) 95 | self.prelu = nn.PReLU() 96 | self.cfLN = nn.LayerNorm([output_dim, F]) 97 | 98 | def forward(self, x): 99 | "x: batch, embed_dim, T, F" 100 | output = self.conv2d(x) # batch, output_dim, T, F 101 | output = self.prelu(output) # batch, output_dim, T, F 102 | 103 | output = output.transpose(1,2).contiguous() # batch, T, output_dim, F 104 | output = self.cfLN(output) 105 | output = output.transpose(1,2).contiguous() 106 | 107 | return output 108 | 109 | 110 | # Pre-processing method to generate batched qkv 111 | class MultiHeadGenerator(nn.Module): 112 | "1X1Conv2d -> PReLU -> cfLN" 113 | 114 | def __init__(self, input_dim=32, output_dim=4, F=129, h=4): 115 | super(MultiHeadGenerator, self).__init__() 116 | self.h = h 117 | self.output_dim = output_dim 118 | self.conv2d = nn.Conv2d(input_dim, h*output_dim, 1) 119 | self.prelu = nn.PReLU() 120 | self.cfLN = nn.LayerNorm([output_dim, F]) 121 | 122 | def forward(self, x): 123 | "x: batch, embed_dim, T, F" 124 | batch, _, T, F = x.shape 125 | output = self.conv2d(x) # batch, h*output_dim, T, F 126 | output = self.prelu(output) # batch, h*output_dim, T, F 127 | 128 | output = output.view(batch, self.h, self.output_dim, T, F).transpose(2,3).contiguous() # batch, h, T, ouput_dim, F 129 | output = self.cfLN(output) 130 | output = output.transpose(3,4).contiguous() # batch, h, T, F, ouput_dim 131 | output = output.view(batch, self.h, T, F*self.output_dim) 132 | 133 | return output 134 | 135 | 136 | # Dot-product Attention 137 | def attention(query, key, value, mask=None, dropout=None): 138 | """ 139 | Compute 'Scaled Dot Product Attention' 140 | Q: batch, T, FxE 141 | K: batch, T, FxE 142 | V: batch, T, FxD/L 143 | """ 144 | d_k = query.size(-1) 145 | scores = torch.matmul(query, key.transpose(-2, -1) / math.sqrt(d_k)) 146 | if mask is not None: 147 | scores = scores.masked_fill(mask == 0, -1e9) 148 | p_attn = scores.softmax(dim=-1) 149 | if dropout is not None: 150 | p_attn = dropout(p_attn) 151 | return torch.matmul(p_attn, value), p_attn 152 | 153 | 154 | # Multi-head Attention with 2D positional embedding (full-band self-attention module) 155 | class PositionalMultiHeadAttention(nn.Module): 156 | """ 157 | Multi-head attention with 2D positional encoding. 158 | This concept (2D PE) was proposed in the "Translating 159 | Math Formula Images to LaTeX Sequences Using 160 | Deep Neural Networks with Sequence-level Training". 161 | """ 162 | 163 | def __init__(self, h=4, d_model=32, d_q=4, F=129, dropout=0.1): 164 | super(PositionalMultiHeadAttention, self).__init__() 165 | 166 | assert d_model % h == 0 167 | self.h = h 168 | self.d_q = d_q 169 | self.d_k = d_q 170 | self.d_v = d_model // h 171 | self.RPE_size = F*d_q # FXE 172 | 173 | self.q_proj = MultiHeadGenerator(d_model, self.d_q, F, h) 174 | self.k_proj = MultiHeadGenerator(d_model, self.d_k, F, h) 175 | self.v_proj = MultiHeadGenerator(d_model, self.d_v, F, h) 176 | self.out_proj = Generator(d_model, d_model, F) 177 | 178 | self.p_enc_2d = Summer(PositionalEncoding2D(F)) 179 | 180 | self.attn = None 181 | self.dropout = nn.Dropout(p=dropout) 182 | 183 | def forward(self, x, mask=None): 184 | "x: batch, D, T, F" 185 | if mask is not None: 186 | # Same mask applied to all h heads 187 | mask = mask.unsqueeze(1) 188 | batch, _, T, F = x.shape 189 | 190 | # Do all the projections in batch 191 | query = self.q_proj(self.p_enc_2d(x)) # B, N, T, FXE 192 | key = self.k_proj(self.p_enc_2d(x)) # B, N, T, FXE 193 | value = self.v_proj(x) # B, N, T, FX(D/L) 194 | 195 | output, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) 196 | 197 | output = output.view(batch, self.h, T, F, self.d_v).permute(0,1,4,2,3).contiguous().view(batch, self.h * self.d_v, T, F) 198 | output = self.out_proj(output) 199 | output = output + x 200 | 201 | del query 202 | del key 203 | del value 204 | 205 | return output 206 | 207 | if __name__ == "__main__": 208 | # Test RNNModule 209 | # rnn_module = RNNModule(hidden_size=256) 210 | # x = torch.rand(3, 32, 1001, 129) 211 | # y = rnn_module(x) 212 | # print(y.shape) 213 | 214 | # Test MultiheadAttention 215 | B, D, T, F = 3, 32, 1001, 129 216 | x = torch.randn(B,D,T,F) 217 | model = PositionalMultiHeadAttention(h=4, d_model=D, d_q=4, F=F) 218 | y = model(x) 219 | print(y.shape) 220 | --------------------------------------------------------------------------------