├── requirements.txt ├── .gitattributes ├── __init__.py ├── .gitignore ├── pyproject.toml ├── readme.md ├── .github └── workflows │ └── publish.yml ├── model ├── attend.py ├── mel_band_roformer.py └── bs_roformer.py ├── example_workflows └── melband_example.json └── nodes.py /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa 2 | rotary_embedding_torch 3 | einops 4 | PyYAML 5 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output/ 2 | *__pycache__/ 3 | samples*/ 4 | runs/ 5 | checkpoints/ 6 | master_ip 7 | logs/ 8 | *.DS_Store 9 | .idea 10 | tools/ 11 | .vscode/ 12 | convert_* 13 | *.pt -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ComfyUI-MelBandRoFormer" 3 | description = "ComfyUI wrapper nodes for WanVideo" 4 | version = "1.0.1" 5 | license = {file = "LICENSE"} 6 | dependencies = ["librosa", "rotary_embedding_torch", "einops"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/kijai/ComfyUI-MelBandRoFormer" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "kijai" 14 | DisplayName = "ComfyUI-MelBandRoFormer" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## ComfyUI node for [Mel-Band RoFormer for Music Source Separation](https://arxiv.org/abs/2310.01809) 2 | 3 | Added support for BS-RoFormer(https://arxiv.org/abs/2309.02612) models. 4 | 5 | https://huggingface.co/anvuew/karaoke_bs_roformer 6 | 7 | https://huggingface.co/models?search=bs_roformer 8 | https://huggingface.co/models?search=mel_band_roformer 9 | 10 | ----------------------------------------------- 11 | 12 | Dereverb Models: 13 | 14 | https://huggingface.co/anvuew/dereverb_mel_band_roformer 15 | 16 | Place models and any config files in `comfy_models\diffusion_models`. 17 | 18 | image 19 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'kijai' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} -------------------------------------------------------------------------------- /model/attend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from packaging import version 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange, reduce 10 | 11 | # constants 12 | 13 | FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def default(v, d): 21 | return v if exists(v) else d 22 | 23 | def once(fn): 24 | called = False 25 | @wraps(fn) 26 | def inner(x): 27 | nonlocal called 28 | if called: 29 | return 30 | called = True 31 | return fn(x) 32 | return inner 33 | 34 | print_once = once(print) 35 | 36 | # main class 37 | 38 | class Attend(nn.Module): 39 | def __init__( 40 | self, 41 | dropout = 0., 42 | flash = False, 43 | scale = None 44 | ): 45 | super().__init__() 46 | self.scale = scale 47 | self.dropout = dropout 48 | self.attn_dropout = nn.Dropout(dropout) 49 | 50 | self.flash = flash 51 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 52 | 53 | # determine efficient attention configs for cuda and cpu 54 | 55 | self.cpu_config = FlashAttentionConfig(True, True, True) 56 | self.cuda_config = None 57 | 58 | if not torch.cuda.is_available() or not flash: 59 | return 60 | 61 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 62 | 63 | if device_properties.major == 8 and device_properties.minor == 0: 64 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 65 | self.cuda_config = FlashAttentionConfig(True, False, False) 66 | else: 67 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 68 | self.cuda_config = FlashAttentionConfig(False, True, True) 69 | 70 | def flash_attn(self, q, k, v): 71 | _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 72 | 73 | if exists(self.scale): 74 | default_scale = q.shape[-1] ** -0.5 75 | q = q * (self.scale / default_scale) 76 | 77 | # Check if there is a compatible device for flash attention 78 | 79 | config = self.cuda_config if is_cuda else self.cpu_config 80 | 81 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale 82 | 83 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 84 | out = F.scaled_dot_product_attention( 85 | q, k, v, 86 | dropout_p = self.dropout if self.training else 0. 87 | ) 88 | 89 | return out 90 | 91 | def forward(self, q, k, v): 92 | """ 93 | einstein notation 94 | b - batch 95 | h - heads 96 | n, i, j - sequence length (base sequence length, source, target) 97 | d - feature dimension 98 | """ 99 | 100 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device 101 | 102 | scale = default(self.scale, q.shape[-1] ** -0.5) 103 | 104 | if self.flash: 105 | return self.flash_attn(q, k, v) 106 | 107 | # similarity 108 | 109 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale 110 | 111 | # attention 112 | 113 | attn = sim.softmax(dim=-1) 114 | attn = self.attn_dropout(attn) 115 | 116 | # aggregate values 117 | 118 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v) 119 | 120 | return out 121 | -------------------------------------------------------------------------------- /example_workflows/melband_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "id": "8b7a9a57-2303-4ef5-9fc2-bf41713bd1fc", 3 | "revision": 0, 4 | "last_node_id": 326, 5 | "last_link_id": 585, 6 | "nodes": [ 7 | { 8 | "id": 325, 9 | "type": "MelBandRoFormerModelLoader", 10 | "pos": [ 11 | 2141.159423828125, 12 | -1714.64501953125 13 | ], 14 | "size": [ 15 | 419.69976806640625, 16 | 58 17 | ], 18 | "flags": {}, 19 | "order": 0, 20 | "mode": 0, 21 | "inputs": [], 22 | "outputs": [ 23 | { 24 | "name": "model", 25 | "type": "MELROFORMERMODEL", 26 | "links": [ 27 | 581 28 | ] 29 | } 30 | ], 31 | "properties": { 32 | "aux_id": "kijai/ComfyUI-MelBandRoFormer", 33 | "ver": "7b12f8c6105666552ac4e4f08b73e0f96eb05c64", 34 | "Node name for S&R": "MelBandRoFormerModelLoader" 35 | }, 36 | "widgets_values": [ 37 | "MelRoFormer\\MelBandRoformer_fp16.safetensors" 38 | ] 39 | }, 40 | { 41 | "id": 317, 42 | "type": "LoadAudio", 43 | "pos": [ 44 | 2146.195068359375, 45 | -1579.27783203125 46 | ], 47 | "size": [ 48 | 274.080078125, 49 | 136 50 | ], 51 | "flags": {}, 52 | "order": 1, 53 | "mode": 0, 54 | "inputs": [], 55 | "outputs": [ 56 | { 57 | "name": "AUDIO", 58 | "type": "AUDIO", 59 | "links": [ 60 | 582 61 | ] 62 | } 63 | ], 64 | "properties": { 65 | "cnr_id": "comfy-core", 66 | "ver": "0.3.51", 67 | "Node name for S&R": "LoadAudio" 68 | }, 69 | "widgets_values": [ 70 | "0321. Alphaville - Big In Japan.mp3", 71 | null, 72 | null 73 | ] 74 | }, 75 | { 76 | "id": 319, 77 | "type": "PreviewAudio", 78 | "pos": [ 79 | 2892.392578125, 80 | -1695.0682373046875 81 | ], 82 | "size": [ 83 | 383.5520935058594, 84 | 88 85 | ], 86 | "flags": {}, 87 | "order": 3, 88 | "mode": 0, 89 | "inputs": [ 90 | { 91 | "name": "audio", 92 | "type": "AUDIO", 93 | "link": 583 94 | } 95 | ], 96 | "outputs": [], 97 | "properties": { 98 | "cnr_id": "comfy-core", 99 | "ver": "0.3.51", 100 | "Node name for S&R": "PreviewAudio" 101 | }, 102 | "widgets_values": [] 103 | }, 104 | { 105 | "id": 326, 106 | "type": "MelBandRoFormerSampler", 107 | "pos": [ 108 | 2610.4697265625, 109 | -1629.0616455078125 110 | ], 111 | "size": [ 112 | 222.73397827148438, 113 | 46 114 | ], 115 | "flags": {}, 116 | "order": 2, 117 | "mode": 0, 118 | "inputs": [ 119 | { 120 | "name": "model", 121 | "type": "MELROFORMERMODEL", 122 | "link": 581 123 | }, 124 | { 125 | "name": "audio", 126 | "type": "AUDIO", 127 | "link": 582 128 | } 129 | ], 130 | "outputs": [ 131 | { 132 | "name": "vocals", 133 | "type": "AUDIO", 134 | "links": [ 135 | 583 136 | ] 137 | }, 138 | { 139 | "name": "instruments", 140 | "type": "AUDIO", 141 | "links": [ 142 | 585 143 | ] 144 | } 145 | ], 146 | "properties": { 147 | "aux_id": "kijai/ComfyUI-MelBandRoFormer", 148 | "ver": "7b12f8c6105666552ac4e4f08b73e0f96eb05c64", 149 | "Node name for S&R": "MelBandRoFormerSampler" 150 | } 151 | }, 152 | { 153 | "id": 323, 154 | "type": "PreviewAudio", 155 | "pos": [ 156 | 2897.984619140625, 157 | -1545.7156982421875 158 | ], 159 | "size": [ 160 | 376.83953857421875, 161 | 91.35621643066406 162 | ], 163 | "flags": {}, 164 | "order": 4, 165 | "mode": 0, 166 | "inputs": [ 167 | { 168 | "name": "audio", 169 | "type": "AUDIO", 170 | "link": 585 171 | } 172 | ], 173 | "outputs": [], 174 | "properties": { 175 | "cnr_id": "comfy-core", 176 | "ver": "0.3.51", 177 | "Node name for S&R": "PreviewAudio" 178 | }, 179 | "widgets_values": [] 180 | } 181 | ], 182 | "links": [ 183 | [ 184 | 581, 185 | 325, 186 | 0, 187 | 326, 188 | 0, 189 | "MELROFORMERMODEL" 190 | ], 191 | [ 192 | 582, 193 | 317, 194 | 0, 195 | 326, 196 | 1, 197 | "AUDIO" 198 | ], 199 | [ 200 | 583, 201 | 326, 202 | 0, 203 | 319, 204 | 0, 205 | "AUDIO" 206 | ], 207 | [ 208 | 585, 209 | 326, 210 | 1, 211 | 323, 212 | 0, 213 | "AUDIO" 214 | ] 215 | ], 216 | "groups": [], 217 | "config": {}, 218 | "extra": { 219 | "ds": { 220 | "scale": 1.1918176537728358, 221 | "offset": [ 222 | -1936.9884239637872, 223 | 1969.438456906274 224 | ] 225 | }, 226 | "frontendVersion": "1.26.3", 227 | "node_versions": { 228 | "ComfyUI-WanVideoWrapper": "0a11c67a0c0062b534178920a0d6dcaa75e7b5fe", 229 | "comfy-core": "0.3.43", 230 | "audio-separation-nodes-comfyui": "31a4567726e035097cc2d1f767767908a6fda2ea", 231 | "ComfyUI-KJNodes": "f7eb33abc80a2aded1b46dff0dd14d07856a7d50", 232 | "comfyui-videohelpersuite": "a7ce59e381934733bfae03b1be029756d6ce936d" 233 | }, 234 | "VHS_latentpreview": true, 235 | "VHS_latentpreviewrate": 0, 236 | "VHS_MetadataImage": true, 237 | "VHS_KeepIntermediate": true 238 | }, 239 | "version": 0.4 240 | } -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | from tqdm import tqdm 5 | import yaml 6 | 7 | import librosa 8 | import folder_paths 9 | 10 | from .model.mel_band_roformer import MelBandRoformer 11 | from .model.bs_roformer import BSRoformer 12 | 13 | script_directory = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | from comfy import model_management as mm 16 | from comfy.utils import load_torch_file, ProgressBar 17 | device = mm.get_torch_device() 18 | offload_device = mm.unet_offload_device() 19 | 20 | # Use the diffusion_models directory for config files 21 | folder_paths.add_model_folder_path("melband_configs", os.path.join(folder_paths.models_dir, "diffusion_models")) 22 | 23 | def get_windowing_array(window_size, fade_size, device): 24 | fadein = torch.linspace(0, 1, fade_size) 25 | fadeout = torch.linspace(1, 0, fade_size) 26 | window = torch.ones(window_size) 27 | window[-fade_size:] *= fadeout 28 | window[:fade_size] *= fadein 29 | return window.to(device) 30 | 31 | class MelBandRoFormerModelLoader: 32 | @classmethod 33 | def INPUT_TYPES(s): 34 | return { 35 | "required": { 36 | "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), 37 | }, 38 | "optional": { 39 | "config_name": ([""] + folder_paths.get_filename_list("melband_configs"), {"tooltip": "Optional YAML config file from 'ComfyUI/models/diffusion_models' folder. If empty, uses default config."}), 40 | }, 41 | } 42 | 43 | RETURN_TYPES = ("MELROFORMERMODEL",) 44 | RETURN_NAMES = ("model", ) 45 | FUNCTION = "loadmodel" 46 | CATEGORY = "Mel-Band RoFormer" 47 | 48 | def loadmodel(self, model_name, config_name=""): 49 | # Default model configuration 50 | model_config = { 51 | "dim": 384, 52 | "depth": 6, 53 | "stereo": True, 54 | "num_stems": 1, 55 | "time_transformer_depth": 1, 56 | "freq_transformer_depth": 1, 57 | "num_bands": 60, 58 | "dim_head": 64, 59 | "heads": 8, 60 | "attn_dropout": 0, 61 | "ff_dropout": 0, 62 | "flash_attn": True, 63 | "dim_freqs_in": 1025, 64 | "sample_rate": 44100, # needed for mel filter bank from librosa 65 | "stft_n_fft": 2048, 66 | "stft_hop_length": 441, 67 | "stft_win_length": 2048, 68 | "stft_normalized": False, 69 | "mask_estimator_depth": 2, 70 | "multi_stft_resolution_loss_weight": 1.0, 71 | "multi_stft_resolutions_window_sizes": (4096, 2048, 1024, 512, 256), 72 | "multi_stft_hop_size": 147, 73 | "multi_stft_normalized": False, 74 | } 75 | 76 | # Load config from YAML file if provided 77 | if config_name: 78 | try: 79 | config_path = folder_paths.get_full_path_or_raise("melband_configs", config_name) 80 | with open(config_path, 'r', encoding='utf-8') as f: 81 | # Use full loader to handle Python-specific tags like !!python/tuple 82 | yaml_config = yaml.load(f, Loader=yaml.FullLoader) 83 | 84 | # Extract model configuration from YAML 85 | if 'model' in yaml_config: 86 | model_config.update(yaml_config['model']) 87 | 88 | print(f"Loaded model configuration from: {config_name}") 89 | 90 | except Exception as e: 91 | print(f"Error loading YAML config file '{config_name}': {e}") 92 | print("Using default configuration instead.") 93 | 94 | # Convert tuple from YAML to tuple for Python 95 | if 'multi_stft_resolutions_window_sizes' in model_config: 96 | if isinstance(model_config['multi_stft_resolutions_window_sizes'], list): 97 | model_config['multi_stft_resolutions_window_sizes'] = tuple(model_config['multi_stft_resolutions_window_sizes']) 98 | 99 | model = MelBandRoformer(**model_config).eval() 100 | model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) 101 | model.load_state_dict(load_torch_file(model_path), strict=True) 102 | 103 | return (model,) 104 | 105 | 106 | class BSRoformerModelLoader: 107 | @classmethod 108 | def INPUT_TYPES(s): 109 | return { 110 | "required": { 111 | "model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from the 'ComfyUI/models/diffusion_models' -folder",}), 112 | "config_name": ([""] + folder_paths.get_filename_list("melband_configs"), {"tooltip": "Optional YAML config file from 'ComfyUI/models/diffusion_models' folder."}), 113 | }, 114 | } 115 | 116 | RETURN_TYPES = ("MELROFORMERMODEL",) 117 | RETURN_NAMES = ("model", ) 118 | FUNCTION = "loadmodel" 119 | CATEGORY = "Mel-Band RoFormer" 120 | 121 | def loadmodel(self, model_name, config_name=""): 122 | # Default model configuration 123 | model_config = { 124 | } 125 | 126 | # Load config from YAML file if provided 127 | if config_name: 128 | try: 129 | config_path = folder_paths.get_full_path_or_raise("melband_configs", config_name) 130 | with open(config_path, 'r', encoding='utf-8') as f: 131 | # Use full loader to handle Python-specific tags like !!python/tuple 132 | yaml_config = yaml.load(f, Loader=yaml.FullLoader) 133 | 134 | # Extract model configuration from YAML 135 | if 'model' in yaml_config: 136 | model_config.update(yaml_config['model']) 137 | 138 | print(f"Loaded model configuration from: {config_name}") 139 | 140 | except Exception as e: 141 | print(f"Error loading YAML config file '{config_name}': {e}") 142 | print("Using default configuration instead.") 143 | 144 | # Convert tuple from YAML to tuple for Python 145 | if 'multi_stft_resolutions_window_sizes' in model_config: 146 | if isinstance(model_config['multi_stft_resolutions_window_sizes'], list): 147 | model_config['multi_stft_resolutions_window_sizes'] = tuple(model_config['multi_stft_resolutions_window_sizes']) 148 | 149 | # freqs_per_bands 150 | if 'freqs_per_bands' in model_config: 151 | if isinstance(model_config['freqs_per_bands'], list): 152 | model_config['freqs_per_bands'] = tuple(model_config['freqs_per_bands']) 153 | 154 | model = BSRoformer(**model_config).eval() 155 | model_path = folder_paths.get_full_path_or_raise("diffusion_models", model_name) 156 | model.load_state_dict(load_torch_file(model_path), strict=True) 157 | 158 | return (model,) 159 | 160 | 161 | class MelBandRoFormerSampler: 162 | @classmethod 163 | def INPUT_TYPES(s): 164 | return { 165 | "required": { 166 | "model": ("MELROFORMERMODEL",), 167 | "audio": ("AUDIO",), 168 | }, 169 | } 170 | 171 | RETURN_TYPES = ("AUDIO","AUDIO",) 172 | RETURN_NAMES = ("vocals", "instruments") 173 | FUNCTION = "process" 174 | CATEGORY = "Mel-Band RoFormer" 175 | 176 | def process(self, model, audio): 177 | 178 | audio_input = audio["waveform"] 179 | sample_rate = audio["sample_rate"] 180 | 181 | B, audio_channels, audio_length = audio_input.shape 182 | 183 | sr = 44100 184 | 185 | if audio_channels == 1: 186 | # Convert mono to stereo by duplicating the channel 187 | audio_input = audio_input.repeat(1, 2, 1) 188 | audio_channels = 2 189 | print("Converted mono input to stereo.") 190 | 191 | if sample_rate != sr: 192 | print(f"Resampling input {sample_rate} to {sr}") 193 | audio_np = audio_input.cpu().numpy() 194 | resampled = librosa.resample(audio_np, orig_sr=sample_rate, target_sr=sr, axis=-1) 195 | audio_input = torch.from_numpy(resampled) 196 | audio_input = original_audio = audio_input[0] 197 | 198 | C = 352800 199 | N = 2 200 | step = C // N 201 | fade_size = C // 10 202 | border = C - step 203 | 204 | if audio_length > 2 * border and border > 0: 205 | audio_input = F.pad(audio_input, (border, border), mode='reflect') 206 | 207 | windowing_array = get_windowing_array(C, fade_size, device) 208 | 209 | 210 | audio_input = audio_input.to(device) 211 | vocals = torch.zeros_like(audio_input, dtype=torch.float32).to(device) 212 | counter = torch.zeros_like(audio_input, dtype=torch.float32).to(device) 213 | 214 | total_length = audio_input.shape[1] 215 | num_chunks = (total_length + step - 1) // step 216 | 217 | model.to(device) 218 | 219 | comfy_pbar = ProgressBar(num_chunks) 220 | 221 | for i in tqdm(range(0, total_length, step), desc="Processing chunks"): 222 | part = audio_input[:, i:i + C] 223 | length = part.shape[-1] 224 | if length < C: 225 | if length > C // 2 + 1: 226 | part = F.pad(input=part, pad=(0, C - length), mode='reflect') 227 | else: 228 | part = F.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) 229 | 230 | x = model(part.unsqueeze(0))[0] 231 | 232 | window = windowing_array.clone() 233 | if i == 0: 234 | window[:fade_size] = 1 235 | elif i + C >= total_length: 236 | window[-fade_size:] = 1 237 | 238 | vocals[..., i:i+length] += x[..., :length] * window[..., :length] 239 | counter[..., i:i+length] += window[..., :length] 240 | comfy_pbar.update(1) 241 | 242 | model.to(offload_device) 243 | 244 | estimated_sources = vocals / counter 245 | 246 | if audio_length > 2 * border and border > 0: 247 | estimated_sources = estimated_sources[..., border:-border] 248 | 249 | vocals_out = { 250 | "waveform": estimated_sources.unsqueeze(0).cpu(), 251 | "sample_rate": sr, 252 | } 253 | instruments_out = { 254 | "waveform": (original_audio.to(device) - estimated_sources).unsqueeze(0).cpu(), 255 | "sample_rate": sr, 256 | } 257 | 258 | return (vocals_out, instruments_out) 259 | 260 | NODE_CLASS_MAPPINGS = { 261 | "MelBandRoFormerModelLoader": MelBandRoFormerModelLoader, 262 | "BSRoformerModelLoader": BSRoformerModelLoader, 263 | "MelBandRoFormerSampler": MelBandRoFormerSampler, 264 | } 265 | NODE_DISPLAY_NAME_MAPPINGS = { 266 | "MelBandRoFormerModelLoader": "Mel-Band RoFormer Model Loader", 267 | "MelBandRoFormerSampler": "Mel-Band RoFormer Sampler", 268 | } 269 | -------------------------------------------------------------------------------- /model/mel_band_roformer.py: -------------------------------------------------------------------------------- 1 | # https://github.com/KimberleyJensen/Mel-Band-Roformer-Vocal-Model/blob/main/models/mel_band_roformer/mel_band_roformer.py 2 | 3 | from functools import partial 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import Module, ModuleList 8 | import torch.nn.functional as F 9 | 10 | from rotary_embedding_torch import RotaryEmbedding 11 | 12 | from einops import rearrange, pack, unpack, reduce, repeat 13 | 14 | from librosa import filters 15 | 16 | # helper functions 17 | 18 | def exists(val): 19 | return val is not None 20 | 21 | 22 | def default(v, d): 23 | return v if exists(v) else d 24 | 25 | 26 | def pack_one(t, pattern): 27 | return pack([t], pattern) 28 | 29 | 30 | def unpack_one(t, ps, pattern): 31 | return unpack(t, ps, pattern)[0] 32 | 33 | 34 | def pad_at_dim(t, pad, dim=-1, value=0.): 35 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) 36 | zeros = ((0, 0) * dims_from_right) 37 | return F.pad(t, (*zeros, *pad), value=value) 38 | 39 | 40 | # norm 41 | 42 | class RMSNorm(Module): 43 | def __init__(self, dim): 44 | super().__init__() 45 | self.scale = dim ** 0.5 46 | self.gamma = nn.Parameter(torch.ones(dim)) 47 | 48 | def forward(self, x): 49 | return F.normalize(x, dim=-1) * self.scale * self.gamma 50 | 51 | 52 | # attention 53 | 54 | class FeedForward(Module): 55 | def __init__( 56 | self, 57 | dim, 58 | mult=4, 59 | dropout=0. 60 | ): 61 | super().__init__() 62 | dim_inner = int(dim * mult) 63 | self.net = nn.Sequential( 64 | RMSNorm(dim), 65 | nn.Linear(dim, dim_inner), 66 | nn.GELU(), 67 | nn.Dropout(dropout), 68 | nn.Linear(dim_inner, dim), 69 | nn.Dropout(dropout) 70 | ) 71 | 72 | def forward(self, x): 73 | return self.net(x) 74 | 75 | 76 | class Attention(Module): 77 | def __init__( 78 | self, 79 | dim, 80 | heads=8, 81 | dim_head=64, 82 | dropout=0., 83 | rotary_embed=None, 84 | ): 85 | super().__init__() 86 | self.heads = heads 87 | self.scale = dim_head ** -0.5 88 | dim_inner = heads * dim_head 89 | 90 | self.rotary_embed = rotary_embed 91 | 92 | self.attend = F.scaled_dot_product_attention 93 | 94 | self.norm = RMSNorm(dim) 95 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) 96 | 97 | self.to_gates = nn.Linear(dim, heads) 98 | 99 | self.to_out = nn.Sequential( 100 | nn.Linear(dim_inner, dim, bias=False), 101 | nn.Dropout(dropout) 102 | ) 103 | 104 | def forward(self, x): 105 | x = self.norm(x) 106 | 107 | q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) 108 | 109 | if exists(self.rotary_embed): 110 | q = self.rotary_embed.rotate_queries_or_keys(q) 111 | k = self.rotary_embed.rotate_queries_or_keys(k) 112 | 113 | out = self.attend(q, k, v) 114 | 115 | gates = self.to_gates(x) 116 | out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() 117 | 118 | out = rearrange(out, 'b h n d -> b n (h d)') 119 | return self.to_out(out) 120 | 121 | 122 | class Transformer(Module): 123 | def __init__( 124 | self, 125 | *, 126 | dim, 127 | depth, 128 | dim_head=64, 129 | heads=8, 130 | attn_dropout=0., 131 | ff_dropout=0., 132 | ff_mult=4, 133 | norm_output=True, 134 | rotary_embed=None, 135 | flash_attn=True 136 | ): 137 | super().__init__() 138 | self.layers = ModuleList([]) 139 | 140 | for _ in range(depth): 141 | self.layers.append(ModuleList([ 142 | Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, rotary_embed=rotary_embed), 143 | FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) 144 | ])) 145 | 146 | self.norm = RMSNorm(dim) if norm_output else nn.Identity() 147 | 148 | def forward(self, x): 149 | 150 | for attn, ff in self.layers: 151 | x = attn(x) + x 152 | x = ff(x) + x 153 | 154 | return self.norm(x) 155 | 156 | 157 | # bandsplit module 158 | 159 | class BandSplit(Module): 160 | def __init__( 161 | self, 162 | dim, 163 | dim_inputs 164 | ): 165 | super().__init__() 166 | self.dim_inputs = dim_inputs 167 | self.to_features = ModuleList([]) 168 | 169 | for dim_in in dim_inputs: 170 | net = nn.Sequential( 171 | RMSNorm(dim_in), 172 | nn.Linear(dim_in, dim) 173 | ) 174 | 175 | self.to_features.append(net) 176 | 177 | def forward(self, x): 178 | x = x.split(self.dim_inputs, dim=-1) 179 | 180 | outs = [] 181 | for split_input, to_feature in zip(x, self.to_features): 182 | split_output = to_feature(split_input) 183 | outs.append(split_output) 184 | 185 | return torch.stack(outs, dim=-2) 186 | 187 | 188 | def MLP( 189 | dim_in, 190 | dim_out, 191 | dim_hidden=None, 192 | depth=1, 193 | activation=nn.Tanh 194 | ): 195 | dim_hidden = default(dim_hidden, dim_in) 196 | 197 | net = [] 198 | dims = (dim_in, *((dim_hidden,) * depth), dim_out) 199 | 200 | for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): 201 | is_last = ind == (len(dims) - 2) 202 | 203 | net.append(nn.Linear(layer_dim_in, layer_dim_out)) 204 | 205 | if is_last: 206 | continue 207 | 208 | net.append(activation()) 209 | 210 | return nn.Sequential(*net) 211 | 212 | 213 | class MaskEstimator(Module): 214 | def __init__( 215 | self, 216 | dim, 217 | dim_inputs, 218 | depth, 219 | mlp_expansion_factor=4 220 | ): 221 | super().__init__() 222 | self.dim_inputs = dim_inputs 223 | self.to_freqs = ModuleList([]) 224 | dim_hidden = dim * mlp_expansion_factor 225 | 226 | for dim_in in dim_inputs: 227 | net = [] 228 | 229 | mlp = nn.Sequential( 230 | MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), 231 | nn.GLU(dim=-1) 232 | ) 233 | 234 | self.to_freqs.append(mlp) 235 | 236 | def forward(self, x): 237 | x = x.unbind(dim=-2) 238 | 239 | outs = [] 240 | 241 | for band_features, mlp in zip(x, self.to_freqs): 242 | freq_out = mlp(band_features) 243 | outs.append(freq_out) 244 | 245 | return torch.cat(outs, dim=-1) 246 | 247 | 248 | # main class 249 | 250 | class MelBandRoformer(Module): 251 | def __init__( 252 | self, 253 | dim, 254 | *, 255 | depth, 256 | stereo=False, 257 | num_stems=1, 258 | time_transformer_depth=2, 259 | freq_transformer_depth=2, 260 | num_bands=60, 261 | dim_head=64, 262 | heads=8, 263 | attn_dropout=0.1, 264 | ff_dropout=0.1, 265 | flash_attn=True, 266 | dim_freqs_in=1025, 267 | sample_rate=44100, # needed for mel filter bank from librosa 268 | stft_n_fft=2048, 269 | stft_hop_length=512, 270 | # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction 271 | stft_win_length=2048, 272 | stft_normalized=False, 273 | stft_window_fn = None, 274 | mask_estimator_depth=1, 275 | multi_stft_resolution_loss_weight=1., 276 | multi_stft_resolutions_window_sizes = (4096, 2048, 1024, 512, 256), 277 | multi_stft_hop_size=147, 278 | multi_stft_normalized=False, 279 | multi_stft_window_fn = torch.hann_window, 280 | match_input_audio_length=False, # if True, pad output tensor to match length of input tensor 281 | ): 282 | super().__init__() 283 | 284 | self.stereo = stereo 285 | self.audio_channels = 2 if stereo else 1 286 | self.num_stems = num_stems 287 | 288 | self.layers = ModuleList([]) 289 | 290 | transformer_kwargs = dict( 291 | dim=dim, 292 | heads=heads, 293 | dim_head=dim_head, 294 | attn_dropout=attn_dropout, 295 | ff_dropout=ff_dropout, 296 | flash_attn=flash_attn 297 | ) 298 | 299 | time_rotary_embed = RotaryEmbedding(dim=dim_head) 300 | freq_rotary_embed = RotaryEmbedding(dim=dim_head) 301 | 302 | for _ in range(depth): 303 | self.layers.append(nn.ModuleList([ 304 | Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs), 305 | Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) 306 | ])) 307 | 308 | self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) 309 | 310 | self.stft_kwargs = dict( 311 | n_fft=stft_n_fft, 312 | hop_length=stft_hop_length, 313 | win_length=stft_win_length, 314 | normalized=stft_normalized 315 | ) 316 | 317 | freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, return_complex=True).shape[1] 318 | 319 | # create mel filter bank 320 | # with librosa.filters.mel as in section 2 of paper 321 | 322 | mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands) 323 | 324 | mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) 325 | 326 | # for some reason, it doesn't include the first freq? just force a value for now 327 | 328 | mel_filter_bank[0][0] = 1. 329 | 330 | # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position, 331 | # so let's force a positive value 332 | 333 | mel_filter_bank[-1, -1] = 1. 334 | 335 | # binary as in paper (then estimated masks are averaged for overlapping regions) 336 | 337 | freqs_per_band = mel_filter_bank > 0 338 | assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now' 339 | 340 | repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands) 341 | freq_indices = repeated_freq_indices[freqs_per_band] 342 | 343 | if stereo: 344 | freq_indices = repeat(freq_indices, 'f -> f s', s=2) 345 | freq_indices = freq_indices * 2 + torch.arange(2) 346 | freq_indices = rearrange(freq_indices, 'f s -> (f s)') 347 | 348 | self.register_buffer('freq_indices', freq_indices, persistent=False) 349 | self.register_buffer('freqs_per_band', freqs_per_band, persistent=False) 350 | 351 | num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') 352 | num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') 353 | 354 | self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False) 355 | self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False) 356 | 357 | # band split and mask estimator 358 | 359 | freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist()) 360 | 361 | self.band_split = BandSplit( 362 | dim=dim, 363 | dim_inputs=freqs_per_bands_with_complex 364 | ) 365 | 366 | self.mask_estimators = nn.ModuleList([]) 367 | 368 | for _ in range(num_stems): 369 | mask_estimator = MaskEstimator( 370 | dim=dim, 371 | dim_inputs=freqs_per_bands_with_complex, 372 | depth=mask_estimator_depth 373 | ) 374 | 375 | self.mask_estimators.append(mask_estimator) 376 | 377 | # for the multi-resolution stft loss 378 | 379 | self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight 380 | self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes 381 | self.multi_stft_n_fft = stft_n_fft 382 | self.multi_stft_window_fn = multi_stft_window_fn 383 | 384 | self.multi_stft_kwargs = dict( 385 | hop_length=multi_stft_hop_size, 386 | normalized=multi_stft_normalized 387 | ) 388 | 389 | self.match_input_audio_length = match_input_audio_length 390 | 391 | def forward( 392 | self, 393 | raw_audio, 394 | target=None, 395 | return_loss_breakdown=False 396 | ): 397 | """ 398 | einops 399 | 400 | b - batch 401 | f - freq 402 | t - time 403 | s - audio channel (1 for mono, 2 for stereo) 404 | n - number of 'stems' 405 | c - complex (2) 406 | d - feature dimension 407 | """ 408 | 409 | device = raw_audio.device 410 | 411 | if raw_audio.ndim == 2: 412 | raw_audio = rearrange(raw_audio, 'b t -> b 1 t') 413 | 414 | batch, channels, raw_audio_length = raw_audio.shape 415 | 416 | istft_length = raw_audio_length if self.match_input_audio_length else None 417 | 418 | assert (not self.stereo and channels == 1) or ( 419 | self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' 420 | 421 | # to stft 422 | 423 | raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') 424 | 425 | stft_window = self.stft_window_fn(device=device) 426 | 427 | stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) 428 | stft_repr = torch.view_as_real(stft_repr) 429 | 430 | stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') 431 | stft_repr = rearrange(stft_repr, 432 | 'b s f t c -> b (f s) t c') # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting 433 | 434 | # index out all frequencies for all frequency ranges across bands ascending in one go 435 | 436 | batch_arange = torch.arange(batch, device=device)[..., None] 437 | 438 | # account for stereo 439 | 440 | x = stft_repr[batch_arange, self.freq_indices] 441 | 442 | # fold the complex (real and imag) into the frequencies dimension 443 | 444 | x = rearrange(x, 'b f t c -> b t (f c)') 445 | 446 | x = self.band_split(x) 447 | 448 | # axial / hierarchical attention 449 | 450 | for time_transformer, freq_transformer in self.layers: 451 | x = rearrange(x, 'b t f d -> b f t d') 452 | x, ps = pack([x], '* t d') 453 | 454 | x = time_transformer(x) 455 | 456 | x, = unpack(x, ps, '* t d') 457 | x = rearrange(x, 'b f t d -> b t f d') 458 | x, ps = pack([x], '* f d') 459 | 460 | x = freq_transformer(x) 461 | 462 | x, = unpack(x, ps, '* f d') 463 | 464 | num_stems = len(self.mask_estimators) 465 | 466 | masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) 467 | masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) 468 | 469 | # modulate frequency representation 470 | 471 | stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') 472 | 473 | # complex number multiplication 474 | 475 | stft_repr = torch.view_as_complex(stft_repr) 476 | masks = torch.view_as_complex(masks) 477 | 478 | masks = masks.type(stft_repr.dtype) 479 | 480 | # need to average the estimated mask for the overlapped frequencies 481 | 482 | scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) 483 | 484 | stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) 485 | masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) 486 | 487 | denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) 488 | 489 | masks_averaged = masks_summed / denom.clamp(min=1e-8) 490 | 491 | # modulate stft repr with estimated mask 492 | 493 | stft_repr = stft_repr * masks_averaged 494 | 495 | # istft 496 | 497 | stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) 498 | 499 | recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, 500 | length=istft_length) 501 | 502 | recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems) 503 | 504 | if num_stems == 1: 505 | recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') 506 | 507 | # if a target is passed in, calculate loss for learning 508 | 509 | if not exists(target): 510 | return recon_audio 511 | 512 | if self.num_stems > 1: 513 | assert target.ndim == 4 and target.shape[1] == self.num_stems 514 | 515 | if target.ndim == 2: 516 | target = rearrange(target, '... t -> ... 1 t') 517 | 518 | target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft 519 | 520 | loss = F.l1_loss(recon_audio, target) 521 | 522 | multi_stft_resolution_loss = 0. 523 | 524 | for window_size in self.multi_stft_resolutions_window_sizes: 525 | res_stft_kwargs = dict( 526 | n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft 527 | win_length=window_size, 528 | return_complex=True, 529 | window=self.multi_stft_window_fn(window_size, device=device), 530 | **self.multi_stft_kwargs, 531 | ) 532 | 533 | recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) 534 | target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) 535 | 536 | multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) 537 | 538 | weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight 539 | 540 | total_loss = loss + weighted_multi_resolution_loss 541 | 542 | if not return_loss_breakdown: 543 | return total_loss 544 | 545 | return total_loss, (loss, multi_stft_resolution_loss) 546 | -------------------------------------------------------------------------------- /model/bs_roformer.py: -------------------------------------------------------------------------------- 1 | #https://github.com/morettt/my-neuro/blob/main/fine_tuning/tools/uvr5/bs_roformer/bs_roformer.py 2 | 3 | from functools import partial 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import Module, ModuleList 8 | import torch.nn.functional as F 9 | 10 | from .attend import Attend 11 | from torch.utils.checkpoint import checkpoint 12 | 13 | from typing import Tuple, Optional, Callable 14 | # from beartype.typing import Tuple, Optional, List, Callable 15 | # from beartype import beartype 16 | 17 | from rotary_embedding_torch import RotaryEmbedding 18 | 19 | from einops import rearrange, pack, unpack 20 | from einops.layers.torch import Rearrange 21 | 22 | # helper functions 23 | 24 | 25 | def exists(val): 26 | return val is not None 27 | 28 | 29 | def default(v, d): 30 | return v if exists(v) else d 31 | 32 | 33 | def pack_one(t, pattern): 34 | return pack([t], pattern) 35 | 36 | 37 | def unpack_one(t, ps, pattern): 38 | return unpack(t, ps, pattern)[0] 39 | 40 | 41 | # norm 42 | 43 | 44 | def l2norm(t): 45 | return F.normalize(t, dim=-1, p=2) 46 | 47 | 48 | class RMSNorm(Module): 49 | def __init__(self, dim): 50 | super().__init__() 51 | self.scale = dim**0.5 52 | self.gamma = nn.Parameter(torch.ones(dim)) 53 | 54 | def forward(self, x): 55 | return F.normalize(x, dim=-1) * self.scale * self.gamma 56 | 57 | 58 | # attention 59 | 60 | 61 | class FeedForward(Module): 62 | def __init__(self, dim, mult=4, dropout=0.0): 63 | super().__init__() 64 | dim_inner = int(dim * mult) 65 | self.net = nn.Sequential( 66 | RMSNorm(dim), 67 | nn.Linear(dim, dim_inner), 68 | nn.GELU(), 69 | nn.Dropout(dropout), 70 | nn.Linear(dim_inner, dim), 71 | nn.Dropout(dropout), 72 | ) 73 | 74 | def forward(self, x): 75 | return self.net(x) 76 | 77 | 78 | class Attention(Module): 79 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None, flash=True): 80 | super().__init__() 81 | self.heads = heads 82 | self.scale = dim_head**-0.5 83 | dim_inner = heads * dim_head 84 | 85 | self.rotary_embed = rotary_embed 86 | 87 | self.attend = Attend(flash=flash, dropout=dropout) 88 | 89 | self.norm = RMSNorm(dim) 90 | self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) 91 | 92 | self.to_gates = nn.Linear(dim, heads) 93 | 94 | self.to_out = nn.Sequential(nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)) 95 | 96 | def forward(self, x): 97 | x = self.norm(x) 98 | 99 | q, k, v = rearrange(self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads) 100 | 101 | if exists(self.rotary_embed): 102 | q = self.rotary_embed.rotate_queries_or_keys(q) 103 | k = self.rotary_embed.rotate_queries_or_keys(k) 104 | 105 | out = self.attend(q, k, v) 106 | 107 | gates = self.to_gates(x) 108 | out = out * rearrange(gates, "b n h -> b h n 1").sigmoid() 109 | 110 | out = rearrange(out, "b h n d -> b n (h d)") 111 | return self.to_out(out) 112 | 113 | 114 | class LinearAttention(Module): 115 | """ 116 | this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. 117 | """ 118 | 119 | # @beartype 120 | def __init__(self, *, dim, dim_head=32, heads=8, scale=8, flash=False, dropout=0.0): 121 | super().__init__() 122 | dim_inner = dim_head * heads 123 | self.norm = RMSNorm(dim) 124 | 125 | self.to_qkv = nn.Sequential( 126 | nn.Linear(dim, dim_inner * 3, bias=False), Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads) 127 | ) 128 | 129 | self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) 130 | 131 | self.attend = Attend(scale=scale, dropout=dropout, flash=flash) 132 | 133 | self.to_out = nn.Sequential(Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)) 134 | 135 | def forward(self, x): 136 | x = self.norm(x) 137 | 138 | q, k, v = self.to_qkv(x) 139 | 140 | q, k = map(l2norm, (q, k)) 141 | q = q * self.temperature.exp() 142 | 143 | out = self.attend(q, k, v) 144 | 145 | return self.to_out(out) 146 | 147 | 148 | class Transformer(Module): 149 | def __init__( 150 | self, 151 | *, 152 | dim, 153 | depth, 154 | dim_head=64, 155 | heads=8, 156 | attn_dropout=0.0, 157 | ff_dropout=0.0, 158 | ff_mult=4, 159 | norm_output=True, 160 | rotary_embed=None, 161 | flash_attn=True, 162 | linear_attn=False, 163 | ): 164 | super().__init__() 165 | self.layers = ModuleList([]) 166 | 167 | for _ in range(depth): 168 | if linear_attn: 169 | attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) 170 | else: 171 | attn = Attention( 172 | dim=dim, 173 | dim_head=dim_head, 174 | heads=heads, 175 | dropout=attn_dropout, 176 | rotary_embed=rotary_embed, 177 | flash=flash_attn, 178 | ) 179 | 180 | self.layers.append(ModuleList([attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)])) 181 | 182 | self.norm = RMSNorm(dim) if norm_output else nn.Identity() 183 | 184 | def forward(self, x): 185 | for attn, ff in self.layers: 186 | x = attn(x) + x 187 | x = ff(x) + x 188 | 189 | return self.norm(x) 190 | 191 | 192 | # bandsplit module 193 | 194 | 195 | class BandSplit(Module): 196 | # @beartype 197 | def __init__(self, dim, dim_inputs: Tuple[int, ...]): 198 | super().__init__() 199 | self.dim_inputs = dim_inputs 200 | self.to_features = ModuleList([]) 201 | 202 | for dim_in in dim_inputs: 203 | net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim)) 204 | 205 | self.to_features.append(net) 206 | 207 | def forward(self, x): 208 | x = x.split(self.dim_inputs, dim=-1) 209 | 210 | outs = [] 211 | for split_input, to_feature in zip(x, self.to_features): 212 | split_output = to_feature(split_input) 213 | outs.append(split_output) 214 | 215 | return torch.stack(outs, dim=-2) 216 | 217 | 218 | def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh): 219 | dim_hidden = default(dim_hidden, dim_in) 220 | 221 | net = [] 222 | dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) 223 | 224 | for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): 225 | is_last = ind == (len(dims) - 2) 226 | 227 | net.append(nn.Linear(layer_dim_in, layer_dim_out)) 228 | 229 | if is_last: 230 | continue 231 | 232 | net.append(activation()) 233 | 234 | return nn.Sequential(*net) 235 | 236 | 237 | class MaskEstimator(Module): 238 | # @beartype 239 | def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4): 240 | super().__init__() 241 | self.dim_inputs = dim_inputs 242 | self.to_freqs = ModuleList([]) 243 | dim_hidden = dim * mlp_expansion_factor 244 | 245 | for dim_in in dim_inputs: 246 | net = [] 247 | 248 | mlp = nn.Sequential(MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)) 249 | 250 | self.to_freqs.append(mlp) 251 | 252 | def forward(self, x): 253 | x = x.unbind(dim=-2) 254 | 255 | outs = [] 256 | 257 | for band_features, mlp in zip(x, self.to_freqs): 258 | freq_out = mlp(band_features) 259 | outs.append(freq_out) 260 | 261 | return torch.cat(outs, dim=-1) 262 | 263 | 264 | # main class 265 | 266 | DEFAULT_FREQS_PER_BANDS = ( 267 | 2, 268 | 2, 269 | 2, 270 | 2, 271 | 2, 272 | 2, 273 | 2, 274 | 2, 275 | 2, 276 | 2, 277 | 2, 278 | 2, 279 | 2, 280 | 2, 281 | 2, 282 | 2, 283 | 2, 284 | 2, 285 | 2, 286 | 2, 287 | 2, 288 | 2, 289 | 2, 290 | 2, 291 | 4, 292 | 4, 293 | 4, 294 | 4, 295 | 4, 296 | 4, 297 | 4, 298 | 4, 299 | 4, 300 | 4, 301 | 4, 302 | 4, 303 | 12, 304 | 12, 305 | 12, 306 | 12, 307 | 12, 308 | 12, 309 | 12, 310 | 12, 311 | 24, 312 | 24, 313 | 24, 314 | 24, 315 | 24, 316 | 24, 317 | 24, 318 | 24, 319 | 48, 320 | 48, 321 | 48, 322 | 48, 323 | 48, 324 | 48, 325 | 48, 326 | 48, 327 | 128, 328 | 129, 329 | ) 330 | 331 | 332 | class BSRoformer(Module): 333 | # @beartype 334 | def __init__( 335 | self, 336 | dim, 337 | *, 338 | depth, 339 | stereo=False, 340 | num_stems=1, 341 | time_transformer_depth=2, 342 | freq_transformer_depth=2, 343 | linear_transformer_depth=0, 344 | freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, 345 | # in the paper, they divide into ~60 bands, test with 1 for starters 346 | dim_head=64, 347 | heads=8, 348 | attn_dropout=0.0, 349 | ff_dropout=0.0, 350 | flash_attn=True, 351 | dim_freqs_in=1025, 352 | stft_n_fft=2048, 353 | stft_hop_length=512, 354 | # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction 355 | stft_win_length=2048, 356 | stft_normalized=False, 357 | stft_window_fn: Optional[Callable] = None, 358 | mask_estimator_depth=2, 359 | multi_stft_resolution_loss_weight=1.0, 360 | multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), 361 | multi_stft_hop_size=147, 362 | multi_stft_normalized=False, 363 | multi_stft_window_fn: Callable = torch.hann_window, 364 | mlp_expansion_factor=4, 365 | use_torch_checkpoint=False, 366 | skip_connection=False, 367 | ): 368 | super().__init__() 369 | 370 | self.stereo = stereo 371 | self.audio_channels = 2 if stereo else 1 372 | self.num_stems = num_stems 373 | self.use_torch_checkpoint = use_torch_checkpoint 374 | self.skip_connection = skip_connection 375 | 376 | self.layers = ModuleList([]) 377 | 378 | transformer_kwargs = dict( 379 | dim=dim, 380 | heads=heads, 381 | dim_head=dim_head, 382 | attn_dropout=attn_dropout, 383 | ff_dropout=ff_dropout, 384 | flash_attn=flash_attn, 385 | norm_output=False, 386 | ) 387 | 388 | time_rotary_embed = RotaryEmbedding(dim=dim_head) 389 | freq_rotary_embed = RotaryEmbedding(dim=dim_head) 390 | 391 | for _ in range(depth): 392 | tran_modules = [] 393 | if linear_transformer_depth > 0: 394 | tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) 395 | tran_modules.append( 396 | Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) 397 | ) 398 | tran_modules.append( 399 | Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) 400 | ) 401 | self.layers.append(nn.ModuleList(tran_modules)) 402 | 403 | self.final_norm = RMSNorm(dim) 404 | 405 | self.stft_kwargs = dict( 406 | n_fft=stft_n_fft, hop_length=stft_hop_length, win_length=stft_win_length, normalized=stft_normalized 407 | ) 408 | 409 | self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) 410 | 411 | freqs = torch.stft( 412 | torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True 413 | ).shape[1] 414 | 415 | assert len(freqs_per_bands) > 1 416 | assert sum(freqs_per_bands) == freqs, ( 417 | f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}" 418 | ) 419 | 420 | freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) 421 | 422 | self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex) 423 | 424 | self.mask_estimators = nn.ModuleList([]) 425 | 426 | for _ in range(num_stems): 427 | mask_estimator = MaskEstimator( 428 | dim=dim, 429 | dim_inputs=freqs_per_bands_with_complex, 430 | depth=mask_estimator_depth, 431 | mlp_expansion_factor=mlp_expansion_factor, 432 | ) 433 | 434 | self.mask_estimators.append(mask_estimator) 435 | 436 | # for the multi-resolution stft loss 437 | 438 | self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight 439 | self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes 440 | self.multi_stft_n_fft = stft_n_fft 441 | self.multi_stft_window_fn = multi_stft_window_fn 442 | 443 | self.multi_stft_kwargs = dict(hop_length=multi_stft_hop_size, normalized=multi_stft_normalized) 444 | 445 | def forward(self, raw_audio, target=None, return_loss_breakdown=False): 446 | """ 447 | einops 448 | 449 | b - batch 450 | f - freq 451 | t - time 452 | s - audio channel (1 for mono, 2 for stereo) 453 | n - number of 'stems' 454 | c - complex (2) 455 | d - feature dimension 456 | """ 457 | 458 | device = raw_audio.device 459 | 460 | # defining whether model is loaded on MPS (MacOS GPU accelerator) 461 | x_is_mps = True if device.type == "mps" else False 462 | 463 | if raw_audio.ndim == 2: 464 | raw_audio = rearrange(raw_audio, "b t -> b 1 t") 465 | 466 | channels = raw_audio.shape[1] 467 | assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), ( 468 | "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)" 469 | ) 470 | 471 | # to stft 472 | 473 | raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t") 474 | 475 | stft_window = self.stft_window_fn(device=device) 476 | 477 | # RuntimeError: FFT operations are only supported on MacOS 14+ 478 | # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used 479 | try: 480 | stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) 481 | except: 482 | stft_repr = torch.stft( 483 | raw_audio.cpu() if x_is_mps else raw_audio, 484 | **self.stft_kwargs, 485 | window=stft_window.cpu() if x_is_mps else stft_window, 486 | return_complex=True, 487 | ).to(device) 488 | 489 | stft_repr = torch.view_as_real(stft_repr) 490 | 491 | stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c") 492 | 493 | # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting 494 | stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c") 495 | 496 | x = rearrange(stft_repr, "b f t c -> b t (f c)") 497 | 498 | if self.use_torch_checkpoint: 499 | x = checkpoint(self.band_split, x, use_reentrant=False) 500 | else: 501 | x = self.band_split(x) 502 | 503 | # axial / hierarchical attention 504 | 505 | store = [None] * len(self.layers) 506 | for i, transformer_block in enumerate(self.layers): 507 | if len(transformer_block) == 3: 508 | linear_transformer, time_transformer, freq_transformer = transformer_block 509 | 510 | x, ft_ps = pack([x], "b * d") 511 | if self.use_torch_checkpoint: 512 | x = checkpoint(linear_transformer, x, use_reentrant=False) 513 | else: 514 | x = linear_transformer(x) 515 | (x,) = unpack(x, ft_ps, "b * d") 516 | else: 517 | time_transformer, freq_transformer = transformer_block 518 | 519 | if self.skip_connection: 520 | # Sum all previous 521 | for j in range(i): 522 | x = x + store[j] 523 | 524 | x = rearrange(x, "b t f d -> b f t d") 525 | x, ps = pack([x], "* t d") 526 | 527 | if self.use_torch_checkpoint: 528 | x = checkpoint(time_transformer, x, use_reentrant=False) 529 | else: 530 | x = time_transformer(x) 531 | 532 | (x,) = unpack(x, ps, "* t d") 533 | x = rearrange(x, "b f t d -> b t f d") 534 | x, ps = pack([x], "* f d") 535 | 536 | if self.use_torch_checkpoint: 537 | x = checkpoint(freq_transformer, x, use_reentrant=False) 538 | else: 539 | x = freq_transformer(x) 540 | 541 | (x,) = unpack(x, ps, "* f d") 542 | 543 | if self.skip_connection: 544 | store[i] = x 545 | 546 | x = self.final_norm(x) 547 | 548 | num_stems = len(self.mask_estimators) 549 | 550 | if self.use_torch_checkpoint: 551 | mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1) 552 | else: 553 | mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) 554 | mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2) 555 | 556 | # modulate frequency representation 557 | 558 | stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c") 559 | 560 | # complex number multiplication 561 | 562 | stft_repr = torch.view_as_complex(stft_repr) 563 | mask = torch.view_as_complex(mask) 564 | 565 | stft_repr = stft_repr * mask 566 | 567 | # istft 568 | 569 | stft_repr = rearrange(stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels) 570 | 571 | # same as torch.stft() fix for MacOS MPS above 572 | try: 573 | recon_audio = torch.istft( 574 | stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1] 575 | ) 576 | except: 577 | recon_audio = torch.istft( 578 | stft_repr.cpu() if x_is_mps else stft_repr, 579 | **self.stft_kwargs, 580 | window=stft_window.cpu() if x_is_mps else stft_window, 581 | return_complex=False, 582 | length=raw_audio.shape[-1], 583 | ).to(device) 584 | 585 | recon_audio = rearrange(recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems) 586 | 587 | if num_stems == 1: 588 | recon_audio = rearrange(recon_audio, "b 1 s t -> b s t") 589 | 590 | # if a target is passed in, calculate loss for learning 591 | 592 | if not exists(target): 593 | return recon_audio 594 | 595 | if self.num_stems > 1: 596 | assert target.ndim == 4 and target.shape[1] == self.num_stems 597 | 598 | if target.ndim == 2: 599 | target = rearrange(target, "... t -> ... 1 t") 600 | 601 | target = target[..., : recon_audio.shape[-1]] # protect against lost length on istft 602 | 603 | loss = F.l1_loss(recon_audio, target) 604 | 605 | multi_stft_resolution_loss = 0.0 606 | 607 | for window_size in self.multi_stft_resolutions_window_sizes: 608 | res_stft_kwargs = dict( 609 | n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft 610 | win_length=window_size, 611 | return_complex=True, 612 | window=self.multi_stft_window_fn(window_size, device=device), 613 | **self.multi_stft_kwargs, 614 | ) 615 | 616 | recon_Y = torch.stft(rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs) 617 | target_Y = torch.stft(rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs) 618 | 619 | multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) 620 | 621 | weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight 622 | 623 | total_loss = loss + weighted_multi_resolution_loss 624 | 625 | if not return_loss_breakdown: 626 | return total_loss 627 | 628 | return total_loss, (loss, multi_stft_resolution_loss) 629 | --------------------------------------------------------------------------------