├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── bs-roformer.png ├── bs_roformer ├── __init__.py ├── attend.py ├── bs_roformer.py └── mel_band_roformer.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## BS-RoFormer 4 | 5 | Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs. They beat the previous first place by a large margin. The technique uses axial attention across frequency (hence multi-band) and time. They also have experiments to show that rotary positional encoding led to a huge improvement over learned absolute positions. 6 | 7 | It also includes support for stereo training and outputting multiple stems. 8 | 9 | Please join Join us on Discord if you are interested in replicating a SOTA music source separator out in the open 10 | 11 | Update: This paper has been replicated by Roman and weight open sourced here 12 | 13 | Update 2: Used for this Katy Perry remix! 14 | 15 | Update 3: Kimberley Jensen has open sourced a MelBand Roformer trained on vocals here! 16 | 17 | ## Appreciation 18 | 19 | - StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence. 20 | 21 | - Roee and Fabian-Robert for sharing their audio expertise and fixing audio hyperparameters 22 | 23 | - @chenht2010 and Roman for working out the default band splitting hyperparameter! 24 | 25 | - Max Prod for reporting a big bug with Mel-Band Roformer with stereo training! 26 | 27 | - Roman for successfully training the model and open sourcing his training code and weights at this repository! 28 | 29 | - Christopher for fixing an issue with multiple stems in Mel-Band Roformer 30 | 31 | - Iver Jordal for identifying that the default stft window function is not correct 32 | 33 | ## Install 34 | 35 | ```bash 36 | $ pip install BS-RoFormer 37 | ``` 38 | 39 | ## Usage 40 | 41 | ```python 42 | import torch 43 | from bs_roformer import BSRoformer 44 | 45 | model = BSRoformer( 46 | dim = 512, 47 | depth = 12, 48 | time_transformer_depth = 1, 49 | freq_transformer_depth = 1 50 | ) 51 | 52 | x = torch.randn(2, 352800) 53 | target = torch.randn(2, 352800) 54 | 55 | loss = model(x, target = target) 56 | loss.backward() 57 | 58 | # after much training 59 | 60 | out = model(x) 61 | ``` 62 | 63 | To use the Mel-Band Roformer proposed in a recent follow up paper, simply import `MelBandRoformer` instead 64 | 65 | ```python 66 | import torch 67 | from bs_roformer import MelBandRoformer 68 | 69 | model = MelBandRoformer( 70 | dim = 32, 71 | depth = 1, 72 | time_transformer_depth = 1, 73 | freq_transformer_depth = 1 74 | ) 75 | 76 | x = torch.randn(2, 352800) 77 | target = torch.randn(2, 352800) 78 | 79 | loss = model(x, target = target) 80 | loss.backward() 81 | 82 | # after much training 83 | 84 | out = model(x) 85 | ``` 86 | 87 | ## Todo 88 | 89 | - [x] get the multiscale stft loss in there 90 | - [x] figure out what `n_fft` should be 91 | - [x] review band split + mask estimation modules 92 | 93 | ## Citations 94 | 95 | ```bibtex 96 | @inproceedings{Lu2023MusicSS, 97 | title = {Music Source Separation with Band-Split RoPE Transformer}, 98 | author = {Wei-Tsung Lu and Ju-Chiang Wang and Qiuqiang Kong and Yun-Ning Hung}, 99 | year = {2023}, 100 | url = {https://api.semanticscholar.org/CorpusID:261556702} 101 | } 102 | ``` 103 | 104 | ```bibtex 105 | @inproceedings{Wang2023MelBandRF, 106 | title = {Mel-Band RoFormer for Music Source Separation}, 107 | author = {Ju-Chiang Wang and Wei-Tsung Lu and Minz Won}, 108 | year = {2023}, 109 | url = {https://api.semanticscholar.org/CorpusID:263608675} 110 | } 111 | ``` 112 | 113 | ```bibtex 114 | @misc{ho2019axial, 115 | title = {Axial Attention in Multidimensional Transformers}, 116 | author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans}, 117 | year = {2019}, 118 | archivePrefix = {arXiv} 119 | } 120 | ``` 121 | 122 | ```bibtex 123 | @misc{su2021roformer, 124 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 125 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, 126 | year = {2021}, 127 | eprint = {2104.09864}, 128 | archivePrefix = {arXiv}, 129 | primaryClass = {cs.CL} 130 | } 131 | ``` 132 | 133 | ```bibtex 134 | @inproceedings{dao2022flashattention, 135 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 136 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 137 | booktitle = {Advances in Neural Information Processing Systems}, 138 | year = {2022} 139 | } 140 | ``` 141 | 142 | ```bibtex 143 | @article{Bondarenko2023QuantizableTR, 144 | title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing}, 145 | author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort}, 146 | journal = {ArXiv}, 147 | year = {2023}, 148 | volume = {abs/2306.12929}, 149 | url = {https://api.semanticscholar.org/CorpusID:259224568} 150 | } 151 | ``` 152 | 153 | ```bibtex 154 | @inproceedings{ElNouby2021XCiTCI, 155 | title = {XCiT: Cross-Covariance Image Transformers}, 156 | author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou}, 157 | booktitle = {Neural Information Processing Systems}, 158 | year = {2021}, 159 | url = {https://api.semanticscholar.org/CorpusID:235458262} 160 | } 161 | ``` 162 | 163 | ```bibtex 164 | @inproceedings{Zhou2024ValueRL, 165 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, 166 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, 167 | year = {2024}, 168 | url = {https://api.semanticscholar.org/CorpusID:273532030} 169 | } 170 | ``` 171 | -------------------------------------------------------------------------------- /bs-roformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/BS-RoFormer/9c3ef86fed568434cafb4e5a0daecb2f5f2ab6f5/bs-roformer.png -------------------------------------------------------------------------------- /bs_roformer/__init__.py: -------------------------------------------------------------------------------- 1 | from bs_roformer.bs_roformer import BSRoformer 2 | from bs_roformer.mel_band_roformer import MelBandRoformer 3 | -------------------------------------------------------------------------------- /bs_roformer/attend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from packaging import version 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange, reduce 10 | 11 | # constants 12 | 13 | FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def default(v, d): 21 | return v if exists(v) else d 22 | 23 | def once(fn): 24 | called = False 25 | @wraps(fn) 26 | def inner(x): 27 | nonlocal called 28 | if called: 29 | return 30 | called = True 31 | return fn(x) 32 | return inner 33 | 34 | print_once = once(print) 35 | 36 | # main class 37 | 38 | class Attend(nn.Module): 39 | def __init__( 40 | self, 41 | dropout = 0., 42 | flash = False, 43 | scale = None 44 | ): 45 | super().__init__() 46 | self.scale = scale 47 | self.dropout = dropout 48 | self.attn_dropout = nn.Dropout(dropout) 49 | 50 | self.flash = flash 51 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 52 | 53 | # determine efficient attention configs for cuda and cpu 54 | 55 | self.cpu_config = FlashAttentionConfig(True, True, True) 56 | self.cuda_config = None 57 | 58 | if not torch.cuda.is_available() or not flash: 59 | return 60 | 61 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 62 | 63 | if device_properties.major == 8 and device_properties.minor == 0: 64 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 65 | self.cuda_config = FlashAttentionConfig(True, False, False) 66 | else: 67 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 68 | self.cuda_config = FlashAttentionConfig(False, True, True) 69 | 70 | def flash_attn(self, q, k, v): 71 | _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 72 | 73 | if exists(self.scale): 74 | default_scale = q.shape[-1] ** -0.5 75 | q = q * (self.scale / default_scale) 76 | 77 | # Check if there is a compatible device for flash attention 78 | 79 | config = self.cuda_config if is_cuda else self.cpu_config 80 | 81 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale 82 | 83 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 84 | out = F.scaled_dot_product_attention( 85 | q, k, v, 86 | dropout_p = self.dropout if self.training else 0. 87 | ) 88 | 89 | return out 90 | 91 | def forward(self, q, k, v): 92 | """ 93 | einstein notation 94 | b - batch 95 | h - heads 96 | n, i, j - sequence length (base sequence length, source, target) 97 | d - feature dimension 98 | """ 99 | 100 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device 101 | 102 | scale = default(self.scale, q.shape[-1] ** -0.5) 103 | 104 | if self.flash: 105 | return self.flash_attn(q, k, v) 106 | 107 | # similarity 108 | 109 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale 110 | 111 | # attention 112 | 113 | attn = sim.softmax(dim=-1) 114 | attn = self.attn_dropout(attn) 115 | 116 | # aggregate values 117 | 118 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v) 119 | 120 | return out 121 | -------------------------------------------------------------------------------- /bs_roformer/bs_roformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn, einsum, Tensor 6 | from torch.nn import Module, ModuleList 7 | import torch.nn.functional as F 8 | 9 | from bs_roformer.attend import Attend 10 | 11 | from beartype.typing import Callable 12 | from beartype import beartype 13 | 14 | from rotary_embedding_torch import RotaryEmbedding 15 | 16 | from einops import rearrange, pack, unpack 17 | 18 | from hyper_connections import get_init_and_expand_reduce_stream_functions 19 | 20 | # helper functions 21 | 22 | def exists(val): 23 | return val is not None 24 | 25 | def default(v, d): 26 | return v if exists(v) else d 27 | 28 | def pack_one(t, pattern): 29 | return pack([t], pattern) 30 | 31 | def unpack_one(t, ps, pattern): 32 | return unpack(t, ps, pattern)[0] 33 | 34 | # norm 35 | 36 | class RMSNorm(Module): 37 | def __init__(self, dim): 38 | super().__init__() 39 | self.scale = dim ** 0.5 40 | self.gamma = nn.Parameter(torch.ones(dim)) 41 | 42 | def forward(self, x): 43 | return F.normalize(x, dim = -1) * self.scale * self.gamma 44 | 45 | # attention 46 | 47 | class FeedForward(Module): 48 | def __init__( 49 | self, 50 | dim, 51 | mult = 4, 52 | dropout = 0. 53 | ): 54 | super().__init__() 55 | dim_inner = int(dim * mult) 56 | self.net = nn.Sequential( 57 | RMSNorm(dim), 58 | nn.Linear(dim, dim_inner), 59 | nn.GELU(), 60 | nn.Dropout(dropout), 61 | nn.Linear(dim_inner, dim), 62 | nn.Dropout(dropout) 63 | ) 64 | 65 | def forward(self, x): 66 | return self.net(x) 67 | 68 | class Attention(Module): 69 | def __init__( 70 | self, 71 | dim, 72 | heads = 8, 73 | dim_head = 64, 74 | dropout = 0., 75 | rotary_embed = None, 76 | flash = True, 77 | learned_value_residual_mix = False 78 | ): 79 | super().__init__() 80 | self.heads = heads 81 | self.scale = dim_head **-0.5 82 | dim_inner = heads * dim_head 83 | 84 | self.rotary_embed = rotary_embed 85 | 86 | self.attend = Attend(flash = flash, dropout = dropout) 87 | 88 | self.norm = RMSNorm(dim) 89 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) 90 | 91 | self.to_value_residual_mix = nn.Linear(dim, heads) if learned_value_residual_mix else None 92 | 93 | self.to_gates = nn.Linear(dim, heads) 94 | 95 | self.to_out = nn.Sequential( 96 | nn.Linear(dim_inner, dim, bias = False), 97 | nn.Dropout(dropout) 98 | ) 99 | 100 | def forward(self, x, value_residual = None): 101 | x = self.norm(x) 102 | 103 | q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads) 104 | 105 | orig_v = v 106 | 107 | if exists(self.to_value_residual_mix): 108 | mix = self.to_value_residual_mix(x) 109 | mix = rearrange(mix, 'b n h -> b h n 1').sigmoid() 110 | 111 | assert exists(value_residual) 112 | v = v.lerp(value_residual, mix) 113 | 114 | if exists(self.rotary_embed): 115 | q = self.rotary_embed.rotate_queries_or_keys(q) 116 | k = self.rotary_embed.rotate_queries_or_keys(k) 117 | 118 | out = self.attend(q, k, v) 119 | 120 | gates = self.to_gates(x) 121 | out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() 122 | 123 | out = rearrange(out, 'b h n d -> b n (h d)') 124 | 125 | return self.to_out(out), orig_v 126 | 127 | class Transformer(Module): 128 | def __init__( 129 | self, 130 | *, 131 | dim, 132 | depth, 133 | dim_head = 64, 134 | heads = 8, 135 | attn_dropout = 0., 136 | ff_dropout = 0., 137 | ff_mult = 4, 138 | norm_output = True, 139 | rotary_embed = None, 140 | flash_attn = True, 141 | add_value_residual = False, 142 | num_residual_streams = 1 143 | ): 144 | super().__init__() 145 | self.layers = ModuleList([]) 146 | 147 | init_hyper_conn, *_ = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 148 | 149 | for _ in range(depth): 150 | self.layers.append(ModuleList([ 151 | init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, learned_value_residual_mix = add_value_residual)), 152 | init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)) 153 | ])) 154 | 155 | self.norm = RMSNorm(dim) if norm_output else nn.Identity() 156 | 157 | def forward(self, x, value_residual = None): 158 | 159 | first_values = None 160 | 161 | for attn, ff in self.layers: 162 | x, next_values = attn(x, value_residual = value_residual) 163 | 164 | first_values = default(first_values, next_values) 165 | 166 | x = ff(x) 167 | 168 | return self.norm(x), first_values 169 | 170 | # bandsplit module 171 | 172 | class BandSplit(Module): 173 | @beartype 174 | def __init__( 175 | self, 176 | dim, 177 | dim_inputs: tuple[int, ...] 178 | ): 179 | super().__init__() 180 | self.dim_inputs = dim_inputs 181 | self.to_features = ModuleList([]) 182 | 183 | for dim_in in dim_inputs: 184 | net = nn.Sequential( 185 | RMSNorm(dim_in), 186 | nn.Linear(dim_in, dim) 187 | ) 188 | 189 | self.to_features.append(net) 190 | 191 | def forward(self, x): 192 | x = x.split(self.dim_inputs, dim = -1) 193 | 194 | outs = [] 195 | for split_input, to_feature in zip(x, self.to_features): 196 | split_output = to_feature(split_input) 197 | outs.append(split_output) 198 | 199 | return torch.stack(outs, dim = -2) 200 | 201 | def MLP( 202 | dim_in, 203 | dim_out, 204 | dim_hidden = None, 205 | depth = 1, 206 | activation = nn.Tanh 207 | ): 208 | dim_hidden = default(dim_hidden, dim_in) 209 | 210 | net = [] 211 | dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) 212 | 213 | for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): 214 | is_last = ind == (len(dims) - 2) 215 | 216 | net.append(nn.Linear(layer_dim_in, layer_dim_out)) 217 | 218 | if is_last: 219 | continue 220 | 221 | net.append(activation()) 222 | 223 | return nn.Sequential(*net) 224 | 225 | class MaskEstimator(Module): 226 | @beartype 227 | def __init__( 228 | self, 229 | dim, 230 | dim_inputs: tuple[int, ...], 231 | depth, 232 | mlp_expansion_factor = 4 233 | ): 234 | super().__init__() 235 | self.dim_inputs = dim_inputs 236 | self.to_freqs = ModuleList([]) 237 | dim_hidden = dim * mlp_expansion_factor 238 | 239 | for dim_in in dim_inputs: 240 | net = [] 241 | 242 | mlp = nn.Sequential( 243 | MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth), 244 | nn.GLU(dim = -1) 245 | ) 246 | 247 | self.to_freqs.append(mlp) 248 | 249 | def forward(self, x): 250 | x = x.unbind(dim = -2) 251 | 252 | outs = [] 253 | 254 | for band_features, mlp in zip(x, self.to_freqs): 255 | freq_out = mlp(band_features) 256 | outs.append(freq_out) 257 | 258 | return torch.cat(outs, dim = -1) 259 | 260 | # main class 261 | 262 | DEFAULT_FREQS_PER_BANDS = ( 263 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 264 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 265 | 2, 2, 2, 2, 266 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 267 | 12, 12, 12, 12, 12, 12, 12, 12, 268 | 24, 24, 24, 24, 24, 24, 24, 24, 269 | 48, 48, 48, 48, 48, 48, 48, 48, 270 | 128, 129, 271 | ) 272 | 273 | class BSRoformer(Module): 274 | 275 | @beartype 276 | def __init__( 277 | self, 278 | dim, 279 | *, 280 | depth, 281 | stereo = False, 282 | num_stems = 1, 283 | time_transformer_depth = 2, 284 | freq_transformer_depth = 2, 285 | freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters 286 | dim_head = 64, 287 | heads = 8, 288 | attn_dropout = 0., 289 | ff_dropout = 0., 290 | flash_attn = True, 291 | num_residual_streams = 4, # set to 1. to disable hyper connections 292 | dim_freqs_in = 1025, 293 | stft_n_fft = 2048, 294 | stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction 295 | stft_win_length = 2048, 296 | stft_normalized = False, 297 | stft_window_fn: Callable | None = None, 298 | mask_estimator_depth = 2, 299 | multi_stft_resolution_loss_weight = 1., 300 | multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256), 301 | multi_stft_hop_size = 147, 302 | multi_stft_normalized = False, 303 | multi_stft_window_fn: Callable = torch.hann_window 304 | ): 305 | super().__init__() 306 | 307 | self.stereo = stereo 308 | self.audio_channels = 2 if stereo else 1 309 | self.num_stems = num_stems 310 | 311 | _, self.expand_stream, self.reduce_stream = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 312 | 313 | self.layers = ModuleList([]) 314 | 315 | transformer_kwargs = dict( 316 | dim = dim, 317 | heads = heads, 318 | dim_head = dim_head, 319 | attn_dropout = attn_dropout, 320 | ff_dropout = ff_dropout, 321 | flash_attn = flash_attn, 322 | num_residual_streams = num_residual_streams, 323 | norm_output = False, 324 | ) 325 | 326 | time_rotary_embed = RotaryEmbedding(dim = dim_head) 327 | freq_rotary_embed = RotaryEmbedding(dim = dim_head) 328 | 329 | for layer_index in range(depth): 330 | is_first = layer_index == 0 331 | 332 | self.layers.append(nn.ModuleList([ 333 | Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, add_value_residual = not is_first, **transformer_kwargs), 334 | Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, add_value_residual = not is_first, **transformer_kwargs) 335 | ])) 336 | 337 | self.final_norm = RMSNorm(dim) 338 | 339 | self.stft_kwargs = dict( 340 | n_fft = stft_n_fft, 341 | hop_length = stft_hop_length, 342 | win_length = stft_win_length, 343 | normalized = stft_normalized 344 | ) 345 | 346 | self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) 347 | 348 | freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex = True).shape[1] 349 | 350 | assert len(freqs_per_bands) > 1 351 | assert sum(freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}' 352 | 353 | freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) 354 | 355 | self.band_split = BandSplit( 356 | dim = dim, 357 | dim_inputs = freqs_per_bands_with_complex 358 | ) 359 | 360 | self.mask_estimators = nn.ModuleList([]) 361 | 362 | for _ in range(num_stems): 363 | mask_estimator = MaskEstimator( 364 | dim = dim, 365 | dim_inputs = freqs_per_bands_with_complex, 366 | depth = mask_estimator_depth 367 | ) 368 | 369 | self.mask_estimators.append(mask_estimator) 370 | 371 | # for the multi-resolution stft loss 372 | 373 | self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight 374 | self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes 375 | self.multi_stft_n_fft = stft_n_fft 376 | self.multi_stft_window_fn = multi_stft_window_fn 377 | 378 | self.multi_stft_kwargs = dict( 379 | hop_length = multi_stft_hop_size, 380 | normalized = multi_stft_normalized 381 | ) 382 | 383 | def forward( 384 | self, 385 | raw_audio, 386 | target = None, 387 | return_loss_breakdown = False 388 | ): 389 | """ 390 | einops 391 | 392 | b - batch 393 | f - freq 394 | t - time 395 | s - audio channel (1 for mono, 2 for stereo) 396 | n - number of 'stems' 397 | c - complex (2) 398 | d - feature dimension 399 | """ 400 | 401 | device = raw_audio.device 402 | 403 | if raw_audio.ndim == 2: 404 | raw_audio = rearrange(raw_audio, 'b t -> b 1 t') 405 | 406 | channels = raw_audio.shape[1] 407 | assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' 408 | 409 | # to stft 410 | 411 | raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') 412 | 413 | stft_window = self.stft_window_fn(device = device) 414 | 415 | stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window = stft_window, return_complex = True) 416 | stft_repr = torch.view_as_real(stft_repr) 417 | 418 | stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') 419 | stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting 420 | 421 | x = rearrange(stft_repr, 'b f t c -> b t (f c)') 422 | 423 | x = self.band_split(x) 424 | 425 | # value residuals 426 | 427 | time_v_residual = None 428 | freq_v_residual = None 429 | 430 | # maybe expand residual streams 431 | 432 | x = self.expand_stream(x) 433 | 434 | # axial / hierarchical attention 435 | 436 | for time_transformer, freq_transformer in self.layers: 437 | 438 | x = rearrange(x, 'b t f d -> b f t d') 439 | x, ps = pack([x], '* t d') 440 | 441 | x, next_time_v_residual = time_transformer(x, value_residual = time_v_residual) 442 | 443 | time_v_residual = default(time_v_residual, next_time_v_residual) 444 | 445 | x, = unpack(x, ps, '* t d') 446 | x = rearrange(x, 'b f t d -> b t f d') 447 | x, ps = pack([x], '* f d') 448 | 449 | x, next_freq_v_residual = freq_transformer(x, value_residual = freq_v_residual) 450 | 451 | freq_v_residual = default(freq_v_residual, next_freq_v_residual) 452 | 453 | x, = unpack(x, ps, '* f d') 454 | 455 | # maybe reduce residual streams 456 | 457 | x = self.reduce_stream(x) 458 | 459 | x = self.final_norm(x) 460 | 461 | num_stems = len(self.mask_estimators) 462 | 463 | mask = torch.stack([fn(x) for fn in self.mask_estimators], dim = 1) 464 | mask = rearrange(mask, 'b n t (f c) -> b n f t c', c = 2) 465 | 466 | # modulate frequency representation 467 | 468 | stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') 469 | 470 | # complex number multiplication 471 | 472 | stft_repr = torch.view_as_complex(stft_repr) 473 | mask = torch.view_as_complex(mask) 474 | 475 | stft_repr = stft_repr * mask 476 | 477 | # istft 478 | 479 | stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s = self.audio_channels) 480 | 481 | recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window = stft_window, return_complex = False) 482 | 483 | recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s = self.audio_channels, n = num_stems) 484 | 485 | if num_stems == 1: 486 | recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') 487 | 488 | # if a target is passed in, calculate loss for learning 489 | 490 | if not exists(target): 491 | return recon_audio 492 | 493 | if self.num_stems > 1: 494 | assert target.ndim == 4 and target.shape[1] == self.num_stems 495 | 496 | if target.ndim == 2: 497 | target = rearrange(target, '... t -> ... 1 t') 498 | 499 | target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft 500 | 501 | loss = F.l1_loss(recon_audio, target) 502 | 503 | multi_stft_resolution_loss = 0. 504 | 505 | for window_size in self.multi_stft_resolutions_window_sizes: 506 | 507 | res_stft_kwargs = dict( 508 | n_fft = max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft 509 | win_length = window_size, 510 | return_complex = True, 511 | window = self.multi_stft_window_fn(window_size, device = device), 512 | **self.multi_stft_kwargs, 513 | ) 514 | 515 | recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) 516 | target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) 517 | 518 | multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) 519 | 520 | weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight 521 | 522 | total_loss = loss + weighted_multi_resolution_loss 523 | 524 | if not return_loss_breakdown: 525 | return total_loss 526 | 527 | return total_loss, (loss, multi_stft_resolution_loss) 528 | -------------------------------------------------------------------------------- /bs_roformer/mel_band_roformer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn, einsum, Tensor 6 | from torch.nn import Module, ModuleList 7 | import torch.nn.functional as F 8 | 9 | from bs_roformer.attend import Attend 10 | 11 | from beartype.typing import Callable 12 | from beartype import beartype 13 | 14 | from rotary_embedding_torch import RotaryEmbedding 15 | 16 | from einops import rearrange, pack, unpack, reduce, repeat 17 | from einops.layers.torch import Rearrange 18 | 19 | from librosa import filters 20 | 21 | from hyper_connections import get_init_and_expand_reduce_stream_functions 22 | 23 | # helper functions 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | def default(v, d): 29 | return v if exists(v) else d 30 | 31 | def pack_one(t, pattern): 32 | return pack([t], pattern) 33 | 34 | def unpack_one(t, ps, pattern): 35 | return unpack(t, ps, pattern)[0] 36 | 37 | def pad_at_dim(t, pad, dim = -1, value = 0.): 38 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) 39 | zeros = ((0, 0) * dims_from_right) 40 | return F.pad(t, (*zeros, *pad), value = value) 41 | 42 | def l2norm(t): 43 | return F.normalize(t, dim = -1, p = 2) 44 | 45 | # norm 46 | 47 | class RMSNorm(Module): 48 | def __init__(self, dim): 49 | super().__init__() 50 | self.scale = dim ** 0.5 51 | self.gamma = nn.Parameter(torch.ones(dim)) 52 | 53 | def forward(self, x): 54 | return F.normalize(x, dim = -1) * self.scale * self.gamma 55 | 56 | # attention 57 | 58 | class FeedForward(Module): 59 | def __init__( 60 | self, 61 | dim, 62 | mult = 4, 63 | dropout = 0. 64 | ): 65 | super().__init__() 66 | dim_inner = int(dim * mult) 67 | self.net = nn.Sequential( 68 | RMSNorm(dim), 69 | nn.Linear(dim, dim_inner), 70 | nn.GELU(), 71 | nn.Dropout(dropout), 72 | nn.Linear(dim_inner, dim), 73 | nn.Dropout(dropout) 74 | ) 75 | 76 | def forward(self, x): 77 | return self.net(x) 78 | 79 | class Attention(Module): 80 | def __init__( 81 | self, 82 | dim, 83 | heads = 8, 84 | dim_head = 64, 85 | dropout = 0., 86 | rotary_embed = None, 87 | flash = True, 88 | add_value_residual = False 89 | ): 90 | super().__init__() 91 | self.heads = heads 92 | self.scale = dim_head **-0.5 93 | dim_inner = heads * dim_head 94 | 95 | self.rotary_embed = rotary_embed 96 | 97 | self.attend = Attend(flash = flash, dropout = dropout) 98 | 99 | self.norm = RMSNorm(dim) 100 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False) 101 | 102 | self.to_gates = nn.Linear(dim, heads) 103 | 104 | self.learned_value_residual_mix = nn.Sequential( 105 | nn.Linear(dim, heads), 106 | Rearrange('b n h -> b h n 1'), 107 | nn.Sigmoid() 108 | ) if add_value_residual else None 109 | 110 | self.to_out = nn.Sequential( 111 | nn.Linear(dim_inner, dim, bias = False), 112 | nn.Dropout(dropout) 113 | ) 114 | 115 | def forward(self, x, value_residual = None): 116 | x = self.norm(x) 117 | 118 | q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads) 119 | 120 | orig_v = v 121 | 122 | if exists(self.learned_value_residual_mix): 123 | mix = self.learned_value_residual_mix(x) 124 | assert exists(value_residual) 125 | v = v.lerp(mix, value_residual) 126 | 127 | if exists(self.rotary_embed): 128 | q = self.rotary_embed.rotate_queries_or_keys(q) 129 | k = self.rotary_embed.rotate_queries_or_keys(k) 130 | 131 | out = self.attend(q, k, v) 132 | 133 | gates = self.to_gates(x) 134 | out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() 135 | 136 | out = rearrange(out, 'b h n d -> b n (h d)') 137 | return self.to_out(out), orig_v 138 | 139 | class LinearAttention(Module): 140 | """ 141 | this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. 142 | """ 143 | 144 | @beartype 145 | def __init__( 146 | self, 147 | *, 148 | dim, 149 | dim_head = 32, 150 | heads = 8, 151 | scale = 8, 152 | flash = False, 153 | dropout = 0., 154 | add_value_residual = False 155 | ): 156 | super().__init__() 157 | dim_inner = dim_head * heads 158 | self.norm = RMSNorm(dim) 159 | 160 | self.to_qkv = nn.Sequential( 161 | nn.Linear(dim, dim_inner * 3, bias = False), 162 | Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads) 163 | ) 164 | 165 | self.temperature = nn.Parameter(torch.zeros(heads, 1, 1)) 166 | 167 | self.attend = Attend( 168 | scale = scale, 169 | dropout = dropout, 170 | flash = flash 171 | ) 172 | 173 | self.learned_value_residual_mix = nn.Sequential( 174 | nn.Linear(dim, heads), 175 | Rearrange('b n h -> b h 1 n'), 176 | nn.Sigmoid() 177 | ) if add_value_residual else None 178 | 179 | self.to_out = nn.Sequential( 180 | Rearrange('b h d n -> b n (h d)'), 181 | nn.Linear(dim_inner, dim, bias = False) 182 | ) 183 | 184 | def forward( 185 | self, 186 | x, 187 | value_residual = None 188 | ): 189 | x = self.norm(x) 190 | 191 | q, k, v = self.to_qkv(x) 192 | 193 | orig_v = v 194 | 195 | if exists(self.learned_value_residual_mix): 196 | mix = self.learned_value_residual_mix(x) 197 | assert exists(value_residual) 198 | v = v.lerp(mix, value_residual) 199 | 200 | q, k = map(l2norm, (q, k)) 201 | q = q * self.temperature.exp() 202 | 203 | out = self.attend(q, k, v) 204 | 205 | return self.to_out(out), orig_v 206 | 207 | class Transformer(Module): 208 | def __init__( 209 | self, 210 | *, 211 | dim, 212 | depth, 213 | dim_head = 64, 214 | heads = 8, 215 | attn_dropout = 0., 216 | ff_dropout = 0., 217 | ff_mult = 4, 218 | norm_output = True, 219 | rotary_embed = None, 220 | flash_attn = True, 221 | linear_attn = False, 222 | add_value_residual = False, 223 | num_residual_streams = 1 224 | ): 225 | super().__init__() 226 | 227 | init_hyper_conn, *_ = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 228 | 229 | self.layers = ModuleList([]) 230 | 231 | for _ in range(depth): 232 | if linear_attn: 233 | attn = LinearAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn, add_value_residual = add_value_residual) 234 | else: 235 | attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, add_value_residual = add_value_residual) 236 | 237 | ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) 238 | 239 | self.layers.append(ModuleList([ 240 | init_hyper_conn(dim = dim, branch = attn), 241 | init_hyper_conn(dim = dim, branch = ff) 242 | ])) 243 | 244 | self.norm = RMSNorm(dim) if norm_output else nn.Identity() 245 | 246 | def forward(self, x, value_residual = None): 247 | 248 | first_values = None 249 | 250 | for attn, ff in self.layers: 251 | 252 | x, values = attn(x, value_residual = value_residual) 253 | first_values = default(first_values, values) 254 | 255 | x = ff(x) 256 | 257 | return self.norm(x), first_values 258 | 259 | # bandsplit module 260 | 261 | class BandSplit(Module): 262 | @beartype 263 | def __init__( 264 | self, 265 | dim, 266 | dim_inputs: tuple[int, ...] 267 | ): 268 | super().__init__() 269 | self.dim_inputs = dim_inputs 270 | self.to_features = ModuleList([]) 271 | 272 | for dim_in in dim_inputs: 273 | net = nn.Sequential( 274 | RMSNorm(dim_in), 275 | nn.Linear(dim_in, dim) 276 | ) 277 | 278 | self.to_features.append(net) 279 | 280 | def forward(self, x): 281 | x = x.split(self.dim_inputs, dim = -1) 282 | 283 | outs = [] 284 | for split_input, to_feature in zip(x, self.to_features): 285 | split_output = to_feature(split_input) 286 | outs.append(split_output) 287 | 288 | return torch.stack(outs, dim = -2) 289 | 290 | def MLP( 291 | dim_in, 292 | dim_out, 293 | dim_hidden = None, 294 | depth = 1, 295 | activation = nn.Tanh 296 | ): 297 | dim_hidden = default(dim_hidden, dim_in) 298 | 299 | net = [] 300 | dims = (dim_in, *((dim_hidden,) * depth), dim_out) 301 | 302 | for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): 303 | is_last = ind == (len(dims) - 2) 304 | 305 | net.append(nn.Linear(layer_dim_in, layer_dim_out)) 306 | 307 | if is_last: 308 | continue 309 | 310 | net.append(activation()) 311 | 312 | return nn.Sequential(*net) 313 | 314 | class MaskEstimator(Module): 315 | @beartype 316 | def __init__( 317 | self, 318 | dim, 319 | dim_inputs: tuple[int, ...], 320 | depth, 321 | mlp_expansion_factor = 4 322 | ): 323 | super().__init__() 324 | self.dim_inputs = dim_inputs 325 | self.to_freqs = ModuleList([]) 326 | dim_hidden = dim * mlp_expansion_factor 327 | 328 | for dim_in in dim_inputs: 329 | net = [] 330 | 331 | mlp = nn.Sequential( 332 | MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth), 333 | nn.GLU(dim = -1) 334 | ) 335 | 336 | self.to_freqs.append(mlp) 337 | 338 | def forward(self, x): 339 | x = x.unbind(dim = -2) 340 | 341 | outs = [] 342 | 343 | for band_features, mlp in zip(x, self.to_freqs): 344 | freq_out = mlp(band_features) 345 | outs.append(freq_out) 346 | 347 | return torch.cat(outs, dim = -1) 348 | 349 | # main class 350 | 351 | class MelBandRoformer(Module): 352 | 353 | @beartype 354 | def __init__( 355 | self, 356 | dim, 357 | *, 358 | depth, 359 | stereo = False, 360 | num_stems = 1, 361 | time_transformer_depth = 2, 362 | freq_transformer_depth = 2, 363 | linear_transformer_depth = 1, 364 | num_bands = 60, 365 | dim_head = 64, 366 | heads = 8, 367 | attn_dropout = 0.1, 368 | ff_dropout = 0.1, 369 | flash_attn = True, 370 | linear_flash_attn = None, 371 | dim_freqs_in = 1025, 372 | sample_rate = 44100, # needed for mel filter bank from librosa 373 | stft_n_fft = 2048, 374 | stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction 375 | stft_win_length = 2048, 376 | stft_normalized = False, 377 | stft_window_fn: Callable | None = None, 378 | mask_estimator_depth = 1, 379 | multi_stft_resolution_loss_weight = 1., 380 | multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256), 381 | multi_stft_hop_size = 147, 382 | multi_stft_normalized = False, 383 | multi_stft_window_fn: Callable = torch.hann_window, 384 | match_input_audio_length = False, # if True, pad output tensor to match length of input tensor 385 | add_value_residual = True, 386 | num_residual_streams = 4 387 | ): 388 | super().__init__() 389 | 390 | self.stereo = stereo 391 | self.audio_channels = 2 if stereo else 1 392 | self.num_stems = num_stems 393 | 394 | self.layers = ModuleList([]) 395 | 396 | transformer_kwargs = dict( 397 | dim = dim, 398 | heads = heads, 399 | dim_head = dim_head, 400 | attn_dropout = attn_dropout, 401 | ff_dropout = ff_dropout, 402 | num_residual_streams = num_residual_streams 403 | ) 404 | 405 | time_rotary_embed = RotaryEmbedding(dim = dim_head) 406 | freq_rotary_embed = RotaryEmbedding(dim = dim_head) 407 | 408 | linear_flash_attn = default(linear_flash_attn, flash_attn) 409 | 410 | # hyper connections 411 | 412 | _, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 413 | 414 | for layer_index in range(depth): 415 | is_first = layer_index == 0 416 | 417 | self.layers.append(nn.ModuleList([ 418 | Transformer(depth = linear_transformer_depth, linear_attn = True, flash_attn = linear_flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs) if linear_transformer_depth > 0 else None, 419 | Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs), 420 | Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs) 421 | ])) 422 | 423 | self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) 424 | 425 | self.stft_kwargs = dict( 426 | n_fft = stft_n_fft, 427 | hop_length = stft_hop_length, 428 | win_length = stft_win_length, 429 | normalized = stft_normalized 430 | ) 431 | 432 | freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex = True).shape[1] 433 | 434 | # create mel filter bank 435 | # with librosa.filters.mel as in section 2 of paper 436 | 437 | mel_filter_bank_numpy = filters.mel(sr = sample_rate, n_fft = stft_n_fft, n_mels = num_bands) 438 | 439 | mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) 440 | 441 | # for some reason, it doesn't include the first freq? just force a value for now 442 | 443 | mel_filter_bank[0][0] = 1. 444 | 445 | # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position, 446 | # so let's force a positive value 447 | 448 | mel_filter_bank[-1, -1] = 1. 449 | 450 | # binary as in paper (then estimated masks are averaged for overlapping regions) 451 | 452 | freqs_per_band = mel_filter_bank > 0 453 | assert freqs_per_band.any(dim = 0).all(), 'all frequencies need to be covered by all bands for now' 454 | 455 | repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b = num_bands) 456 | freq_indices = repeated_freq_indices[freqs_per_band] 457 | 458 | if stereo: 459 | freq_indices = repeat(freq_indices, 'f -> f s', s = 2) 460 | freq_indices = freq_indices * 2 + torch.arange(2) 461 | freq_indices = rearrange(freq_indices, 'f s -> (f s)') 462 | 463 | self.register_buffer('freq_indices', freq_indices, persistent = False) 464 | self.register_buffer('freqs_per_band', freqs_per_band, persistent = False) 465 | 466 | num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') 467 | num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') 468 | 469 | self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent = False) 470 | self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent = False) 471 | 472 | # band split and mask estimator 473 | 474 | freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist()) 475 | 476 | self.band_split = BandSplit( 477 | dim = dim, 478 | dim_inputs = freqs_per_bands_with_complex 479 | ) 480 | 481 | self.mask_estimators = nn.ModuleList([]) 482 | 483 | for _ in range(num_stems): 484 | mask_estimator = MaskEstimator( 485 | dim = dim, 486 | dim_inputs = freqs_per_bands_with_complex, 487 | depth = mask_estimator_depth 488 | ) 489 | 490 | self.mask_estimators.append(mask_estimator) 491 | 492 | # for the multi-resolution stft loss 493 | 494 | self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight 495 | self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes 496 | self.multi_stft_n_fft = stft_n_fft 497 | self.multi_stft_window_fn = multi_stft_window_fn 498 | 499 | self.multi_stft_kwargs = dict( 500 | hop_length = multi_stft_hop_size, 501 | normalized = multi_stft_normalized 502 | ) 503 | 504 | self.match_input_audio_length = match_input_audio_length 505 | 506 | def forward( 507 | self, 508 | raw_audio, 509 | target = None, 510 | return_loss_breakdown = False 511 | ): 512 | """ 513 | einops 514 | 515 | b - batch 516 | f - freq 517 | t - time 518 | s - audio channel (1 for mono, 2 for stereo) 519 | n - number of 'stems' 520 | c - complex (2) 521 | d - feature dimension 522 | """ 523 | 524 | device = raw_audio.device 525 | 526 | if raw_audio.ndim == 2: 527 | raw_audio = rearrange(raw_audio, 'b t -> b 1 t') 528 | 529 | batch, channels, raw_audio_length = raw_audio.shape 530 | 531 | istft_length = raw_audio_length if self.match_input_audio_length else None 532 | 533 | assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' 534 | 535 | # to stft 536 | 537 | raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') 538 | 539 | stft_window = self.stft_window_fn(device = device) 540 | 541 | stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window = stft_window, return_complex = True) 542 | stft_repr = torch.view_as_real(stft_repr) 543 | 544 | stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') 545 | stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting 546 | 547 | # index out all frequencies for all frequency ranges across bands ascending in one go 548 | 549 | batch_arange = torch.arange(batch, device = device)[..., None] 550 | 551 | # account for stereo 552 | 553 | x = stft_repr[batch_arange, self.freq_indices] 554 | 555 | # fold the complex (real and imag) into the frequencies dimension 556 | 557 | x = rearrange(x, 'b f t c -> b t (f c)') 558 | 559 | x = self.band_split(x) 560 | 561 | # value residuals 562 | 563 | linear_value_residual = None 564 | time_value_residual = None 565 | freq_value_residual = None 566 | 567 | # expand residual streams (hyper connections) 568 | 569 | x = self.expand_streams(x) 570 | 571 | # axial / hierarchical attention 572 | 573 | for linear_transformer, time_transformer, freq_transformer in self.layers: 574 | 575 | if exists(linear_transformer): 576 | x, ft_ps = pack([x], 'b * d') 577 | 578 | x, next_linear_values = linear_transformer(x, value_residual = linear_value_residual) 579 | linear_value_residual = default(linear_value_residual, next_linear_values) 580 | 581 | x, = unpack(x, ft_ps, 'b * d') 582 | 583 | x = rearrange(x, 'b t f d -> b f t d') 584 | x, ps = pack([x], '* t d') 585 | 586 | x, next_time_values = time_transformer(x, value_residual = time_value_residual) 587 | time_value_residual = default(time_value_residual, next_time_values) 588 | 589 | x, = unpack(x, ps, '* t d') 590 | x = rearrange(x, 'b f t d -> b t f d') 591 | x, ps = pack([x], '* f d') 592 | 593 | x, next_freq_values = freq_transformer(x, value_residual = freq_value_residual) 594 | freq_value_residual = default(freq_value_residual, next_freq_values) 595 | 596 | x, = unpack(x, ps, '* f d') 597 | 598 | # reduce residual streams 599 | 600 | x = self.reduce_streams(x) 601 | 602 | # mask estimators 603 | 604 | num_stems = len(self.mask_estimators) 605 | 606 | masks = torch.stack([fn(x) for fn in self.mask_estimators], dim = 1) 607 | masks = rearrange(masks, 'b n t (f c) -> b n f t c', c = 2) 608 | 609 | # modulate frequency representation 610 | 611 | stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') 612 | 613 | # complex number multiplication 614 | 615 | stft_repr = torch.view_as_complex(stft_repr) 616 | masks = torch.view_as_complex(masks) 617 | 618 | masks = masks.type(stft_repr.dtype) 619 | 620 | # need to average the estimated mask for the overlapped frequencies 621 | 622 | scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b = batch, n = num_stems, t = stft_repr.shape[-1]) 623 | 624 | stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n = num_stems) 625 | masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) 626 | 627 | denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r = channels) 628 | 629 | masks_averaged = masks_summed / denom.clamp(min = 1e-8) 630 | 631 | # modulate stft repr with estimated mask 632 | 633 | stft_repr = stft_repr * masks_averaged 634 | 635 | # istft 636 | 637 | stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s = self.audio_channels) 638 | 639 | recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window = stft_window, return_complex = False, length = istft_length) 640 | 641 | recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s = self.audio_channels, n = num_stems) 642 | 643 | if num_stems == 1: 644 | recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') 645 | 646 | # if a target is passed in, calculate loss for learning 647 | 648 | if not exists(target): 649 | return recon_audio 650 | 651 | if self.num_stems > 1: 652 | assert target.ndim == 4 and target.shape[1] == self.num_stems 653 | 654 | if target.ndim == 2: 655 | target = rearrange(target, '... t -> ... 1 t') 656 | 657 | target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft 658 | 659 | loss = F.l1_loss(recon_audio, target) 660 | 661 | multi_stft_resolution_loss = 0. 662 | 663 | for window_size in self.multi_stft_resolutions_window_sizes: 664 | 665 | res_stft_kwargs = dict( 666 | n_fft = max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft 667 | win_length = window_size, 668 | return_complex = True, 669 | window = self.multi_stft_window_fn(window_size, device = device), 670 | **self.multi_stft_kwargs, 671 | ) 672 | 673 | recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) 674 | target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) 675 | 676 | multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) 677 | 678 | weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight 679 | 680 | total_loss = loss + weighted_multi_resolution_loss 681 | 682 | if not return_loss_breakdown: 683 | return total_loss 684 | 685 | return total_loss, (loss, multi_stft_resolution_loss) 686 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'BS-RoFormer', 5 | packages = find_packages(exclude=[]), 6 | version = '0.6.1', 7 | license='MIT', 8 | description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/BS-RoFormer', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'music source separation' 19 | ], 20 | install_requires=[ 21 | 'beartype', 22 | 'einops>=0.8.0', 23 | 'hyper-connections>=0.1.8', 24 | 'librosa', 25 | 'rotary-embedding-torch>=0.3.6', 26 | 'torch>=2.0', 27 | ], 28 | classifiers=[ 29 | 'Development Status :: 4 - Beta', 30 | 'Intended Audience :: Developers', 31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python :: 3.6', 34 | ], 35 | ) 36 | --------------------------------------------------------------------------------