├── .gitignore ├── LICENSE ├── README.md ├── configs ├── avlit_2.yaml ├── avlit_4.yaml └── avlit_8.yaml ├── docs ├── AVLIT_Folded.png └── AVLIT_Unfolded.png ├── src ├── __init__.py ├── avlit.py └── modules │ ├── __init__.py │ ├── afrcnn.py │ └── autoencoder.py └── tests ├── __init__.py ├── test_models.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Héctor Martel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ______________________________________________________________________ 2 | 3 |
4 | 5 | # AVLIT: Audio-Visual Lightweight ITerative model 6 | 7 | PyTorch 8 | [![arXiv](https://img.shields.io/badge/arXiv-2306.00160-brightgreen.svg)](https://arxiv.org/abs/2306.00160) 9 | [![Samples](https://img.shields.io/badge/Website-Demo_Samples-blue.svg)](https://avlit-interspeech.github.io/) 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 | 19 | ## Description 20 | Official Pytorch Lightning implementation of ["Audio-Visual Speech Separation in Noisy Environments with a Lightweight Iterative Model"](https://arxiv.org/abs/2306.00160), accepted at INTERSPEECH 2023. 21 | 22 | | | | 23 | |:----------:|:----------:| 24 | | | | 25 | | ![AVLIT_Folded](docs/AVLIT_Folded.png) | ![AVLIT_Unfolded](docs/AVLIT_Unfolded.png) | 26 | | (A) Folded view of AVLIT | (B) Unfolded view of AVLIT | 27 | 28 | Audio-Visual Lightweight ITerative model (AVLIT) uses the [A-FRCNN](https://github.com/JusperLee/AFRCNN-For-Speech-Separation) as building block. 29 | AVLIT employs a homogeneous design with audio and video branches composed of A-FRCNN blocks used iteratively. The weights are shared for each modality, making the number of parameters constant. Please refer to the paper for details. 30 | 31 | ## Quick start 32 | 33 | ### Installation 34 | 35 | Make sure to have ``pytorch`` with GPU support installed on your machine according to the [official installation guide](https://pytorch.org/get-started/locally/). 36 | 37 | ### Basic usage 38 | 39 | Here is a minimal example of how to use AVLIT in plain Pytorch. The default parameters will produce the configuration for AVLIT-8, which is the best performing model in the paper. 40 | 41 | ```python 42 | from src.avlit import AVLIT 43 | 44 | # Instantiate the model 45 | model = AVLIT( 46 | num_sources = 2, 47 | # Audio branch 48 | audio_num_blocks = 8, 49 | # Video branch 50 | video_num_blocks = 4, 51 | video_encoder_checkpoint = "path/to/ae.ckpt", 52 | ) 53 | model.cuda() 54 | 55 | # Training or inference logic here 56 | # ... 57 | 58 | ``` 59 | 60 | ### Advanced usage 61 | 62 | For more control over the architecture, it is possible to provide values for more parameters as follows: 63 | 64 | ```python 65 | from src.avlit import AVLIT 66 | 67 | # Instantiate the model 68 | model = AVLIT( 69 | num_sources = 2, 70 | # Audio branch 71 | kernel_size = 40, 72 | audio_hidden_channels = 512, 73 | audio_bottleneck_channels = 128, 74 | audio_num_blocks = 8, 75 | audio_states = 5, 76 | # Video branch 77 | video_hidden_channels = 128, 78 | video_bottleneck_channels = 128, 79 | video_num_blocks = 4, 80 | video_states = 5, 81 | video_encoder_checkpoint = "path/to/ae.ckpt", 82 | video_encoder_trainable = False, 83 | video_embedding_dim = 1024, 84 | # AV fusion 85 | fusion_operation = "sum", 86 | fusion_positions = [4], 87 | ) 88 | model.cuda() 89 | 90 | # Training or inference logic here 91 | # ... 92 | 93 | ``` 94 | 95 | ### Tests 96 | The [tests/](https://github.com/hmartelb/avlit/blob/main/tests) folder contains unit tests for the AVLIT architecture. 97 | It is useful to run these tests if you want to customize the configuration parameters to verify that the input/output shapes are as expected and that the model can perform a forward pass correctly on CPU/GPU. 98 | 99 | To run all the unit tests, make sure to install the ``pytest`` package and run: 100 | ``` 101 | pytest tests/test_models.py 102 | ``` 103 | 104 | ## Cite 105 | 106 | If you use AVLIT in your research, please cite our paper: 107 | ```bibtex 108 | @inproceedings{martel23_interspeech, 109 | author={Héctor Martel and Julius Richter and Kai Li and Xiaolin Hu and Timo Gerkmann}, 110 | title={{Audio-Visual Speech Separation in Noisy Environments with a Lightweight Iterative Model}}, 111 | year=2023, 112 | booktitle={Proc. INTERSPEECH 2023}, 113 | pages={1673--1677}, 114 | doi={10.21437/Interspeech.2023-1753} 115 | } 116 | ``` 117 | 118 | ## Contact 119 | 120 | * For **technical/academic questions** please write an email to the corresponding authors mentioned in the paper. Alternatively, use the [discussions](https://github.com/hmartelb/avlit/discussions) page. Do not open an issue. 121 | * For **bugs** or **problems with the code**, please [open an issue](https://github.com/hmartelb/avlit/issues) in this repository. 122 | * For **other inquiries**, contact me via email at hmartelb@hotmail.com. 123 | 124 | ## Changelog 125 | 126 | * [2023/07/26] 🎧 Demo samples website made public. 127 | * [2023/06/02] 🚀 Model code released. 128 | * [2023/05/31] 📰 Final version made public on arXiv.org. 129 | * [2023/05/17] 📰 Paper accepted at INTERSPEECH 2023! 130 | 131 | ## License 132 | 133 | This code is licensed under the terms of the MIT License. 134 | 135 | ``` 136 | MIT License 137 | Copyright (c) 2023 Héctor Martel 138 | 139 | Permission is hereby granted, free of charge, to any person obtaining a copy 140 | of this software and associated documentation files (the "Software"), to deal 141 | in the Software without restriction, including without limitation the rights 142 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 143 | copies of the Software, and to permit persons to whom the Software is 144 | furnished to do so, subject to the following conditions: 145 | 146 | The above copyright notice and this permission notice shall be included in all 147 | copies or substantial portions of the Software. 148 | 149 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 150 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 151 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 152 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 153 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 154 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 155 | SOFTWARE. 156 | ``` 157 | -------------------------------------------------------------------------------- /configs/avlit_2.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.avlit.AVLIT 2 | 3 | num_sources: 2 4 | 5 | # Audio branch 6 | kernel_size: 40 7 | audio_num_blocks: 2 8 | audio_hidden_channels: 512 9 | audio_bottleneck_channels: 128 10 | audio_states: 5 11 | 12 | # Video branch 13 | video_num_blocks: 1 14 | video_hidden_channels: 128 15 | video_bottleneck_channels: 128 16 | video_states: 5 17 | video_embedding_dim: 1024 18 | 19 | # AV fusion 20 | fusion_op: "sum" 21 | fusion_positions: [0] -------------------------------------------------------------------------------- /configs/avlit_4.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.avlit.AVLIT 2 | 3 | num_sources: 2 4 | 5 | # Audio branch 6 | kernel_size: 40 7 | audio_num_blocks: 4 8 | audio_hidden_channels: 512 9 | audio_bottleneck_channels: 128 10 | audio_states: 5 11 | 12 | # Video branch 13 | video_num_blocks: 2 14 | video_hidden_channels: 128 15 | video_bottleneck_channels: 128 16 | video_states: 5 17 | video_embedding_dim: 1024 18 | 19 | # AV fusion 20 | fusion_op: "sum" 21 | fusion_positions: [0] -------------------------------------------------------------------------------- /configs/avlit_8.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.avlit.AVLIT 2 | 3 | num_sources: 2 4 | 5 | # Audio branch 6 | kernel_size: 40 7 | audio_num_blocks: 8 8 | audio_hidden_channels: 512 9 | audio_bottleneck_channels: 128 10 | audio_states: 5 11 | 12 | # Video branch 13 | video_num_blocks: 4 14 | video_hidden_channels: 128 15 | video_bottleneck_channels: 128 16 | video_states: 5 17 | video_embedding_dim: 1024 18 | 19 | # AV fusion 20 | fusion_op: "sum" 21 | fusion_positions: [0] -------------------------------------------------------------------------------- /docs/AVLIT_Folded.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/docs/AVLIT_Folded.png -------------------------------------------------------------------------------- /docs/AVLIT_Unfolded.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/docs/AVLIT_Unfolded.png -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/src/__init__.py -------------------------------------------------------------------------------- /src/avlit.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from src.modules.autoencoder import FrameAutoEncoder 10 | from src.modules.afrcnn import AFRCNN, GlobLN 11 | 12 | 13 | class AVLIT(nn.Module): 14 | def __init__( 15 | self, 16 | num_sources: int = 2, 17 | # Audio branch 18 | kernel_size: int = 40, 19 | audio_hidden_channels: int = 512, 20 | audio_bottleneck_channels: int = 128, 21 | audio_num_blocks: int = 8, 22 | audio_states: int = 5, 23 | # Video branch 24 | video_hidden_channels: int = 128, 25 | video_bottleneck_channels: int = 128, 26 | video_num_blocks: int = 4, 27 | video_states: int = 5, 28 | video_encoder_checkpoint: str = "", 29 | video_encoder_trainable: bool = False, 30 | video_embedding_dim: int = 1024, 31 | # AV fusion 32 | fusion_operation: str = "sum", 33 | fusion_positions: list[int] = [4], 34 | ) -> None: 35 | super().__init__() 36 | 37 | self.num_sources = num_sources 38 | self.kernel_size = kernel_size 39 | self.audio_states = audio_states 40 | self.audio_hidden_channels = audio_hidden_channels 41 | self.video_embedding_dim = video_embedding_dim 42 | 43 | # Audio encoder 44 | self.audio_encoder = nn.Conv1d( 45 | in_channels=1, 46 | out_channels=audio_hidden_channels, 47 | kernel_size=kernel_size, 48 | stride=kernel_size // 2, 49 | padding=kernel_size // 2, 50 | bias=False, 51 | ) 52 | torch.nn.init.xavier_uniform_(self.audio_encoder.weight) 53 | 54 | # Audio decoder 55 | self.audio_decoder = nn.ConvTranspose1d( 56 | in_channels=audio_hidden_channels, 57 | out_channels=1, 58 | output_padding=kernel_size // 2 - 1, 59 | kernel_size=kernel_size, 60 | stride=kernel_size // 2, 61 | padding=kernel_size // 2, 62 | groups=1, 63 | bias=False, 64 | ) 65 | torch.nn.init.xavier_uniform_(self.audio_decoder.weight) 66 | 67 | # Video encoder 68 | self.video_encoder = FrameAutoEncoder() 69 | if os.path.isfile(video_encoder_checkpoint): 70 | self.video_encoder.load_state_dict(torch.load(video_encoder_checkpoint)) 71 | if not video_encoder_trainable: 72 | for p in self.video_encoder.parameters(): 73 | p.requires_grad = False 74 | 75 | # Audio adaptation 76 | self.audio_norm = GlobLN(audio_hidden_channels) 77 | self.audio_bottleneck = nn.Conv1d( 78 | in_channels=audio_hidden_channels, 79 | out_channels=audio_bottleneck_channels, 80 | kernel_size=1, 81 | ) 82 | 83 | # Video adaptation 84 | self.video_bottleneck = nn.Conv1d( 85 | in_channels=num_sources * video_embedding_dim, 86 | out_channels=video_bottleneck_channels, 87 | kernel_size=1, 88 | ) 89 | 90 | # Masking 91 | self.mask_net = nn.Sequential( 92 | nn.PReLU(), 93 | nn.Conv1d( 94 | in_channels=audio_bottleneck_channels, 95 | out_channels=num_sources * audio_hidden_channels, 96 | kernel_size=1, 97 | ), 98 | ) 99 | self.mask_activation = nn.ReLU() 100 | 101 | # Audio branch 102 | self.audio_branch = IterativeBranch( 103 | num_sources=num_sources, 104 | hidden_channels=audio_hidden_channels, 105 | bottleneck_channels=audio_bottleneck_channels, 106 | num_blocks=audio_num_blocks, 107 | states=audio_states, 108 | fusion_operation=fusion_operation, 109 | fusion_positions=fusion_positions, 110 | ) 111 | 112 | # Video branch 113 | self.video_branch = IterativeBranch( 114 | num_sources=num_sources, 115 | hidden_channels=video_hidden_channels, 116 | bottleneck_channels=video_bottleneck_channels, 117 | num_blocks=video_num_blocks, 118 | states=video_states, 119 | ) 120 | 121 | def forward(self, x: torch.Tensor, v: torch.Tensor): 122 | # Get sizes of inputs 123 | b, T = x.shape[0], x.shape[-1] 124 | M, F = v.shape[1], v.shape[2] 125 | 126 | # Get audio features, fa 127 | x = self._pad_input(x) 128 | fa_in = self.audio_encoder(x) 129 | fa = self.audio_norm(fa_in) 130 | fa = self.audio_bottleneck(fa) 131 | 132 | # Get video features, fv 133 | fv = self.video_encoder.encode(v) 134 | fv = fv.permute(0, 1, 3, 2).reshape(b, M * self.video_embedding_dim, -1) 135 | fv = self.video_bottleneck(fv) 136 | 137 | # Forward the video and audio branches 138 | fv_p = self.video_branch(fv) 139 | fa_p = self.audio_branch(fa, fv_p) 140 | 141 | # Apply masking 142 | fa_m = self._masking(fa_in, fa_p) 143 | 144 | # Decode audio 145 | fa_m = fa_m.view(b * self.num_sources, self.audio_hidden_channels, -1) 146 | s = self.audio_decoder(fa_m) 147 | s = s.view(b, self.num_sources, -1) 148 | s = self._trim_output(s, T) 149 | return s 150 | 151 | def _masking(self, f, m): 152 | m = self.mask_net(m) 153 | m = m.view( 154 | m.shape[0], 155 | self.num_sources, 156 | self.audio_hidden_channels, 157 | -1, 158 | ) 159 | m = self.mask_activation(m) 160 | masked = m * f.unsqueeze(1) 161 | return masked 162 | 163 | def lcm(self): 164 | half_kernel = self.kernel_size // 2 165 | pow_states = 2**self.audio_states 166 | return abs(half_kernel * pow_states) // math.gcd(half_kernel, pow_states) 167 | 168 | def _pad_input(self, x): 169 | values_to_pad = int(x.shape[-1]) % self.lcm() 170 | if values_to_pad: 171 | appropriate_shape = x.shape 172 | padded_x = torch.zeros( 173 | list(appropriate_shape[:-1]) 174 | + [appropriate_shape[-1] + self.lcm() - values_to_pad], 175 | dtype=torch.float32, 176 | ).to(x.device) 177 | padded_x[..., : x.shape[-1]] = x 178 | return padded_x 179 | return x 180 | 181 | def _trim_output(self, x, T): 182 | if x.shape[-1] >= T: 183 | return x[..., 0:T] 184 | return x 185 | 186 | 187 | class IterativeBranch(nn.Module): 188 | def __init__( 189 | self, 190 | num_sources: int = 2, 191 | hidden_channels: int = 512, 192 | bottleneck_channels: int = 128, 193 | num_blocks: int = 8, 194 | states: int = 5, 195 | fusion_operation: str = "sum", 196 | fusion_positions: list = [0], 197 | ) -> None: 198 | super().__init__() 199 | 200 | # Branch attributes 201 | self.num_sources = num_sources 202 | self.hidden_channels = hidden_channels 203 | self.bottleneck_channels = bottleneck_channels 204 | self.num_blocks = num_blocks 205 | self.states = states 206 | self.fusion_operation = fusion_operation 207 | assert fusion_operation in [ 208 | "sum", 209 | "prod", 210 | "concat", 211 | ], f"The specified fusion_operation is not supported, must be one of ['sum', 'prod', 'concat']." 212 | self.fusion_positions = list( 213 | filter(lambda x: x < num_blocks and x >= 0, fusion_positions) 214 | ) 215 | assert ( 216 | len(fusion_positions) > 0 217 | ), f"The length of the fusion positions must be non-zero. Make sure to specify values between 1 and num_blocks ({num_blocks})" 218 | 219 | # Modules 220 | self.afrcnn_block = AFRCNN( 221 | in_channels=hidden_channels, 222 | out_channels=bottleneck_channels, 223 | states=states, 224 | ) 225 | self.adapt_audio = nn.Sequential( 226 | nn.Conv1d( 227 | bottleneck_channels, 228 | bottleneck_channels, 229 | kernel_size=1, 230 | stride=1, 231 | groups=bottleneck_channels, 232 | ), 233 | nn.PReLU(), 234 | ) 235 | if len(self.fusion_positions) > 0: 236 | self.adapt_fusion = nn.Sequential( 237 | nn.Conv1d( 238 | bottleneck_channels * (2 if fusion_operation == "concat" else 1), 239 | bottleneck_channels, 240 | kernel_size=1, 241 | stride=1, 242 | groups=bottleneck_channels, 243 | ), 244 | nn.PReLU(), 245 | ) 246 | 247 | def forward( 248 | self, 249 | fa: torch.Tensor, 250 | fv_p: Optional[torch.Tensor] = None, 251 | ) -> torch.Tensor: 252 | for i in range(self.num_blocks): 253 | # 1) Get the input: base case fa, else last output + fa 254 | Ri = fa if i == 0 else self.adapt_audio(Ri + fa) 255 | 256 | # 2) Apply modality fusion ? 257 | if i in self.fusion_positions and fv_p is not None: 258 | f = self._modality_fusion(Ri, fv_p) 259 | Ri = self.adapt_fusion(f) 260 | 261 | # 3) Apply the A-FRCNN block 262 | Ri = self.afrcnn_block(Ri) 263 | return Ri 264 | 265 | def _modality_fusion(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 266 | if a.shape[-1] > b.shape[-1]: 267 | b = F.interpolate(b, size=a.shape[2:]) 268 | 269 | if self.fusion_operation == "sum": 270 | return a + b 271 | elif self.fusion_operation == "prod": 272 | return a * b 273 | elif self.fusion_operation == "concat": 274 | return torch.cat([a, b], dim=1) 275 | return a 276 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/src/modules/__init__.py -------------------------------------------------------------------------------- /src/modules/afrcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class GlobLN(nn.Module): 7 | """Global Layer Normalization (globLN).""" 8 | 9 | def __init__(self, channel_size: int): 10 | super().__init__() 11 | self.channel_size = channel_size 12 | self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True) 13 | self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True) 14 | 15 | def forward(self, x: torch.Tensor): 16 | """Applies forward pass. 17 | Works for any input size > 2D. 18 | 19 | Args: 20 | x (:class:`torch.Tensor`): Shape `[batch, chan, *]` 21 | 22 | Returns: 23 | :class:`torch.Tensor`: gLN_x `[batch, chan, *]` 24 | """ 25 | dims = list(range(1, len(x.shape))) 26 | mean = x.mean(dim=dims, keepdim=True) 27 | var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True) 28 | return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt()) 29 | 30 | def apply_gain_and_bias(self, normed_x): 31 | """Assumes input of size `[batch, chanel, *]`.""" 32 | return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1) 33 | 34 | 35 | class ConvNormAct(nn.Module): 36 | """ 37 | This class defines the convolution layer with normalization and a PReLU 38 | activation. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | in_channels: int, 44 | out_channels: int, 45 | kernel_size: int, 46 | stride: int = 1, 47 | groups: int = 1, 48 | ): 49 | """ 50 | :param in_channels: number of input channels 51 | :param out_channels: number of output channels 52 | :param kernel_size: kernel size 53 | :param stride: stride rate for down-sampling. Default is 1 54 | """ 55 | super().__init__() 56 | padding = int((kernel_size - 1) / 2) 57 | self.conv = nn.Conv1d( 58 | in_channels, 59 | out_channels, 60 | kernel_size, 61 | stride=stride, 62 | padding=padding, 63 | bias=True, 64 | groups=groups, 65 | ) 66 | self.norm = GlobLN(out_channels) 67 | self.act = nn.PReLU() 68 | 69 | def forward(self, input: torch.Tensor): 70 | output = self.conv(input) 71 | output = self.norm(output) 72 | return self.act(output) 73 | 74 | 75 | class DilatedConvNorm(nn.Module): 76 | """ 77 | This class defines the dilated convolution with normalized output. 78 | """ 79 | 80 | def __init__( 81 | self, 82 | in_channels: int, 83 | out_channels: int, 84 | kernel_size: int, 85 | stride: int = 1, 86 | d: int = 1, 87 | groups: int = 1, 88 | ): 89 | """ 90 | :param in_channels: number of input channels 91 | :param out_channels: number of output channels 92 | :param kernel_size: kernel size 93 | :param stride: optional stride rate for down-sampling 94 | :param d: optional dilation rate 95 | """ 96 | super().__init__() 97 | self.conv = nn.Conv1d( 98 | in_channels, 99 | out_channels, 100 | kernel_size, 101 | stride=stride, 102 | dilation=d, 103 | padding=((kernel_size - 1) // 2) * d, 104 | groups=groups, 105 | ) 106 | self.norm = GlobLN(out_channels) 107 | 108 | def forward(self, input: torch.Tensor): 109 | output = self.conv(input) 110 | return self.norm(output) 111 | 112 | 113 | class AFRCNN(nn.Module): 114 | def __init__( 115 | self, 116 | in_channels: int = 512, 117 | out_channels: int = 128, 118 | states: int = 4, 119 | ): 120 | super().__init__() 121 | self.proj_1x1 = ConvNormAct( 122 | out_channels, 123 | in_channels, 124 | 1, 125 | stride=1, 126 | groups=1, 127 | ) 128 | self.depth = states 129 | self.spp_dw = nn.ModuleList([]) 130 | self.spp_dw.append( 131 | DilatedConvNorm( 132 | in_channels, 133 | in_channels, 134 | kernel_size=5, 135 | stride=1, 136 | groups=in_channels, 137 | d=1, 138 | ) 139 | ) 140 | # ----------Down Sample Layer---------- 141 | for i in range(1, states): 142 | self.spp_dw.append( 143 | DilatedConvNorm( 144 | in_channels, 145 | in_channels, 146 | kernel_size=5, 147 | stride=2, 148 | groups=in_channels, 149 | d=1, 150 | ) 151 | ) 152 | # ----------Fusion Layer---------- 153 | self.fuse_layers = nn.ModuleList([]) 154 | for i in range(states): 155 | fuse_layer = nn.ModuleList([]) 156 | for j in range(states): 157 | if i == j: 158 | fuse_layer.append(None) 159 | elif j - i == 1: 160 | fuse_layer.append(None) 161 | elif i - j == 1: 162 | fuse_layer.append( 163 | DilatedConvNorm( 164 | in_channels, 165 | in_channels, 166 | kernel_size=5, 167 | stride=2, 168 | groups=in_channels, 169 | d=1, 170 | ) 171 | ) 172 | self.fuse_layers.append(fuse_layer) 173 | self.concat_layer = nn.ModuleList([]) 174 | # ----------Concat Layer---------- 175 | for i in range(states): 176 | if i == 0 or i == states - 1: 177 | self.concat_layer.append( 178 | ConvNormAct(in_channels * 2, in_channels, 1, 1) 179 | ) 180 | else: 181 | self.concat_layer.append( 182 | ConvNormAct(in_channels * 3, in_channels, 1, 1) 183 | ) 184 | 185 | self.last_layer = nn.Sequential( 186 | ConvNormAct(in_channels * states, in_channels, 1, 1) 187 | ) 188 | self.res_conv = nn.Conv1d(in_channels, out_channels, 1) 189 | # # ----------parameters------------- 190 | # self.depth = states # Already defined! 191 | 192 | def forward(self, x: torch.Tensor): 193 | """ 194 | :param x: input feature map 195 | :return: transformed feature map 196 | """ 197 | residual = x.clone() 198 | # Reduce --> project high-dimensional feature maps to low-dimensional space 199 | output1 = self.proj_1x1(x) 200 | output = [self.spp_dw[0](output1)] 201 | for k in range(1, self.depth): 202 | out_k = self.spp_dw[k](output[-1]) 203 | output.append(out_k) 204 | 205 | x_fuse = [] 206 | for i in range(len(self.fuse_layers)): 207 | wav_length = output[i].shape[-1] 208 | y = torch.cat( 209 | ( 210 | self.fuse_layers[i][0](output[i - 1]) 211 | if i - 1 >= 0 212 | else torch.Tensor().to(output1.device), 213 | output[i], 214 | F.interpolate(output[i + 1], size=wav_length, mode="nearest") 215 | if i + 1 < self.depth 216 | else torch.Tensor().to(output1.device), 217 | ), 218 | dim=1, 219 | ) 220 | x_fuse.append(self.concat_layer[i](y)) 221 | 222 | wav_length = output[0].shape[-1] 223 | for i in range(1, len(x_fuse)): 224 | x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest") 225 | 226 | concat = self.last_layer(torch.cat(x_fuse, dim=1)) 227 | expanded = self.res_conv(concat) 228 | return expanded + residual -------------------------------------------------------------------------------- /src/modules/autoencoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class EncoderBlock(nn.Module): 5 | def __init__( 6 | self, 7 | in_channels: int, 8 | out_channels: int, 9 | kernel_size: int, 10 | stride: int, 11 | leaky_slope: float = 0.3, 12 | ): 13 | """ 14 | Encoder block module that performs convolution, instance normalization, and leaky ReLU activation. 15 | 16 | Args: 17 | in_channels (int): Number of input channels. 18 | out_channels (int): Number of output channels. 19 | kernel_size (int or tuple): Size of the convolution kernel. 20 | stride (int or tuple): Stride of the convolution operation. 21 | leaky_slope (float): Negative slope coefficient for the leaky ReLU activation. 22 | 23 | """ 24 | super().__init__() 25 | 26 | self.conv = nn.Conv2d( 27 | in_channels, 28 | out_channels, 29 | kernel_size, 30 | stride, 31 | ) 32 | self.norm = nn.InstanceNorm2d(out_channels, affine=True) 33 | self.act = nn.LeakyReLU(leaky_slope) 34 | 35 | def forward(self, x): 36 | """ 37 | Forward pass of the encoder block. 38 | 39 | Args: 40 | x (tensor): Input tensor. 41 | 42 | Returns: 43 | tensor: Output tensor after convolution, instance normalization, and activation. 44 | 45 | """ 46 | x = self.conv(x) 47 | x = self.norm(x) 48 | x = self.act(x) 49 | return x 50 | 51 | 52 | class DecoderBlock(nn.Module): 53 | def __init__( 54 | self, 55 | in_channels: int, 56 | out_channels: int, 57 | kernel_size: int, 58 | stride: int, 59 | leaky_slope: float = 0.3, 60 | ): 61 | """ 62 | Decoder block module that performs transposed convolution, instance normalization, and leaky ReLU activation. 63 | 64 | Args: 65 | in_channels (int): Number of input channels. 66 | out_channels (int): Number of output channels. 67 | kernel_size (int or tuple): Size of the transposed convolution kernel. 68 | stride (int or tuple): Stride of the transposed convolution operation. 69 | leaky_slope (float): Negative slope coefficient for the leaky ReLU activation. 70 | 71 | """ 72 | super().__init__() 73 | 74 | self.conv = nn.ConvTranspose2d( 75 | in_channels, 76 | out_channels, 77 | kernel_size, 78 | stride, 79 | ) 80 | self.norm = nn.InstanceNorm2d(out_channels, affine=True) 81 | self.act = nn.LeakyReLU(leaky_slope) 82 | 83 | def forward(self, x): 84 | x = self.conv(x) 85 | x = self.norm(x) 86 | x = self.act(x) 87 | return x 88 | 89 | 90 | class EncoderAE(nn.Module): 91 | def __init__( 92 | self, 93 | in_channels: int = 3, 94 | base_channels: int = 8, 95 | num_layers: int = 3, 96 | ): 97 | super().__init__() 98 | 99 | self.layers = nn.ModuleList() 100 | for i in range(num_layers): 101 | cout = base_channels * (2**i) 102 | cin = in_channels if i == 0 else cout // 2 103 | self.layers.append(EncoderBlock(cin, cout, 2, 2)) 104 | 105 | def forward(self, x): 106 | for layer in self.layers: 107 | x = layer(x) 108 | return x 109 | 110 | 111 | class DecoderAE(nn.Module): 112 | def __init__( 113 | self, 114 | in_channels: int = 3, 115 | base_channels: int = 8, 116 | num_layers: int = 3, 117 | ): 118 | super().__init__() 119 | 120 | self.layers = nn.ModuleList() 121 | for i in range(num_layers): 122 | cin = base_channels * (2 ** (num_layers - i - 1)) 123 | cout = in_channels if i == num_layers - 1 else cin // 2 124 | self.layers.append(DecoderBlock(cin, cout, 2, 2)) 125 | 126 | def forward(self, x): 127 | for layer in self.layers: 128 | x = layer(x) 129 | return x 130 | 131 | 132 | class FrameAutoEncoder(nn.Module): 133 | def __init__( 134 | self, 135 | in_channels: int = 1, 136 | base_channels: int = 8, 137 | num_layers: int = 3, 138 | ): 139 | """ 140 | Single-frame autoencoder used to obtain frame-level embeddings from a video. 141 | 142 | Args: 143 | - in_channels (int, optional): Number of video channels (1: grayscale, 3: rgb, ...). Defaults to 1. 144 | - base_channels (int, optional): Number of channels in the first convolutional layer, multiplied by 2 each subsequent layer. Defaults to 8. 145 | - num_layers (int, optional): Number of layers (i.e. depth of the autoencoder). Defaults to 3. 146 | """ 147 | super().__init__() 148 | self.encoder = EncoderAE(in_channels, base_channels, num_layers) 149 | self.decoder = DecoderAE(in_channels, base_channels, num_layers) 150 | 151 | def forward(self, x): 152 | return self.reconstruct(x) 153 | 154 | def encode(self, x): 155 | # x is expected to be a tensor of shape [batch, num_sources, frames, w, h]. 156 | # Convert it to [batch * num_sources * frames, w, h] 157 | batch, num_sources, frames, w, h = ( 158 | x.shape[0], 159 | x.shape[1], 160 | x.shape[2], 161 | x.shape[3], 162 | x.shape[4], 163 | ) 164 | x = x.contiguous().view(batch * num_sources * frames, 1, w, h) 165 | 166 | z = self.encoder(x) 167 | 168 | # Undo the view of x. z has [batch * num_sources * frames, c', w', h'] 169 | # Convert it to [batch, num_sources, frames, c' * w' * h'] 170 | z = z.view(batch, num_sources, frames, -1) 171 | return z 172 | 173 | def reconstruct(self, x): 174 | # x is expected to be a tensor of shape [batch, num_sources, frames, w, h]. 175 | # Convert it to [batch * num_sources * frames, w, h] 176 | batch, num_sources, frames, w, h = ( 177 | x.shape[0], 178 | x.shape[1], 179 | x.shape[2], 180 | x.shape[3], 181 | x.shape[4], 182 | ) 183 | x = x.contiguous().view(batch * num_sources * frames, 1, w, h) 184 | 185 | z = self.encoder(x) 186 | y = self.decoder(z) 187 | 188 | # Undo the view of x. y has [batch * num_sources * frames, w, h] 189 | # Convert it to [batch, num_sources, frames, w, h] 190 | y = y.view(batch, num_sources, frames, w, h) 191 | return y 192 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmartelb/avlit/c7ffc0f4abd84b804f8b1e535bdfdd5c879401cf/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from src.avlit import AVLIT 4 | from tests.utils import model_forward_test 5 | 6 | 7 | class TestModels: 8 | """ 9 | Class to test all the models, where each model is defined in an individual function. 10 | Model forward test ``_model_forward_test()`` is called after instantiating the model given the expected input and output shapes. 11 | """ 12 | 13 | @pytest.mark.parametrize("batch_size", [1, 3]) 14 | @pytest.mark.parametrize("num_sources", [1, 2]) 15 | @pytest.mark.parametrize("sr", [8000, 16000]) 16 | @pytest.mark.parametrize("segment_length", [4]) 17 | @pytest.mark.parametrize("fps", [25]) 18 | @pytest.mark.parametrize("audio_num_blocks", [2, 4, 8]) 19 | @pytest.mark.parametrize("video_num_blocks", [1, 2, 4]) 20 | @pytest.mark.parametrize("video_embedding_dim", [1024]) 21 | @pytest.mark.parametrize("fusion_operation", ["sum", "prod", "concat"]) 22 | def test_avlit( 23 | self, 24 | batch_size, 25 | num_sources, 26 | sr, 27 | segment_length, 28 | fps, 29 | audio_num_blocks, 30 | video_num_blocks, 31 | video_embedding_dim, 32 | fusion_operation, 33 | ): 34 | # Instantiate the model 35 | model = AVLIT( 36 | num_sources=num_sources, 37 | audio_num_blocks=audio_num_blocks, 38 | video_num_blocks=video_num_blocks, 39 | video_embedding_dim=video_embedding_dim, 40 | fusion_operation=fusion_operation, 41 | fusion_positions=[0], 42 | ) 43 | 44 | # Generate expected I/O shapes 45 | input_shape = [ 46 | (batch_size, 1, segment_length * sr), # Audio mixture 47 | ( 48 | batch_size, 49 | num_sources, 50 | fps * segment_length, 51 | 64, 52 | 64, 53 | ), # Video inputs (1 video per speaker) 54 | ] 55 | output_shape = [ 56 | (batch_size, num_sources, segment_length * sr), 57 | ] 58 | 59 | # Test the model 60 | model_forward_test(model, input_shape, output_shape, strict=False) 61 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def shapes_equal(x: torch.Tensor, y: torch.Tensor): 5 | assert ( 6 | x.shape == y.shape 7 | ), f"Shapes do not match, expected {x.shape} but got {y.shape}" 8 | 9 | 10 | def data_equal(x: torch.Tensor, y: torch.Tensor, eps=1e-6): 11 | assert torch.allclose( 12 | x, y, atol=eps 13 | ), f"Inputs are not equal, mean absolute difference = {(x-y).abs().sum() / torch.numel(x)}" 14 | 15 | 16 | def model_forward_test( 17 | model, 18 | input_shape, 19 | output_shape, 20 | strict=True, 21 | ): 22 | """ 23 | Generic method to test the forward function of a given model. 24 | The expected input and output shape(s) should be provided. 25 | 26 | Args: 27 | - model (torch.nn.Module): Model to be tested, must implement ``forward()`` or ``__call__()`` methods. 28 | - input_shape (Tuple, list): Shape of the input tensor. 29 | - output_shape (Tuple, list): Shape of the output tensor. 30 | - strict (bool, default: True): Check that the results of forward for ``train()`` and ``eval()`` are consistent. Must be set to ``False`` for the tests to pass if the model contains instances of ``nn.Dropout`` or ``nn.BatchNorm`` since they have different behaviors for training and inference. 31 | """ 32 | if isinstance(input_shape, tuple): 33 | input_shape = [input_shape] 34 | 35 | if isinstance(output_shape, tuple): 36 | output_shape = [output_shape] 37 | 38 | # Generate some random input data 39 | x = [torch.randn(s) for s in input_shape] 40 | 41 | last_output = None 42 | for mode in ["train", "eval"]: 43 | # Set the model to training or inference mode 44 | model = model.train() if mode == "train" else model.eval() 45 | 46 | # Forward operation 47 | y = model(*x) 48 | 49 | if isinstance(y, torch.Tensor): 50 | y = [y] 51 | 52 | # Tests: 53 | # 1) Check that the output shape is the one expected 54 | for i in range(len(output_shape)): 55 | assert ( 56 | y[i].shape == output_shape[i] 57 | ), f"[{mode.upper()}] Output shape does not match. Expected {output_shape[i]}, got {y[i].shape}." 58 | 59 | # 2) Check that the outputs are consistent for the same input in train() and eval() 60 | if last_output is not None and strict: 61 | for i in range(len(output_shape)): 62 | assert data_equal( 63 | y[i], last_output[i] 64 | ), "Output mismatch between TRAIN and EVAL" 65 | 66 | # Save the last output 67 | last_output = y 68 | --------------------------------------------------------------------------------