├── .gitignore
├── LICENSE
├── README.md
├── configs
├── avlit_2.yaml
├── avlit_4.yaml
└── avlit_8.yaml
├── docs
├── AVLIT_Folded.png
└── AVLIT_Unfolded.png
├── src
├── __init__.py
├── avlit.py
└── modules
│ ├── __init__.py
│ ├── afrcnn.py
│ └── autoencoder.py
└── tests
├── __init__.py
├── test_models.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Héctor Martel
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ______________________________________________________________________
2 |
3 |
4 |
5 | # AVLIT: Audio-Visual Lightweight ITerative model
6 |
7 |

8 | [](https://arxiv.org/abs/2306.00160)
9 | [](https://avlit-interspeech.github.io/)
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Description
20 | Official Pytorch Lightning implementation of ["Audio-Visual Speech Separation in Noisy Environments with a Lightweight Iterative Model"](https://arxiv.org/abs/2306.00160), accepted at INTERSPEECH 2023.
21 |
22 | | | |
23 | |:----------:|:----------:|
24 | | | |
25 | |  |  |
26 | | (A) Folded view of AVLIT | (B) Unfolded view of AVLIT |
27 |
28 | Audio-Visual Lightweight ITerative model (AVLIT) uses the [A-FRCNN](https://github.com/JusperLee/AFRCNN-For-Speech-Separation) as building block.
29 | AVLIT employs a homogeneous design with audio and video branches composed of A-FRCNN blocks used iteratively. The weights are shared for each modality, making the number of parameters constant. Please refer to the paper for details.
30 |
31 | ## Quick start
32 |
33 | ### Installation
34 |
35 | Make sure to have ``pytorch`` with GPU support installed on your machine according to the [official installation guide](https://pytorch.org/get-started/locally/).
36 |
37 | ### Basic usage
38 |
39 | Here is a minimal example of how to use AVLIT in plain Pytorch. The default parameters will produce the configuration for AVLIT-8, which is the best performing model in the paper.
40 |
41 | ```python
42 | from src.avlit import AVLIT
43 |
44 | # Instantiate the model
45 | model = AVLIT(
46 | num_sources = 2,
47 | # Audio branch
48 | audio_num_blocks = 8,
49 | # Video branch
50 | video_num_blocks = 4,
51 | video_encoder_checkpoint = "path/to/ae.ckpt",
52 | )
53 | model.cuda()
54 |
55 | # Training or inference logic here
56 | # ...
57 |
58 | ```
59 |
60 | ### Advanced usage
61 |
62 | For more control over the architecture, it is possible to provide values for more parameters as follows:
63 |
64 | ```python
65 | from src.avlit import AVLIT
66 |
67 | # Instantiate the model
68 | model = AVLIT(
69 | num_sources = 2,
70 | # Audio branch
71 | kernel_size = 40,
72 | audio_hidden_channels = 512,
73 | audio_bottleneck_channels = 128,
74 | audio_num_blocks = 8,
75 | audio_states = 5,
76 | # Video branch
77 | video_hidden_channels = 128,
78 | video_bottleneck_channels = 128,
79 | video_num_blocks = 4,
80 | video_states = 5,
81 | video_encoder_checkpoint = "path/to/ae.ckpt",
82 | video_encoder_trainable = False,
83 | video_embedding_dim = 1024,
84 | # AV fusion
85 | fusion_operation = "sum",
86 | fusion_positions = [4],
87 | )
88 | model.cuda()
89 |
90 | # Training or inference logic here
91 | # ...
92 |
93 | ```
94 |
95 | ### Tests
96 | The [tests/](https://github.com/hmartelb/avlit/blob/main/tests) folder contains unit tests for the AVLIT architecture.
97 | It is useful to run these tests if you want to customize the configuration parameters to verify that the input/output shapes are as expected and that the model can perform a forward pass correctly on CPU/GPU.
98 |
99 | To run all the unit tests, make sure to install the ``pytest`` package and run:
100 | ```
101 | pytest tests/test_models.py
102 | ```
103 |
104 | ## Cite
105 |
106 | If you use AVLIT in your research, please cite our paper:
107 | ```bibtex
108 | @inproceedings{martel23_interspeech,
109 | author={Héctor Martel and Julius Richter and Kai Li and Xiaolin Hu and Timo Gerkmann},
110 | title={{Audio-Visual Speech Separation in Noisy Environments with a Lightweight Iterative Model}},
111 | year=2023,
112 | booktitle={Proc. INTERSPEECH 2023},
113 | pages={1673--1677},
114 | doi={10.21437/Interspeech.2023-1753}
115 | }
116 | ```
117 |
118 | ## Contact
119 |
120 | * For **technical/academic questions** please write an email to the corresponding authors mentioned in the paper. Alternatively, use the [discussions](https://github.com/hmartelb/avlit/discussions) page. Do not open an issue.
121 | * For **bugs** or **problems with the code**, please [open an issue](https://github.com/hmartelb/avlit/issues) in this repository.
122 | * For **other inquiries**, contact me via email at hmartelb@hotmail.com.
123 |
124 | ## Changelog
125 |
126 | * [2023/07/26] 🎧 Demo samples website made public.
127 | * [2023/06/02] 🚀 Model code released.
128 | * [2023/05/31] 📰 Final version made public on arXiv.org.
129 | * [2023/05/17] 📰 Paper accepted at INTERSPEECH 2023!
130 |
131 | ## License
132 |
133 | This code is licensed under the terms of the MIT License.
134 |
135 | ```
136 | MIT License
137 | Copyright (c) 2023 Héctor Martel
138 |
139 | Permission is hereby granted, free of charge, to any person obtaining a copy
140 | of this software and associated documentation files (the "Software"), to deal
141 | in the Software without restriction, including without limitation the rights
142 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
143 | copies of the Software, and to permit persons to whom the Software is
144 | furnished to do so, subject to the following conditions:
145 |
146 | The above copyright notice and this permission notice shall be included in all
147 | copies or substantial portions of the Software.
148 |
149 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
150 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
151 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
152 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
153 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
154 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
155 | SOFTWARE.
156 | ```
157 |
--------------------------------------------------------------------------------
/configs/avlit_2.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.avlit.AVLIT
2 |
3 | num_sources: 2
4 |
5 | # Audio branch
6 | kernel_size: 40
7 | audio_num_blocks: 2
8 | audio_hidden_channels: 512
9 | audio_bottleneck_channels: 128
10 | audio_states: 5
11 |
12 | # Video branch
13 | video_num_blocks: 1
14 | video_hidden_channels: 128
15 | video_bottleneck_channels: 128
16 | video_states: 5
17 | video_embedding_dim: 1024
18 |
19 | # AV fusion
20 | fusion_op: "sum"
21 | fusion_positions: [0]
--------------------------------------------------------------------------------
/configs/avlit_4.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.avlit.AVLIT
2 |
3 | num_sources: 2
4 |
5 | # Audio branch
6 | kernel_size: 40
7 | audio_num_blocks: 4
8 | audio_hidden_channels: 512
9 | audio_bottleneck_channels: 128
10 | audio_states: 5
11 |
12 | # Video branch
13 | video_num_blocks: 2
14 | video_hidden_channels: 128
15 | video_bottleneck_channels: 128
16 | video_states: 5
17 | video_embedding_dim: 1024
18 |
19 | # AV fusion
20 | fusion_op: "sum"
21 | fusion_positions: [0]
--------------------------------------------------------------------------------
/configs/avlit_8.yaml:
--------------------------------------------------------------------------------
1 | _target_: src.avlit.AVLIT
2 |
3 | num_sources: 2
4 |
5 | # Audio branch
6 | kernel_size: 40
7 | audio_num_blocks: 8
8 | audio_hidden_channels: 512
9 | audio_bottleneck_channels: 128
10 | audio_states: 5
11 |
12 | # Video branch
13 | video_num_blocks: 4
14 | video_hidden_channels: 128
15 | video_bottleneck_channels: 128
16 | video_states: 5
17 | video_embedding_dim: 1024
18 |
19 | # AV fusion
20 | fusion_op: "sum"
21 | fusion_positions: [0]
--------------------------------------------------------------------------------
/docs/AVLIT_Folded.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/docs/AVLIT_Folded.png
--------------------------------------------------------------------------------
/docs/AVLIT_Unfolded.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/docs/AVLIT_Unfolded.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/src/__init__.py
--------------------------------------------------------------------------------
/src/avlit.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from typing import Optional
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 | from src.modules.autoencoder import FrameAutoEncoder
10 | from src.modules.afrcnn import AFRCNN, GlobLN
11 |
12 |
13 | class AVLIT(nn.Module):
14 | def __init__(
15 | self,
16 | num_sources: int = 2,
17 | # Audio branch
18 | kernel_size: int = 40,
19 | audio_hidden_channels: int = 512,
20 | audio_bottleneck_channels: int = 128,
21 | audio_num_blocks: int = 8,
22 | audio_states: int = 5,
23 | # Video branch
24 | video_hidden_channels: int = 128,
25 | video_bottleneck_channels: int = 128,
26 | video_num_blocks: int = 4,
27 | video_states: int = 5,
28 | video_encoder_checkpoint: str = "",
29 | video_encoder_trainable: bool = False,
30 | video_embedding_dim: int = 1024,
31 | # AV fusion
32 | fusion_operation: str = "sum",
33 | fusion_positions: list[int] = [4],
34 | ) -> None:
35 | super().__init__()
36 |
37 | self.num_sources = num_sources
38 | self.kernel_size = kernel_size
39 | self.audio_states = audio_states
40 | self.audio_hidden_channels = audio_hidden_channels
41 | self.video_embedding_dim = video_embedding_dim
42 |
43 | # Audio encoder
44 | self.audio_encoder = nn.Conv1d(
45 | in_channels=1,
46 | out_channels=audio_hidden_channels,
47 | kernel_size=kernel_size,
48 | stride=kernel_size // 2,
49 | padding=kernel_size // 2,
50 | bias=False,
51 | )
52 | torch.nn.init.xavier_uniform_(self.audio_encoder.weight)
53 |
54 | # Audio decoder
55 | self.audio_decoder = nn.ConvTranspose1d(
56 | in_channels=audio_hidden_channels,
57 | out_channels=1,
58 | output_padding=kernel_size // 2 - 1,
59 | kernel_size=kernel_size,
60 | stride=kernel_size // 2,
61 | padding=kernel_size // 2,
62 | groups=1,
63 | bias=False,
64 | )
65 | torch.nn.init.xavier_uniform_(self.audio_decoder.weight)
66 |
67 | # Video encoder
68 | self.video_encoder = FrameAutoEncoder()
69 | if os.path.isfile(video_encoder_checkpoint):
70 | self.video_encoder.load_state_dict(torch.load(video_encoder_checkpoint))
71 | if not video_encoder_trainable:
72 | for p in self.video_encoder.parameters():
73 | p.requires_grad = False
74 |
75 | # Audio adaptation
76 | self.audio_norm = GlobLN(audio_hidden_channels)
77 | self.audio_bottleneck = nn.Conv1d(
78 | in_channels=audio_hidden_channels,
79 | out_channels=audio_bottleneck_channels,
80 | kernel_size=1,
81 | )
82 |
83 | # Video adaptation
84 | self.video_bottleneck = nn.Conv1d(
85 | in_channels=num_sources * video_embedding_dim,
86 | out_channels=video_bottleneck_channels,
87 | kernel_size=1,
88 | )
89 |
90 | # Masking
91 | self.mask_net = nn.Sequential(
92 | nn.PReLU(),
93 | nn.Conv1d(
94 | in_channels=audio_bottleneck_channels,
95 | out_channels=num_sources * audio_hidden_channels,
96 | kernel_size=1,
97 | ),
98 | )
99 | self.mask_activation = nn.ReLU()
100 |
101 | # Audio branch
102 | self.audio_branch = IterativeBranch(
103 | num_sources=num_sources,
104 | hidden_channels=audio_hidden_channels,
105 | bottleneck_channels=audio_bottleneck_channels,
106 | num_blocks=audio_num_blocks,
107 | states=audio_states,
108 | fusion_operation=fusion_operation,
109 | fusion_positions=fusion_positions,
110 | )
111 |
112 | # Video branch
113 | self.video_branch = IterativeBranch(
114 | num_sources=num_sources,
115 | hidden_channels=video_hidden_channels,
116 | bottleneck_channels=video_bottleneck_channels,
117 | num_blocks=video_num_blocks,
118 | states=video_states,
119 | )
120 |
121 | def forward(self, x: torch.Tensor, v: torch.Tensor):
122 | # Get sizes of inputs
123 | b, T = x.shape[0], x.shape[-1]
124 | M, F = v.shape[1], v.shape[2]
125 |
126 | # Get audio features, fa
127 | x = self._pad_input(x)
128 | fa_in = self.audio_encoder(x)
129 | fa = self.audio_norm(fa_in)
130 | fa = self.audio_bottleneck(fa)
131 |
132 | # Get video features, fv
133 | fv = self.video_encoder.encode(v)
134 | fv = fv.permute(0, 1, 3, 2).reshape(b, M * self.video_embedding_dim, -1)
135 | fv = self.video_bottleneck(fv)
136 |
137 | # Forward the video and audio branches
138 | fv_p = self.video_branch(fv)
139 | fa_p = self.audio_branch(fa, fv_p)
140 |
141 | # Apply masking
142 | fa_m = self._masking(fa_in, fa_p)
143 |
144 | # Decode audio
145 | fa_m = fa_m.view(b * self.num_sources, self.audio_hidden_channels, -1)
146 | s = self.audio_decoder(fa_m)
147 | s = s.view(b, self.num_sources, -1)
148 | s = self._trim_output(s, T)
149 | return s
150 |
151 | def _masking(self, f, m):
152 | m = self.mask_net(m)
153 | m = m.view(
154 | m.shape[0],
155 | self.num_sources,
156 | self.audio_hidden_channels,
157 | -1,
158 | )
159 | m = self.mask_activation(m)
160 | masked = m * f.unsqueeze(1)
161 | return masked
162 |
163 | def lcm(self):
164 | half_kernel = self.kernel_size // 2
165 | pow_states = 2**self.audio_states
166 | return abs(half_kernel * pow_states) // math.gcd(half_kernel, pow_states)
167 |
168 | def _pad_input(self, x):
169 | values_to_pad = int(x.shape[-1]) % self.lcm()
170 | if values_to_pad:
171 | appropriate_shape = x.shape
172 | padded_x = torch.zeros(
173 | list(appropriate_shape[:-1])
174 | + [appropriate_shape[-1] + self.lcm() - values_to_pad],
175 | dtype=torch.float32,
176 | ).to(x.device)
177 | padded_x[..., : x.shape[-1]] = x
178 | return padded_x
179 | return x
180 |
181 | def _trim_output(self, x, T):
182 | if x.shape[-1] >= T:
183 | return x[..., 0:T]
184 | return x
185 |
186 |
187 | class IterativeBranch(nn.Module):
188 | def __init__(
189 | self,
190 | num_sources: int = 2,
191 | hidden_channels: int = 512,
192 | bottleneck_channels: int = 128,
193 | num_blocks: int = 8,
194 | states: int = 5,
195 | fusion_operation: str = "sum",
196 | fusion_positions: list = [0],
197 | ) -> None:
198 | super().__init__()
199 |
200 | # Branch attributes
201 | self.num_sources = num_sources
202 | self.hidden_channels = hidden_channels
203 | self.bottleneck_channels = bottleneck_channels
204 | self.num_blocks = num_blocks
205 | self.states = states
206 | self.fusion_operation = fusion_operation
207 | assert fusion_operation in [
208 | "sum",
209 | "prod",
210 | "concat",
211 | ], f"The specified fusion_operation is not supported, must be one of ['sum', 'prod', 'concat']."
212 | self.fusion_positions = list(
213 | filter(lambda x: x < num_blocks and x >= 0, fusion_positions)
214 | )
215 | assert (
216 | len(fusion_positions) > 0
217 | ), f"The length of the fusion positions must be non-zero. Make sure to specify values between 1 and num_blocks ({num_blocks})"
218 |
219 | # Modules
220 | self.afrcnn_block = AFRCNN(
221 | in_channels=hidden_channels,
222 | out_channels=bottleneck_channels,
223 | states=states,
224 | )
225 | self.adapt_audio = nn.Sequential(
226 | nn.Conv1d(
227 | bottleneck_channels,
228 | bottleneck_channels,
229 | kernel_size=1,
230 | stride=1,
231 | groups=bottleneck_channels,
232 | ),
233 | nn.PReLU(),
234 | )
235 | if len(self.fusion_positions) > 0:
236 | self.adapt_fusion = nn.Sequential(
237 | nn.Conv1d(
238 | bottleneck_channels * (2 if fusion_operation == "concat" else 1),
239 | bottleneck_channels,
240 | kernel_size=1,
241 | stride=1,
242 | groups=bottleneck_channels,
243 | ),
244 | nn.PReLU(),
245 | )
246 |
247 | def forward(
248 | self,
249 | fa: torch.Tensor,
250 | fv_p: Optional[torch.Tensor] = None,
251 | ) -> torch.Tensor:
252 | for i in range(self.num_blocks):
253 | # 1) Get the input: base case fa, else last output + fa
254 | Ri = fa if i == 0 else self.adapt_audio(Ri + fa)
255 |
256 | # 2) Apply modality fusion ?
257 | if i in self.fusion_positions and fv_p is not None:
258 | f = self._modality_fusion(Ri, fv_p)
259 | Ri = self.adapt_fusion(f)
260 |
261 | # 3) Apply the A-FRCNN block
262 | Ri = self.afrcnn_block(Ri)
263 | return Ri
264 |
265 | def _modality_fusion(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
266 | if a.shape[-1] > b.shape[-1]:
267 | b = F.interpolate(b, size=a.shape[2:])
268 |
269 | if self.fusion_operation == "sum":
270 | return a + b
271 | elif self.fusion_operation == "prod":
272 | return a * b
273 | elif self.fusion_operation == "concat":
274 | return torch.cat([a, b], dim=1)
275 | return a
276 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/src/modules/__init__.py
--------------------------------------------------------------------------------
/src/modules/afrcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | class GlobLN(nn.Module):
7 | """Global Layer Normalization (globLN)."""
8 |
9 | def __init__(self, channel_size: int):
10 | super().__init__()
11 | self.channel_size = channel_size
12 | self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
13 | self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True)
14 |
15 | def forward(self, x: torch.Tensor):
16 | """Applies forward pass.
17 | Works for any input size > 2D.
18 |
19 | Args:
20 | x (:class:`torch.Tensor`): Shape `[batch, chan, *]`
21 |
22 | Returns:
23 | :class:`torch.Tensor`: gLN_x `[batch, chan, *]`
24 | """
25 | dims = list(range(1, len(x.shape)))
26 | mean = x.mean(dim=dims, keepdim=True)
27 | var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
28 | return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())
29 |
30 | def apply_gain_and_bias(self, normed_x):
31 | """Assumes input of size `[batch, chanel, *]`."""
32 | return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
33 |
34 |
35 | class ConvNormAct(nn.Module):
36 | """
37 | This class defines the convolution layer with normalization and a PReLU
38 | activation.
39 | """
40 |
41 | def __init__(
42 | self,
43 | in_channels: int,
44 | out_channels: int,
45 | kernel_size: int,
46 | stride: int = 1,
47 | groups: int = 1,
48 | ):
49 | """
50 | :param in_channels: number of input channels
51 | :param out_channels: number of output channels
52 | :param kernel_size: kernel size
53 | :param stride: stride rate for down-sampling. Default is 1
54 | """
55 | super().__init__()
56 | padding = int((kernel_size - 1) / 2)
57 | self.conv = nn.Conv1d(
58 | in_channels,
59 | out_channels,
60 | kernel_size,
61 | stride=stride,
62 | padding=padding,
63 | bias=True,
64 | groups=groups,
65 | )
66 | self.norm = GlobLN(out_channels)
67 | self.act = nn.PReLU()
68 |
69 | def forward(self, input: torch.Tensor):
70 | output = self.conv(input)
71 | output = self.norm(output)
72 | return self.act(output)
73 |
74 |
75 | class DilatedConvNorm(nn.Module):
76 | """
77 | This class defines the dilated convolution with normalized output.
78 | """
79 |
80 | def __init__(
81 | self,
82 | in_channels: int,
83 | out_channels: int,
84 | kernel_size: int,
85 | stride: int = 1,
86 | d: int = 1,
87 | groups: int = 1,
88 | ):
89 | """
90 | :param in_channels: number of input channels
91 | :param out_channels: number of output channels
92 | :param kernel_size: kernel size
93 | :param stride: optional stride rate for down-sampling
94 | :param d: optional dilation rate
95 | """
96 | super().__init__()
97 | self.conv = nn.Conv1d(
98 | in_channels,
99 | out_channels,
100 | kernel_size,
101 | stride=stride,
102 | dilation=d,
103 | padding=((kernel_size - 1) // 2) * d,
104 | groups=groups,
105 | )
106 | self.norm = GlobLN(out_channels)
107 |
108 | def forward(self, input: torch.Tensor):
109 | output = self.conv(input)
110 | return self.norm(output)
111 |
112 |
113 | class AFRCNN(nn.Module):
114 | def __init__(
115 | self,
116 | in_channels: int = 512,
117 | out_channels: int = 128,
118 | states: int = 4,
119 | ):
120 | super().__init__()
121 | self.proj_1x1 = ConvNormAct(
122 | out_channels,
123 | in_channels,
124 | 1,
125 | stride=1,
126 | groups=1,
127 | )
128 | self.depth = states
129 | self.spp_dw = nn.ModuleList([])
130 | self.spp_dw.append(
131 | DilatedConvNorm(
132 | in_channels,
133 | in_channels,
134 | kernel_size=5,
135 | stride=1,
136 | groups=in_channels,
137 | d=1,
138 | )
139 | )
140 | # ----------Down Sample Layer----------
141 | for i in range(1, states):
142 | self.spp_dw.append(
143 | DilatedConvNorm(
144 | in_channels,
145 | in_channels,
146 | kernel_size=5,
147 | stride=2,
148 | groups=in_channels,
149 | d=1,
150 | )
151 | )
152 | # ----------Fusion Layer----------
153 | self.fuse_layers = nn.ModuleList([])
154 | for i in range(states):
155 | fuse_layer = nn.ModuleList([])
156 | for j in range(states):
157 | if i == j:
158 | fuse_layer.append(None)
159 | elif j - i == 1:
160 | fuse_layer.append(None)
161 | elif i - j == 1:
162 | fuse_layer.append(
163 | DilatedConvNorm(
164 | in_channels,
165 | in_channels,
166 | kernel_size=5,
167 | stride=2,
168 | groups=in_channels,
169 | d=1,
170 | )
171 | )
172 | self.fuse_layers.append(fuse_layer)
173 | self.concat_layer = nn.ModuleList([])
174 | # ----------Concat Layer----------
175 | for i in range(states):
176 | if i == 0 or i == states - 1:
177 | self.concat_layer.append(
178 | ConvNormAct(in_channels * 2, in_channels, 1, 1)
179 | )
180 | else:
181 | self.concat_layer.append(
182 | ConvNormAct(in_channels * 3, in_channels, 1, 1)
183 | )
184 |
185 | self.last_layer = nn.Sequential(
186 | ConvNormAct(in_channels * states, in_channels, 1, 1)
187 | )
188 | self.res_conv = nn.Conv1d(in_channels, out_channels, 1)
189 | # # ----------parameters-------------
190 | # self.depth = states # Already defined!
191 |
192 | def forward(self, x: torch.Tensor):
193 | """
194 | :param x: input feature map
195 | :return: transformed feature map
196 | """
197 | residual = x.clone()
198 | # Reduce --> project high-dimensional feature maps to low-dimensional space
199 | output1 = self.proj_1x1(x)
200 | output = [self.spp_dw[0](output1)]
201 | for k in range(1, self.depth):
202 | out_k = self.spp_dw[k](output[-1])
203 | output.append(out_k)
204 |
205 | x_fuse = []
206 | for i in range(len(self.fuse_layers)):
207 | wav_length = output[i].shape[-1]
208 | y = torch.cat(
209 | (
210 | self.fuse_layers[i][0](output[i - 1])
211 | if i - 1 >= 0
212 | else torch.Tensor().to(output1.device),
213 | output[i],
214 | F.interpolate(output[i + 1], size=wav_length, mode="nearest")
215 | if i + 1 < self.depth
216 | else torch.Tensor().to(output1.device),
217 | ),
218 | dim=1,
219 | )
220 | x_fuse.append(self.concat_layer[i](y))
221 |
222 | wav_length = output[0].shape[-1]
223 | for i in range(1, len(x_fuse)):
224 | x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest")
225 |
226 | concat = self.last_layer(torch.cat(x_fuse, dim=1))
227 | expanded = self.res_conv(concat)
228 | return expanded + residual
--------------------------------------------------------------------------------
/src/modules/autoencoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class EncoderBlock(nn.Module):
5 | def __init__(
6 | self,
7 | in_channels: int,
8 | out_channels: int,
9 | kernel_size: int,
10 | stride: int,
11 | leaky_slope: float = 0.3,
12 | ):
13 | """
14 | Encoder block module that performs convolution, instance normalization, and leaky ReLU activation.
15 |
16 | Args:
17 | in_channels (int): Number of input channels.
18 | out_channels (int): Number of output channels.
19 | kernel_size (int or tuple): Size of the convolution kernel.
20 | stride (int or tuple): Stride of the convolution operation.
21 | leaky_slope (float): Negative slope coefficient for the leaky ReLU activation.
22 |
23 | """
24 | super().__init__()
25 |
26 | self.conv = nn.Conv2d(
27 | in_channels,
28 | out_channels,
29 | kernel_size,
30 | stride,
31 | )
32 | self.norm = nn.InstanceNorm2d(out_channels, affine=True)
33 | self.act = nn.LeakyReLU(leaky_slope)
34 |
35 | def forward(self, x):
36 | """
37 | Forward pass of the encoder block.
38 |
39 | Args:
40 | x (tensor): Input tensor.
41 |
42 | Returns:
43 | tensor: Output tensor after convolution, instance normalization, and activation.
44 |
45 | """
46 | x = self.conv(x)
47 | x = self.norm(x)
48 | x = self.act(x)
49 | return x
50 |
51 |
52 | class DecoderBlock(nn.Module):
53 | def __init__(
54 | self,
55 | in_channels: int,
56 | out_channels: int,
57 | kernel_size: int,
58 | stride: int,
59 | leaky_slope: float = 0.3,
60 | ):
61 | """
62 | Decoder block module that performs transposed convolution, instance normalization, and leaky ReLU activation.
63 |
64 | Args:
65 | in_channels (int): Number of input channels.
66 | out_channels (int): Number of output channels.
67 | kernel_size (int or tuple): Size of the transposed convolution kernel.
68 | stride (int or tuple): Stride of the transposed convolution operation.
69 | leaky_slope (float): Negative slope coefficient for the leaky ReLU activation.
70 |
71 | """
72 | super().__init__()
73 |
74 | self.conv = nn.ConvTranspose2d(
75 | in_channels,
76 | out_channels,
77 | kernel_size,
78 | stride,
79 | )
80 | self.norm = nn.InstanceNorm2d(out_channels, affine=True)
81 | self.act = nn.LeakyReLU(leaky_slope)
82 |
83 | def forward(self, x):
84 | x = self.conv(x)
85 | x = self.norm(x)
86 | x = self.act(x)
87 | return x
88 |
89 |
90 | class EncoderAE(nn.Module):
91 | def __init__(
92 | self,
93 | in_channels: int = 3,
94 | base_channels: int = 8,
95 | num_layers: int = 3,
96 | ):
97 | super().__init__()
98 |
99 | self.layers = nn.ModuleList()
100 | for i in range(num_layers):
101 | cout = base_channels * (2**i)
102 | cin = in_channels if i == 0 else cout // 2
103 | self.layers.append(EncoderBlock(cin, cout, 2, 2))
104 |
105 | def forward(self, x):
106 | for layer in self.layers:
107 | x = layer(x)
108 | return x
109 |
110 |
111 | class DecoderAE(nn.Module):
112 | def __init__(
113 | self,
114 | in_channels: int = 3,
115 | base_channels: int = 8,
116 | num_layers: int = 3,
117 | ):
118 | super().__init__()
119 |
120 | self.layers = nn.ModuleList()
121 | for i in range(num_layers):
122 | cin = base_channels * (2 ** (num_layers - i - 1))
123 | cout = in_channels if i == num_layers - 1 else cin // 2
124 | self.layers.append(DecoderBlock(cin, cout, 2, 2))
125 |
126 | def forward(self, x):
127 | for layer in self.layers:
128 | x = layer(x)
129 | return x
130 |
131 |
132 | class FrameAutoEncoder(nn.Module):
133 | def __init__(
134 | self,
135 | in_channels: int = 1,
136 | base_channels: int = 8,
137 | num_layers: int = 3,
138 | ):
139 | """
140 | Single-frame autoencoder used to obtain frame-level embeddings from a video.
141 |
142 | Args:
143 | - in_channels (int, optional): Number of video channels (1: grayscale, 3: rgb, ...). Defaults to 1.
144 | - base_channels (int, optional): Number of channels in the first convolutional layer, multiplied by 2 each subsequent layer. Defaults to 8.
145 | - num_layers (int, optional): Number of layers (i.e. depth of the autoencoder). Defaults to 3.
146 | """
147 | super().__init__()
148 | self.encoder = EncoderAE(in_channels, base_channels, num_layers)
149 | self.decoder = DecoderAE(in_channels, base_channels, num_layers)
150 |
151 | def forward(self, x):
152 | return self.reconstruct(x)
153 |
154 | def encode(self, x):
155 | # x is expected to be a tensor of shape [batch, num_sources, frames, w, h].
156 | # Convert it to [batch * num_sources * frames, w, h]
157 | batch, num_sources, frames, w, h = (
158 | x.shape[0],
159 | x.shape[1],
160 | x.shape[2],
161 | x.shape[3],
162 | x.shape[4],
163 | )
164 | x = x.contiguous().view(batch * num_sources * frames, 1, w, h)
165 |
166 | z = self.encoder(x)
167 |
168 | # Undo the view of x. z has [batch * num_sources * frames, c', w', h']
169 | # Convert it to [batch, num_sources, frames, c' * w' * h']
170 | z = z.view(batch, num_sources, frames, -1)
171 | return z
172 |
173 | def reconstruct(self, x):
174 | # x is expected to be a tensor of shape [batch, num_sources, frames, w, h].
175 | # Convert it to [batch * num_sources * frames, w, h]
176 | batch, num_sources, frames, w, h = (
177 | x.shape[0],
178 | x.shape[1],
179 | x.shape[2],
180 | x.shape[3],
181 | x.shape[4],
182 | )
183 | x = x.contiguous().view(batch * num_sources * frames, 1, w, h)
184 |
185 | z = self.encoder(x)
186 | y = self.decoder(z)
187 |
188 | # Undo the view of x. y has [batch * num_sources * frames, w, h]
189 | # Convert it to [batch, num_sources, frames, w, h]
190 | y = y.view(batch, num_sources, frames, w, h)
191 | return y
192 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from src.avlit import AVLIT
4 | from tests.utils import model_forward_test
5 |
6 |
7 | class TestModels:
8 | """
9 | Class to test all the models, where each model is defined in an individual function.
10 | Model forward test ``_model_forward_test()`` is called after instantiating the model given the expected input and output shapes.
11 | """
12 |
13 | @pytest.mark.parametrize("batch_size", [1, 3])
14 | @pytest.mark.parametrize("num_sources", [1, 2])
15 | @pytest.mark.parametrize("sr", [8000, 16000])
16 | @pytest.mark.parametrize("segment_length", [4])
17 | @pytest.mark.parametrize("fps", [25])
18 | @pytest.mark.parametrize("audio_num_blocks", [2, 4, 8])
19 | @pytest.mark.parametrize("video_num_blocks", [1, 2, 4])
20 | @pytest.mark.parametrize("video_embedding_dim", [1024])
21 | @pytest.mark.parametrize("fusion_operation", ["sum", "prod", "concat"])
22 | def test_avlit(
23 | self,
24 | batch_size,
25 | num_sources,
26 | sr,
27 | segment_length,
28 | fps,
29 | audio_num_blocks,
30 | video_num_blocks,
31 | video_embedding_dim,
32 | fusion_operation,
33 | ):
34 | # Instantiate the model
35 | model = AVLIT(
36 | num_sources=num_sources,
37 | audio_num_blocks=audio_num_blocks,
38 | video_num_blocks=video_num_blocks,
39 | video_embedding_dim=video_embedding_dim,
40 | fusion_operation=fusion_operation,
41 | fusion_positions=[0],
42 | )
43 |
44 | # Generate expected I/O shapes
45 | input_shape = [
46 | (batch_size, 1, segment_length * sr), # Audio mixture
47 | (
48 | batch_size,
49 | num_sources,
50 | fps * segment_length,
51 | 64,
52 | 64,
53 | ), # Video inputs (1 video per speaker)
54 | ]
55 | output_shape = [
56 | (batch_size, num_sources, segment_length * sr),
57 | ]
58 |
59 | # Test the model
60 | model_forward_test(model, input_shape, output_shape, strict=False)
61 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def shapes_equal(x: torch.Tensor, y: torch.Tensor):
5 | assert (
6 | x.shape == y.shape
7 | ), f"Shapes do not match, expected {x.shape} but got {y.shape}"
8 |
9 |
10 | def data_equal(x: torch.Tensor, y: torch.Tensor, eps=1e-6):
11 | assert torch.allclose(
12 | x, y, atol=eps
13 | ), f"Inputs are not equal, mean absolute difference = {(x-y).abs().sum() / torch.numel(x)}"
14 |
15 |
16 | def model_forward_test(
17 | model,
18 | input_shape,
19 | output_shape,
20 | strict=True,
21 | ):
22 | """
23 | Generic method to test the forward function of a given model.
24 | The expected input and output shape(s) should be provided.
25 |
26 | Args:
27 | - model (torch.nn.Module): Model to be tested, must implement ``forward()`` or ``__call__()`` methods.
28 | - input_shape (Tuple, list): Shape of the input tensor.
29 | - output_shape (Tuple, list): Shape of the output tensor.
30 | - strict (bool, default: True): Check that the results of forward for ``train()`` and ``eval()`` are consistent. Must be set to ``False`` for the tests to pass if the model contains instances of ``nn.Dropout`` or ``nn.BatchNorm`` since they have different behaviors for training and inference.
31 | """
32 | if isinstance(input_shape, tuple):
33 | input_shape = [input_shape]
34 |
35 | if isinstance(output_shape, tuple):
36 | output_shape = [output_shape]
37 |
38 | # Generate some random input data
39 | x = [torch.randn(s) for s in input_shape]
40 |
41 | last_output = None
42 | for mode in ["train", "eval"]:
43 | # Set the model to training or inference mode
44 | model = model.train() if mode == "train" else model.eval()
45 |
46 | # Forward operation
47 | y = model(*x)
48 |
49 | if isinstance(y, torch.Tensor):
50 | y = [y]
51 |
52 | # Tests:
53 | # 1) Check that the output shape is the one expected
54 | for i in range(len(output_shape)):
55 | assert (
56 | y[i].shape == output_shape[i]
57 | ), f"[{mode.upper()}] Output shape does not match. Expected {output_shape[i]}, got {y[i].shape}."
58 |
59 | # 2) Check that the outputs are consistent for the same input in train() and eval()
60 | if last_output is not None and strict:
61 | for i in range(len(output_shape)):
62 | assert data_equal(
63 | y[i], last_output[i]
64 | ), "Output mismatch between TRAIN and EVAL"
65 |
66 | # Save the last output
67 | last_output = y
68 |
--------------------------------------------------------------------------------