├── .gitignore ├── LICENSE ├── README.md ├── data └── .gitkeep ├── images └── architecture.png ├── logs └── .gitkeep ├── requirements.txt ├── scripts └── .gitkeep └── src ├── conf ├── callbacks │ └── default.yaml ├── dataset │ └── default.yaml ├── hydra │ └── default.yaml ├── logger │ └── default.yaml ├── loss │ └── default.yaml ├── metrics │ └── default.yaml ├── optimizer │ └── default.yaml ├── scheduler │ └── .gitkeep ├── separator │ ├── default.yaml │ └── net │ │ └── default.yaml └── train.yaml ├── data ├── dataset.py └── utils.py ├── evaluate.py ├── inference.py ├── model ├── modules │ ├── __init__.py │ ├── dualpath_rnn.py │ ├── sd_encoder.py │ └── su_decoder.py ├── scnet.py ├── separator.py └── utils.py ├── train.py └── utils ├── lightning.py ├── loss.py └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | **/*.ipynb_checkpoints/ 3 | **/*.pt 4 | **/*.ipynb 5 | *.pyc 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Amantur Amatov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCNet-Pytorch 2 | 3 | Unofficial PyTorch implementation of the paper 4 | ["SCNet: Sparse Compression Network for Music Source Separation"](https://arxiv.org/abs/2401.13276.pdf). 5 | 6 | ![architecture](images/architecture.png) 7 | 8 | --- 9 | ## Table of Contents 10 | 11 | 1. [Changelog & ToDo's](#changelog) 12 | 2. [Dependencies](#dependencies) 13 | 3. [Training](#train) 14 | 4. [Inference](#inference) 15 | 5. [Evaluation](#eval) 16 | 6. [Repository structure](#structure) 17 | 7. [Citing](#cite) 18 | 19 | --- 20 | 21 | 22 | # Changelog 23 | 24 | - **10.02.2024** 25 | - Model itself is finished. The train script is on its way. 26 | - **21.02.2024** 27 | - Add part of the training pipeline. 28 | - **02.03.2024** 29 | - Finish the training pipeline and the separator. 30 | - **17.03.2024** 31 | - Finish inference.py and fill README.md 32 | - **13.04.2024** 33 | - Finish evaluation pipeline and fill README.md. 34 | 35 | ## ToDo's 36 | 37 | - Add trained model. 38 | 39 | --- 40 | 41 | 42 | # Dependencies 43 | 44 | Before starting training, you need to install the requirements: 45 | 46 | ```bash 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | Then, download the MUSDB18HQ dataset: 51 | 52 | ```bash 53 | wget -P /path/to/dataset/musdb18hq.zip https://zenodo.org/records/3338373/files/musdb18hq.zip 54 | unzip /path/to/dataset/musdb18hq.zip -d /path/to/dataset 55 | ``` 56 | 57 | Next, create environment variables with paths to the audio data and generated metadata `.pqt` file: 58 | 59 | ```bash 60 | export DATASET_DIR=/path/to/dataset/musdb18hq 61 | export DATASET_PATH=/path/to/dataset/dataset.pqt 62 | ``` 63 | 64 | Finally, export the GPU to make it visible: 65 | ```bash 66 | export CUDA_VISIBLE_DEVICES=0 67 | ``` 68 | 69 | Now, you can train the model. 70 | 71 | --- 72 | 73 | 74 | # Training 75 | 76 | To train the model, a combination of `PyTorch-Lightning` and `hydra` was used. 77 | All configuration files are stored in the `src/conf` directory in `hydra`-friendly format. 78 | 79 | To start training a model with given configurations, use the following script: 80 | ``` 81 | python src/train.py 82 | ``` 83 | To configure the training process, follow `hydra` [instructions](https://hydra.cc/docs/advanced/override_grammar/basic/). 84 | You can modify/override the arguments doing something like this: 85 | ``` 86 | python src/train.py +trainer.overfit_batches=10 loader.train.batch_size=16 87 | ``` 88 | 89 | After training is started, the logging folder will be created for a particular experiment with the following path: 90 | ``` 91 | logs/scnet/${now:%Y-%m-%d}_${now:%H-%M}/ 92 | ``` 93 | This folder will have the following structure: 94 | ``` 95 | ├── checkpoints 96 | │ └── tensorboard_log_file - main tensorboard log file 97 | ├── tensorboard 98 | │ └── *.ckpt - lightning model checkpoint files. 99 | └── yamls 100 | │ └──*.yaml - hydra configuration and override files 101 | └── train.log - logging file for train.py 102 | 103 | ``` 104 | 105 | --- 106 | 107 | 108 | # Inference 109 | 110 | After training a model, you can run inference using the following command: 111 | 112 | ```bash 113 | python src/inference.py -i \ 114 | -o \ 115 | -c 116 | ``` 117 | 118 | This command will generate separated audio files in .wav format in the directory. 119 | 120 | For more information about the script and its options, use: 121 | ```bash 122 | usage: inference.py [-h] -i INPUT_PATH -o OUTPUT_PATH -c CKPT_PATH [-d DEVICE] [-b BATCH_SIZE] [-w WINDOW_SIZE] [-s STEP_SIZE] [-p] 123 | 124 | Argument Parser for Separator 125 | 126 | optional arguments: 127 | -h, --help show this help message and exit 128 | -i INPUT_PATH, --input-path INPUT_PATH 129 | Input path to .wav audio file/directory containing audio files 130 | -o OUTPUT_PATH, --output-path OUTPUT_PATH 131 | Output directory to save separated audio files in .wav format 132 | -c CKPT_PATH, --ckpt-path CKPT_PATH 133 | Path to the model checkpoint 134 | -d DEVICE, --device DEVICE 135 | Device to run the model on (default: cuda) 136 | -b BATCH_SIZE, --batch-size BATCH_SIZE 137 | Batch size for processing (default: 4) 138 | -w WINDOW_SIZE, --window-size WINDOW_SIZE 139 | Window size (default: 11) 140 | -s STEP_SIZE, --step-size STEP_SIZE 141 | Step size (default: 5.5) 142 | -p, --use-progress-bar 143 | Use progress bar (default: True) 144 | ``` 145 | 146 | Additionally, you can run inference within Python using the following script: 147 | ```python 148 | import sys 149 | sys.path.append('src/') 150 | 151 | import torchaudio 152 | from src.model.separator import Separator 153 | 154 | device: str = 'cuda' 155 | 156 | separator = Separator.load_from_checkpoint( 157 | path="", # path to trained Lightning checkpoint 158 | batch_size=4, # adjust batch size to fit into your GPU's memory 159 | window_size=11, # window size of the model (do not change) 160 | step_size=5.5, # as step size is closer to window size, inference will be faster, but results less good 161 | use_progress_bar=True # show progress bar per audio file 162 | ).to(device) 163 | 164 | y, sr = torchaudio.load("") 165 | y = y.to(device) 166 | 167 | y_separated = separator.separate(y).cpu() 168 | ``` 169 | 170 | Make sure to replace ``, ``, and `` with the appropriate paths for your setup. 171 | 172 | --- 173 | 174 | 175 | # Evaluation 176 | 177 | After training a model, you can run evaluation pipeline using the following command: 178 | 179 | ```bash 180 | python src/evaluate.py -c 181 | ``` 182 | 183 | This script uses defined checkpoint path `` and runs inference on `test` audio files from `DATASET_PATH` file. 184 | 185 | As a result, the script will output into console mean SDRs for each source. 186 | 187 | For more information about the script and its options, use: 188 | 189 | ```bash 190 | usage: evaluate.py [-h] -c CKPT_PATH [-d DEVICE] [-b BATCH_SIZE] [-w WINDOW_SIZE] [-s STEP_SIZE] 191 | 192 | Argument Parser for Separator 193 | 194 | optional arguments: 195 | -h, --help show this help message and exit 196 | -c CKPT_PATH, --ckpt-path CKPT_PATH 197 | Path to the model checkpoint 198 | -d DEVICE, --device DEVICE 199 | Device to run the model on (default: cuda) 200 | -b BATCH_SIZE, --batch-size BATCH_SIZE 201 | Batch size for processing (default: 4) 202 | -w WINDOW_SIZE, --window-size WINDOW_SIZE 203 | Window size (default: 11) 204 | -s STEP_SIZE, --step-size STEP_SIZE 205 | Step size (default: 5.5) 206 | ``` 207 | 208 | --- 209 | 210 | 211 | # Citing 212 | 213 | To cite this paper, please use: 214 | ``` 215 | @misc{tong2024scnet, 216 | title={SCNet: Sparse Compression Network for Music Source Separation}, 217 | author={Weinan Tong and Jiaxu Zhu and Jun Chen and Shiyin Kang and Tao Jiang and Yang Li and Zhiyong Wu and Helen Meng}, 218 | year={2024}, 219 | eprint={2401.13276}, 220 | archivePrefix={arXiv}, 221 | primaryClass={eess.AS} 222 | } 223 | ``` 224 | 225 | 226 | 227 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/data/.gitkeep -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/images/architecture.png -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/logs/.gitkeep -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core~=1.3.2 2 | omegaconf~=2.2.3 3 | pandas~=1.5.3 4 | torch~=2.1.2 5 | torchaudio~=2.1.2 6 | lightning~=2.2.0 7 | soundfile==0.12.1 8 | tqdm==4.64.1 9 | fast-bss-eval==0.1.4 -------------------------------------------------------------------------------- /scripts/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/scripts/.gitkeep -------------------------------------------------------------------------------- /src/conf/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | lr_monitor: 2 | _target_: lightning.pytorch.callbacks.LearningRateMonitor 3 | logging_interval: epoch 4 | model_ckpt: 5 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 6 | monitor: val/usdr 7 | mode: max 8 | save_top_k: 3 9 | dirpath: ${hydra:runtime.output_dir}/checkpoints 10 | filename: epoch{epoch:02d}-val_usdr{val/usdr:.2f} 11 | auto_insert_metric_name: False -------------------------------------------------------------------------------- /src/conf/dataset/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: data.dataset.SourceSeparationDataset 2 | subset: train 3 | dataset_dirs: ${oc.env:DATASET_DIR} 4 | dataset_path: ${oc.env:DATASET_PATH} 5 | dataset_extension: wav 6 | window_size: 11 7 | step_size: 1 8 | sample_rate: 44100 9 | sources: ${sources} -------------------------------------------------------------------------------- /src/conf/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | job: 2 | name: scnet 3 | output_subdir: yamls 4 | run: 5 | dir: logs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M} 6 | sweep: 7 | dir: logs/${hydra.job.name}_multirun/${now:%Y-%m-%d}_${now:%H-%M} 8 | subdir: ${hydra.job.num} 9 | job_logging: 10 | handlers: 11 | file: 12 | filename: ${hydra.runtime.output_dir}/train.log -------------------------------------------------------------------------------- /src/conf/logger/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.loggers.TensorBoardLogger 2 | save_dir: ${hydra:runtime.output_dir}/tensorboard 3 | name: '' 4 | version: '' 5 | default_hp_metric: False -------------------------------------------------------------------------------- /src/conf/loss/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: utils.loss.RMSELoss -------------------------------------------------------------------------------- /src/conf/metrics/default.yaml: -------------------------------------------------------------------------------- 1 | usdr: 2 | _target_: utils.metrics.GlobalSignalDistortionRatio 3 | epsilon: 1e-7 -------------------------------------------------------------------------------- /src/conf/optimizer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 5e-4 -------------------------------------------------------------------------------- /src/conf/scheduler/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/src/conf/scheduler/.gitkeep -------------------------------------------------------------------------------- /src/conf/separator/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - net: default 3 | 4 | window_size: 11 5 | step_size: 5.5 6 | sample_rate: 44100 7 | batch_size: 4 8 | 9 | return_spec: True 10 | 11 | stft: 12 | _target_: torchaudio.transforms.Spectrogram 13 | n_fft: 4096 14 | win_length: 4096 15 | hop_length: 1024 16 | power: null 17 | 18 | istft: 19 | _target_: torchaudio.transforms.InverseSpectrogram 20 | n_fft: ${separator.stft.n_fft} 21 | win_length: ${separator.stft.win_length} 22 | hop_length: ${separator.stft.hop_length} -------------------------------------------------------------------------------- /src/conf/separator/net/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.scnet.SCNet 2 | _convert_: object 3 | n_fft: ${separator.stft.n_fft} 4 | dims: [4, 32, 64, 128] 5 | bandsplit_ratios: [.175, .392, .433] 6 | downsample_strides: [1, 4, 16] 7 | n_conv_modules: [3, 2, 1] 8 | n_rnn_layers: 6 9 | rnn_hidden_dim: 128 10 | n_sources: 4 -------------------------------------------------------------------------------- /src/conf/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - dataset: default 4 | - separator: default 5 | 6 | - loss: default 7 | - optimizer: default 8 | - scheduler: null 9 | 10 | - metrics: default 11 | - logger: default 12 | - callbacks: default 13 | - hydra: default 14 | 15 | 16 | seed: 42 17 | sources: [drums, bass, other, vocals] 18 | output_dir: ${hydra:runtime.output_dir} 19 | 20 | use_validation: True 21 | train_val_split: 22 | lengths: [86, 14] 23 | loader: 24 | train: 25 | batch_size: 8 26 | num_workers: 8 27 | shuffle: True 28 | drop_last: True 29 | validation: 30 | batch_size: 4 31 | num_workers: 8 32 | shuffle: False 33 | drop_last: False 34 | 35 | trainer: 36 | _target_: lightning.pytorch.Trainer 37 | fast_dev_run: False 38 | accelerator: cuda 39 | max_epochs: 130 40 | check_val_every_n_epoch: 1 41 | num_sanity_val_steps: 5 42 | log_every_n_steps: 100 43 | devices: 1 44 | gradient_clip_val: 5 45 | precision: 32 46 | enable_progress_bar: True 47 | benchmark: False 48 | deterministic: False 49 | -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import pandas as pd 4 | import torch 5 | import torchaudio 6 | from torch.utils.data import Dataset, Subset, random_split 7 | 8 | from data.utils import construct_dataset 9 | 10 | 11 | class SourceSeparationDataset(Dataset): 12 | """ 13 | Dataset class for source separation tasks. 14 | 15 | Args: 16 | - dataset_dirs (str): Comma-separated paths to directories where datasets are located. 17 | - subset (str): Subset of the dataset to load. 18 | - window_size (int): Size of the sliding window in seconds. 19 | - step_size (int): Step size of the sliding window in seconds. 20 | - sample_rate (int): Sample rate of the audio. 21 | - dataset_extension (str, optional): Extension of the dataset files. Defaults to 'wav'. 22 | - dataset_path (str, optional): Path to cache the dataset. Defaults to None. 23 | If specified, dataset metadata will be cached in parquet format for future retrieval. 24 | Defaults to None. 25 | - mixture_name (str, optional): Name of the mixture. Defaults to None. 26 | - sources (List[str], optional): List of source names. Defaults to ['drums', 'bass', 'other', 'vocals']. 27 | """ 28 | 29 | MIXTURE_NAME: str = "mixture" 30 | SOURCE_NAMES: List[str] = ["drums", "bass", "other", "vocals"] 31 | 32 | def __init__( 33 | self, 34 | dataset_dirs: str, 35 | subset: str, 36 | window_size: int, 37 | step_size: int, 38 | sample_rate: int, 39 | dataset_extension: str = "wav", 40 | dataset_path: Optional[str] = None, 41 | mixture_name: Optional[str] = None, 42 | sources: Optional[List[str]] = None, 43 | ): 44 | """ 45 | Initializes the SourceSeparationDataset. 46 | """ 47 | super().__init__() 48 | 49 | self.dataset_dirs: List[str] = dataset_dirs.split(",") 50 | self.subset = subset 51 | self.dataset_extension = dataset_extension 52 | self.dataset_path = dataset_path 53 | 54 | self.window_size = int(window_size * sample_rate) 55 | self.step_size = int(step_size * sample_rate) 56 | self.sample_rate = sample_rate 57 | 58 | self.mixture_name = mixture_name or self.MIXTURE_NAME 59 | self.sources = sources or self.SOURCE_NAMES 60 | 61 | self.df = self.load_df() 62 | self.segment_ids, self.track_ids = self.get_ids() 63 | 64 | def generate_offsets( 65 | self, 66 | total_frames: int, 67 | ) -> List[int]: 68 | """ 69 | Generates the offsets based on total frames of track, window size and step size of segments. 70 | 71 | Args: 72 | - total_frames (int): Total number of frames of audio. 73 | 74 | Returns: 75 | - List[int]: List of offsets. 76 | """ 77 | return [ 78 | start 79 | for start in range(0, total_frames - self.window_size + 1, self.step_size) 80 | ] 81 | 82 | def load_df(self) -> pd.DataFrame: 83 | """ 84 | Loads the DataFrame based on the train/test subset and populates data based on window/step sizes. 85 | 86 | Returns: 87 | - pd.DataFrame: Loaded DataFrame. 88 | """ 89 | df = construct_dataset( 90 | self.dataset_dirs, 91 | extension=self.dataset_extension, 92 | save_path=self.dataset_path, 93 | ) 94 | df = df[df["subset"].eq(self.subset)] 95 | df["offset"] = df["total_frames"].apply(self.generate_offsets) 96 | df = df.explode("offset") 97 | df["track_id"] = df["track_name"].factorize()[0] 98 | df["segment_id"] = df.set_index(["track_id", "offset"]).index.factorize()[0] 99 | return df 100 | 101 | def get_ids(self) -> Tuple[List[int], List[int]]: 102 | """ 103 | Gets the segment and track IDs. 104 | 105 | Returns: 106 | - Tuple[List[int], List[int]]: Tuple containing segment and track IDs. 107 | """ 108 | segment_ids = self.df["segment_id"].tolist() 109 | track_ids = self.df["track_id"].unique().tolist() 110 | return segment_ids, track_ids 111 | 112 | def load_audio(self, segment_info: Dict[str, Any]) -> torch.Tensor: 113 | """ 114 | Loads the audio based on segment information. 115 | 116 | Args: 117 | - segment_info (Dict[str, Any]): Segment information. 118 | 119 | Returns: 120 | - torch.Tensor: Loaded audio tensor. 121 | """ 122 | audio, sr = torchaudio.load( 123 | segment_info["path"], 124 | num_frames=self.window_size, 125 | frame_offset=segment_info["offset"], 126 | ) 127 | assert ( 128 | sr == self.sample_rate 129 | ), f"Sample rate of the audio should be {self.sample_rate}Hz instead of {sr}Hz." 130 | return audio 131 | 132 | def load_mixture(self, idx: int) -> torch.Tensor: 133 | """ 134 | Loads the audio mixture based on the provided index. 135 | 136 | Args: 137 | - idx (int): Index of the mixture. 138 | 139 | Returns: 140 | - torch.Tensor: Loaded audio mixture tensor. 141 | """ 142 | segment_info = ( 143 | self.df[ 144 | self.df["segment_id"].eq(idx) 145 | & self.df["source_type"].eq(self.mixture_name) 146 | ] 147 | .iloc[0] 148 | .to_dict() 149 | ) 150 | audio = self.load_audio(segment_info) 151 | return audio 152 | 153 | def load_sources(self, idx: int) -> torch.Tensor: 154 | """ 155 | Loads the separated sources based on the provided index. 156 | 157 | Args: 158 | - idx (int): Index of the source. 159 | 160 | Returns: 161 | - torch.Tensor: Loaded and stacked audio sources tensor. 162 | """ 163 | audios = [] 164 | for source in self.sources: 165 | segment_info = ( 166 | self.df[ 167 | self.df["segment_id"].eq(idx) & self.df["source_type"].eq(source) 168 | ] 169 | .iloc[0] 170 | .to_dict() 171 | ) 172 | audio = self.load_audio(segment_info) 173 | audios.append(audio) 174 | return torch.stack(audios) 175 | 176 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 177 | """ 178 | Retrieves an item from the dataset based on the index. 179 | 180 | Args: 181 | - idx (int): Index of the item. 182 | 183 | Returns: 184 | - Dict[str, torch.Tensor]: Dictionary containing mixture and sources. 185 | """ 186 | segment_id = self.segment_ids[idx] 187 | 188 | mixture = self.load_mixture(segment_id) 189 | sources = self.load_sources(segment_id) 190 | 191 | return { 192 | "mixture": mixture, 193 | "sources": sources, 194 | } 195 | 196 | def __len__(self) -> int: 197 | """ 198 | Returns the length of the dataset. 199 | 200 | Returns: 201 | - int: Length of the dataset. 202 | """ 203 | return len(self.segment_ids) 204 | 205 | def get_train_val_split( 206 | self, lengths: List[float], seed: Optional[int] = None 207 | ) -> Tuple[Subset, Subset]: 208 | """ 209 | Splits the dataset into training and validation subsets. 210 | 211 | Args: 212 | - lengths (List[float]): List containing the lengths of the training and validation subsets. 213 | - seed (Optional[int]): Random seed for reproducibility. Defaults to None. 214 | 215 | Returns: 216 | - Tuple[Subset, Subset]: Tuple containing the training and validation subsets. 217 | """ 218 | assert ( 219 | self.subset == "train" 220 | ), "Only train subset of the dataset can be split into train and val." 221 | assert len(lengths) == 2, "Dataset can be only split into two subset." 222 | generator = torch.Generator().manual_seed(seed) if seed is not None else None 223 | 224 | train_track_ids, val_track_ids = random_split( 225 | self.track_ids, lengths=lengths, generator=generator 226 | ) 227 | train_segment_ids = self.df[ 228 | self.df.track_id.isin(train_track_ids.indices) 229 | ].segment_id.to_list() 230 | val_segment_ids = self.df[ 231 | self.df.track_id.isin(val_track_ids.indices) 232 | ].segment_id.to_list() 233 | 234 | train_subset = Subset(self, indices=train_segment_ids) 235 | val_subset = Subset(self, indices=val_segment_ids) 236 | 237 | return train_subset, val_subset 238 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | import pandas as pd 6 | import torchaudio 7 | 8 | PARQUET_EXTENSIONS: List[str] = [".pqt", ".parquet"] 9 | 10 | 11 | def construct_dataset( 12 | dataset_dirs: List[str], extension: str = "wav", save_path: Optional[str] = None 13 | ) -> pd.DataFrame: 14 | """ 15 | Constructs a dataset DataFrame from audio files in the specified directories. 16 | 17 | Args: 18 | - dataset_dirs (List[str]): List of directories containing audio files. 19 | - extension (str): Extension of the audio files to consider. Defaults to "wav". 20 | - save_path (Optional[str]): Optional path to save the constructed DataFrame as a parquet file. 21 | 22 | Returns: 23 | - pd.DataFrame: DataFrame containing information about the audio files. 24 | """ 25 | if save_path is not None and not Path(save_path).suffix in PARQUET_EXTENSIONS: 26 | raise ValueError("'save_path' should be in .parquet/.pqt format.") 27 | 28 | if save_path is not None and Path(save_path).is_file(): 29 | dataset_df = pd.read_parquet(save_path) 30 | return dataset_df 31 | 32 | if not isinstance(dataset_dirs, list): 33 | raise TypeError( 34 | f"'dataset_dirs' should be a list of strings, but got {type(dataset_dirs)}" 35 | ) 36 | 37 | dataset = [] 38 | for dataset_dir in dataset_dirs: 39 | for path in Path(dataset_dir).glob(f"**/*.{extension}"): 40 | abs_path = str(path.resolve()) 41 | source_type = path.stem 42 | subset, track_name = path.relative_to(dataset_dir).parts[:2] 43 | audio_info = torchaudio.info(path) 44 | dataset.append( 45 | ( 46 | abs_path, 47 | track_name, 48 | source_type, 49 | subset, 50 | audio_info.num_frames, 51 | audio_info.num_frames / audio_info.sample_rate, 52 | audio_info.sample_rate, 53 | audio_info.num_channels, 54 | ) 55 | ) 56 | columns = [ 57 | "path", 58 | "track_name", 59 | "source_type", 60 | "subset", 61 | "total_frames", 62 | "total_seconds", 63 | "sample_rate", 64 | "num_channels", 65 | ] 66 | dataset_df = pd.DataFrame(dataset, columns=columns) 67 | if save_path is not None: 68 | dataset_df.to_parquet(save_path, index=False) 69 | return dataset_df 70 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | from typing import Dict, Iterable, List, Tuple 5 | 6 | import fast_bss_eval 7 | import pandas as pd 8 | import torchaudio 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | 12 | from model.separator import Separator 13 | 14 | SOURCES: List[str] = ["drums", "bass", "other", "vocals"] 15 | 16 | logging.basicConfig( 17 | level=logging.INFO, 18 | format="%(asctime)s - %(levelname)s - %(message)s", 19 | datefmt="%Y-%m-%d %H:%M:%S", 20 | ) 21 | 22 | log = logging.getLogger(__name__) 23 | 24 | 25 | def parse_arguments(): 26 | """ 27 | Parse command line arguments. 28 | 29 | Returns: 30 | argparse.Namespace: Parsed arguments. 31 | """ 32 | parser = argparse.ArgumentParser(description="Argument Parser for Separator") 33 | parser.add_argument( 34 | "-c", 35 | "--ckpt-path", 36 | type=str, 37 | required=True, 38 | help="Path to the model checkpoint", 39 | ) 40 | parser.add_argument( 41 | "-d", 42 | "--device", 43 | type=str, 44 | default="cuda", 45 | help="Device to run the model on (default: cuda)", 46 | ) 47 | parser.add_argument( 48 | "-b", 49 | "--batch-size", 50 | type=int, 51 | default=4, 52 | help="Batch size for processing (default: 4)", 53 | ) 54 | parser.add_argument( 55 | "-w", "--window-size", type=int, default=11, help="Window size (default: 11)" 56 | ) 57 | parser.add_argument( 58 | "-s", "--step-size", type=float, default=5.5, help="Step size (default: 5.5)" 59 | ) 60 | return parser.parse_args() 61 | 62 | 63 | def load_data(dataset_path: str) -> Iterable[Tuple[str, str, Dict[str, str]]]: 64 | """ 65 | Load data from the dataset. 66 | 67 | Args: 68 | dataset_path (str): Path to the dataset. 69 | 70 | Yields: 71 | Tuple[str, str, Dict[str, str]]: Tuple containing track name, mixture path, and source paths. 72 | """ 73 | df = pd.read_parquet(dataset_path) 74 | df = df[df["subset"].eq("test")] 75 | track_names = df["track_name"].unique() 76 | for track_name in tqdm(track_names): 77 | rows = df[df["track_name"].eq(track_name)] 78 | mixture_path = rows[rows["source_type"].eq("mixture")]["path"].values[0] 79 | source_paths = ( 80 | rows[~rows["source_type"].eq("mixture")] 81 | .set_index("source_type")["path"] 82 | .to_dict() 83 | ) 84 | yield track_name, mixture_path, source_paths 85 | 86 | 87 | def compute_sdrs(separator: nn.Module, dataset_path: str, device: str) -> str: 88 | """ 89 | Compute evaluation SDRs. 90 | 91 | Args: 92 | separator (nn.Module): Separator model. 93 | dataset_path (str): Path to the dataset.pqt. 94 | device (str): Device to send tensors on. 95 | 96 | Returns: 97 | str: Evaluation SDRs table. 98 | """ 99 | sdrs = [] 100 | for track_name, mixture_path, source_paths in load_data(dataset_path): 101 | y, sr = torchaudio.load(mixture_path) 102 | y_separated = separator.separate(y.to(device)).cpu() 103 | for y_source_est, source_type in zip(y_separated, SOURCES): 104 | y_source_ref, _ = torchaudio.load(source_paths[source_type]) 105 | sdr, *_ = fast_bss_eval.bss_eval_sources( 106 | y_source_ref, y_source_est, compute_permutation=False, load_diag=1e-7 107 | ) 108 | sdrs.append((track_name, source_type, sdr.mean().item())) 109 | sdrs_df = pd.DataFrame(sdrs, columns=["track_name", "source_type", "sdr"]) 110 | 111 | return sdrs_df.groupby("source_type")["sdr"].mean().reset_index(name="sdr").to_string() 112 | 113 | 114 | def main(): 115 | args = parse_arguments() 116 | 117 | dataset_path = os.getenv("DATASET_PATH") 118 | if dataset_path is None: 119 | raise ValueError("DATASET_PATH environment variable is not set.") 120 | 121 | log.info(f"Initializing Separator with following checkpoint {args.ckpt_path}...") 122 | separator = Separator.load_from_checkpoint( 123 | path=args.ckpt_path, 124 | batch_size=args.batch_size, 125 | window_size=args.window_size, 126 | step_size=args.step_size, 127 | ).to(args.device) 128 | 129 | log.info("Starting evaluation...") 130 | metrics = compute_sdrs(separator, dataset_path, args.device) 131 | log.info(f"Evaluation completed with following metrics:\n{metrics}") 132 | 133 | 134 | if __name__ == "__main__": 135 | main() 136 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | from typing import Iterable, List, Tuple 5 | 6 | import soundfile as sf 7 | import torch.nn as nn 8 | import torchaudio 9 | 10 | from model.separator import Separator 11 | 12 | SOURCES: List[str] = ["drums", "bass", "other", "vocals"] 13 | SAVE_SAMPLE_RATE: int = 44100 14 | 15 | logging.basicConfig( 16 | level=logging.INFO, 17 | format="%(asctime)s - %(levelname)s - %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | ) 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | 24 | def parse_arguments(): 25 | """ 26 | Parse command line arguments. 27 | 28 | Returns: 29 | argparse.Namespace: Parsed arguments. 30 | """ 31 | parser = argparse.ArgumentParser(description="Argument Parser for Separator") 32 | parser.add_argument( 33 | "-i", 34 | "--input-path", 35 | type=str, 36 | required=True, 37 | help="Input path to .wav audio file/directory containing audio files", 38 | ) 39 | parser.add_argument( 40 | "-o", 41 | "--output-path", 42 | type=str, 43 | required=True, 44 | help="Output directory to save separated audio files in .wav format", 45 | ) 46 | parser.add_argument( 47 | "-c", 48 | "--ckpt-path", 49 | type=str, 50 | required=True, 51 | help="Path to the model checkpoint", 52 | ) 53 | parser.add_argument( 54 | "-d", 55 | "--device", 56 | type=str, 57 | default="cuda", 58 | help="Device to run the model on (default: cuda)", 59 | ) 60 | parser.add_argument( 61 | "-b", 62 | "--batch-size", 63 | type=int, 64 | default=4, 65 | help="Batch size for processing (default: 4)", 66 | ) 67 | parser.add_argument( 68 | "-w", "--window-size", type=int, default=11, help="Window size (default: 11)" 69 | ) 70 | parser.add_argument( 71 | "-s", "--step-size", type=float, default=5.5, help="Step size (default: 5.5)" 72 | ) 73 | parser.add_argument( 74 | "-p", 75 | "--use-progress-bar", 76 | action="store_true", 77 | help="Use progress bar (default: True)", 78 | ) 79 | return parser.parse_args() 80 | 81 | 82 | def load_paths(input_path: str, output_path: str) -> Iterable[Tuple[Path, Path]]: 83 | """ 84 | Load input and output paths. 85 | 86 | Args: 87 | input_path (str): Input path to audio files. 88 | output_path (str): Output directory to save separated audio files. 89 | 90 | Yields: 91 | Tuple[Path, Path]: Tuple of input and output file paths. 92 | """ 93 | input_path = Path(input_path) 94 | output_path = Path(output_path) 95 | 96 | if not input_path.exists(): 97 | raise FileNotFoundError(f"Input path '{input_path}' does not exist.") 98 | 99 | if input_path.is_file(): 100 | if not (input_path.suffix == ".wav" or input_path.suffix == ".mp3"): 101 | raise ValueError("Input audio file should be in .wav or .mp3 formats.") 102 | fp_out = output_path / input_path.stem 103 | fp_out.mkdir(exist_ok=True, parents=True) 104 | yield input_path, fp_out 105 | elif input_path.is_dir(): 106 | for fp_in in input_path.glob("*"): 107 | if fp_in.suffix in (".wav", ".mp3"): 108 | fp_out = output_path / fp_in.stem 109 | fp_out.mkdir(exist_ok=True, parents=True) 110 | yield fp_in, fp_out 111 | else: 112 | raise ValueError( 113 | f"Input path '{input_path}' is neither a file nor a directory." 114 | ) 115 | 116 | 117 | def process_files( 118 | separator: nn.Module, device: str, input_path: str, output_path: str 119 | ) -> None: 120 | for fp_in, fp_out in load_paths(input_path, output_path): 121 | y, sr = torchaudio.load(fp_in) 122 | y_separated = separator.separate(y.to(device)).cpu() 123 | for y_source, source in zip(y_separated, SOURCES): 124 | sf.write(f"{fp_out}/{source}.wav", y_source.T, SAVE_SAMPLE_RATE) 125 | 126 | 127 | def main(): 128 | args = parse_arguments() 129 | 130 | log.info(f"Initializing Separator with following checkpoint {args.ckpt_path}...") 131 | separator = Separator.load_from_checkpoint( 132 | path=args.ckpt_path, 133 | batch_size=args.batch_size, 134 | window_size=args.window_size, 135 | step_size=args.step_size, 136 | use_progress_bar=args.use_progress_bar, 137 | ).to(args.device) 138 | 139 | log.info("Processing audio files...") 140 | process_files(separator, args.device, args.input_path, args.output_path) 141 | log.info(f"Audio files processing completed.") 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /src/model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .dualpath_rnn import DualPathRNN 2 | from .sd_encoder import SDBlock 3 | from .su_decoder import SUBlock 4 | -------------------------------------------------------------------------------- /src/model/modules/dualpath_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RNNModule(nn.Module): 6 | """ 7 | RNNModule class implements a recurrent neural network module with LSTM cells. 8 | 9 | Args: 10 | - input_dim (int): Dimensionality of the input features. 11 | - hidden_dim (int): Dimensionality of the hidden state of the LSTM. 12 | - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True. 13 | 14 | Shapes: 15 | - Input: (B, T, D) where 16 | B is batch size, 17 | T is sequence length, 18 | D is input dimensionality. 19 | - Output: (B, T, D) where 20 | B is batch size, 21 | T is sequence length, 22 | D is input dimensionality. 23 | """ 24 | 25 | def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True): 26 | """ 27 | Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag. 28 | """ 29 | super().__init__() 30 | self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim) 31 | self.rnn = nn.LSTM( 32 | input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional 33 | ) 34 | self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim) 35 | 36 | def forward(self, x: torch.Tensor) -> torch.Tensor: 37 | """ 38 | Performs forward pass through the RNNModule. 39 | 40 | Args: 41 | - x (torch.Tensor): Input tensor of shape (B, T, D). 42 | 43 | Returns: 44 | - torch.Tensor: Output tensor of shape (B, T, D). 45 | """ 46 | x = x.transpose(1, 2) 47 | x = self.groupnorm(x) 48 | x = x.transpose(1, 2) 49 | 50 | x, (hidden, _) = self.rnn(x) 51 | x = self.fc(x) 52 | return x 53 | 54 | 55 | class RFFTModule(nn.Module): 56 | """ 57 | RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT) 58 | or its inverse on input tensors. 59 | 60 | Args: 61 | - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False. 62 | 63 | Shapes: 64 | - Input: (B, F, T, D) where 65 | B is batch size, 66 | F is the number of features, 67 | T is sequence length, 68 | D is input dimensionality. 69 | - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT. 70 | (B, F, T, D // 2, 2) if performing inverse FFT. 71 | """ 72 | 73 | def __init__(self, inverse: bool = False): 74 | """ 75 | Initializes RFFTModule with inverse flag. 76 | """ 77 | super().__init__() 78 | self.inverse = inverse 79 | 80 | def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor: 81 | """ 82 | Performs forward or inverse FFT on the input tensor x. 83 | 84 | Args: 85 | - x (torch.Tensor): Input tensor of shape (B, F, T, D). 86 | - time_dim (int): Input size of time dimension. 87 | 88 | Returns: 89 | - torch.Tensor: Output tensor after FFT or its inverse operation. 90 | """ 91 | B, F, T, D = x.shape 92 | dtype = x.dtype 93 | # in case of training in fp16/bf16 and tensor is not a power of 2, tensor will be sent to the float32 94 | if dtype != torch.float and (T & (T - 1)): 95 | x = x.float() 96 | if not self.inverse: 97 | x = torch.fft.rfft(x, dim=2) 98 | x = torch.view_as_real(x) 99 | x = x.reshape(B, F, T // 2 + 1, D * 2) 100 | else: 101 | x = x.reshape(B, F, T, D // 2, 2) 102 | x = torch.view_as_complex(x) 103 | x = torch.fft.irfft(x, n=time_dim, dim=2) 104 | x = x.to(dtype) 105 | return x 106 | 107 | def extra_repr(self) -> str: 108 | """ 109 | Returns extra representation string with module's configuration. 110 | """ 111 | return f"inverse={self.inverse}" 112 | 113 | 114 | class DualPathRNN(nn.Module): 115 | """ 116 | DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule. 117 | 118 | Args: 119 | - n_layers (int): Number of layers in the network. 120 | - input_dim (int): Dimensionality of the input features. 121 | - hidden_dim (int): Dimensionality of the hidden state of the RNNModule. 122 | 123 | Shapes: 124 | - Input: (B, F, T, D) where 125 | B is batch size, 126 | F is the number of features (frequency dimension), 127 | T is sequence length (time dimension), 128 | D is input dimensionality (channel dimension). 129 | - Output: (B, F, T, D) where 130 | B is batch size, 131 | F is the number of features (frequency dimension), 132 | T is sequence length (time dimension), 133 | D is input dimensionality (channel dimension). 134 | """ 135 | 136 | def __init__( 137 | self, 138 | n_layers: int, 139 | input_dim: int, 140 | hidden_dim: int, 141 | ): 142 | """ 143 | Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension. 144 | """ 145 | super().__init__() 146 | 147 | self.layers = nn.ModuleList() 148 | for i in range(1, n_layers + 1): 149 | if i % 2 == 1: 150 | layer = nn.ModuleList( 151 | [ 152 | RNNModule(input_dim=input_dim, hidden_dim=hidden_dim), 153 | RNNModule(input_dim=input_dim, hidden_dim=hidden_dim), 154 | RFFTModule(inverse=False), 155 | ] 156 | ) 157 | else: 158 | layer = nn.ModuleList( 159 | [ 160 | RNNModule(input_dim=input_dim * 2, hidden_dim=hidden_dim * 2), 161 | RNNModule(input_dim=input_dim * 2, hidden_dim=hidden_dim * 2), 162 | RFFTModule(inverse=True), 163 | ] 164 | ) 165 | self.layers.append(layer) 166 | 167 | def forward(self, x: torch.Tensor) -> torch.Tensor: 168 | """ 169 | Performs forward pass through the DualPathRNN. 170 | 171 | Args: 172 | - x (torch.Tensor): Input tensor of shape (B, F, T, D). 173 | 174 | Returns: 175 | - torch.Tensor: Output tensor of shape (B, F, T, D). 176 | """ 177 | time_dim = x.shape[2] 178 | 179 | for time_layer, freq_layer, rfft_layer in self.layers: 180 | B, F, T, D = x.shape 181 | 182 | x = x.reshape((B * F), T, D) 183 | x = time_layer(x) 184 | x = x.reshape(B, F, T, D) 185 | x = x.permute(0, 2, 1, 3) 186 | 187 | x = x.reshape((B * T), F, D) 188 | x = freq_layer(x) 189 | x = x.reshape(B, T, F, D) 190 | x = x.permute(0, 2, 1, 3) 191 | 192 | x = rfft_layer(x, time_dim) 193 | return x 194 | -------------------------------------------------------------------------------- /src/model/modules/sd_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from model.utils import create_intervals 7 | 8 | 9 | class Downsample(nn.Module): 10 | """ 11 | Downsample class implements a module for downsampling input tensors using 2D convolution. 12 | 13 | Args: 14 | - input_dim (int): Dimensionality of the input channels. 15 | - output_dim (int): Dimensionality of the output channels. 16 | - stride (int): Stride value for the convolution operation. 17 | 18 | Shapes: 19 | - Input: (B, C_in, F, T) where 20 | B is batch size, 21 | C_in is the number of input channels, 22 | F is the frequency dimension, 23 | T is the time dimension. 24 | - Output: (B, C_out, F // stride, T) where 25 | B is batch size, 26 | C_out is the number of output channels, 27 | F // stride is the downsampled frequency dimension. 28 | 29 | """ 30 | 31 | def __init__( 32 | self, 33 | input_dim: int, 34 | output_dim: int, 35 | stride: int, 36 | ): 37 | """ 38 | Initializes Downsample with input dimension, output dimension, and stride. 39 | """ 40 | super().__init__() 41 | self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1)) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Performs forward pass through the Downsample module. 46 | 47 | Args: 48 | - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). 49 | 50 | Returns: 51 | - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T). 52 | """ 53 | return self.conv(x) 54 | 55 | 56 | class ConvolutionModule(nn.Module): 57 | """ 58 | ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer. 59 | 60 | Args: 61 | - input_dim (int): Dimensionality of the input features. 62 | - hidden_dim (int): Dimensionality of the hidden features. 63 | - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. 64 | - bias (bool, optional): If True, adds a learnable bias to the output. Default is False. 65 | 66 | Shapes: 67 | - Input: (B, T, D) where 68 | B is batch size, 69 | T is sequence length, 70 | D is input dimensionality. 71 | - Output: (B, T, D) where 72 | B is batch size, 73 | T is sequence length, 74 | D is input dimensionality. 75 | """ 76 | 77 | def __init__( 78 | self, 79 | input_dim: int, 80 | hidden_dim: int, 81 | kernel_sizes: List[int], 82 | bias: bool = False, 83 | ) -> None: 84 | """ 85 | Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias. 86 | """ 87 | super().__init__() 88 | self.sequential = nn.Sequential( 89 | nn.GroupNorm(num_groups=1, num_channels=input_dim), 90 | nn.Conv1d( 91 | input_dim, 92 | 2 * hidden_dim, 93 | kernel_sizes[0], 94 | stride=1, 95 | padding=(kernel_sizes[0] - 1) // 2, 96 | bias=bias, 97 | ), 98 | nn.GLU(dim=1), 99 | nn.Conv1d( 100 | hidden_dim, 101 | hidden_dim, 102 | kernel_sizes[1], 103 | stride=1, 104 | padding=(kernel_sizes[1] - 1) // 2, 105 | groups=hidden_dim, 106 | bias=bias, 107 | ), 108 | nn.GroupNorm(num_groups=1, num_channels=hidden_dim), 109 | nn.SiLU(), 110 | nn.Conv1d( 111 | hidden_dim, 112 | input_dim, 113 | kernel_sizes[2], 114 | stride=1, 115 | padding=(kernel_sizes[2] - 1) // 2, 116 | bias=bias, 117 | ), 118 | ) 119 | 120 | def forward(self, x: torch.Tensor) -> torch.Tensor: 121 | """ 122 | Performs forward pass through the ConvolutionModule. 123 | 124 | Args: 125 | - x (torch.Tensor): Input tensor of shape (B, T, D). 126 | 127 | Returns: 128 | - torch.Tensor: Output tensor of shape (B, T, D). 129 | """ 130 | x = x.transpose(1, 2) 131 | x = x + self.sequential(x) 132 | x = x.transpose(1, 2) 133 | return x 134 | 135 | 136 | class SDLayer(nn.Module): 137 | """ 138 | SDLayer class implements a subband decomposition layer with downsampling and convolutional modules. 139 | 140 | Args: 141 | - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition. 142 | - input_dim (int): Dimensionality of the input channels. 143 | - output_dim (int): Dimensionality of the output channels after downsampling. 144 | - downsample_stride (int): Stride value for the downsampling operation. 145 | - n_conv_modules (int): Number of convolutional modules. 146 | - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers. 147 | - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True. 148 | 149 | Shapes: 150 | - Input: (B, Fi, T, Ci) where 151 | B is batch size, 152 | Fi is the number of input subbands, 153 | T is sequence length, and 154 | Ci is the number of input channels. 155 | - Output: (B, Fi+1, T, Ci+1) where 156 | B is batch size, 157 | Fi+1 is the number of output subbands, 158 | T is sequence length, 159 | Ci+1 is the number of output channels. 160 | """ 161 | 162 | def __init__( 163 | self, 164 | subband_interval: Tuple[float, float], 165 | input_dim: int, 166 | output_dim: int, 167 | downsample_stride: int, 168 | n_conv_modules: int, 169 | kernel_sizes: List[int], 170 | bias: bool = True, 171 | ): 172 | """ 173 | Initializes SDLayer with subband interval, input dimension, 174 | output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias. 175 | """ 176 | super().__init__() 177 | self.subband_interval = subband_interval 178 | self.downsample = Downsample(input_dim, output_dim, downsample_stride) 179 | self.activation = nn.GELU() 180 | conv_modules = [ 181 | ConvolutionModule( 182 | input_dim=output_dim, 183 | hidden_dim=output_dim // 4, 184 | kernel_sizes=kernel_sizes, 185 | bias=bias, 186 | ) 187 | for _ in range(n_conv_modules) 188 | ] 189 | self.conv_modules = nn.Sequential(*conv_modules) 190 | 191 | def forward(self, x: torch.Tensor) -> torch.Tensor: 192 | """ 193 | Performs forward pass through the SDLayer. 194 | 195 | Args: 196 | - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). 197 | 198 | Returns: 199 | - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1). 200 | """ 201 | B, F, T, C = x.shape 202 | x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)] 203 | x = x.permute(0, 3, 1, 2) 204 | x = self.downsample(x) 205 | x = self.activation(x) 206 | x = x.permute(0, 2, 3, 1) 207 | 208 | B, F, T, C = x.shape 209 | x = x.reshape((B * F), T, C) 210 | x = self.conv_modules(x) 211 | x = x.reshape(B, F, T, C) 212 | 213 | return x 214 | 215 | 216 | class SDBlock(nn.Module): 217 | """ 218 | SDBlock class implements a block with subband decomposition layers and global convolution. 219 | 220 | Args: 221 | - input_dim (int): Dimensionality of the input channels. 222 | - output_dim (int): Dimensionality of the output channels. 223 | - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. 224 | - downsample_strides (List[int]): List of stride values for downsampling in each subband layer. 225 | - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer. 226 | - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None. 227 | 228 | Shapes: 229 | - Input: (B, Fi, T, Ci) where 230 | B is batch size, 231 | Fi is the number of input subbands, 232 | T is sequence length, 233 | Ci is the number of input channels. 234 | - Output: (B, Fi+1, T, Ci+1) where 235 | B is batch size, 236 | Fi+1 is the number of output subbands, 237 | T is sequence length, 238 | Ci+1 is the number of output channels. 239 | """ 240 | 241 | def __init__( 242 | self, 243 | input_dim: int, 244 | output_dim: int, 245 | bandsplit_ratios: List[float], 246 | downsample_strides: List[int], 247 | n_conv_modules: List[int], 248 | kernel_sizes: Optional[List[int]] = None, 249 | ): 250 | """ 251 | Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes. 252 | """ 253 | super().__init__() 254 | if kernel_sizes is None: 255 | kernel_sizes = [3, 3, 1] 256 | assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1." 257 | subband_intervals = create_intervals(bandsplit_ratios) 258 | self.sd_layers = nn.ModuleList( 259 | SDLayer( 260 | input_dim=input_dim, 261 | output_dim=output_dim, 262 | subband_interval=sbi, 263 | downsample_stride=dss, 264 | n_conv_modules=ncm, 265 | kernel_sizes=kernel_sizes, 266 | ) 267 | for sbi, dss, ncm in zip( 268 | subband_intervals, downsample_strides, n_conv_modules 269 | ) 270 | ) 271 | self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1) 272 | 273 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 274 | """ 275 | Performs forward pass through the SDBlock. 276 | 277 | Args: 278 | - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci). 279 | 280 | Returns: 281 | - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor. 282 | """ 283 | x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1) 284 | x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 285 | return x, x_skip 286 | -------------------------------------------------------------------------------- /src/model/modules/su_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from model.utils import get_convtranspose_output_padding 7 | 8 | 9 | class FusionLayer(nn.Module): 10 | """ 11 | FusionLayer class implements a module for fusing two input tensors using convolutional operations. 12 | 13 | Args: 14 | - input_dim (int): Dimensionality of the input channels. 15 | - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3. 16 | - stride (int, optional): Stride value for the convolutional layer. Default is 1. 17 | - padding (int, optional): Padding value for the convolutional layer. Default is 1. 18 | 19 | Shapes: 20 | - Input: (B, F, T, C) and (B, F, T, C) where 21 | B is batch size, 22 | F is the number of features, 23 | T is sequence length, 24 | C is input dimensionality. 25 | - Output: (B, F, T, C) where 26 | B is batch size, 27 | F is the number of features, 28 | T is sequence length, 29 | C is input dimensionality. 30 | """ 31 | 32 | def __init__( 33 | self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 34 | ): 35 | """ 36 | Initializes FusionLayer with input dimension, kernel size, stride, and padding. 37 | """ 38 | super().__init__() 39 | self.conv = nn.Conv2d( 40 | input_dim * 2, 41 | input_dim * 2, 42 | kernel_size=(kernel_size, 1), 43 | stride=(stride, 1), 44 | padding=(padding, 0), 45 | ) 46 | self.activation = nn.GLU() 47 | 48 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 49 | """ 50 | Performs forward pass through the FusionLayer. 51 | 52 | Args: 53 | - x1 (torch.Tensor): First input tensor of shape (B, F, T, C). 54 | - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C). 55 | 56 | Returns: 57 | - torch.Tensor: Output tensor of shape (B, F, T, C). 58 | """ 59 | x = x1 + x2 60 | x = x.repeat(1, 1, 1, 2) 61 | x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 62 | x = self.activation(x) 63 | return x 64 | 65 | 66 | class Upsample(nn.Module): 67 | """ 68 | Upsample class implements a module for upsampling input tensors using transposed 2D convolution. 69 | 70 | Args: 71 | - input_dim (int): Dimensionality of the input channels. 72 | - output_dim (int): Dimensionality of the output channels. 73 | - stride (int): Stride value for the transposed convolution operation. 74 | - output_padding (int): Output padding value for the transposed convolution operation. 75 | 76 | Shapes: 77 | - Input: (B, C_in, F, T) where 78 | B is batch size, 79 | C_in is the number of input channels, 80 | F is the frequency dimension, 81 | T is the time dimension. 82 | - Output: (B, C_out, F * stride + output_padding, T) where 83 | B is batch size, 84 | C_out is the number of output channels, 85 | F * stride + output_padding is the upsampled frequency dimension. 86 | """ 87 | 88 | def __init__( 89 | self, input_dim: int, output_dim: int, stride: int, output_padding: int 90 | ): 91 | """ 92 | Initializes Upsample with input dimension, output dimension, stride, and output padding. 93 | """ 94 | super().__init__() 95 | self.conv = nn.ConvTranspose2d( 96 | input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) 97 | ) 98 | 99 | def forward(self, x: torch.Tensor) -> torch.Tensor: 100 | """ 101 | Performs forward pass through the Upsample module. 102 | 103 | Args: 104 | - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). 105 | 106 | Returns: 107 | - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T). 108 | """ 109 | return self.conv(x) 110 | 111 | 112 | class SULayer(nn.Module): 113 | """ 114 | SULayer class implements a subband upsampling layer using transposed convolution. 115 | 116 | Args: 117 | - input_dim (int): Dimensionality of the input channels. 118 | - output_dim (int): Dimensionality of the output channels. 119 | - upsample_stride (int): Stride value for the upsampling operation. 120 | - subband_shape (int): Shape of the subband. 121 | - sd_interval (Tuple[int, int]): Start and end indices of the subband interval. 122 | 123 | Shapes: 124 | - Input: (B, F, T, C) where 125 | B is batch size, 126 | F is the number of features, 127 | T is sequence length, 128 | C is input dimensionality. 129 | - Output: (B, F, T, C) where 130 | B is batch size, 131 | F is the number of features, 132 | T is sequence length, 133 | C is input dimensionality. 134 | """ 135 | 136 | def __init__( 137 | self, 138 | input_dim: int, 139 | output_dim: int, 140 | upsample_stride: int, 141 | subband_shape: int, 142 | sd_interval: Tuple[int, int], 143 | ): 144 | """ 145 | Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval. 146 | """ 147 | super().__init__() 148 | sd_shape = sd_interval[1] - sd_interval[0] 149 | upsample_output_padding = get_convtranspose_output_padding( 150 | input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride 151 | ) 152 | self.upsample = Upsample( 153 | input_dim=input_dim, 154 | output_dim=output_dim, 155 | stride=upsample_stride, 156 | output_padding=upsample_output_padding, 157 | ) 158 | self.sd_interval = sd_interval 159 | 160 | def forward(self, x: torch.Tensor) -> torch.Tensor: 161 | """ 162 | Performs forward pass through the SULayer. 163 | 164 | Args: 165 | - x (torch.Tensor): Input tensor of shape (B, F, T, C). 166 | 167 | Returns: 168 | - torch.Tensor: Output tensor of shape (B, F, T, C). 169 | """ 170 | x = x[:, self.sd_interval[0] : self.sd_interval[1]] 171 | x = x.permute(0, 3, 1, 2) 172 | x = self.upsample(x) 173 | x = x.permute(0, 2, 3, 1) 174 | return x 175 | 176 | 177 | class SUBlock(nn.Module): 178 | """ 179 | SUBlock class implements a block with fusion layer and subband upsampling layers. 180 | 181 | Args: 182 | - input_dim (int): Dimensionality of the input channels. 183 | - output_dim (int): Dimensionality of the output channels. 184 | - upsample_strides (List[int]): List of stride values for the upsampling operations. 185 | - subband_shapes (List[int]): List of shapes for the subbands. 186 | - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition. 187 | 188 | Shapes: 189 | - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where 190 | B is batch size, 191 | Fi-1 is the number of input subbands, 192 | T is sequence length, 193 | Ci-1 is the number of input channels. 194 | - Output: (B, Fi, T, Ci) where 195 | B is batch size, 196 | Fi is the number of output subbands, 197 | T is sequence length, 198 | Ci is the number of output channels. 199 | """ 200 | 201 | def __init__( 202 | self, 203 | input_dim: int, 204 | output_dim: int, 205 | upsample_strides: List[int], 206 | subband_shapes: List[int], 207 | sd_intervals: List[Tuple[int, int]], 208 | ): 209 | """ 210 | Initializes SUBlock with input dimension, output dimension, 211 | upsample strides, subband shapes, and subband intervals. 212 | """ 213 | super().__init__() 214 | self.fusion_layer = FusionLayer(input_dim=input_dim) 215 | self.su_layers = nn.ModuleList( 216 | SULayer( 217 | input_dim=input_dim, 218 | output_dim=output_dim, 219 | upsample_stride=uss, 220 | subband_shape=sbs, 221 | sd_interval=sdi, 222 | ) 223 | for i, (uss, sbs, sdi) in enumerate( 224 | zip(upsample_strides, subband_shapes, sd_intervals) 225 | ) 226 | ) 227 | 228 | def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: 229 | """ 230 | Performs forward pass through the SUBlock. 231 | 232 | Args: 233 | - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1). 234 | - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1). 235 | 236 | Returns: 237 | - torch.Tensor: Output tensor of shape (B, Fi, T, Ci). 238 | """ 239 | x = self.fusion_layer(x, x_skip) 240 | x = torch.concat([layer(x) for layer in self.su_layers], dim=1) 241 | return x 242 | -------------------------------------------------------------------------------- /src/model/scnet.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from model.modules import DualPathRNN, SDBlock, SUBlock 7 | from model.utils import compute_sd_layer_shapes, compute_gcr 8 | 9 | 10 | class SCNet(nn.Module): 11 | """ 12 | SCNet class implements a source separation network, 13 | which explicitly split the spectrogram of the mixture into several subbands 14 | and introduce a sparsity-based encoder to model different frequency bands. 15 | 16 | Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION" 17 | Authors: Weinan Tong, Jiaxu Zhu et al. 18 | Link: https://arxiv.org/abs/2401.13276.pdf 19 | 20 | Args: 21 | - n_fft (int): Number of FFTs to determine the frequency dimension of the input. 22 | - dimes (List[int]): List of channel dimensions for each block. 23 | - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands. 24 | - downsample_strides (List[int]): List of stride values for downsampling in each block. 25 | - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block. 26 | - n_rnn_layers (int): Number of recurrent layers in the dual path RNN. 27 | - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN. 28 | - n_sources (int, optional): Number of sources to be separated. Default is 4. 29 | 30 | Shapes: 31 | - Input: (B, F, T, C) where 32 | B is batch size, 33 | F is the number of features, 34 | T is sequence length, 35 | C is input dimensionality. 36 | - Output: (B, F, T, C, S) where 37 | B is batch size, 38 | F is the number of features, 39 | T is sequence length, 40 | C is input dimensionality, 41 | S is the number of sources. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | n_fft: int, 47 | dims: List[int], 48 | bandsplit_ratios: List[float], 49 | downsample_strides: List[int], 50 | n_conv_modules: List[int], 51 | n_rnn_layers: int, 52 | rnn_hidden_dim: int, 53 | n_sources: int = 4, 54 | ): 55 | """ 56 | Initializes SCNet with input parameters. 57 | """ 58 | super().__init__() 59 | self.assert_input_data( 60 | bandsplit_ratios, 61 | downsample_strides, 62 | n_conv_modules, 63 | ) 64 | 65 | n_blocks = len(dims) - 1 66 | n_freq_bins = n_fft // 2 + 1 67 | subband_shapes, sd_intervals = compute_sd_layer_shapes( 68 | input_shape=n_freq_bins, 69 | bandsplit_ratios=bandsplit_ratios, 70 | downsample_strides=downsample_strides, 71 | n_layers=n_blocks, 72 | ) 73 | self.sd_blocks = nn.ModuleList( 74 | SDBlock( 75 | input_dim=dims[i], 76 | output_dim=dims[i + 1], 77 | bandsplit_ratios=bandsplit_ratios, 78 | downsample_strides=downsample_strides, 79 | n_conv_modules=n_conv_modules, 80 | ) 81 | for i in range(n_blocks) 82 | ) 83 | self.dualpath_blocks = DualPathRNN( 84 | n_layers=n_rnn_layers, 85 | input_dim=dims[-1], 86 | hidden_dim=rnn_hidden_dim, 87 | ) 88 | self.su_blocks = nn.ModuleList( 89 | SUBlock( 90 | input_dim=dims[i + 1], 91 | output_dim=dims[i] if i != 0 else dims[i] * n_sources, 92 | subband_shapes=subband_shapes[i], 93 | sd_intervals=sd_intervals[i], 94 | upsample_strides=downsample_strides, 95 | ) 96 | for i in reversed(range(n_blocks)) 97 | ) 98 | self.gcr = compute_gcr(subband_shapes) 99 | 100 | @staticmethod 101 | def assert_input_data(*args): 102 | """ 103 | Asserts that the shapes of input features are equal. 104 | """ 105 | for arg1 in args: 106 | for arg2 in args: 107 | if len(arg1) != len(arg2): 108 | raise ValueError( 109 | f"Shapes of input features {arg1} and {arg2} are not equal." 110 | ) 111 | 112 | def forward(self, x: torch.Tensor) -> torch.Tensor: 113 | """ 114 | Performs forward pass through the SCNet. 115 | 116 | Args: 117 | - x (torch.Tensor): Input tensor of shape (B, F, T, C). 118 | 119 | Returns: 120 | - torch.Tensor: Output tensor of shape (B, F, T, C, S). 121 | """ 122 | B, F, T, C = x.shape 123 | 124 | # encoder part 125 | x_skips = [] 126 | for sd_block in self.sd_blocks: 127 | x, x_skip = sd_block(x) 128 | x_skips.append(x_skip) 129 | 130 | # separation part 131 | x = self.dualpath_blocks(x) 132 | 133 | # decoder part 134 | for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)): 135 | x = su_block(x, x_skip) 136 | 137 | # split into N sources 138 | x = x.reshape(B, F, T, C, -1) 139 | 140 | return x 141 | 142 | def count_parameters(self): 143 | """ 144 | Counts the total number of parameters in the SCNet. 145 | """ 146 | return sum(p.numel() for p in self.parameters()) 147 | 148 | 149 | if __name__ == "__main__": 150 | net_params = { 151 | "n_fft": 4096, 152 | "dims": [4, 32, 64, 128], 153 | "bandsplit_ratios": [0.175, 0.392, 0.433], 154 | "downsample_strides": [1, 4, 16], 155 | "n_conv_modules": [3, 2, 1], 156 | "n_rnn_layers": 6, 157 | "rnn_hidden_dim": 128, 158 | "n_sources": 4, 159 | } 160 | device = "cpu" 161 | B, F, T, C = 4, 2049, 474, 4 162 | 163 | net = SCNet(**net_params).to(device) 164 | _ = net.eval() 165 | 166 | test_input = torch.rand(B, F, T, C).to(device) 167 | out = net(test_input) 168 | -------------------------------------------------------------------------------- /src/model/separator.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from hydra.utils import instantiate 8 | from omegaconf import DictConfig 9 | from tqdm import tqdm 10 | 11 | 12 | class Separator(nn.Module): 13 | """ 14 | Neural Network Separator. 15 | 16 | This class implements a neural network-based separator for audio source separation. 17 | 18 | Args: 19 | - return_spec (bool): Whether to return the spectrogram of separated sources along with audio. 20 | - batch_size (int): Batch size of chunks to process at the same time. 21 | - sample_rate (int): Sample rate of the audio. 22 | - window_size (float): Size of the sliding window in seconds. 23 | - step_size (float): Step size of the sliding window in seconds. 24 | - stft (DictConfig): Configuration for Short-Time Fourier Transform (STFT). 25 | - istft (DictConfig): Configuration for Inverse Short-Time Fourier Transform (ISTFT). 26 | - net (DictConfig): Configuration for the neural network architecture. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | return_spec: bool, 32 | batch_size: int, 33 | sample_rate: int, 34 | window_size: float, 35 | step_size: float, 36 | stft: DictConfig, 37 | istft: DictConfig, 38 | net: DictConfig, 39 | use_progress_bar: bool = False, 40 | ) -> None: 41 | """ 42 | Initialize Separator. 43 | """ 44 | super().__init__() 45 | 46 | assert step_size < window_size, "Step size should be smaller than window size." 47 | 48 | self.return_spec = return_spec 49 | 50 | self.bs = batch_size 51 | self.sr = sample_rate 52 | self.ws = int(window_size * sample_rate) 53 | self.ss = int(step_size * sample_rate) 54 | self.ps = self.ws - self.ss 55 | 56 | self.stft = instantiate(stft) 57 | self.net = instantiate(net) 58 | self.istft = instantiate(istft) 59 | 60 | self.use_progress_bar = use_progress_bar 61 | 62 | @classmethod 63 | def load_from_checkpoint(cls, path: str, **overrides) -> nn.Module: 64 | """ 65 | Initializes Separator from Lightning checkpoint. 66 | 67 | Args: 68 | - path (str): Checkpoint path. 69 | """ 70 | # load checkpoint 71 | assert Path(path).suffix == ".ckpt", "Checkpoint is not in prefered format." 72 | ckpt = torch.load(path) 73 | params = ckpt["hyper_parameters"].separator 74 | state_dict = ckpt["state_dict"] 75 | 76 | # initialize separator 77 | params = {**params, **overrides} 78 | model = cls(**params) 79 | 80 | # delete lightning-module related prefix from weight keys 81 | state_dict = {k.replace("sep.", ""): state_dict[k] for k in state_dict} 82 | # load state dict 83 | _ = model.load_state_dict(state_dict) 84 | return model 85 | 86 | def pad(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: 87 | """ 88 | Pad input tensor to fit STFT requirements. 89 | 90 | Args: 91 | - x (torch.Tensor): Input tensor. 92 | 93 | Returns: 94 | - x (torch.Tensor): Padded input tensor. 95 | - pad_size (int): Size of padding. 96 | """ 97 | pad_size = self.stft.hop_length - x.shape[-1] % self.stft.hop_length 98 | x = F.pad(x, (0, pad_size)) 99 | return x, pad_size 100 | 101 | def apply_stft(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: 102 | """ 103 | Apply Short-Time Fourier Transform (STFT) to input tensor. 104 | 105 | Args: 106 | - x (torch.Tensor): Input tensor. 107 | 108 | Returns: 109 | - x (torch.Tensor): Transformed tensor. 110 | - pad_size (int): Size of padding applied. 111 | """ 112 | x, pad_size = self.pad(x) 113 | x = self.stft(x) 114 | x = torch.view_as_real(x) 115 | return x, pad_size 116 | 117 | def apply_istft( 118 | self, x: torch.Tensor, pad_size: Optional[int] = None 119 | ) -> torch.Tensor: 120 | """ 121 | Apply Inverse Short-Time Fourier Transform (ISTFT) to input tensor. 122 | 123 | Args: 124 | - x (torch.Tensor): Input tensor. 125 | - pad_size (int): Size of padding applied. 126 | 127 | Returns: 128 | - x (torch.Tensor): Inverse transformed tensor. 129 | """ 130 | x = torch.view_as_complex(x) 131 | x = self.istft(x) 132 | if pad_size is not None: 133 | x = x[..., :-pad_size] 134 | return x 135 | 136 | def apply_net(self, x: torch.Tensor) -> torch.Tensor: 137 | """ 138 | Apply neural network to input tensor. 139 | 140 | Args: 141 | - x (torch.Tensor): Input tensor. 142 | 143 | Returns: 144 | - x (torch.Tensor): Transformed tensor. 145 | 146 | Shapes: 147 | - Input: (B, Ch, Fr, T, Co) where 148 | B is batch size, 149 | Ch is the number of channels, 150 | Fr is the number of frequencies, 151 | T is sequence length, 152 | Co is real/imag part. 153 | - Output: (B, S, Ch, Fr, T, Co) where 154 | B is batch size, 155 | S is number of separated sources, 156 | Ch is the number of channels, 157 | Fr is the number of frequencies, 158 | T is sequence length, 159 | Co is real/imag part. 160 | """ 161 | B, Ch, Fr, T, Co = x.shape 162 | x = x.permute(0, 2, 3, 1, 4).reshape(B, Fr, T, Ch * Co) 163 | 164 | x = self.net(x) 165 | 166 | S = x.shape[-1] 167 | x = x.reshape(B, Fr, T, Ch, Co, S).permute(0, 5, 3, 1, 2, 4).contiguous() 168 | return x 169 | 170 | def forward( 171 | self, wav_mixture: torch.Tensor 172 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 173 | """ 174 | Forward pass of the separator. 175 | 176 | Args: 177 | - wav_mixture (torch.Tensor): Input mixture tensor. 178 | 179 | Returns: 180 | - torch.Tensor: Separated sources. 181 | 182 | Shapes: 183 | input (B, Ch, T) 184 | -> stft (B, Ch, F, T, Co) 185 | -> net (B, S, Ch, F, T, Co) 186 | -> istft (B, S, Ch, T) 187 | """ 188 | spec_mixture, pad_size = self.apply_stft(wav_mixture) 189 | 190 | spec_sources = self.apply_net(spec_mixture) 191 | 192 | wav_sources = self.apply_istft(spec_sources, pad_size) 193 | 194 | if self.return_spec: 195 | return wav_sources, spec_sources 196 | return wav_sources, None 197 | 198 | def pad_whole(self, y: torch.Tensor) -> Tuple[torch.Tensor, int]: 199 | """ 200 | Pad the input tensor before overlap-add. 201 | 202 | Args: 203 | - y (torch.Tensor): Input tensor. 204 | 205 | Returns: 206 | - y (torch.Tensor): Padded input tensor. 207 | """ 208 | padding_add = self.ss - (y.shape[-1] + self.ps * 2 - self.ws) % self.ss 209 | y = F.pad(y, (self.ps, self.ps + padding_add), "constant") 210 | return y, padding_add 211 | 212 | def unpad_whole(self, y: torch.Tensor, padding_add: int) -> torch.Tensor: 213 | """ 214 | Unpad the input tensor after overlap-add. 215 | 216 | Args: 217 | - y (torch.Tensor): Input tensor. 218 | - padding_add (int): Size of padding applied. 219 | 220 | Returns: 221 | - y (torch.Tensor): Unpadded input tensor. 222 | """ 223 | return y[..., self.ps : -(self.ps + padding_add)] 224 | 225 | def unfold(self, y: torch.Tensor) -> torch.Tensor: 226 | """ 227 | Unfold the input tensor before applying the model. 228 | 229 | Args: 230 | - y (torch.Tensor): Input tensor. 231 | 232 | Returns: 233 | - y (torch.Tensor): Unfolded input tensor. 234 | """ 235 | y = y.unfold(-1, self.ws, self.ss).permute(1, 0, 2) 236 | return y 237 | 238 | def fold(self, y_chunks: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 239 | """ 240 | Perform overlap-add operation. 241 | 242 | Args: 243 | - y_chunks (torch.Tensor): Segmented chunks of input tensor. 244 | - y (torch.Tensor): Original input tensor. 245 | 246 | Returns: 247 | - y_out (torch.Tensor): Overlap-added output tensor. 248 | """ 249 | n_chunks, n_sources, *_ = y_chunks.shape 250 | y_out = torch.zeros_like(y).unsqueeze(0).repeat(n_sources, 1, 1) 251 | start = 0 252 | for i in range(n_chunks): 253 | y_out[..., start : start + self.ws] += y_chunks[i] 254 | start += self.ss 255 | return y_out 256 | 257 | def forward_batches(self, y_chunks: torch.Tensor) -> torch.Tensor: 258 | """ 259 | Forward pass for batches of input chunks. 260 | 261 | Args: 262 | - y_chunks (torch.Tensor): Input tensor chunks. 263 | 264 | Returns: 265 | - y_chunks (torch.Tensor): Processed output tensor chunks. 266 | """ 267 | norm_value = self.ws / self.ss 268 | chunks = list(range(0, y_chunks.shape[0], self.bs)) 269 | if self.use_progress_bar: 270 | chunks = tqdm(chunks) 271 | y_chunks = torch.cat( 272 | [ 273 | self(y_chunks[start : start + self.bs])[0] / norm_value 274 | for start in chunks 275 | ] 276 | ) 277 | return y_chunks 278 | 279 | @torch.no_grad() 280 | def separate(self, y: torch.Tensor) -> torch.Tensor: 281 | """ 282 | Perform source separation on the input tensor. 283 | 284 | Args: 285 | - y (torch.Tensor): Input tensor. 286 | 287 | Returns: 288 | - y (torch.Tensor): Separated source tensor. 289 | """ 290 | y, padding_add = self.pad_whole(y) 291 | y_chunks = self.unfold(y) 292 | 293 | y_chunks = self.forward_batches(y_chunks) 294 | 295 | y = self.fold(y_chunks, y) 296 | y = self.unpad_whole(y, padding_add) 297 | 298 | return y 299 | -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import torch 4 | 5 | 6 | def create_intervals( 7 | splits: List[Union[float, int]] 8 | ) -> List[Union[Tuple[float, float], Tuple[int, int]]]: 9 | """ 10 | Create intervals based on splits provided. 11 | 12 | Args: 13 | - splits (List[Union[float, int]]): List of floats or integers representing splits. 14 | 15 | Returns: 16 | - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals. 17 | """ 18 | start = 0 19 | return [(start, start := start + split) for split in splits] 20 | 21 | 22 | def get_conv_output_shape( 23 | input_shape: int, 24 | kernel_size: int = 1, 25 | padding: int = 0, 26 | dilation: int = 1, 27 | stride: int = 1, 28 | ) -> int: 29 | """ 30 | Compute the output shape of a convolutional layer. 31 | 32 | Args: 33 | - input_shape (int): Input shape. 34 | - kernel_size (int, optional): Kernel size of the convolution. Default is 1. 35 | - padding (int, optional): Padding size. Default is 0. 36 | - dilation (int, optional): Dilation factor. Default is 1. 37 | - stride (int, optional): Stride value. Default is 1. 38 | 39 | Returns: 40 | - int: Output shape. 41 | """ 42 | return int( 43 | (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1 44 | ) 45 | 46 | 47 | def get_convtranspose_output_padding( 48 | input_shape: int, 49 | output_shape: int, 50 | kernel_size: int = 1, 51 | padding: int = 0, 52 | dilation: int = 1, 53 | stride: int = 1, 54 | ) -> int: 55 | """ 56 | Compute the output padding for a convolution transpose operation. 57 | 58 | Args: 59 | - input_shape (int): Input shape. 60 | - output_shape (int): Desired output shape. 61 | - kernel_size (int, optional): Kernel size of the convolution. Default is 1. 62 | - padding (int, optional): Padding size. Default is 0. 63 | - dilation (int, optional): Dilation factor. Default is 1. 64 | - stride (int, optional): Stride value. Default is 1. 65 | 66 | Returns: 67 | - int: Output padding. 68 | """ 69 | return ( 70 | output_shape 71 | - (input_shape - 1) * stride 72 | + 2 * padding 73 | - dilation * (kernel_size - 1) 74 | - 1 75 | ) 76 | 77 | 78 | def compute_sd_layer_shapes( 79 | input_shape: int, 80 | bandsplit_ratios: List[float], 81 | downsample_strides: List[int], 82 | n_layers: int, 83 | ) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]: 84 | """ 85 | Compute the shapes for the subband layers. 86 | 87 | Args: 88 | - input_shape (int): Input shape. 89 | - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands. 90 | - downsample_strides (List[int]): Strides for downsampling in each layer. 91 | - n_layers (int): Number of layers. 92 | 93 | Returns: 94 | - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes. 95 | """ 96 | bandsplit_shapes_list = [] 97 | conv2d_shapes_list = [] 98 | for _ in range(n_layers): 99 | bandsplit_intervals = create_intervals(bandsplit_ratios) 100 | bandsplit_shapes = [ 101 | int(right * input_shape) - int(left * input_shape) 102 | for left, right in bandsplit_intervals 103 | ] 104 | conv2d_shapes = [ 105 | get_conv_output_shape(bs, stride=ds) 106 | for bs, ds in zip(bandsplit_shapes, downsample_strides) 107 | ] 108 | input_shape = sum(conv2d_shapes) 109 | bandsplit_shapes_list.append(bandsplit_shapes) 110 | conv2d_shapes_list.append(create_intervals(conv2d_shapes)) 111 | 112 | return bandsplit_shapes_list, conv2d_shapes_list 113 | 114 | 115 | def compute_gcr(subband_shapes: List[List[int]]) -> float: 116 | """ 117 | Compute the global compression ratio. 118 | 119 | Args: 120 | - subband_shapes (List[List[int]]): List of subband shapes. 121 | 122 | Returns: 123 | - float: Global compression ratio. 124 | """ 125 | t = torch.Tensor(subband_shapes) 126 | gcr = torch.stack( 127 | [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)] 128 | ).mean() 129 | return float(gcr) 130 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import traceback 3 | from shutil import rmtree 4 | 5 | import hydra 6 | from hydra.utils import instantiate 7 | from omegaconf import DictConfig 8 | from lightning.pytorch import seed_everything 9 | from torch.utils.data import DataLoader 10 | 11 | from utils.lightning import LightningWrapper 12 | 13 | log = logging.getLogger(__name__) 14 | 15 | 16 | @hydra.main(version_base=None, config_path="conf", config_name="train") 17 | def train(cfg: DictConfig) -> None: 18 | if cfg.get("seed"): 19 | log.info(f"Seed everything with <{cfg.seed}>...") 20 | seed_everything(cfg.seed, workers=True) 21 | 22 | log.info("Initializing DataLoaders...") 23 | dataset = instantiate(cfg.dataset) 24 | if cfg.use_validation: 25 | train_dataset, val_dataset = dataset.get_train_val_split(**cfg.train_val_split) 26 | train_dataloader = DataLoader(train_dataset, **cfg.loader.train) 27 | val_dataloader = DataLoader(val_dataset, **cfg.loader.validation) 28 | else: 29 | train_dataloader = DataLoader(dataset, **cfg.loader.train) 30 | val_dataloader = None 31 | 32 | log.info("Initializing LightningWrapper...") 33 | lt_wrapper = LightningWrapper(cfg) 34 | 35 | log.info("Initializing training utilities...") 36 | logger = instantiate(cfg.logger) 37 | callbacks = list(instantiate(cfg.callbacks).values()) 38 | 39 | log.info("Initializing trainer...") 40 | trainer = instantiate( 41 | cfg.trainer, 42 | logger=logger, 43 | callbacks=callbacks, 44 | ) 45 | 46 | log.info("Starting training...") 47 | try: 48 | trainer.fit(lt_wrapper, train_dataloader, val_dataloader) 49 | except Exception as e: 50 | log.error(f"Finished with error:\n{traceback.format_exc()}") 51 | 52 | # cleaning up if it was testrun 53 | if cfg.trainer.fast_dev_run: 54 | rmtree(cfg.output_dir) 55 | 56 | log.info("Training finished!") 57 | 58 | 59 | if __name__ == "__main__": 60 | train() 61 | -------------------------------------------------------------------------------- /src/utils/lightning.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | from hydra.utils import instantiate 6 | from lightning.pytorch import LightningModule 7 | from lightning.pytorch.utilities import grad_norm 8 | from omegaconf import DictConfig 9 | 10 | from model.separator import Separator 11 | 12 | 13 | class LightningWrapper(LightningModule): 14 | """ 15 | This class serves as a LightningModule for training and evaluating the source separation model. 16 | 17 | Args: 18 | - cfg (DictConfig): Configuration object containing wrapper settings. 19 | """ 20 | 21 | def __init__(self, cfg: DictConfig) -> None: 22 | """ 23 | Initializes the LightningWrapper. 24 | """ 25 | super().__init__() 26 | 27 | self.cfg = cfg 28 | 29 | self.sep = Separator(**cfg.separator) 30 | 31 | self.loss = instantiate(cfg.loss) 32 | self.optimizer = instantiate(cfg.optimizer, params=self.sep.parameters()) 33 | self.scheduler = instantiate(cfg.scheduler) if cfg.get("scheduler") else None 34 | self.metrics = nn.ModuleDict(instantiate(cfg.metrics)) 35 | 36 | self.save_hyperparameters(cfg) 37 | 38 | def training_step( 39 | self, batch: Dict[str, torch.Tensor], batch_idx: int 40 | ) -> torch.Tensor: 41 | loss = self.step(batch, mode="train") 42 | return loss 43 | 44 | def validation_step( 45 | self, batch: Dict[str, torch.Tensor], batch_idx: int 46 | ) -> torch.Tensor: 47 | loss = self.step(batch, mode="val") 48 | return loss 49 | 50 | def step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor: 51 | """ 52 | Common step function for training and validation. 53 | 54 | Args: 55 | - batch (torch.Tensor): Batch of input data. 56 | - mode (str): Mode of operation. Defaults to 'train'. 57 | 58 | Returns: 59 | - torch.Tensor: Loss tensor. 60 | """ 61 | wav_mix, wav_src = batch["mixture"], batch["sources"] 62 | wav_src_hat, spec_src_hat = self.sep(wav_mix) 63 | spec_src, _ = self.sep.apply_stft(wav_src) 64 | 65 | loss = self.loss(spec_src_hat, spec_src) 66 | 67 | self.log(f"{mode}/loss", loss.detach(), prog_bar=True) 68 | 69 | if mode == "val": 70 | metrics = self.compute_metrics(wav_src_hat, wav_src) 71 | self.log_dict(metrics) 72 | return loss 73 | 74 | @torch.no_grad() 75 | def compute_metrics( 76 | self, preds: torch.Tensor, target: torch.Tensor 77 | ) -> Dict[str, torch.Tensor]: 78 | """ 79 | Computes metrics for evaluation. 80 | 81 | Args: 82 | - preds (torch.Tensor): Predicted sources tensor. 83 | - target (torch.Tensor): Target sources tensor. 84 | 85 | Returns: 86 | - Dict[str, torch.Tensor]: Dictionary containing computed metrics. 87 | """ 88 | metrics = {} 89 | for key in self.metrics: 90 | for i, source in enumerate(self.cfg.dataset.sources): 91 | metrics[f"val/{key}_{source}"] = self.metrics[key]( 92 | preds[:, i], target[:, i] 93 | ) 94 | metrics[f"val/{key}"] = self.metrics[key](preds, target) 95 | return metrics 96 | 97 | def on_before_optimizer_step(self, *args, **kwargs) -> None: 98 | norms = grad_norm(self, norm_type=2) 99 | norms = dict(filter(lambda elem: "_total" in elem[0], norms.items())) 100 | self.log_dict(norms) 101 | return 102 | 103 | def configure_optimizers(self): 104 | return [self.optimizer] 105 | 106 | def get_metrics(self, *args, **kwargs): 107 | items = super().get_metrics() 108 | items.pop("v_num", None) 109 | return items 110 | -------------------------------------------------------------------------------- /src/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RMSELoss(nn.Module): 6 | """ 7 | Root Mean Squared Error Loss for Complex Inputs. 8 | 9 | Args: 10 | - eps (float): A small value to prevent division by zero. 11 | """ 12 | 13 | def __init__(self, eps: float = 1e-6): 14 | """ 15 | Initialize RMSELoss. 16 | """ 17 | super().__init__() 18 | self.mse = nn.MSELoss() 19 | self.eps = eps 20 | 21 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 22 | """ 23 | Compute the forward pass of RMSELoss. 24 | 25 | Args: 26 | - pred (torch.Tensor): Predicted values, a tensor of shape (batch_size, ..., 2), 27 | where last dimension is real/imaginary part. 28 | - target (torch.Tensor): Target values, a tensor of shape (batch_size, ..., 2), 29 | where last dimension is real/imaginary part. 30 | 31 | Returns: 32 | - loss (torch.Tensor): Computed RMSE loss. 33 | """ 34 | loss = torch.sqrt( 35 | self.mse(pred[..., 0], target[..., 0]) 36 | + self.mse(pred[..., 1], target[..., 1]) 37 | + self.eps 38 | ) 39 | return loss 40 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | from torchmetrics.audio import SignalDistortionRatio 5 | 6 | 7 | def global_signal_distortion_ratio( 8 | preds: torch.Tensor, target: torch.Tensor, epsilon: float = 1e-7 9 | ) -> torch.Tensor: 10 | """ 11 | Calculates the Global Signal Distortion Ratio (GSDR) between predicted and target signals. 12 | 13 | Args: 14 | - preds (torch.Tensor): Predicted signal tensor. 15 | - target (torch.Tensor): Target signal tensor. 16 | - epsilon (float, optional): Small value to avoid division by zero. Defaults to 1e-7. 17 | 18 | Returns: 19 | - torch.Tensor: Mean Global Signal Distortion Ratio (GSDR) over the batch. 20 | """ 21 | num = torch.sum(torch.square(target), dim=(-2, -1)) + epsilon 22 | den = torch.sum(torch.square(target - preds), dim=(-2, -1)) + epsilon 23 | usdr = 10 * torch.log10(num / den) 24 | return usdr.mean() 25 | 26 | 27 | class GlobalSignalDistortionRatio(SignalDistortionRatio): 28 | """ 29 | Computes the Global Signal Distortion Ratio (GSDR) metric for audio signals 30 | as it was described firstly in Sound Demixing Challenge. 31 | 32 | This metric calculates the ratio between the energy of the original signal 33 | and the energy of the difference between the original and the predicted signal, 34 | measured in decibels (dB). 35 | 36 | Paper: https://arxiv.org/pdf/2308.06979.pdf 37 | """ 38 | 39 | def __init__( 40 | self, 41 | epsilon: float = 1e-7, 42 | **kwargs: Any, 43 | ) -> None: 44 | """ 45 | Initializes the GlobalSignalDistortionRatio metric. 46 | """ 47 | super().__init__(**kwargs) 48 | 49 | self.epsilon = epsilon 50 | 51 | self.add_state("sum_sdr", default=torch.tensor(0.0), dist_reduce_fx="sum") 52 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 53 | 54 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: 55 | """Update state with predictions and targets.""" 56 | sdr_batch = global_signal_distortion_ratio(preds, target, self.epsilon) 57 | 58 | self.sum_sdr += sdr_batch.sum() 59 | self.total += sdr_batch.numel() 60 | --------------------------------------------------------------------------------