├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── audio_encoders_pytorch ├── __init__.py ├── modules.py ├── pipelines.py └── utils.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | 8 | # Formats code correctly 9 | - repo: https://github.com/psf/black 10 | rev: 22.3.0 11 | hooks: 12 | - id: black 13 | args: [ 14 | '--experimental-string-processing' 15 | ] 16 | 17 | # Sorts imports 18 | - repo: https://github.com/pycqa/isort 19 | rev: 5.10.1 20 | hooks: 21 | - id: isort 22 | name: isort (python) 23 | args: ["--profile", "black"] 24 | 25 | # Checks unused imports, like lengths, etc 26 | - repo: https://gitlab.com/pycqa/flake8 27 | rev: 4.0.0 28 | hooks: 29 | - id: flake8 30 | args: [ 31 | '--per-file-ignores=__init__.py:F401', 32 | '--max-line-length=88', 33 | '--ignore=E203,W503' 34 | ] 35 | 36 | 37 | 38 | # Checks types 39 | - repo: https://github.com/pre-commit/mirrors-mypy 40 | rev: 'v0.971' 41 | hooks: 42 | - id: mypy 43 | additional_dependencies: [data-science-types>=0.2, torch>=1.6] 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 archinet.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Audio Encoders - PyTorch 2 | 3 | A collection of audio autoencoders, in PyTorch. Pretrained models can be found at [`archisound`](https://github.com/archinetai/archisound). 4 | 5 | ## Install 6 | 7 | ```bash 8 | pip install audio-encoders-pytorch 9 | ``` 10 | 11 | [![PyPI - Python Version](https://img.shields.io/pypi/v/audio-encoders-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/audio-encoders-pytorch/) 12 | 13 | 14 | ## Usage 15 | 16 | ### AutoEncoder1d 17 | ```py 18 | from audio_encoders_pytorch import AutoEncoder1d 19 | 20 | autoencoder = AutoEncoder1d( 21 | in_channels=2, # Number of input channels 22 | channels=32, # Number of base channels 23 | multipliers=[1, 1, 2, 2], # Channel multiplier between layers (i.e. channels * multiplier[i] -> channels * multiplier[i+1]) 24 | factors=[4, 4, 4], # Downsampling/upsampling factor per layer 25 | num_blocks=[2, 2, 2] # Number of resnet blocks per layer 26 | ) 27 | 28 | x = torch.randn(1, 2, 2**18) # [1, 2, 262144] 29 | x_recon = autoencoder(x) # [1, 2, 262144] 30 | ``` 31 | 32 | ### Discriminator1d 33 | ```py 34 | from audio_encoders_pytorch import Discriminator1d 35 | 36 | discriminator = Discriminator1d( 37 | in_channels=2, # Number of input channels 38 | channels=32, # Number of base channels 39 | multipliers=[1, 1, 2, 2], # Channel multiplier between layers (i.e. channels * multiplier[i] -> channels * multiplier[i+1]) 40 | factors=[8, 8, 8], # Downsampling factor per layer 41 | num_blocks=[2, 2, 2], # Number of resnet blocks per layer 42 | use_loss=[True, True, True] # Whether to use this layer as GAN loss 43 | ) 44 | 45 | wave_true = torch.randn(1, 2, 2**18) 46 | wave_fake = torch.randn(1, 2, 2**18) 47 | 48 | loss_generator, loss_discriminator = discriminator(wave_true, wave_fake) 49 | # tensor(0.613949, grad_fn=) tensor(0.097330, grad_fn=) 50 | ``` 51 | -------------------------------------------------------------------------------- /audio_encoders_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import ( 2 | STFT, 3 | AutoEncoder1d, 4 | BitcodesBottleneck, 5 | Bottleneck, 6 | Decoder1d, 7 | Discriminator1d, 8 | Encoder1d, 9 | MAE1d, 10 | ME1d, 11 | MelE1d, 12 | NoiserBottleneck, 13 | TanhBottleneck, 14 | VariationalBottleneck, 15 | ) 16 | -------------------------------------------------------------------------------- /audio_encoders_pytorch/modules.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from einops import pack, rearrange, reduce, unpack 8 | from einops_exts import rearrange_many 9 | from torch import Tensor 10 | from torchaudio import transforms 11 | 12 | from .utils import closest_power_2, default, exists, groupby, prefix_dict, prod, to_list 13 | 14 | """ 15 | Convolutional Modules 16 | """ 17 | 18 | 19 | def Conv1d(*args, **kwargs) -> nn.Module: 20 | return nn.Conv1d(*args, **kwargs) 21 | 22 | 23 | def ConvTranspose1d(*args, **kwargs) -> nn.Module: 24 | return nn.ConvTranspose1d(*args, **kwargs) 25 | 26 | 27 | def Downsample1d( 28 | in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 29 | ) -> nn.Module: 30 | assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" 31 | 32 | return Conv1d( 33 | in_channels=in_channels, 34 | out_channels=out_channels, 35 | kernel_size=factor * kernel_multiplier + 1, 36 | stride=factor, 37 | padding=factor * (kernel_multiplier // 2), 38 | ) 39 | 40 | 41 | def Upsample1d(in_channels: int, out_channels: int, factor: int) -> nn.Module: 42 | if factor == 1: 43 | return Conv1d( 44 | in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 45 | ) 46 | return ConvTranspose1d( 47 | in_channels=in_channels, 48 | out_channels=out_channels, 49 | kernel_size=factor * 2, 50 | stride=factor, 51 | padding=factor // 2 + factor % 2, 52 | output_padding=factor % 2, 53 | ) 54 | 55 | 56 | class ConvBlock1d(nn.Module): 57 | def __init__( 58 | self, 59 | in_channels: int, 60 | out_channels: int, 61 | *, 62 | kernel_size: int = 3, 63 | stride: int = 1, 64 | padding: int = 1, 65 | dilation: int = 1, 66 | num_groups: int = 8, 67 | use_norm: bool = True, 68 | ) -> None: 69 | super().__init__() 70 | 71 | self.groupnorm = ( 72 | nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) 73 | if use_norm 74 | else nn.Identity() 75 | ) 76 | self.activation = nn.SiLU() 77 | self.project = Conv1d( 78 | in_channels=in_channels, 79 | out_channels=out_channels, 80 | kernel_size=kernel_size, 81 | stride=stride, 82 | padding=padding, 83 | dilation=dilation, 84 | ) 85 | 86 | def forward(self, x: Tensor) -> Tensor: 87 | x = self.groupnorm(x) 88 | x = self.activation(x) 89 | return self.project(x) 90 | 91 | 92 | class ResnetBlock1d(nn.Module): 93 | def __init__( 94 | self, 95 | in_channels: int, 96 | out_channels: int, 97 | *, 98 | kernel_size: int = 3, 99 | stride: int = 1, 100 | padding: int = 1, 101 | dilation: int = 1, 102 | use_norm: bool = True, 103 | num_groups: int = 8, 104 | ) -> None: 105 | super().__init__() 106 | 107 | self.block1 = ConvBlock1d( 108 | in_channels=in_channels, 109 | out_channels=out_channels, 110 | kernel_size=kernel_size, 111 | stride=stride, 112 | padding=padding, 113 | dilation=dilation, 114 | use_norm=use_norm, 115 | num_groups=num_groups, 116 | ) 117 | 118 | self.block2 = ConvBlock1d( 119 | in_channels=out_channels, 120 | out_channels=out_channels, 121 | use_norm=use_norm, 122 | num_groups=num_groups, 123 | ) 124 | 125 | self.to_out = ( 126 | Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) 127 | if in_channels != out_channels 128 | else nn.Identity() 129 | ) 130 | 131 | def forward(self, x: Tensor) -> Tensor: 132 | h = self.block1(x) 133 | h = self.block2(h) 134 | return h + self.to_out(x) 135 | 136 | 137 | class Patcher(nn.Module): 138 | def __init__(self, in_channels: int, out_channels: int, patch_size: int): 139 | super().__init__() 140 | assert_message = f"out_channels must be divisible by patch_size ({patch_size})" 141 | assert out_channels % patch_size == 0, assert_message 142 | self.patch_size = patch_size 143 | 144 | self.block = ResnetBlock1d( 145 | in_channels=in_channels, 146 | out_channels=out_channels // patch_size, 147 | num_groups=min(patch_size, in_channels), 148 | ) 149 | 150 | def forward(self, x: Tensor) -> Tensor: 151 | x = self.block(x) 152 | x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) 153 | return x 154 | 155 | 156 | class Unpatcher(nn.Module): 157 | def __init__(self, in_channels: int, out_channels: int, patch_size: int): 158 | super().__init__() 159 | assert_message = f"in_channels must be divisible by patch_size ({patch_size})" 160 | assert in_channels % patch_size == 0, assert_message 161 | self.patch_size = patch_size 162 | 163 | self.block = ResnetBlock1d( 164 | in_channels=in_channels // patch_size, 165 | out_channels=out_channels, 166 | num_groups=min(patch_size, out_channels), 167 | ) 168 | 169 | def forward(self, x: Tensor) -> Tensor: 170 | x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) 171 | x = self.block(x) 172 | return x 173 | 174 | 175 | class DownsampleBlock1d(nn.Module): 176 | def __init__( 177 | self, 178 | in_channels: int, 179 | out_channels: int, 180 | *, 181 | factor: int, 182 | num_groups: int, 183 | num_layers: int, 184 | ): 185 | super().__init__() 186 | 187 | self.downsample = Downsample1d( 188 | in_channels=in_channels, out_channels=out_channels, factor=factor 189 | ) 190 | 191 | self.blocks = nn.ModuleList( 192 | [ 193 | ResnetBlock1d( 194 | in_channels=out_channels, 195 | out_channels=out_channels, 196 | num_groups=num_groups, 197 | ) 198 | for i in range(num_layers) 199 | ] 200 | ) 201 | 202 | def forward(self, x: Tensor) -> Tensor: 203 | x = self.downsample(x) 204 | for block in self.blocks: 205 | x = block(x) 206 | return x 207 | 208 | 209 | class UpsampleBlock1d(nn.Module): 210 | def __init__( 211 | self, 212 | in_channels: int, 213 | out_channels: int, 214 | *, 215 | factor: int, 216 | num_layers: int, 217 | num_groups: int, 218 | ): 219 | super().__init__() 220 | 221 | self.blocks = nn.ModuleList( 222 | [ 223 | ResnetBlock1d( 224 | in_channels=in_channels, 225 | out_channels=in_channels, 226 | num_groups=num_groups, 227 | ) 228 | for _ in range(num_layers) 229 | ] 230 | ) 231 | 232 | self.upsample = Upsample1d( 233 | in_channels=in_channels, out_channels=out_channels, factor=factor 234 | ) 235 | 236 | def forward(self, x: Tensor) -> Tensor: 237 | for block in self.blocks: 238 | x = block(x) 239 | x = self.upsample(x) 240 | return x 241 | 242 | 243 | """ 244 | Encoders / Decoders 245 | """ 246 | 247 | 248 | class Bottleneck(nn.Module): 249 | def forward( 250 | self, x: Tensor, with_info: bool = False 251 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 252 | raise NotImplementedError() 253 | 254 | 255 | class Encoder1d(nn.Module): 256 | def __init__( 257 | self, 258 | in_channels: int, 259 | channels: int, 260 | multipliers: Sequence[int], 261 | factors: Sequence[int], 262 | num_blocks: Sequence[int], 263 | patch_size: int = 1, 264 | resnet_groups: int = 8, 265 | out_channels: Optional[int] = None, 266 | bottleneck: Union[Bottleneck, List[Bottleneck]] = [], 267 | ): 268 | super().__init__() 269 | self.bottlenecks = nn.ModuleList(to_list(bottleneck)) 270 | self.num_layers = len(multipliers) - 1 271 | self.downsample_factor = patch_size * prod(factors) 272 | self.out_channels = ( 273 | out_channels if exists(out_channels) else channels * multipliers[-1] 274 | ) 275 | assert len(factors) == self.num_layers and len(num_blocks) == self.num_layers 276 | 277 | self.to_in = Patcher( 278 | in_channels=in_channels, 279 | out_channels=channels * multipliers[0], 280 | patch_size=patch_size, 281 | ) 282 | 283 | self.downsamples = nn.ModuleList( 284 | [ 285 | DownsampleBlock1d( 286 | in_channels=channels * multipliers[i], 287 | out_channels=channels * multipliers[i + 1], 288 | factor=factors[i], 289 | num_groups=resnet_groups, 290 | num_layers=num_blocks[i], 291 | ) 292 | for i in range(self.num_layers) 293 | ] 294 | ) 295 | 296 | self.to_out = ( 297 | nn.Conv1d( 298 | in_channels=channels * multipliers[-1], 299 | out_channels=out_channels, 300 | kernel_size=1, 301 | ) 302 | if exists(out_channels) 303 | else nn.Identity() 304 | ) 305 | 306 | def forward( 307 | self, x: Tensor, with_info: bool = False 308 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 309 | xs = [x] 310 | x = self.to_in(x) 311 | xs += [x] 312 | 313 | for downsample in self.downsamples: 314 | x = downsample(x) 315 | xs += [x] 316 | 317 | x = self.to_out(x) 318 | xs += [x] 319 | info = dict(xs=xs) 320 | 321 | for bottleneck in self.bottlenecks: 322 | x, info_bottleneck = bottleneck(x, with_info=True) 323 | info = {**info, **prefix_dict("bottleneck_", info_bottleneck)} 324 | 325 | return (x, info) if with_info else x 326 | 327 | 328 | class Decoder1d(nn.Module): 329 | def __init__( 330 | self, 331 | out_channels: int, 332 | channels: int, 333 | multipliers: Sequence[int], 334 | factors: Sequence[int], 335 | num_blocks: Sequence[int], 336 | patch_size: int = 1, 337 | resnet_groups: int = 8, 338 | in_channels: Optional[int] = None, 339 | ): 340 | super().__init__() 341 | num_layers = len(multipliers) - 1 342 | 343 | assert len(factors) == num_layers and len(num_blocks) == num_layers 344 | 345 | self.to_in = ( 346 | Conv1d( 347 | in_channels=in_channels, 348 | out_channels=channels * multipliers[0], 349 | kernel_size=1, 350 | ) 351 | if exists(in_channels) 352 | else nn.Identity() 353 | ) 354 | 355 | self.upsamples = nn.ModuleList( 356 | [ 357 | UpsampleBlock1d( 358 | in_channels=channels * multipliers[i], 359 | out_channels=channels * multipliers[i + 1], 360 | factor=factors[i], 361 | num_groups=resnet_groups, 362 | num_layers=num_blocks[i], 363 | ) 364 | for i in range(num_layers) 365 | ] 366 | ) 367 | 368 | self.to_out = Unpatcher( 369 | in_channels=channels * multipliers[-1], 370 | out_channels=out_channels, 371 | patch_size=patch_size, 372 | ) 373 | 374 | def forward( 375 | self, x: Tensor, with_info: bool = False 376 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 377 | xs = [x] 378 | x = self.to_in(x) 379 | xs += [x] 380 | 381 | for upsample in self.upsamples: 382 | x = upsample(x) 383 | xs += [x] 384 | 385 | x = self.to_out(x) 386 | xs += [x] 387 | 388 | info = dict(xs=xs) 389 | return (x, info) if with_info else x 390 | 391 | 392 | class AutoEncoder1d(nn.Module): 393 | def __init__( 394 | self, 395 | in_channels: int, 396 | channels: int, 397 | multipliers: Sequence[int], 398 | factors: Sequence[int], 399 | num_blocks: Sequence[int], 400 | patch_size: int = 1, 401 | resnet_groups: int = 8, 402 | out_channels: Optional[int] = None, 403 | bottleneck: Union[Bottleneck, List[Bottleneck]] = [], 404 | bottleneck_channels: Optional[int] = None, 405 | ): 406 | super().__init__() 407 | out_channels = default(out_channels, in_channels) 408 | 409 | self.encoder = Encoder1d( 410 | in_channels=in_channels, 411 | out_channels=bottleneck_channels, 412 | channels=channels, 413 | multipliers=multipliers, 414 | factors=factors, 415 | num_blocks=num_blocks, 416 | patch_size=patch_size, 417 | resnet_groups=resnet_groups, 418 | bottleneck=bottleneck, 419 | ) 420 | 421 | self.decoder = Decoder1d( 422 | in_channels=bottleneck_channels, 423 | out_channels=out_channels, 424 | channels=channels, 425 | multipliers=multipliers[::-1], 426 | factors=factors[::-1], 427 | num_blocks=num_blocks[::-1], 428 | patch_size=patch_size, 429 | resnet_groups=resnet_groups, 430 | ) 431 | 432 | def forward( 433 | self, x: Tensor, with_info: bool = False 434 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 435 | z, info_encoder = self.encode(x, with_info=True) 436 | y, info_decoder = self.decode(z, with_info=True) 437 | info = { 438 | **dict(latent=z), 439 | **prefix_dict("encoder_", info_encoder), 440 | **prefix_dict("decoder_", info_decoder), 441 | } 442 | return (y, info) if with_info else y 443 | 444 | def encode( 445 | self, x: Tensor, with_info: bool = False 446 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 447 | return self.encoder(x, with_info=with_info) 448 | 449 | def decode(self, x: Tensor, with_info: bool = False) -> Tensor: 450 | return self.decoder(x, with_info=with_info) 451 | 452 | 453 | class STFT(nn.Module): 454 | """Helper for torch stft and istft""" 455 | 456 | def __init__( 457 | self, 458 | num_fft: int = 1023, 459 | hop_length: int = 256, 460 | window_length: Optional[int] = None, 461 | length: Optional[int] = None, 462 | use_complex: bool = False, 463 | ): 464 | super().__init__() 465 | self.num_fft = num_fft 466 | self.hop_length = default(hop_length, floor(num_fft // 4)) 467 | self.window_length = default(window_length, num_fft) 468 | self.length = length 469 | self.register_buffer("window", torch.hann_window(self.window_length)) 470 | self.use_complex = use_complex 471 | 472 | def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: 473 | b = wave.shape[0] 474 | wave = rearrange(wave, "b c t -> (b c) t") 475 | 476 | stft = torch.stft( 477 | wave, 478 | n_fft=self.num_fft, 479 | hop_length=self.hop_length, 480 | win_length=self.window_length, 481 | window=self.window, # type: ignore 482 | return_complex=True, 483 | normalized=True, 484 | ) 485 | 486 | if self.use_complex: 487 | # Returns real and imaginary 488 | stft_a, stft_b = stft.real, stft.imag 489 | else: 490 | # Returns magnitude and phase matrices 491 | magnitude, phase = torch.abs(stft), torch.angle(stft) 492 | stft_a, stft_b = magnitude, phase 493 | 494 | return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) 495 | 496 | def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: 497 | b, l = stft_a.shape[0], stft_a.shape[-1] # noqa 498 | length = closest_power_2(l * self.hop_length) 499 | 500 | stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") 501 | 502 | if self.use_complex: 503 | real, imag = stft_a, stft_b 504 | else: 505 | magnitude, phase = stft_a, stft_b 506 | real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) 507 | 508 | stft = torch.stack([real, imag], dim=-1) 509 | 510 | wave = torch.istft( 511 | stft, 512 | n_fft=self.num_fft, 513 | hop_length=self.hop_length, 514 | win_length=self.window_length, 515 | window=self.window, # type: ignore 516 | length=default(self.length, length), 517 | normalized=True, 518 | ) 519 | 520 | return rearrange(wave, "(b c) t -> b c t", b=b) 521 | 522 | def encode1d( 523 | self, wave: Tensor, stacked: bool = True 524 | ) -> Union[Tensor, Tuple[Tensor, Tensor]]: 525 | stft_a, stft_b = self.encode(wave) 526 | stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") 527 | return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) 528 | 529 | def decode1d(self, stft_pair: Tensor) -> Tensor: 530 | f = self.num_fft // 2 + 1 531 | stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) 532 | stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) 533 | return self.decode(stft_a, stft_b) 534 | 535 | 536 | class ME1d(Encoder1d): 537 | """Magnitude Encoder""" 538 | 539 | def __init__( 540 | self, in_channels: int, stft_num_fft: int, use_log: bool = False, **kwargs 541 | ): 542 | self.use_log = use_log 543 | self.frequency_channels = stft_num_fft // 2 + 1 544 | stft_kwargs, kwargs = groupby("stft_", kwargs) 545 | super().__init__(in_channels=in_channels * self.frequency_channels, **kwargs) 546 | self.stft = STFT(num_fft=stft_num_fft, **stft_kwargs) 547 | self.downsample_factor *= self.stft.hop_length 548 | 549 | def forward(self, x: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Any]]: # type: ignore # noqa 550 | magnitude, _ = self.stft.encode(x) 551 | magnitude = rearrange(magnitude, "b c f l -> b (c f) l") 552 | magnitude = torch.log(magnitude) if self.use_log else magnitude 553 | return super().forward(magnitude, **kwargs) 554 | 555 | 556 | class MAE1d(AutoEncoder1d): 557 | """Magnitude Auto Encoder""" 558 | 559 | def __init__(self, in_channels: int, stft_num_fft: int = 1023, **kwargs): 560 | self.frequency_channels = stft_num_fft // 2 + 1 561 | stft_kwargs, kwargs = groupby("stft_", kwargs) 562 | super().__init__(in_channels=in_channels * self.frequency_channels, **kwargs) 563 | self.stft = STFT(num_fft=stft_num_fft, **stft_kwargs) 564 | 565 | def encode(self, magnitude: Tensor, **kwargs): # type: ignore 566 | log_magnitude = torch.log(magnitude) 567 | log_magnitude_flat = rearrange(log_magnitude, "b c f l -> b (c f) l") 568 | return super().encode(log_magnitude_flat, **kwargs) 569 | 570 | def decode( # type: ignore 571 | self, latent: Tensor, with_info: bool = False 572 | ) -> Union[Tensor, Tuple[Tensor, Dict]]: 573 | f = self.frequency_channels 574 | log_magnitude_flat, info = super().decode(latent, with_info=True) 575 | log_magnitude = rearrange(log_magnitude_flat, "b (c f) l -> b c f l", f=f) 576 | log_magnitude = torch.clamp(log_magnitude, -30.0, 20.0) 577 | magnitude = torch.exp(log_magnitude) 578 | info = dict(log_magnitude=log_magnitude, **info) 579 | return (magnitude, info) if with_info else magnitude 580 | 581 | def loss( 582 | self, wave: Tensor, with_info: bool = False 583 | ) -> Union[Tensor, Tuple[Tensor, Dict]]: 584 | magnitude, _ = self.stft.encode(wave) 585 | magnitude_pred, info = self(magnitude, with_info=True) 586 | loss = F.l1_loss(torch.log(magnitude), torch.log(magnitude_pred)) 587 | return (loss, info) if with_info else loss 588 | 589 | 590 | class MelSpectrogram(nn.Module): 591 | def __init__( 592 | self, 593 | n_fft: int = 1024, 594 | hop_length: int = 256, 595 | win_length: int = 1024, 596 | sample_rate: int = 48000, 597 | n_mel_channels: int = 80, 598 | center: bool = False, 599 | normalize: bool = False, 600 | normalize_log: bool = False, 601 | ): 602 | super().__init__() 603 | self.padding = (n_fft - hop_length) // 2 604 | self.normalize = normalize 605 | self.normalize_log = normalize_log 606 | self.hop_length = hop_length 607 | 608 | self.to_spectrogram = transforms.Spectrogram( 609 | n_fft=n_fft, 610 | hop_length=hop_length, 611 | win_length=win_length, 612 | center=center, 613 | power=None, 614 | ) 615 | 616 | self.to_mel_scale = transforms.MelScale( 617 | n_mels=n_mel_channels, n_stft=n_fft // 2 + 1, sample_rate=sample_rate 618 | ) 619 | 620 | def forward(self, waveform: Tensor) -> Tensor: 621 | # Pack non-time dimension 622 | waveform, ps = pack([waveform], "* t") 623 | # Pad waveform 624 | waveform = F.pad(waveform, [self.padding] * 2, mode="reflect") 625 | # Compute STFT 626 | spectrogram = self.to_spectrogram(waveform) 627 | # Compute magnitude 628 | spectrogram = torch.abs(spectrogram) 629 | # Convert to mel scale 630 | mel_spectrogram = self.to_mel_scale(spectrogram) 631 | # Normalize 632 | if self.normalize: 633 | mel_spectrogram = mel_spectrogram / torch.max(mel_spectrogram) 634 | mel_spectrogram = 2 * torch.pow(mel_spectrogram, 0.25) - 1 635 | if self.normalize_log: 636 | mel_spectrogram = torch.log(torch.clamp(mel_spectrogram, min=1e-5)) 637 | # Unpack non-spectrogram dimension 638 | return unpack(mel_spectrogram, ps, "* f l")[0] 639 | 640 | 641 | class MelE1d(Encoder1d): 642 | """Magnitude Encoder""" 643 | 644 | def __init__(self, in_channels: int, mel_channels: int, **kwargs): 645 | mel_kwargs, kwargs = groupby("mel_", kwargs) 646 | super().__init__(in_channels=in_channels * mel_channels, **kwargs) 647 | self.mel = MelSpectrogram(n_mel_channels=mel_channels, **mel_kwargs) 648 | self.downsample_factor *= self.mel.hop_length 649 | 650 | def forward(self, x: Tensor, **kwargs) -> Union[Tensor, Tuple[Tensor, Any]]: # type: ignore # noqa 651 | mel = rearrange(self.mel(x), "b c f l -> b (c f) l") 652 | return super().forward(mel, **kwargs) 653 | 654 | 655 | """ 656 | Bottlenecks 657 | """ 658 | 659 | 660 | def gaussian_sample(mean: Tensor, logvar: Tensor) -> Tensor: 661 | std = torch.exp(0.5 * logvar) 662 | eps = torch.randn_like(std) 663 | sample = mean + std * eps 664 | return sample 665 | 666 | 667 | def kl_loss(mean: Tensor, logvar: Tensor) -> Tensor: 668 | losses = mean**2 + logvar.exp() - logvar - 1 669 | loss = reduce(losses, "b ... -> 1", "mean").item() 670 | return loss 671 | 672 | 673 | class VariationalBottleneck(Bottleneck): 674 | def __init__(self, channels: int, loss_weight: float = 1.0): 675 | super().__init__() 676 | self.loss_weight = loss_weight 677 | self.to_mean_and_std = Conv1d( 678 | in_channels=channels, 679 | out_channels=channels * 2, 680 | kernel_size=1, 681 | ) 682 | 683 | def forward( 684 | self, x: Tensor, with_info: bool = False 685 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 686 | mean_and_std = self.to_mean_and_std(x) 687 | mean, std = mean_and_std.chunk(chunks=2, dim=1) 688 | mean = torch.tanh(mean) # mean in range [-1, 1] 689 | std = torch.tanh(std) + 1.0 # std in range [0, 2] 690 | out = gaussian_sample(mean, std) 691 | info = dict( 692 | variational_kl_loss=kl_loss(mean, std) * self.loss_weight, 693 | variational_mean=mean, 694 | variational_std=std, 695 | ) 696 | return (out, info) if with_info else out 697 | 698 | 699 | class TanhBottleneck(Bottleneck): 700 | def forward( 701 | self, x: Tensor, with_info: bool = False 702 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 703 | x = torch.tanh(x) 704 | info: Dict = dict() 705 | return (x, info) if with_info else x 706 | 707 | 708 | class NoiserBottleneck(Bottleneck): 709 | def __init__(self, sigma: float = 1.0): 710 | super().__init__() 711 | self.sigma = sigma 712 | 713 | def forward( 714 | self, x: Tensor, with_info: bool = False 715 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 716 | if self.training: 717 | x = torch.randn_like(x) * self.sigma + x 718 | info: Dict = dict() 719 | return (x, info) if with_info else x 720 | 721 | 722 | class BitcodesBottleneck(Bottleneck): 723 | def __init__(self, channels: int, num_bits: int, temperature: float = 1.0): 724 | super().__init__() 725 | from bitcodes_pytorch import Bitcodes 726 | 727 | self.bitcodes = Bitcodes( 728 | features=channels, num_bits=num_bits, temperature=temperature 729 | ) 730 | 731 | def forward( 732 | self, x: Tensor, with_info: bool = False 733 | ) -> Union[Tensor, Tuple[Tensor, Any]]: 734 | x = rearrange(x, "b c t -> b t c") 735 | x, bits = self.bitcodes(x) 736 | x = rearrange(x, "b t c -> b c t") 737 | info: Dict = dict(bits=bits) 738 | return (x, info) if with_info else x 739 | 740 | 741 | """ 742 | Discriminators 743 | """ 744 | 745 | 746 | class Discriminator1d(nn.Module): 747 | def __init__(self, use_loss: Optional[Sequence[bool]] = None, **kwargs): 748 | super().__init__() 749 | self.discriminator = Encoder1d(**kwargs) 750 | num_layers = self.discriminator.num_layers 751 | # By default we activate discrimination loss extraction on all layers 752 | self.use_loss = default(use_loss, [True] * num_layers) 753 | # Check correct length 754 | msg = f"use_loss length must match the number of layers ({num_layers})" 755 | assert len(self.use_loss) == num_layers, msg 756 | 757 | def forward( 758 | self, true: Tensor, fake: Tensor, with_info: bool = False 759 | ) -> Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Dict]]: 760 | # Get discriminator outputs for true/fake scores 761 | _, info_true = self.discriminator(true, with_info=True) 762 | _, info_fake = self.discriminator(fake, with_info=True) 763 | 764 | # Get all intermediate layer features (ignore input) 765 | xs_true = info_true["xs"][1:] 766 | xs_fake = info_fake["xs"][1:] 767 | 768 | loss_gs, loss_ds, scores_true, scores_fake = [], [], [], [] 769 | 770 | for use_loss, x_true, x_fake in zip(self.use_loss, xs_true, xs_fake): 771 | if use_loss: 772 | # Half the channels are used for scores, the other for features 773 | score_true, feat_true = x_true.chunk(chunks=2, dim=1) 774 | score_fake, feat_fake = x_fake.chunk(chunks=2, dim=1) 775 | # Generator must match features with true sample and fool discriminator 776 | loss_gs += [F.l1_loss(feat_true, feat_fake) - score_fake.mean()] 777 | # Discriminator must give high score to true samples, low to fake 778 | loss_ds += [((1 - score_true).relu() + (1 + score_fake).relu()).mean()] 779 | # Save scores 780 | scores_true += [score_true.mean()] 781 | scores_fake += [score_fake.mean()] 782 | 783 | # Average all generator/discriminator losses over all layers 784 | loss_g = torch.stack(loss_gs).mean() 785 | loss_d = torch.stack(loss_ds).mean() 786 | 787 | info = dict(scores_true=scores_true, scores_fake=scores_fake) 788 | 789 | return (loss_g, loss_d, info) if with_info else (loss_g, loss_d) 790 | -------------------------------------------------------------------------------- /audio_encoders_pytorch/pipelines.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from .modules import AutoEncoder1d 9 | 10 | 11 | class StackedPipeline(nn.Module): 12 | def __init__( 13 | self, 14 | autoencoders: Sequence[AutoEncoder1d], 15 | num_stage_steps: Sequence[int], 16 | use_inner_loss: bool = False, 17 | ): 18 | super().__init__() 19 | assert_message = "len(num_stage_steps)+1 must equal len(autoencoders)" 20 | assert len(autoencoders) == len(num_stage_steps) + 1, assert_message 21 | 22 | self.autoencoders = nn.ModuleList(autoencoders) 23 | self.num_stage_steps = num_stage_steps 24 | self.use_inner_loss = use_inner_loss 25 | self.register_buffer("step_id", torch.tensor(0)) 26 | self.register_buffer("stage_id", torch.tensor(0)) 27 | 28 | # Init multi resolution stft loss 29 | import auraloss 30 | 31 | scales = [2048, 1024, 512, 256, 128] 32 | hop_sizes, win_lengths, overlap = [], [], 0.75 33 | for scale in scales: 34 | hop_sizes += [int(scale * (1.0 - overlap))] 35 | win_lengths += [scale] 36 | self.loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss( 37 | fft_sizes=scales, hop_sizes=hop_sizes, win_lengths=win_lengths 38 | ) 39 | 40 | def step(self): 41 | # Check if next pipeline stage has to be activated 42 | for i, step in enumerate(self.num_stage_steps): 43 | if self.step_id == step: 44 | self.stage_id += 1 45 | self.stage_changed() 46 | print(f"Stage {self.stage_id-1} completed.") 47 | self.step_id += 1 48 | 49 | def stage_changed(self) -> None: 50 | num_stages = len(self.num_stage_steps) 51 | for i in range(self.stage_id): # type: ignore 52 | self.autoencoders[i].requires_grad_(False) 53 | self.autoencoders[i].eval() 54 | for i in range(self.stage_id, num_stages): # type: ignore 55 | self.autoencoders[i].requires_grad_(True) 56 | self.autoencoders[i].train() 57 | 58 | def encode( 59 | self, x: Tensor, with_info: bool = False 60 | ) -> Union[Tensor, Tuple[Tensor, Dict]]: 61 | info: Dict = dict(encoders=[]) 62 | for i in range(self.stage_id + 1): # type: ignore 63 | x, info_encoder = self.autoencoders[i].encode(x, with_info=True) 64 | info["encoders"] += [info_encoder] 65 | return (x, info) if with_info else x 66 | 67 | def decode( 68 | self, x: Tensor, with_info: bool = False 69 | ) -> Union[Tensor, Tuple[Tensor, Dict]]: 70 | info: Dict = dict(decoders=[]) 71 | for i in reversed(range(self.stage_id + 1)): # type: ignore 72 | x, info_decoder = self.autoencoders[i].decode(x, with_info=True) 73 | info["decoders"] += [info_decoder] 74 | return (x, info) if with_info else x 75 | 76 | def forward( 77 | self, x: Tensor, with_info: bool = False 78 | ) -> Union[Tensor, Tuple[Tensor, Dict]]: 79 | if self.training: 80 | self.step() 81 | 82 | z, info_encoders = self.encode(x, with_info=True) 83 | y, info_decoders = self.decode(z, with_info=True) 84 | info = dict(**info_encoders, **info_decoders, latent=z) 85 | 86 | if self.use_inner_loss and self.stage_id > 0: # type: ignore 87 | inner_input = info["encoders"][-1]["xs"][0] 88 | inner_output = info["decoders"][0]["xs"][-1] 89 | loss = F.mse_loss(inner_input, inner_output) 90 | else: 91 | loss = self.loss_fn(x, y) 92 | 93 | return (loss, info) if with_info else loss 94 | -------------------------------------------------------------------------------- /audio_encoders_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from inspect import isfunction 3 | from math import ceil, floor, log2 4 | from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union 5 | 6 | from typing_extensions import TypeGuard 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | def exists(val: Optional[T]) -> TypeGuard[T]: 12 | return val is not None 13 | 14 | 15 | def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: 16 | if exists(val): 17 | return val 18 | return d() if isfunction(d) else d 19 | 20 | 21 | def to_list(val: Union[T, Sequence[T]]) -> List[T]: 22 | if isinstance(val, tuple): 23 | return list(val) 24 | if isinstance(val, list): 25 | return val 26 | return [val] # type: ignore 27 | 28 | 29 | def closest_power_2(x: float) -> int: 30 | exponent = log2(x) 31 | distance_fn = lambda z: abs(x - 2**z) # noqa 32 | exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) 33 | return 2 ** int(exponent_closest) 34 | 35 | 36 | def prod(vals: Sequence[int]) -> int: 37 | return reduce(lambda x, y: x * y, vals) 38 | 39 | 40 | """ 41 | Kwargs Utils 42 | """ 43 | 44 | 45 | def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: 46 | return_dicts: Tuple[Dict, Dict] = ({}, {}) 47 | for key in d.keys(): 48 | no_prefix = int(not key.startswith(prefix)) 49 | return_dicts[no_prefix][key] = d[key] 50 | return return_dicts 51 | 52 | 53 | def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: 54 | kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) 55 | if keep_prefix: 56 | return kwargs_with_prefix, kwargs 57 | kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} 58 | return kwargs_no_prefix, kwargs 59 | 60 | 61 | def prefix_dict(prefix: str, d: Dict) -> Dict: 62 | return {prefix + str(k): v for k, v in d.items()} 63 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="audio-encoders-pytorch", 5 | packages=find_packages(exclude=[]), 6 | version="0.0.22", 7 | license="MIT", 8 | description="Audio Encoders - PyTorch", 9 | long_description_content_type="text/markdown", 10 | author="Flavio Schneider", 11 | author_email="archinetai@protonmail.com", 12 | url="https://github.com/archinetai/audio-encoders-pytorch", 13 | keywords=["artificial intelligence", "deep learning", "audio"], 14 | install_requires=[ 15 | "torch>=1.6", 16 | "torchaudio", 17 | "data-science-types>=0.2", 18 | "einops>=0.6", 19 | "einops-exts>=0.0.3", 20 | ], 21 | classifiers=[ 22 | "Development Status :: 4 - Beta", 23 | "Intended Audience :: Developers", 24 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 25 | "License :: OSI Approved :: MIT License", 26 | "Programming Language :: Python :: 3.6", 27 | ], 28 | ) 29 | --------------------------------------------------------------------------------