├── .gitignore
├── LICENSE
├── README.md
├── flowchart.png
├── model.py
├── requirements.txt
└── utility
└── separator.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 JonathanDZ
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Time-Frequency Domain Filter-and-Sum Network for Multi-channel Speech Separation
2 |
3 | This repository contains the model implementation for the paper titled "Time-Frequency Domain Filter-and-Sum Network for Multi-channel Speech Separation." Our paper proposes a new approach to multi-channel speech separation, building upon the implicit Filter-and-Sum Network (iFaSNet). We achieve this by converting each module of the iFaSNet architecture to perform separation in the time-frequency domain. Our experimental results indicate that our method is superior under the considered conditions.
4 |
5 | # Model
6 |
7 | We implement the Time-Frequency Domain Filter-and-Sum Network (TF-FaSNet) based on iFaSNet's overall structure. The network performs multi-channel speech separation in the time-frequency domain. Refer to the original paper for more information.
8 |
9 | We propose the following improvements to enhance the performance of the iFaSNet model for separating mixtures:
10 |
11 | - Use a multi-path separation module for spectral mapping in the T-F domain
12 | - Add a 2D positional encoding to facilitate attention module learning spectro-temporal information
13 | - Use narrow-band feature extraction to exploit inter-channel cues of different speakers
14 | - Add a convolution module at the end of the separation module to capture local interactions and features.
15 |
16 | The following flowchart depicts the TF-FaSNet model.
17 |
18 |
19 |
20 |
21 |
22 | # Usage
23 |
24 | A minimum implementation of the TF-FaSNet model can be found in `model.py`.
25 |
26 | ## Requirements
27 |
28 | - torch==1.13.1
29 | - torchaudio==0.13.1
30 | - positional-encodings==6.0.1
31 |
32 | ## Dataset
33 |
34 | The model is evaluated on a simulated 6-mic circular array dataset. The data generation script is available at [here](https://github.com/yluo42/TAC/tree/master/data).
35 |
36 | ## Model configurations
37 |
38 | To use our model:
39 | ``` python
40 | mix_audio = torch.randn(3,6,64000)
41 | test_model = make_TF_FaSNet(
42 | nmic=6, nspk=2, n_fft=256, embed_dim=16,
43 | dim_nb=32, dim_ffn=64, n_conv_layers=2,
44 | B=4, I=8, J=1, H=128, L=4
45 | )
46 | separated_audio = test_model(mix_audio)
47 | ```
48 | Each variable stands for:
49 |
50 | - General config
51 | - `nmic`: Number of microphones
52 | - `nspk`: Number of speakers
53 | - `n_fft`: Number of FFT points
54 | - `embed_dim`: Embedding dimension for each T-F unit
55 | - Encoder-decoder:
56 | - `dim_nb`: Number of hidden units in the Narrow-band feature extraction module
57 | - `dim_fft`: Number of hidden units between two linear layers in context decoding module
58 | - `n_conv_layers`: Number of convolution blocks in the context decoding module
59 | - Multi-path separation module:
60 | - `B`: Number of multi-path blocks
61 | - `I`: Kernel size for Unfold and Deconv
62 | - `J`: Stride size for Unfold and Deconv
63 | - `H`: Number of hidden units in BLSTM
64 | - `L`: Number of heads in self-attention
65 |
66 | With these configurations, we achieve an average 15.5 dB SI-SNR improvement on the simulated 6-mic circular-array dataset with a model size of 2.5M.
67 |
68 | # Miscellaneous
69 |
70 | Given a $D \times T \times F$ tensor, we apply 2D positional encoding as follows:
71 | ```math
72 | \begin{align*}PE(t,f,2i) = sin(t/10000^{4i/D})\\PE(t,f,2i+1) = cos(t/10000^{4i/D})\\PE(t,f,2j+D/2) = sin(f/10000^{4j/D})\\PE(t,f,2j+1+D/2) = cos(f/10000^{4j/D})\end{align*}
73 | ```
74 | where $t$ indexes $T$ frames, $f$ indexes $F$ frequencies, and $i,j \in [0, D/4)$ specify the dimension.
--------------------------------------------------------------------------------
/flowchart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JonathanDZ/TF-FaSNet/843a456aaabadfeeec52610c058fd22f8fa90ef2/flowchart.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import copy
4 |
5 | from utility.separator import Separator, SeparatorLayer, RNNModule, PositionalMultiHeadAttention
6 |
7 |
8 | class EncoderSeparatorDecoder(nn.Module):
9 | """
10 | A standard Speech Separation architecture
11 | """
12 |
13 | def __init__(self, encoder, decoder, separator):
14 | super(EncoderSeparatorDecoder, self).__init__()
15 | self.encoder = encoder
16 | self.decoder = decoder
17 | self.separator = separator
18 |
19 | def forward(self, mix_audio):
20 | """
21 | Input:
22 | mix_audio: Batch, nmic, T
23 | Output:
24 | separated_audio: Batch, nspk, T
25 | """
26 | enc_output, XrMM, ref_enc = self.encode(mix_audio)
27 | sep_output = self.separate(enc_output)
28 | dec_output = self.decode(sep_output, XrMM, ref_enc)
29 | return dec_output
30 |
31 | def encode(self, mix_audio):
32 | """
33 | Beside returning enc_output and ref_enc, encoder also returns a XrMM
34 | (Xr_magnitude_mean: mean of the magnitude of the ref channel of X)
35 | to ensure decoder properly perform inverse normalization.
36 | """
37 | return self.encoder(mix_audio)
38 |
39 | def separate(self, enc_output):
40 | return self.separator(enc_output)
41 |
42 | def decode(self, sep_output, XrMM, ref_enc):
43 | """
44 | Decoder receives XrMM to perform inverse normalization.
45 | """
46 | return self.decoder(sep_output, XrMM, ref_enc)
47 |
48 |
49 | # Encoder
50 | class Encoder(nn.Module):
51 | """
52 | STFT -> Normalization |-> view as Narrow-band -> BLSTM |-> concat -> conv2d -> gLN
53 | |-> take ref channel -> conv2d |
54 | """
55 |
56 | def __init__(self, n_fft=256, embed_dim=32, nmic=6, dim_nb=64):
57 | """
58 | The sampling rate is 16 kHz. The STFT window size and hop size are 16 ms and 4 ms, respectively.
59 | n_fft = 16*16 = 256, hop_length = 4*16 = 64
60 | """
61 | super(Encoder, self).__init__()
62 | self.n_fft = n_fft
63 | self.window = torch.hann_window(n_fft)
64 | F = n_fft//2 + 1
65 |
66 | # For ref encode
67 | self.conv2d_1 = nn.Conv2d(in_channels=2, out_channels=embed_dim, kernel_size=3, padding=1)
68 |
69 | # For narrow-band encode
70 | # dim_nb = 2*embed_dim
71 | self.embed_dim = embed_dim
72 | self.conv1d = nn.Conv1d(nmic*2, dim_nb, 4)
73 | self.layernorm = nn.LayerNorm(dim_nb)
74 | self.rnn1 = nn.LSTM(input_size=dim_nb, hidden_size=dim_nb, batch_first=True, bidirectional=True)
75 | self.linear1 = nn.Linear(2*dim_nb, dim_nb)
76 | self.deconv1d = nn.ConvTranspose1d(dim_nb, embed_dim, 4)
77 |
78 | # Final gLN
79 | self.conv2d_2 = nn.Conv2d(in_channels=2*embed_dim, out_channels=embed_dim, kernel_size=3, padding=1)
80 | self.gLN = nn.GroupNorm(1, embed_dim, eps=1e-8)
81 |
82 | def forward(self, mix_audio):
83 | """
84 | mix_audio: batch, nmic, T
85 | """
86 | batch, nmic, T = mix_audio.shape
87 | output = mix_audio.view(-1, T) # batch*nmic, T
88 | output = torch.stft(output, self.n_fft,
89 | window=self.window,
90 | return_complex=True) # batch*nmic, n_fft/2 + 1, T'
91 | output = output.view(batch, nmic, output.shape[-2], output.shape[-1]) # batch, nmic, n_fft/2 + 1, T'
92 | output = output.permute(0,2,3,1).contiguous() # batch, n_fft/2 + 1, T', nmic
93 |
94 | # Normalization by using reference channel
95 | F, TF = output.shape[1], output.shape[2]
96 | ref_channel = 0
97 | Xr = output[... , ref_channel].clone() # Take a ref channel copy
98 | XrMM = torch.abs(Xr).mean(dim=2) # Xr_magnitude_mean: mean of the magnitude of the ref channel of X
99 | output[:, :, :, :] /= (XrMM.reshape(batch, F, 1, 1) + 1e-8)
100 |
101 | # View as real
102 | output = torch.view_as_real(output) # [B, F, TF, C, 2]
103 |
104 | # 1) Get ref channel encoding
105 | ref_enc = output[:,:,:,ref_channel,:].clone().permute(0,3,2,1).contiguous() # batch, 2, TF, n_fft/2 + 1
106 | ref_enc = self.conv2d_1(ref_enc) # batch, embed_dim, TF, n_fft/2 + 1
107 |
108 | # 2)Get all channel narrow-band encoding
109 | nb_enc_input = output.view(batch*F, TF, nmic*2).transpose(1,2).contiguous() # batch*n_fft/2+1, nmic*2, TF
110 | nb_enc_input = self.conv1d(nb_enc_input) # batch*nfft/2+1, embed_dim, TF
111 | nb_enc = nb_enc_input.transpose(1,2).contiguous() # batch*nfft/2+1, TF, embed_dim
112 | nb_enc = self.layernorm(nb_enc)
113 |
114 | nb_enc, _ = self.rnn1(nb_enc) # batch*nfft/2+1, TF, embed_dim*2
115 | nb_enc = self.linear1(nb_enc) # batch*nfft/2+1, TF, embed_dim
116 |
117 | nb_enc = nb_enc.transpose(1,2).contiguous() # batch*nfft/2+1, embed_dim, TF
118 | nb_enc = nb_enc + nb_enc_input
119 | nb_enc = self.deconv1d(nb_enc) # batch*nfft/2+1, embed_dim, TF
120 | nb_enc = nb_enc.view(batch, F, self.embed_dim, TF).permute(0,2,3,1).contiguous() # batch, embed_dim, TF, n_fft/2+1
121 |
122 | # 3) Concat two encodings to get a ifasnet-like encoding
123 | all_enc = torch.cat([ref_enc, nb_enc], 1) # batch, 2*embed_dim, TF, n_fft/2+1
124 | all_enc = self.conv2d_2(all_enc)
125 | all_enc = self.gLN(all_enc)
126 |
127 | return all_enc, XrMM, ref_enc
128 |
129 |
130 | # Decoder
131 | class Decoder(nn.Module):
132 | "Deconv2d -> linear -> view as full-band -> inverse normalization -> iSTFT"
133 |
134 | def __init__(self, n_fft=256, embed_dim=32, nspk=2, nmic=6, n_conv_layers=2, dropout=0.1, dim_ffn=128):
135 | super(Decoder, self).__init__()
136 | # For context decoding
137 | # dim_ffn = 4*embed_dim
138 | self.conv2d_in = nn.Conv2d(in_channels=2*embed_dim, out_channels=2*embed_dim, kernel_size=3, padding=1)
139 | self.gLN = nn.GroupNorm(1, 2*embed_dim, eps=1e-8)
140 | self.linear1 = nn.Linear(2*embed_dim, dim_ffn)
141 | self.activation = nn.functional.silu
142 |
143 | convs = []
144 | for l in range(n_conv_layers):
145 | convs.append(nn.Conv2d(in_channels=dim_ffn, out_channels=dim_ffn, kernel_size=3, padding='same', groups=dim_ffn, bias=True))
146 | convs.append(nn.GroupNorm(4, dim_ffn, eps=1e-8))
147 | convs.append(nn.SiLU())
148 | self.conv = nn.Sequential(*convs)
149 |
150 | # self.dropout1 = nn.Dropout(dropout)
151 | self.linear2 = nn.Linear(dim_ffn, 2*embed_dim)
152 | self.dropout2 = nn.Dropout(dropout)
153 |
154 | # Decode
155 | self.nspk = nspk
156 | self.deconv2d = nn.ConvTranspose2d(in_channels=2*embed_dim, out_channels=2*nspk, kernel_size=3, padding=1)
157 | self.linear = nn.Linear(in_features=2*nspk, out_features=2*nspk)
158 | self.n_fft = n_fft
159 | self.window = torch.hann_window(n_fft)
160 |
161 | def forward(self, x, XrMM, ref_enc):
162 | """
163 | x: batch, D, T, F
164 | ref_enc: batch, D, T, F
165 | """
166 | batch, _, T, F = x.shape
167 |
168 | # Add a decode process, which utilizes ref_enc info to potentially enhance its performance
169 | embedding_input = torch.cat([ref_enc, x], 1) # batch, 2*D, T, F
170 | embedding_input = self.conv2d_in(embedding_input) # batch, 2*D, T, F
171 | embedding = self.gLN(embedding_input) # batch, 2*D, T, F
172 | embedding = self._ff_block(embedding) # batch, 2*D, T, F
173 | embedding = embedding + embedding_input
174 |
175 | output = self.deconv2d(embedding) # batch, 2*nspk, T, F
176 | output = output.permute(0,3,2,1).contiguous() # batch, F, T, 2*nspk
177 | output = self.linear(output)
178 |
179 | # To complex
180 | output = output.view(batch, F, T, self.nspk, 2) # batch, F, T, nspk, 2
181 | output = torch.view_as_complex(output) # batch, F, T, nspk
182 |
183 | # Inverse normalization
184 | Ys_hat = torch.empty(size=(batch, self.nspk, F, T), dtype=torch.complex64, device=output.device)
185 | XrMM = torch.unsqueeze(XrMM, dim=2).expand(-1, -1, T)
186 | for spk in range(self.nspk):
187 | Ys_hat[:, spk, :, :] = output[:, :, :, spk] * XrMM[:, :, :]
188 |
189 | # iSTFT with frequency binding
190 | ys_hat = torch.istft(Ys_hat.view(batch * self.nspk, F, T), n_fft=self.n_fft, window=self.window, win_length=self.n_fft)
191 | ys_hat = ys_hat.view(batch, self.nspk, -1)
192 | return ys_hat
193 |
194 | # Feed forward block
195 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
196 | "x: B, 2*D, T, F"
197 | x = x.transpose(1,3).contiguous() # B, F, T, 2*D
198 | x = self.linear1(x) # B, F, T, 2*D
199 | x = self.activation(x)
200 | x = x.transpose(1,3).contiguous() # B, 2*D, T, F
201 | x = self.conv(x)
202 | x = x.transpose(1,3).contiguous() # B, F, T, 2*D
203 | # x = self.dropout1(x)
204 | x = self.linear2(x)
205 | x = x.transpose(1,3).contiguous() # B, 2*D, T, F
206 | return self.dropout2(x)
207 |
208 |
209 | def make_TF_FaSNet(nmic=6, nspk=2, n_fft=256, embed_dim=16, dim_nb=32, dim_ffn=64, n_conv_layers=2, B=4, I=8, J=1, H=128, L=4):
210 | "Helper: Construct TF-FaSNet model from hyperparameters"
211 | F = n_fft//2 + 1
212 | E = embed_dim//L
213 |
214 | c = copy.deepcopy
215 | RNN_module = RNNModule(hidden_size=H, kernel_size=I, stride=J, embed_dim=embed_dim)
216 | self_attn = PositionalMultiHeadAttention(h=L, d_model=embed_dim, d_q=E, F=F)
217 | model = EncoderSeparatorDecoder(
218 | encoder=Encoder(n_fft=n_fft, embed_dim=embed_dim, nmic=nmic, dim_nb=dim_nb),
219 | decoder=Decoder(n_fft=n_fft, embed_dim=embed_dim, nspk=nspk, n_conv_layers=n_conv_layers, dim_ffn=dim_ffn),
220 | separator=Separator(SeparatorLayer(c(RNN_module), c(RNN_module), c(self_attn)), N=B)
221 | )
222 |
223 | # Initialize parameters with Glorot / fan_avg
224 | for p in model.parameters():
225 | if p.dim() > 1:
226 | nn.init.xavier_uniform_(p)
227 | return model
228 |
229 | def check_parameters(net):
230 | '''
231 | Returns module parameters. Mb
232 | '''
233 | parameters = sum(param.numel() for param in net.parameters())
234 | return parameters / 10**6
235 |
236 | if __name__ == "__main__":
237 |
238 | test_model = make_TF_FaSNet()
239 | test_model.eval()
240 |
241 | # Check model size
242 | print(check_parameters(test_model))
243 |
244 | # Test full model
245 | mix_audio = torch.randn(3,6,64000)
246 | separated_audio = test_model(mix_audio)
247 | print(separated_audio.shape)
248 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | positional-encodings==6.0.1
2 | torch==1.13.1
3 | torchaudio==0.13.1
--------------------------------------------------------------------------------
/utility/separator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | import copy
5 | from positional_encodings.torch_encodings import PositionalEncoding2D, Summer
6 |
7 |
8 | # Separator
9 | def clones(module, N):
10 | "Produce N identical layers"
11 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
12 |
13 | class Separator(nn.Module):
14 | "Separator is a stack of N layers"
15 |
16 | def __init__(self, layer, N=6):
17 | super(Separator, self).__init__()
18 | self.layers = clones(layer, N)
19 | self.N = N
20 |
21 | def forward(self, x):
22 | "Pass the input through each layer in turn"
23 | for layer in self.layers:
24 | x = layer(x)
25 | return x
26 |
27 |
28 | # Multi-path architecture
29 | class SeparatorLayer(nn.Module):
30 | "Separator is made up of spectral module, temporal module and self-attention module(defined below)"
31 |
32 | def __init__(self, RNN_module_s, RNN_module_t, self_attn):
33 | super(SeparatorLayer, self).__init__()
34 | self.spectral = RNN_module_s
35 | self.temporal = RNN_module_t
36 | self.self_attn = self_attn
37 |
38 | def forward(self, x):
39 | """
40 | x: batch, D, T, F
41 | """
42 | output = self.spectral(x) # batch, D, T, F
43 |
44 | output = output.transpose(2,3).contiguous()
45 | output = self.temporal(output) # batch, D, F, T
46 |
47 | output = output.transpose(2,3).contiguous()
48 | output = self.self_attn(output) # batch, D, T, F
49 |
50 | return output
51 |
52 |
53 | # Spectral & temporal module
54 | class RNNModule(nn.Module):
55 | "Unfold -> LN -> BLSTM -> Deconv1D -> residual"
56 |
57 | def __init__(self, hidden_size=256, kernel_size=8, stride=1, embed_dim=32, dropout=0, bidirectional=True):
58 | super(RNNModule, self).__init__()
59 | self.stride = stride
60 | self.kernel_size = kernel_size
61 |
62 | self.layernorm = nn.LayerNorm(embed_dim*kernel_size)
63 | self.rnn = nn.LSTM(embed_dim*kernel_size, hidden_size, 1, dropout=dropout, batch_first=True, bidirectional=bidirectional) # N,L,F -> N,L,2H
64 | self.deconv1d = nn.ConvTranspose1d(2*hidden_size, embed_dim, kernel_size, stride=stride)
65 |
66 |
67 | def forward(self, x):
68 | "x: batch, D, dim1, dim2"
69 | batch, embed_dim, dim1, dim2 = x.shape
70 |
71 | output = x.unfold(-1, self.kernel_size, self.stride) # batch, D, dim1, dim2/stride, kernel_size
72 |
73 | output = output.permute(0,2,3,4,1).contiguous().view(batch, dim1, -1, self.kernel_size*embed_dim)
74 | output = self.layernorm(output) # batch, dim1, dim2/stride, D*kernel_size
75 |
76 | output = output.view(batch*dim1, -1, self.kernel_size*embed_dim)
77 | output, _ = self.rnn(output) # batch*dim1, dim2/stride, 2*hidden_size
78 |
79 | output = output.contiguous().transpose(1,2).contiguous() # batch*dim1, 2*hidden_size, dim2/stride
80 | output = self.deconv1d(output) # batch*dim1, D, dim2
81 |
82 | output = output.view(batch, dim1, embed_dim, dim2).permute(0,2,1,3).contiguous()
83 | output = output + x
84 |
85 | return output
86 |
87 |
88 | # Pre-processing method to generate final result in Attention module
89 | class Generator(nn.Module):
90 | "1X1Conv2d -> PReLU -> cfLN"
91 |
92 | def __init__(self, input_dim=32, output_dim=4, F=129):
93 | super(Generator, self).__init__()
94 | self.conv2d = nn.Conv2d(input_dim, output_dim, 1)
95 | self.prelu = nn.PReLU()
96 | self.cfLN = nn.LayerNorm([output_dim, F])
97 |
98 | def forward(self, x):
99 | "x: batch, embed_dim, T, F"
100 | output = self.conv2d(x) # batch, output_dim, T, F
101 | output = self.prelu(output) # batch, output_dim, T, F
102 |
103 | output = output.transpose(1,2).contiguous() # batch, T, output_dim, F
104 | output = self.cfLN(output)
105 | output = output.transpose(1,2).contiguous()
106 |
107 | return output
108 |
109 |
110 | # Pre-processing method to generate batched qkv
111 | class MultiHeadGenerator(nn.Module):
112 | "1X1Conv2d -> PReLU -> cfLN"
113 |
114 | def __init__(self, input_dim=32, output_dim=4, F=129, h=4):
115 | super(MultiHeadGenerator, self).__init__()
116 | self.h = h
117 | self.output_dim = output_dim
118 | self.conv2d = nn.Conv2d(input_dim, h*output_dim, 1)
119 | self.prelu = nn.PReLU()
120 | self.cfLN = nn.LayerNorm([output_dim, F])
121 |
122 | def forward(self, x):
123 | "x: batch, embed_dim, T, F"
124 | batch, _, T, F = x.shape
125 | output = self.conv2d(x) # batch, h*output_dim, T, F
126 | output = self.prelu(output) # batch, h*output_dim, T, F
127 |
128 | output = output.view(batch, self.h, self.output_dim, T, F).transpose(2,3).contiguous() # batch, h, T, ouput_dim, F
129 | output = self.cfLN(output)
130 | output = output.transpose(3,4).contiguous() # batch, h, T, F, ouput_dim
131 | output = output.view(batch, self.h, T, F*self.output_dim)
132 |
133 | return output
134 |
135 |
136 | # Dot-product Attention
137 | def attention(query, key, value, mask=None, dropout=None):
138 | """
139 | Compute 'Scaled Dot Product Attention'
140 | Q: batch, T, FxE
141 | K: batch, T, FxE
142 | V: batch, T, FxD/L
143 | """
144 | d_k = query.size(-1)
145 | scores = torch.matmul(query, key.transpose(-2, -1) / math.sqrt(d_k))
146 | if mask is not None:
147 | scores = scores.masked_fill(mask == 0, -1e9)
148 | p_attn = scores.softmax(dim=-1)
149 | if dropout is not None:
150 | p_attn = dropout(p_attn)
151 | return torch.matmul(p_attn, value), p_attn
152 |
153 |
154 | # Multi-head Attention with 2D positional embedding (full-band self-attention module)
155 | class PositionalMultiHeadAttention(nn.Module):
156 | """
157 | Multi-head attention with 2D positional encoding.
158 | This concept (2D PE) was proposed in the "Translating
159 | Math Formula Images to LaTeX Sequences Using
160 | Deep Neural Networks with Sequence-level Training".
161 | """
162 |
163 | def __init__(self, h=4, d_model=32, d_q=4, F=129, dropout=0.1):
164 | super(PositionalMultiHeadAttention, self).__init__()
165 |
166 | assert d_model % h == 0
167 | self.h = h
168 | self.d_q = d_q
169 | self.d_k = d_q
170 | self.d_v = d_model // h
171 | self.RPE_size = F*d_q # FXE
172 |
173 | self.q_proj = MultiHeadGenerator(d_model, self.d_q, F, h)
174 | self.k_proj = MultiHeadGenerator(d_model, self.d_k, F, h)
175 | self.v_proj = MultiHeadGenerator(d_model, self.d_v, F, h)
176 | self.out_proj = Generator(d_model, d_model, F)
177 |
178 | self.p_enc_2d = Summer(PositionalEncoding2D(F))
179 |
180 | self.attn = None
181 | self.dropout = nn.Dropout(p=dropout)
182 |
183 | def forward(self, x, mask=None):
184 | "x: batch, D, T, F"
185 | if mask is not None:
186 | # Same mask applied to all h heads
187 | mask = mask.unsqueeze(1)
188 | batch, _, T, F = x.shape
189 |
190 | # Do all the projections in batch
191 | query = self.q_proj(self.p_enc_2d(x)) # B, N, T, FXE
192 | key = self.k_proj(self.p_enc_2d(x)) # B, N, T, FXE
193 | value = self.v_proj(x) # B, N, T, FX(D/L)
194 |
195 | output, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
196 |
197 | output = output.view(batch, self.h, T, F, self.d_v).permute(0,1,4,2,3).contiguous().view(batch, self.h * self.d_v, T, F)
198 | output = self.out_proj(output)
199 | output = output + x
200 |
201 | del query
202 | del key
203 | del value
204 |
205 | return output
206 |
207 | if __name__ == "__main__":
208 | # Test RNNModule
209 | # rnn_module = RNNModule(hidden_size=256)
210 | # x = torch.rand(3, 32, 1001, 129)
211 | # y = rnn_module(x)
212 | # print(y.shape)
213 |
214 | # Test MultiheadAttention
215 | B, D, T, F = 3, 32, 1001, 129
216 | x = torch.randn(B,D,T,F)
217 | model = PositionalMultiHeadAttention(h=4, d_model=D, d_q=4, F=F)
218 | y = model(x)
219 | print(y.shape)
220 |
--------------------------------------------------------------------------------