├── .gitignore ├── mss ├── models │ ├── bandit │ │ ├── core │ │ │ ├── data │ │ │ │ ├── dnr │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── preprocess.py │ │ │ │ │ └── datamodule.py │ │ │ │ ├── musdb │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── validation.yaml │ │ │ │ │ ├── datamodule.py │ │ │ │ │ ├── preprocess.py │ │ │ │ │ └── dataset.py │ │ │ │ ├── __init__.py │ │ │ │ ├── _types.py │ │ │ │ ├── augmented.py │ │ │ │ ├── base.py │ │ │ │ └── augmentation.py │ │ │ ├── utils │ │ │ │ └── __init__.py │ │ │ ├── model │ │ │ │ ├── __init__.py │ │ │ │ ├── bsrnn │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── bandsplit.py │ │ │ │ └── _spectral.py │ │ │ ├── loss │ │ │ │ ├── __init__.py │ │ │ │ ├── _complex.py │ │ │ │ ├── _multistem.py │ │ │ │ ├── _timefreq.py │ │ │ │ └── snr.py │ │ │ └── metrics │ │ │ │ ├── __init__.py │ │ │ │ └── snr.py │ │ └── model_from_config.py │ ├── scnet │ │ ├── __init__.py │ │ └── separation.py │ ├── scnet_unofficial │ │ ├── __init__.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── dualpath_rnn.py │ │ │ └── su_decoder.py │ │ ├── utils.py │ │ └── scnet.py │ ├── bs_roformer │ │ ├── __init__.py │ │ └── attend.py │ ├── upernet_swin_transformers.py │ ├── mdx23c_tfc_tdf_v3.py │ ├── torchseg_models.py │ └── segm_models.py ├── configs │ ├── config_vocals_swin_upernet.yaml │ ├── viperx │ │ ├── model_mel_band_roformer_ep_3005_sdr_11.4360.yaml │ │ ├── model_bs_roformer_ep_317_sdr_12.9755.yaml │ │ └── model_bs_roformer_ep_937_sdr_10.5309.yaml │ ├── config_vocals_scnet_unofficial.yaml │ ├── config_vocals_torchseg.yaml │ ├── config_drumsep.yaml │ ├── config_musdb18_demucs3_mmi.yaml │ ├── config_musdb18_mel_band_roformer.yaml │ ├── config_vocals_scnet.yaml │ ├── config_musdb18_scnet.yaml │ ├── config_vocals_segm_models.yaml │ ├── config_vocals_bandit_bsrnn_multi_mus64.yaml │ ├── config_vocals_mel_band_roformer.yaml │ ├── config_dnr_bandit_bsrnn_multi_mus64.yaml │ ├── config_musdb18_bs_roformer.yaml │ ├── config_musdb18_torchseg.yaml │ ├── config_musdb18_segm_models.yaml │ ├── config_vocals_bs_roformer.yaml │ ├── config_vocals_mdx23c.yaml │ ├── config_musdb18_htdemucs.yaml │ ├── config_vocals_htdemucs.yaml │ ├── config_htdemucs_6stems.yaml │ └── config_musdb18_mdx23c.yaml ├── inference.py ├── ensemble.py └── utils.py ├── requirements.txt ├── README.md └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | pretrained_models -------------------------------------------------------------------------------- /mss/models/bandit/core/data/dnr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mss/models/bandit/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/musdb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mss/models/scnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .scnet import SCNet 2 | -------------------------------------------------------------------------------- /mss/models/scnet_unofficial/__init__.py: -------------------------------------------------------------------------------- 1 | from .scnet import SCNet -------------------------------------------------------------------------------- /mss/models/bandit/core/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .bsrnn.wrapper import ( 2 | MultiMaskMultiSourceBandSplitRNNSimple, 3 | ) 4 | -------------------------------------------------------------------------------- /mss/models/bs_roformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .bs_roformer import BSRoformer 2 | from .mel_band_roformer import MelBandRoformer 3 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dnr.datamodule import DivideAndRemasterDataModule 2 | from .musdb.datamodule import MUSDB18DataModule -------------------------------------------------------------------------------- /mss/models/bandit/core/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from ._multistem import MultiStemWrapperFromConfig 2 | from ._timefreq import ReImL1Loss, ReImL2Loss, TimeFreqL1Loss, TimeFreqL2Loss, TimeFreqSignalNoisePNormRatioLoss 3 | -------------------------------------------------------------------------------- /mss/models/scnet_unofficial/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from models.scnet_unofficial.modules.dualpath_rnn import DualPathRNN 2 | from models.scnet_unofficial.modules.sd_encoder import SDBlock 3 | from models.scnet_unofficial.modules.su_decoder import SUBlock 4 | -------------------------------------------------------------------------------- /mss/models/bandit/core/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .snr import ( 2 | ChunkMedianScaleInvariantSignalDistortionRatio, 3 | ChunkMedianScaleInvariantSignalNoiseRatio, 4 | ChunkMedianSignalDistortionRatio, 5 | ChunkMedianSignalNoiseRatio, 6 | SafeSignalDistortionRatio, 7 | ) 8 | 9 | # from .mushra import EstimatedMushraScore 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | scipy 4 | ml_collections 5 | tqdm 6 | segmentation_models_pytorch 7 | timm 8 | audiomentations 9 | pedalboard 10 | omegaconf 11 | beartype 12 | rotary_embedding_torch 13 | einops 14 | demucs 15 | transformers 16 | torchmetrics 17 | spafe 18 | protobuf 19 | torch_audiomentations 20 | asteroid 21 | auraloss 22 | torchseg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VocalSeparation-ComfyUI 2 | a custom node for separation vocals from music based on [ZFTurbo/Music-Source-Separation-Training](https://github.com/ZFTurbo/Music-Source-Separation-Training) 3 | 4 | ## How to use 5 | ``` 6 | ## in ComfyUI/custom_nodes 7 | git clone https://github.com/AIFSH/VocalSeparation-ComfyUI.git 8 | cd VocalSeparation-ComfyUI 9 | pip install -r requirements.txt 10 | ``` 11 | weights will be downloaded from github automatically -------------------------------------------------------------------------------- /mss/models/bandit/core/data/_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, TypedDict 2 | 3 | import torch 4 | 5 | AudioDict = Dict[str, torch.Tensor] 6 | 7 | DataDict = TypedDict('DataDict', {'audio': AudioDict, 'track': str}) 8 | 9 | BatchedDataDict = TypedDict( 10 | 'BatchedDataDict', 11 | {'audio': AudioDict, 'track': Sequence[str]} 12 | ) 13 | 14 | 15 | class DataDictWithLanguage(TypedDict): 16 | audio: AudioDict 17 | track: str 18 | language: str 19 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/musdb/validation.yaml: -------------------------------------------------------------------------------- 1 | validation: 2 | - 'Actions - One Minute Smile' 3 | - 'Clara Berry And Wooldog - Waltz For My Victims' 4 | - 'Johnny Lokke - Promises & Lies' 5 | - 'Patrick Talbot - A Reason To Leave' 6 | - 'Triviul - Angelsaint' 7 | - 'Alexander Ross - Goodbye Bolero' 8 | - 'Fergessen - Nos Palpitants' 9 | - 'Leaf - Summerghost' 10 | - 'Skelpolu - Human Mistakes' 11 | - 'Young Griffo - Pennies' 12 | - 'ANiMAL - Rockshow' 13 | - 'James May - On The Line' 14 | - 'Meaxic - Take A Step' 15 | - 'Traffic Experiment - Sirens' -------------------------------------------------------------------------------- /mss/models/bandit/core/model/bsrnn/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Iterable, Mapping, Union 3 | 4 | from torch import nn 5 | 6 | from models.bandit.core.model.bsrnn.bandsplit import BandSplitModule 7 | from models.bandit.core.model.bsrnn.tfmodel import ( 8 | SeqBandModellingModule, 9 | TransformerTimeFreqModule, 10 | ) 11 | 12 | 13 | class BandsplitCoreBase(nn.Module, ABC): 14 | band_split: nn.Module 15 | tf_model: nn.Module 16 | mask_estim: Union[nn.Module, Mapping[str, nn.Module], Iterable[nn.Module]] 17 | 18 | def __init__(self) -> None: 19 | super().__init__() 20 | 21 | @staticmethod 22 | def mask(x, m): 23 | return x * m 24 | -------------------------------------------------------------------------------- /mss/models/bandit/model_from_config.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path 3 | import torch 4 | 5 | code_path = os.path.dirname(os.path.abspath(__file__)) + '/' 6 | sys.path.append(code_path) 7 | 8 | import yaml 9 | from ml_collections import ConfigDict 10 | 11 | torch.set_float32_matmul_precision("medium") 12 | 13 | 14 | def get_model( 15 | config_path, 16 | weights_path, 17 | device, 18 | ): 19 | from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple 20 | 21 | f = open(config_path) 22 | config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) 23 | f.close() 24 | 25 | model = MultiMaskMultiSourceBandSplitRNNSimple( 26 | **config.model 27 | ) 28 | d = torch.load(code_path + 'model_bandit_plus_dnr_sdr_11.47.chpt') 29 | model.load_state_dict(d) 30 | model.to(device) 31 | return model, config 32 | -------------------------------------------------------------------------------- /mss/models/bandit/core/loss/_complex.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.modules import loss as _loss 6 | from torch.nn.modules.loss import _Loss 7 | 8 | 9 | class ReImLossWrapper(_Loss): 10 | def __init__(self, module: _Loss) -> None: 11 | super().__init__() 12 | self.module = module 13 | 14 | def forward( 15 | self, 16 | preds: torch.Tensor, 17 | target: torch.Tensor 18 | ) -> torch.Tensor: 19 | return self.module( 20 | torch.view_as_real(preds), 21 | torch.view_as_real(target) 22 | ) 23 | 24 | 25 | class ReImL1Loss(ReImLossWrapper): 26 | def __init__(self, **kwargs: Any) -> None: 27 | l1_loss = _loss.L1Loss(**kwargs) 28 | super().__init__(module=(l1_loss)) 29 | 30 | 31 | class ReImL2Loss(ReImLossWrapper): 32 | def __init__(self, **kwargs: Any) -> None: 33 | l2_loss = _loss.MSELoss(**kwargs) 34 | super().__init__(module=(l2_loss)) 35 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/augmented.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Dict, Optional, Union 3 | 4 | import torch 5 | from torch import nn 6 | from torch.utils import data 7 | 8 | 9 | class AugmentedDataset(data.Dataset): 10 | def __init__( 11 | self, 12 | dataset: data.Dataset, 13 | augmentation: nn.Module = nn.Identity(), 14 | target_length: Optional[int] = None, 15 | ) -> None: 16 | warnings.warn( 17 | "This class is no longer used. Attach augmentation to " 18 | "the LightningSystem instead.", 19 | DeprecationWarning, 20 | ) 21 | 22 | self.dataset = dataset 23 | self.augmentation = augmentation 24 | 25 | self.ds_length: int = len(dataset) # type: ignore[arg-type] 26 | self.length = target_length if target_length is not None else self.ds_length 27 | 28 | def __getitem__(self, index: int) -> Dict[str, Union[str, Dict[str, 29 | torch.Tensor]]]: 30 | item = self.dataset[index % self.ds_length] 31 | item = self.augmentation(item) 32 | return item 33 | 34 | def __len__(self) -> int: 35 | return self.length 36 | -------------------------------------------------------------------------------- /mss/configs/config_vocals_swin_upernet.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 261632 3 | dim_f: 4096 4 | dim_t: 512 5 | hop_length: 512 6 | n_fft: 8192 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | act: gelu 13 | num_channels: 16 14 | num_subbands: 8 15 | 16 | training: 17 | batch_size: 14 18 | gradient_accumulation_steps: 4 19 | grad_clip: 0 20 | instruments: 21 | - vocals 22 | - other 23 | lr: 3.0e-05 24 | patience: 2 25 | reduce_factor: 0.95 26 | target_instrument: null 27 | num_epochs: 1000 28 | num_steps: 1000 29 | q: 0.95 30 | coarse_loss_clip: true 31 | ema_momentum: 0.999 32 | optimizer: adamw 33 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 34 | 35 | augmentations: 36 | enable: true # enable or disable all augmentations (to fast disable if needed) 37 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 38 | loudness_min: 0.5 39 | loudness_max: 1.5 40 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 41 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 42 | - 0.2 43 | - 0.02 44 | mixup_loudness_min: 0.5 45 | mixup_loudness_max: 1.5 46 | 47 | inference: 48 | batch_size: 1 49 | dim_t: 512 50 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/models/bandit/core/data/dnr/preprocess.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torchaudio as ta 7 | from tqdm.contrib.concurrent import process_map 8 | 9 | 10 | def process_one(inputs: Tuple[str, str, int]) -> None: 11 | infile, outfile, target_fs = inputs 12 | 13 | dir = os.path.dirname(outfile) 14 | os.makedirs(dir, exist_ok=True) 15 | 16 | data, fs = ta.load(infile) 17 | 18 | if fs != target_fs: 19 | data = ta.functional.resample(data, fs, target_fs, resampling_method="sinc_interp_kaiser") 20 | fs = target_fs 21 | 22 | data = data.numpy() 23 | data = data.astype(np.float32) 24 | 25 | if os.path.exists(outfile): 26 | data_ = np.load(outfile) 27 | if np.allclose(data, data_): 28 | return 29 | 30 | np.save(outfile, data) 31 | 32 | 33 | def preprocess( 34 | data_path: str, 35 | output_path: str, 36 | fs: int 37 | ) -> None: 38 | files = glob.glob(os.path.join(data_path, "**", "*.wav"), recursive=True) 39 | print(files) 40 | outfiles = [ 41 | f.replace(data_path, output_path).replace(".wav", ".npy") for f in 42 | files 43 | ] 44 | 45 | os.makedirs(output_path, exist_ok=True) 46 | inputs = list(zip(files, outfiles, [fs] * len(files))) 47 | 48 | process_map(process_one, inputs, chunksize=32) 49 | 50 | 51 | if __name__ == "__main__": 52 | import fire 53 | 54 | fire.Fire() 55 | -------------------------------------------------------------------------------- /mss/models/bandit/core/loss/_multistem.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | from asteroid import losses as asteroid_losses 5 | from torch import nn 6 | from torch.nn.modules.loss import _Loss 7 | 8 | from . import snr 9 | 10 | 11 | def parse_loss(name: str, kwargs: Dict[str, Any]) -> _Loss: 12 | 13 | for module in [nn.modules.loss, snr, asteroid_losses, asteroid_losses.sdr]: 14 | if name in module.__dict__: 15 | return module.__dict__[name](**kwargs) 16 | 17 | raise NameError 18 | 19 | 20 | class MultiStemWrapper(_Loss): 21 | def __init__(self, module: _Loss, modality: str = "audio") -> None: 22 | super().__init__() 23 | self.loss = module 24 | self.modality = modality 25 | 26 | def forward( 27 | self, 28 | preds: Dict[str, Dict[str, torch.Tensor]], 29 | target: Dict[str, Dict[str, torch.Tensor]], 30 | ) -> torch.Tensor: 31 | loss = { 32 | stem: self.loss( 33 | preds[self.modality][stem], 34 | target[self.modality][stem] 35 | ) 36 | for stem in preds[self.modality] if stem in target[self.modality] 37 | } 38 | 39 | return sum(list(loss.values())) 40 | 41 | 42 | class MultiStemWrapperFromConfig(MultiStemWrapper): 43 | def __init__(self, name: str, kwargs: Any, modality: str = "audio") -> None: 44 | loss = parse_loss(name, kwargs) 45 | super().__init__(module=loss, modality=modality) 46 | -------------------------------------------------------------------------------- /mss/configs/viperx/model_mel_band_roformer_ep_3005_sdr_11.4360.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 352800 3 | dim_f: 1024 4 | dim_t: 801 # don't work (use in model) 5 | hop_length: 441 # don't work (use in model) 6 | n_fft: 2048 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.000 10 | 11 | model: 12 | dim: 384 13 | depth: 12 14 | stereo: true 15 | num_stems: 1 16 | time_transformer_depth: 1 17 | freq_transformer_depth: 1 18 | linear_transformer_depth: 0 19 | num_bands: 60 20 | dim_head: 64 21 | heads: 8 22 | attn_dropout: 0.1 23 | ff_dropout: 0.1 24 | flash_attn: True 25 | dim_freqs_in: 1025 26 | sample_rate: 44100 # needed for mel filter bank from librosa 27 | stft_n_fft: 2048 28 | stft_hop_length: 441 29 | stft_win_length: 2048 30 | stft_normalized: False 31 | mask_estimator_depth: 2 32 | multi_stft_resolution_loss_weight: 1.0 33 | multi_stft_resolutions_window_sizes: !!python/tuple 34 | - 4096 35 | - 2048 36 | - 1024 37 | - 512 38 | - 256 39 | multi_stft_hop_size: 147 40 | multi_stft_normalized: False 41 | 42 | training: 43 | batch_size: 1 44 | gradient_accumulation_steps: 8 45 | grad_clip: 0 46 | instruments: 47 | - vocals 48 | - other 49 | lr: 4.0e-05 50 | patience: 2 51 | reduce_factor: 0.95 52 | target_instrument: vocals 53 | num_epochs: 1000 54 | num_steps: 1000 55 | q: 0.95 56 | coarse_loss_clip: true 57 | ema_momentum: 0.999 58 | optimizer: adam 59 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 60 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 61 | 62 | inference: 63 | batch_size: 4 64 | dim_t: 801 65 | num_overlap: 2 -------------------------------------------------------------------------------- /mss/configs/config_vocals_scnet_unofficial.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 264600 3 | num_channels: 2 4 | sample_rate: 44100 5 | min_mean_abs: 0.000 6 | 7 | model: 8 | dims: [4, 32, 64, 128] 9 | bandsplit_ratios: [.175, .392, .433] 10 | downsample_strides: [1, 4, 16] 11 | n_conv_modules: [3, 2, 1] 12 | n_rnn_layers: 6 13 | rnn_hidden_dim: 128 14 | n_sources: 2 15 | 16 | n_fft: 4096 17 | hop_length: 1024 18 | win_length: 4096 19 | stft_normalized: false 20 | 21 | use_mamba: false 22 | d_state: 16 23 | d_conv: 4 24 | d_expand: 2 25 | 26 | training: 27 | batch_size: 10 28 | gradient_accumulation_steps: 2 29 | grad_clip: 0 30 | instruments: 31 | - vocals 32 | - other 33 | lr: 5.0e-04 34 | patience: 2 35 | reduce_factor: 0.95 36 | target_instrument: null 37 | num_epochs: 1000 38 | num_steps: 1000 39 | q: 0.95 40 | coarse_loss_clip: true 41 | ema_momentum: 0.999 42 | optimizer: adam 43 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 44 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 45 | 46 | augmentations: 47 | enable: true # enable or disable all augmentations (to fast disable if needed) 48 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 49 | loudness_min: 0.5 50 | loudness_max: 1.5 51 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 52 | mixup_probs: 53 | !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 54 | - 0.2 55 | - 0.02 56 | mixup_loudness_min: 0.5 57 | mixup_loudness_max: 1.5 58 | 59 | inference: 60 | batch_size: 8 61 | dim_t: 256 62 | num_overlap: 4 63 | -------------------------------------------------------------------------------- /mss/configs/config_vocals_torchseg.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 261632 3 | dim_f: 4096 4 | dim_t: 512 5 | hop_length: 512 6 | n_fft: 8192 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.000 10 | 11 | model: 12 | encoder_name: maxvit_tiny_tf_512 # look with torchseg.list_encoders(). Currently 858 available 13 | decoder_type: unet # unet, fpn 14 | act: gelu 15 | num_channels: 128 16 | num_subbands: 8 17 | 18 | training: 19 | batch_size: 18 20 | gradient_accumulation_steps: 1 21 | grad_clip: 1.0 22 | instruments: 23 | - vocals 24 | - other 25 | lr: 1.0e-04 26 | patience: 2 27 | reduce_factor: 0.95 28 | target_instrument: null 29 | num_epochs: 1000 30 | num_steps: 1000 31 | q: 0.95 32 | coarse_loss_clip: true 33 | ema_momentum: 0.999 34 | optimizer: radam 35 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 36 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 37 | 38 | augmentations: 39 | enable: false # enable or disable all augmentations (to fast disable if needed) 40 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 41 | loudness_min: 0.5 42 | loudness_max: 1.5 43 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 44 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 45 | - 0.2 46 | - 0.02 47 | mixup_loudness_min: 0.5 48 | mixup_loudness_max: 1.5 49 | 50 | all: 51 | channel_shuffle: 0.5 # Set 0 or lower to disable 52 | random_inverse: 0.1 # inverse track (better lower probability) 53 | random_polarity: 0.5 # polarity change (multiply waveform to -1) 54 | 55 | inference: 56 | batch_size: 8 57 | dim_t: 512 58 | num_overlap: 2 -------------------------------------------------------------------------------- /mss/configs/config_drumsep.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 485100 # samplerate * segment 3 | min_mean_abs: 0.000 4 | hop_length: 1024 5 | 6 | training: 7 | batch_size: 8 8 | gradient_accumulation_steps: 1 9 | grad_clip: 0 10 | segment: 11 11 | shift: 1 12 | samplerate: 44100 13 | channels: 2 14 | normalize: true 15 | instruments: ['kick', 'snare', 'cymbals', 'toms'] 16 | target_instrument: null 17 | num_epochs: 1000 18 | num_steps: 1000 19 | optimizer: adam 20 | lr: 9.0e-05 21 | patience: 2 22 | reduce_factor: 0.95 23 | q: 0.95 24 | coarse_loss_clip: true 25 | ema_momentum: 0.999 26 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 27 | use_amp: false # enable or disable usage of mixed precision (float16) - usually it must be true 28 | 29 | augmentations: 30 | enable: true # enable or disable all augmentations (to fast disable if needed) 31 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 32 | loudness_min: 0.5 33 | loudness_max: 1.5 34 | 35 | inference: 36 | num_overlap: 4 37 | batch_size: 8 38 | 39 | model: hdemucs 40 | 41 | hdemucs: # see demucs/hdemucs.py for a detailed description 42 | channels: 48 43 | channels_time: null 44 | growth: 2 45 | nfft: 4096 46 | wiener_iters: 0 47 | end_iters: 0 48 | wiener_residual: False 49 | cac: True 50 | depth: 6 51 | rewrite: True 52 | hybrid: True 53 | hybrid_old: False 54 | multi_freqs: [] 55 | multi_freqs_depth: 3 56 | freq_emb: 0.2 57 | emb_scale: 10 58 | emb_smooth: True 59 | kernel_size: 8 60 | stride: 4 61 | time_stride: 2 62 | context: 1 63 | context_enc: 0 64 | norm_starts: 4 65 | norm_groups: 4 66 | dconv_mode: 1 67 | dconv_depth: 2 68 | dconv_comp: 4 69 | dconv_attn: 4 70 | dconv_lstm: 4 71 | dconv_init: 0.001 72 | rescale: 0.1 73 | -------------------------------------------------------------------------------- /mss/configs/config_musdb18_demucs3_mmi.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 485100 # samplerate * segment 3 | min_mean_abs: 0.000 4 | hop_length: 1024 5 | 6 | training: 7 | batch_size: 8 8 | gradient_accumulation_steps: 1 9 | grad_clip: 0 10 | segment: 11 11 | shift: 1 12 | samplerate: 44100 13 | channels: 2 14 | normalize: true 15 | instruments: ['drums', 'bass', 'other', 'vocals'] 16 | target_instrument: null 17 | num_epochs: 1000 18 | num_steps: 1000 19 | optimizer: adam 20 | lr: 9.0e-05 21 | patience: 2 22 | reduce_factor: 0.95 23 | q: 0.95 24 | coarse_loss_clip: true 25 | ema_momentum: 0.999 26 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 27 | use_amp: false # enable or disable usage of mixed precision (float16) - usually it must be true 28 | 29 | augmentations: 30 | enable: true # enable or disable all augmentations (to fast disable if needed) 31 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 32 | loudness_min: 0.5 33 | loudness_max: 1.5 34 | 35 | inference: 36 | num_overlap: 4 37 | batch_size: 8 38 | 39 | model: hdemucs 40 | 41 | hdemucs: # see demucs/hdemucs.py for a detailed description 42 | channels: 48 43 | channels_time: null 44 | growth: 2 45 | nfft: 4096 46 | wiener_iters: 0 47 | end_iters: 0 48 | wiener_residual: False 49 | cac: True 50 | depth: 6 51 | rewrite: True 52 | hybrid: True 53 | hybrid_old: False 54 | multi_freqs: [] 55 | multi_freqs_depth: 3 56 | freq_emb: 0.2 57 | emb_scale: 10 58 | emb_smooth: True 59 | kernel_size: 8 60 | stride: 4 61 | time_stride: 2 62 | context: 1 63 | context_enc: 0 64 | norm_starts: 4 65 | norm_groups: 4 66 | dconv_mode: 1 67 | dconv_depth: 2 68 | dconv_comp: 4 69 | dconv_attn: 4 70 | dconv_lstm: 4 71 | dconv_init: 0.001 72 | rescale: 0.1 73 | -------------------------------------------------------------------------------- /mss/configs/config_musdb18_mel_band_roformer.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 131584 3 | dim_f: 1024 4 | dim_t: 256 5 | hop_length: 512 6 | n_fft: 2048 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | dim: 192 13 | depth: 8 14 | stereo: true 15 | num_stems: 1 16 | time_transformer_depth: 1 17 | freq_transformer_depth: 1 18 | linear_transformer_depth: 0 19 | num_bands: 60 20 | dim_head: 64 21 | heads: 8 22 | attn_dropout: 0.1 23 | ff_dropout: 0.1 24 | flash_attn: True 25 | dim_freqs_in: 1025 26 | sample_rate: 44100 # needed for mel filter bank from librosa 27 | stft_n_fft: 2048 28 | stft_hop_length: 512 29 | stft_win_length: 2048 30 | stft_normalized: False 31 | mask_estimator_depth: 2 32 | multi_stft_resolution_loss_weight: 1.0 33 | multi_stft_resolutions_window_sizes: !!python/tuple 34 | - 4096 35 | - 2048 36 | - 1024 37 | - 512 38 | - 256 39 | multi_stft_hop_size: 147 40 | multi_stft_normalized: False 41 | 42 | training: 43 | batch_size: 7 44 | gradient_accumulation_steps: 1 45 | grad_clip: 0 46 | instruments: 47 | - vocals 48 | - bass 49 | - drums 50 | - other 51 | lr: 5.0e-05 52 | patience: 2 53 | reduce_factor: 0.95 54 | target_instrument: vocals 55 | num_epochs: 1000 56 | num_steps: 1000 57 | q: 0.95 58 | coarse_loss_clip: true 59 | ema_momentum: 0.999 60 | optimizer: adam 61 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 62 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 63 | 64 | augmentations: 65 | enable: true # enable or disable all augmentations (to fast disable if needed) 66 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 67 | loudness_min: 0.5 68 | loudness_max: 1.5 69 | 70 | inference: 71 | batch_size: 1 72 | dim_t: 256 73 | num_overlap: 4 74 | -------------------------------------------------------------------------------- /mss/configs/config_vocals_scnet.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 485100 # 44100 * 11 3 | num_channels: 2 4 | sample_rate: 44100 5 | min_mean_abs: 0.000 6 | 7 | model: 8 | sources: 9 | - vocals 10 | - other 11 | audio_channels: 2 12 | dims: 13 | - 4 14 | - 32 15 | - 64 16 | - 128 17 | nfft: 4096 18 | hop_size: 1024 19 | win_size: 4096 20 | normalized: True 21 | band_SR: 22 | - 0.175 23 | - 0.392 24 | - 0.433 25 | band_stride: 26 | - 1 27 | - 4 28 | - 16 29 | band_kernel: 30 | - 3 31 | - 4 32 | - 16 33 | conv_depths: 34 | - 3 35 | - 2 36 | - 1 37 | compress: 4 38 | conv_kernel: 3 39 | num_dplayer: 6 40 | expand: 1 41 | 42 | training: 43 | batch_size: 10 44 | gradient_accumulation_steps: 1 45 | grad_clip: 0 46 | instruments: 47 | - vocals 48 | - other 49 | lr: 5.0e-04 50 | patience: 2 51 | reduce_factor: 0.95 52 | target_instrument: null 53 | num_epochs: 1000 54 | num_steps: 1000 55 | q: 0.95 56 | coarse_loss_clip: true 57 | ema_momentum: 0.999 58 | optimizer: adam 59 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 60 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 61 | 62 | augmentations: 63 | enable: true # enable or disable all augmentations (to fast disable if needed) 64 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 65 | loudness_min: 0.5 66 | loudness_max: 1.5 67 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 68 | mixup_probs: 69 | !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 70 | - 0.2 71 | - 0.02 72 | mixup_loudness_min: 0.5 73 | mixup_loudness_max: 1.5 74 | 75 | inference: 76 | batch_size: 8 77 | dim_t: 256 78 | num_overlap: 4 79 | normalize: false 80 | -------------------------------------------------------------------------------- /mss/models/bandit/core/model/_spectral.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | import torchaudio as ta 5 | from torch import nn 6 | 7 | 8 | class _SpectralComponent(nn.Module): 9 | def __init__( 10 | self, 11 | n_fft: int = 2048, 12 | win_length: Optional[int] = 2048, 13 | hop_length: int = 512, 14 | window_fn: str = "hann_window", 15 | wkwargs: Optional[Dict] = None, 16 | power: Optional[int] = None, 17 | center: bool = True, 18 | normalized: bool = True, 19 | pad_mode: str = "constant", 20 | onesided: bool = True, 21 | **kwargs, 22 | ) -> None: 23 | super().__init__() 24 | 25 | assert power is None 26 | 27 | window_fn = torch.__dict__[window_fn] 28 | 29 | self.stft = ( 30 | ta.transforms.Spectrogram( 31 | n_fft=n_fft, 32 | win_length=win_length, 33 | hop_length=hop_length, 34 | pad_mode=pad_mode, 35 | pad=0, 36 | window_fn=window_fn, 37 | wkwargs=wkwargs, 38 | power=power, 39 | normalized=normalized, 40 | center=center, 41 | onesided=onesided, 42 | ) 43 | ) 44 | 45 | self.istft = ( 46 | ta.transforms.InverseSpectrogram( 47 | n_fft=n_fft, 48 | win_length=win_length, 49 | hop_length=hop_length, 50 | pad_mode=pad_mode, 51 | pad=0, 52 | window_fn=window_fn, 53 | wkwargs=wkwargs, 54 | normalized=normalized, 55 | center=center, 56 | onesided=onesided, 57 | ) 58 | ) 59 | -------------------------------------------------------------------------------- /mss/configs/config_musdb18_scnet.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 485100 # 44100 * 11 3 | num_channels: 2 4 | sample_rate: 44100 5 | min_mean_abs: 0.000 6 | 7 | model: 8 | sources: 9 | - drums 10 | - bass 11 | - other 12 | - vocals 13 | audio_channels: 2 14 | dims: 15 | - 4 16 | - 32 17 | - 64 18 | - 128 19 | nfft: 4096 20 | hop_size: 1024 21 | win_size: 4096 22 | normalized: True 23 | band_SR: 24 | - 0.175 25 | - 0.392 26 | - 0.433 27 | band_stride: 28 | - 1 29 | - 4 30 | - 16 31 | band_kernel: 32 | - 3 33 | - 4 34 | - 16 35 | conv_depths: 36 | - 3 37 | - 2 38 | - 1 39 | compress: 4 40 | conv_kernel: 3 41 | num_dplayer: 6 42 | expand: 1 43 | 44 | training: 45 | batch_size: 10 46 | gradient_accumulation_steps: 1 47 | grad_clip: 0 48 | instruments: 49 | - drums 50 | - bass 51 | - other 52 | - vocals 53 | lr: 5.0e-04 54 | patience: 2 55 | reduce_factor: 0.95 56 | target_instrument: null 57 | num_epochs: 1000 58 | num_steps: 1000 59 | q: 0.95 60 | coarse_loss_clip: true 61 | ema_momentum: 0.999 62 | optimizer: adam 63 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 64 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 65 | 66 | augmentations: 67 | enable: true # enable or disable all augmentations (to fast disable if needed) 68 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 69 | loudness_min: 0.5 70 | loudness_max: 1.5 71 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 72 | mixup_probs: 73 | !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 74 | - 0.2 75 | - 0.02 76 | mixup_loudness_min: 0.5 77 | mixup_loudness_max: 1.5 78 | 79 | inference: 80 | batch_size: 8 81 | dim_t: 256 82 | num_overlap: 4 83 | normalize: true 84 | -------------------------------------------------------------------------------- /mss/configs/config_vocals_segm_models.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 261632 3 | dim_f: 4096 4 | dim_t: 512 5 | hop_length: 512 6 | n_fft: 8192 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | encoder_name: tu-maxvit_large_tf_512 # look here for possibilities: https://github.com/qubvel/segmentation_models.pytorch#encoders- 13 | decoder_type: unet # unet, fpn 14 | act: gelu 15 | num_channels: 128 16 | num_subbands: 8 17 | 18 | loss_multistft: 19 | fft_sizes: 20 | - 1024 21 | - 2048 22 | - 4096 23 | hop_sizes: 24 | - 512 25 | - 1024 26 | - 2048 27 | win_lengths: 28 | - 1024 29 | - 2048 30 | - 4096 31 | window: "hann_window" 32 | scale: "mel" 33 | n_bins: 128 34 | sample_rate: 44100 35 | perceptual_weighting: true 36 | w_sc: 1.0 37 | w_log_mag: 1.0 38 | w_lin_mag: 0.0 39 | w_phs: 0.0 40 | mag_distance: "L1" 41 | 42 | 43 | training: 44 | batch_size: 8 45 | gradient_accumulation_steps: 1 46 | grad_clip: 0 47 | instruments: 48 | - vocals 49 | - other 50 | lr: 5.0e-05 51 | patience: 2 52 | reduce_factor: 0.95 53 | target_instrument: null 54 | num_epochs: 1000 55 | num_steps: 2000 56 | q: 0.95 57 | coarse_loss_clip: true 58 | ema_momentum: 0.999 59 | optimizer: adamw 60 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 61 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 62 | 63 | augmentations: 64 | enable: true # enable or disable all augmentations (to fast disable if needed) 65 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 66 | loudness_min: 0.5 67 | loudness_max: 1.5 68 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 69 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 70 | - 0.2 71 | - 0.02 72 | mixup_loudness_min: 0.5 73 | mixup_loudness_max: 1.5 74 | 75 | inference: 76 | batch_size: 1 77 | dim_t: 512 78 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/configs/config_vocals_bandit_bsrnn_multi_mus64.yaml: -------------------------------------------------------------------------------- 1 | name: "MultiMaskMultiSourceBandSplitRNN" 2 | audio: 3 | chunk_size: 264600 4 | num_channels: 2 5 | sample_rate: 44100 6 | min_mean_abs: 0.001 7 | 8 | model: 9 | in_channel: 1 10 | stems: ['vocals', 'other'] 11 | band_specs: "musical" 12 | n_bands: 64 13 | fs: 44100 14 | require_no_overlap: false 15 | require_no_gap: true 16 | normalize_channel_independently: false 17 | treat_channel_as_feature: true 18 | n_sqm_modules: 8 19 | emb_dim: 128 20 | rnn_dim: 256 21 | bidirectional: true 22 | rnn_type: "GRU" 23 | mlp_dim: 512 24 | hidden_activation: "Tanh" 25 | hidden_activation_kwargs: null 26 | complex_mask: true 27 | n_fft: 2048 28 | win_length: 2048 29 | hop_length: 512 30 | window_fn: "hann_window" 31 | wkwargs: null 32 | power: null 33 | center: true 34 | normalized: true 35 | pad_mode: "constant" 36 | onesided: true 37 | 38 | training: 39 | batch_size: 4 40 | gradient_accumulation_steps: 4 41 | grad_clip: 0 42 | instruments: 43 | - vocals 44 | - other 45 | lr: 9.0e-05 46 | patience: 2 47 | reduce_factor: 0.95 48 | target_instrument: null 49 | num_epochs: 1000 50 | num_steps: 1000 51 | q: 0.95 52 | coarse_loss_clip: true 53 | ema_momentum: 0.999 54 | optimizer: adam 55 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 56 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 57 | 58 | augmentations: 59 | enable: true # enable or disable all augmentations (to fast disable if needed) 60 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 61 | loudness_min: 0.5 62 | loudness_max: 1.5 63 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 64 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 65 | - 0.2 66 | - 0.02 67 | mixup_loudness_min: 0.5 68 | mixup_loudness_max: 1.5 69 | 70 | inference: 71 | batch_size: 1 72 | dim_t: 256 73 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/models/bandit/core/data/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from typing import Any, Dict, List, Optional 4 | 5 | import numpy as np 6 | import pedalboard as pb 7 | import torch 8 | import torchaudio as ta 9 | from torch.utils import data 10 | 11 | from models.bandit.core.data._types import AudioDict, DataDict 12 | 13 | 14 | class BaseSourceSeparationDataset(data.Dataset, ABC): 15 | def __init__( 16 | self, split: str, 17 | stems: List[str], 18 | files: List[str], 19 | data_path: str, 20 | fs: int, 21 | npy_memmap: bool, 22 | recompute_mixture: bool 23 | ): 24 | self.split = split 25 | self.stems = stems 26 | self.stems_no_mixture = [s for s in stems if s != "mixture"] 27 | self.files = files 28 | self.data_path = data_path 29 | self.fs = fs 30 | self.npy_memmap = npy_memmap 31 | self.recompute_mixture = recompute_mixture 32 | 33 | @abstractmethod 34 | def get_stem( 35 | self, 36 | *, 37 | stem: str, 38 | identifier: Dict[str, Any] 39 | ) -> torch.Tensor: 40 | raise NotImplementedError 41 | 42 | def _get_audio(self, stems, identifier: Dict[str, Any]): 43 | audio = {} 44 | for stem in stems: 45 | audio[stem] = self.get_stem(stem=stem, identifier=identifier) 46 | 47 | return audio 48 | 49 | def get_audio(self, identifier: Dict[str, Any]) -> AudioDict: 50 | 51 | if self.recompute_mixture: 52 | audio = self._get_audio( 53 | self.stems_no_mixture, 54 | identifier=identifier 55 | ) 56 | audio["mixture"] = self.compute_mixture(audio) 57 | return audio 58 | else: 59 | return self._get_audio(self.stems, identifier=identifier) 60 | 61 | @abstractmethod 62 | def get_identifier(self, index: int) -> Dict[str, Any]: 63 | pass 64 | 65 | def compute_mixture(self, audio: AudioDict) -> torch.Tensor: 66 | 67 | return sum( 68 | audio[stem] for stem in audio if stem != "mixture" 69 | ) 70 | -------------------------------------------------------------------------------- /mss/configs/config_vocals_mel_band_roformer.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 131584 3 | dim_f: 1024 4 | dim_t: 256 5 | hop_length: 512 6 | n_fft: 2048 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | dim: 192 13 | depth: 8 14 | stereo: true 15 | num_stems: 1 16 | time_transformer_depth: 1 17 | freq_transformer_depth: 1 18 | linear_transformer_depth: 0 19 | num_bands: 60 20 | dim_head: 64 21 | heads: 8 22 | attn_dropout: 0.1 23 | ff_dropout: 0.1 24 | flash_attn: True 25 | dim_freqs_in: 1025 26 | sample_rate: 44100 # needed for mel filter bank from librosa 27 | stft_n_fft: 2048 28 | stft_hop_length: 512 29 | stft_win_length: 2048 30 | stft_normalized: False 31 | mask_estimator_depth: 2 32 | multi_stft_resolution_loss_weight: 1.0 33 | multi_stft_resolutions_window_sizes: !!python/tuple 34 | - 4096 35 | - 2048 36 | - 1024 37 | - 512 38 | - 256 39 | multi_stft_hop_size: 147 40 | multi_stft_normalized: False 41 | 42 | training: 43 | batch_size: 7 44 | gradient_accumulation_steps: 1 45 | grad_clip: 0 46 | instruments: 47 | - vocals 48 | - other 49 | lr: 5.0e-05 50 | patience: 2 51 | reduce_factor: 0.95 52 | target_instrument: vocals 53 | num_epochs: 1000 54 | num_steps: 1000 55 | q: 0.95 56 | coarse_loss_clip: true 57 | ema_momentum: 0.999 58 | optimizer: adam 59 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 60 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 61 | 62 | augmentations: 63 | enable: true # enable or disable all augmentations (to fast disable if needed) 64 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 65 | loudness_min: 0.5 66 | loudness_max: 1.5 67 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 68 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 69 | - 0.2 70 | - 0.02 71 | mixup_loudness_min: 0.5 72 | mixup_loudness_max: 1.5 73 | 74 | inference: 75 | batch_size: 1 76 | dim_t: 256 77 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/configs/config_dnr_bandit_bsrnn_multi_mus64.yaml: -------------------------------------------------------------------------------- 1 | name: "MultiMaskMultiSourceBandSplitRNN" 2 | audio: 3 | chunk_size: 264600 4 | num_channels: 2 5 | sample_rate: 44100 6 | min_mean_abs: 0.001 7 | 8 | model: 9 | in_channel: 1 10 | stems: ['speech', 'music', 'effects'] 11 | band_specs: "musical" 12 | n_bands: 64 13 | fs: 44100 14 | require_no_overlap: false 15 | require_no_gap: true 16 | normalize_channel_independently: false 17 | treat_channel_as_feature: true 18 | n_sqm_modules: 8 19 | emb_dim: 128 20 | rnn_dim: 256 21 | bidirectional: true 22 | rnn_type: "GRU" 23 | mlp_dim: 512 24 | hidden_activation: "Tanh" 25 | hidden_activation_kwargs: null 26 | complex_mask: true 27 | n_fft: 2048 28 | win_length: 2048 29 | hop_length: 512 30 | window_fn: "hann_window" 31 | wkwargs: null 32 | power: null 33 | center: true 34 | normalized: true 35 | pad_mode: "constant" 36 | onesided: true 37 | 38 | training: 39 | batch_size: 4 40 | gradient_accumulation_steps: 4 41 | grad_clip: 0 42 | instruments: 43 | - speech 44 | - music 45 | - effects 46 | lr: 9.0e-05 47 | patience: 2 48 | reduce_factor: 0.95 49 | target_instrument: null 50 | num_epochs: 1000 51 | num_steps: 1000 52 | q: 0.95 53 | coarse_loss_clip: true 54 | ema_momentum: 0.999 55 | optimizer: adam 56 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 57 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 58 | 59 | augmentations: 60 | enable: true # enable or disable all augmentations (to fast disable if needed) 61 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 62 | loudness_min: 0.5 63 | loudness_max: 1.5 64 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 65 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 66 | - 0.2 67 | - 0.02 68 | mixup_loudness_min: 0.5 69 | mixup_loudness_max: 1.5 70 | all: 71 | channel_shuffle: 0.5 # Set 0 or lower to disable 72 | random_inverse: 0.1 # inverse track (better lower probability) 73 | random_polarity: 0.5 # polarity change (multiply waveform to -1) 74 | 75 | inference: 76 | batch_size: 1 77 | dim_t: 256 78 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/models/bandit/core/data/dnr/datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Mapping, Optional 3 | 4 | import pytorch_lightning as pl 5 | 6 | from .dataset import ( 7 | DivideAndRemasterDataset, 8 | DivideAndRemasterDeterministicChunkDataset, 9 | DivideAndRemasterRandomChunkDataset, 10 | DivideAndRemasterRandomChunkDatasetWithSpeechReverb 11 | ) 12 | 13 | 14 | def DivideAndRemasterDataModule( 15 | data_root: str = "$DATA_ROOT/DnR/v2", 16 | batch_size: int = 2, 17 | num_workers: int = 8, 18 | train_kwargs: Optional[Mapping] = None, 19 | val_kwargs: Optional[Mapping] = None, 20 | test_kwargs: Optional[Mapping] = None, 21 | datamodule_kwargs: Optional[Mapping] = None, 22 | use_speech_reverb: bool = False 23 | # augmentor=None 24 | ) -> pl.LightningDataModule: 25 | if train_kwargs is None: 26 | train_kwargs = {} 27 | 28 | if val_kwargs is None: 29 | val_kwargs = {} 30 | 31 | if test_kwargs is None: 32 | test_kwargs = {} 33 | 34 | if datamodule_kwargs is None: 35 | datamodule_kwargs = {} 36 | 37 | if num_workers is None: 38 | num_workers = os.cpu_count() 39 | 40 | if num_workers is None: 41 | num_workers = 32 42 | 43 | num_workers = min(num_workers, 64) 44 | 45 | if use_speech_reverb: 46 | train_cls = DivideAndRemasterRandomChunkDatasetWithSpeechReverb 47 | else: 48 | train_cls = DivideAndRemasterRandomChunkDataset 49 | 50 | train_dataset = train_cls( 51 | data_root, "train", **train_kwargs 52 | ) 53 | 54 | # if augmentor is not None: 55 | # train_dataset = AugmentedDataset(train_dataset, augmentor) 56 | 57 | datamodule = pl.LightningDataModule.from_datasets( 58 | train_dataset=train_dataset, 59 | val_dataset=DivideAndRemasterDeterministicChunkDataset( 60 | data_root, "val", **val_kwargs 61 | ), 62 | test_dataset=DivideAndRemasterDataset( 63 | data_root, 64 | "test", 65 | **test_kwargs 66 | ), 67 | batch_size=batch_size, 68 | num_workers=num_workers, 69 | **datamodule_kwargs 70 | ) 71 | 72 | datamodule.predict_dataloader = datamodule.test_dataloader # type: ignore[method-assign] 73 | 74 | return datamodule 75 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/musdb/datamodule.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Mapping, Optional 3 | 4 | import pytorch_lightning as pl 5 | 6 | from models.bandit.core.data.musdb.dataset import ( 7 | MUSDB18BaseDataset, 8 | MUSDB18FullTrackDataset, 9 | MUSDB18SadDataset, 10 | MUSDB18SadOnTheFlyAugmentedDataset 11 | ) 12 | 13 | 14 | def MUSDB18DataModule( 15 | data_root: str = "$DATA_ROOT/MUSDB18/HQ", 16 | target_stem: str = "vocals", 17 | batch_size: int = 2, 18 | num_workers: int = 8, 19 | train_kwargs: Optional[Mapping] = None, 20 | val_kwargs: Optional[Mapping] = None, 21 | test_kwargs: Optional[Mapping] = None, 22 | datamodule_kwargs: Optional[Mapping] = None, 23 | use_on_the_fly: bool = True, 24 | npy_memmap: bool = True 25 | ) -> pl.LightningDataModule: 26 | if train_kwargs is None: 27 | train_kwargs = {} 28 | 29 | if val_kwargs is None: 30 | val_kwargs = {} 31 | 32 | if test_kwargs is None: 33 | test_kwargs = {} 34 | 35 | if datamodule_kwargs is None: 36 | datamodule_kwargs = {} 37 | 38 | train_dataset: MUSDB18BaseDataset 39 | 40 | if use_on_the_fly: 41 | train_dataset = MUSDB18SadOnTheFlyAugmentedDataset( 42 | data_root=os.path.join(data_root, "saded-np"), 43 | split="train", 44 | target_stem=target_stem, 45 | **train_kwargs 46 | ) 47 | else: 48 | train_dataset = MUSDB18SadDataset( 49 | data_root=os.path.join(data_root, "saded-np"), 50 | split="train", 51 | target_stem=target_stem, 52 | **train_kwargs 53 | ) 54 | 55 | datamodule = pl.LightningDataModule.from_datasets( 56 | train_dataset=train_dataset, 57 | val_dataset=MUSDB18SadDataset( 58 | data_root=os.path.join(data_root, "saded-np"), 59 | split="val", 60 | target_stem=target_stem, 61 | **val_kwargs 62 | ), 63 | test_dataset=MUSDB18FullTrackDataset( 64 | data_root=os.path.join(data_root, "canonical"), 65 | split="test", 66 | **test_kwargs 67 | ), 68 | batch_size=batch_size, 69 | num_workers=num_workers, 70 | **datamodule_kwargs 71 | ) 72 | 73 | datamodule.predict_dataloader = ( # type: ignore[method-assign] 74 | datamodule.test_dataloader 75 | ) 76 | 77 | return datamodule 78 | -------------------------------------------------------------------------------- /mss/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 352800 3 | dim_f: 1024 4 | dim_t: 801 # don't work (use in model) 5 | hop_length: 441 # don't work (use in model) 6 | n_fft: 2048 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.000 10 | 11 | model: 12 | dim: 512 13 | depth: 12 14 | stereo: true 15 | num_stems: 1 16 | time_transformer_depth: 1 17 | freq_transformer_depth: 1 18 | linear_transformer_depth: 0 19 | freqs_per_bands: !!python/tuple 20 | - 2 21 | - 2 22 | - 2 23 | - 2 24 | - 2 25 | - 2 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | - 2 31 | - 2 32 | - 2 33 | - 2 34 | - 2 35 | - 2 36 | - 2 37 | - 2 38 | - 2 39 | - 2 40 | - 2 41 | - 2 42 | - 2 43 | - 2 44 | - 4 45 | - 4 46 | - 4 47 | - 4 48 | - 4 49 | - 4 50 | - 4 51 | - 4 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | - 12 57 | - 12 58 | - 12 59 | - 12 60 | - 12 61 | - 12 62 | - 12 63 | - 12 64 | - 24 65 | - 24 66 | - 24 67 | - 24 68 | - 24 69 | - 24 70 | - 24 71 | - 24 72 | - 48 73 | - 48 74 | - 48 75 | - 48 76 | - 48 77 | - 48 78 | - 48 79 | - 48 80 | - 128 81 | - 129 82 | dim_head: 64 83 | heads: 8 84 | attn_dropout: 0.1 85 | ff_dropout: 0.1 86 | flash_attn: true 87 | dim_freqs_in: 1025 88 | stft_n_fft: 2048 89 | stft_hop_length: 441 90 | stft_win_length: 2048 91 | stft_normalized: false 92 | mask_estimator_depth: 2 93 | multi_stft_resolution_loss_weight: 1.0 94 | multi_stft_resolutions_window_sizes: !!python/tuple 95 | - 4096 96 | - 2048 97 | - 1024 98 | - 512 99 | - 256 100 | multi_stft_hop_size: 147 101 | multi_stft_normalized: False 102 | 103 | training: 104 | batch_size: 2 105 | gradient_accumulation_steps: 1 106 | grad_clip: 0 107 | instruments: 108 | - vocals 109 | - other 110 | lr: 1.0e-05 111 | patience: 2 112 | reduce_factor: 0.95 113 | target_instrument: vocals 114 | num_epochs: 1000 115 | num_steps: 1000 116 | q: 0.95 117 | coarse_loss_clip: true 118 | ema_momentum: 0.999 119 | optimizer: adam 120 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 121 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 122 | 123 | inference: 124 | batch_size: 4 125 | dim_t: 801 126 | num_overlap: 2 -------------------------------------------------------------------------------- /mss/configs/config_musdb18_bs_roformer.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 131584 3 | dim_f: 1024 4 | dim_t: 256 5 | hop_length: 512 6 | n_fft: 2048 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | dim: 192 13 | depth: 6 14 | stereo: true 15 | num_stems: 1 16 | time_transformer_depth: 1 17 | freq_transformer_depth: 1 18 | linear_transformer_depth: 0 19 | freqs_per_bands: !!python/tuple 20 | - 2 21 | - 2 22 | - 2 23 | - 2 24 | - 2 25 | - 2 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | - 2 31 | - 2 32 | - 2 33 | - 2 34 | - 2 35 | - 2 36 | - 2 37 | - 2 38 | - 2 39 | - 2 40 | - 2 41 | - 2 42 | - 2 43 | - 2 44 | - 4 45 | - 4 46 | - 4 47 | - 4 48 | - 4 49 | - 4 50 | - 4 51 | - 4 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | - 12 57 | - 12 58 | - 12 59 | - 12 60 | - 12 61 | - 12 62 | - 12 63 | - 12 64 | - 24 65 | - 24 66 | - 24 67 | - 24 68 | - 24 69 | - 24 70 | - 24 71 | - 24 72 | - 48 73 | - 48 74 | - 48 75 | - 48 76 | - 48 77 | - 48 78 | - 48 79 | - 48 80 | - 128 81 | - 129 82 | dim_head: 64 83 | heads: 8 84 | attn_dropout: 0.1 85 | ff_dropout: 0.1 86 | flash_attn: true 87 | dim_freqs_in: 1025 88 | stft_n_fft: 2048 89 | stft_hop_length: 512 90 | stft_win_length: 2048 91 | stft_normalized: false 92 | mask_estimator_depth: 2 93 | multi_stft_resolution_loss_weight: 1.0 94 | multi_stft_resolutions_window_sizes: !!python/tuple 95 | - 4096 96 | - 2048 97 | - 1024 98 | - 512 99 | - 256 100 | multi_stft_hop_size: 147 101 | multi_stft_normalized: False 102 | 103 | training: 104 | batch_size: 10 105 | gradient_accumulation_steps: 1 106 | grad_clip: 0 107 | instruments: 108 | - vocals 109 | - bass 110 | - drums 111 | - other 112 | lr: 5.0e-05 113 | patience: 2 114 | reduce_factor: 0.95 115 | target_instrument: vocals 116 | num_epochs: 1000 117 | num_steps: 1000 118 | q: 0.95 119 | coarse_loss_clip: true 120 | ema_momentum: 0.999 121 | optimizer: adam 122 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 123 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 124 | 125 | augmentations: 126 | enable: true # enable or disable all augmentations (to fast disable if needed) 127 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 128 | loudness_min: 0.5 129 | loudness_max: 1.5 130 | 131 | inference: 132 | batch_size: 1 133 | dim_t: 256 134 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/configs/config_musdb18_torchseg.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 261632 3 | dim_f: 4096 4 | dim_t: 512 5 | hop_length: 512 6 | n_fft: 8192 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | encoder_name: maxvit_tiny_tf_512 # look with torchseg.list_encoders(). Currently 858 available 13 | decoder_type: unet # unet, fpn 14 | act: gelu 15 | num_channels: 128 16 | num_subbands: 8 17 | 18 | training: 19 | batch_size: 18 20 | gradient_accumulation_steps: 1 21 | grad_clip: 0 22 | instruments: 23 | - vocals 24 | - bass 25 | - drums 26 | - other 27 | lr: 5.0e-05 28 | patience: 2 29 | reduce_factor: 0.95 30 | target_instrument: null 31 | num_epochs: 1000 32 | num_steps: 2000 33 | q: 0.95 34 | coarse_loss_clip: true 35 | ema_momentum: 0.999 36 | optimizer: adamw 37 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 38 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 39 | 40 | augmentations: 41 | enable: true # enable or disable all augmentations (to fast disable if needed) 42 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 43 | loudness_min: 0.5 44 | loudness_max: 1.5 45 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 46 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 47 | - 0.2 48 | - 0.02 49 | mixup_loudness_min: 0.5 50 | mixup_loudness_max: 1.5 51 | 52 | # apply mp3 compression to mixture only (emulate downloading mp3 from internet) 53 | mp3_compression_on_mixture: 0.01 54 | mp3_compression_on_mixture_bitrate_min: 32 55 | mp3_compression_on_mixture_bitrate_max: 320 56 | mp3_compression_on_mixture_backend: "lameenc" 57 | 58 | all: 59 | channel_shuffle: 0.5 # Set 0 or lower to disable 60 | random_inverse: 0.1 # inverse track (better lower probability) 61 | random_polarity: 0.5 # polarity change (multiply waveform to -1) 62 | mp3_compression: 0.01 63 | mp3_compression_min_bitrate: 32 64 | mp3_compression_max_bitrate: 320 65 | mp3_compression_backend: "lameenc" 66 | 67 | vocals: 68 | pitch_shift: 0.1 69 | pitch_shift_min_semitones: -5 70 | pitch_shift_max_semitones: 5 71 | seven_band_parametric_eq: 0.25 72 | seven_band_parametric_eq_min_gain_db: -9 73 | seven_band_parametric_eq_max_gain_db: 9 74 | tanh_distortion: 0.1 75 | tanh_distortion_min: 0.1 76 | tanh_distortion_max: 0.7 77 | other: 78 | pitch_shift: 0.1 79 | pitch_shift_min_semitones: -4 80 | pitch_shift_max_semitones: 4 81 | gaussian_noise: 0.1 82 | gaussian_noise_min_amplitude: 0.001 83 | gaussian_noise_max_amplitude: 0.015 84 | time_stretch: 0.01 85 | time_stretch_min_rate: 0.8 86 | time_stretch_max_rate: 1.25 87 | 88 | 89 | inference: 90 | batch_size: 1 91 | dim_t: 512 92 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/configs/config_musdb18_segm_models.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 261632 3 | dim_f: 4096 4 | dim_t: 512 5 | hop_length: 512 6 | n_fft: 8192 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | encoder_name: tu-maxvit_large_tf_512 # look here for possibilities: https://github.com/qubvel/segmentation_models.pytorch#encoders- 13 | decoder_type: unet # unet, fpn 14 | act: gelu 15 | num_channels: 128 16 | num_subbands: 8 17 | 18 | training: 19 | batch_size: 7 20 | gradient_accumulation_steps: 1 21 | grad_clip: 0 22 | instruments: 23 | - vocals 24 | - bass 25 | - drums 26 | - other 27 | lr: 5.0e-05 28 | patience: 2 29 | reduce_factor: 0.95 30 | target_instrument: null 31 | num_epochs: 1000 32 | num_steps: 2000 33 | q: 0.95 34 | coarse_loss_clip: true 35 | ema_momentum: 0.999 36 | optimizer: adamw 37 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 38 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 39 | 40 | augmentations: 41 | enable: true # enable or disable all augmentations (to fast disable if needed) 42 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 43 | loudness_min: 0.5 44 | loudness_max: 1.5 45 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 46 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 47 | - 0.2 48 | - 0.02 49 | mixup_loudness_min: 0.5 50 | mixup_loudness_max: 1.5 51 | 52 | # apply mp3 compression to mixture only (emulate downloading mp3 from internet) 53 | mp3_compression_on_mixture: 0.01 54 | mp3_compression_on_mixture_bitrate_min: 32 55 | mp3_compression_on_mixture_bitrate_max: 320 56 | mp3_compression_on_mixture_backend: "lameenc" 57 | 58 | all: 59 | channel_shuffle: 0.5 # Set 0 or lower to disable 60 | random_inverse: 0.1 # inverse track (better lower probability) 61 | random_polarity: 0.5 # polarity change (multiply waveform to -1) 62 | mp3_compression: 0.01 63 | mp3_compression_min_bitrate: 32 64 | mp3_compression_max_bitrate: 320 65 | mp3_compression_backend: "lameenc" 66 | 67 | vocals: 68 | pitch_shift: 0.1 69 | pitch_shift_min_semitones: -5 70 | pitch_shift_max_semitones: 5 71 | seven_band_parametric_eq: 0.25 72 | seven_band_parametric_eq_min_gain_db: -9 73 | seven_band_parametric_eq_max_gain_db: 9 74 | tanh_distortion: 0.1 75 | tanh_distortion_min: 0.1 76 | tanh_distortion_max: 0.7 77 | other: 78 | pitch_shift: 0.1 79 | pitch_shift_min_semitones: -4 80 | pitch_shift_max_semitones: 4 81 | gaussian_noise: 0.1 82 | gaussian_noise_min_amplitude: 0.001 83 | gaussian_noise_max_amplitude: 0.015 84 | time_stretch: 0.01 85 | time_stretch_min_rate: 0.8 86 | time_stretch_max_rate: 1.25 87 | 88 | 89 | inference: 90 | batch_size: 1 91 | dim_t: 512 92 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/configs/config_vocals_bs_roformer.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 131584 3 | dim_f: 1024 4 | dim_t: 256 5 | hop_length: 512 6 | n_fft: 2048 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | dim: 192 13 | depth: 6 14 | stereo: true 15 | num_stems: 1 16 | time_transformer_depth: 1 17 | freq_transformer_depth: 1 18 | linear_transformer_depth: 0 19 | freqs_per_bands: !!python/tuple 20 | - 2 21 | - 2 22 | - 2 23 | - 2 24 | - 2 25 | - 2 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | - 2 31 | - 2 32 | - 2 33 | - 2 34 | - 2 35 | - 2 36 | - 2 37 | - 2 38 | - 2 39 | - 2 40 | - 2 41 | - 2 42 | - 2 43 | - 2 44 | - 4 45 | - 4 46 | - 4 47 | - 4 48 | - 4 49 | - 4 50 | - 4 51 | - 4 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | - 12 57 | - 12 58 | - 12 59 | - 12 60 | - 12 61 | - 12 62 | - 12 63 | - 12 64 | - 24 65 | - 24 66 | - 24 67 | - 24 68 | - 24 69 | - 24 70 | - 24 71 | - 24 72 | - 48 73 | - 48 74 | - 48 75 | - 48 76 | - 48 77 | - 48 78 | - 48 79 | - 48 80 | - 128 81 | - 129 82 | dim_head: 64 83 | heads: 8 84 | attn_dropout: 0.1 85 | ff_dropout: 0.1 86 | flash_attn: true 87 | dim_freqs_in: 1025 88 | stft_n_fft: 2048 89 | stft_hop_length: 512 90 | stft_win_length: 2048 91 | stft_normalized: false 92 | mask_estimator_depth: 2 93 | multi_stft_resolution_loss_weight: 1.0 94 | multi_stft_resolutions_window_sizes: !!python/tuple 95 | - 4096 96 | - 2048 97 | - 1024 98 | - 512 99 | - 256 100 | multi_stft_hop_size: 147 101 | multi_stft_normalized: False 102 | 103 | training: 104 | batch_size: 10 105 | gradient_accumulation_steps: 1 106 | grad_clip: 0 107 | instruments: 108 | - vocals 109 | - other 110 | lr: 5.0e-05 111 | patience: 2 112 | reduce_factor: 0.95 113 | target_instrument: vocals 114 | num_epochs: 1000 115 | num_steps: 1000 116 | q: 0.95 117 | coarse_loss_clip: true 118 | ema_momentum: 0.999 119 | optimizer: adam 120 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 121 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 122 | 123 | augmentations: 124 | enable: true # enable or disable all augmentations (to fast disable if needed) 125 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 126 | loudness_min: 0.5 127 | loudness_max: 1.5 128 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 129 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 130 | - 0.2 131 | - 0.02 132 | mixup_loudness_min: 0.5 133 | mixup_loudness_max: 1.5 134 | 135 | inference: 136 | batch_size: 1 137 | dim_t: 256 138 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/configs/config_vocals_mdx23c.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 261120 3 | dim_f: 4096 4 | dim_t: 256 5 | hop_length: 1024 6 | n_fft: 8192 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | act: gelu 13 | bottleneck_factor: 4 14 | growth: 128 15 | norm: InstanceNorm 16 | num_blocks_per_scale: 2 17 | num_channels: 128 18 | num_scales: 5 19 | num_subbands: 4 20 | scale: 21 | - 2 22 | - 2 23 | 24 | training: 25 | batch_size: 6 26 | gradient_accumulation_steps: 1 27 | grad_clip: 0 28 | instruments: 29 | - vocals 30 | - other 31 | lr: 9.0e-05 32 | patience: 2 33 | reduce_factor: 0.95 34 | target_instrument: null 35 | num_epochs: 1000 36 | num_steps: 1000 37 | q: 0.95 38 | coarse_loss_clip: true 39 | ema_momentum: 0.999 40 | optimizer: adam 41 | read_metadata_procs: 8 # Number of processes to use during metadata reading for dataset. Can speed up metadata generation 42 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 43 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 44 | 45 | augmentations: 46 | enable: true # enable or disable all augmentations (to fast disable if needed) 47 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 48 | loudness_min: 0.5 49 | loudness_max: 1.5 50 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 51 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 52 | - 0.2 53 | - 0.02 54 | mixup_loudness_min: 0.5 55 | mixup_loudness_max: 1.5 56 | 57 | # apply mp3 compression to mixture only (emulate downloading mp3 from internet) 58 | mp3_compression_on_mixture: 0.01 59 | mp3_compression_on_mixture_bitrate_min: 32 60 | mp3_compression_on_mixture_bitrate_max: 320 61 | mp3_compression_on_mixture_backend: "lameenc" 62 | 63 | all: 64 | channel_shuffle: 0.5 # Set 0 or lower to disable 65 | random_inverse: 0.1 # inverse track (better lower probability) 66 | random_polarity: 0.5 # polarity change (multiply waveform to -1) 67 | mp3_compression: 0.01 68 | mp3_compression_min_bitrate: 32 69 | mp3_compression_max_bitrate: 320 70 | mp3_compression_backend: "lameenc" 71 | 72 | vocals: 73 | pitch_shift: 0.1 74 | pitch_shift_min_semitones: -5 75 | pitch_shift_max_semitones: 5 76 | seven_band_parametric_eq: 0.25 77 | seven_band_parametric_eq_min_gain_db: -9 78 | seven_band_parametric_eq_max_gain_db: 9 79 | tanh_distortion: 0.1 80 | tanh_distortion_min: 0.1 81 | tanh_distortion_max: 0.7 82 | other: 83 | pitch_shift: 0.1 84 | pitch_shift_min_semitones: -4 85 | pitch_shift_max_semitones: 4 86 | gaussian_noise: 0.1 87 | gaussian_noise_min_amplitude: 0.001 88 | gaussian_noise_max_amplitude: 0.015 89 | time_stretch: 0.01 90 | time_stretch_min_rate: 0.8 91 | time_stretch_max_rate: 1.25 92 | 93 | inference: 94 | batch_size: 1 95 | dim_t: 256 96 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 131584 3 | dim_f: 1024 4 | dim_t: 256 5 | hop_length: 512 6 | n_fft: 2048 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | dim: 384 13 | depth: 12 14 | stereo: true 15 | num_stems: 1 16 | time_transformer_depth: 1 17 | freq_transformer_depth: 1 18 | linear_transformer_depth: 0 19 | freqs_per_bands: !!python/tuple 20 | - 2 21 | - 2 22 | - 2 23 | - 2 24 | - 2 25 | - 2 26 | - 2 27 | - 2 28 | - 2 29 | - 2 30 | - 2 31 | - 2 32 | - 2 33 | - 2 34 | - 2 35 | - 2 36 | - 2 37 | - 2 38 | - 2 39 | - 2 40 | - 2 41 | - 2 42 | - 2 43 | - 2 44 | - 4 45 | - 4 46 | - 4 47 | - 4 48 | - 4 49 | - 4 50 | - 4 51 | - 4 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | - 12 57 | - 12 58 | - 12 59 | - 12 60 | - 12 61 | - 12 62 | - 12 63 | - 12 64 | - 24 65 | - 24 66 | - 24 67 | - 24 68 | - 24 69 | - 24 70 | - 24 71 | - 24 72 | - 48 73 | - 48 74 | - 48 75 | - 48 76 | - 48 77 | - 48 78 | - 48 79 | - 48 80 | - 128 81 | - 129 82 | dim_head: 64 83 | heads: 8 84 | attn_dropout: 0.1 85 | ff_dropout: 0.1 86 | flash_attn: true 87 | dim_freqs_in: 1025 88 | stft_n_fft: 2048 89 | stft_hop_length: 512 90 | stft_win_length: 2048 91 | stft_normalized: false 92 | mask_estimator_depth: 2 93 | multi_stft_resolution_loss_weight: 1.0 94 | multi_stft_resolutions_window_sizes: !!python/tuple 95 | - 4096 96 | - 2048 97 | - 1024 98 | - 512 99 | - 256 100 | multi_stft_hop_size: 147 101 | multi_stft_normalized: False 102 | 103 | training: 104 | batch_size: 4 105 | gradient_accumulation_steps: 1 106 | grad_clip: 0 107 | instruments: 108 | - vocals 109 | - other 110 | lr: 5.0e-05 111 | patience: 2 112 | reduce_factor: 0.95 113 | target_instrument: other 114 | num_epochs: 1000 115 | num_steps: 1000 116 | q: 0.95 117 | coarse_loss_clip: true 118 | ema_momentum: 0.999 119 | optimizer: adam 120 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 121 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 122 | 123 | augmentations: 124 | enable: true # enable or disable all augmentations (to fast disable if needed) 125 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 126 | loudness_min: 0.5 127 | loudness_max: 1.5 128 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 129 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 130 | - 0.2 131 | - 0.02 132 | mixup_loudness_min: 0.5 133 | mixup_loudness_max: 1.5 134 | 135 | inference: 136 | batch_size: 8 137 | dim_t: 512 138 | num_overlap: 2 -------------------------------------------------------------------------------- /mss/configs/config_musdb18_htdemucs.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 485100 # samplerate * segment 3 | min_mean_abs: 0.001 4 | hop_length: 1024 5 | 6 | training: 7 | batch_size: 8 8 | gradient_accumulation_steps: 1 9 | grad_clip: 0 10 | segment: 11 11 | shift: 1 12 | samplerate: 44100 13 | channels: 2 14 | normalize: true 15 | instruments: ['drums', 'bass', 'other', 'vocals'] 16 | target_instrument: null 17 | num_epochs: 1000 18 | num_steps: 1000 19 | optimizer: adam 20 | lr: 9.0e-05 21 | patience: 2 22 | reduce_factor: 0.95 23 | q: 0.95 24 | coarse_loss_clip: true 25 | ema_momentum: 0.999 26 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 27 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 28 | 29 | augmentations: 30 | enable: true # enable or disable all augmentations (to fast disable if needed) 31 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 32 | loudness_min: 0.5 33 | loudness_max: 1.5 34 | 35 | inference: 36 | num_overlap: 4 37 | batch_size: 8 38 | 39 | model: htdemucs 40 | 41 | htdemucs: # see demucs/htdemucs.py for a detailed description 42 | # Channels 43 | channels: 48 44 | channels_time: 45 | growth: 2 46 | # STFT 47 | num_subbands: 1 48 | nfft: 4096 49 | wiener_iters: 0 50 | end_iters: 0 51 | wiener_residual: false 52 | cac: true 53 | # Main structure 54 | depth: 4 55 | rewrite: true 56 | # Frequency Branch 57 | multi_freqs: [] 58 | multi_freqs_depth: 3 59 | freq_emb: 0.2 60 | emb_scale: 10 61 | emb_smooth: true 62 | # Convolutions 63 | kernel_size: 8 64 | stride: 4 65 | time_stride: 2 66 | context: 1 67 | context_enc: 0 68 | # normalization 69 | norm_starts: 4 70 | norm_groups: 4 71 | # DConv residual branch 72 | dconv_mode: 3 73 | dconv_depth: 2 74 | dconv_comp: 8 75 | dconv_init: 1e-3 76 | # Before the Transformer 77 | bottom_channels: 512 78 | # CrossTransformer 79 | # ------ Common to all 80 | # Regular parameters 81 | t_layers: 5 82 | t_hidden_scale: 4.0 83 | t_heads: 8 84 | t_dropout: 0.0 85 | t_layer_scale: True 86 | t_gelu: True 87 | # ------------- Positional Embedding 88 | t_emb: sin 89 | t_max_positions: 10000 # for the scaled embedding 90 | t_max_period: 10000.0 91 | t_weight_pos_embed: 1.0 92 | t_cape_mean_normalize: True 93 | t_cape_augment: True 94 | t_cape_glob_loc_scale: [5000.0, 1.0, 1.4] 95 | t_sin_random_shift: 0 96 | # ------------- norm before a transformer encoder 97 | t_norm_in: True 98 | t_norm_in_group: False 99 | # ------------- norm inside the encoder 100 | t_group_norm: False 101 | t_norm_first: True 102 | t_norm_out: True 103 | # ------------- optim 104 | t_weight_decay: 0.0 105 | t_lr: 106 | # ------------- sparsity 107 | t_sparse_self_attn: False 108 | t_sparse_cross_attn: False 109 | t_mask_type: diag 110 | t_mask_random_seed: 42 111 | t_sparse_attn_window: 400 112 | t_global_window: 100 113 | t_sparsity: 0.95 114 | t_auto_sparsity: False 115 | # Cross Encoder First (False) 116 | t_cross_first: False 117 | # Weight init 118 | rescale: 0.1 119 | 120 | -------------------------------------------------------------------------------- /mss/configs/config_vocals_htdemucs.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 485100 # samplerate * segment 3 | min_mean_abs: 0.001 4 | hop_length: 1024 5 | 6 | training: 7 | batch_size: 10 8 | gradient_accumulation_steps: 1 9 | grad_clip: 0 10 | segment: 11 11 | shift: 1 12 | samplerate: 44100 13 | channels: 2 14 | normalize: true 15 | instruments: ['vocals', 'other'] 16 | target_instrument: null 17 | num_epochs: 1000 18 | num_steps: 1000 19 | optimizer: adam 20 | lr: 9.0e-05 21 | patience: 2 22 | reduce_factor: 0.95 23 | q: 0.95 24 | coarse_loss_clip: true 25 | ema_momentum: 0.999 26 | other_fix: true # it's needed for checking on multisong dataset if other is actually instrumental 27 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 28 | 29 | augmentations: 30 | enable: true # enable or disable all augmentations (to fast disable if needed) 31 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 32 | loudness_min: 0.5 33 | loudness_max: 1.5 34 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 35 | mixup_probs: [0.2, 0.02] 36 | mixup_loudness_min: 0.5 37 | mixup_loudness_max: 1.5 38 | 39 | inference: 40 | num_overlap: 2 41 | batch_size: 8 42 | 43 | model: htdemucs 44 | 45 | htdemucs: # see demucs/htdemucs.py for a detailed description 46 | # Channels 47 | channels: 48 48 | channels_time: 49 | growth: 2 50 | # STFT 51 | num_subbands: 1 52 | nfft: 4096 53 | wiener_iters: 0 54 | end_iters: 0 55 | wiener_residual: false 56 | cac: true 57 | # Main structure 58 | depth: 4 59 | rewrite: true 60 | # Frequency Branch 61 | multi_freqs: [] 62 | multi_freqs_depth: 3 63 | freq_emb: 0.2 64 | emb_scale: 10 65 | emb_smooth: true 66 | # Convolutions 67 | kernel_size: 8 68 | stride: 4 69 | time_stride: 2 70 | context: 1 71 | context_enc: 0 72 | # normalization 73 | norm_starts: 4 74 | norm_groups: 4 75 | # DConv residual branch 76 | dconv_mode: 3 77 | dconv_depth: 2 78 | dconv_comp: 8 79 | dconv_init: 1e-3 80 | # Before the Transformer 81 | bottom_channels: 512 82 | # CrossTransformer 83 | # ------ Common to all 84 | # Regular parameters 85 | t_layers: 5 86 | t_hidden_scale: 4.0 87 | t_heads: 8 88 | t_dropout: 0.0 89 | t_layer_scale: True 90 | t_gelu: True 91 | # ------------- Positional Embedding 92 | t_emb: sin 93 | t_max_positions: 10000 # for the scaled embedding 94 | t_max_period: 10000.0 95 | t_weight_pos_embed: 1.0 96 | t_cape_mean_normalize: True 97 | t_cape_augment: True 98 | t_cape_glob_loc_scale: [5000.0, 1.0, 1.4] 99 | t_sin_random_shift: 0 100 | # ------------- norm before a transformer encoder 101 | t_norm_in: True 102 | t_norm_in_group: False 103 | # ------------- norm inside the encoder 104 | t_group_norm: False 105 | t_norm_first: True 106 | t_norm_out: True 107 | # ------------- optim 108 | t_weight_decay: 0.0 109 | t_lr: 110 | # ------------- sparsity 111 | t_sparse_self_attn: False 112 | t_sparse_cross_attn: False 113 | t_mask_type: diag 114 | t_mask_random_seed: 42 115 | t_sparse_attn_window: 400 116 | t_global_window: 100 117 | t_sparsity: 0.95 118 | t_auto_sparsity: False 119 | # Cross Encoder First (False) 120 | t_cross_first: False 121 | # Weight init 122 | rescale: 0.1 123 | 124 | -------------------------------------------------------------------------------- /mss/configs/config_htdemucs_6stems.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 485100 # samplerate * segment 3 | min_mean_abs: 0.001 4 | hop_length: 1024 5 | 6 | training: 7 | batch_size: 8 8 | gradient_accumulation_steps: 1 9 | grad_clip: 0 10 | segment: 11 11 | shift: 1 12 | samplerate: 44100 13 | channels: 2 14 | normalize: true 15 | instruments: ['drums', 'bass', 'other', 'vocals', 'guitar', 'piano'] 16 | target_instrument: null 17 | num_epochs: 1000 18 | num_steps: 1000 19 | optimizer: adam 20 | lr: 9.0e-05 21 | patience: 2 22 | reduce_factor: 0.95 23 | q: 0.95 24 | coarse_loss_clip: true 25 | ema_momentum: 0.999 26 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 27 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 28 | 29 | augmentations: 30 | enable: true # enable or disable all augmentations (to fast disable if needed) 31 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 32 | loudness_min: 0.5 33 | loudness_max: 1.5 34 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 35 | mixup_probs: [0.2, 0.02] 36 | mixup_loudness_min: 0.5 37 | mixup_loudness_max: 1.5 38 | all: 39 | channel_shuffle: 0.5 # Set 0 or lower to disable 40 | random_inverse: 0.1 # inverse track (better lower probability) 41 | random_polarity: 0.5 # polarity change (multiply waveform to -1) 42 | 43 | inference: 44 | num_overlap: 4 45 | batch_size: 8 46 | 47 | model: htdemucs 48 | 49 | htdemucs: # see demucs/htdemucs.py for a detailed description 50 | # Channels 51 | channels: 48 52 | channels_time: 53 | growth: 2 54 | # STFT 55 | num_subbands: 1 56 | nfft: 4096 57 | wiener_iters: 0 58 | end_iters: 0 59 | wiener_residual: false 60 | cac: true 61 | # Main structure 62 | depth: 4 63 | rewrite: true 64 | # Frequency Branch 65 | multi_freqs: [] 66 | multi_freqs_depth: 3 67 | freq_emb: 0.2 68 | emb_scale: 10 69 | emb_smooth: true 70 | # Convolutions 71 | kernel_size: 8 72 | stride: 4 73 | time_stride: 2 74 | context: 1 75 | context_enc: 0 76 | # normalization 77 | norm_starts: 4 78 | norm_groups: 4 79 | # DConv residual branch 80 | dconv_mode: 3 81 | dconv_depth: 2 82 | dconv_comp: 8 83 | dconv_init: 1e-3 84 | # Before the Transformer 85 | bottom_channels: 0 86 | # CrossTransformer 87 | # ------ Common to all 88 | # Regular parameters 89 | t_layers: 5 90 | t_hidden_scale: 4.0 91 | t_heads: 8 92 | t_dropout: 0.0 93 | t_layer_scale: True 94 | t_gelu: True 95 | # ------------- Positional Embedding 96 | t_emb: sin 97 | t_max_positions: 10000 # for the scaled embedding 98 | t_max_period: 10000.0 99 | t_weight_pos_embed: 1.0 100 | t_cape_mean_normalize: True 101 | t_cape_augment: True 102 | t_cape_glob_loc_scale: [5000.0, 1.0, 1.4] 103 | t_sin_random_shift: 0 104 | # ------------- norm before a transformer encoder 105 | t_norm_in: True 106 | t_norm_in_group: False 107 | # ------------- norm inside the encoder 108 | t_group_norm: False 109 | t_norm_first: True 110 | t_norm_out: True 111 | # ------------- optim 112 | t_weight_decay: 0.0 113 | t_lr: 114 | # ------------- sparsity 115 | t_sparse_self_attn: False 116 | t_sparse_cross_attn: False 117 | t_mask_type: diag 118 | t_mask_random_seed: 42 119 | t_sparse_attn_window: 400 120 | t_global_window: 100 121 | t_sparsity: 0.95 122 | t_auto_sparsity: False 123 | # Cross Encoder First (False) 124 | t_cross_first: False 125 | # Weight init 126 | rescale: 0.1 127 | 128 | -------------------------------------------------------------------------------- /mss/models/bs_roformer/attend.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from packaging import version 3 | from collections import namedtuple 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange, reduce 10 | 11 | # constants 12 | 13 | FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def default(v, d): 21 | return v if exists(v) else d 22 | 23 | def once(fn): 24 | called = False 25 | @wraps(fn) 26 | def inner(x): 27 | nonlocal called 28 | if called: 29 | return 30 | called = True 31 | return fn(x) 32 | return inner 33 | 34 | print_once = once(print) 35 | 36 | # main class 37 | 38 | class Attend(nn.Module): 39 | def __init__( 40 | self, 41 | dropout = 0., 42 | flash = False, 43 | scale = None 44 | ): 45 | super().__init__() 46 | self.scale = scale 47 | self.dropout = dropout 48 | self.attn_dropout = nn.Dropout(dropout) 49 | 50 | self.flash = flash 51 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 52 | 53 | # determine efficient attention configs for cuda and cpu 54 | 55 | self.cpu_config = FlashAttentionConfig(True, True, True) 56 | self.cuda_config = None 57 | 58 | if not torch.cuda.is_available() or not flash: 59 | return 60 | 61 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 62 | 63 | if device_properties.major == 8 and device_properties.minor == 0: 64 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 65 | self.cuda_config = FlashAttentionConfig(True, False, False) 66 | else: 67 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 68 | self.cuda_config = FlashAttentionConfig(False, True, True) 69 | 70 | def flash_attn(self, q, k, v): 71 | _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 72 | 73 | if exists(self.scale): 74 | default_scale = q.shape[-1] ** -0.5 75 | q = q * (self.scale / default_scale) 76 | 77 | # Check if there is a compatible device for flash attention 78 | 79 | config = self.cuda_config if is_cuda else self.cpu_config 80 | 81 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale 82 | 83 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 84 | out = F.scaled_dot_product_attention( 85 | q, k, v, 86 | dropout_p = self.dropout if self.training else 0. 87 | ) 88 | 89 | return out 90 | 91 | def forward(self, q, k, v): 92 | """ 93 | einstein notation 94 | b - batch 95 | h - heads 96 | n, i, j - sequence length (base sequence length, source, target) 97 | d - feature dimension 98 | """ 99 | 100 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device 101 | 102 | scale = default(self.scale, q.shape[-1] ** -0.5) 103 | 104 | if self.flash: 105 | return self.flash_attn(q, k, v) 106 | 107 | # similarity 108 | 109 | sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale 110 | 111 | # attention 112 | 113 | attn = sim.softmax(dim=-1) 114 | attn = self.attn_dropout(attn) 115 | 116 | # aggregate values 117 | 118 | out = einsum(f"b h i j, b h j d -> b h i d", attn, v) 119 | 120 | return out 121 | -------------------------------------------------------------------------------- /mss/models/bandit/core/loss/_timefreq.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.modules.loss import _Loss 6 | 7 | from models.bandit.core.loss._multistem import MultiStemWrapper 8 | from models.bandit.core.loss._complex import ReImL1Loss, ReImL2Loss, ReImLossWrapper 9 | from models.bandit.core.loss.snr import SignalNoisePNormRatio 10 | 11 | class TimeFreqWrapper(_Loss): 12 | def __init__( 13 | self, 14 | time_module: _Loss, 15 | freq_module: Optional[_Loss] = None, 16 | time_weight: float = 1.0, 17 | freq_weight: float = 1.0, 18 | multistem: bool = True, 19 | ) -> None: 20 | super().__init__() 21 | 22 | if freq_module is None: 23 | freq_module = time_module 24 | 25 | if multistem: 26 | time_module = MultiStemWrapper(time_module, modality="audio") 27 | freq_module = MultiStemWrapper(freq_module, modality="spectrogram") 28 | 29 | self.time_module = time_module 30 | self.freq_module = freq_module 31 | 32 | self.time_weight = time_weight 33 | self.freq_weight = freq_weight 34 | 35 | # TODO: add better type hints 36 | def forward(self, preds: Any, target: Any) -> torch.Tensor: 37 | 38 | return self.time_weight * self.time_module( 39 | preds, target 40 | ) + self.freq_weight * self.freq_module(preds, target) 41 | 42 | 43 | class TimeFreqL1Loss(TimeFreqWrapper): 44 | def __init__( 45 | self, 46 | time_weight: float = 1.0, 47 | freq_weight: float = 1.0, 48 | tkwargs: Optional[Dict[str, Any]] = None, 49 | fkwargs: Optional[Dict[str, Any]] = None, 50 | multistem: bool = True, 51 | ) -> None: 52 | if tkwargs is None: 53 | tkwargs = {} 54 | if fkwargs is None: 55 | fkwargs = {} 56 | time_module = (nn.L1Loss(**tkwargs)) 57 | freq_module = ReImL1Loss(**fkwargs) 58 | super().__init__( 59 | time_module, 60 | freq_module, 61 | time_weight, 62 | freq_weight, 63 | multistem 64 | ) 65 | 66 | 67 | class TimeFreqL2Loss(TimeFreqWrapper): 68 | def __init__( 69 | self, 70 | time_weight: float = 1.0, 71 | freq_weight: float = 1.0, 72 | tkwargs: Optional[Dict[str, Any]] = None, 73 | fkwargs: Optional[Dict[str, Any]] = None, 74 | multistem: bool = True, 75 | ) -> None: 76 | if tkwargs is None: 77 | tkwargs = {} 78 | if fkwargs is None: 79 | fkwargs = {} 80 | time_module = nn.MSELoss(**tkwargs) 81 | freq_module = ReImL2Loss(**fkwargs) 82 | super().__init__( 83 | time_module, 84 | freq_module, 85 | time_weight, 86 | freq_weight, 87 | multistem 88 | ) 89 | 90 | 91 | 92 | class TimeFreqSignalNoisePNormRatioLoss(TimeFreqWrapper): 93 | def __init__( 94 | self, 95 | time_weight: float = 1.0, 96 | freq_weight: float = 1.0, 97 | tkwargs: Optional[Dict[str, Any]] = None, 98 | fkwargs: Optional[Dict[str, Any]] = None, 99 | multistem: bool = True, 100 | ) -> None: 101 | if tkwargs is None: 102 | tkwargs = {} 103 | if fkwargs is None: 104 | fkwargs = {} 105 | time_module = SignalNoisePNormRatio(**tkwargs) 106 | freq_module = SignalNoisePNormRatio(**fkwargs) 107 | super().__init__( 108 | time_module, 109 | freq_module, 110 | time_weight, 111 | freq_weight, 112 | multistem 113 | ) 114 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/augmentation.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Dict, Union 3 | 4 | import torch 5 | import torch_audiomentations as tam 6 | from torch import nn 7 | 8 | from models.bandit.core.data._types import BatchedDataDict, DataDict 9 | 10 | 11 | class BaseAugmentor(nn.Module, ABC): 12 | def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ 13 | DataDict, BatchedDataDict]: 14 | raise NotImplementedError 15 | 16 | 17 | class StemAugmentor(BaseAugmentor): 18 | def __init__( 19 | self, 20 | audiomentations: Dict[str, Dict[str, Any]], 21 | fix_clipping: bool = True, 22 | scaler_margin: float = 0.5, 23 | apply_both_default_and_common: bool = False, 24 | ) -> None: 25 | super().__init__() 26 | 27 | augmentations = {} 28 | 29 | self.has_default = "[default]" in audiomentations 30 | self.has_common = "[common]" in audiomentations 31 | self.apply_both_default_and_common = apply_both_default_and_common 32 | 33 | for stem in audiomentations: 34 | if audiomentations[stem]["name"] == "Compose": 35 | augmentations[stem] = getattr( 36 | tam, 37 | audiomentations[stem]["name"] 38 | )( 39 | [ 40 | getattr(tam, aug["name"])(**aug["kwargs"]) 41 | for aug in 42 | audiomentations[stem]["kwargs"]["transforms"] 43 | ], 44 | **audiomentations[stem]["kwargs"]["kwargs"], 45 | ) 46 | else: 47 | augmentations[stem] = getattr( 48 | tam, 49 | audiomentations[stem]["name"] 50 | )( 51 | **audiomentations[stem]["kwargs"] 52 | ) 53 | 54 | self.augmentations = nn.ModuleDict(augmentations) 55 | self.fix_clipping = fix_clipping 56 | self.scaler_margin = scaler_margin 57 | 58 | def check_and_fix_clipping( 59 | self, item: Union[DataDict, BatchedDataDict] 60 | ) -> Union[DataDict, BatchedDataDict]: 61 | max_abs = [] 62 | 63 | for stem in item["audio"]: 64 | max_abs.append(item["audio"][stem].abs().max().item()) 65 | 66 | if max(max_abs) > 1.0: 67 | scaler = 1.0 / (max(max_abs) + torch.rand( 68 | (1,), 69 | device=item["audio"]["mixture"].device 70 | ) * self.scaler_margin) 71 | 72 | for stem in item["audio"]: 73 | item["audio"][stem] *= scaler 74 | 75 | return item 76 | 77 | def forward(self, item: Union[DataDict, BatchedDataDict]) -> Union[ 78 | DataDict, BatchedDataDict]: 79 | 80 | for stem in item["audio"]: 81 | if stem == "mixture": 82 | continue 83 | 84 | if self.has_common: 85 | item["audio"][stem] = self.augmentations["[common]"]( 86 | item["audio"][stem] 87 | ).samples 88 | 89 | if stem in self.augmentations: 90 | item["audio"][stem] = self.augmentations[stem]( 91 | item["audio"][stem] 92 | ).samples 93 | elif self.has_default: 94 | if not self.has_common or self.apply_both_default_and_common: 95 | item["audio"][stem] = self.augmentations["[default]"]( 96 | item["audio"][stem] 97 | ).samples 98 | 99 | item["audio"]["mixture"] = sum( 100 | [item["audio"][stem] for stem in item["audio"] 101 | if stem != "mixture"] 102 | ) # type: ignore[call-overload, assignment] 103 | 104 | if self.fix_clipping: 105 | item = self.check_and_fix_clipping(item) 106 | 107 | return item 108 | -------------------------------------------------------------------------------- /mss/models/scnet/separation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.rnn import LSTM 4 | 5 | 6 | class FeatureConversion(nn.Module): 7 | """ 8 | Integrates into the adjacent Dual-Path layer. 9 | 10 | Args: 11 | channels (int): Number of input channels. 12 | inverse (bool): If True, uses ifft; otherwise, uses rfft. 13 | """ 14 | 15 | def __init__(self, channels, inverse): 16 | super().__init__() 17 | self.inverse = inverse 18 | self.channels = channels 19 | 20 | def forward(self, x): 21 | # B, C, F, T = x.shape 22 | if self.inverse: 23 | x = x.float() 24 | x_r = x[:, :self.channels // 2, :, :] 25 | x_i = x[:, self.channels // 2:, :, :] 26 | x = torch.complex(x_r, x_i) 27 | x = torch.fft.irfft(x, dim=3, norm="ortho") 28 | else: 29 | x = x.float() 30 | x = torch.fft.rfft(x, dim=3, norm="ortho") 31 | x_real = x.real 32 | x_imag = x.imag 33 | x = torch.cat([x_real, x_imag], dim=1) 34 | return x 35 | 36 | 37 | class DualPathRNN(nn.Module): 38 | """ 39 | Dual-Path RNN in Separation Network. 40 | 41 | Args: 42 | d_model (int): The number of expected features in the input (input_size). 43 | expand (int): Expansion factor used to calculate the hidden_size of LSTM. 44 | bidirectional (bool): If True, becomes a bidirectional LSTM. 45 | """ 46 | 47 | def __init__(self, d_model, expand, bidirectional=True): 48 | super(DualPathRNN, self).__init__() 49 | 50 | self.d_model = d_model 51 | self.hidden_size = d_model * expand 52 | self.bidirectional = bidirectional 53 | # Initialize LSTM layers and normalization layers 54 | self.lstm_layers = nn.ModuleList([self._init_lstm_layer(self.d_model, self.hidden_size) for _ in range(2)]) 55 | self.linear_layers = nn.ModuleList([nn.Linear(self.hidden_size * 2, self.d_model) for _ in range(2)]) 56 | self.norm_layers = nn.ModuleList([nn.GroupNorm(1, d_model) for _ in range(2)]) 57 | 58 | def _init_lstm_layer(self, d_model, hidden_size): 59 | return LSTM(d_model, hidden_size, num_layers=1, bidirectional=self.bidirectional, batch_first=True) 60 | 61 | def forward(self, x): 62 | B, C, F, T = x.shape 63 | 64 | # Process dual-path rnn 65 | original_x = x 66 | # Frequency-path 67 | x = self.norm_layers[0](x) 68 | x = x.transpose(1, 3).contiguous().view(B * T, F, C) 69 | x, _ = self.lstm_layers[0](x) 70 | x = self.linear_layers[0](x) 71 | x = x.view(B, T, F, C).transpose(1, 3) 72 | x = x + original_x 73 | 74 | original_x = x 75 | # Time-path 76 | x = self.norm_layers[1](x) 77 | x = x.transpose(1, 2).contiguous().view(B * F, C, T).transpose(1, 2) 78 | x, _ = self.lstm_layers[1](x) 79 | x = self.linear_layers[1](x) 80 | x = x.transpose(1, 2).contiguous().view(B, F, C, T).transpose(1, 2) 81 | x = x + original_x 82 | 83 | return x 84 | 85 | 86 | class SeparationNet(nn.Module): 87 | """ 88 | Implements a simplified Sparse Down-sample block in an encoder architecture. 89 | 90 | Args: 91 | - channels (int): Number input channels. 92 | - expand (int): Expansion factor used to calculate the hidden_size of LSTM. 93 | - num_layers (int): Number of dual-path layers. 94 | """ 95 | 96 | def __init__(self, channels, expand=1, num_layers=6): 97 | super(SeparationNet, self).__init__() 98 | 99 | self.num_layers = num_layers 100 | 101 | self.dp_modules = nn.ModuleList([ 102 | DualPathRNN(channels * (2 if i % 2 == 1 else 1), expand) for i in range(num_layers) 103 | ]) 104 | 105 | self.feature_conversion = nn.ModuleList([ 106 | FeatureConversion(channels * 2, inverse=False if i % 2 == 0 else True) for i in range(num_layers) 107 | ]) 108 | 109 | def forward(self, x): 110 | for i in range(self.num_layers): 111 | x = self.dp_modules[i](x) 112 | x = self.feature_conversion[i](x) 113 | return x 114 | -------------------------------------------------------------------------------- /mss/models/scnet_unofficial/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | SCNet - great paper, great implementation 3 | https://arxiv.org/pdf/2401.13276.pdf 4 | https://github.com/amanteur/SCNet-PyTorch 5 | ''' 6 | 7 | from typing import List, Tuple, Union 8 | 9 | import torch 10 | 11 | 12 | def create_intervals( 13 | splits: List[Union[float, int]] 14 | ) -> List[Union[Tuple[float, float], Tuple[int, int]]]: 15 | """ 16 | Create intervals based on splits provided. 17 | 18 | Args: 19 | - splits (List[Union[float, int]]): List of floats or integers representing splits. 20 | 21 | Returns: 22 | - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals. 23 | """ 24 | start = 0 25 | return [(start, start := start + split) for split in splits] 26 | 27 | 28 | def get_conv_output_shape( 29 | input_shape: int, 30 | kernel_size: int = 1, 31 | padding: int = 0, 32 | dilation: int = 1, 33 | stride: int = 1, 34 | ) -> int: 35 | """ 36 | Compute the output shape of a convolutional layer. 37 | 38 | Args: 39 | - input_shape (int): Input shape. 40 | - kernel_size (int, optional): Kernel size of the convolution. Default is 1. 41 | - padding (int, optional): Padding size. Default is 0. 42 | - dilation (int, optional): Dilation factor. Default is 1. 43 | - stride (int, optional): Stride value. Default is 1. 44 | 45 | Returns: 46 | - int: Output shape. 47 | """ 48 | return int( 49 | (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 50 | ) 51 | 52 | 53 | def get_convtranspose_output_padding( 54 | input_shape: int, 55 | output_shape: int, 56 | kernel_size: int = 1, 57 | padding: int = 0, 58 | dilation: int = 1, 59 | stride: int = 1, 60 | ) -> int: 61 | """ 62 | Compute the output padding for a convolution transpose operation. 63 | 64 | Args: 65 | - input_shape (int): Input shape. 66 | - output_shape (int): Desired output shape. 67 | - kernel_size (int, optional): Kernel size of the convolution. Default is 1. 68 | - padding (int, optional): Padding size. Default is 0. 69 | - dilation (int, optional): Dilation factor. Default is 1. 70 | - stride (int, optional): Stride value. Default is 1. 71 | 72 | Returns: 73 | - int: Output padding. 74 | """ 75 | return ( 76 | output_shape 77 | - (input_shape - 1) * stride 78 | + 2 * padding 79 | - dilation * (kernel_size - 1) 80 | - 1 81 | ) 82 | 83 | 84 | def compute_sd_layer_shapes( 85 | input_shape: int, 86 | bandsplit_ratios: List[float], 87 | downsample_strides: List[int], 88 | n_layers: int, 89 | ) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]: 90 | """ 91 | Compute the shapes for the subband layers. 92 | 93 | Args: 94 | - input_shape (int): Input shape. 95 | - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands. 96 | - downsample_strides (List[int]): Strides for downsampling in each layer. 97 | - n_layers (int): Number of layers. 98 | 99 | Returns: 100 | - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes. 101 | """ 102 | bandsplit_shapes_list = [] 103 | conv2d_shapes_list = [] 104 | for _ in range(n_layers): 105 | bandsplit_intervals = create_intervals(bandsplit_ratios) 106 | bandsplit_shapes = [ 107 | int(right * input_shape) - int(left * input_shape) 108 | for left, right in bandsplit_intervals 109 | ] 110 | conv2d_shapes = [ 111 | get_conv_output_shape(bs, stride=ds) 112 | for bs, ds in zip(bandsplit_shapes, downsample_strides) 113 | ] 114 | input_shape = sum(conv2d_shapes) 115 | bandsplit_shapes_list.append(bandsplit_shapes) 116 | conv2d_shapes_list.append(create_intervals(conv2d_shapes)) 117 | 118 | return bandsplit_shapes_list, conv2d_shapes_list 119 | 120 | 121 | def compute_gcr(subband_shapes: List[List[int]]) -> float: 122 | """ 123 | Compute the global compression ratio. 124 | 125 | Args: 126 | - subband_shapes (List[List[int]]): List of subband shapes. 127 | 128 | Returns: 129 | - float: Global compression ratio. 130 | """ 131 | t = torch.Tensor(subband_shapes) 132 | gcr = torch.stack( 133 | [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] 134 | ).mean() 135 | return float(gcr) -------------------------------------------------------------------------------- /mss/models/bandit/core/metrics/snr.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import numpy as np 4 | import torch 5 | import torchmetrics as tm 6 | from torch._C import _LinAlgError 7 | from torchmetrics import functional as tmF 8 | 9 | 10 | class SafeSignalDistortionRatio(tm.SignalDistortionRatio): 11 | def __init__(self, **kwargs) -> None: 12 | super().__init__(**kwargs) 13 | 14 | def update(self, *args, **kwargs) -> Any: 15 | try: 16 | super().update(*args, **kwargs) 17 | except: 18 | pass 19 | 20 | def compute(self) -> Any: 21 | if self.total == 0: 22 | return torch.tensor(torch.nan) 23 | return super().compute() 24 | 25 | 26 | class BaseChunkMedianSignalRatio(tm.Metric): 27 | def __init__( 28 | self, 29 | func: Callable, 30 | window_size: int, 31 | hop_size: int = None, 32 | zero_mean: bool = False, 33 | ) -> None: 34 | super().__init__() 35 | 36 | # self.zero_mean = zero_mean 37 | self.func = func 38 | self.window_size = window_size 39 | if hop_size is None: 40 | hop_size = window_size 41 | self.hop_size = hop_size 42 | 43 | self.add_state( 44 | "sum_snr", 45 | default=torch.tensor(0.0), 46 | dist_reduce_fx="sum" 47 | ) 48 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 49 | 50 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: 51 | 52 | n_samples = target.shape[-1] 53 | 54 | n_chunks = int( 55 | np.ceil((n_samples - self.window_size) / self.hop_size) + 1 56 | ) 57 | 58 | snr_chunk = [] 59 | 60 | for i in range(n_chunks): 61 | start = i * self.hop_size 62 | 63 | if n_samples - start < self.window_size: 64 | continue 65 | 66 | end = start + self.window_size 67 | 68 | try: 69 | chunk_snr = self.func( 70 | preds[..., start:end], 71 | target[..., start:end] 72 | ) 73 | 74 | # print(preds.shape, chunk_snr.shape) 75 | 76 | if torch.all(torch.isfinite(chunk_snr)): 77 | snr_chunk.append(chunk_snr) 78 | except _LinAlgError: 79 | pass 80 | 81 | snr_chunk = torch.stack(snr_chunk, dim=-1) 82 | snr_batch, _ = torch.nanmedian(snr_chunk, dim=-1) 83 | 84 | self.sum_snr += snr_batch.sum() 85 | self.total += snr_batch.numel() 86 | 87 | def compute(self) -> Any: 88 | return self.sum_snr / self.total 89 | 90 | 91 | class ChunkMedianSignalNoiseRatio(BaseChunkMedianSignalRatio): 92 | def __init__( 93 | self, 94 | window_size: int, 95 | hop_size: int = None, 96 | zero_mean: bool = False 97 | ) -> None: 98 | super().__init__( 99 | func=tmF.signal_noise_ratio, 100 | window_size=window_size, 101 | hop_size=hop_size, 102 | zero_mean=zero_mean, 103 | ) 104 | 105 | 106 | class ChunkMedianScaleInvariantSignalNoiseRatio(BaseChunkMedianSignalRatio): 107 | def __init__( 108 | self, 109 | window_size: int, 110 | hop_size: int = None, 111 | zero_mean: bool = False 112 | ) -> None: 113 | super().__init__( 114 | func=tmF.scale_invariant_signal_noise_ratio, 115 | window_size=window_size, 116 | hop_size=hop_size, 117 | zero_mean=zero_mean, 118 | ) 119 | 120 | 121 | class ChunkMedianSignalDistortionRatio(BaseChunkMedianSignalRatio): 122 | def __init__( 123 | self, 124 | window_size: int, 125 | hop_size: int = None, 126 | zero_mean: bool = False 127 | ) -> None: 128 | super().__init__( 129 | func=tmF.signal_distortion_ratio, 130 | window_size=window_size, 131 | hop_size=hop_size, 132 | zero_mean=zero_mean, 133 | ) 134 | 135 | 136 | class ChunkMedianScaleInvariantSignalDistortionRatio( 137 | BaseChunkMedianSignalRatio 138 | ): 139 | def __init__( 140 | self, 141 | window_size: int, 142 | hop_size: int = None, 143 | zero_mean: bool = False 144 | ) -> None: 145 | super().__init__( 146 | func=tmF.scale_invariant_signal_distortion_ratio, 147 | window_size=window_size, 148 | hop_size=hop_size, 149 | zero_mean=zero_mean, 150 | ) 151 | -------------------------------------------------------------------------------- /mss/models/bandit/core/model/bsrnn/bandsplit.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from models.bandit.core.model.bsrnn.utils import ( 7 | band_widths_from_specs, 8 | check_no_gap, 9 | check_no_overlap, 10 | check_nonzero_bandwidth, 11 | ) 12 | 13 | 14 | class NormFC(nn.Module): 15 | def __init__( 16 | self, 17 | emb_dim: int, 18 | bandwidth: int, 19 | in_channel: int, 20 | normalize_channel_independently: bool = False, 21 | treat_channel_as_feature: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | 25 | self.treat_channel_as_feature = treat_channel_as_feature 26 | 27 | if normalize_channel_independently: 28 | raise NotImplementedError 29 | 30 | reim = 2 31 | 32 | self.norm = nn.LayerNorm(in_channel * bandwidth * reim) 33 | 34 | fc_in = bandwidth * reim 35 | 36 | if treat_channel_as_feature: 37 | fc_in *= in_channel 38 | else: 39 | assert emb_dim % in_channel == 0 40 | emb_dim = emb_dim // in_channel 41 | 42 | self.fc = nn.Linear(fc_in, emb_dim) 43 | 44 | def forward(self, xb): 45 | # xb = (batch, n_time, in_chan, reim * band_width) 46 | 47 | batch, n_time, in_chan, ribw = xb.shape 48 | xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw)) 49 | # (batch, n_time, in_chan * reim * band_width) 50 | 51 | if not self.treat_channel_as_feature: 52 | xb = xb.reshape(batch, n_time, in_chan, ribw) 53 | # (batch, n_time, in_chan, reim * band_width) 54 | 55 | zb = self.fc(xb) 56 | # (batch, n_time, emb_dim) 57 | # OR 58 | # (batch, n_time, in_chan, emb_dim_per_chan) 59 | 60 | if not self.treat_channel_as_feature: 61 | batch, n_time, in_chan, emb_dim_per_chan = zb.shape 62 | # (batch, n_time, in_chan, emb_dim_per_chan) 63 | zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan)) 64 | 65 | return zb # (batch, n_time, emb_dim) 66 | 67 | 68 | class BandSplitModule(nn.Module): 69 | def __init__( 70 | self, 71 | band_specs: List[Tuple[float, float]], 72 | emb_dim: int, 73 | in_channel: int, 74 | require_no_overlap: bool = False, 75 | require_no_gap: bool = True, 76 | normalize_channel_independently: bool = False, 77 | treat_channel_as_feature: bool = True, 78 | ) -> None: 79 | super().__init__() 80 | 81 | check_nonzero_bandwidth(band_specs) 82 | 83 | if require_no_gap: 84 | check_no_gap(band_specs) 85 | 86 | if require_no_overlap: 87 | check_no_overlap(band_specs) 88 | 89 | self.band_specs = band_specs 90 | # list of [fstart, fend) in index. 91 | # Note that fend is exclusive. 92 | self.band_widths = band_widths_from_specs(band_specs) 93 | self.n_bands = len(band_specs) 94 | self.emb_dim = emb_dim 95 | 96 | self.norm_fc_modules = nn.ModuleList( 97 | [ # type: ignore 98 | ( 99 | NormFC( 100 | emb_dim=emb_dim, 101 | bandwidth=bw, 102 | in_channel=in_channel, 103 | normalize_channel_independently=normalize_channel_independently, 104 | treat_channel_as_feature=treat_channel_as_feature, 105 | ) 106 | ) 107 | for bw in self.band_widths 108 | ] 109 | ) 110 | 111 | def forward(self, x: torch.Tensor): 112 | # x = complex spectrogram (batch, in_chan, n_freq, n_time) 113 | 114 | batch, in_chan, _, n_time = x.shape 115 | 116 | z = torch.zeros( 117 | size=(batch, self.n_bands, n_time, self.emb_dim), 118 | device=x.device 119 | ) 120 | 121 | xr = torch.view_as_real(x) # batch, in_chan, n_freq, n_time, 2 122 | xr = torch.permute( 123 | xr, 124 | (0, 3, 1, 4, 2) 125 | ) # batch, n_time, in_chan, 2, n_freq 126 | batch, n_time, in_chan, reim, band_width = xr.shape 127 | for i, nfm in enumerate(self.norm_fc_modules): 128 | # print(f"bandsplit/band{i:02d}") 129 | fstart, fend = self.band_specs[i] 130 | xb = xr[..., fstart:fend] 131 | # (batch, n_time, in_chan, reim, band_width) 132 | xb = torch.reshape(xb, (batch, n_time, in_chan, -1)) 133 | # (batch, n_time, in_chan, reim * band_width) 134 | # z.append(nfm(xb)) # (batch, n_time, emb_dim) 135 | z[:, i, :, :] = nfm(xb.contiguous()) 136 | 137 | # z = torch.stack(z, dim=1) 138 | 139 | return z 140 | -------------------------------------------------------------------------------- /mss/inference.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' 3 | 4 | import argparse 5 | import time 6 | import librosa 7 | from tqdm import tqdm 8 | import sys 9 | import os 10 | import glob 11 | import torch 12 | import numpy as np 13 | import soundfile as sf 14 | import torch.nn as nn 15 | from utils import demix_track, demix_track_demucs, get_model_from_config 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | def run_folder(model, args, config, device, verbose=False): 22 | start_time = time.time() 23 | model.eval() 24 | all_mixtures_path = glob.glob(args.input_folder + '/*.*') 25 | print('Total files found: {}'.format(len(all_mixtures_path))) 26 | 27 | instruments = config.training.instruments 28 | if config.training.target_instrument is not None: 29 | instruments = [config.training.target_instrument] 30 | 31 | if not os.path.isdir(args.store_dir): 32 | os.mkdir(args.store_dir) 33 | 34 | if not verbose: 35 | all_mixtures_path = tqdm(all_mixtures_path) 36 | 37 | for path in all_mixtures_path: 38 | if not verbose: 39 | all_mixtures_path.set_postfix({'track': os.path.basename(path)}) 40 | try: 41 | # mix, sr = sf.read(path) 42 | mix, sr = librosa.load(path, sr=44100, mono=False) 43 | except Exception as e: 44 | print('Can read track: {}'.format(path)) 45 | print('Error message: {}'.format(str(e))) 46 | continue 47 | 48 | # Convert mono to stereo if needed 49 | if len(mix.shape) == 1: 50 | mix = np.stack([mix, mix], axis=0) 51 | 52 | mix_orig = mix.copy() 53 | if 'normalize' in config.inference: 54 | if config.inference['normalize'] is True: 55 | mono = mix.mean(0) 56 | mean = mono.mean() 57 | std = mono.std() 58 | mix = (mix - mean) / std 59 | 60 | mixture = torch.tensor(mix, dtype=torch.float32) 61 | if args.model_type == 'htdemucs': 62 | res = demix_track_demucs(config, model, mixture, device) 63 | else: 64 | res = demix_track(config, model, mixture, device) 65 | 66 | for instr in instruments: 67 | estimates = res[instr].T 68 | if 'normalize' in config.inference: 69 | if config.inference['normalize'] is True: 70 | estimates = estimates * std + mean 71 | sf.write("{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], instr), estimates, sr, subtype='FLOAT') 72 | 73 | if 'vocals' in instruments and args.extract_instrumental: 74 | instrum_file_name = "{}/{}_{}.wav".format(args.store_dir, os.path.basename(path)[:-4], 'instrumental') 75 | estimates = res['vocals'].T 76 | if 'normalize' in config.inference: 77 | if config.inference['normalize'] is True: 78 | estimates = estimates * std + mean 79 | sf.write(instrum_file_name, mix_orig.T - estimates, sr, subtype='FLOAT') 80 | 81 | time.sleep(1) 82 | print("Elapsed time: {:.2f} sec".format(time.time() - start_time)) 83 | 84 | 85 | def proc_folder(args): 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--model_type", type=str, default='mdx23c', help="One of mdx23c, htdemucs, segm_models, mel_band_roformer, bs_roformer, swin_upernet, bandit") 88 | parser.add_argument("--config_path", type=str, help="path to config file") 89 | parser.add_argument("--start_check_point", type=str, default='', help="Initial checkpoint to valid weights") 90 | parser.add_argument("--input_folder", type=str, help="folder with mixtures to process") 91 | parser.add_argument("--store_dir", default="", type=str, help="path to store results as wav file") 92 | parser.add_argument("--device_ids", nargs='+', type=int, default=0, help='list of gpu ids') 93 | parser.add_argument("--extract_instrumental", action='store_true', help="invert vocals to get instrumental if provided") 94 | if args is None: 95 | args = parser.parse_args() 96 | else: 97 | args = parser.parse_args(args) 98 | 99 | torch.backends.cudnn.benchmark = True 100 | 101 | model, config = get_model_from_config(args.model_type, args.config_path) 102 | if args.start_check_point != '': 103 | print('Start from checkpoint: {}'.format(args.start_check_point)) 104 | state_dict = torch.load(args.start_check_point) 105 | if args.model_type == 'htdemucs': 106 | # Fix for htdemucs pround etrained models 107 | if 'state' in state_dict: 108 | state_dict = state_dict['state'] 109 | model.load_state_dict(state_dict) 110 | print("Instruments: {}".format(config.training.instruments)) 111 | 112 | if torch.cuda.is_available(): 113 | device_ids = args.device_ids 114 | if type(device_ids)==int: 115 | device = torch.device(f'cuda:{device_ids}') 116 | model = model.to(device) 117 | else: 118 | device = torch.device(f'cuda:{device_ids[0]}') 119 | model = nn.DataParallel(model, device_ids=device_ids).to(device) 120 | else: 121 | device = 'cpu' 122 | print('CUDA is not avilable. Run inference on CPU. It will be very slow...') 123 | model = model.to(device) 124 | 125 | run_folder(model, args, config, device, verbose=False) 126 | 127 | 128 | if __name__ == "__main__": 129 | proc_folder(None) 130 | -------------------------------------------------------------------------------- /mss/models/bandit/core/loss/snr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | from torch.nn import functional as F 4 | 5 | class SignalNoisePNormRatio(_Loss): 6 | def __init__( 7 | self, 8 | p: float = 1.0, 9 | scale_invariant: bool = False, 10 | zero_mean: bool = False, 11 | take_log: bool = True, 12 | reduction: str = "mean", 13 | EPS: float = 1e-3, 14 | ) -> None: 15 | assert reduction != "sum", NotImplementedError 16 | super().__init__(reduction=reduction) 17 | assert not zero_mean 18 | 19 | self.p = p 20 | 21 | self.EPS = EPS 22 | self.take_log = take_log 23 | 24 | self.scale_invariant = scale_invariant 25 | 26 | def forward( 27 | self, 28 | est_target: torch.Tensor, 29 | target: torch.Tensor 30 | ) -> torch.Tensor: 31 | 32 | target_ = target 33 | if self.scale_invariant: 34 | ndim = target.ndim 35 | dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True) 36 | s_target_energy = ( 37 | torch.sum(target * torch.conj(target), dim=-1, keepdim=True) 38 | ) 39 | 40 | if ndim > 2: 41 | dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True) 42 | s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True) 43 | 44 | target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8) 45 | target = target_ * target_scaler 46 | 47 | if torch.is_complex(est_target): 48 | est_target = torch.view_as_real(est_target) 49 | target = torch.view_as_real(target) 50 | 51 | 52 | batch_size = est_target.shape[0] 53 | est_target = est_target.reshape(batch_size, -1) 54 | target = target.reshape(batch_size, -1) 55 | # target_ = target_.reshape(batch_size, -1) 56 | 57 | if self.p == 1: 58 | e_error = torch.abs(est_target-target).mean(dim=-1) 59 | e_target = torch.abs(target).mean(dim=-1) 60 | elif self.p == 2: 61 | e_error = torch.square(est_target-target).mean(dim=-1) 62 | e_target = torch.square(target).mean(dim=-1) 63 | else: 64 | raise NotImplementedError 65 | 66 | if self.take_log: 67 | loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS)) 68 | else: 69 | loss = (e_error + self.EPS)/(e_target + self.EPS) 70 | 71 | if self.reduction == "mean": 72 | loss = loss.mean() 73 | elif self.reduction == "sum": 74 | loss = loss.sum() 75 | 76 | return loss 77 | 78 | 79 | 80 | class MultichannelSingleSrcNegSDR(_Loss): 81 | def __init__( 82 | self, 83 | sdr_type: str, 84 | p: float = 2.0, 85 | zero_mean: bool = True, 86 | take_log: bool = True, 87 | reduction: str = "mean", 88 | EPS: float = 1e-8, 89 | ) -> None: 90 | assert reduction != "sum", NotImplementedError 91 | super().__init__(reduction=reduction) 92 | 93 | assert sdr_type in ["snr", "sisdr", "sdsdr"] 94 | self.sdr_type = sdr_type 95 | self.zero_mean = zero_mean 96 | self.take_log = take_log 97 | self.EPS = 1e-8 98 | 99 | self.p = p 100 | 101 | def forward( 102 | self, 103 | est_target: torch.Tensor, 104 | target: torch.Tensor 105 | ) -> torch.Tensor: 106 | if target.size() != est_target.size() or target.ndim != 3: 107 | raise TypeError( 108 | f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead" 109 | ) 110 | # Step 1. Zero-mean norm 111 | if self.zero_mean: 112 | mean_source = torch.mean(target, dim=[1, 2], keepdim=True) 113 | mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True) 114 | target = target - mean_source 115 | est_target = est_target - mean_estimate 116 | # Step 2. Pair-wise SI-SDR. 117 | if self.sdr_type in ["sisdr", "sdsdr"]: 118 | # [batch, 1] 119 | dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True) 120 | # [batch, 1] 121 | s_target_energy = ( 122 | torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS 123 | ) 124 | # [batch, time] 125 | scaled_target = dot * target / s_target_energy 126 | else: 127 | # [batch, time] 128 | scaled_target = target 129 | if self.sdr_type in ["sdsdr", "snr"]: 130 | e_noise = est_target - target 131 | else: 132 | e_noise = est_target - scaled_target 133 | # [batch] 134 | 135 | if self.p == 2.0: 136 | losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / ( 137 | torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS 138 | ) 139 | else: 140 | losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / ( 141 | torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS 142 | ) 143 | if self.take_log: 144 | losses = 10 * torch.log10(losses + self.EPS) 145 | losses = losses.mean() if self.reduction == "mean" else losses 146 | return -losses 147 | -------------------------------------------------------------------------------- /mss/configs/config_musdb18_mdx23c.yaml: -------------------------------------------------------------------------------- 1 | audio: 2 | chunk_size: 261120 3 | dim_f: 4096 4 | dim_t: 256 5 | hop_length: 1024 6 | n_fft: 8192 7 | num_channels: 2 8 | sample_rate: 44100 9 | min_mean_abs: 0.001 10 | 11 | model: 12 | act: gelu 13 | bottleneck_factor: 4 14 | growth: 128 15 | norm: InstanceNorm 16 | num_blocks_per_scale: 2 17 | num_channels: 128 18 | num_scales: 5 19 | num_subbands: 4 20 | scale: 21 | - 2 22 | - 2 23 | 24 | training: 25 | batch_size: 6 26 | gradient_accumulation_steps: 1 27 | grad_clip: 0 28 | instruments: 29 | - vocals 30 | - bass 31 | - drums 32 | - other 33 | lr: 9.0e-05 34 | patience: 2 35 | reduce_factor: 0.95 36 | target_instrument: null 37 | num_epochs: 1000 38 | num_steps: 1000 39 | q: 0.95 40 | coarse_loss_clip: true 41 | ema_momentum: 0.999 42 | optimizer: adam 43 | other_fix: false # it's needed for checking on multisong dataset if other is actually instrumental 44 | use_amp: true # enable or disable usage of mixed precision (float16) - usually it must be true 45 | 46 | augmentations: 47 | enable: true # enable or disable all augmentations (to fast disable if needed) 48 | loudness: true # randomly change loudness of each stem on the range (loudness_min; loudness_max) 49 | loudness_min: 0.5 50 | loudness_max: 1.5 51 | mixup: true # mix several stems of same type with some probability (only works for dataset types: 1, 2, 3) 52 | mixup_probs: !!python/tuple # 2 additional stems of the same type (1st with prob 0.2, 2nd with prob 0.02) 53 | - 0.2 54 | - 0.02 55 | mixup_loudness_min: 0.5 56 | mixup_loudness_max: 1.5 57 | 58 | # apply mp3 compression to mixture only (emulate downloading mp3 from internet) 59 | mp3_compression_on_mixture: 0.01 60 | mp3_compression_on_mixture_bitrate_min: 32 61 | mp3_compression_on_mixture_bitrate_max: 320 62 | mp3_compression_on_mixture_backend: "lameenc" 63 | 64 | all: 65 | channel_shuffle: 0.5 # Set 0 or lower to disable 66 | random_inverse: 0.1 # inverse track (better lower probability) 67 | random_polarity: 0.5 # polarity change (multiply waveform to -1) 68 | mp3_compression: 0.01 69 | mp3_compression_min_bitrate: 32 70 | mp3_compression_max_bitrate: 320 71 | mp3_compression_backend: "lameenc" 72 | 73 | # pedalboard reverb block 74 | pedalboard_reverb: 0.01 75 | pedalboard_reverb_room_size_min: 0.1 76 | pedalboard_reverb_room_size_max: 0.9 77 | pedalboard_reverb_damping_min: 0.1 78 | pedalboard_reverb_damping_max: 0.9 79 | pedalboard_reverb_wet_level_min: 0.1 80 | pedalboard_reverb_wet_level_max: 0.9 81 | pedalboard_reverb_dry_level_min: 0.1 82 | pedalboard_reverb_dry_level_max: 0.9 83 | pedalboard_reverb_width_min: 0.9 84 | pedalboard_reverb_width_max: 1.0 85 | 86 | # pedalboard chorus block 87 | pedalboard_chorus: 0.01 88 | pedalboard_chorus_rate_hz_min: 1.0 89 | pedalboard_chorus_rate_hz_max: 7.0 90 | pedalboard_chorus_depth_min: 0.25 91 | pedalboard_chorus_depth_max: 0.95 92 | pedalboard_chorus_centre_delay_ms_min: 3 93 | pedalboard_chorus_centre_delay_ms_max: 10 94 | pedalboard_chorus_feedback_min: 0.0 95 | pedalboard_chorus_feedback_max: 0.5 96 | pedalboard_chorus_mix_min: 0.1 97 | pedalboard_chorus_mix_max: 0.9 98 | 99 | # pedalboard phazer block 100 | pedalboard_phazer: 0.01 101 | pedalboard_phazer_rate_hz_min: 1.0 102 | pedalboard_phazer_rate_hz_max: 10.0 103 | pedalboard_phazer_depth_min: 0.25 104 | pedalboard_phazer_depth_max: 0.95 105 | pedalboard_phazer_centre_frequency_hz_min: 200 106 | pedalboard_phazer_centre_frequency_hz_max: 12000 107 | pedalboard_phazer_feedback_min: 0.0 108 | pedalboard_phazer_feedback_max: 0.5 109 | pedalboard_phazer_mix_min: 0.1 110 | pedalboard_phazer_mix_max: 0.9 111 | 112 | # pedalboard distortion block 113 | pedalboard_distortion: 0.01 114 | pedalboard_distortion_drive_db_min: 1.0 115 | pedalboard_distortion_drive_db_max: 25.0 116 | 117 | # pedalboard pitch shift block 118 | pedalboard_pitch_shift: 0.01 119 | pedalboard_pitch_shift_semitones_min: -7 120 | pedalboard_pitch_shift_semitones_max: 7 121 | 122 | # pedalboard resample block 123 | pedalboard_resample: 0.01 124 | pedalboard_resample_target_sample_rate_min: 4000 125 | pedalboard_resample_target_sample_rate_max: 44100 126 | 127 | # pedalboard bitcrash block 128 | pedalboard_bitcrash: 0.01 129 | pedalboard_bitcrash_bit_depth_min: 4 130 | pedalboard_bitcrash_bit_depth_max: 16 131 | 132 | # pedalboard mp3 compressor block 133 | pedalboard_mp3_compressor: 0.01 134 | pedalboard_mp3_compressor_pedalboard_mp3_compressor_min: 0 135 | pedalboard_mp3_compressor_pedalboard_mp3_compressor_max: 9.999 136 | 137 | vocals: 138 | pitch_shift: 0.1 139 | pitch_shift_min_semitones: -5 140 | pitch_shift_max_semitones: 5 141 | seven_band_parametric_eq: 0.25 142 | seven_band_parametric_eq_min_gain_db: -9 143 | seven_band_parametric_eq_max_gain_db: 9 144 | tanh_distortion: 0.1 145 | tanh_distortion_min: 0.1 146 | tanh_distortion_max: 0.7 147 | bass: 148 | pitch_shift: 0.1 149 | pitch_shift_min_semitones: -2 150 | pitch_shift_max_semitones: 2 151 | seven_band_parametric_eq: 0.25 152 | seven_band_parametric_eq_min_gain_db: -3 153 | seven_band_parametric_eq_max_gain_db: 6 154 | tanh_distortion: 0.2 155 | tanh_distortion_min: 0.1 156 | tanh_distortion_max: 0.5 157 | drums: 158 | pitch_shift: 0.33 159 | pitch_shift_min_semitones: -5 160 | pitch_shift_max_semitones: 5 161 | seven_band_parametric_eq: 0.25 162 | seven_band_parametric_eq_min_gain_db: -9 163 | seven_band_parametric_eq_max_gain_db: 9 164 | tanh_distortion: 0.33 165 | tanh_distortion_min: 0.1 166 | tanh_distortion_max: 0.6 167 | other: 168 | pitch_shift: 0.1 169 | pitch_shift_min_semitones: -4 170 | pitch_shift_max_semitones: 4 171 | gaussian_noise: 0.1 172 | gaussian_noise_min_amplitude: 0.001 173 | gaussian_noise_max_amplitude: 0.015 174 | time_stretch: 0.01 175 | time_stretch_min_rate: 0.8 176 | time_stretch_max_rate: 1.25 177 | 178 | 179 | inference: 180 | batch_size: 1 181 | dim_t: 256 182 | num_overlap: 4 -------------------------------------------------------------------------------- /mss/ensemble.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' 3 | 4 | import os 5 | import librosa 6 | import soundfile as sf 7 | import numpy as np 8 | import argparse 9 | 10 | 11 | def stft(wave, nfft, hl): 12 | wave_left = np.asfortranarray(wave[0]) 13 | wave_right = np.asfortranarray(wave[1]) 14 | spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) 15 | spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) 16 | spec = np.asfortranarray([spec_left, spec_right]) 17 | return spec 18 | 19 | 20 | def istft(spec, hl, length): 21 | spec_left = np.asfortranarray(spec[0]) 22 | spec_right = np.asfortranarray(spec[1]) 23 | wave_left = librosa.istft(spec_left, hop_length=hl, length=length) 24 | wave_right = librosa.istft(spec_right, hop_length=hl, length=length) 25 | wave = np.asfortranarray([wave_left, wave_right]) 26 | return wave 27 | 28 | 29 | def absmax(a, *, axis): 30 | dims = list(a.shape) 31 | dims.pop(axis) 32 | indices = np.ogrid[tuple(slice(0, d) for d in dims)] 33 | argmax = np.abs(a).argmax(axis=axis) 34 | indices.insert((len(a.shape) + axis) % len(a.shape), argmax) 35 | return a[tuple(indices)] 36 | 37 | 38 | def absmin(a, *, axis): 39 | dims = list(a.shape) 40 | dims.pop(axis) 41 | indices = np.ogrid[tuple(slice(0, d) for d in dims)] 42 | argmax = np.abs(a).argmin(axis=axis) 43 | indices.insert((len(a.shape) + axis) % len(a.shape), argmax) 44 | return a[tuple(indices)] 45 | 46 | 47 | def lambda_max(arr, axis=None, key=None, keepdims=False): 48 | idxs = np.argmax(key(arr), axis) 49 | if axis is not None: 50 | idxs = np.expand_dims(idxs, axis) 51 | result = np.take_along_axis(arr, idxs, axis) 52 | if not keepdims: 53 | result = np.squeeze(result, axis=axis) 54 | return result 55 | else: 56 | return arr.flatten()[idxs] 57 | 58 | 59 | def lambda_min(arr, axis=None, key=None, keepdims=False): 60 | idxs = np.argmin(key(arr), axis) 61 | if axis is not None: 62 | idxs = np.expand_dims(idxs, axis) 63 | result = np.take_along_axis(arr, idxs, axis) 64 | if not keepdims: 65 | result = np.squeeze(result, axis=axis) 66 | return result 67 | else: 68 | return arr.flatten()[idxs] 69 | 70 | 71 | def average_waveforms(pred_track, weights, algorithm): 72 | """ 73 | :param pred_track: shape = (num, channels, length) 74 | :param weights: shape = (num, ) 75 | :param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft 76 | :return: averaged waveform in shape (channels, length) 77 | """ 78 | 79 | pred_track = np.array(pred_track) 80 | final_length = pred_track.shape[-1] 81 | 82 | mod_track = [] 83 | for i in range(pred_track.shape[0]): 84 | if algorithm == 'avg_wave': 85 | mod_track.append(pred_track[i] * weights[i]) 86 | elif algorithm in ['median_wave', 'min_wave', 'max_wave']: 87 | mod_track.append(pred_track[i]) 88 | elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']: 89 | spec = stft(pred_track[i], nfft=2048, hl=1024) 90 | if algorithm in ['avg_fft']: 91 | mod_track.append(spec * weights[i]) 92 | else: 93 | mod_track.append(spec) 94 | pred_track = np.array(mod_track) 95 | 96 | if algorithm in ['avg_wave']: 97 | pred_track = pred_track.sum(axis=0) 98 | pred_track /= np.array(weights).sum().T 99 | elif algorithm in ['median_wave']: 100 | pred_track = np.median(pred_track, axis=0) 101 | elif algorithm in ['min_wave']: 102 | pred_track = np.array(pred_track) 103 | pred_track = lambda_min(pred_track, axis=0, key=np.abs) 104 | elif algorithm in ['max_wave']: 105 | pred_track = np.array(pred_track) 106 | pred_track = lambda_max(pred_track, axis=0, key=np.abs) 107 | elif algorithm in ['avg_fft']: 108 | pred_track = pred_track.sum(axis=0) 109 | pred_track /= np.array(weights).sum() 110 | pred_track = istft(pred_track, 1024, final_length) 111 | elif algorithm in ['min_fft']: 112 | pred_track = np.array(pred_track) 113 | pred_track = lambda_min(pred_track, axis=0, key=np.abs) 114 | pred_track = istft(pred_track, 1024, final_length) 115 | elif algorithm in ['max_fft']: 116 | pred_track = np.array(pred_track) 117 | pred_track = absmax(pred_track, axis=0) 118 | pred_track = istft(pred_track, 1024, final_length) 119 | elif algorithm in ['median_fft']: 120 | pred_track = np.array(pred_track) 121 | pred_track = np.median(pred_track, axis=0) 122 | pred_track = istft(pred_track, 1024, final_length) 123 | return pred_track 124 | 125 | 126 | def ensemble_files(args): 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("--files", type=str, required=True, nargs='+', help="Path to all audio-files to ensemble") 129 | parser.add_argument("--type", type=str, default='avg_wave', help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft") 130 | parser.add_argument("--weights", type=float, nargs='+', help="Weights to create ensemble. Number of weights must be equal to number of files") 131 | parser.add_argument("--output", default="res.wav", type=str, help="Path to wav file where ensemble result will be stored") 132 | if args is None: 133 | args = parser.parse_args() 134 | else: 135 | args = parser.parse_args(args) 136 | 137 | print('Ensemble type: {}'.format(args.type)) 138 | print('Number of input files: {}'.format(len(args.files))) 139 | if args.weights is not None: 140 | weights = args.weights 141 | else: 142 | weights = np.ones(len(args.files)) 143 | print('Weights: {}'.format(weights)) 144 | print('Output file: {}'.format(args.output)) 145 | data = [] 146 | for f in args.files: 147 | if not os.path.isfile(f): 148 | print('Error. Can\'t find file: {}. Check paths.'.format(f)) 149 | exit() 150 | print('Reading file: {}'.format(f)) 151 | wav, sr = librosa.load(f, sr=None, mono=False) 152 | # wav, sr = sf.read(f) 153 | print("Waveform shape: {} sample rate: {}".format(wav.shape, sr)) 154 | data.append(wav) 155 | data = np.array(data) 156 | res = average_waveforms(data, weights, args.type) 157 | print('Result shape: {}'.format(res.shape)) 158 | sf.write(args.output, res.T, sr, 'FLOAT') 159 | 160 | 161 | if __name__ == "__main__": 162 | ensemble_files(None) 163 | -------------------------------------------------------------------------------- /mss/models/upernet_swin_transformers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | import torch.nn as nn 4 | from transformers import UperNetForSemanticSegmentation 5 | 6 | 7 | class STFT: 8 | def __init__(self, config): 9 | self.n_fft = config.n_fft 10 | self.hop_length = config.hop_length 11 | self.window = torch.hann_window(window_length=self.n_fft, periodic=True) 12 | self.dim_f = config.dim_f 13 | 14 | def __call__(self, x): 15 | window = self.window.to(x.device) 16 | batch_dims = x.shape[:-2] 17 | c, t = x.shape[-2:] 18 | x = x.reshape([-1, t]) 19 | x = torch.stft( 20 | x, 21 | n_fft=self.n_fft, 22 | hop_length=self.hop_length, 23 | window=window, 24 | center=True, 25 | return_complex=True 26 | ) 27 | x = torch.view_as_real(x) 28 | x = x.permute([0, 3, 1, 2]) 29 | x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) 30 | return x[..., :self.dim_f, :] 31 | 32 | def inverse(self, x): 33 | window = self.window.to(x.device) 34 | batch_dims = x.shape[:-3] 35 | c, f, t = x.shape[-3:] 36 | n = self.n_fft // 2 + 1 37 | f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) 38 | x = torch.cat([x, f_pad], -2) 39 | x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) 40 | x = x.permute([0, 2, 3, 1]) 41 | x = x[..., 0] + x[..., 1] * 1.j 42 | x = torch.istft( 43 | x, 44 | n_fft=self.n_fft, 45 | hop_length=self.hop_length, 46 | window=window, 47 | center=True 48 | ) 49 | x = x.reshape([*batch_dims, 2, -1]) 50 | return x 51 | 52 | 53 | def get_norm(norm_type): 54 | def norm(c, norm_type): 55 | if norm_type == 'BatchNorm': 56 | return nn.BatchNorm2d(c) 57 | elif norm_type == 'InstanceNorm': 58 | return nn.InstanceNorm2d(c, affine=True) 59 | elif 'GroupNorm' in norm_type: 60 | g = int(norm_type.replace('GroupNorm', '')) 61 | return nn.GroupNorm(num_groups=g, num_channels=c) 62 | else: 63 | return nn.Identity() 64 | 65 | return partial(norm, norm_type=norm_type) 66 | 67 | 68 | def get_act(act_type): 69 | if act_type == 'gelu': 70 | return nn.GELU() 71 | elif act_type == 'relu': 72 | return nn.ReLU() 73 | elif act_type[:3] == 'elu': 74 | alpha = float(act_type.replace('elu', '')) 75 | return nn.ELU(alpha) 76 | else: 77 | raise Exception 78 | 79 | 80 | class Upscale(nn.Module): 81 | def __init__(self, in_c, out_c, scale, norm, act): 82 | super().__init__() 83 | self.conv = nn.Sequential( 84 | norm(in_c), 85 | act, 86 | nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) 87 | ) 88 | 89 | def forward(self, x): 90 | return self.conv(x) 91 | 92 | 93 | class Downscale(nn.Module): 94 | def __init__(self, in_c, out_c, scale, norm, act): 95 | super().__init__() 96 | self.conv = nn.Sequential( 97 | norm(in_c), 98 | act, 99 | nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) 100 | ) 101 | 102 | def forward(self, x): 103 | return self.conv(x) 104 | 105 | 106 | class TFC_TDF(nn.Module): 107 | def __init__(self, in_c, c, l, f, bn, norm, act): 108 | super().__init__() 109 | 110 | self.blocks = nn.ModuleList() 111 | for i in range(l): 112 | block = nn.Module() 113 | 114 | block.tfc1 = nn.Sequential( 115 | norm(in_c), 116 | act, 117 | nn.Conv2d(in_c, c, 3, 1, 1, bias=False), 118 | ) 119 | block.tdf = nn.Sequential( 120 | norm(c), 121 | act, 122 | nn.Linear(f, f // bn, bias=False), 123 | norm(c), 124 | act, 125 | nn.Linear(f // bn, f, bias=False), 126 | ) 127 | block.tfc2 = nn.Sequential( 128 | norm(c), 129 | act, 130 | nn.Conv2d(c, c, 3, 1, 1, bias=False), 131 | ) 132 | block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) 133 | 134 | self.blocks.append(block) 135 | in_c = c 136 | 137 | def forward(self, x): 138 | for block in self.blocks: 139 | s = block.shortcut(x) 140 | x = block.tfc1(x) 141 | x = x + block.tdf(x) 142 | x = block.tfc2(x) 143 | x = x + s 144 | return x 145 | 146 | 147 | class Swin_UperNet_Model(nn.Module): 148 | def __init__(self, config): 149 | super().__init__() 150 | self.config = config 151 | 152 | act = get_act(act_type=config.model.act) 153 | 154 | self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) 155 | self.num_subbands = config.model.num_subbands 156 | 157 | dim_c = self.num_subbands * config.audio.num_channels * 2 158 | c = config.model.num_channels 159 | f = config.audio.dim_f // self.num_subbands 160 | 161 | self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) 162 | 163 | self.swin_upernet_model = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-swin-large") 164 | 165 | self.swin_upernet_model.auxiliary_head.classifier = nn.Conv2d(256, c, kernel_size=(1, 1), stride=(1, 1)) 166 | self.swin_upernet_model.decode_head.classifier = nn.Conv2d(512, c, kernel_size=(1, 1), stride=(1, 1)) 167 | self.swin_upernet_model.backbone.embeddings.patch_embeddings.projection = nn.Conv2d(c, 192, kernel_size=(4, 4), stride=(4, 4)) 168 | 169 | self.final_conv = nn.Sequential( 170 | nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), 171 | act, 172 | nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) 173 | ) 174 | 175 | self.stft = STFT(config.audio) 176 | 177 | def cac2cws(self, x): 178 | k = self.num_subbands 179 | b, c, f, t = x.shape 180 | x = x.reshape(b, c, k, f // k, t) 181 | x = x.reshape(b, c * k, f // k, t) 182 | return x 183 | 184 | def cws2cac(self, x): 185 | k = self.num_subbands 186 | b, c, f, t = x.shape 187 | x = x.reshape(b, c // k, k, f, t) 188 | x = x.reshape(b, c // k, f * k, t) 189 | return x 190 | 191 | def forward(self, x): 192 | 193 | x = self.stft(x) 194 | 195 | mix = x = self.cac2cws(x) 196 | 197 | first_conv_out = x = self.first_conv(x) 198 | 199 | x = x.transpose(-1, -2) 200 | 201 | x = self.swin_upernet_model(x).logits 202 | 203 | x = x.transpose(-1, -2) 204 | 205 | x = x * first_conv_out # reduce artifacts 206 | 207 | x = self.final_conv(torch.cat([mix, x], 1)) 208 | 209 | x = self.cws2cac(x) 210 | 211 | if self.num_target_instruments > 1: 212 | b, c, f, t = x.shape 213 | x = x.reshape(b, self.num_target_instruments, -1, f, t) 214 | 215 | x = self.stft.inverse(x) 216 | return x 217 | 218 | 219 | if __name__ == "__main__": 220 | model = UperNetForSemanticSegmentation.from_pretrained("./results/", ignore_mismatched_sizes=True) 221 | print(model) 222 | print(model.auxiliary_head.classifier) 223 | print(model.decode_head.classifier) 224 | 225 | x = torch.zeros((2, 16, 512, 512), dtype=torch.float32) 226 | res = model(x) 227 | print(res.logits.shape) 228 | model.save_pretrained('./results/') -------------------------------------------------------------------------------- /mss/models/scnet_unofficial/modules/dualpath_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as Func 4 | 5 | class RMSNorm(nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | self.scale = dim ** 0.5 9 | self.gamma = nn.Parameter(torch.ones(dim)) 10 | 11 | def forward(self, x): 12 | return Func.normalize(x, dim=-1) * self.scale * self.gamma 13 | 14 | 15 | class MambaModule(nn.Module): 16 | def __init__(self, d_model, d_state, d_conv, d_expand): 17 | super().__init__() 18 | self.norm = RMSNorm(dim=d_model) 19 | self.mamba = Mamba( 20 | d_model=d_model, 21 | d_state=d_state, 22 | d_conv=d_conv, 23 | d_expand=d_expand 24 | ) 25 | 26 | def forward(self, x): 27 | x = x + self.mamba(self.norm(x)) 28 | return x 29 | 30 | 31 | class RNNModule(nn.Module): 32 | """ 33 | RNNModule class implements a recurrent neural network module with LSTM cells. 34 | 35 | Args: 36 | - input_dim (int): Dimensionality of the input features. 37 | - hidden_dim (int): Dimensionality of the hidden state of the LSTM. 38 | - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True. 39 | 40 | Shapes: 41 | - Input: (B, T, D) where 42 | B is batch size, 43 | T is sequence length, 44 | D is input dimensionality. 45 | - Output: (B, T, D) where 46 | B is batch size, 47 | T is sequence length, 48 | D is input dimensionality. 49 | """ 50 | 51 | def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True): 52 | """ 53 | Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag. 54 | """ 55 | super().__init__() 56 | self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim) 57 | self.rnn = nn.LSTM( 58 | input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional 59 | ) 60 | self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim) 61 | 62 | def forward(self, x: torch.Tensor) -> torch.Tensor: 63 | """ 64 | Performs forward pass through the RNNModule. 65 | 66 | Args: 67 | - x (torch.Tensor): Input tensor of shape (B, T, D). 68 | 69 | Returns: 70 | - torch.Tensor: Output tensor of shape (B, T, D). 71 | """ 72 | x = x.transpose(1, 2) 73 | x = self.groupnorm(x) 74 | x = x.transpose(1, 2) 75 | 76 | x, (hidden, _) = self.rnn(x) 77 | x = self.fc(x) 78 | return x 79 | 80 | 81 | class RFFTModule(nn.Module): 82 | """ 83 | RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT) 84 | or its inverse on input tensors. 85 | 86 | Args: 87 | - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False. 88 | 89 | Shapes: 90 | - Input: (B, F, T, D) where 91 | B is batch size, 92 | F is the number of features, 93 | T is sequence length, 94 | D is input dimensionality. 95 | - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT. 96 | (B, F, T, D // 2, 2) if performing inverse FFT. 97 | """ 98 | 99 | def __init__(self, inverse: bool = False): 100 | """ 101 | Initializes RFFTModule with inverse flag. 102 | """ 103 | super().__init__() 104 | self.inverse = inverse 105 | 106 | def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor: 107 | """ 108 | Performs forward or inverse FFT on the input tensor x. 109 | 110 | Args: 111 | - x (torch.Tensor): Input tensor of shape (B, F, T, D). 112 | - time_dim (int): Input size of time dimension. 113 | 114 | Returns: 115 | - torch.Tensor: Output tensor after FFT or its inverse operation. 116 | """ 117 | dtype = x.dtype 118 | B, F, T, D = x.shape 119 | 120 | # RuntimeError: cuFFT only supports dimensions whose sizes are powers of two when computing in half precision 121 | x = x.float() 122 | 123 | if not self.inverse: 124 | x = torch.fft.rfft(x, dim=2) 125 | x = torch.view_as_real(x) 126 | x = x.reshape(B, F, T // 2 + 1, D * 2) 127 | else: 128 | x = x.reshape(B, F, T, D // 2, 2) 129 | x = torch.view_as_complex(x) 130 | x = torch.fft.irfft(x, n=time_dim, dim=2) 131 | 132 | x = x.to(dtype) 133 | return x 134 | 135 | def extra_repr(self) -> str: 136 | """ 137 | Returns extra representation string with module's configuration. 138 | """ 139 | return f"inverse={self.inverse}" 140 | 141 | 142 | class DualPathRNN(nn.Module): 143 | """ 144 | DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule. 145 | 146 | Args: 147 | - n_layers (int): Number of layers in the network. 148 | - input_dim (int): Dimensionality of the input features. 149 | - hidden_dim (int): Dimensionality of the hidden state of the RNNModule. 150 | 151 | Shapes: 152 | - Input: (B, F, T, D) where 153 | B is batch size, 154 | F is the number of features (frequency dimension), 155 | T is sequence length (time dimension), 156 | D is input dimensionality (channel dimension). 157 | - Output: (B, F, T, D) where 158 | B is batch size, 159 | F is the number of features (frequency dimension), 160 | T is sequence length (time dimension), 161 | D is input dimensionality (channel dimension). 162 | """ 163 | 164 | def __init__( 165 | self, 166 | n_layers: int, 167 | input_dim: int, 168 | hidden_dim: int, 169 | 170 | use_mamba: bool = False, 171 | d_state: int = 16, 172 | d_conv: int = 4, 173 | d_expand: int = 2 174 | ): 175 | """ 176 | Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension. 177 | """ 178 | super().__init__() 179 | 180 | if use_mamba: 181 | from mamba_ssm.modules.mamba_simple import Mamba 182 | net = MambaModule 183 | dkwargs = {"d_model": input_dim, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand} 184 | ukwargs = {"d_model": input_dim * 2, "d_state": d_state, "d_conv": d_conv, "d_expand": d_expand * 2} 185 | else: 186 | net = RNNModule 187 | dkwargs = {"input_dim": input_dim, "hidden_dim": hidden_dim} 188 | ukwargs = {"input_dim": input_dim * 2, "hidden_dim": hidden_dim * 2} 189 | 190 | self.layers = nn.ModuleList() 191 | for i in range(1, n_layers + 1): 192 | kwargs = dkwargs if i % 2 == 1 else ukwargs 193 | layer = nn.ModuleList([ 194 | net(**kwargs), 195 | net(**kwargs), 196 | RFFTModule(inverse=(i % 2 == 0)), 197 | ]) 198 | self.layers.append(layer) 199 | 200 | def forward(self, x: torch.Tensor) -> torch.Tensor: 201 | """ 202 | Performs forward pass through the DualPathRNN. 203 | 204 | Args: 205 | - x (torch.Tensor): Input tensor of shape (B, F, T, D). 206 | 207 | Returns: 208 | - torch.Tensor: Output tensor of shape (B, F, T, D). 209 | """ 210 | 211 | time_dim = x.shape[2] 212 | 213 | for time_layer, freq_layer, rfft_layer in self.layers: 214 | B, F, T, D = x.shape 215 | 216 | x = x.reshape((B * F), T, D) 217 | x = time_layer(x) 218 | x = x.reshape(B, F, T, D) 219 | x = x.permute(0, 2, 1, 3) 220 | 221 | x = x.reshape((B * T), F, D) 222 | x = freq_layer(x) 223 | x = x.reshape(B, T, F, D) 224 | x = x.permute(0, 2, 1, 3) 225 | 226 | x = rfft_layer(x, time_dim) 227 | 228 | return x 229 | -------------------------------------------------------------------------------- /mss/models/mdx23c_tfc_tdf_v3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | 7 | class STFT: 8 | def __init__(self, config): 9 | self.n_fft = config.n_fft 10 | self.hop_length = config.hop_length 11 | self.window = torch.hann_window(window_length=self.n_fft, periodic=True) 12 | self.dim_f = config.dim_f 13 | 14 | def __call__(self, x): 15 | window = self.window.to(x.device) 16 | batch_dims = x.shape[:-2] 17 | c, t = x.shape[-2:] 18 | x = x.reshape([-1, t]) 19 | x = torch.stft( 20 | x, 21 | n_fft=self.n_fft, 22 | hop_length=self.hop_length, 23 | window=window, 24 | center=True, 25 | return_complex=True 26 | ) 27 | x = torch.view_as_real(x) 28 | x = x.permute([0, 3, 1, 2]) 29 | x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) 30 | return x[..., :self.dim_f, :] 31 | 32 | def inverse(self, x): 33 | window = self.window.to(x.device) 34 | batch_dims = x.shape[:-3] 35 | c, f, t = x.shape[-3:] 36 | n = self.n_fft // 2 + 1 37 | f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) 38 | x = torch.cat([x, f_pad], -2) 39 | x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) 40 | x = x.permute([0, 2, 3, 1]) 41 | x = x[..., 0] + x[..., 1] * 1.j 42 | x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=window, center=True) 43 | x = x.reshape([*batch_dims, 2, -1]) 44 | return x 45 | 46 | 47 | def get_norm(norm_type): 48 | def norm(c, norm_type): 49 | if norm_type == 'BatchNorm': 50 | return nn.BatchNorm2d(c) 51 | elif norm_type == 'InstanceNorm': 52 | return nn.InstanceNorm2d(c, affine=True) 53 | elif 'GroupNorm' in norm_type: 54 | g = int(norm_type.replace('GroupNorm', '')) 55 | return nn.GroupNorm(num_groups=g, num_channels=c) 56 | else: 57 | return nn.Identity() 58 | 59 | return partial(norm, norm_type=norm_type) 60 | 61 | 62 | def get_act(act_type): 63 | if act_type == 'gelu': 64 | return nn.GELU() 65 | elif act_type == 'relu': 66 | return nn.ReLU() 67 | elif act_type[:3] == 'elu': 68 | alpha = float(act_type.replace('elu', '')) 69 | return nn.ELU(alpha) 70 | else: 71 | raise Exception 72 | 73 | 74 | class Upscale(nn.Module): 75 | def __init__(self, in_c, out_c, scale, norm, act): 76 | super().__init__() 77 | self.conv = nn.Sequential( 78 | norm(in_c), 79 | act, 80 | nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) 81 | ) 82 | 83 | def forward(self, x): 84 | return self.conv(x) 85 | 86 | 87 | class Downscale(nn.Module): 88 | def __init__(self, in_c, out_c, scale, norm, act): 89 | super().__init__() 90 | self.conv = nn.Sequential( 91 | norm(in_c), 92 | act, 93 | nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=scale, stride=scale, bias=False) 94 | ) 95 | 96 | def forward(self, x): 97 | return self.conv(x) 98 | 99 | 100 | class TFC_TDF(nn.Module): 101 | def __init__(self, in_c, c, l, f, bn, norm, act): 102 | super().__init__() 103 | 104 | self.blocks = nn.ModuleList() 105 | for i in range(l): 106 | block = nn.Module() 107 | 108 | block.tfc1 = nn.Sequential( 109 | norm(in_c), 110 | act, 111 | nn.Conv2d(in_c, c, 3, 1, 1, bias=False), 112 | ) 113 | block.tdf = nn.Sequential( 114 | norm(c), 115 | act, 116 | nn.Linear(f, f // bn, bias=False), 117 | norm(c), 118 | act, 119 | nn.Linear(f // bn, f, bias=False), 120 | ) 121 | block.tfc2 = nn.Sequential( 122 | norm(c), 123 | act, 124 | nn.Conv2d(c, c, 3, 1, 1, bias=False), 125 | ) 126 | block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False) 127 | 128 | self.blocks.append(block) 129 | in_c = c 130 | 131 | def forward(self, x): 132 | for block in self.blocks: 133 | s = block.shortcut(x) 134 | x = block.tfc1(x) 135 | x = x + block.tdf(x) 136 | x = block.tfc2(x) 137 | x = x + s 138 | return x 139 | 140 | 141 | class TFC_TDF_net(nn.Module): 142 | def __init__(self, config): 143 | super().__init__() 144 | self.config = config 145 | 146 | norm = get_norm(norm_type=config.model.norm) 147 | act = get_act(act_type=config.model.act) 148 | 149 | self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) 150 | self.num_subbands = config.model.num_subbands 151 | 152 | dim_c = self.num_subbands * config.audio.num_channels * 2 153 | n = config.model.num_scales 154 | scale = config.model.scale 155 | l = config.model.num_blocks_per_scale 156 | c = config.model.num_channels 157 | g = config.model.growth 158 | bn = config.model.bottleneck_factor 159 | f = config.audio.dim_f // self.num_subbands 160 | 161 | self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) 162 | 163 | self.encoder_blocks = nn.ModuleList() 164 | for i in range(n): 165 | block = nn.Module() 166 | block.tfc_tdf = TFC_TDF(c, c, l, f, bn, norm, act) 167 | block.downscale = Downscale(c, c + g, scale, norm, act) 168 | f = f // scale[1] 169 | c += g 170 | self.encoder_blocks.append(block) 171 | 172 | self.bottleneck_block = TFC_TDF(c, c, l, f, bn, norm, act) 173 | 174 | self.decoder_blocks = nn.ModuleList() 175 | for i in range(n): 176 | block = nn.Module() 177 | block.upscale = Upscale(c, c - g, scale, norm, act) 178 | f = f * scale[1] 179 | c -= g 180 | block.tfc_tdf = TFC_TDF(2 * c, c, l, f, bn, norm, act) 181 | self.decoder_blocks.append(block) 182 | 183 | self.final_conv = nn.Sequential( 184 | nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), 185 | act, 186 | nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) 187 | ) 188 | 189 | self.stft = STFT(config.audio) 190 | 191 | def cac2cws(self, x): 192 | k = self.num_subbands 193 | b, c, f, t = x.shape 194 | x = x.reshape(b, c, k, f // k, t) 195 | x = x.reshape(b, c * k, f // k, t) 196 | return x 197 | 198 | def cws2cac(self, x): 199 | k = self.num_subbands 200 | b, c, f, t = x.shape 201 | x = x.reshape(b, c // k, k, f, t) 202 | x = x.reshape(b, c // k, f * k, t) 203 | return x 204 | 205 | def forward(self, x): 206 | 207 | x = self.stft(x) 208 | 209 | mix = x = self.cac2cws(x) 210 | 211 | first_conv_out = x = self.first_conv(x) 212 | 213 | x = x.transpose(-1, -2) 214 | 215 | encoder_outputs = [] 216 | for block in self.encoder_blocks: 217 | x = block.tfc_tdf(x) 218 | encoder_outputs.append(x) 219 | x = block.downscale(x) 220 | 221 | x = self.bottleneck_block(x) 222 | 223 | for block in self.decoder_blocks: 224 | x = block.upscale(x) 225 | x = torch.cat([x, encoder_outputs.pop()], 1) 226 | x = block.tfc_tdf(x) 227 | 228 | x = x.transpose(-1, -2) 229 | 230 | x = x * first_conv_out # reduce artifacts 231 | 232 | x = self.final_conv(torch.cat([mix, x], 1)) 233 | 234 | x = self.cws2cac(x) 235 | 236 | if self.num_target_instruments > 1: 237 | b, c, f, t = x.shape 238 | x = x.reshape(b, self.num_target_instruments, -1, f, t) 239 | 240 | x = self.stft.inverse(x) 241 | 242 | return x 243 | -------------------------------------------------------------------------------- /mss/models/torchseg_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchseg as smp 4 | 5 | 6 | class STFT: 7 | def __init__(self, config): 8 | self.n_fft = config.n_fft 9 | self.hop_length = config.hop_length 10 | self.window = torch.hann_window(window_length=self.n_fft, periodic=True) 11 | self.dim_f = config.dim_f 12 | 13 | def __call__(self, x): 14 | window = self.window.to(x.device) 15 | batch_dims = x.shape[:-2] 16 | c, t = x.shape[-2:] 17 | x = x.reshape([-1, t]) 18 | x = torch.stft( 19 | x, 20 | n_fft=self.n_fft, 21 | hop_length=self.hop_length, 22 | window=window, 23 | center=True, 24 | return_complex=True 25 | ) 26 | x = torch.view_as_real(x) 27 | x = x.permute([0, 3, 1, 2]) 28 | x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) 29 | return x[..., :self.dim_f, :] 30 | 31 | def inverse(self, x): 32 | window = self.window.to(x.device) 33 | batch_dims = x.shape[:-3] 34 | c, f, t = x.shape[-3:] 35 | n = self.n_fft // 2 + 1 36 | f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) 37 | x = torch.cat([x, f_pad], -2) 38 | x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) 39 | x = x.permute([0, 2, 3, 1]) 40 | x = x[..., 0] + x[..., 1] * 1.j 41 | x = torch.istft( 42 | x, 43 | n_fft=self.n_fft, 44 | hop_length=self.hop_length, 45 | window=window, 46 | center=True 47 | ) 48 | x = x.reshape([*batch_dims, 2, -1]) 49 | return x 50 | 51 | 52 | def get_act(act_type): 53 | if act_type == 'gelu': 54 | return nn.GELU() 55 | elif act_type == 'relu': 56 | return nn.ReLU() 57 | elif act_type[:3] == 'elu': 58 | alpha = float(act_type.replace('elu', '')) 59 | return nn.ELU(alpha) 60 | else: 61 | raise Exception 62 | 63 | 64 | def get_decoder(config, c): 65 | decoder = None 66 | decoder_options = dict() 67 | if config.model.decoder_type == 'unet': 68 | try: 69 | decoder_options = dict(config.decoder_unet) 70 | except: 71 | pass 72 | decoder = smp.Unet( 73 | encoder_name=config.model.encoder_name, 74 | encoder_weights="imagenet", 75 | in_channels=c, 76 | classes=c, 77 | **decoder_options, 78 | ) 79 | elif config.model.decoder_type == 'fpn': 80 | try: 81 | decoder_options = dict(config.decoder_fpn) 82 | except: 83 | pass 84 | decoder = smp.FPN( 85 | encoder_name=config.model.encoder_name, 86 | encoder_weights="imagenet", 87 | in_channels=c, 88 | classes=c, 89 | **decoder_options, 90 | ) 91 | elif config.model.decoder_type == 'unet++': 92 | try: 93 | decoder_options = dict(config.decoder_unet_plus_plus) 94 | except: 95 | pass 96 | decoder = smp.UnetPlusPlus( 97 | encoder_name=config.model.encoder_name, 98 | encoder_weights="imagenet", 99 | in_channels=c, 100 | classes=c, 101 | **decoder_options, 102 | ) 103 | elif config.model.decoder_type == 'manet': 104 | try: 105 | decoder_options = dict(config.decoder_manet) 106 | except: 107 | pass 108 | decoder = smp.MAnet( 109 | encoder_name=config.model.encoder_name, 110 | encoder_weights="imagenet", 111 | in_channels=c, 112 | classes=c, 113 | **decoder_options, 114 | ) 115 | elif config.model.decoder_type == 'linknet': 116 | try: 117 | decoder_options = dict(config.decoder_linknet) 118 | except: 119 | pass 120 | decoder = smp.Linknet( 121 | encoder_name=config.model.encoder_name, 122 | encoder_weights="imagenet", 123 | in_channels=c, 124 | classes=c, 125 | **decoder_options, 126 | ) 127 | elif config.model.decoder_type == 'pspnet': 128 | try: 129 | decoder_options = dict(config.decoder_pspnet) 130 | except: 131 | pass 132 | decoder = smp.PSPNet( 133 | encoder_name=config.model.encoder_name, 134 | encoder_weights="imagenet", 135 | in_channels=c, 136 | classes=c, 137 | **decoder_options, 138 | ) 139 | elif config.model.decoder_type == 'pspnet': 140 | try: 141 | decoder_options = dict(config.decoder_pspnet) 142 | except: 143 | pass 144 | decoder = smp.PSPNet( 145 | encoder_name=config.model.encoder_name, 146 | encoder_weights="imagenet", 147 | in_channels=c, 148 | classes=c, 149 | **decoder_options, 150 | ) 151 | elif config.model.decoder_type == 'pan': 152 | try: 153 | decoder_options = dict(config.decoder_pan) 154 | except: 155 | pass 156 | decoder = smp.PAN( 157 | encoder_name=config.model.encoder_name, 158 | encoder_weights="imagenet", 159 | in_channels=c, 160 | classes=c, 161 | **decoder_options, 162 | ) 163 | elif config.model.decoder_type == 'deeplabv3': 164 | try: 165 | decoder_options = dict(config.decoder_deeplabv3) 166 | except: 167 | pass 168 | decoder = smp.DeepLabV3( 169 | encoder_name=config.model.encoder_name, 170 | encoder_weights="imagenet", 171 | in_channels=c, 172 | classes=c, 173 | **decoder_options, 174 | ) 175 | elif config.model.decoder_type == 'deeplabv3plus': 176 | try: 177 | decoder_options = dict(config.decoder_deeplabv3plus) 178 | except: 179 | pass 180 | decoder = smp.DeepLabV3Plus( 181 | encoder_name=config.model.encoder_name, 182 | encoder_weights="imagenet", 183 | in_channels=c, 184 | classes=c, 185 | **decoder_options, 186 | ) 187 | return decoder 188 | 189 | 190 | class Torchseg_Net(nn.Module): 191 | def __init__(self, config): 192 | super().__init__() 193 | self.config = config 194 | 195 | act = get_act(act_type=config.model.act) 196 | 197 | self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) 198 | self.num_subbands = config.model.num_subbands 199 | 200 | dim_c = self.num_subbands * config.audio.num_channels * 2 201 | c = config.model.num_channels 202 | f = config.audio.dim_f // self.num_subbands 203 | 204 | self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) 205 | 206 | self.unet_model = get_decoder(config, c) 207 | 208 | self.final_conv = nn.Sequential( 209 | nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), 210 | act, 211 | nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) 212 | ) 213 | 214 | self.stft = STFT(config.audio) 215 | 216 | def cac2cws(self, x): 217 | k = self.num_subbands 218 | b, c, f, t = x.shape 219 | x = x.reshape(b, c, k, f // k, t) 220 | x = x.reshape(b, c * k, f // k, t) 221 | return x 222 | 223 | def cws2cac(self, x): 224 | k = self.num_subbands 225 | b, c, f, t = x.shape 226 | x = x.reshape(b, c // k, k, f, t) 227 | x = x.reshape(b, c // k, f * k, t) 228 | return x 229 | 230 | def forward(self, x): 231 | 232 | x = self.stft(x) 233 | 234 | mix = x = self.cac2cws(x) 235 | 236 | first_conv_out = x = self.first_conv(x) 237 | 238 | x = x.transpose(-1, -2) 239 | 240 | x = self.unet_model(x) 241 | 242 | x = x.transpose(-1, -2) 243 | 244 | x = x * first_conv_out # reduce artifacts 245 | 246 | x = self.final_conv(torch.cat([mix, x], 1)) 247 | 248 | x = self.cws2cac(x) 249 | 250 | if self.num_target_instruments > 1: 251 | b, c, f, t = x.shape 252 | x = x.reshape(b, self.num_target_instruments, -1, f, t) 253 | 254 | x = self.stft.inverse(x) 255 | return x 256 | -------------------------------------------------------------------------------- /mss/models/segm_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import segmentation_models_pytorch as smp 4 | 5 | 6 | class STFT: 7 | def __init__(self, config): 8 | self.n_fft = config.n_fft 9 | self.hop_length = config.hop_length 10 | self.window = torch.hann_window(window_length=self.n_fft, periodic=True) 11 | self.dim_f = config.dim_f 12 | 13 | def __call__(self, x): 14 | window = self.window.to(x.device) 15 | batch_dims = x.shape[:-2] 16 | c, t = x.shape[-2:] 17 | x = x.reshape([-1, t]) 18 | x = torch.stft( 19 | x, 20 | n_fft=self.n_fft, 21 | hop_length=self.hop_length, 22 | window=window, 23 | center=True, 24 | return_complex=True 25 | ) 26 | x = torch.view_as_real(x) 27 | x = x.permute([0, 3, 1, 2]) 28 | x = x.reshape([*batch_dims, c, 2, -1, x.shape[-1]]).reshape([*batch_dims, c * 2, -1, x.shape[-1]]) 29 | return x[..., :self.dim_f, :] 30 | 31 | def inverse(self, x): 32 | window = self.window.to(x.device) 33 | batch_dims = x.shape[:-3] 34 | c, f, t = x.shape[-3:] 35 | n = self.n_fft // 2 + 1 36 | f_pad = torch.zeros([*batch_dims, c, n - f, t]).to(x.device) 37 | x = torch.cat([x, f_pad], -2) 38 | x = x.reshape([*batch_dims, c // 2, 2, n, t]).reshape([-1, 2, n, t]) 39 | x = x.permute([0, 2, 3, 1]) 40 | x = x[..., 0] + x[..., 1] * 1.j 41 | x = torch.istft( 42 | x, 43 | n_fft=self.n_fft, 44 | hop_length=self.hop_length, 45 | window=window, 46 | center=True 47 | ) 48 | x = x.reshape([*batch_dims, 2, -1]) 49 | return x 50 | 51 | 52 | def get_act(act_type): 53 | if act_type == 'gelu': 54 | return nn.GELU() 55 | elif act_type == 'relu': 56 | return nn.ReLU() 57 | elif act_type[:3] == 'elu': 58 | alpha = float(act_type.replace('elu', '')) 59 | return nn.ELU(alpha) 60 | else: 61 | raise Exception 62 | 63 | 64 | def get_decoder(config, c): 65 | decoder = None 66 | decoder_options = dict() 67 | if config.model.decoder_type == 'unet': 68 | try: 69 | decoder_options = dict(config.decoder_unet) 70 | except: 71 | pass 72 | decoder = smp.Unet( 73 | encoder_name=config.model.encoder_name, 74 | encoder_weights="imagenet", 75 | in_channels=c, 76 | classes=c, 77 | **decoder_options, 78 | ) 79 | elif config.model.decoder_type == 'fpn': 80 | try: 81 | decoder_options = dict(config.decoder_fpn) 82 | except: 83 | pass 84 | decoder = smp.FPN( 85 | encoder_name=config.model.encoder_name, 86 | encoder_weights="imagenet", 87 | in_channels=c, 88 | classes=c, 89 | **decoder_options, 90 | ) 91 | elif config.model.decoder_type == 'unet++': 92 | try: 93 | decoder_options = dict(config.decoder_unet_plus_plus) 94 | except: 95 | pass 96 | decoder = smp.UnetPlusPlus( 97 | encoder_name=config.model.encoder_name, 98 | encoder_weights="imagenet", 99 | in_channels=c, 100 | classes=c, 101 | **decoder_options, 102 | ) 103 | elif config.model.decoder_type == 'manet': 104 | try: 105 | decoder_options = dict(config.decoder_manet) 106 | except: 107 | pass 108 | decoder = smp.MAnet( 109 | encoder_name=config.model.encoder_name, 110 | encoder_weights="imagenet", 111 | in_channels=c, 112 | classes=c, 113 | **decoder_options, 114 | ) 115 | elif config.model.decoder_type == 'linknet': 116 | try: 117 | decoder_options = dict(config.decoder_linknet) 118 | except: 119 | pass 120 | decoder = smp.Linknet( 121 | encoder_name=config.model.encoder_name, 122 | encoder_weights="imagenet", 123 | in_channels=c, 124 | classes=c, 125 | **decoder_options, 126 | ) 127 | elif config.model.decoder_type == 'pspnet': 128 | try: 129 | decoder_options = dict(config.decoder_pspnet) 130 | except: 131 | pass 132 | decoder = smp.PSPNet( 133 | encoder_name=config.model.encoder_name, 134 | encoder_weights="imagenet", 135 | in_channels=c, 136 | classes=c, 137 | **decoder_options, 138 | ) 139 | elif config.model.decoder_type == 'pspnet': 140 | try: 141 | decoder_options = dict(config.decoder_pspnet) 142 | except: 143 | pass 144 | decoder = smp.PSPNet( 145 | encoder_name=config.model.encoder_name, 146 | encoder_weights="imagenet", 147 | in_channels=c, 148 | classes=c, 149 | **decoder_options, 150 | ) 151 | elif config.model.decoder_type == 'pan': 152 | try: 153 | decoder_options = dict(config.decoder_pan) 154 | except: 155 | pass 156 | decoder = smp.PAN( 157 | encoder_name=config.model.encoder_name, 158 | encoder_weights="imagenet", 159 | in_channels=c, 160 | classes=c, 161 | **decoder_options, 162 | ) 163 | elif config.model.decoder_type == 'deeplabv3': 164 | try: 165 | decoder_options = dict(config.decoder_deeplabv3) 166 | except: 167 | pass 168 | decoder = smp.DeepLabV3( 169 | encoder_name=config.model.encoder_name, 170 | encoder_weights="imagenet", 171 | in_channels=c, 172 | classes=c, 173 | **decoder_options, 174 | ) 175 | elif config.model.decoder_type == 'deeplabv3plus': 176 | try: 177 | decoder_options = dict(config.decoder_deeplabv3plus) 178 | except: 179 | pass 180 | decoder = smp.DeepLabV3Plus( 181 | encoder_name=config.model.encoder_name, 182 | encoder_weights="imagenet", 183 | in_channels=c, 184 | classes=c, 185 | **decoder_options, 186 | ) 187 | return decoder 188 | 189 | 190 | class Segm_Models_Net(nn.Module): 191 | def __init__(self, config): 192 | super().__init__() 193 | self.config = config 194 | 195 | act = get_act(act_type=config.model.act) 196 | 197 | self.num_target_instruments = 1 if config.training.target_instrument else len(config.training.instruments) 198 | self.num_subbands = config.model.num_subbands 199 | 200 | dim_c = self.num_subbands * config.audio.num_channels * 2 201 | c = config.model.num_channels 202 | f = config.audio.dim_f // self.num_subbands 203 | 204 | self.first_conv = nn.Conv2d(dim_c, c, 1, 1, 0, bias=False) 205 | 206 | self.unet_model = get_decoder(config, c) 207 | 208 | self.final_conv = nn.Sequential( 209 | nn.Conv2d(c + dim_c, c, 1, 1, 0, bias=False), 210 | act, 211 | nn.Conv2d(c, self.num_target_instruments * dim_c, 1, 1, 0, bias=False) 212 | ) 213 | 214 | self.stft = STFT(config.audio) 215 | 216 | def cac2cws(self, x): 217 | k = self.num_subbands 218 | b, c, f, t = x.shape 219 | x = x.reshape(b, c, k, f // k, t) 220 | x = x.reshape(b, c * k, f // k, t) 221 | return x 222 | 223 | def cws2cac(self, x): 224 | k = self.num_subbands 225 | b, c, f, t = x.shape 226 | x = x.reshape(b, c // k, k, f, t) 227 | x = x.reshape(b, c // k, f * k, t) 228 | return x 229 | 230 | def forward(self, x): 231 | 232 | x = self.stft(x) 233 | 234 | mix = x = self.cac2cws(x) 235 | 236 | first_conv_out = x = self.first_conv(x) 237 | 238 | x = x.transpose(-1, -2) 239 | 240 | x = self.unet_model(x) 241 | 242 | x = x.transpose(-1, -2) 243 | 244 | x = x * first_conv_out # reduce artifacts 245 | 246 | x = self.final_conv(torch.cat([mix, x], 1)) 247 | 248 | x = self.cws2cac(x) 249 | 250 | if self.num_target_instruments > 1: 251 | b, c, f, t = x.shape 252 | x = x.reshape(b, self.num_target_instruments, -1, f, t) 253 | 254 | x = self.stft.inverse(x) 255 | return x 256 | -------------------------------------------------------------------------------- /mss/models/scnet_unofficial/scnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | SCNet - great paper, great implementation 3 | https://arxiv.org/pdf/2401.13276.pdf 4 | https://github.com/amanteur/SCNet-PyTorch 5 | ''' 6 | 7 | from typing import List 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torchaudio 13 | 14 | from .modules import DualPathRNN, SDBlock, SUBlock 15 | from .utils import compute_sd_layer_shapes, compute_gcr 16 | 17 | from einops import rearrange, pack, unpack 18 | from functools import partial 19 | 20 | from beartype.typing import Tuple, Optional, List, Callable 21 | from beartype import beartype 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | 27 | def default(v, d): 28 | return v if exists(v) else d 29 | 30 | 31 | def pack_one(t, pattern): 32 | return pack([t], pattern) 33 | 34 | 35 | def unpack_one(t, ps, pattern): 36 | return unpack(t, ps, pattern)[0] 37 | 38 | 39 | class RMSNorm(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.scale = dim ** 0.5 43 | self.gamma = nn.Parameter(torch.ones(dim)) 44 | 45 | def forward(self, x): 46 | return F.normalize(x, dim=-1) * self.scale * self.gamma 47 | 48 | 49 | class BandSplit(nn.Module): 50 | @beartype 51 | def __init__( 52 | self, 53 | dim, 54 | dim_inputs: Tuple[int, ...] 55 | ): 56 | super().__init__() 57 | self.dim_inputs = dim_inputs 58 | self.to_features = ModuleList([]) 59 | 60 | for dim_in in dim_inputs: 61 | net = nn.Sequential( 62 | RMSNorm(dim_in), 63 | nn.Linear(dim_in, dim) 64 | ) 65 | 66 | self.to_features.append(net) 67 | 68 | def forward(self, x): 69 | x = x.split(self.dim_inputs, dim=-1) 70 | 71 | outs = [] 72 | for split_input, to_feature in zip(x, self.to_features): 73 | split_output = to_feature(split_input) 74 | outs.append(split_output) 75 | 76 | return torch.stack(outs, dim=-2) 77 | 78 | 79 | class SCNet(nn.Module): 80 | """ 81 | SCNet class implements a source separation network, 82 | which explicitly split the spectrogram of the mixture into several subbands 83 | and introduce a sparsity-based encoder to model different frequency bands. 84 | 85 | Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION" 86 | Authors: Weinan Tong, Jiaxu Zhu et al. 87 | Link: https://arxiv.org/abs/2401.13276.pdf 88 | 89 | Args: 90 | - n_fft (int): Number of FFTs to determine the frequency dimension of the input. 91 | - dims (List[int]): List of channel dimensions for each block. 92 | - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. 93 | - downsample_strides (List[int]): List of stride values for downsampling in each block. 94 | - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block. 95 | - n_rnn_layers (int): Number of recurrent layers in the dual path RNN. 96 | - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN. 97 | - n_sources (int, optional): Number of sources to be separated. Default is 4. 98 | 99 | Shapes: 100 | - Input: (B, C, T) where 101 | B is batch size, 102 | C is channel dim (mono / stereo), 103 | T is time dim 104 | - Output: (B, N, C, T) where 105 | B is batch size, 106 | N is the number of sources. 107 | C is channel dim (mono / stereo), 108 | T is sequence length, 109 | """ 110 | @beartype 111 | def __init__( 112 | self, 113 | n_fft: int, 114 | dims: List[int], 115 | bandsplit_ratios: List[float], 116 | downsample_strides: List[int], 117 | n_conv_modules: List[int], 118 | n_rnn_layers: int, 119 | rnn_hidden_dim: int, 120 | n_sources: int = 4, 121 | hop_length: int = 1024, 122 | win_length: int = 4096, 123 | stft_window_fn: Optional[Callable] = None, 124 | stft_normalized: bool = False, 125 | **kwargs 126 | ): 127 | """ 128 | Initializes SCNet with input parameters. 129 | """ 130 | super().__init__() 131 | self.assert_input_data( 132 | bandsplit_ratios, 133 | downsample_strides, 134 | n_conv_modules, 135 | ) 136 | 137 | n_blocks = len(dims) - 1 138 | n_freq_bins = n_fft // 2 + 1 139 | subband_shapes, sd_intervals = compute_sd_layer_shapes( 140 | input_shape=n_freq_bins, 141 | bandsplit_ratios=bandsplit_ratios, 142 | downsample_strides=downsample_strides, 143 | n_layers=n_blocks, 144 | ) 145 | self.sd_blocks = nn.ModuleList( 146 | SDBlock( 147 | input_dim=dims[i], 148 | output_dim=dims[i + 1], 149 | bandsplit_ratios=bandsplit_ratios, 150 | downsample_strides=downsample_strides, 151 | n_conv_modules=n_conv_modules, 152 | ) 153 | for i in range(n_blocks) 154 | ) 155 | self.dualpath_blocks = DualPathRNN( 156 | n_layers=n_rnn_layers, 157 | input_dim=dims[-1], 158 | hidden_dim=rnn_hidden_dim, 159 | **kwargs 160 | ) 161 | self.su_blocks = nn.ModuleList( 162 | SUBlock( 163 | input_dim=dims[i + 1], 164 | output_dim=dims[i] if i != 0 else dims[i] * n_sources, 165 | subband_shapes=subband_shapes[i], 166 | sd_intervals=sd_intervals[i], 167 | upsample_strides=downsample_strides, 168 | ) 169 | for i in reversed(range(n_blocks)) 170 | ) 171 | self.gcr = compute_gcr(subband_shapes) 172 | 173 | self.stft_kwargs = dict( 174 | n_fft=n_fft, 175 | hop_length=hop_length, 176 | win_length=win_length, 177 | normalized=stft_normalized 178 | ) 179 | 180 | self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), win_length) 181 | self.n_sources = n_sources 182 | self.hop_length = hop_length 183 | 184 | @staticmethod 185 | def assert_input_data(*args): 186 | """ 187 | Asserts that the shapes of input features are equal. 188 | """ 189 | for arg1 in args: 190 | for arg2 in args: 191 | if len(arg1) != len(arg2): 192 | raise ValueError( 193 | f"Shapes of input features {arg1} and {arg2} are not equal." 194 | ) 195 | 196 | def forward(self, x: torch.Tensor) -> torch.Tensor: 197 | """ 198 | Performs forward pass through the SCNet. 199 | 200 | Args: 201 | - x (torch.Tensor): Input tensor of shape (B, C, T). 202 | 203 | Returns: 204 | - torch.Tensor: Output tensor of shape (B, N, C, T). 205 | """ 206 | 207 | device = x.device 208 | stft_window = self.stft_window_fn(device=device) 209 | 210 | if x.ndim == 2: 211 | x = rearrange(x, 'b t -> b 1 t') 212 | 213 | c = x.shape[1] 214 | 215 | stft_pad = self.hop_length - x.shape[-1] % self.hop_length 216 | x = F.pad(x, (0, stft_pad)) 217 | 218 | # stft 219 | x, ps = pack_one(x, '* t') 220 | x = torch.stft(x, **self.stft_kwargs, window=stft_window, return_complex=True) 221 | x = torch.view_as_real(x) 222 | x = unpack_one(x, ps, '* c f t') 223 | x = rearrange(x, 'b c f t r -> b f t (c r)') 224 | 225 | # encoder part 226 | x_skips = [] 227 | for sd_block in self.sd_blocks: 228 | x, x_skip = sd_block(x) 229 | x_skips.append(x_skip) 230 | 231 | # separation part 232 | x = self.dualpath_blocks(x) 233 | 234 | # decoder part 235 | for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)): 236 | x = su_block(x, x_skip) 237 | 238 | # istft 239 | x = rearrange(x, 'b f t (c r n) -> b n c f t r', c=c, n=self.n_sources, r=2) 240 | x = x.contiguous() 241 | 242 | x = torch.view_as_complex(x) 243 | x = rearrange(x, 'b n c f t -> (b n c) f t') 244 | x = torch.istft(x, **self.stft_kwargs, window=stft_window, return_complex=False) 245 | x = rearrange(x, '(b n c) t -> b n c t', c=c, n=self.n_sources) 246 | 247 | x = x[..., :-stft_pad] 248 | 249 | return x 250 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/musdb/preprocess.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torchaudio as ta 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from tqdm.contrib.concurrent import process_map 10 | 11 | from core.data._types import DataDict 12 | from core.data.musdb.dataset import MUSDB18FullTrackDataset 13 | import pyloudnorm as pyln 14 | 15 | class SourceActivityDetector(nn.Module): 16 | def __init__( 17 | self, 18 | analysis_stem: str, 19 | output_path: str, 20 | fs: int = 44100, 21 | segment_length_second: float = 6.0, 22 | hop_length_second: float = 3.0, 23 | n_chunks: int = 10, 24 | chunk_epsilon: float = 1e-5, 25 | energy_threshold_quantile: float = 0.15, 26 | segment_epsilon: float = 1e-3, 27 | salient_proportion_threshold: float = 0.5, 28 | target_lufs: float = -24 29 | ) -> None: 30 | super().__init__() 31 | 32 | self.fs = fs 33 | self.segment_length = int(segment_length_second * self.fs) 34 | self.hop_length = int(hop_length_second * self.fs) 35 | self.n_chunks = n_chunks 36 | assert self.segment_length % self.n_chunks == 0 37 | self.chunk_size = self.segment_length // self.n_chunks 38 | self.chunk_epsilon = chunk_epsilon 39 | self.energy_threshold_quantile = energy_threshold_quantile 40 | self.segment_epsilon = segment_epsilon 41 | self.salient_proportion_threshold = salient_proportion_threshold 42 | self.analysis_stem = analysis_stem 43 | 44 | self.meter = pyln.Meter(self.fs) 45 | self.target_lufs = target_lufs 46 | 47 | self.output_path = output_path 48 | 49 | def forward(self, data: DataDict) -> None: 50 | 51 | stem_ = self.analysis_stem if ( 52 | self.analysis_stem != "none") else "mixture" 53 | 54 | x = data["audio"][stem_] 55 | 56 | xnp = x.numpy() 57 | loudness = self.meter.integrated_loudness(xnp.T) 58 | 59 | for stem in data["audio"]: 60 | s = data["audio"][stem] 61 | s = pyln.normalize.loudness(s.numpy().T, loudness, self.target_lufs).T 62 | s = torch.as_tensor(s) 63 | data["audio"][stem] = s 64 | 65 | if x.ndim == 3: 66 | assert x.shape[0] == 1 67 | x = x[0] 68 | 69 | n_chan, n_samples = x.shape 70 | 71 | n_segments = ( 72 | int( 73 | np.ceil((n_samples - self.segment_length) / self.hop_length) 74 | ) + 1 75 | ) 76 | 77 | segments = torch.zeros((n_segments, n_chan, self.segment_length)) 78 | for i in range(n_segments): 79 | start = i * self.hop_length 80 | end = start + self.segment_length 81 | end = min(end, n_samples) 82 | 83 | xseg = x[:, start:end] 84 | 85 | if end - start < self.segment_length: 86 | xseg = F.pad( 87 | xseg, 88 | pad=(0, self.segment_length - (end - start)), 89 | value=torch.nan 90 | ) 91 | 92 | segments[i, :, :] = xseg 93 | 94 | chunks = segments.reshape( 95 | (n_segments, n_chan, self.n_chunks, self.chunk_size) 96 | ) 97 | 98 | if self.analysis_stem != "none": 99 | chunk_energies = torch.mean(torch.square(chunks), dim=(1, 3)) 100 | chunk_energies = torch.nan_to_num(chunk_energies, nan=0) 101 | chunk_energies[chunk_energies == 0] = self.chunk_epsilon 102 | 103 | energy_threshold = torch.nanquantile( 104 | chunk_energies, q=self.energy_threshold_quantile 105 | ) 106 | 107 | if energy_threshold < self.segment_epsilon: 108 | energy_threshold = self.segment_epsilon # type: ignore[assignment] 109 | 110 | chunks_above_threshold = chunk_energies > energy_threshold 111 | n_chunks_above_threshold = torch.mean( 112 | chunks_above_threshold.to(torch.float), dim=-1 113 | ) 114 | 115 | segment_above_threshold = ( 116 | n_chunks_above_threshold > self.salient_proportion_threshold 117 | ) 118 | 119 | if torch.sum(segment_above_threshold) == 0: 120 | return 121 | 122 | else: 123 | segment_above_threshold = torch.ones((n_segments,)) 124 | 125 | for i in range(n_segments): 126 | if not segment_above_threshold[i]: 127 | continue 128 | 129 | outpath = os.path.join( 130 | self.output_path, 131 | self.analysis_stem, 132 | f"{data['track']} - {self.analysis_stem}{i:03d}", 133 | ) 134 | os.makedirs(outpath, exist_ok=True) 135 | 136 | for stem in data["audio"]: 137 | if stem == self.analysis_stem: 138 | segment = torch.nan_to_num(segments[i, :, :], nan=0) 139 | else: 140 | start = i * self.hop_length 141 | end = start + self.segment_length 142 | end = min(n_samples, end) 143 | 144 | segment = data["audio"][stem][:, start:end] 145 | 146 | if end - start < self.segment_length: 147 | segment = F.pad( 148 | segment, 149 | (0, self.segment_length - (end - start)) 150 | ) 151 | 152 | assert segment.shape[-1] == self.segment_length, segment.shape 153 | 154 | # ta.save(os.path.join(outpath, f"{stem}.wav"), segment, self.fs) 155 | 156 | np.save(os.path.join(outpath, f"{stem}.wav"), segment) 157 | 158 | 159 | def preprocess( 160 | analysis_stem: str, 161 | output_path: str = "/data/MUSDB18/HQ/saded-np", 162 | fs: int = 44100, 163 | segment_length_second: float = 6.0, 164 | hop_length_second: float = 3.0, 165 | n_chunks: int = 10, 166 | chunk_epsilon: float = 1e-5, 167 | energy_threshold_quantile: float = 0.15, 168 | segment_epsilon: float = 1e-3, 169 | salient_proportion_threshold: float = 0.5, 170 | ) -> None: 171 | 172 | sad = SourceActivityDetector( 173 | analysis_stem=analysis_stem, 174 | output_path=output_path, 175 | fs=fs, 176 | segment_length_second=segment_length_second, 177 | hop_length_second=hop_length_second, 178 | n_chunks=n_chunks, 179 | chunk_epsilon=chunk_epsilon, 180 | energy_threshold_quantile=energy_threshold_quantile, 181 | segment_epsilon=segment_epsilon, 182 | salient_proportion_threshold=salient_proportion_threshold, 183 | ) 184 | 185 | for split in ["train", "val", "test"]: 186 | ds = MUSDB18FullTrackDataset( 187 | data_root="/data/MUSDB18/HQ/canonical", 188 | split=split, 189 | ) 190 | 191 | tracks = [] 192 | for i, track in enumerate(tqdm(ds, total=len(ds))): 193 | if i % 32 == 0 and tracks: 194 | process_map(sad, tracks, max_workers=8) 195 | tracks = [] 196 | tracks.append(track) 197 | process_map(sad, tracks, max_workers=8) 198 | 199 | def loudness_norm_one( 200 | inputs 201 | ): 202 | infile, outfile, target_lufs = inputs 203 | 204 | audio, fs = ta.load(infile) 205 | audio = audio.mean(dim=0, keepdim=True).numpy().T 206 | 207 | meter = pyln.Meter(fs) 208 | loudness = meter.integrated_loudness(audio) 209 | audio = pyln.normalize.loudness(audio, loudness, target_lufs) 210 | 211 | os.makedirs(os.path.dirname(outfile), exist_ok=True) 212 | np.save(outfile, audio.T) 213 | 214 | def loudness_norm( 215 | data_path: str, 216 | # output_path: str, 217 | target_lufs = -17.0, 218 | ): 219 | files = glob.glob( 220 | os.path.join(data_path, "**", "*.wav"), recursive=True 221 | ) 222 | 223 | outfiles = [ 224 | f.replace(".wav", ".npy").replace("saded", "saded-np") for f in files 225 | ] 226 | 227 | files = [(f, o, target_lufs) for f, o in zip(files, outfiles)] 228 | 229 | process_map(loudness_norm_one, files, chunksize=2) 230 | 231 | 232 | 233 | if __name__ == "__main__": 234 | 235 | from tqdm import tqdm 236 | import fire 237 | 238 | fire.Fire() 239 | -------------------------------------------------------------------------------- /mss/utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' 3 | 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import yaml 9 | from ml_collections import ConfigDict 10 | from omegaconf import OmegaConf 11 | 12 | 13 | def get_model_from_config(model_type, config_path): 14 | with open(config_path) as f: 15 | if model_type == 'htdemucs': 16 | config = OmegaConf.load(config_path) 17 | else: 18 | config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) 19 | 20 | if model_type == 'mdx23c': 21 | from .models.mdx23c_tfc_tdf_v3 import TFC_TDF_net 22 | model = TFC_TDF_net(config) 23 | elif model_type == 'htdemucs': 24 | from .models.demucs4ht import get_model 25 | model = get_model(config) 26 | elif model_type == 'segm_models': 27 | from .models.segm_models import Segm_Models_Net 28 | model = Segm_Models_Net(config) 29 | elif model_type == 'torchseg': 30 | from .models.torchseg_models import Torchseg_Net 31 | model = Torchseg_Net(config) 32 | elif model_type == 'mel_band_roformer': 33 | from .models.bs_roformer import MelBandRoformer 34 | model = MelBandRoformer( 35 | **dict(config.model) 36 | ) 37 | elif model_type == 'bs_roformer': 38 | from .models.bs_roformer import BSRoformer 39 | model = BSRoformer( 40 | **dict(config.model) 41 | ) 42 | print(dict(config.model)) 43 | elif model_type == 'swin_upernet': 44 | from .models.upernet_swin_transformers import Swin_UperNet_Model 45 | model = Swin_UperNet_Model(config) 46 | elif model_type == 'bandit': 47 | from .models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple 48 | model = MultiMaskMultiSourceBandSplitRNNSimple( 49 | **config.model 50 | ) 51 | elif model_type == 'scnet_unofficial': 52 | from .models.scnet_unofficial import SCNet 53 | model = SCNet( 54 | **config.model 55 | ) 56 | elif model_type == 'scnet': 57 | from .models.scnet import SCNet 58 | model = SCNet( 59 | **config.model 60 | ) 61 | else: 62 | print('Unknown model: {}'.format(model_type)) 63 | model = None 64 | 65 | return model, config 66 | 67 | 68 | def demix_track(config, model, mix, device): 69 | C = config.audio.chunk_size 70 | N = config.inference.num_overlap 71 | fade_size = C // 10 72 | step = int(C // N) 73 | border = C - step 74 | batch_size = config.inference.batch_size 75 | 76 | length_init = mix.shape[-1] 77 | 78 | # Do pad from the beginning and end to account floating window results better 79 | if length_init > 2 * border and (border > 0): 80 | mix = nn.functional.pad(mix, (border, border), mode='reflect') 81 | 82 | # Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment 83 | window_size = C 84 | fadein = torch.linspace(0, 1, fade_size) 85 | fadeout = torch.linspace(1, 0, fade_size) 86 | window_start = torch.ones(window_size) 87 | window_middle = torch.ones(window_size) 88 | window_finish = torch.ones(window_size) 89 | window_start[-fade_size:] *= fadeout # First audio chunk, no fadein 90 | window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout 91 | window_middle[-fade_size:] *= fadeout 92 | window_middle[:fade_size] *= fadein 93 | 94 | with torch.cuda.amp.autocast(): 95 | with torch.inference_mode(): 96 | if config.training.target_instrument is not None: 97 | req_shape = (1, ) + tuple(mix.shape) 98 | else: 99 | req_shape = (len(config.training.instruments),) + tuple(mix.shape) 100 | 101 | result = torch.zeros(req_shape, dtype=torch.float32) 102 | counter = torch.zeros(req_shape, dtype=torch.float32) 103 | i = 0 104 | batch_data = [] 105 | batch_locations = [] 106 | while i < mix.shape[1]: 107 | # print(i, i + C, mix.shape[1]) 108 | part = mix[:, i:i + C].to(device) 109 | length = part.shape[-1] 110 | if length < C: 111 | if length > C // 2 + 1: 112 | part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect') 113 | else: 114 | part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) 115 | batch_data.append(part) 116 | batch_locations.append((i, length)) 117 | i += step 118 | 119 | if len(batch_data) >= batch_size or (i >= mix.shape[1]): 120 | arr = torch.stack(batch_data, dim=0) 121 | x = model(arr) 122 | 123 | window = window_middle 124 | if i - step == 0: # First audio chunk, no fadein 125 | window = window_start 126 | elif i >= mix.shape[1]: # Last audio chunk, no fadeout 127 | window = window_finish 128 | 129 | for j in range(len(batch_locations)): 130 | start, l = batch_locations[j] 131 | result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l] 132 | counter[..., start:start+l] += window[..., :l] 133 | 134 | batch_data = [] 135 | batch_locations = [] 136 | 137 | estimated_sources = result / counter 138 | estimated_sources = estimated_sources.cpu().numpy() 139 | np.nan_to_num(estimated_sources, copy=False, nan=0.0) 140 | 141 | if length_init > 2 * border and (border > 0): 142 | # Remove pad 143 | estimated_sources = estimated_sources[..., border:-border] 144 | 145 | if config.training.target_instrument is None: 146 | return {k: v for k, v in zip(config.training.instruments, estimated_sources)} 147 | else: 148 | return {k: v for k, v in zip([config.training.target_instrument], estimated_sources)} 149 | 150 | 151 | def demix_track_demucs(config, model, mix, device): 152 | S = len(config.training.instruments) 153 | C = config.training.samplerate * config.training.segment 154 | N = config.inference.num_overlap 155 | batch_size = config.inference.batch_size 156 | step = C // N 157 | # print(S, C, N, step, mix.shape, mix.device) 158 | 159 | with torch.cuda.amp.autocast(enabled=config.training.use_amp): 160 | with torch.inference_mode(): 161 | req_shape = (S, ) + tuple(mix.shape) 162 | result = torch.zeros(req_shape, dtype=torch.float32) 163 | counter = torch.zeros(req_shape, dtype=torch.float32) 164 | i = 0 165 | batch_data = [] 166 | batch_locations = [] 167 | while i < mix.shape[1]: 168 | # print(i, i + C, mix.shape[1]) 169 | part = mix[:, i:i + C].to(device) 170 | length = part.shape[-1] 171 | if length < C: 172 | part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0) 173 | batch_data.append(part) 174 | batch_locations.append((i, length)) 175 | i += step 176 | 177 | if len(batch_data) >= batch_size or (i >= mix.shape[1]): 178 | arr = torch.stack(batch_data, dim=0) 179 | x = model(arr) 180 | for j in range(len(batch_locations)): 181 | start, l = batch_locations[j] 182 | result[..., start:start+l] += x[j][..., :l].cpu() 183 | counter[..., start:start+l] += 1. 184 | batch_data = [] 185 | batch_locations = [] 186 | 187 | estimated_sources = result / counter 188 | estimated_sources = estimated_sources.cpu().numpy() 189 | np.nan_to_num(estimated_sources, copy=False, nan=0.0) 190 | 191 | if S > 1: 192 | return {k: v for k, v in zip(config.training.instruments, estimated_sources)} 193 | else: 194 | return estimated_sources 195 | 196 | 197 | def sdr(references, estimates): 198 | # compute SDR for one song 199 | delta = 1e-7 # avoid numerical errors 200 | num = np.sum(np.square(references), axis=(1, 2)) 201 | den = np.sum(np.square(references - estimates), axis=(1, 2)) 202 | num += delta 203 | den += delta 204 | return 10 * np.log10(num / den) 205 | -------------------------------------------------------------------------------- /mss/models/scnet_unofficial/modules/su_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.scnet_unofficial.utils import get_convtranspose_output_padding 7 | 8 | 9 | class FusionLayer(nn.Module): 10 | """ 11 | FusionLayer class implements a module for fusing two input tensors using convolutional operations. 12 | 13 | Args: 14 | - input_dim (int): Dimensionality of the input channels. 15 | - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3. 16 | - stride (int, optional): Stride value for the convolutional layer. Default is 1. 17 | - padding (int, optional): Padding value for the convolutional layer. Default is 1. 18 | 19 | Shapes: 20 | - Input: (B, F, T, C) and (B, F, T, C) where 21 | B is batch size, 22 | F is the number of features, 23 | T is sequence length, 24 | C is input dimensionality. 25 | - Output: (B, F, T, C) where 26 | B is batch size, 27 | F is the number of features, 28 | T is sequence length, 29 | C is input dimensionality. 30 | """ 31 | 32 | def __init__( 33 | self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 34 | ): 35 | """ 36 | Initializes FusionLayer with input dimension, kernel size, stride, and padding. 37 | """ 38 | super().__init__() 39 | self.conv = nn.Conv2d( 40 | input_dim * 2, 41 | input_dim * 2, 42 | kernel_size=(kernel_size, 1), 43 | stride=(stride, 1), 44 | padding=(padding, 0), 45 | ) 46 | self.activation = nn.GLU() 47 | 48 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 49 | """ 50 | Performs forward pass through the FusionLayer. 51 | 52 | Args: 53 | - x1 (torch.Tensor): First input tensor of shape (B, F, T, C). 54 | - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C). 55 | 56 | Returns: 57 | - torch.Tensor: Output tensor of shape (B, F, T, C). 58 | """ 59 | x = x1 + x2 60 | x = x.repeat(1, 1, 1, 2) 61 | x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 62 | x = self.activation(x) 63 | return x 64 | 65 | 66 | class Upsample(nn.Module): 67 | """ 68 | Upsample class implements a module for upsampling input tensors using transposed 2D convolution. 69 | 70 | Args: 71 | - input_dim (int): Dimensionality of the input channels. 72 | - output_dim (int): Dimensionality of the output channels. 73 | - stride (int): Stride value for the transposed convolution operation. 74 | - output_padding (int): Output padding value for the transposed convolution operation. 75 | 76 | Shapes: 77 | - Input: (B, C_in, F, T) where 78 | B is batch size, 79 | C_in is the number of input channels, 80 | F is the frequency dimension, 81 | T is the time dimension. 82 | - Output: (B, C_out, F * stride + output_padding, T) where 83 | B is batch size, 84 | C_out is the number of output channels, 85 | F * stride + output_padding is the upsampled frequency dimension. 86 | """ 87 | 88 | def __init__( 89 | self, input_dim: int, output_dim: int, stride: int, output_padding: int 90 | ): 91 | """ 92 | Initializes Upsample with input dimension, output dimension, stride, and output padding. 93 | """ 94 | super().__init__() 95 | self.conv = nn.ConvTranspose2d( 96 | input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) 97 | ) 98 | 99 | def forward(self, x: torch.Tensor) -> torch.Tensor: 100 | """ 101 | Performs forward pass through the Upsample module. 102 | 103 | Args: 104 | - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). 105 | 106 | Returns: 107 | - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T). 108 | """ 109 | return self.conv(x) 110 | 111 | 112 | class SULayer(nn.Module): 113 | """ 114 | SULayer class implements a subband upsampling layer using transposed convolution. 115 | 116 | Args: 117 | - input_dim (int): Dimensionality of the input channels. 118 | - output_dim (int): Dimensionality of the output channels. 119 | - upsample_stride (int): Stride value for the upsampling operation. 120 | - subband_shape (int): Shape of the subband. 121 | - sd_interval (Tuple[int, int]): Start and end indices of the subband interval. 122 | 123 | Shapes: 124 | - Input: (B, F, T, C) where 125 | B is batch size, 126 | F is the number of features, 127 | T is sequence length, 128 | C is input dimensionality. 129 | - Output: (B, F, T, C) where 130 | B is batch size, 131 | F is the number of features, 132 | T is sequence length, 133 | C is input dimensionality. 134 | """ 135 | 136 | def __init__( 137 | self, 138 | input_dim: int, 139 | output_dim: int, 140 | upsample_stride: int, 141 | subband_shape: int, 142 | sd_interval: Tuple[int, int], 143 | ): 144 | """ 145 | Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval. 146 | """ 147 | super().__init__() 148 | sd_shape = sd_interval[1] - sd_interval[0] 149 | upsample_output_padding = get_convtranspose_output_padding( 150 | input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride 151 | ) 152 | self.upsample = Upsample( 153 | input_dim=input_dim, 154 | output_dim=output_dim, 155 | stride=upsample_stride, 156 | output_padding=upsample_output_padding, 157 | ) 158 | self.sd_interval = sd_interval 159 | 160 | def forward(self, x: torch.Tensor) -> torch.Tensor: 161 | """ 162 | Performs forward pass through the SULayer. 163 | 164 | Args: 165 | - x (torch.Tensor): Input tensor of shape (B, F, T, C). 166 | 167 | Returns: 168 | - torch.Tensor: Output tensor of shape (B, F, T, C). 169 | """ 170 | x = x[:, self.sd_interval[0] : self.sd_interval[1]] 171 | x = x.permute(0, 3, 1, 2) 172 | x = self.upsample(x) 173 | x = x.permute(0, 2, 3, 1) 174 | return x 175 | 176 | 177 | class SUBlock(nn.Module): 178 | """ 179 | SUBlock class implements a block with fusion layer and subband upsampling layers. 180 | 181 | Args: 182 | - input_dim (int): Dimensionality of the input channels. 183 | - output_dim (int): Dimensionality of the output channels. 184 | - upsample_strides (List[int]): List of stride values for the upsampling operations. 185 | - subband_shapes (List[int]): List of shapes for the subbands. 186 | - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition. 187 | 188 | Shapes: 189 | - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where 190 | B is batch size, 191 | Fi-1 is the number of input subbands, 192 | T is sequence length, 193 | Ci-1 is the number of input channels. 194 | - Output: (B, Fi, T, Ci) where 195 | B is batch size, 196 | Fi is the number of output subbands, 197 | T is sequence length, 198 | Ci is the number of output channels. 199 | """ 200 | 201 | def __init__( 202 | self, 203 | input_dim: int, 204 | output_dim: int, 205 | upsample_strides: List[int], 206 | subband_shapes: List[int], 207 | sd_intervals: List[Tuple[int, int]], 208 | ): 209 | """ 210 | Initializes SUBlock with input dimension, output dimension, 211 | upsample strides, subband shapes, and subband intervals. 212 | """ 213 | super().__init__() 214 | self.fusion_layer = FusionLayer(input_dim=input_dim) 215 | self.su_layers = nn.ModuleList( 216 | SULayer( 217 | input_dim=input_dim, 218 | output_dim=output_dim, 219 | upsample_stride=uss, 220 | subband_shape=sbs, 221 | sd_interval=sdi, 222 | ) 223 | for i, (uss, sbs, sdi) in enumerate( 224 | zip(upsample_strides, subband_shapes, sd_intervals) 225 | ) 226 | ) 227 | 228 | def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: 229 | """ 230 | Performs forward pass through the SUBlock. 231 | 232 | Args: 233 | - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1). 234 | - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1). 235 | 236 | Returns: 237 | - torch.Tensor: Output tensor of shape (B, Fi, T, Ci). 238 | """ 239 | x = self.fusion_layer(x, x_skip) 240 | x = torch.concat([layer(x) for layer in self.su_layers], dim=1) 241 | return x 242 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import urllib.request 4 | import torch 5 | import torchaudio 6 | import cuda_malloc 7 | import numpy as np 8 | import urllib 9 | import folder_paths 10 | from .mss.utils import get_model_from_config,demix_track,demix_track_demucs 11 | 12 | now_dir = os.path.dirname(os.path.abspath(__file__)) 13 | models_dir = os.path.join(folder_paths.models_dir,"AIFSH", "Music-Source-Separation-Training") 14 | os.makedirs(models_dir,exist_ok=True) 15 | 16 | base_url = "https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.0/" 17 | configs_dir = os.path.join(now_dir,"mss","configs") 18 | device = "cuda" if cuda_malloc.cuda_malloc_supported() else "cpu" 19 | config_path_dict = { 20 | 'htdemucs':["config_vocals_htdemucs.yaml","model_vocals_htdemucs_sdr_8.78.ckpt"], 21 | 'mdx23c':["config_vocals_mdx23c.yaml","model_vocals_mdx23c_sdr_10.17.ckpt",], 22 | 'segm_models':["config_vocals_segm_models.yaml","model_vocals_segm_models_sdr_9.77.ckpt"], 23 | 'mel_band_roformer':["model_mel_band_roformer_ep_3005_sdr_11.4360.yaml","https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"], 24 | 'bs_roformer':["model_bs_roformer_ep_317_sdr_12.9755.yaml","https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_317_sdr_12.9755.ckpt"], 25 | # 'swin_upernet':["config_vocals_swin_upernet.yaml","https://github.com/ZFTurbo/Music-Source-Separation-Training/releases/download/v1.0.2/model_swin_upernet_ep_56_sdr_10.6703.ckpt"] 26 | } 27 | 28 | def download_from_url(url,path): 29 | def progressbar(cur, cursize,totalsize): 30 | percent = '{:.2%}'.format(cur / 100) 31 | sys.stdout.write('\r') 32 | # sys.stdout.write("[%-50s] %s" % ('=' * int(math.floor(cur * 50 / total)),percent)) 33 | sys.stdout.write("Downloading [%-50s] %s %s/%sMB" % ('=' * int(cur), percent,cursize//1024//1024,totalsize//1024//1024)) 34 | sys.stdout.flush() 35 | 36 | 37 | def schedule(blocknum,blocksize,totalsize): 38 | """ 39 | blocknum:当前已经下载的块 40 | blocksize:每次传输的块大小 41 | totalsize:网页文件总大小 42 | """ 43 | if totalsize == 0: 44 | percent = 0 45 | else: 46 | percent = blocknum * blocksize / totalsize 47 | if percent > 1.0: 48 | percent = 1.0 49 | percent = percent * 100 50 | # print("download : %.2f%%" %(percent)) 51 | progressbar(percent,blocknum * blocksize,totalsize) 52 | urllib.request.urlretrieve(url,path,schedule) 53 | 54 | class VocalSeparationNode: 55 | def __init__(self) -> None: 56 | self.model_type = None 57 | self.model = None 58 | self.config = None 59 | @classmethod 60 | def INPUT_TYPES(s): 61 | return { 62 | "required":{ 63 | "music":("AUDIO",), 64 | "model_type":(['htdemucs','mdx23c','segm_models', 65 | 'mel_band_roformer','bs_roformer'],{ 66 | "default": 'bs_roformer' 67 | }), 68 | "batch_size":("INT",{ 69 | "default": 4 70 | }), 71 | "if_mirror":("BOOLEAN",{ 72 | "default": True 73 | }) 74 | } 75 | } 76 | 77 | RETURN_TYPES = ("AUDIO","AUDIO",) 78 | RETURN_NAMES = ("vocals_AUDIO","instrumental_AUDIO",) 79 | 80 | FUNCTION = "separate" 81 | 82 | #OUTPUT_NODE = False 83 | 84 | CATEGORY = "AIFSH_VocalSeparation" 85 | 86 | def separate(self,music,model_type,batch_size,if_mirror): 87 | torch.backends.cudnn.benchmark = True 88 | if model_type in ['mel_band_roformer','bs_roformer']: 89 | config_path = os.path.join(configs_dir,"viperx",config_path_dict[model_type][0]) 90 | model_url = ("https://mirror.ghproxy.com/" if if_mirror else "") + config_path_dict[model_type][1] 91 | model_path = os.path.join(models_dir,model_url.split("/")[-1]) 92 | else: 93 | config_path = os.path.join(configs_dir,config_path_dict[model_type][0]) 94 | model_url = ("https://mirror.ghproxy.com/" if if_mirror else "") + base_url + config_path_dict[model_type][1] 95 | model_path = os.path.join(models_dir,config_path_dict[model_type][1]) 96 | 97 | if not os.path.isfile(model_path): 98 | print(f"Downloading {model_path} from {model_url}") 99 | download_from_url(model_url,model_path) 100 | 101 | if self.model_type != model_type: 102 | self.model_type = model_type 103 | self.model, self.config = get_model_from_config(model_type, config_path) 104 | print('Start from checkpoint: {}'.format(model_path)) 105 | state_dict = torch.load(model_path) 106 | if model_type == 'htdemucs': 107 | # Fix for htdemucs pround etrained models 108 | if 'state' in state_dict: 109 | state_dict = state_dict['state'] 110 | self.model.load_state_dict(state_dict) 111 | 112 | print("Instruments: {}".format(self.config.training.instruments)) 113 | self.model.to(device) 114 | self.model.eval() 115 | self.config.inference.batch_size = batch_size 116 | 117 | audio_data = music["waveform"].squeeze(0) 118 | audio_rate = music['sample_rate'] 119 | target_sr = 44100 120 | if audio_rate != target_sr: 121 | audio_data = torchaudio.transforms.Resample(audio_rate,target_sr)(audio_data) 122 | 123 | mix = audio_data.numpy()[0] 124 | # print(mix.shape) 125 | 126 | # Convert mono to stereo if needed 127 | if len(mix.shape) == 1: 128 | mix = np.stack([mix, mix], axis=0) 129 | # print(mix.shape) 130 | mix_orig = mix.copy() 131 | if 'normalize' in self.config.inference: 132 | if self.config.inference['normalize'] is True: 133 | mono = mix.mean(0) 134 | mean = mono.mean() 135 | std = mono.std() 136 | mix = (mix - mean) / std 137 | mixture = torch.tensor(mix, dtype=torch.float32) 138 | 139 | if self.model_type == 'htdemucs': 140 | res = demix_track_demucs(self.config, self.model, mixture, device) 141 | else: 142 | res = demix_track(self.config, self.model, mixture, device) 143 | 144 | estimates = res['vocals'].T 145 | # print(estimates.shape) 146 | if 'normalize' in self.config.inference: 147 | if self.config.inference['normalize'] is True: 148 | estimates = estimates * std + mean 149 | 150 | estimates_mono = estimates.mean(1) 151 | estimates_t = torch.tensor(estimates_mono,dtype=torch.float32).unsqueeze(0).unsqueeze(0) 152 | vocals = { 153 | "waveform":estimates_t, 154 | "sample_rate": target_sr 155 | } 156 | instru_mono = (mix_orig.T-estimates).mean(1) 157 | instru_t = torch.Tensor(instru_mono).unsqueeze(0).unsqueeze(0) 158 | instrumental = { 159 | "waveform":instru_t, 160 | "sample_rate": target_sr 161 | } 162 | self.model.to("cpu") 163 | torch.backends.cudnn.benchmark = False 164 | return (vocals,instrumental,) 165 | 166 | class CombineAudioNode: 167 | @classmethod 168 | def INPUT_TYPES(s): 169 | return { 170 | "required":{ 171 | "vocal":("AUDIO",), 172 | "instrumental":("AUDIO",) 173 | } 174 | } 175 | 176 | RETURN_TYPES = ("AUDIO",) 177 | 178 | FUNCTION = "combine" 179 | 180 | #OUTPUT_NODE = False 181 | 182 | CATEGORY = "AIFSH_VocalSeparation" 183 | 184 | def audio2numpy(self,audio,target_sr): 185 | audio_data = audio["waveform"].squeeze(0) 186 | audio_rate = audio['sample_rate'] 187 | 188 | if audio_rate != target_sr: 189 | audio_data = torchaudio.transforms.Resample(audio_rate,target_sr)(audio_data) 190 | 191 | mix = audio_data.numpy()[0] 192 | print(mix.shape) 193 | return mix 194 | 195 | def combine(self,vocal,instrumental): 196 | target_sr = 44100 197 | vocal_np = self.audio2numpy(vocal,target_sr) 198 | instrumental_np = self.audio2numpy(instrumental,target_sr) 199 | dur = vocal_np.shape[0] - instrumental_np.shape[0] 200 | slient_np = np.zeros(abs(dur)) 201 | if dur < 0: 202 | vocal_np = np.concatenate([vocal_np, slient_np])[:instrumental_np.shape[0]] 203 | else: 204 | instrumental_np = np.concatenate([instrumental_np,slient_np])[:vocal_np.shape[0]] 205 | 206 | total_np = vocal_np + instrumental_np 207 | 208 | res_audio = { 209 | "waveform":torch.tensor(total_np).unsqueeze(0).unsqueeze(0), 210 | "sample_rate": target_sr 211 | } 212 | return (res_audio,) 213 | 214 | 215 | NODE_CLASS_MAPPINGS = { 216 | "CombineAudioNode":CombineAudioNode, 217 | "VocalSeparationNode": VocalSeparationNode 218 | } 219 | -------------------------------------------------------------------------------- /mss/models/bandit/core/data/musdb/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC 3 | from typing import List, Optional, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torchaudio as ta 8 | from torch.utils import data 9 | 10 | from models.bandit.core.data._types import AudioDict, DataDict 11 | from models.bandit.core.data.base import BaseSourceSeparationDataset 12 | 13 | 14 | class MUSDB18BaseDataset(BaseSourceSeparationDataset, ABC): 15 | 16 | ALLOWED_STEMS = ["mixture", "vocals", "bass", "drums", "other"] 17 | 18 | def __init__( 19 | self, 20 | split: str, 21 | stems: List[str], 22 | files: List[str], 23 | data_path: str, 24 | fs: int = 44100, 25 | npy_memmap=False, 26 | ) -> None: 27 | super().__init__( 28 | split=split, 29 | stems=stems, 30 | files=files, 31 | data_path=data_path, 32 | fs=fs, 33 | npy_memmap=npy_memmap, 34 | recompute_mixture=False 35 | ) 36 | 37 | def get_stem(self, *, stem: str, identifier) -> torch.Tensor: 38 | track = identifier["track"] 39 | path = os.path.join(self.data_path, track) 40 | # noinspection PyUnresolvedReferences 41 | 42 | if self.npy_memmap: 43 | audio = np.load(os.path.join(path, f"{stem}.wav.npy"), mmap_mode="r") 44 | else: 45 | audio, _ = ta.load(os.path.join(path, f"{stem}.wav")) 46 | 47 | return audio 48 | 49 | def get_identifier(self, index): 50 | return dict(track=self.files[index]) 51 | 52 | def __getitem__(self, index: int) -> DataDict: 53 | identifier = self.get_identifier(index) 54 | audio = self.get_audio(identifier) 55 | 56 | return {"audio": audio, "track": f"{self.split}/{identifier['track']}"} 57 | 58 | 59 | class MUSDB18FullTrackDataset(MUSDB18BaseDataset): 60 | 61 | N_TRAIN_TRACKS = 100 62 | N_TEST_TRACKS = 50 63 | VALIDATION_FILES = [ 64 | "Actions - One Minute Smile", 65 | "Clara Berry And Wooldog - Waltz For My Victims", 66 | "Johnny Lokke - Promises & Lies", 67 | "Patrick Talbot - A Reason To Leave", 68 | "Triviul - Angelsaint", 69 | "Alexander Ross - Goodbye Bolero", 70 | "Fergessen - Nos Palpitants", 71 | "Leaf - Summerghost", 72 | "Skelpolu - Human Mistakes", 73 | "Young Griffo - Pennies", 74 | "ANiMAL - Rockshow", 75 | "James May - On The Line", 76 | "Meaxic - Take A Step", 77 | "Traffic Experiment - Sirens", 78 | ] 79 | 80 | def __init__( 81 | self, data_root: str, split: str, stems: Optional[List[ 82 | str]] = None 83 | ) -> None: 84 | 85 | if stems is None: 86 | stems = self.ALLOWED_STEMS 87 | self.stems = stems 88 | 89 | if split == "test": 90 | subset = "test" 91 | elif split in ["train", "val"]: 92 | subset = "train" 93 | else: 94 | raise NameError 95 | 96 | data_path = os.path.join(data_root, subset) 97 | 98 | files = sorted(os.listdir(data_path)) 99 | files = [f for f in files if not f.startswith(".")] 100 | # pprint(list(enumerate(files))) 101 | if subset == "train": 102 | assert len(files) == 100, len(files) 103 | if split == "train": 104 | files = [f for f in files if f not in self.VALIDATION_FILES] 105 | assert len(files) == 100 - len(self.VALIDATION_FILES) 106 | else: 107 | files = [f for f in files if f in self.VALIDATION_FILES] 108 | assert len(files) == len(self.VALIDATION_FILES) 109 | else: 110 | split = "test" 111 | assert len(files) == 50 112 | 113 | self.n_tracks = len(files) 114 | 115 | super().__init__( 116 | data_path=data_path, 117 | split=split, 118 | stems=stems, 119 | files=files 120 | ) 121 | 122 | def __len__(self) -> int: 123 | return self.n_tracks 124 | 125 | class MUSDB18SadDataset(MUSDB18BaseDataset): 126 | def __init__( 127 | self, 128 | data_root: str, 129 | split: str, 130 | target_stem: str, 131 | stems: Optional[List[str]] = None, 132 | target_length: Optional[int] = None, 133 | npy_memmap=False, 134 | ) -> None: 135 | 136 | if stems is None: 137 | stems = self.ALLOWED_STEMS 138 | 139 | data_path = os.path.join(data_root, target_stem, split) 140 | 141 | files = sorted(os.listdir(data_path)) 142 | files = [f for f in files if not f.startswith(".")] 143 | 144 | super().__init__( 145 | data_path=data_path, 146 | split=split, 147 | stems=stems, 148 | files=files, 149 | npy_memmap=npy_memmap 150 | ) 151 | self.n_segments = len(files) 152 | self.target_stem = target_stem 153 | self.target_length = ( 154 | target_length if target_length is not None else self.n_segments 155 | ) 156 | 157 | def __len__(self) -> int: 158 | return self.target_length 159 | 160 | def __getitem__(self, index: int) -> DataDict: 161 | 162 | index = index % self.n_segments 163 | 164 | return super().__getitem__(index) 165 | 166 | def get_identifier(self, index): 167 | return super().get_identifier(index % self.n_segments) 168 | 169 | 170 | class MUSDB18SadOnTheFlyAugmentedDataset(MUSDB18SadDataset): 171 | def __init__( 172 | self, 173 | data_root: str, 174 | split: str, 175 | target_stem: str, 176 | stems: Optional[List[str]] = None, 177 | target_length: int = 20000, 178 | apply_probability: Optional[float] = None, 179 | chunk_size_second: float = 3.0, 180 | random_scale_range_db: Tuple[float, float] = (-10, 10), 181 | drop_probability: float = 0.1, 182 | rescale: bool = True, 183 | ) -> None: 184 | super().__init__(data_root, split, target_stem, stems) 185 | 186 | if apply_probability is None: 187 | apply_probability = ( 188 | target_length - self.n_segments) / target_length 189 | 190 | self.apply_probability = apply_probability 191 | self.drop_probability = drop_probability 192 | self.chunk_size_second = chunk_size_second 193 | self.random_scale_range_db = random_scale_range_db 194 | self.rescale = rescale 195 | 196 | self.chunk_size_sample = int(self.chunk_size_second * self.fs) 197 | self.target_length = target_length 198 | 199 | def __len__(self) -> int: 200 | return self.target_length 201 | 202 | def __getitem__(self, index: int) -> DataDict: 203 | 204 | index = index % self.n_segments 205 | 206 | # if np.random.rand() > self.apply_probability: 207 | # return super().__getitem__(index) 208 | 209 | audio = {} 210 | identifier = self.get_identifier(index) 211 | 212 | # assert self.target_stem in self.stems_no_mixture 213 | for stem in self.stems_no_mixture: 214 | if stem == self.target_stem: 215 | identifier_ = identifier 216 | else: 217 | if np.random.rand() < self.apply_probability: 218 | index_ = np.random.randint(self.n_segments) 219 | identifier_ = self.get_identifier(index_) 220 | else: 221 | identifier_ = identifier 222 | 223 | audio[stem] = self.get_stem(stem=stem, identifier=identifier_) 224 | 225 | # if stem == self.target_stem: 226 | 227 | if self.chunk_size_sample < audio[stem].shape[-1]: 228 | chunk_start = np.random.randint( 229 | audio[stem].shape[-1] - self.chunk_size_sample 230 | ) 231 | else: 232 | chunk_start = 0 233 | 234 | if np.random.rand() < self.drop_probability: 235 | # db_scale = "-inf" 236 | linear_scale = 0.0 237 | else: 238 | db_scale = np.random.uniform(*self.random_scale_range_db) 239 | linear_scale = np.power(10, db_scale / 20) 240 | # db_scale = f"{db_scale:+2.1f}" 241 | # print(linear_scale) 242 | audio[stem][..., 243 | chunk_start: chunk_start + self.chunk_size_sample] = ( 244 | linear_scale 245 | * audio[stem][..., 246 | chunk_start: chunk_start + self.chunk_size_sample] 247 | ) 248 | 249 | audio["mixture"] = self.compute_mixture(audio) 250 | 251 | if self.rescale: 252 | max_abs_val = max( 253 | [torch.max(torch.abs(audio[stem])) for stem in self.stems] 254 | ) # type: ignore[type-var] 255 | if max_abs_val > 1: 256 | audio = {k: v / max_abs_val for k, v in audio.items()} 257 | 258 | track = identifier["track"] 259 | 260 | return {"audio": audio, "track": f"{self.split}/{track}"} 261 | 262 | # if __name__ == "__main__": 263 | # 264 | # from pprint import pprint 265 | # from tqdm import tqdm 266 | # 267 | # for split_ in ["train", "val", "test"]: 268 | # ds = MUSDB18SadOnTheFlyAugmentedDataset( 269 | # data_root="$DATA_ROOT/MUSDB18/HQ/saded", 270 | # split=split_, 271 | # target_stem="vocals" 272 | # ) 273 | # 274 | # print(split_, len(ds)) 275 | # 276 | # for track_ in tqdm(ds): 277 | # track_["audio"] = { 278 | # k: v.shape for k, v in track_["audio"].items() 279 | # } 280 | # pprint(track_) 281 | --------------------------------------------------------------------------------