├── .gitignore
├── LICENSE
├── README.md
├── data
└── .gitkeep
├── images
└── architecture.png
├── logs
└── .gitkeep
├── requirements.txt
├── scripts
└── .gitkeep
└── src
├── conf
├── callbacks
│ └── default.yaml
├── dataset
│ └── default.yaml
├── hydra
│ └── default.yaml
├── logger
│ └── default.yaml
├── loss
│ └── default.yaml
├── metrics
│ └── default.yaml
├── optimizer
│ └── default.yaml
├── scheduler
│ └── .gitkeep
├── separator
│ ├── default.yaml
│ └── net
│ │ └── default.yaml
└── train.yaml
├── data
├── dataset.py
└── utils.py
├── evaluate.py
├── inference.py
├── model
├── modules
│ ├── __init__.py
│ ├── dualpath_rnn.py
│ ├── sd_encoder.py
│ └── su_decoder.py
├── scnet.py
├── separator.py
└── utils.py
├── train.py
└── utils
├── lightning.py
├── loss.py
└── metrics.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | **/*.ipynb_checkpoints/
3 | **/*.pt
4 | **/*.ipynb
5 | *.pyc
6 | .DS_Store
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Amantur Amatov
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SCNet-Pytorch
2 |
3 | Unofficial PyTorch implementation of the paper
4 | ["SCNet: Sparse Compression Network for Music Source Separation"](https://arxiv.org/abs/2401.13276.pdf).
5 |
6 | 
7 |
8 | ---
9 | ## Table of Contents
10 |
11 | 1. [Changelog & ToDo's](#changelog)
12 | 2. [Dependencies](#dependencies)
13 | 3. [Training](#train)
14 | 4. [Inference](#inference)
15 | 5. [Evaluation](#eval)
16 | 6. [Repository structure](#structure)
17 | 7. [Citing](#cite)
18 |
19 | ---
20 |
21 |
22 | # Changelog
23 |
24 | - **10.02.2024**
25 | - Model itself is finished. The train script is on its way.
26 | - **21.02.2024**
27 | - Add part of the training pipeline.
28 | - **02.03.2024**
29 | - Finish the training pipeline and the separator.
30 | - **17.03.2024**
31 | - Finish inference.py and fill README.md
32 | - **13.04.2024**
33 | - Finish evaluation pipeline and fill README.md.
34 |
35 | ## ToDo's
36 |
37 | - Add trained model.
38 |
39 | ---
40 |
41 |
42 | # Dependencies
43 |
44 | Before starting training, you need to install the requirements:
45 |
46 | ```bash
47 | pip install -r requirements.txt
48 | ```
49 |
50 | Then, download the MUSDB18HQ dataset:
51 |
52 | ```bash
53 | wget -P /path/to/dataset/musdb18hq.zip https://zenodo.org/records/3338373/files/musdb18hq.zip
54 | unzip /path/to/dataset/musdb18hq.zip -d /path/to/dataset
55 | ```
56 |
57 | Next, create environment variables with paths to the audio data and generated metadata `.pqt` file:
58 |
59 | ```bash
60 | export DATASET_DIR=/path/to/dataset/musdb18hq
61 | export DATASET_PATH=/path/to/dataset/dataset.pqt
62 | ```
63 |
64 | Finally, export the GPU to make it visible:
65 | ```bash
66 | export CUDA_VISIBLE_DEVICES=0
67 | ```
68 |
69 | Now, you can train the model.
70 |
71 | ---
72 |
73 |
74 | # Training
75 |
76 | To train the model, a combination of `PyTorch-Lightning` and `hydra` was used.
77 | All configuration files are stored in the `src/conf` directory in `hydra`-friendly format.
78 |
79 | To start training a model with given configurations, use the following script:
80 | ```
81 | python src/train.py
82 | ```
83 | To configure the training process, follow `hydra` [instructions](https://hydra.cc/docs/advanced/override_grammar/basic/).
84 | You can modify/override the arguments doing something like this:
85 | ```
86 | python src/train.py +trainer.overfit_batches=10 loader.train.batch_size=16
87 | ```
88 |
89 | After training is started, the logging folder will be created for a particular experiment with the following path:
90 | ```
91 | logs/scnet/${now:%Y-%m-%d}_${now:%H-%M}/
92 | ```
93 | This folder will have the following structure:
94 | ```
95 | ├── checkpoints
96 | │ └── tensorboard_log_file - main tensorboard log file
97 | ├── tensorboard
98 | │ └── *.ckpt - lightning model checkpoint files.
99 | └── yamls
100 | │ └──*.yaml - hydra configuration and override files
101 | └── train.log - logging file for train.py
102 |
103 | ```
104 |
105 | ---
106 |
107 |
108 | # Inference
109 |
110 | After training a model, you can run inference using the following command:
111 |
112 | ```bash
113 | python src/inference.py -i \
114 | -o \
115 | -c
116 | ```
117 |
118 | This command will generate separated audio files in .wav format in the directory.
119 |
120 | For more information about the script and its options, use:
121 | ```bash
122 | usage: inference.py [-h] -i INPUT_PATH -o OUTPUT_PATH -c CKPT_PATH [-d DEVICE] [-b BATCH_SIZE] [-w WINDOW_SIZE] [-s STEP_SIZE] [-p]
123 |
124 | Argument Parser for Separator
125 |
126 | optional arguments:
127 | -h, --help show this help message and exit
128 | -i INPUT_PATH, --input-path INPUT_PATH
129 | Input path to .wav audio file/directory containing audio files
130 | -o OUTPUT_PATH, --output-path OUTPUT_PATH
131 | Output directory to save separated audio files in .wav format
132 | -c CKPT_PATH, --ckpt-path CKPT_PATH
133 | Path to the model checkpoint
134 | -d DEVICE, --device DEVICE
135 | Device to run the model on (default: cuda)
136 | -b BATCH_SIZE, --batch-size BATCH_SIZE
137 | Batch size for processing (default: 4)
138 | -w WINDOW_SIZE, --window-size WINDOW_SIZE
139 | Window size (default: 11)
140 | -s STEP_SIZE, --step-size STEP_SIZE
141 | Step size (default: 5.5)
142 | -p, --use-progress-bar
143 | Use progress bar (default: True)
144 | ```
145 |
146 | Additionally, you can run inference within Python using the following script:
147 | ```python
148 | import sys
149 | sys.path.append('src/')
150 |
151 | import torchaudio
152 | from src.model.separator import Separator
153 |
154 | device: str = 'cuda'
155 |
156 | separator = Separator.load_from_checkpoint(
157 | path="", # path to trained Lightning checkpoint
158 | batch_size=4, # adjust batch size to fit into your GPU's memory
159 | window_size=11, # window size of the model (do not change)
160 | step_size=5.5, # as step size is closer to window size, inference will be faster, but results less good
161 | use_progress_bar=True # show progress bar per audio file
162 | ).to(device)
163 |
164 | y, sr = torchaudio.load("")
165 | y = y.to(device)
166 |
167 | y_separated = separator.separate(y).cpu()
168 | ```
169 |
170 | Make sure to replace ``, ``, and `` with the appropriate paths for your setup.
171 |
172 | ---
173 |
174 |
175 | # Evaluation
176 |
177 | After training a model, you can run evaluation pipeline using the following command:
178 |
179 | ```bash
180 | python src/evaluate.py -c
181 | ```
182 |
183 | This script uses defined checkpoint path `` and runs inference on `test` audio files from `DATASET_PATH` file.
184 |
185 | As a result, the script will output into console mean SDRs for each source.
186 |
187 | For more information about the script and its options, use:
188 |
189 | ```bash
190 | usage: evaluate.py [-h] -c CKPT_PATH [-d DEVICE] [-b BATCH_SIZE] [-w WINDOW_SIZE] [-s STEP_SIZE]
191 |
192 | Argument Parser for Separator
193 |
194 | optional arguments:
195 | -h, --help show this help message and exit
196 | -c CKPT_PATH, --ckpt-path CKPT_PATH
197 | Path to the model checkpoint
198 | -d DEVICE, --device DEVICE
199 | Device to run the model on (default: cuda)
200 | -b BATCH_SIZE, --batch-size BATCH_SIZE
201 | Batch size for processing (default: 4)
202 | -w WINDOW_SIZE, --window-size WINDOW_SIZE
203 | Window size (default: 11)
204 | -s STEP_SIZE, --step-size STEP_SIZE
205 | Step size (default: 5.5)
206 | ```
207 |
208 | ---
209 |
210 |
211 | # Citing
212 |
213 | To cite this paper, please use:
214 | ```
215 | @misc{tong2024scnet,
216 | title={SCNet: Sparse Compression Network for Music Source Separation},
217 | author={Weinan Tong and Jiaxu Zhu and Jun Chen and Shiyin Kang and Tao Jiang and Yang Li and Zhiyong Wu and Helen Meng},
218 | year={2024},
219 | eprint={2401.13276},
220 | archivePrefix={arXiv},
221 | primaryClass={eess.AS}
222 | }
223 | ```
224 |
225 |
226 |
227 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/data/.gitkeep
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/images/architecture.png
--------------------------------------------------------------------------------
/logs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/logs/.gitkeep
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core~=1.3.2
2 | omegaconf~=2.2.3
3 | pandas~=1.5.3
4 | torch~=2.1.2
5 | torchaudio~=2.1.2
6 | lightning~=2.2.0
7 | soundfile==0.12.1
8 | tqdm==4.64.1
9 | fast-bss-eval==0.1.4
--------------------------------------------------------------------------------
/scripts/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/scripts/.gitkeep
--------------------------------------------------------------------------------
/src/conf/callbacks/default.yaml:
--------------------------------------------------------------------------------
1 | lr_monitor:
2 | _target_: lightning.pytorch.callbacks.LearningRateMonitor
3 | logging_interval: epoch
4 | model_ckpt:
5 | _target_: lightning.pytorch.callbacks.ModelCheckpoint
6 | monitor: val/usdr
7 | mode: max
8 | save_top_k: 3
9 | dirpath: ${hydra:runtime.output_dir}/checkpoints
10 | filename: epoch{epoch:02d}-val_usdr{val/usdr:.2f}
11 | auto_insert_metric_name: False
--------------------------------------------------------------------------------
/src/conf/dataset/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: data.dataset.SourceSeparationDataset
2 | subset: train
3 | dataset_dirs: ${oc.env:DATASET_DIR}
4 | dataset_path: ${oc.env:DATASET_PATH}
5 | dataset_extension: wav
6 | window_size: 11
7 | step_size: 1
8 | sample_rate: 44100
9 | sources: ${sources}
--------------------------------------------------------------------------------
/src/conf/hydra/default.yaml:
--------------------------------------------------------------------------------
1 | job:
2 | name: scnet
3 | output_subdir: yamls
4 | run:
5 | dir: logs/${hydra.job.name}/${now:%Y-%m-%d}_${now:%H-%M}
6 | sweep:
7 | dir: logs/${hydra.job.name}_multirun/${now:%Y-%m-%d}_${now:%H-%M}
8 | subdir: ${hydra.job.num}
9 | job_logging:
10 | handlers:
11 | file:
12 | filename: ${hydra.runtime.output_dir}/train.log
--------------------------------------------------------------------------------
/src/conf/logger/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.loggers.TensorBoardLogger
2 | save_dir: ${hydra:runtime.output_dir}/tensorboard
3 | name: ''
4 | version: ''
5 | default_hp_metric: False
--------------------------------------------------------------------------------
/src/conf/loss/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: utils.loss.RMSELoss
--------------------------------------------------------------------------------
/src/conf/metrics/default.yaml:
--------------------------------------------------------------------------------
1 | usdr:
2 | _target_: utils.metrics.GlobalSignalDistortionRatio
3 | epsilon: 1e-7
--------------------------------------------------------------------------------
/src/conf/optimizer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.Adam
2 | lr: 5e-4
--------------------------------------------------------------------------------
/src/conf/scheduler/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amanteur/SCNet-PyTorch/69bede853c7144e82bbe5aeb6432d8f43d2fce48/src/conf/scheduler/.gitkeep
--------------------------------------------------------------------------------
/src/conf/separator/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - net: default
3 |
4 | window_size: 11
5 | step_size: 5.5
6 | sample_rate: 44100
7 | batch_size: 4
8 |
9 | return_spec: True
10 |
11 | stft:
12 | _target_: torchaudio.transforms.Spectrogram
13 | n_fft: 4096
14 | win_length: 4096
15 | hop_length: 1024
16 | power: null
17 |
18 | istft:
19 | _target_: torchaudio.transforms.InverseSpectrogram
20 | n_fft: ${separator.stft.n_fft}
21 | win_length: ${separator.stft.win_length}
22 | hop_length: ${separator.stft.hop_length}
--------------------------------------------------------------------------------
/src/conf/separator/net/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.scnet.SCNet
2 | _convert_: object
3 | n_fft: ${separator.stft.n_fft}
4 | dims: [4, 32, 64, 128]
5 | bandsplit_ratios: [.175, .392, .433]
6 | downsample_strides: [1, 4, 16]
7 | n_conv_modules: [3, 2, 1]
8 | n_rnn_layers: 6
9 | rnn_hidden_dim: 128
10 | n_sources: 4
--------------------------------------------------------------------------------
/src/conf/train.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - _self_
3 | - dataset: default
4 | - separator: default
5 |
6 | - loss: default
7 | - optimizer: default
8 | - scheduler: null
9 |
10 | - metrics: default
11 | - logger: default
12 | - callbacks: default
13 | - hydra: default
14 |
15 |
16 | seed: 42
17 | sources: [drums, bass, other, vocals]
18 | output_dir: ${hydra:runtime.output_dir}
19 |
20 | use_validation: True
21 | train_val_split:
22 | lengths: [86, 14]
23 | loader:
24 | train:
25 | batch_size: 8
26 | num_workers: 8
27 | shuffle: True
28 | drop_last: True
29 | validation:
30 | batch_size: 4
31 | num_workers: 8
32 | shuffle: False
33 | drop_last: False
34 |
35 | trainer:
36 | _target_: lightning.pytorch.Trainer
37 | fast_dev_run: False
38 | accelerator: cuda
39 | max_epochs: 130
40 | check_val_every_n_epoch: 1
41 | num_sanity_val_steps: 5
42 | log_every_n_steps: 100
43 | devices: 1
44 | gradient_clip_val: 5
45 | precision: 32
46 | enable_progress_bar: True
47 | benchmark: False
48 | deterministic: False
49 |
--------------------------------------------------------------------------------
/src/data/dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple
2 |
3 | import pandas as pd
4 | import torch
5 | import torchaudio
6 | from torch.utils.data import Dataset, Subset, random_split
7 |
8 | from data.utils import construct_dataset
9 |
10 |
11 | class SourceSeparationDataset(Dataset):
12 | """
13 | Dataset class for source separation tasks.
14 |
15 | Args:
16 | - dataset_dirs (str): Comma-separated paths to directories where datasets are located.
17 | - subset (str): Subset of the dataset to load.
18 | - window_size (int): Size of the sliding window in seconds.
19 | - step_size (int): Step size of the sliding window in seconds.
20 | - sample_rate (int): Sample rate of the audio.
21 | - dataset_extension (str, optional): Extension of the dataset files. Defaults to 'wav'.
22 | - dataset_path (str, optional): Path to cache the dataset. Defaults to None.
23 | If specified, dataset metadata will be cached in parquet format for future retrieval.
24 | Defaults to None.
25 | - mixture_name (str, optional): Name of the mixture. Defaults to None.
26 | - sources (List[str], optional): List of source names. Defaults to ['drums', 'bass', 'other', 'vocals'].
27 | """
28 |
29 | MIXTURE_NAME: str = "mixture"
30 | SOURCE_NAMES: List[str] = ["drums", "bass", "other", "vocals"]
31 |
32 | def __init__(
33 | self,
34 | dataset_dirs: str,
35 | subset: str,
36 | window_size: int,
37 | step_size: int,
38 | sample_rate: int,
39 | dataset_extension: str = "wav",
40 | dataset_path: Optional[str] = None,
41 | mixture_name: Optional[str] = None,
42 | sources: Optional[List[str]] = None,
43 | ):
44 | """
45 | Initializes the SourceSeparationDataset.
46 | """
47 | super().__init__()
48 |
49 | self.dataset_dirs: List[str] = dataset_dirs.split(",")
50 | self.subset = subset
51 | self.dataset_extension = dataset_extension
52 | self.dataset_path = dataset_path
53 |
54 | self.window_size = int(window_size * sample_rate)
55 | self.step_size = int(step_size * sample_rate)
56 | self.sample_rate = sample_rate
57 |
58 | self.mixture_name = mixture_name or self.MIXTURE_NAME
59 | self.sources = sources or self.SOURCE_NAMES
60 |
61 | self.df = self.load_df()
62 | self.segment_ids, self.track_ids = self.get_ids()
63 |
64 | def generate_offsets(
65 | self,
66 | total_frames: int,
67 | ) -> List[int]:
68 | """
69 | Generates the offsets based on total frames of track, window size and step size of segments.
70 |
71 | Args:
72 | - total_frames (int): Total number of frames of audio.
73 |
74 | Returns:
75 | - List[int]: List of offsets.
76 | """
77 | return [
78 | start
79 | for start in range(0, total_frames - self.window_size + 1, self.step_size)
80 | ]
81 |
82 | def load_df(self) -> pd.DataFrame:
83 | """
84 | Loads the DataFrame based on the train/test subset and populates data based on window/step sizes.
85 |
86 | Returns:
87 | - pd.DataFrame: Loaded DataFrame.
88 | """
89 | df = construct_dataset(
90 | self.dataset_dirs,
91 | extension=self.dataset_extension,
92 | save_path=self.dataset_path,
93 | )
94 | df = df[df["subset"].eq(self.subset)]
95 | df["offset"] = df["total_frames"].apply(self.generate_offsets)
96 | df = df.explode("offset")
97 | df["track_id"] = df["track_name"].factorize()[0]
98 | df["segment_id"] = df.set_index(["track_id", "offset"]).index.factorize()[0]
99 | return df
100 |
101 | def get_ids(self) -> Tuple[List[int], List[int]]:
102 | """
103 | Gets the segment and track IDs.
104 |
105 | Returns:
106 | - Tuple[List[int], List[int]]: Tuple containing segment and track IDs.
107 | """
108 | segment_ids = self.df["segment_id"].tolist()
109 | track_ids = self.df["track_id"].unique().tolist()
110 | return segment_ids, track_ids
111 |
112 | def load_audio(self, segment_info: Dict[str, Any]) -> torch.Tensor:
113 | """
114 | Loads the audio based on segment information.
115 |
116 | Args:
117 | - segment_info (Dict[str, Any]): Segment information.
118 |
119 | Returns:
120 | - torch.Tensor: Loaded audio tensor.
121 | """
122 | audio, sr = torchaudio.load(
123 | segment_info["path"],
124 | num_frames=self.window_size,
125 | frame_offset=segment_info["offset"],
126 | )
127 | assert (
128 | sr == self.sample_rate
129 | ), f"Sample rate of the audio should be {self.sample_rate}Hz instead of {sr}Hz."
130 | return audio
131 |
132 | def load_mixture(self, idx: int) -> torch.Tensor:
133 | """
134 | Loads the audio mixture based on the provided index.
135 |
136 | Args:
137 | - idx (int): Index of the mixture.
138 |
139 | Returns:
140 | - torch.Tensor: Loaded audio mixture tensor.
141 | """
142 | segment_info = (
143 | self.df[
144 | self.df["segment_id"].eq(idx)
145 | & self.df["source_type"].eq(self.mixture_name)
146 | ]
147 | .iloc[0]
148 | .to_dict()
149 | )
150 | audio = self.load_audio(segment_info)
151 | return audio
152 |
153 | def load_sources(self, idx: int) -> torch.Tensor:
154 | """
155 | Loads the separated sources based on the provided index.
156 |
157 | Args:
158 | - idx (int): Index of the source.
159 |
160 | Returns:
161 | - torch.Tensor: Loaded and stacked audio sources tensor.
162 | """
163 | audios = []
164 | for source in self.sources:
165 | segment_info = (
166 | self.df[
167 | self.df["segment_id"].eq(idx) & self.df["source_type"].eq(source)
168 | ]
169 | .iloc[0]
170 | .to_dict()
171 | )
172 | audio = self.load_audio(segment_info)
173 | audios.append(audio)
174 | return torch.stack(audios)
175 |
176 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
177 | """
178 | Retrieves an item from the dataset based on the index.
179 |
180 | Args:
181 | - idx (int): Index of the item.
182 |
183 | Returns:
184 | - Dict[str, torch.Tensor]: Dictionary containing mixture and sources.
185 | """
186 | segment_id = self.segment_ids[idx]
187 |
188 | mixture = self.load_mixture(segment_id)
189 | sources = self.load_sources(segment_id)
190 |
191 | return {
192 | "mixture": mixture,
193 | "sources": sources,
194 | }
195 |
196 | def __len__(self) -> int:
197 | """
198 | Returns the length of the dataset.
199 |
200 | Returns:
201 | - int: Length of the dataset.
202 | """
203 | return len(self.segment_ids)
204 |
205 | def get_train_val_split(
206 | self, lengths: List[float], seed: Optional[int] = None
207 | ) -> Tuple[Subset, Subset]:
208 | """
209 | Splits the dataset into training and validation subsets.
210 |
211 | Args:
212 | - lengths (List[float]): List containing the lengths of the training and validation subsets.
213 | - seed (Optional[int]): Random seed for reproducibility. Defaults to None.
214 |
215 | Returns:
216 | - Tuple[Subset, Subset]: Tuple containing the training and validation subsets.
217 | """
218 | assert (
219 | self.subset == "train"
220 | ), "Only train subset of the dataset can be split into train and val."
221 | assert len(lengths) == 2, "Dataset can be only split into two subset."
222 | generator = torch.Generator().manual_seed(seed) if seed is not None else None
223 |
224 | train_track_ids, val_track_ids = random_split(
225 | self.track_ids, lengths=lengths, generator=generator
226 | )
227 | train_segment_ids = self.df[
228 | self.df.track_id.isin(train_track_ids.indices)
229 | ].segment_id.to_list()
230 | val_segment_ids = self.df[
231 | self.df.track_id.isin(val_track_ids.indices)
232 | ].segment_id.to_list()
233 |
234 | train_subset = Subset(self, indices=train_segment_ids)
235 | val_subset = Subset(self, indices=val_segment_ids)
236 |
237 | return train_subset, val_subset
238 |
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from pathlib import Path
3 | from typing import Optional
4 |
5 | import pandas as pd
6 | import torchaudio
7 |
8 | PARQUET_EXTENSIONS: List[str] = [".pqt", ".parquet"]
9 |
10 |
11 | def construct_dataset(
12 | dataset_dirs: List[str], extension: str = "wav", save_path: Optional[str] = None
13 | ) -> pd.DataFrame:
14 | """
15 | Constructs a dataset DataFrame from audio files in the specified directories.
16 |
17 | Args:
18 | - dataset_dirs (List[str]): List of directories containing audio files.
19 | - extension (str): Extension of the audio files to consider. Defaults to "wav".
20 | - save_path (Optional[str]): Optional path to save the constructed DataFrame as a parquet file.
21 |
22 | Returns:
23 | - pd.DataFrame: DataFrame containing information about the audio files.
24 | """
25 | if save_path is not None and not Path(save_path).suffix in PARQUET_EXTENSIONS:
26 | raise ValueError("'save_path' should be in .parquet/.pqt format.")
27 |
28 | if save_path is not None and Path(save_path).is_file():
29 | dataset_df = pd.read_parquet(save_path)
30 | return dataset_df
31 |
32 | if not isinstance(dataset_dirs, list):
33 | raise TypeError(
34 | f"'dataset_dirs' should be a list of strings, but got {type(dataset_dirs)}"
35 | )
36 |
37 | dataset = []
38 | for dataset_dir in dataset_dirs:
39 | for path in Path(dataset_dir).glob(f"**/*.{extension}"):
40 | abs_path = str(path.resolve())
41 | source_type = path.stem
42 | subset, track_name = path.relative_to(dataset_dir).parts[:2]
43 | audio_info = torchaudio.info(path)
44 | dataset.append(
45 | (
46 | abs_path,
47 | track_name,
48 | source_type,
49 | subset,
50 | audio_info.num_frames,
51 | audio_info.num_frames / audio_info.sample_rate,
52 | audio_info.sample_rate,
53 | audio_info.num_channels,
54 | )
55 | )
56 | columns = [
57 | "path",
58 | "track_name",
59 | "source_type",
60 | "subset",
61 | "total_frames",
62 | "total_seconds",
63 | "sample_rate",
64 | "num_channels",
65 | ]
66 | dataset_df = pd.DataFrame(dataset, columns=columns)
67 | if save_path is not None:
68 | dataset_df.to_parquet(save_path, index=False)
69 | return dataset_df
70 |
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import logging
4 | from typing import Dict, Iterable, List, Tuple
5 |
6 | import fast_bss_eval
7 | import pandas as pd
8 | import torchaudio
9 | import torch.nn as nn
10 | from tqdm import tqdm
11 |
12 | from model.separator import Separator
13 |
14 | SOURCES: List[str] = ["drums", "bass", "other", "vocals"]
15 |
16 | logging.basicConfig(
17 | level=logging.INFO,
18 | format="%(asctime)s - %(levelname)s - %(message)s",
19 | datefmt="%Y-%m-%d %H:%M:%S",
20 | )
21 |
22 | log = logging.getLogger(__name__)
23 |
24 |
25 | def parse_arguments():
26 | """
27 | Parse command line arguments.
28 |
29 | Returns:
30 | argparse.Namespace: Parsed arguments.
31 | """
32 | parser = argparse.ArgumentParser(description="Argument Parser for Separator")
33 | parser.add_argument(
34 | "-c",
35 | "--ckpt-path",
36 | type=str,
37 | required=True,
38 | help="Path to the model checkpoint",
39 | )
40 | parser.add_argument(
41 | "-d",
42 | "--device",
43 | type=str,
44 | default="cuda",
45 | help="Device to run the model on (default: cuda)",
46 | )
47 | parser.add_argument(
48 | "-b",
49 | "--batch-size",
50 | type=int,
51 | default=4,
52 | help="Batch size for processing (default: 4)",
53 | )
54 | parser.add_argument(
55 | "-w", "--window-size", type=int, default=11, help="Window size (default: 11)"
56 | )
57 | parser.add_argument(
58 | "-s", "--step-size", type=float, default=5.5, help="Step size (default: 5.5)"
59 | )
60 | return parser.parse_args()
61 |
62 |
63 | def load_data(dataset_path: str) -> Iterable[Tuple[str, str, Dict[str, str]]]:
64 | """
65 | Load data from the dataset.
66 |
67 | Args:
68 | dataset_path (str): Path to the dataset.
69 |
70 | Yields:
71 | Tuple[str, str, Dict[str, str]]: Tuple containing track name, mixture path, and source paths.
72 | """
73 | df = pd.read_parquet(dataset_path)
74 | df = df[df["subset"].eq("test")]
75 | track_names = df["track_name"].unique()
76 | for track_name in tqdm(track_names):
77 | rows = df[df["track_name"].eq(track_name)]
78 | mixture_path = rows[rows["source_type"].eq("mixture")]["path"].values[0]
79 | source_paths = (
80 | rows[~rows["source_type"].eq("mixture")]
81 | .set_index("source_type")["path"]
82 | .to_dict()
83 | )
84 | yield track_name, mixture_path, source_paths
85 |
86 |
87 | def compute_sdrs(separator: nn.Module, dataset_path: str, device: str) -> str:
88 | """
89 | Compute evaluation SDRs.
90 |
91 | Args:
92 | separator (nn.Module): Separator model.
93 | dataset_path (str): Path to the dataset.pqt.
94 | device (str): Device to send tensors on.
95 |
96 | Returns:
97 | str: Evaluation SDRs table.
98 | """
99 | sdrs = []
100 | for track_name, mixture_path, source_paths in load_data(dataset_path):
101 | y, sr = torchaudio.load(mixture_path)
102 | y_separated = separator.separate(y.to(device)).cpu()
103 | for y_source_est, source_type in zip(y_separated, SOURCES):
104 | y_source_ref, _ = torchaudio.load(source_paths[source_type])
105 | sdr, *_ = fast_bss_eval.bss_eval_sources(
106 | y_source_ref, y_source_est, compute_permutation=False, load_diag=1e-7
107 | )
108 | sdrs.append((track_name, source_type, sdr.mean().item()))
109 | sdrs_df = pd.DataFrame(sdrs, columns=["track_name", "source_type", "sdr"])
110 |
111 | return sdrs_df.groupby("source_type")["sdr"].mean().reset_index(name="sdr").to_string()
112 |
113 |
114 | def main():
115 | args = parse_arguments()
116 |
117 | dataset_path = os.getenv("DATASET_PATH")
118 | if dataset_path is None:
119 | raise ValueError("DATASET_PATH environment variable is not set.")
120 |
121 | log.info(f"Initializing Separator with following checkpoint {args.ckpt_path}...")
122 | separator = Separator.load_from_checkpoint(
123 | path=args.ckpt_path,
124 | batch_size=args.batch_size,
125 | window_size=args.window_size,
126 | step_size=args.step_size,
127 | ).to(args.device)
128 |
129 | log.info("Starting evaluation...")
130 | metrics = compute_sdrs(separator, dataset_path, args.device)
131 | log.info(f"Evaluation completed with following metrics:\n{metrics}")
132 |
133 |
134 | if __name__ == "__main__":
135 | main()
136 |
--------------------------------------------------------------------------------
/src/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from pathlib import Path
4 | from typing import Iterable, List, Tuple
5 |
6 | import soundfile as sf
7 | import torch.nn as nn
8 | import torchaudio
9 |
10 | from model.separator import Separator
11 |
12 | SOURCES: List[str] = ["drums", "bass", "other", "vocals"]
13 | SAVE_SAMPLE_RATE: int = 44100
14 |
15 | logging.basicConfig(
16 | level=logging.INFO,
17 | format="%(asctime)s - %(levelname)s - %(message)s",
18 | datefmt="%Y-%m-%d %H:%M:%S",
19 | )
20 |
21 | log = logging.getLogger(__name__)
22 |
23 |
24 | def parse_arguments():
25 | """
26 | Parse command line arguments.
27 |
28 | Returns:
29 | argparse.Namespace: Parsed arguments.
30 | """
31 | parser = argparse.ArgumentParser(description="Argument Parser for Separator")
32 | parser.add_argument(
33 | "-i",
34 | "--input-path",
35 | type=str,
36 | required=True,
37 | help="Input path to .wav audio file/directory containing audio files",
38 | )
39 | parser.add_argument(
40 | "-o",
41 | "--output-path",
42 | type=str,
43 | required=True,
44 | help="Output directory to save separated audio files in .wav format",
45 | )
46 | parser.add_argument(
47 | "-c",
48 | "--ckpt-path",
49 | type=str,
50 | required=True,
51 | help="Path to the model checkpoint",
52 | )
53 | parser.add_argument(
54 | "-d",
55 | "--device",
56 | type=str,
57 | default="cuda",
58 | help="Device to run the model on (default: cuda)",
59 | )
60 | parser.add_argument(
61 | "-b",
62 | "--batch-size",
63 | type=int,
64 | default=4,
65 | help="Batch size for processing (default: 4)",
66 | )
67 | parser.add_argument(
68 | "-w", "--window-size", type=int, default=11, help="Window size (default: 11)"
69 | )
70 | parser.add_argument(
71 | "-s", "--step-size", type=float, default=5.5, help="Step size (default: 5.5)"
72 | )
73 | parser.add_argument(
74 | "-p",
75 | "--use-progress-bar",
76 | action="store_true",
77 | help="Use progress bar (default: True)",
78 | )
79 | return parser.parse_args()
80 |
81 |
82 | def load_paths(input_path: str, output_path: str) -> Iterable[Tuple[Path, Path]]:
83 | """
84 | Load input and output paths.
85 |
86 | Args:
87 | input_path (str): Input path to audio files.
88 | output_path (str): Output directory to save separated audio files.
89 |
90 | Yields:
91 | Tuple[Path, Path]: Tuple of input and output file paths.
92 | """
93 | input_path = Path(input_path)
94 | output_path = Path(output_path)
95 |
96 | if not input_path.exists():
97 | raise FileNotFoundError(f"Input path '{input_path}' does not exist.")
98 |
99 | if input_path.is_file():
100 | if not (input_path.suffix == ".wav" or input_path.suffix == ".mp3"):
101 | raise ValueError("Input audio file should be in .wav or .mp3 formats.")
102 | fp_out = output_path / input_path.stem
103 | fp_out.mkdir(exist_ok=True, parents=True)
104 | yield input_path, fp_out
105 | elif input_path.is_dir():
106 | for fp_in in input_path.glob("*"):
107 | if fp_in.suffix in (".wav", ".mp3"):
108 | fp_out = output_path / fp_in.stem
109 | fp_out.mkdir(exist_ok=True, parents=True)
110 | yield fp_in, fp_out
111 | else:
112 | raise ValueError(
113 | f"Input path '{input_path}' is neither a file nor a directory."
114 | )
115 |
116 |
117 | def process_files(
118 | separator: nn.Module, device: str, input_path: str, output_path: str
119 | ) -> None:
120 | for fp_in, fp_out in load_paths(input_path, output_path):
121 | y, sr = torchaudio.load(fp_in)
122 | y_separated = separator.separate(y.to(device)).cpu()
123 | for y_source, source in zip(y_separated, SOURCES):
124 | sf.write(f"{fp_out}/{source}.wav", y_source.T, SAVE_SAMPLE_RATE)
125 |
126 |
127 | def main():
128 | args = parse_arguments()
129 |
130 | log.info(f"Initializing Separator with following checkpoint {args.ckpt_path}...")
131 | separator = Separator.load_from_checkpoint(
132 | path=args.ckpt_path,
133 | batch_size=args.batch_size,
134 | window_size=args.window_size,
135 | step_size=args.step_size,
136 | use_progress_bar=args.use_progress_bar,
137 | ).to(args.device)
138 |
139 | log.info("Processing audio files...")
140 | process_files(separator, args.device, args.input_path, args.output_path)
141 | log.info(f"Audio files processing completed.")
142 |
143 |
144 | if __name__ == "__main__":
145 | main()
146 |
--------------------------------------------------------------------------------
/src/model/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .dualpath_rnn import DualPathRNN
2 | from .sd_encoder import SDBlock
3 | from .su_decoder import SUBlock
4 |
--------------------------------------------------------------------------------
/src/model/modules/dualpath_rnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RNNModule(nn.Module):
6 | """
7 | RNNModule class implements a recurrent neural network module with LSTM cells.
8 |
9 | Args:
10 | - input_dim (int): Dimensionality of the input features.
11 | - hidden_dim (int): Dimensionality of the hidden state of the LSTM.
12 | - bidirectional (bool, optional): If True, uses bidirectional LSTM. Defaults to True.
13 |
14 | Shapes:
15 | - Input: (B, T, D) where
16 | B is batch size,
17 | T is sequence length,
18 | D is input dimensionality.
19 | - Output: (B, T, D) where
20 | B is batch size,
21 | T is sequence length,
22 | D is input dimensionality.
23 | """
24 |
25 | def __init__(self, input_dim: int, hidden_dim: int, bidirectional: bool = True):
26 | """
27 | Initializes RNNModule with input dimension, hidden dimension, and bidirectional flag.
28 | """
29 | super().__init__()
30 | self.groupnorm = nn.GroupNorm(num_groups=1, num_channels=input_dim)
31 | self.rnn = nn.LSTM(
32 | input_dim, hidden_dim, batch_first=True, bidirectional=bidirectional
33 | )
34 | self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, input_dim)
35 |
36 | def forward(self, x: torch.Tensor) -> torch.Tensor:
37 | """
38 | Performs forward pass through the RNNModule.
39 |
40 | Args:
41 | - x (torch.Tensor): Input tensor of shape (B, T, D).
42 |
43 | Returns:
44 | - torch.Tensor: Output tensor of shape (B, T, D).
45 | """
46 | x = x.transpose(1, 2)
47 | x = self.groupnorm(x)
48 | x = x.transpose(1, 2)
49 |
50 | x, (hidden, _) = self.rnn(x)
51 | x = self.fc(x)
52 | return x
53 |
54 |
55 | class RFFTModule(nn.Module):
56 | """
57 | RFFTModule class implements a module for performing real-valued Fast Fourier Transform (FFT)
58 | or its inverse on input tensors.
59 |
60 | Args:
61 | - inverse (bool, optional): If False, performs forward FFT. If True, performs inverse FFT. Defaults to False.
62 |
63 | Shapes:
64 | - Input: (B, F, T, D) where
65 | B is batch size,
66 | F is the number of features,
67 | T is sequence length,
68 | D is input dimensionality.
69 | - Output: (B, F, T // 2 + 1, D * 2) if performing forward FFT.
70 | (B, F, T, D // 2, 2) if performing inverse FFT.
71 | """
72 |
73 | def __init__(self, inverse: bool = False):
74 | """
75 | Initializes RFFTModule with inverse flag.
76 | """
77 | super().__init__()
78 | self.inverse = inverse
79 |
80 | def forward(self, x: torch.Tensor, time_dim: int) -> torch.Tensor:
81 | """
82 | Performs forward or inverse FFT on the input tensor x.
83 |
84 | Args:
85 | - x (torch.Tensor): Input tensor of shape (B, F, T, D).
86 | - time_dim (int): Input size of time dimension.
87 |
88 | Returns:
89 | - torch.Tensor: Output tensor after FFT or its inverse operation.
90 | """
91 | B, F, T, D = x.shape
92 | dtype = x.dtype
93 | # in case of training in fp16/bf16 and tensor is not a power of 2, tensor will be sent to the float32
94 | if dtype != torch.float and (T & (T - 1)):
95 | x = x.float()
96 | if not self.inverse:
97 | x = torch.fft.rfft(x, dim=2)
98 | x = torch.view_as_real(x)
99 | x = x.reshape(B, F, T // 2 + 1, D * 2)
100 | else:
101 | x = x.reshape(B, F, T, D // 2, 2)
102 | x = torch.view_as_complex(x)
103 | x = torch.fft.irfft(x, n=time_dim, dim=2)
104 | x = x.to(dtype)
105 | return x
106 |
107 | def extra_repr(self) -> str:
108 | """
109 | Returns extra representation string with module's configuration.
110 | """
111 | return f"inverse={self.inverse}"
112 |
113 |
114 | class DualPathRNN(nn.Module):
115 | """
116 | DualPathRNN class implements a neural network with alternating layers of RNNModule and RFFTModule.
117 |
118 | Args:
119 | - n_layers (int): Number of layers in the network.
120 | - input_dim (int): Dimensionality of the input features.
121 | - hidden_dim (int): Dimensionality of the hidden state of the RNNModule.
122 |
123 | Shapes:
124 | - Input: (B, F, T, D) where
125 | B is batch size,
126 | F is the number of features (frequency dimension),
127 | T is sequence length (time dimension),
128 | D is input dimensionality (channel dimension).
129 | - Output: (B, F, T, D) where
130 | B is batch size,
131 | F is the number of features (frequency dimension),
132 | T is sequence length (time dimension),
133 | D is input dimensionality (channel dimension).
134 | """
135 |
136 | def __init__(
137 | self,
138 | n_layers: int,
139 | input_dim: int,
140 | hidden_dim: int,
141 | ):
142 | """
143 | Initializes DualPathRNN with the specified number of layers, input dimension, and hidden dimension.
144 | """
145 | super().__init__()
146 |
147 | self.layers = nn.ModuleList()
148 | for i in range(1, n_layers + 1):
149 | if i % 2 == 1:
150 | layer = nn.ModuleList(
151 | [
152 | RNNModule(input_dim=input_dim, hidden_dim=hidden_dim),
153 | RNNModule(input_dim=input_dim, hidden_dim=hidden_dim),
154 | RFFTModule(inverse=False),
155 | ]
156 | )
157 | else:
158 | layer = nn.ModuleList(
159 | [
160 | RNNModule(input_dim=input_dim * 2, hidden_dim=hidden_dim * 2),
161 | RNNModule(input_dim=input_dim * 2, hidden_dim=hidden_dim * 2),
162 | RFFTModule(inverse=True),
163 | ]
164 | )
165 | self.layers.append(layer)
166 |
167 | def forward(self, x: torch.Tensor) -> torch.Tensor:
168 | """
169 | Performs forward pass through the DualPathRNN.
170 |
171 | Args:
172 | - x (torch.Tensor): Input tensor of shape (B, F, T, D).
173 |
174 | Returns:
175 | - torch.Tensor: Output tensor of shape (B, F, T, D).
176 | """
177 | time_dim = x.shape[2]
178 |
179 | for time_layer, freq_layer, rfft_layer in self.layers:
180 | B, F, T, D = x.shape
181 |
182 | x = x.reshape((B * F), T, D)
183 | x = time_layer(x)
184 | x = x.reshape(B, F, T, D)
185 | x = x.permute(0, 2, 1, 3)
186 |
187 | x = x.reshape((B * T), F, D)
188 | x = freq_layer(x)
189 | x = x.reshape(B, T, F, D)
190 | x = x.permute(0, 2, 1, 3)
191 |
192 | x = rfft_layer(x, time_dim)
193 | return x
194 |
--------------------------------------------------------------------------------
/src/model/modules/sd_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from model.utils import create_intervals
7 |
8 |
9 | class Downsample(nn.Module):
10 | """
11 | Downsample class implements a module for downsampling input tensors using 2D convolution.
12 |
13 | Args:
14 | - input_dim (int): Dimensionality of the input channels.
15 | - output_dim (int): Dimensionality of the output channels.
16 | - stride (int): Stride value for the convolution operation.
17 |
18 | Shapes:
19 | - Input: (B, C_in, F, T) where
20 | B is batch size,
21 | C_in is the number of input channels,
22 | F is the frequency dimension,
23 | T is the time dimension.
24 | - Output: (B, C_out, F // stride, T) where
25 | B is batch size,
26 | C_out is the number of output channels,
27 | F // stride is the downsampled frequency dimension.
28 |
29 | """
30 |
31 | def __init__(
32 | self,
33 | input_dim: int,
34 | output_dim: int,
35 | stride: int,
36 | ):
37 | """
38 | Initializes Downsample with input dimension, output dimension, and stride.
39 | """
40 | super().__init__()
41 | self.conv = nn.Conv2d(input_dim, output_dim, 1, (stride, 1))
42 |
43 | def forward(self, x: torch.Tensor) -> torch.Tensor:
44 | """
45 | Performs forward pass through the Downsample module.
46 |
47 | Args:
48 | - x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
49 |
50 | Returns:
51 | - torch.Tensor: Downsampled tensor of shape (B, C_out, F // stride, T).
52 | """
53 | return self.conv(x)
54 |
55 |
56 | class ConvolutionModule(nn.Module):
57 | """
58 | ConvolutionModule class implements a module with a sequence of convolutional layers similar to Conformer.
59 |
60 | Args:
61 | - input_dim (int): Dimensionality of the input features.
62 | - hidden_dim (int): Dimensionality of the hidden features.
63 | - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
64 | - bias (bool, optional): If True, adds a learnable bias to the output. Default is False.
65 |
66 | Shapes:
67 | - Input: (B, T, D) where
68 | B is batch size,
69 | T is sequence length,
70 | D is input dimensionality.
71 | - Output: (B, T, D) where
72 | B is batch size,
73 | T is sequence length,
74 | D is input dimensionality.
75 | """
76 |
77 | def __init__(
78 | self,
79 | input_dim: int,
80 | hidden_dim: int,
81 | kernel_sizes: List[int],
82 | bias: bool = False,
83 | ) -> None:
84 | """
85 | Initializes ConvolutionModule with input dimension, hidden dimension, kernel sizes, and bias.
86 | """
87 | super().__init__()
88 | self.sequential = nn.Sequential(
89 | nn.GroupNorm(num_groups=1, num_channels=input_dim),
90 | nn.Conv1d(
91 | input_dim,
92 | 2 * hidden_dim,
93 | kernel_sizes[0],
94 | stride=1,
95 | padding=(kernel_sizes[0] - 1) // 2,
96 | bias=bias,
97 | ),
98 | nn.GLU(dim=1),
99 | nn.Conv1d(
100 | hidden_dim,
101 | hidden_dim,
102 | kernel_sizes[1],
103 | stride=1,
104 | padding=(kernel_sizes[1] - 1) // 2,
105 | groups=hidden_dim,
106 | bias=bias,
107 | ),
108 | nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
109 | nn.SiLU(),
110 | nn.Conv1d(
111 | hidden_dim,
112 | input_dim,
113 | kernel_sizes[2],
114 | stride=1,
115 | padding=(kernel_sizes[2] - 1) // 2,
116 | bias=bias,
117 | ),
118 | )
119 |
120 | def forward(self, x: torch.Tensor) -> torch.Tensor:
121 | """
122 | Performs forward pass through the ConvolutionModule.
123 |
124 | Args:
125 | - x (torch.Tensor): Input tensor of shape (B, T, D).
126 |
127 | Returns:
128 | - torch.Tensor: Output tensor of shape (B, T, D).
129 | """
130 | x = x.transpose(1, 2)
131 | x = x + self.sequential(x)
132 | x = x.transpose(1, 2)
133 | return x
134 |
135 |
136 | class SDLayer(nn.Module):
137 | """
138 | SDLayer class implements a subband decomposition layer with downsampling and convolutional modules.
139 |
140 | Args:
141 | - subband_interval (Tuple[float, float]): Tuple representing the frequency interval for subband decomposition.
142 | - input_dim (int): Dimensionality of the input channels.
143 | - output_dim (int): Dimensionality of the output channels after downsampling.
144 | - downsample_stride (int): Stride value for the downsampling operation.
145 | - n_conv_modules (int): Number of convolutional modules.
146 | - kernel_sizes (List[int]): List of kernel sizes for the convolutional layers.
147 | - bias (bool, optional): If True, adds a learnable bias to the convolutional layers. Default is True.
148 |
149 | Shapes:
150 | - Input: (B, Fi, T, Ci) where
151 | B is batch size,
152 | Fi is the number of input subbands,
153 | T is sequence length, and
154 | Ci is the number of input channels.
155 | - Output: (B, Fi+1, T, Ci+1) where
156 | B is batch size,
157 | Fi+1 is the number of output subbands,
158 | T is sequence length,
159 | Ci+1 is the number of output channels.
160 | """
161 |
162 | def __init__(
163 | self,
164 | subband_interval: Tuple[float, float],
165 | input_dim: int,
166 | output_dim: int,
167 | downsample_stride: int,
168 | n_conv_modules: int,
169 | kernel_sizes: List[int],
170 | bias: bool = True,
171 | ):
172 | """
173 | Initializes SDLayer with subband interval, input dimension,
174 | output dimension, downsample stride, number of convolutional modules, kernel sizes, and bias.
175 | """
176 | super().__init__()
177 | self.subband_interval = subband_interval
178 | self.downsample = Downsample(input_dim, output_dim, downsample_stride)
179 | self.activation = nn.GELU()
180 | conv_modules = [
181 | ConvolutionModule(
182 | input_dim=output_dim,
183 | hidden_dim=output_dim // 4,
184 | kernel_sizes=kernel_sizes,
185 | bias=bias,
186 | )
187 | for _ in range(n_conv_modules)
188 | ]
189 | self.conv_modules = nn.Sequential(*conv_modules)
190 |
191 | def forward(self, x: torch.Tensor) -> torch.Tensor:
192 | """
193 | Performs forward pass through the SDLayer.
194 |
195 | Args:
196 | - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
197 |
198 | Returns:
199 | - torch.Tensor: Output tensor of shape (B, Fi+1, T, Ci+1).
200 | """
201 | B, F, T, C = x.shape
202 | x = x[:, int(self.subband_interval[0] * F) : int(self.subband_interval[1] * F)]
203 | x = x.permute(0, 3, 1, 2)
204 | x = self.downsample(x)
205 | x = self.activation(x)
206 | x = x.permute(0, 2, 3, 1)
207 |
208 | B, F, T, C = x.shape
209 | x = x.reshape((B * F), T, C)
210 | x = self.conv_modules(x)
211 | x = x.reshape(B, F, T, C)
212 |
213 | return x
214 |
215 |
216 | class SDBlock(nn.Module):
217 | """
218 | SDBlock class implements a block with subband decomposition layers and global convolution.
219 |
220 | Args:
221 | - input_dim (int): Dimensionality of the input channels.
222 | - output_dim (int): Dimensionality of the output channels.
223 | - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
224 | - downsample_strides (List[int]): List of stride values for downsampling in each subband layer.
225 | - n_conv_modules (List[int]): List specifying the number of convolutional modules in each subband layer.
226 | - kernel_sizes (List[int], optional): List of kernel sizes for the convolutional layers. Default is None.
227 |
228 | Shapes:
229 | - Input: (B, Fi, T, Ci) where
230 | B is batch size,
231 | Fi is the number of input subbands,
232 | T is sequence length,
233 | Ci is the number of input channels.
234 | - Output: (B, Fi+1, T, Ci+1) where
235 | B is batch size,
236 | Fi+1 is the number of output subbands,
237 | T is sequence length,
238 | Ci+1 is the number of output channels.
239 | """
240 |
241 | def __init__(
242 | self,
243 | input_dim: int,
244 | output_dim: int,
245 | bandsplit_ratios: List[float],
246 | downsample_strides: List[int],
247 | n_conv_modules: List[int],
248 | kernel_sizes: Optional[List[int]] = None,
249 | ):
250 | """
251 | Initializes SDBlock with input dimension, output dimension, band split ratios, downsample strides, number of convolutional modules, and kernel sizes.
252 | """
253 | super().__init__()
254 | if kernel_sizes is None:
255 | kernel_sizes = [3, 3, 1]
256 | assert sum(bandsplit_ratios) == 1, "The split ratios must sum up to 1."
257 | subband_intervals = create_intervals(bandsplit_ratios)
258 | self.sd_layers = nn.ModuleList(
259 | SDLayer(
260 | input_dim=input_dim,
261 | output_dim=output_dim,
262 | subband_interval=sbi,
263 | downsample_stride=dss,
264 | n_conv_modules=ncm,
265 | kernel_sizes=kernel_sizes,
266 | )
267 | for sbi, dss, ncm in zip(
268 | subband_intervals, downsample_strides, n_conv_modules
269 | )
270 | )
271 | self.global_conv2d = nn.Conv2d(output_dim, output_dim, 1, 1)
272 |
273 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
274 | """
275 | Performs forward pass through the SDBlock.
276 |
277 | Args:
278 | - x (torch.Tensor): Input tensor of shape (B, Fi, T, Ci).
279 |
280 | Returns:
281 | - Tuple[torch.Tensor, torch.Tensor]: Output tensor and skip connection tensor.
282 | """
283 | x_skip = torch.concat([layer(x) for layer in self.sd_layers], dim=1)
284 | x = self.global_conv2d(x_skip.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
285 | return x, x_skip
286 |
--------------------------------------------------------------------------------
/src/model/modules/su_decoder.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from model.utils import get_convtranspose_output_padding
7 |
8 |
9 | class FusionLayer(nn.Module):
10 | """
11 | FusionLayer class implements a module for fusing two input tensors using convolutional operations.
12 |
13 | Args:
14 | - input_dim (int): Dimensionality of the input channels.
15 | - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3.
16 | - stride (int, optional): Stride value for the convolutional layer. Default is 1.
17 | - padding (int, optional): Padding value for the convolutional layer. Default is 1.
18 |
19 | Shapes:
20 | - Input: (B, F, T, C) and (B, F, T, C) where
21 | B is batch size,
22 | F is the number of features,
23 | T is sequence length,
24 | C is input dimensionality.
25 | - Output: (B, F, T, C) where
26 | B is batch size,
27 | F is the number of features,
28 | T is sequence length,
29 | C is input dimensionality.
30 | """
31 |
32 | def __init__(
33 | self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1
34 | ):
35 | """
36 | Initializes FusionLayer with input dimension, kernel size, stride, and padding.
37 | """
38 | super().__init__()
39 | self.conv = nn.Conv2d(
40 | input_dim * 2,
41 | input_dim * 2,
42 | kernel_size=(kernel_size, 1),
43 | stride=(stride, 1),
44 | padding=(padding, 0),
45 | )
46 | self.activation = nn.GLU()
47 |
48 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
49 | """
50 | Performs forward pass through the FusionLayer.
51 |
52 | Args:
53 | - x1 (torch.Tensor): First input tensor of shape (B, F, T, C).
54 | - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C).
55 |
56 | Returns:
57 | - torch.Tensor: Output tensor of shape (B, F, T, C).
58 | """
59 | x = x1 + x2
60 | x = x.repeat(1, 1, 1, 2)
61 | x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
62 | x = self.activation(x)
63 | return x
64 |
65 |
66 | class Upsample(nn.Module):
67 | """
68 | Upsample class implements a module for upsampling input tensors using transposed 2D convolution.
69 |
70 | Args:
71 | - input_dim (int): Dimensionality of the input channels.
72 | - output_dim (int): Dimensionality of the output channels.
73 | - stride (int): Stride value for the transposed convolution operation.
74 | - output_padding (int): Output padding value for the transposed convolution operation.
75 |
76 | Shapes:
77 | - Input: (B, C_in, F, T) where
78 | B is batch size,
79 | C_in is the number of input channels,
80 | F is the frequency dimension,
81 | T is the time dimension.
82 | - Output: (B, C_out, F * stride + output_padding, T) where
83 | B is batch size,
84 | C_out is the number of output channels,
85 | F * stride + output_padding is the upsampled frequency dimension.
86 | """
87 |
88 | def __init__(
89 | self, input_dim: int, output_dim: int, stride: int, output_padding: int
90 | ):
91 | """
92 | Initializes Upsample with input dimension, output dimension, stride, and output padding.
93 | """
94 | super().__init__()
95 | self.conv = nn.ConvTranspose2d(
96 | input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0)
97 | )
98 |
99 | def forward(self, x: torch.Tensor) -> torch.Tensor:
100 | """
101 | Performs forward pass through the Upsample module.
102 |
103 | Args:
104 | - x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
105 |
106 | Returns:
107 | - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T).
108 | """
109 | return self.conv(x)
110 |
111 |
112 | class SULayer(nn.Module):
113 | """
114 | SULayer class implements a subband upsampling layer using transposed convolution.
115 |
116 | Args:
117 | - input_dim (int): Dimensionality of the input channels.
118 | - output_dim (int): Dimensionality of the output channels.
119 | - upsample_stride (int): Stride value for the upsampling operation.
120 | - subband_shape (int): Shape of the subband.
121 | - sd_interval (Tuple[int, int]): Start and end indices of the subband interval.
122 |
123 | Shapes:
124 | - Input: (B, F, T, C) where
125 | B is batch size,
126 | F is the number of features,
127 | T is sequence length,
128 | C is input dimensionality.
129 | - Output: (B, F, T, C) where
130 | B is batch size,
131 | F is the number of features,
132 | T is sequence length,
133 | C is input dimensionality.
134 | """
135 |
136 | def __init__(
137 | self,
138 | input_dim: int,
139 | output_dim: int,
140 | upsample_stride: int,
141 | subband_shape: int,
142 | sd_interval: Tuple[int, int],
143 | ):
144 | """
145 | Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval.
146 | """
147 | super().__init__()
148 | sd_shape = sd_interval[1] - sd_interval[0]
149 | upsample_output_padding = get_convtranspose_output_padding(
150 | input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride
151 | )
152 | self.upsample = Upsample(
153 | input_dim=input_dim,
154 | output_dim=output_dim,
155 | stride=upsample_stride,
156 | output_padding=upsample_output_padding,
157 | )
158 | self.sd_interval = sd_interval
159 |
160 | def forward(self, x: torch.Tensor) -> torch.Tensor:
161 | """
162 | Performs forward pass through the SULayer.
163 |
164 | Args:
165 | - x (torch.Tensor): Input tensor of shape (B, F, T, C).
166 |
167 | Returns:
168 | - torch.Tensor: Output tensor of shape (B, F, T, C).
169 | """
170 | x = x[:, self.sd_interval[0] : self.sd_interval[1]]
171 | x = x.permute(0, 3, 1, 2)
172 | x = self.upsample(x)
173 | x = x.permute(0, 2, 3, 1)
174 | return x
175 |
176 |
177 | class SUBlock(nn.Module):
178 | """
179 | SUBlock class implements a block with fusion layer and subband upsampling layers.
180 |
181 | Args:
182 | - input_dim (int): Dimensionality of the input channels.
183 | - output_dim (int): Dimensionality of the output channels.
184 | - upsample_strides (List[int]): List of stride values for the upsampling operations.
185 | - subband_shapes (List[int]): List of shapes for the subbands.
186 | - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition.
187 |
188 | Shapes:
189 | - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where
190 | B is batch size,
191 | Fi-1 is the number of input subbands,
192 | T is sequence length,
193 | Ci-1 is the number of input channels.
194 | - Output: (B, Fi, T, Ci) where
195 | B is batch size,
196 | Fi is the number of output subbands,
197 | T is sequence length,
198 | Ci is the number of output channels.
199 | """
200 |
201 | def __init__(
202 | self,
203 | input_dim: int,
204 | output_dim: int,
205 | upsample_strides: List[int],
206 | subband_shapes: List[int],
207 | sd_intervals: List[Tuple[int, int]],
208 | ):
209 | """
210 | Initializes SUBlock with input dimension, output dimension,
211 | upsample strides, subband shapes, and subband intervals.
212 | """
213 | super().__init__()
214 | self.fusion_layer = FusionLayer(input_dim=input_dim)
215 | self.su_layers = nn.ModuleList(
216 | SULayer(
217 | input_dim=input_dim,
218 | output_dim=output_dim,
219 | upsample_stride=uss,
220 | subband_shape=sbs,
221 | sd_interval=sdi,
222 | )
223 | for i, (uss, sbs, sdi) in enumerate(
224 | zip(upsample_strides, subband_shapes, sd_intervals)
225 | )
226 | )
227 |
228 | def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
229 | """
230 | Performs forward pass through the SUBlock.
231 |
232 | Args:
233 | - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1).
234 | - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1).
235 |
236 | Returns:
237 | - torch.Tensor: Output tensor of shape (B, Fi, T, Ci).
238 | """
239 | x = self.fusion_layer(x, x_skip)
240 | x = torch.concat([layer(x) for layer in self.su_layers], dim=1)
241 | return x
242 |
--------------------------------------------------------------------------------
/src/model/scnet.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from model.modules import DualPathRNN, SDBlock, SUBlock
7 | from model.utils import compute_sd_layer_shapes, compute_gcr
8 |
9 |
10 | class SCNet(nn.Module):
11 | """
12 | SCNet class implements a source separation network,
13 | which explicitly split the spectrogram of the mixture into several subbands
14 | and introduce a sparsity-based encoder to model different frequency bands.
15 |
16 | Paper: "SCNET: SPARSE COMPRESSION NETWORK FOR MUSIC SOURCE SEPARATION"
17 | Authors: Weinan Tong, Jiaxu Zhu et al.
18 | Link: https://arxiv.org/abs/2401.13276.pdf
19 |
20 | Args:
21 | - n_fft (int): Number of FFTs to determine the frequency dimension of the input.
22 | - dimes (List[int]): List of channel dimensions for each block.
23 | - bandsplit_ratios (List[float]): List of ratios for splitting the frequency bands.
24 | - downsample_strides (List[int]): List of stride values for downsampling in each block.
25 | - n_conv_modules (List[int]): List specifying the number of convolutional modules in each block.
26 | - n_rnn_layers (int): Number of recurrent layers in the dual path RNN.
27 | - rnn_hidden_dim (int): Dimensionality of the hidden state in the dual path RNN.
28 | - n_sources (int, optional): Number of sources to be separated. Default is 4.
29 |
30 | Shapes:
31 | - Input: (B, F, T, C) where
32 | B is batch size,
33 | F is the number of features,
34 | T is sequence length,
35 | C is input dimensionality.
36 | - Output: (B, F, T, C, S) where
37 | B is batch size,
38 | F is the number of features,
39 | T is sequence length,
40 | C is input dimensionality,
41 | S is the number of sources.
42 | """
43 |
44 | def __init__(
45 | self,
46 | n_fft: int,
47 | dims: List[int],
48 | bandsplit_ratios: List[float],
49 | downsample_strides: List[int],
50 | n_conv_modules: List[int],
51 | n_rnn_layers: int,
52 | rnn_hidden_dim: int,
53 | n_sources: int = 4,
54 | ):
55 | """
56 | Initializes SCNet with input parameters.
57 | """
58 | super().__init__()
59 | self.assert_input_data(
60 | bandsplit_ratios,
61 | downsample_strides,
62 | n_conv_modules,
63 | )
64 |
65 | n_blocks = len(dims) - 1
66 | n_freq_bins = n_fft // 2 + 1
67 | subband_shapes, sd_intervals = compute_sd_layer_shapes(
68 | input_shape=n_freq_bins,
69 | bandsplit_ratios=bandsplit_ratios,
70 | downsample_strides=downsample_strides,
71 | n_layers=n_blocks,
72 | )
73 | self.sd_blocks = nn.ModuleList(
74 | SDBlock(
75 | input_dim=dims[i],
76 | output_dim=dims[i + 1],
77 | bandsplit_ratios=bandsplit_ratios,
78 | downsample_strides=downsample_strides,
79 | n_conv_modules=n_conv_modules,
80 | )
81 | for i in range(n_blocks)
82 | )
83 | self.dualpath_blocks = DualPathRNN(
84 | n_layers=n_rnn_layers,
85 | input_dim=dims[-1],
86 | hidden_dim=rnn_hidden_dim,
87 | )
88 | self.su_blocks = nn.ModuleList(
89 | SUBlock(
90 | input_dim=dims[i + 1],
91 | output_dim=dims[i] if i != 0 else dims[i] * n_sources,
92 | subband_shapes=subband_shapes[i],
93 | sd_intervals=sd_intervals[i],
94 | upsample_strides=downsample_strides,
95 | )
96 | for i in reversed(range(n_blocks))
97 | )
98 | self.gcr = compute_gcr(subband_shapes)
99 |
100 | @staticmethod
101 | def assert_input_data(*args):
102 | """
103 | Asserts that the shapes of input features are equal.
104 | """
105 | for arg1 in args:
106 | for arg2 in args:
107 | if len(arg1) != len(arg2):
108 | raise ValueError(
109 | f"Shapes of input features {arg1} and {arg2} are not equal."
110 | )
111 |
112 | def forward(self, x: torch.Tensor) -> torch.Tensor:
113 | """
114 | Performs forward pass through the SCNet.
115 |
116 | Args:
117 | - x (torch.Tensor): Input tensor of shape (B, F, T, C).
118 |
119 | Returns:
120 | - torch.Tensor: Output tensor of shape (B, F, T, C, S).
121 | """
122 | B, F, T, C = x.shape
123 |
124 | # encoder part
125 | x_skips = []
126 | for sd_block in self.sd_blocks:
127 | x, x_skip = sd_block(x)
128 | x_skips.append(x_skip)
129 |
130 | # separation part
131 | x = self.dualpath_blocks(x)
132 |
133 | # decoder part
134 | for su_block, x_skip in zip(self.su_blocks, reversed(x_skips)):
135 | x = su_block(x, x_skip)
136 |
137 | # split into N sources
138 | x = x.reshape(B, F, T, C, -1)
139 |
140 | return x
141 |
142 | def count_parameters(self):
143 | """
144 | Counts the total number of parameters in the SCNet.
145 | """
146 | return sum(p.numel() for p in self.parameters())
147 |
148 |
149 | if __name__ == "__main__":
150 | net_params = {
151 | "n_fft": 4096,
152 | "dims": [4, 32, 64, 128],
153 | "bandsplit_ratios": [0.175, 0.392, 0.433],
154 | "downsample_strides": [1, 4, 16],
155 | "n_conv_modules": [3, 2, 1],
156 | "n_rnn_layers": 6,
157 | "rnn_hidden_dim": 128,
158 | "n_sources": 4,
159 | }
160 | device = "cpu"
161 | B, F, T, C = 4, 2049, 474, 4
162 |
163 | net = SCNet(**net_params).to(device)
164 | _ = net.eval()
165 |
166 | test_input = torch.rand(B, F, T, C).to(device)
167 | out = net(test_input)
168 |
--------------------------------------------------------------------------------
/src/model/separator.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from hydra.utils import instantiate
8 | from omegaconf import DictConfig
9 | from tqdm import tqdm
10 |
11 |
12 | class Separator(nn.Module):
13 | """
14 | Neural Network Separator.
15 |
16 | This class implements a neural network-based separator for audio source separation.
17 |
18 | Args:
19 | - return_spec (bool): Whether to return the spectrogram of separated sources along with audio.
20 | - batch_size (int): Batch size of chunks to process at the same time.
21 | - sample_rate (int): Sample rate of the audio.
22 | - window_size (float): Size of the sliding window in seconds.
23 | - step_size (float): Step size of the sliding window in seconds.
24 | - stft (DictConfig): Configuration for Short-Time Fourier Transform (STFT).
25 | - istft (DictConfig): Configuration for Inverse Short-Time Fourier Transform (ISTFT).
26 | - net (DictConfig): Configuration for the neural network architecture.
27 | """
28 |
29 | def __init__(
30 | self,
31 | return_spec: bool,
32 | batch_size: int,
33 | sample_rate: int,
34 | window_size: float,
35 | step_size: float,
36 | stft: DictConfig,
37 | istft: DictConfig,
38 | net: DictConfig,
39 | use_progress_bar: bool = False,
40 | ) -> None:
41 | """
42 | Initialize Separator.
43 | """
44 | super().__init__()
45 |
46 | assert step_size < window_size, "Step size should be smaller than window size."
47 |
48 | self.return_spec = return_spec
49 |
50 | self.bs = batch_size
51 | self.sr = sample_rate
52 | self.ws = int(window_size * sample_rate)
53 | self.ss = int(step_size * sample_rate)
54 | self.ps = self.ws - self.ss
55 |
56 | self.stft = instantiate(stft)
57 | self.net = instantiate(net)
58 | self.istft = instantiate(istft)
59 |
60 | self.use_progress_bar = use_progress_bar
61 |
62 | @classmethod
63 | def load_from_checkpoint(cls, path: str, **overrides) -> nn.Module:
64 | """
65 | Initializes Separator from Lightning checkpoint.
66 |
67 | Args:
68 | - path (str): Checkpoint path.
69 | """
70 | # load checkpoint
71 | assert Path(path).suffix == ".ckpt", "Checkpoint is not in prefered format."
72 | ckpt = torch.load(path)
73 | params = ckpt["hyper_parameters"].separator
74 | state_dict = ckpt["state_dict"]
75 |
76 | # initialize separator
77 | params = {**params, **overrides}
78 | model = cls(**params)
79 |
80 | # delete lightning-module related prefix from weight keys
81 | state_dict = {k.replace("sep.", ""): state_dict[k] for k in state_dict}
82 | # load state dict
83 | _ = model.load_state_dict(state_dict)
84 | return model
85 |
86 | def pad(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
87 | """
88 | Pad input tensor to fit STFT requirements.
89 |
90 | Args:
91 | - x (torch.Tensor): Input tensor.
92 |
93 | Returns:
94 | - x (torch.Tensor): Padded input tensor.
95 | - pad_size (int): Size of padding.
96 | """
97 | pad_size = self.stft.hop_length - x.shape[-1] % self.stft.hop_length
98 | x = F.pad(x, (0, pad_size))
99 | return x, pad_size
100 |
101 | def apply_stft(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]:
102 | """
103 | Apply Short-Time Fourier Transform (STFT) to input tensor.
104 |
105 | Args:
106 | - x (torch.Tensor): Input tensor.
107 |
108 | Returns:
109 | - x (torch.Tensor): Transformed tensor.
110 | - pad_size (int): Size of padding applied.
111 | """
112 | x, pad_size = self.pad(x)
113 | x = self.stft(x)
114 | x = torch.view_as_real(x)
115 | return x, pad_size
116 |
117 | def apply_istft(
118 | self, x: torch.Tensor, pad_size: Optional[int] = None
119 | ) -> torch.Tensor:
120 | """
121 | Apply Inverse Short-Time Fourier Transform (ISTFT) to input tensor.
122 |
123 | Args:
124 | - x (torch.Tensor): Input tensor.
125 | - pad_size (int): Size of padding applied.
126 |
127 | Returns:
128 | - x (torch.Tensor): Inverse transformed tensor.
129 | """
130 | x = torch.view_as_complex(x)
131 | x = self.istft(x)
132 | if pad_size is not None:
133 | x = x[..., :-pad_size]
134 | return x
135 |
136 | def apply_net(self, x: torch.Tensor) -> torch.Tensor:
137 | """
138 | Apply neural network to input tensor.
139 |
140 | Args:
141 | - x (torch.Tensor): Input tensor.
142 |
143 | Returns:
144 | - x (torch.Tensor): Transformed tensor.
145 |
146 | Shapes:
147 | - Input: (B, Ch, Fr, T, Co) where
148 | B is batch size,
149 | Ch is the number of channels,
150 | Fr is the number of frequencies,
151 | T is sequence length,
152 | Co is real/imag part.
153 | - Output: (B, S, Ch, Fr, T, Co) where
154 | B is batch size,
155 | S is number of separated sources,
156 | Ch is the number of channels,
157 | Fr is the number of frequencies,
158 | T is sequence length,
159 | Co is real/imag part.
160 | """
161 | B, Ch, Fr, T, Co = x.shape
162 | x = x.permute(0, 2, 3, 1, 4).reshape(B, Fr, T, Ch * Co)
163 |
164 | x = self.net(x)
165 |
166 | S = x.shape[-1]
167 | x = x.reshape(B, Fr, T, Ch, Co, S).permute(0, 5, 3, 1, 2, 4).contiguous()
168 | return x
169 |
170 | def forward(
171 | self, wav_mixture: torch.Tensor
172 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
173 | """
174 | Forward pass of the separator.
175 |
176 | Args:
177 | - wav_mixture (torch.Tensor): Input mixture tensor.
178 |
179 | Returns:
180 | - torch.Tensor: Separated sources.
181 |
182 | Shapes:
183 | input (B, Ch, T)
184 | -> stft (B, Ch, F, T, Co)
185 | -> net (B, S, Ch, F, T, Co)
186 | -> istft (B, S, Ch, T)
187 | """
188 | spec_mixture, pad_size = self.apply_stft(wav_mixture)
189 |
190 | spec_sources = self.apply_net(spec_mixture)
191 |
192 | wav_sources = self.apply_istft(spec_sources, pad_size)
193 |
194 | if self.return_spec:
195 | return wav_sources, spec_sources
196 | return wav_sources, None
197 |
198 | def pad_whole(self, y: torch.Tensor) -> Tuple[torch.Tensor, int]:
199 | """
200 | Pad the input tensor before overlap-add.
201 |
202 | Args:
203 | - y (torch.Tensor): Input tensor.
204 |
205 | Returns:
206 | - y (torch.Tensor): Padded input tensor.
207 | """
208 | padding_add = self.ss - (y.shape[-1] + self.ps * 2 - self.ws) % self.ss
209 | y = F.pad(y, (self.ps, self.ps + padding_add), "constant")
210 | return y, padding_add
211 |
212 | def unpad_whole(self, y: torch.Tensor, padding_add: int) -> torch.Tensor:
213 | """
214 | Unpad the input tensor after overlap-add.
215 |
216 | Args:
217 | - y (torch.Tensor): Input tensor.
218 | - padding_add (int): Size of padding applied.
219 |
220 | Returns:
221 | - y (torch.Tensor): Unpadded input tensor.
222 | """
223 | return y[..., self.ps : -(self.ps + padding_add)]
224 |
225 | def unfold(self, y: torch.Tensor) -> torch.Tensor:
226 | """
227 | Unfold the input tensor before applying the model.
228 |
229 | Args:
230 | - y (torch.Tensor): Input tensor.
231 |
232 | Returns:
233 | - y (torch.Tensor): Unfolded input tensor.
234 | """
235 | y = y.unfold(-1, self.ws, self.ss).permute(1, 0, 2)
236 | return y
237 |
238 | def fold(self, y_chunks: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
239 | """
240 | Perform overlap-add operation.
241 |
242 | Args:
243 | - y_chunks (torch.Tensor): Segmented chunks of input tensor.
244 | - y (torch.Tensor): Original input tensor.
245 |
246 | Returns:
247 | - y_out (torch.Tensor): Overlap-added output tensor.
248 | """
249 | n_chunks, n_sources, *_ = y_chunks.shape
250 | y_out = torch.zeros_like(y).unsqueeze(0).repeat(n_sources, 1, 1)
251 | start = 0
252 | for i in range(n_chunks):
253 | y_out[..., start : start + self.ws] += y_chunks[i]
254 | start += self.ss
255 | return y_out
256 |
257 | def forward_batches(self, y_chunks: torch.Tensor) -> torch.Tensor:
258 | """
259 | Forward pass for batches of input chunks.
260 |
261 | Args:
262 | - y_chunks (torch.Tensor): Input tensor chunks.
263 |
264 | Returns:
265 | - y_chunks (torch.Tensor): Processed output tensor chunks.
266 | """
267 | norm_value = self.ws / self.ss
268 | chunks = list(range(0, y_chunks.shape[0], self.bs))
269 | if self.use_progress_bar:
270 | chunks = tqdm(chunks)
271 | y_chunks = torch.cat(
272 | [
273 | self(y_chunks[start : start + self.bs])[0] / norm_value
274 | for start in chunks
275 | ]
276 | )
277 | return y_chunks
278 |
279 | @torch.no_grad()
280 | def separate(self, y: torch.Tensor) -> torch.Tensor:
281 | """
282 | Perform source separation on the input tensor.
283 |
284 | Args:
285 | - y (torch.Tensor): Input tensor.
286 |
287 | Returns:
288 | - y (torch.Tensor): Separated source tensor.
289 | """
290 | y, padding_add = self.pad_whole(y)
291 | y_chunks = self.unfold(y)
292 |
293 | y_chunks = self.forward_batches(y_chunks)
294 |
295 | y = self.fold(y_chunks, y)
296 | y = self.unpad_whole(y, padding_add)
297 |
298 | return y
299 |
--------------------------------------------------------------------------------
/src/model/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Union
2 |
3 | import torch
4 |
5 |
6 | def create_intervals(
7 | splits: List[Union[float, int]]
8 | ) -> List[Union[Tuple[float, float], Tuple[int, int]]]:
9 | """
10 | Create intervals based on splits provided.
11 |
12 | Args:
13 | - splits (List[Union[float, int]]): List of floats or integers representing splits.
14 |
15 | Returns:
16 | - List[Union[Tuple[float, float], Tuple[int, int]]]: List of tuples representing intervals.
17 | """
18 | start = 0
19 | return [(start, start := start + split) for split in splits]
20 |
21 |
22 | def get_conv_output_shape(
23 | input_shape: int,
24 | kernel_size: int = 1,
25 | padding: int = 0,
26 | dilation: int = 1,
27 | stride: int = 1,
28 | ) -> int:
29 | """
30 | Compute the output shape of a convolutional layer.
31 |
32 | Args:
33 | - input_shape (int): Input shape.
34 | - kernel_size (int, optional): Kernel size of the convolution. Default is 1.
35 | - padding (int, optional): Padding size. Default is 0.
36 | - dilation (int, optional): Dilation factor. Default is 1.
37 | - stride (int, optional): Stride value. Default is 1.
38 |
39 | Returns:
40 | - int: Output shape.
41 | """
42 | return int(
43 | (input_shape + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
44 | )
45 |
46 |
47 | def get_convtranspose_output_padding(
48 | input_shape: int,
49 | output_shape: int,
50 | kernel_size: int = 1,
51 | padding: int = 0,
52 | dilation: int = 1,
53 | stride: int = 1,
54 | ) -> int:
55 | """
56 | Compute the output padding for a convolution transpose operation.
57 |
58 | Args:
59 | - input_shape (int): Input shape.
60 | - output_shape (int): Desired output shape.
61 | - kernel_size (int, optional): Kernel size of the convolution. Default is 1.
62 | - padding (int, optional): Padding size. Default is 0.
63 | - dilation (int, optional): Dilation factor. Default is 1.
64 | - stride (int, optional): Stride value. Default is 1.
65 |
66 | Returns:
67 | - int: Output padding.
68 | """
69 | return (
70 | output_shape
71 | - (input_shape - 1) * stride
72 | + 2 * padding
73 | - dilation * (kernel_size - 1)
74 | - 1
75 | )
76 |
77 |
78 | def compute_sd_layer_shapes(
79 | input_shape: int,
80 | bandsplit_ratios: List[float],
81 | downsample_strides: List[int],
82 | n_layers: int,
83 | ) -> Tuple[List[List[int]], List[List[Tuple[int, int]]]]:
84 | """
85 | Compute the shapes for the subband layers.
86 |
87 | Args:
88 | - input_shape (int): Input shape.
89 | - bandsplit_ratios (List[float]): Ratios for splitting the frequency bands.
90 | - downsample_strides (List[int]): Strides for downsampling in each layer.
91 | - n_layers (int): Number of layers.
92 |
93 | Returns:
94 | - Tuple[List[List[int]], List[List[Tuple[int, int]]]]: Tuple containing subband shapes and convolution shapes.
95 | """
96 | bandsplit_shapes_list = []
97 | conv2d_shapes_list = []
98 | for _ in range(n_layers):
99 | bandsplit_intervals = create_intervals(bandsplit_ratios)
100 | bandsplit_shapes = [
101 | int(right * input_shape) - int(left * input_shape)
102 | for left, right in bandsplit_intervals
103 | ]
104 | conv2d_shapes = [
105 | get_conv_output_shape(bs, stride=ds)
106 | for bs, ds in zip(bandsplit_shapes, downsample_strides)
107 | ]
108 | input_shape = sum(conv2d_shapes)
109 | bandsplit_shapes_list.append(bandsplit_shapes)
110 | conv2d_shapes_list.append(create_intervals(conv2d_shapes))
111 |
112 | return bandsplit_shapes_list, conv2d_shapes_list
113 |
114 |
115 | def compute_gcr(subband_shapes: List[List[int]]) -> float:
116 | """
117 | Compute the global compression ratio.
118 |
119 | Args:
120 | - subband_shapes (List[List[int]]): List of subband shapes.
121 |
122 | Returns:
123 | - float: Global compression ratio.
124 | """
125 | t = torch.Tensor(subband_shapes)
126 | gcr = torch.stack(
127 | [(1 - t[i + 1] / t[i]).mean() for i in range(0, len(t) - 1)]
128 | ).mean()
129 | return float(gcr)
130 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import traceback
3 | from shutil import rmtree
4 |
5 | import hydra
6 | from hydra.utils import instantiate
7 | from omegaconf import DictConfig
8 | from lightning.pytorch import seed_everything
9 | from torch.utils.data import DataLoader
10 |
11 | from utils.lightning import LightningWrapper
12 |
13 | log = logging.getLogger(__name__)
14 |
15 |
16 | @hydra.main(version_base=None, config_path="conf", config_name="train")
17 | def train(cfg: DictConfig) -> None:
18 | if cfg.get("seed"):
19 | log.info(f"Seed everything with <{cfg.seed}>...")
20 | seed_everything(cfg.seed, workers=True)
21 |
22 | log.info("Initializing DataLoaders...")
23 | dataset = instantiate(cfg.dataset)
24 | if cfg.use_validation:
25 | train_dataset, val_dataset = dataset.get_train_val_split(**cfg.train_val_split)
26 | train_dataloader = DataLoader(train_dataset, **cfg.loader.train)
27 | val_dataloader = DataLoader(val_dataset, **cfg.loader.validation)
28 | else:
29 | train_dataloader = DataLoader(dataset, **cfg.loader.train)
30 | val_dataloader = None
31 |
32 | log.info("Initializing LightningWrapper...")
33 | lt_wrapper = LightningWrapper(cfg)
34 |
35 | log.info("Initializing training utilities...")
36 | logger = instantiate(cfg.logger)
37 | callbacks = list(instantiate(cfg.callbacks).values())
38 |
39 | log.info("Initializing trainer...")
40 | trainer = instantiate(
41 | cfg.trainer,
42 | logger=logger,
43 | callbacks=callbacks,
44 | )
45 |
46 | log.info("Starting training...")
47 | try:
48 | trainer.fit(lt_wrapper, train_dataloader, val_dataloader)
49 | except Exception as e:
50 | log.error(f"Finished with error:\n{traceback.format_exc()}")
51 |
52 | # cleaning up if it was testrun
53 | if cfg.trainer.fast_dev_run:
54 | rmtree(cfg.output_dir)
55 |
56 | log.info("Training finished!")
57 |
58 |
59 | if __name__ == "__main__":
60 | train()
61 |
--------------------------------------------------------------------------------
/src/utils/lightning.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 | from hydra.utils import instantiate
6 | from lightning.pytorch import LightningModule
7 | from lightning.pytorch.utilities import grad_norm
8 | from omegaconf import DictConfig
9 |
10 | from model.separator import Separator
11 |
12 |
13 | class LightningWrapper(LightningModule):
14 | """
15 | This class serves as a LightningModule for training and evaluating the source separation model.
16 |
17 | Args:
18 | - cfg (DictConfig): Configuration object containing wrapper settings.
19 | """
20 |
21 | def __init__(self, cfg: DictConfig) -> None:
22 | """
23 | Initializes the LightningWrapper.
24 | """
25 | super().__init__()
26 |
27 | self.cfg = cfg
28 |
29 | self.sep = Separator(**cfg.separator)
30 |
31 | self.loss = instantiate(cfg.loss)
32 | self.optimizer = instantiate(cfg.optimizer, params=self.sep.parameters())
33 | self.scheduler = instantiate(cfg.scheduler) if cfg.get("scheduler") else None
34 | self.metrics = nn.ModuleDict(instantiate(cfg.metrics))
35 |
36 | self.save_hyperparameters(cfg)
37 |
38 | def training_step(
39 | self, batch: Dict[str, torch.Tensor], batch_idx: int
40 | ) -> torch.Tensor:
41 | loss = self.step(batch, mode="train")
42 | return loss
43 |
44 | def validation_step(
45 | self, batch: Dict[str, torch.Tensor], batch_idx: int
46 | ) -> torch.Tensor:
47 | loss = self.step(batch, mode="val")
48 | return loss
49 |
50 | def step(self, batch: Dict[str, torch.Tensor], mode: str = "train") -> torch.Tensor:
51 | """
52 | Common step function for training and validation.
53 |
54 | Args:
55 | - batch (torch.Tensor): Batch of input data.
56 | - mode (str): Mode of operation. Defaults to 'train'.
57 |
58 | Returns:
59 | - torch.Tensor: Loss tensor.
60 | """
61 | wav_mix, wav_src = batch["mixture"], batch["sources"]
62 | wav_src_hat, spec_src_hat = self.sep(wav_mix)
63 | spec_src, _ = self.sep.apply_stft(wav_src)
64 |
65 | loss = self.loss(spec_src_hat, spec_src)
66 |
67 | self.log(f"{mode}/loss", loss.detach(), prog_bar=True)
68 |
69 | if mode == "val":
70 | metrics = self.compute_metrics(wav_src_hat, wav_src)
71 | self.log_dict(metrics)
72 | return loss
73 |
74 | @torch.no_grad()
75 | def compute_metrics(
76 | self, preds: torch.Tensor, target: torch.Tensor
77 | ) -> Dict[str, torch.Tensor]:
78 | """
79 | Computes metrics for evaluation.
80 |
81 | Args:
82 | - preds (torch.Tensor): Predicted sources tensor.
83 | - target (torch.Tensor): Target sources tensor.
84 |
85 | Returns:
86 | - Dict[str, torch.Tensor]: Dictionary containing computed metrics.
87 | """
88 | metrics = {}
89 | for key in self.metrics:
90 | for i, source in enumerate(self.cfg.dataset.sources):
91 | metrics[f"val/{key}_{source}"] = self.metrics[key](
92 | preds[:, i], target[:, i]
93 | )
94 | metrics[f"val/{key}"] = self.metrics[key](preds, target)
95 | return metrics
96 |
97 | def on_before_optimizer_step(self, *args, **kwargs) -> None:
98 | norms = grad_norm(self, norm_type=2)
99 | norms = dict(filter(lambda elem: "_total" in elem[0], norms.items()))
100 | self.log_dict(norms)
101 | return
102 |
103 | def configure_optimizers(self):
104 | return [self.optimizer]
105 |
106 | def get_metrics(self, *args, **kwargs):
107 | items = super().get_metrics()
108 | items.pop("v_num", None)
109 | return items
110 |
--------------------------------------------------------------------------------
/src/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class RMSELoss(nn.Module):
6 | """
7 | Root Mean Squared Error Loss for Complex Inputs.
8 |
9 | Args:
10 | - eps (float): A small value to prevent division by zero.
11 | """
12 |
13 | def __init__(self, eps: float = 1e-6):
14 | """
15 | Initialize RMSELoss.
16 | """
17 | super().__init__()
18 | self.mse = nn.MSELoss()
19 | self.eps = eps
20 |
21 | def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
22 | """
23 | Compute the forward pass of RMSELoss.
24 |
25 | Args:
26 | - pred (torch.Tensor): Predicted values, a tensor of shape (batch_size, ..., 2),
27 | where last dimension is real/imaginary part.
28 | - target (torch.Tensor): Target values, a tensor of shape (batch_size, ..., 2),
29 | where last dimension is real/imaginary part.
30 |
31 | Returns:
32 | - loss (torch.Tensor): Computed RMSE loss.
33 | """
34 | loss = torch.sqrt(
35 | self.mse(pred[..., 0], target[..., 0])
36 | + self.mse(pred[..., 1], target[..., 1])
37 | + self.eps
38 | )
39 | return loss
40 |
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import torch
4 | from torchmetrics.audio import SignalDistortionRatio
5 |
6 |
7 | def global_signal_distortion_ratio(
8 | preds: torch.Tensor, target: torch.Tensor, epsilon: float = 1e-7
9 | ) -> torch.Tensor:
10 | """
11 | Calculates the Global Signal Distortion Ratio (GSDR) between predicted and target signals.
12 |
13 | Args:
14 | - preds (torch.Tensor): Predicted signal tensor.
15 | - target (torch.Tensor): Target signal tensor.
16 | - epsilon (float, optional): Small value to avoid division by zero. Defaults to 1e-7.
17 |
18 | Returns:
19 | - torch.Tensor: Mean Global Signal Distortion Ratio (GSDR) over the batch.
20 | """
21 | num = torch.sum(torch.square(target), dim=(-2, -1)) + epsilon
22 | den = torch.sum(torch.square(target - preds), dim=(-2, -1)) + epsilon
23 | usdr = 10 * torch.log10(num / den)
24 | return usdr.mean()
25 |
26 |
27 | class GlobalSignalDistortionRatio(SignalDistortionRatio):
28 | """
29 | Computes the Global Signal Distortion Ratio (GSDR) metric for audio signals
30 | as it was described firstly in Sound Demixing Challenge.
31 |
32 | This metric calculates the ratio between the energy of the original signal
33 | and the energy of the difference between the original and the predicted signal,
34 | measured in decibels (dB).
35 |
36 | Paper: https://arxiv.org/pdf/2308.06979.pdf
37 | """
38 |
39 | def __init__(
40 | self,
41 | epsilon: float = 1e-7,
42 | **kwargs: Any,
43 | ) -> None:
44 | """
45 | Initializes the GlobalSignalDistortionRatio metric.
46 | """
47 | super().__init__(**kwargs)
48 |
49 | self.epsilon = epsilon
50 |
51 | self.add_state("sum_sdr", default=torch.tensor(0.0), dist_reduce_fx="sum")
52 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
53 |
54 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
55 | """Update state with predictions and targets."""
56 | sdr_batch = global_signal_distortion_ratio(preds, target, self.epsilon)
57 |
58 | self.sum_sdr += sdr_batch.sum()
59 | self.total += sdr_batch.numel()
60 |
--------------------------------------------------------------------------------