├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── bs-roformer.png
├── bs_roformer
├── __init__.py
├── attend.py
├── bs_roformer.py
└── mel_band_roformer.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## BS-RoFormer
4 |
5 | Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs. They beat the previous first place by a large margin. The technique uses axial attention across frequency (hence multi-band) and time. They also have experiments to show that rotary positional encoding led to a huge improvement over learned absolute positions.
6 |
7 | It also includes support for stereo training and outputting multiple stems.
8 |
9 | Please join
if you are interested in replicating a SOTA music source separator out in the open
10 |
11 | Update: This paper has been replicated by Roman and weight open sourced here
12 |
13 | Update 2: Used for this Katy Perry remix!
14 |
15 | Update 3: Kimberley Jensen has open sourced a MelBand Roformer trained on vocals here!
16 |
17 | ## Appreciation
18 |
19 | - StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
20 |
21 | - Roee and Fabian-Robert for sharing their audio expertise and fixing audio hyperparameters
22 |
23 | - @chenht2010 and Roman for working out the default band splitting hyperparameter!
24 |
25 | - Max Prod for reporting a big bug with Mel-Band Roformer with stereo training!
26 |
27 | - Roman for successfully training the model and open sourcing his training code and weights at this repository!
28 |
29 | - Christopher for fixing an issue with multiple stems in Mel-Band Roformer
30 |
31 | - Iver Jordal for identifying that the default stft window function is not correct
32 |
33 | ## Install
34 |
35 | ```bash
36 | $ pip install BS-RoFormer
37 | ```
38 |
39 | ## Usage
40 |
41 | ```python
42 | import torch
43 | from bs_roformer import BSRoformer
44 |
45 | model = BSRoformer(
46 | dim = 512,
47 | depth = 12,
48 | time_transformer_depth = 1,
49 | freq_transformer_depth = 1
50 | )
51 |
52 | x = torch.randn(2, 352800)
53 | target = torch.randn(2, 352800)
54 |
55 | loss = model(x, target = target)
56 | loss.backward()
57 |
58 | # after much training
59 |
60 | out = model(x)
61 | ```
62 |
63 | To use the Mel-Band Roformer proposed in a recent follow up paper, simply import `MelBandRoformer` instead
64 |
65 | ```python
66 | import torch
67 | from bs_roformer import MelBandRoformer
68 |
69 | model = MelBandRoformer(
70 | dim = 32,
71 | depth = 1,
72 | time_transformer_depth = 1,
73 | freq_transformer_depth = 1
74 | )
75 |
76 | x = torch.randn(2, 352800)
77 | target = torch.randn(2, 352800)
78 |
79 | loss = model(x, target = target)
80 | loss.backward()
81 |
82 | # after much training
83 |
84 | out = model(x)
85 | ```
86 |
87 | ## Todo
88 |
89 | - [x] get the multiscale stft loss in there
90 | - [x] figure out what `n_fft` should be
91 | - [x] review band split + mask estimation modules
92 |
93 | ## Citations
94 |
95 | ```bibtex
96 | @inproceedings{Lu2023MusicSS,
97 | title = {Music Source Separation with Band-Split RoPE Transformer},
98 | author = {Wei-Tsung Lu and Ju-Chiang Wang and Qiuqiang Kong and Yun-Ning Hung},
99 | year = {2023},
100 | url = {https://api.semanticscholar.org/CorpusID:261556702}
101 | }
102 | ```
103 |
104 | ```bibtex
105 | @inproceedings{Wang2023MelBandRF,
106 | title = {Mel-Band RoFormer for Music Source Separation},
107 | author = {Ju-Chiang Wang and Wei-Tsung Lu and Minz Won},
108 | year = {2023},
109 | url = {https://api.semanticscholar.org/CorpusID:263608675}
110 | }
111 | ```
112 |
113 | ```bibtex
114 | @misc{ho2019axial,
115 | title = {Axial Attention in Multidimensional Transformers},
116 | author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
117 | year = {2019},
118 | archivePrefix = {arXiv}
119 | }
120 | ```
121 |
122 | ```bibtex
123 | @misc{su2021roformer,
124 | title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
125 | author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
126 | year = {2021},
127 | eprint = {2104.09864},
128 | archivePrefix = {arXiv},
129 | primaryClass = {cs.CL}
130 | }
131 | ```
132 |
133 | ```bibtex
134 | @inproceedings{dao2022flashattention,
135 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
136 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
137 | booktitle = {Advances in Neural Information Processing Systems},
138 | year = {2022}
139 | }
140 | ```
141 |
142 | ```bibtex
143 | @article{Bondarenko2023QuantizableTR,
144 | title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
145 | author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
146 | journal = {ArXiv},
147 | year = {2023},
148 | volume = {abs/2306.12929},
149 | url = {https://api.semanticscholar.org/CorpusID:259224568}
150 | }
151 | ```
152 |
153 | ```bibtex
154 | @inproceedings{ElNouby2021XCiTCI,
155 | title = {XCiT: Cross-Covariance Image Transformers},
156 | author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
157 | booktitle = {Neural Information Processing Systems},
158 | year = {2021},
159 | url = {https://api.semanticscholar.org/CorpusID:235458262}
160 | }
161 | ```
162 |
163 | ```bibtex
164 | @inproceedings{Zhou2024ValueRL,
165 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
166 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
167 | year = {2024},
168 | url = {https://api.semanticscholar.org/CorpusID:273532030}
169 | }
170 | ```
171 |
--------------------------------------------------------------------------------
/bs-roformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/BS-RoFormer/9c3ef86fed568434cafb4e5a0daecb2f5f2ab6f5/bs-roformer.png
--------------------------------------------------------------------------------
/bs_roformer/__init__.py:
--------------------------------------------------------------------------------
1 | from bs_roformer.bs_roformer import BSRoformer
2 | from bs_roformer.mel_band_roformer import MelBandRoformer
3 |
--------------------------------------------------------------------------------
/bs_roformer/attend.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from packaging import version
3 | from collections import namedtuple
4 |
5 | import torch
6 | from torch import nn, einsum
7 | import torch.nn.functional as F
8 |
9 | from einops import rearrange, reduce
10 |
11 | # constants
12 |
13 | FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
14 |
15 | # helpers
16 |
17 | def exists(val):
18 | return val is not None
19 |
20 | def default(v, d):
21 | return v if exists(v) else d
22 |
23 | def once(fn):
24 | called = False
25 | @wraps(fn)
26 | def inner(x):
27 | nonlocal called
28 | if called:
29 | return
30 | called = True
31 | return fn(x)
32 | return inner
33 |
34 | print_once = once(print)
35 |
36 | # main class
37 |
38 | class Attend(nn.Module):
39 | def __init__(
40 | self,
41 | dropout = 0.,
42 | flash = False,
43 | scale = None
44 | ):
45 | super().__init__()
46 | self.scale = scale
47 | self.dropout = dropout
48 | self.attn_dropout = nn.Dropout(dropout)
49 |
50 | self.flash = flash
51 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52 |
53 | # determine efficient attention configs for cuda and cpu
54 |
55 | self.cpu_config = FlashAttentionConfig(True, True, True)
56 | self.cuda_config = None
57 |
58 | if not torch.cuda.is_available() or not flash:
59 | return
60 |
61 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62 |
63 | if device_properties.major == 8 and device_properties.minor == 0:
64 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
65 | self.cuda_config = FlashAttentionConfig(True, False, False)
66 | else:
67 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
68 | self.cuda_config = FlashAttentionConfig(False, True, True)
69 |
70 | def flash_attn(self, q, k, v):
71 | _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device
72 |
73 | if exists(self.scale):
74 | default_scale = q.shape[-1] ** -0.5
75 | q = q * (self.scale / default_scale)
76 |
77 | # Check if there is a compatible device for flash attention
78 |
79 | config = self.cuda_config if is_cuda else self.cpu_config
80 |
81 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale
82 |
83 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
84 | out = F.scaled_dot_product_attention(
85 | q, k, v,
86 | dropout_p = self.dropout if self.training else 0.
87 | )
88 |
89 | return out
90 |
91 | def forward(self, q, k, v):
92 | """
93 | einstein notation
94 | b - batch
95 | h - heads
96 | n, i, j - sequence length (base sequence length, source, target)
97 | d - feature dimension
98 | """
99 |
100 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
101 |
102 | scale = default(self.scale, q.shape[-1] ** -0.5)
103 |
104 | if self.flash:
105 | return self.flash_attn(q, k, v)
106 |
107 | # similarity
108 |
109 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
110 |
111 | # attention
112 |
113 | attn = sim.softmax(dim=-1)
114 | attn = self.attn_dropout(attn)
115 |
116 | # aggregate values
117 |
118 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
119 |
120 | return out
121 |
--------------------------------------------------------------------------------
/bs_roformer/bs_roformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from functools import partial
3 |
4 | import torch
5 | from torch import nn, einsum, Tensor
6 | from torch.nn import Module, ModuleList
7 | import torch.nn.functional as F
8 |
9 | from bs_roformer.attend import Attend
10 |
11 | from beartype.typing import Callable
12 | from beartype import beartype
13 |
14 | from rotary_embedding_torch import RotaryEmbedding
15 |
16 | from einops import rearrange, pack, unpack
17 |
18 | from hyper_connections import get_init_and_expand_reduce_stream_functions
19 |
20 | # helper functions
21 |
22 | def exists(val):
23 | return val is not None
24 |
25 | def default(v, d):
26 | return v if exists(v) else d
27 |
28 | def pack_one(t, pattern):
29 | return pack([t], pattern)
30 |
31 | def unpack_one(t, ps, pattern):
32 | return unpack(t, ps, pattern)[0]
33 |
34 | # norm
35 |
36 | class RMSNorm(Module):
37 | def __init__(self, dim):
38 | super().__init__()
39 | self.scale = dim ** 0.5
40 | self.gamma = nn.Parameter(torch.ones(dim))
41 |
42 | def forward(self, x):
43 | return F.normalize(x, dim = -1) * self.scale * self.gamma
44 |
45 | # attention
46 |
47 | class FeedForward(Module):
48 | def __init__(
49 | self,
50 | dim,
51 | mult = 4,
52 | dropout = 0.
53 | ):
54 | super().__init__()
55 | dim_inner = int(dim * mult)
56 | self.net = nn.Sequential(
57 | RMSNorm(dim),
58 | nn.Linear(dim, dim_inner),
59 | nn.GELU(),
60 | nn.Dropout(dropout),
61 | nn.Linear(dim_inner, dim),
62 | nn.Dropout(dropout)
63 | )
64 |
65 | def forward(self, x):
66 | return self.net(x)
67 |
68 | class Attention(Module):
69 | def __init__(
70 | self,
71 | dim,
72 | heads = 8,
73 | dim_head = 64,
74 | dropout = 0.,
75 | rotary_embed = None,
76 | flash = True,
77 | learned_value_residual_mix = False
78 | ):
79 | super().__init__()
80 | self.heads = heads
81 | self.scale = dim_head **-0.5
82 | dim_inner = heads * dim_head
83 |
84 | self.rotary_embed = rotary_embed
85 |
86 | self.attend = Attend(flash = flash, dropout = dropout)
87 |
88 | self.norm = RMSNorm(dim)
89 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
90 |
91 | self.to_value_residual_mix = nn.Linear(dim, heads) if learned_value_residual_mix else None
92 |
93 | self.to_gates = nn.Linear(dim, heads)
94 |
95 | self.to_out = nn.Sequential(
96 | nn.Linear(dim_inner, dim, bias = False),
97 | nn.Dropout(dropout)
98 | )
99 |
100 | def forward(self, x, value_residual = None):
101 | x = self.norm(x)
102 |
103 | q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads)
104 |
105 | orig_v = v
106 |
107 | if exists(self.to_value_residual_mix):
108 | mix = self.to_value_residual_mix(x)
109 | mix = rearrange(mix, 'b n h -> b h n 1').sigmoid()
110 |
111 | assert exists(value_residual)
112 | v = v.lerp(value_residual, mix)
113 |
114 | if exists(self.rotary_embed):
115 | q = self.rotary_embed.rotate_queries_or_keys(q)
116 | k = self.rotary_embed.rotate_queries_or_keys(k)
117 |
118 | out = self.attend(q, k, v)
119 |
120 | gates = self.to_gates(x)
121 | out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
122 |
123 | out = rearrange(out, 'b h n d -> b n (h d)')
124 |
125 | return self.to_out(out), orig_v
126 |
127 | class Transformer(Module):
128 | def __init__(
129 | self,
130 | *,
131 | dim,
132 | depth,
133 | dim_head = 64,
134 | heads = 8,
135 | attn_dropout = 0.,
136 | ff_dropout = 0.,
137 | ff_mult = 4,
138 | norm_output = True,
139 | rotary_embed = None,
140 | flash_attn = True,
141 | add_value_residual = False,
142 | num_residual_streams = 1
143 | ):
144 | super().__init__()
145 | self.layers = ModuleList([])
146 |
147 | init_hyper_conn, *_ = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
148 |
149 | for _ in range(depth):
150 | self.layers.append(ModuleList([
151 | init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, learned_value_residual_mix = add_value_residual)),
152 | init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout))
153 | ]))
154 |
155 | self.norm = RMSNorm(dim) if norm_output else nn.Identity()
156 |
157 | def forward(self, x, value_residual = None):
158 |
159 | first_values = None
160 |
161 | for attn, ff in self.layers:
162 | x, next_values = attn(x, value_residual = value_residual)
163 |
164 | first_values = default(first_values, next_values)
165 |
166 | x = ff(x)
167 |
168 | return self.norm(x), first_values
169 |
170 | # bandsplit module
171 |
172 | class BandSplit(Module):
173 | @beartype
174 | def __init__(
175 | self,
176 | dim,
177 | dim_inputs: tuple[int, ...]
178 | ):
179 | super().__init__()
180 | self.dim_inputs = dim_inputs
181 | self.to_features = ModuleList([])
182 |
183 | for dim_in in dim_inputs:
184 | net = nn.Sequential(
185 | RMSNorm(dim_in),
186 | nn.Linear(dim_in, dim)
187 | )
188 |
189 | self.to_features.append(net)
190 |
191 | def forward(self, x):
192 | x = x.split(self.dim_inputs, dim = -1)
193 |
194 | outs = []
195 | for split_input, to_feature in zip(x, self.to_features):
196 | split_output = to_feature(split_input)
197 | outs.append(split_output)
198 |
199 | return torch.stack(outs, dim = -2)
200 |
201 | def MLP(
202 | dim_in,
203 | dim_out,
204 | dim_hidden = None,
205 | depth = 1,
206 | activation = nn.Tanh
207 | ):
208 | dim_hidden = default(dim_hidden, dim_in)
209 |
210 | net = []
211 | dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
212 |
213 | for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
214 | is_last = ind == (len(dims) - 2)
215 |
216 | net.append(nn.Linear(layer_dim_in, layer_dim_out))
217 |
218 | if is_last:
219 | continue
220 |
221 | net.append(activation())
222 |
223 | return nn.Sequential(*net)
224 |
225 | class MaskEstimator(Module):
226 | @beartype
227 | def __init__(
228 | self,
229 | dim,
230 | dim_inputs: tuple[int, ...],
231 | depth,
232 | mlp_expansion_factor = 4
233 | ):
234 | super().__init__()
235 | self.dim_inputs = dim_inputs
236 | self.to_freqs = ModuleList([])
237 | dim_hidden = dim * mlp_expansion_factor
238 |
239 | for dim_in in dim_inputs:
240 | net = []
241 |
242 | mlp = nn.Sequential(
243 | MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth),
244 | nn.GLU(dim = -1)
245 | )
246 |
247 | self.to_freqs.append(mlp)
248 |
249 | def forward(self, x):
250 | x = x.unbind(dim = -2)
251 |
252 | outs = []
253 |
254 | for band_features, mlp in zip(x, self.to_freqs):
255 | freq_out = mlp(band_features)
256 | outs.append(freq_out)
257 |
258 | return torch.cat(outs, dim = -1)
259 |
260 | # main class
261 |
262 | DEFAULT_FREQS_PER_BANDS = (
263 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
264 | 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
265 | 2, 2, 2, 2,
266 | 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
267 | 12, 12, 12, 12, 12, 12, 12, 12,
268 | 24, 24, 24, 24, 24, 24, 24, 24,
269 | 48, 48, 48, 48, 48, 48, 48, 48,
270 | 128, 129,
271 | )
272 |
273 | class BSRoformer(Module):
274 |
275 | @beartype
276 | def __init__(
277 | self,
278 | dim,
279 | *,
280 | depth,
281 | stereo = False,
282 | num_stems = 1,
283 | time_transformer_depth = 2,
284 | freq_transformer_depth = 2,
285 | freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, # in the paper, they divide into ~60 bands, test with 1 for starters
286 | dim_head = 64,
287 | heads = 8,
288 | attn_dropout = 0.,
289 | ff_dropout = 0.,
290 | flash_attn = True,
291 | num_residual_streams = 4, # set to 1. to disable hyper connections
292 | dim_freqs_in = 1025,
293 | stft_n_fft = 2048,
294 | stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
295 | stft_win_length = 2048,
296 | stft_normalized = False,
297 | stft_window_fn: Callable | None = None,
298 | mask_estimator_depth = 2,
299 | multi_stft_resolution_loss_weight = 1.,
300 | multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256),
301 | multi_stft_hop_size = 147,
302 | multi_stft_normalized = False,
303 | multi_stft_window_fn: Callable = torch.hann_window
304 | ):
305 | super().__init__()
306 |
307 | self.stereo = stereo
308 | self.audio_channels = 2 if stereo else 1
309 | self.num_stems = num_stems
310 |
311 | _, self.expand_stream, self.reduce_stream = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
312 |
313 | self.layers = ModuleList([])
314 |
315 | transformer_kwargs = dict(
316 | dim = dim,
317 | heads = heads,
318 | dim_head = dim_head,
319 | attn_dropout = attn_dropout,
320 | ff_dropout = ff_dropout,
321 | flash_attn = flash_attn,
322 | num_residual_streams = num_residual_streams,
323 | norm_output = False,
324 | )
325 |
326 | time_rotary_embed = RotaryEmbedding(dim = dim_head)
327 | freq_rotary_embed = RotaryEmbedding(dim = dim_head)
328 |
329 | for layer_index in range(depth):
330 | is_first = layer_index == 0
331 |
332 | self.layers.append(nn.ModuleList([
333 | Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, add_value_residual = not is_first, **transformer_kwargs),
334 | Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, add_value_residual = not is_first, **transformer_kwargs)
335 | ]))
336 |
337 | self.final_norm = RMSNorm(dim)
338 |
339 | self.stft_kwargs = dict(
340 | n_fft = stft_n_fft,
341 | hop_length = stft_hop_length,
342 | win_length = stft_win_length,
343 | normalized = stft_normalized
344 | )
345 |
346 | self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
347 |
348 | freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex = True).shape[1]
349 |
350 | assert len(freqs_per_bands) > 1
351 | assert sum(freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
352 |
353 | freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
354 |
355 | self.band_split = BandSplit(
356 | dim = dim,
357 | dim_inputs = freqs_per_bands_with_complex
358 | )
359 |
360 | self.mask_estimators = nn.ModuleList([])
361 |
362 | for _ in range(num_stems):
363 | mask_estimator = MaskEstimator(
364 | dim = dim,
365 | dim_inputs = freqs_per_bands_with_complex,
366 | depth = mask_estimator_depth
367 | )
368 |
369 | self.mask_estimators.append(mask_estimator)
370 |
371 | # for the multi-resolution stft loss
372 |
373 | self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
374 | self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
375 | self.multi_stft_n_fft = stft_n_fft
376 | self.multi_stft_window_fn = multi_stft_window_fn
377 |
378 | self.multi_stft_kwargs = dict(
379 | hop_length = multi_stft_hop_size,
380 | normalized = multi_stft_normalized
381 | )
382 |
383 | def forward(
384 | self,
385 | raw_audio,
386 | target = None,
387 | return_loss_breakdown = False
388 | ):
389 | """
390 | einops
391 |
392 | b - batch
393 | f - freq
394 | t - time
395 | s - audio channel (1 for mono, 2 for stereo)
396 | n - number of 'stems'
397 | c - complex (2)
398 | d - feature dimension
399 | """
400 |
401 | device = raw_audio.device
402 |
403 | if raw_audio.ndim == 2:
404 | raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
405 |
406 | channels = raw_audio.shape[1]
407 | assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
408 |
409 | # to stft
410 |
411 | raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
412 |
413 | stft_window = self.stft_window_fn(device = device)
414 |
415 | stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window = stft_window, return_complex = True)
416 | stft_repr = torch.view_as_real(stft_repr)
417 |
418 | stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
419 | stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
420 |
421 | x = rearrange(stft_repr, 'b f t c -> b t (f c)')
422 |
423 | x = self.band_split(x)
424 |
425 | # value residuals
426 |
427 | time_v_residual = None
428 | freq_v_residual = None
429 |
430 | # maybe expand residual streams
431 |
432 | x = self.expand_stream(x)
433 |
434 | # axial / hierarchical attention
435 |
436 | for time_transformer, freq_transformer in self.layers:
437 |
438 | x = rearrange(x, 'b t f d -> b f t d')
439 | x, ps = pack([x], '* t d')
440 |
441 | x, next_time_v_residual = time_transformer(x, value_residual = time_v_residual)
442 |
443 | time_v_residual = default(time_v_residual, next_time_v_residual)
444 |
445 | x, = unpack(x, ps, '* t d')
446 | x = rearrange(x, 'b f t d -> b t f d')
447 | x, ps = pack([x], '* f d')
448 |
449 | x, next_freq_v_residual = freq_transformer(x, value_residual = freq_v_residual)
450 |
451 | freq_v_residual = default(freq_v_residual, next_freq_v_residual)
452 |
453 | x, = unpack(x, ps, '* f d')
454 |
455 | # maybe reduce residual streams
456 |
457 | x = self.reduce_stream(x)
458 |
459 | x = self.final_norm(x)
460 |
461 | num_stems = len(self.mask_estimators)
462 |
463 | mask = torch.stack([fn(x) for fn in self.mask_estimators], dim = 1)
464 | mask = rearrange(mask, 'b n t (f c) -> b n f t c', c = 2)
465 |
466 | # modulate frequency representation
467 |
468 | stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
469 |
470 | # complex number multiplication
471 |
472 | stft_repr = torch.view_as_complex(stft_repr)
473 | mask = torch.view_as_complex(mask)
474 |
475 | stft_repr = stft_repr * mask
476 |
477 | # istft
478 |
479 | stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s = self.audio_channels)
480 |
481 | recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window = stft_window, return_complex = False)
482 |
483 | recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s = self.audio_channels, n = num_stems)
484 |
485 | if num_stems == 1:
486 | recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
487 |
488 | # if a target is passed in, calculate loss for learning
489 |
490 | if not exists(target):
491 | return recon_audio
492 |
493 | if self.num_stems > 1:
494 | assert target.ndim == 4 and target.shape[1] == self.num_stems
495 |
496 | if target.ndim == 2:
497 | target = rearrange(target, '... t -> ... 1 t')
498 |
499 | target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
500 |
501 | loss = F.l1_loss(recon_audio, target)
502 |
503 | multi_stft_resolution_loss = 0.
504 |
505 | for window_size in self.multi_stft_resolutions_window_sizes:
506 |
507 | res_stft_kwargs = dict(
508 | n_fft = max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
509 | win_length = window_size,
510 | return_complex = True,
511 | window = self.multi_stft_window_fn(window_size, device = device),
512 | **self.multi_stft_kwargs,
513 | )
514 |
515 | recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
516 | target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
517 |
518 | multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
519 |
520 | weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
521 |
522 | total_loss = loss + weighted_multi_resolution_loss
523 |
524 | if not return_loss_breakdown:
525 | return total_loss
526 |
527 | return total_loss, (loss, multi_stft_resolution_loss)
528 |
--------------------------------------------------------------------------------
/bs_roformer/mel_band_roformer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from functools import partial
3 |
4 | import torch
5 | from torch import nn, einsum, Tensor
6 | from torch.nn import Module, ModuleList
7 | import torch.nn.functional as F
8 |
9 | from bs_roformer.attend import Attend
10 |
11 | from beartype.typing import Callable
12 | from beartype import beartype
13 |
14 | from rotary_embedding_torch import RotaryEmbedding
15 |
16 | from einops import rearrange, pack, unpack, reduce, repeat
17 | from einops.layers.torch import Rearrange
18 |
19 | from librosa import filters
20 |
21 | from hyper_connections import get_init_and_expand_reduce_stream_functions
22 |
23 | # helper functions
24 |
25 | def exists(val):
26 | return val is not None
27 |
28 | def default(v, d):
29 | return v if exists(v) else d
30 |
31 | def pack_one(t, pattern):
32 | return pack([t], pattern)
33 |
34 | def unpack_one(t, ps, pattern):
35 | return unpack(t, ps, pattern)[0]
36 |
37 | def pad_at_dim(t, pad, dim = -1, value = 0.):
38 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
39 | zeros = ((0, 0) * dims_from_right)
40 | return F.pad(t, (*zeros, *pad), value = value)
41 |
42 | def l2norm(t):
43 | return F.normalize(t, dim = -1, p = 2)
44 |
45 | # norm
46 |
47 | class RMSNorm(Module):
48 | def __init__(self, dim):
49 | super().__init__()
50 | self.scale = dim ** 0.5
51 | self.gamma = nn.Parameter(torch.ones(dim))
52 |
53 | def forward(self, x):
54 | return F.normalize(x, dim = -1) * self.scale * self.gamma
55 |
56 | # attention
57 |
58 | class FeedForward(Module):
59 | def __init__(
60 | self,
61 | dim,
62 | mult = 4,
63 | dropout = 0.
64 | ):
65 | super().__init__()
66 | dim_inner = int(dim * mult)
67 | self.net = nn.Sequential(
68 | RMSNorm(dim),
69 | nn.Linear(dim, dim_inner),
70 | nn.GELU(),
71 | nn.Dropout(dropout),
72 | nn.Linear(dim_inner, dim),
73 | nn.Dropout(dropout)
74 | )
75 |
76 | def forward(self, x):
77 | return self.net(x)
78 |
79 | class Attention(Module):
80 | def __init__(
81 | self,
82 | dim,
83 | heads = 8,
84 | dim_head = 64,
85 | dropout = 0.,
86 | rotary_embed = None,
87 | flash = True,
88 | add_value_residual = False
89 | ):
90 | super().__init__()
91 | self.heads = heads
92 | self.scale = dim_head **-0.5
93 | dim_inner = heads * dim_head
94 |
95 | self.rotary_embed = rotary_embed
96 |
97 | self.attend = Attend(flash = flash, dropout = dropout)
98 |
99 | self.norm = RMSNorm(dim)
100 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias = False)
101 |
102 | self.to_gates = nn.Linear(dim, heads)
103 |
104 | self.learned_value_residual_mix = nn.Sequential(
105 | nn.Linear(dim, heads),
106 | Rearrange('b n h -> b h n 1'),
107 | nn.Sigmoid()
108 | ) if add_value_residual else None
109 |
110 | self.to_out = nn.Sequential(
111 | nn.Linear(dim_inner, dim, bias = False),
112 | nn.Dropout(dropout)
113 | )
114 |
115 | def forward(self, x, value_residual = None):
116 | x = self.norm(x)
117 |
118 | q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = self.heads)
119 |
120 | orig_v = v
121 |
122 | if exists(self.learned_value_residual_mix):
123 | mix = self.learned_value_residual_mix(x)
124 | assert exists(value_residual)
125 | v = v.lerp(mix, value_residual)
126 |
127 | if exists(self.rotary_embed):
128 | q = self.rotary_embed.rotate_queries_or_keys(q)
129 | k = self.rotary_embed.rotate_queries_or_keys(k)
130 |
131 | out = self.attend(q, k, v)
132 |
133 | gates = self.to_gates(x)
134 | out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
135 |
136 | out = rearrange(out, 'b h n d -> b n (h d)')
137 | return self.to_out(out), orig_v
138 |
139 | class LinearAttention(Module):
140 | """
141 | this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
142 | """
143 |
144 | @beartype
145 | def __init__(
146 | self,
147 | *,
148 | dim,
149 | dim_head = 32,
150 | heads = 8,
151 | scale = 8,
152 | flash = False,
153 | dropout = 0.,
154 | add_value_residual = False
155 | ):
156 | super().__init__()
157 | dim_inner = dim_head * heads
158 | self.norm = RMSNorm(dim)
159 |
160 | self.to_qkv = nn.Sequential(
161 | nn.Linear(dim, dim_inner * 3, bias = False),
162 | Rearrange('b n (qkv h d) -> qkv b h d n', qkv = 3, h = heads)
163 | )
164 |
165 | self.temperature = nn.Parameter(torch.zeros(heads, 1, 1))
166 |
167 | self.attend = Attend(
168 | scale = scale,
169 | dropout = dropout,
170 | flash = flash
171 | )
172 |
173 | self.learned_value_residual_mix = nn.Sequential(
174 | nn.Linear(dim, heads),
175 | Rearrange('b n h -> b h 1 n'),
176 | nn.Sigmoid()
177 | ) if add_value_residual else None
178 |
179 | self.to_out = nn.Sequential(
180 | Rearrange('b h d n -> b n (h d)'),
181 | nn.Linear(dim_inner, dim, bias = False)
182 | )
183 |
184 | def forward(
185 | self,
186 | x,
187 | value_residual = None
188 | ):
189 | x = self.norm(x)
190 |
191 | q, k, v = self.to_qkv(x)
192 |
193 | orig_v = v
194 |
195 | if exists(self.learned_value_residual_mix):
196 | mix = self.learned_value_residual_mix(x)
197 | assert exists(value_residual)
198 | v = v.lerp(mix, value_residual)
199 |
200 | q, k = map(l2norm, (q, k))
201 | q = q * self.temperature.exp()
202 |
203 | out = self.attend(q, k, v)
204 |
205 | return self.to_out(out), orig_v
206 |
207 | class Transformer(Module):
208 | def __init__(
209 | self,
210 | *,
211 | dim,
212 | depth,
213 | dim_head = 64,
214 | heads = 8,
215 | attn_dropout = 0.,
216 | ff_dropout = 0.,
217 | ff_mult = 4,
218 | norm_output = True,
219 | rotary_embed = None,
220 | flash_attn = True,
221 | linear_attn = False,
222 | add_value_residual = False,
223 | num_residual_streams = 1
224 | ):
225 | super().__init__()
226 |
227 | init_hyper_conn, *_ = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
228 |
229 | self.layers = ModuleList([])
230 |
231 | for _ in range(depth):
232 | if linear_attn:
233 | attn = LinearAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn, add_value_residual = add_value_residual)
234 | else:
235 | attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_embed = rotary_embed, flash = flash_attn, add_value_residual = add_value_residual)
236 |
237 | ff = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
238 |
239 | self.layers.append(ModuleList([
240 | init_hyper_conn(dim = dim, branch = attn),
241 | init_hyper_conn(dim = dim, branch = ff)
242 | ]))
243 |
244 | self.norm = RMSNorm(dim) if norm_output else nn.Identity()
245 |
246 | def forward(self, x, value_residual = None):
247 |
248 | first_values = None
249 |
250 | for attn, ff in self.layers:
251 |
252 | x, values = attn(x, value_residual = value_residual)
253 | first_values = default(first_values, values)
254 |
255 | x = ff(x)
256 |
257 | return self.norm(x), first_values
258 |
259 | # bandsplit module
260 |
261 | class BandSplit(Module):
262 | @beartype
263 | def __init__(
264 | self,
265 | dim,
266 | dim_inputs: tuple[int, ...]
267 | ):
268 | super().__init__()
269 | self.dim_inputs = dim_inputs
270 | self.to_features = ModuleList([])
271 |
272 | for dim_in in dim_inputs:
273 | net = nn.Sequential(
274 | RMSNorm(dim_in),
275 | nn.Linear(dim_in, dim)
276 | )
277 |
278 | self.to_features.append(net)
279 |
280 | def forward(self, x):
281 | x = x.split(self.dim_inputs, dim = -1)
282 |
283 | outs = []
284 | for split_input, to_feature in zip(x, self.to_features):
285 | split_output = to_feature(split_input)
286 | outs.append(split_output)
287 |
288 | return torch.stack(outs, dim = -2)
289 |
290 | def MLP(
291 | dim_in,
292 | dim_out,
293 | dim_hidden = None,
294 | depth = 1,
295 | activation = nn.Tanh
296 | ):
297 | dim_hidden = default(dim_hidden, dim_in)
298 |
299 | net = []
300 | dims = (dim_in, *((dim_hidden,) * depth), dim_out)
301 |
302 | for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
303 | is_last = ind == (len(dims) - 2)
304 |
305 | net.append(nn.Linear(layer_dim_in, layer_dim_out))
306 |
307 | if is_last:
308 | continue
309 |
310 | net.append(activation())
311 |
312 | return nn.Sequential(*net)
313 |
314 | class MaskEstimator(Module):
315 | @beartype
316 | def __init__(
317 | self,
318 | dim,
319 | dim_inputs: tuple[int, ...],
320 | depth,
321 | mlp_expansion_factor = 4
322 | ):
323 | super().__init__()
324 | self.dim_inputs = dim_inputs
325 | self.to_freqs = ModuleList([])
326 | dim_hidden = dim * mlp_expansion_factor
327 |
328 | for dim_in in dim_inputs:
329 | net = []
330 |
331 | mlp = nn.Sequential(
332 | MLP(dim, dim_in * 2, dim_hidden = dim_hidden, depth = depth),
333 | nn.GLU(dim = -1)
334 | )
335 |
336 | self.to_freqs.append(mlp)
337 |
338 | def forward(self, x):
339 | x = x.unbind(dim = -2)
340 |
341 | outs = []
342 |
343 | for band_features, mlp in zip(x, self.to_freqs):
344 | freq_out = mlp(band_features)
345 | outs.append(freq_out)
346 |
347 | return torch.cat(outs, dim = -1)
348 |
349 | # main class
350 |
351 | class MelBandRoformer(Module):
352 |
353 | @beartype
354 | def __init__(
355 | self,
356 | dim,
357 | *,
358 | depth,
359 | stereo = False,
360 | num_stems = 1,
361 | time_transformer_depth = 2,
362 | freq_transformer_depth = 2,
363 | linear_transformer_depth = 1,
364 | num_bands = 60,
365 | dim_head = 64,
366 | heads = 8,
367 | attn_dropout = 0.1,
368 | ff_dropout = 0.1,
369 | flash_attn = True,
370 | linear_flash_attn = None,
371 | dim_freqs_in = 1025,
372 | sample_rate = 44100, # needed for mel filter bank from librosa
373 | stft_n_fft = 2048,
374 | stft_hop_length = 512, # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
375 | stft_win_length = 2048,
376 | stft_normalized = False,
377 | stft_window_fn: Callable | None = None,
378 | mask_estimator_depth = 1,
379 | multi_stft_resolution_loss_weight = 1.,
380 | multi_stft_resolutions_window_sizes: tuple[int, ...] = (4096, 2048, 1024, 512, 256),
381 | multi_stft_hop_size = 147,
382 | multi_stft_normalized = False,
383 | multi_stft_window_fn: Callable = torch.hann_window,
384 | match_input_audio_length = False, # if True, pad output tensor to match length of input tensor
385 | add_value_residual = True,
386 | num_residual_streams = 4
387 | ):
388 | super().__init__()
389 |
390 | self.stereo = stereo
391 | self.audio_channels = 2 if stereo else 1
392 | self.num_stems = num_stems
393 |
394 | self.layers = ModuleList([])
395 |
396 | transformer_kwargs = dict(
397 | dim = dim,
398 | heads = heads,
399 | dim_head = dim_head,
400 | attn_dropout = attn_dropout,
401 | ff_dropout = ff_dropout,
402 | num_residual_streams = num_residual_streams
403 | )
404 |
405 | time_rotary_embed = RotaryEmbedding(dim = dim_head)
406 | freq_rotary_embed = RotaryEmbedding(dim = dim_head)
407 |
408 | linear_flash_attn = default(linear_flash_attn, flash_attn)
409 |
410 | # hyper connections
411 |
412 | _, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
413 |
414 | for layer_index in range(depth):
415 | is_first = layer_index == 0
416 |
417 | self.layers.append(nn.ModuleList([
418 | Transformer(depth = linear_transformer_depth, linear_attn = True, flash_attn = linear_flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs) if linear_transformer_depth > 0 else None,
419 | Transformer(depth = time_transformer_depth, rotary_embed = time_rotary_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs),
420 | Transformer(depth = freq_transformer_depth, rotary_embed = freq_rotary_embed, flash_attn = flash_attn, add_value_residual = add_value_residual and not is_first, **transformer_kwargs)
421 | ]))
422 |
423 | self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
424 |
425 | self.stft_kwargs = dict(
426 | n_fft = stft_n_fft,
427 | hop_length = stft_hop_length,
428 | win_length = stft_win_length,
429 | normalized = stft_normalized
430 | )
431 |
432 | freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex = True).shape[1]
433 |
434 | # create mel filter bank
435 | # with librosa.filters.mel as in section 2 of paper
436 |
437 | mel_filter_bank_numpy = filters.mel(sr = sample_rate, n_fft = stft_n_fft, n_mels = num_bands)
438 |
439 | mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
440 |
441 | # for some reason, it doesn't include the first freq? just force a value for now
442 |
443 | mel_filter_bank[0][0] = 1.
444 |
445 | # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
446 | # so let's force a positive value
447 |
448 | mel_filter_bank[-1, -1] = 1.
449 |
450 | # binary as in paper (then estimated masks are averaged for overlapping regions)
451 |
452 | freqs_per_band = mel_filter_bank > 0
453 | assert freqs_per_band.any(dim = 0).all(), 'all frequencies need to be covered by all bands for now'
454 |
455 | repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b = num_bands)
456 | freq_indices = repeated_freq_indices[freqs_per_band]
457 |
458 | if stereo:
459 | freq_indices = repeat(freq_indices, 'f -> f s', s = 2)
460 | freq_indices = freq_indices * 2 + torch.arange(2)
461 | freq_indices = rearrange(freq_indices, 'f s -> (f s)')
462 |
463 | self.register_buffer('freq_indices', freq_indices, persistent = False)
464 | self.register_buffer('freqs_per_band', freqs_per_band, persistent = False)
465 |
466 | num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
467 | num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
468 |
469 | self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent = False)
470 | self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent = False)
471 |
472 | # band split and mask estimator
473 |
474 | freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
475 |
476 | self.band_split = BandSplit(
477 | dim = dim,
478 | dim_inputs = freqs_per_bands_with_complex
479 | )
480 |
481 | self.mask_estimators = nn.ModuleList([])
482 |
483 | for _ in range(num_stems):
484 | mask_estimator = MaskEstimator(
485 | dim = dim,
486 | dim_inputs = freqs_per_bands_with_complex,
487 | depth = mask_estimator_depth
488 | )
489 |
490 | self.mask_estimators.append(mask_estimator)
491 |
492 | # for the multi-resolution stft loss
493 |
494 | self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
495 | self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
496 | self.multi_stft_n_fft = stft_n_fft
497 | self.multi_stft_window_fn = multi_stft_window_fn
498 |
499 | self.multi_stft_kwargs = dict(
500 | hop_length = multi_stft_hop_size,
501 | normalized = multi_stft_normalized
502 | )
503 |
504 | self.match_input_audio_length = match_input_audio_length
505 |
506 | def forward(
507 | self,
508 | raw_audio,
509 | target = None,
510 | return_loss_breakdown = False
511 | ):
512 | """
513 | einops
514 |
515 | b - batch
516 | f - freq
517 | t - time
518 | s - audio channel (1 for mono, 2 for stereo)
519 | n - number of 'stems'
520 | c - complex (2)
521 | d - feature dimension
522 | """
523 |
524 | device = raw_audio.device
525 |
526 | if raw_audio.ndim == 2:
527 | raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
528 |
529 | batch, channels, raw_audio_length = raw_audio.shape
530 |
531 | istft_length = raw_audio_length if self.match_input_audio_length else None
532 |
533 | assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
534 |
535 | # to stft
536 |
537 | raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
538 |
539 | stft_window = self.stft_window_fn(device = device)
540 |
541 | stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window = stft_window, return_complex = True)
542 | stft_repr = torch.view_as_real(stft_repr)
543 |
544 | stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
545 | stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
546 |
547 | # index out all frequencies for all frequency ranges across bands ascending in one go
548 |
549 | batch_arange = torch.arange(batch, device = device)[..., None]
550 |
551 | # account for stereo
552 |
553 | x = stft_repr[batch_arange, self.freq_indices]
554 |
555 | # fold the complex (real and imag) into the frequencies dimension
556 |
557 | x = rearrange(x, 'b f t c -> b t (f c)')
558 |
559 | x = self.band_split(x)
560 |
561 | # value residuals
562 |
563 | linear_value_residual = None
564 | time_value_residual = None
565 | freq_value_residual = None
566 |
567 | # expand residual streams (hyper connections)
568 |
569 | x = self.expand_streams(x)
570 |
571 | # axial / hierarchical attention
572 |
573 | for linear_transformer, time_transformer, freq_transformer in self.layers:
574 |
575 | if exists(linear_transformer):
576 | x, ft_ps = pack([x], 'b * d')
577 |
578 | x, next_linear_values = linear_transformer(x, value_residual = linear_value_residual)
579 | linear_value_residual = default(linear_value_residual, next_linear_values)
580 |
581 | x, = unpack(x, ft_ps, 'b * d')
582 |
583 | x = rearrange(x, 'b t f d -> b f t d')
584 | x, ps = pack([x], '* t d')
585 |
586 | x, next_time_values = time_transformer(x, value_residual = time_value_residual)
587 | time_value_residual = default(time_value_residual, next_time_values)
588 |
589 | x, = unpack(x, ps, '* t d')
590 | x = rearrange(x, 'b f t d -> b t f d')
591 | x, ps = pack([x], '* f d')
592 |
593 | x, next_freq_values = freq_transformer(x, value_residual = freq_value_residual)
594 | freq_value_residual = default(freq_value_residual, next_freq_values)
595 |
596 | x, = unpack(x, ps, '* f d')
597 |
598 | # reduce residual streams
599 |
600 | x = self.reduce_streams(x)
601 |
602 | # mask estimators
603 |
604 | num_stems = len(self.mask_estimators)
605 |
606 | masks = torch.stack([fn(x) for fn in self.mask_estimators], dim = 1)
607 | masks = rearrange(masks, 'b n t (f c) -> b n f t c', c = 2)
608 |
609 | # modulate frequency representation
610 |
611 | stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
612 |
613 | # complex number multiplication
614 |
615 | stft_repr = torch.view_as_complex(stft_repr)
616 | masks = torch.view_as_complex(masks)
617 |
618 | masks = masks.type(stft_repr.dtype)
619 |
620 | # need to average the estimated mask for the overlapped frequencies
621 |
622 | scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b = batch, n = num_stems, t = stft_repr.shape[-1])
623 |
624 | stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n = num_stems)
625 | masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
626 |
627 | denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r = channels)
628 |
629 | masks_averaged = masks_summed / denom.clamp(min = 1e-8)
630 |
631 | # modulate stft repr with estimated mask
632 |
633 | stft_repr = stft_repr * masks_averaged
634 |
635 | # istft
636 |
637 | stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s = self.audio_channels)
638 |
639 | recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window = stft_window, return_complex = False, length = istft_length)
640 |
641 | recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s = self.audio_channels, n = num_stems)
642 |
643 | if num_stems == 1:
644 | recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
645 |
646 | # if a target is passed in, calculate loss for learning
647 |
648 | if not exists(target):
649 | return recon_audio
650 |
651 | if self.num_stems > 1:
652 | assert target.ndim == 4 and target.shape[1] == self.num_stems
653 |
654 | if target.ndim == 2:
655 | target = rearrange(target, '... t -> ... 1 t')
656 |
657 | target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
658 |
659 | loss = F.l1_loss(recon_audio, target)
660 |
661 | multi_stft_resolution_loss = 0.
662 |
663 | for window_size in self.multi_stft_resolutions_window_sizes:
664 |
665 | res_stft_kwargs = dict(
666 | n_fft = max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
667 | win_length = window_size,
668 | return_complex = True,
669 | window = self.multi_stft_window_fn(window_size, device = device),
670 | **self.multi_stft_kwargs,
671 | )
672 |
673 | recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs)
674 | target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs)
675 |
676 | multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
677 |
678 | weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
679 |
680 | total_loss = loss + weighted_multi_resolution_loss
681 |
682 | if not return_loss_breakdown:
683 | return total_loss
684 |
685 | return total_loss, (loss, multi_stft_resolution_loss)
686 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'BS-RoFormer',
5 | packages = find_packages(exclude=[]),
6 | version = '0.6.1',
7 | license='MIT',
8 | description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/BS-RoFormer',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'music source separation'
19 | ],
20 | install_requires=[
21 | 'beartype',
22 | 'einops>=0.8.0',
23 | 'hyper-connections>=0.1.8',
24 | 'librosa',
25 | 'rotary-embedding-torch>=0.3.6',
26 | 'torch>=2.0',
27 | ],
28 | classifiers=[
29 | 'Development Status :: 4 - Beta',
30 | 'Intended Audience :: Developers',
31 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
32 | 'License :: OSI Approved :: MIT License',
33 | 'Programming Language :: Python :: 3.6',
34 | ],
35 | )
36 |
--------------------------------------------------------------------------------