├── src ├── __init__.py ├── models │ ├── __init__.py │ ├── modules.py │ └── mdxnet.py ├── utils │ ├── __init__.py │ ├── data_augmentation.py │ └── utils.py ├── callbacks │ ├── __init__.py │ ├── onnx_callback.py │ └── wandb_callbacks.py ├── datamodules │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ └── musdb.py │ └── musdb_datamodule.py ├── mdx_kit │ ├── evaluator │ │ ├── __init__.py │ │ ├── aicrowd_helpers.py │ │ └── music_demixing.py │ └── README.md ├── evaluation │ ├── separate.py │ └── eval.py └── train.py ├── tests ├── __init__.py ├── smoke │ ├── __init__.py │ ├── test_sweeps.py │ ├── test_wandb.py │ ├── test_mixed_precision.py │ └── test_commands.py ├── submit │ └── to_onnx.py ├── unit │ ├── __init__.py │ └── test_sth.py └── helpers │ ├── __init__.py │ ├── run_command.py │ ├── module_available.py │ └── runif.py ├── configs ├── callbacks │ ├── none.yaml │ ├── default.yaml │ └── wandb.yaml ├── logger │ ├── none.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── tensorboard.yaml │ ├── neptune.yaml │ └── wandb.yaml ├── experiment │ ├── multigpu_default.yaml │ ├── multigpu_bass.yaml │ ├── multigpu_drums.yaml │ ├── multigpu_other.yaml │ ├── multigpu_vocals.yaml │ ├── mixer.yaml │ └── debug.yaml ├── trainer │ ├── minimal.yaml │ └── default.yaml ├── hydra │ └── default.yaml ├── model │ ├── ConvTDFNet_bass.yaml │ ├── ConvTDFNet_drums.yaml │ ├── ConvTDFNet_vocals.yaml │ ├── Mixer.yaml │ └── ConvTDFNet_other.yaml ├── evaluation.yaml ├── datamodule │ └── musdb18_hq.yaml └── config.yaml ├── onnx_callback.png ├── val_loss_vocals.png ├── conda_env_gpu.yaml ├── .env.example ├── run_eval.py ├── setup.cfg ├── requirements.txt ├── .pre-commit-config.yaml ├── LICENSE ├── run.py ├── README.md ├── .gitignore └── README_SUBMISSION.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/smoke/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/submit/to_onnx.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/logger/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/mdx_kit/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/datamodules/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /onnx_callback.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuielab/mdx-net/HEAD/onnx_callback.png -------------------------------------------------------------------------------- /val_loss_vocals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kuielab/mdx-net/HEAD/val_loss_vocals.png -------------------------------------------------------------------------------- /src/mdx_kit/README.md: -------------------------------------------------------------------------------- 1 | [original code](https://github.com/AIcrowd/music-demixing-challenge-starter-kit) -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | version: null 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /configs/experiment/multigpu_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /callbacks: wandb 8 | - override /logger: wandb -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - aim.yaml 5 | # - comet.yaml 6 | - csv.yaml 7 | # - mlflow.yaml 8 | # - neptune.yaml 9 | # - tensorboard.yaml 10 | - wandb.yaml 11 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | version: null 8 | log_graph: False 9 | default_hp_metric: True 10 | prefix: "" 11 | -------------------------------------------------------------------------------- /conda_env_gpu.yaml: -------------------------------------------------------------------------------- 1 | name: mdx-net 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | - anaconda 7 | 8 | dependencies: 9 | - python=3.8 10 | - pip 11 | - cudatoolkit 12 | - pytorch=1.8.1 13 | - torchvision=0.9.1 14 | - jupyter 15 | - librosa=0.8 16 | - ffmpeg 17 | - onnxruntime 18 | - nvidia-apex 19 | 20 | - pip: 21 | - -r requirements.txt 22 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable 6 | project_name: your_name/template-tests 7 | close_after_fit: True 8 | offline_mode: False 9 | experiment_name: null 10 | experiment_id: null 11 | prefix: "" 12 | -------------------------------------------------------------------------------- /tests/helpers/run_command.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | import sh 5 | 6 | 7 | def run_command(command: List[str]): 8 | """Default method for executing shell commands with pytest.""" 9 | msg = None 10 | try: 11 | sh.python(command) 12 | except sh.ErrorReturnCode as e: 13 | msg = e.stderr.decode() 14 | if msg: 15 | pytest.fail(msg=msg) 16 | -------------------------------------------------------------------------------- /configs/trainer/minimal.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | defaults: 4 | - default 5 | 6 | gpus: 4 7 | 8 | resume_from_checkpoint: 9 | auto_lr_find: False 10 | deterministic: True 11 | accelerator: dp 12 | sync_batchnorm: False 13 | 14 | max_epochs: 3000 15 | min_epochs: 1 16 | check_val_every_n_epoch: 10 17 | num_sanity_val_steps: 1 18 | 19 | precision: 16 20 | amp_backend: "native" 21 | amp_level: "O2" -------------------------------------------------------------------------------- /configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # output paths for hydra logs 2 | run: 3 | dir: logs/runs/${datamodule.target_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 4 | 5 | sweep: 6 | dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} 7 | subdir: ${hydra.job.num} 8 | 9 | # you can set here environment variables that are universal for all users 10 | # for system specific variables (like data paths) it's better to use .env file! 11 | job: 12 | env_set: 13 | EXAMPLE_VAR: "example_value" 14 | 15 | -------------------------------------------------------------------------------- /configs/experiment/multigpu_bass.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | 9 | seed: 2021 10 | 11 | logger: 12 | wandb: 13 | name: 'mdx_bass' 14 | 15 | trainer: 16 | gpus: '0,1,2' 17 | 18 | datamodule: 19 | batch_size: 8 20 | num_workers: 8 21 | pin_memory: False 22 | overlap: 8192 23 | 24 | callbacks: 25 | early_stopping: 26 | patience: 500 -------------------------------------------------------------------------------- /configs/experiment/multigpu_drums.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | 9 | seed: 2021 10 | 11 | logger: 12 | wandb: 13 | name: 'mdx_drums' 14 | 15 | trainer: 16 | gpus: '0,1' 17 | 18 | datamodule: 19 | batch_size: 8 20 | num_workers: 8 21 | pin_memory: False 22 | overlap: 2048 23 | 24 | callbacks: 25 | early_stopping: 26 | patience: 500 -------------------------------------------------------------------------------- /configs/experiment/multigpu_other.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | 9 | seed: 2021 10 | 11 | logger: 12 | wandb: 13 | name: 'mdx_other' 14 | 15 | trainer: 16 | gpus: '0,1' 17 | 18 | datamodule: 19 | batch_size: 8 20 | num_workers: 8 21 | pin_memory: False 22 | overlap: 4096 23 | 24 | callbacks: 25 | early_stopping: 26 | patience: 500 -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 5 | project: mdx_${model.target_name} 6 | name: null 7 | save_dir: "." 8 | offline: False # set True to store all logs only locally 9 | id: null # pass correct id to resume experiment! 10 | # entity: "" # set to name of your wandb team or just remove it 11 | log_model: False 12 | prefix: "" 13 | job_type: "train" 14 | group: "" 15 | tags: [] 16 | -------------------------------------------------------------------------------- /configs/model/ConvTDFNet_bass.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.mdxnet.ConvTDFNet 2 | 3 | target_name: 'bass' 4 | 5 | # absolute path to the model checkpoint you want to load (random init if null) 6 | ckpt: null 7 | 8 | # model 9 | num_blocks: 11 10 | l: 3 11 | g: 32 12 | k: 3 13 | bn: 8 14 | bias: False 15 | 16 | # stft 17 | n_fft: 16384 18 | dim_f: 2048 19 | dim_t: 256 20 | dim_c: 4 21 | hop_length: 1024 22 | 23 | overlap: 8192 24 | 25 | # optimizer 26 | lr: 0.0001 27 | optimizer: rmsprop -------------------------------------------------------------------------------- /configs/model/ConvTDFNet_drums.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.mdxnet.ConvTDFNet 2 | 3 | target_name: 'drums' 4 | 5 | # absolute path to the model checkpoint you want to load (random init if null) 6 | ckpt: null 7 | 8 | # model 9 | num_blocks: 11 10 | l: 3 11 | g: 32 12 | k: 3 13 | bn: 8 14 | bias: False 15 | 16 | # stft 17 | n_fft: 4096 18 | dim_f: 2048 19 | dim_t: 256 20 | dim_c: 4 21 | hop_length: 1024 22 | 23 | overlap: 2048 24 | 25 | # optimizer 26 | lr: 0.0001 27 | optimizer: rmsprop -------------------------------------------------------------------------------- /configs/model/ConvTDFNet_vocals.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.mdxnet.ConvTDFNet 2 | 3 | target_name: 'vocals' 4 | 5 | # absolute path to the model checkpoint you want to load (random init if null) 6 | ckpt: null 7 | 8 | # model 9 | num_blocks: 11 10 | l: 3 11 | g: 32 12 | k: 3 13 | bn: 8 14 | bias: False 15 | 16 | # stft 17 | n_fft: 6144 18 | dim_f: 2048 19 | dim_t: 256 20 | dim_c: 4 21 | hop_length: 1024 22 | 23 | overlap: 3072 24 | 25 | # optimizer 26 | lr: 0.0002 27 | optimizer: rmsprop -------------------------------------------------------------------------------- /configs/experiment/multigpu_vocals.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | 9 | seed: 2021 10 | 11 | logger: 12 | wandb: 13 | name: 'mdx_vocals' 14 | 15 | trainer: 16 | gpus: '0,1,2,3' 17 | 18 | datamodule: 19 | batch_size: 16 20 | num_workers: 8 21 | pin_memory: False 22 | overlap: 3072 23 | 24 | callbacks: 25 | early_stopping: 26 | patience: 500 -------------------------------------------------------------------------------- /configs/experiment/mixer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - multigpu_default 8 | 9 | seed: 2021 10 | 11 | logger: 12 | wandb: 13 | name: 'mixer' 14 | 15 | trainer: 16 | gpus: '0' 17 | max_epochs: 10 18 | min_epochs: 1 19 | check_val_every_n_epoch: 1 20 | 21 | datamodule: 22 | batch_size: 16 23 | num_workers: 8 24 | pin_memory: False 25 | external_datasets: 26 | - test -------------------------------------------------------------------------------- /configs/model/Mixer.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.mdxnet.Mixer 2 | 3 | target_name: 'all' 4 | 5 | separator_configs: 6 | vocals: ${work_dir}/configs/model/ConvTDFNet_vocals.yaml 7 | drums: ${work_dir}/configs/model/ConvTDFNet_drums.yaml 8 | bass: ${work_dir}/configs/model/ConvTDFNet_bass.yaml 9 | other: ${work_dir}/configs/model/ConvTDFNet_other.yaml 10 | 11 | separator_ckpts: 12 | vocals: null 13 | drums: null 14 | bass: null 15 | other: null 16 | 17 | dim_t: 256 18 | hop_length: 1024 19 | overlap: 8192 20 | 21 | # optimizer 22 | lr: 0.003 23 | optimizer: rmsprop -------------------------------------------------------------------------------- /configs/model/ConvTDFNet_other.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.models.mdxnet.ConvTDFNet 2 | 3 | target_name: 'other' 4 | 5 | # absolute path to the model checkpoint you want to load (random init if null) 6 | ckpt: #'/home/ielab/PycharmProjects/mdx-net/logs/runs/other/2021-08-04/01-01-09/checkpoints/epoch=1529.ckpt' 7 | 8 | # model 9 | num_blocks: 11 10 | l: 3 11 | g: 32 12 | k: 3 13 | bn: 8 14 | bias: False 15 | 16 | # stft 17 | n_fft: 8192 18 | dim_f: 2048 19 | dim_t: 256 20 | dim_c: 4 21 | hop_length: 1024 22 | 23 | overlap: 4096 24 | 25 | # optimizer 26 | lr: 0.0001 27 | optimizer: rmsprop -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | # this is example of the file that can be used for storing private and user specific environment variables, like keys or system paths 2 | # create a file named .env (by default .env will be excluded from version control) 3 | # the variables declared in .env are loaded in run.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | wandb_api_key= [YOUR WANDB API KEY] # go wandb.ai/settings and copy your key 7 | data_dir= [Your MUSDBHQ Data PATH] # Your Musdb data directory. must be an absolute path. 8 | HYDRA_FULL_ERROR=1 -------------------------------------------------------------------------------- /configs/evaluation.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - model: ConvTDFNet_other 6 | - logger: wandb.yaml 7 | # enable color logging 8 | - override hydra/hydra_logging: colorlog 9 | - override hydra/job_logging: colorlog 10 | 11 | ckpt_path: 'other/2021-08-02/23-12-31/checkpoints/epoch=339.ckpt' 12 | 13 | split: 'valid' 14 | batch_size: 4 15 | device: 'cuda:2' 16 | 17 | data_dir: ${oc.env:data_dir} 18 | ckpt_dir: ${oc.env:ckpt_dir} 19 | wandb_api_key: ${oc.env:wandb_api_key} 20 | 21 | logger: 22 | wandb: 23 | project: mdx_eval_${split} 24 | name: ${ckpt_path} -------------------------------------------------------------------------------- /tests/unit/test_sth.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.runif import RunIf 4 | 5 | 6 | def test_something1(): 7 | """Some test description.""" 8 | assert True is True 9 | 10 | 11 | def test_something2(): 12 | """Some test description.""" 13 | assert 1 + 1 == 2 14 | 15 | 16 | @pytest.mark.parametrize("arg1", [0.5, 1.0, 2.0]) 17 | def test_something3(arg1: float): 18 | """Some test description.""" 19 | assert arg1 > 0 20 | 21 | 22 | # use RunIf to skip execution of some tests, e.g. when not on windows or when no gpus are available 23 | @RunIf(skip_windows=True, min_gpus=1) 24 | def test_something4(): 25 | """Some test description.""" 26 | assert True is True 27 | -------------------------------------------------------------------------------- /run_eval.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | dotenv.load_dotenv(override=True) 6 | 7 | 8 | @hydra.main(config_path="configs/", config_name="evaluation.yaml") 9 | def main(config: DictConfig): 10 | # Imports should be nested inside @hydra.main to optimize tab completion 11 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 12 | from src.evaluation.eval import evaluation 13 | from src.utils import utils 14 | 15 | # Pretty print config using Rich library 16 | if config.get("print_config"): 17 | utils.print_config(config, resolve=True) 18 | 19 | evaluation(config) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | project_name = ... 3 | author = ... 4 | contact = ... 5 | license_file = LICENSE 6 | description_file = README.md 7 | project_template = https://github.com/ashleve/lightning-hydra-template 8 | 9 | 10 | [isort] 11 | line_length = 99 12 | profile = black 13 | filter_files = True 14 | 15 | 16 | [flake8] 17 | max_line_length = 99 18 | show_source = True 19 | format = pylint 20 | ignore = 21 | F401 # Module imported but unused 22 | W504 # Line break occurred after a binary operator 23 | F841 # Local variable name is assigned to but never used 24 | exclude = 25 | .git 26 | __pycache__ 27 | data/* 28 | tests/* 29 | notebooks/* 30 | logs/* 31 | 32 | 33 | [tool:pytest] 34 | python_files = tests/* 35 | log_cli = True 36 | markers = 37 | slow 38 | addopts = 39 | --durations=0 40 | --strict-markers 41 | --doctest-modules 42 | filterwarnings = 43 | ignore::DeprecationWarning 44 | -------------------------------------------------------------------------------- /configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: "val/sdr" # name of the logged metric which determines when model is improving 4 | save_top_k: 5 # save k best models (determined by above metric) 5 | save_last: True # additionaly always save model from last epoch 6 | mode: "max" # can be "max" or "min" 7 | verbose: False 8 | dirpath: "checkpoints/" 9 | filename: "{epoch:02d}" 10 | # 11 | early_stopping: 12 | _target_: pytorch_lightning.callbacks.EarlyStopping 13 | monitor: "val/sdr" # name of the logged metric which determines when model is improving 14 | patience: 300 # how many epochs of not improving until training stops 15 | mode: "max" # can be "max" or "min" 16 | min_delta: 0.05 # minimum change in the monitored metric needed to qualify as an improvement 17 | 18 | make_onnx: 19 | _target_: src.callbacks.onnx_callback.MakeONNXCallback 20 | dirpath: "onnx/" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch==1.8.1 3 | torchvision==0.9.1 4 | pytorch-lightning>=1.3.2 5 | torchmetrics>=0.3.2 6 | onnxruntime 7 | onnxruntime-gpu 8 | onnx 9 | 10 | # --------- hydra --------- # 11 | hydra-core==1.1.0.rc1 12 | hydra-colorlog==1.1.0.dev1 13 | hydra-optuna-sweeper==1.1.0.dev2 14 | # hydra-ax-sweeper==1.1.0 15 | # hydra-ray-launcher==0.1.2 16 | # hydra-submitit-launcher==1.1.0 17 | 18 | # --------- loggers --------- # 19 | wandb>=0.10.30 20 | # neptune-client 21 | # mlflow 22 | # comet-ml 23 | # torch_tb_profiler 24 | 25 | # --------- linters --------- # 26 | pre-commit 27 | black 28 | isort 29 | flake8 30 | 31 | # --------- others --------- # 32 | jupyterlab 33 | python-dotenv 34 | rich 35 | pytest 36 | sh 37 | scikit-learn 38 | seaborn 39 | pudb 40 | # dvc 41 | 42 | 43 | # --------- kuielab --------- # 44 | aicrowd_api 45 | coloredlogs 46 | loguru 47 | openunmix 48 | musdb 49 | asteroid>=0.5.0 50 | demucs 51 | -------------------------------------------------------------------------------- /configs/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | watch_model: 5 | _target_: src.callbacks.wandb_callbacks.WatchModel 6 | log: "all" 7 | log_freq: 100 8 | 9 | #upload_valid_track: 10 | # _target_: src.callbacks.wandb_callbacks.UploadValidTrack 11 | # crop: 3 12 | # upload_after_n_epoch: -1 13 | 14 | #upload_code_as_artifact: 15 | # _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact 16 | # code_dir: ${work_dir}/src 17 | # 18 | #upload_ckpts_as_artifact: 19 | # _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact 20 | # ckpt_dir: "checkpoints/" 21 | # upload_best_only: True 22 | # 23 | #log_f1_precision_recall_heatmap: 24 | # _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap 25 | # 26 | #log_confusion_matrix: 27 | # _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix 28 | # 29 | #log_image_predictions: 30 | # _target_: src.callbacks.wandb_callbacks.LogImagePredictions 31 | # num_samples: 8 32 | -------------------------------------------------------------------------------- /tests/helpers/module_available.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from importlib.util import find_spec 3 | 4 | """ 5 | Adapted from: 6 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/imports.py 7 | """ 8 | 9 | 10 | def _module_available(module_path: str) -> bool: 11 | """Check if a path is available in your environment. 12 | 13 | >>> _module_available('os') 14 | True 15 | >>> _module_available('bla.bla') 16 | False 17 | 18 | """ 19 | try: 20 | return find_spec(module_path) is not None 21 | except AttributeError: 22 | # Python 3.6 23 | return False 24 | except ModuleNotFoundError: 25 | # Python 3.7+ 26 | return False 27 | 28 | 29 | _IS_WINDOWS = platform.system() == "Windows" 30 | _APEX_AVAILABLE = _module_available("apex.amp") 31 | _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available("deepspeed") 32 | _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") 33 | _RPC_AVAILABLE = not _IS_WINDOWS and _module_available("torch.distributed.rpc") 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.8 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v3.4.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | - id: check-yaml 12 | - id: check-added-large-files 13 | - id: debug-statements 14 | - id: detect-private-key 15 | 16 | # python code formatting 17 | - repo: https://github.com/psf/black 18 | rev: 20.8b1 19 | hooks: 20 | - id: black 21 | args: [--line-length, "99"] 22 | 23 | # python import sorting 24 | - repo: https://github.com/PyCQA/isort 25 | rev: 5.8.0 26 | hooks: 27 | - id: isort 28 | 29 | # yaml formatting 30 | - repo: https://github.com/pre-commit/mirrors-prettier 31 | rev: v2.3.0 32 | hooks: 33 | - id: prettier 34 | types: [yaml] 35 | 36 | # python code analysis 37 | - repo: https://github.com/PyCQA/flake8 38 | rev: 3.9.2 39 | hooks: 40 | - id: flake8 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 KINoAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/smoke/test_sweeps.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | """ 6 | Use the following command to skip slow tests: 7 | pytest -k "not slow" 8 | """ 9 | 10 | 11 | @pytest.mark.slow 12 | def test_default_sweep(): 13 | """Test default Hydra sweeper.""" 14 | command = [ 15 | "run.py", 16 | "-m", 17 | "datamodule.batch_size=64,128", 18 | "model.lr=0.01,0.02", 19 | "trainer=default", 20 | "trainer.fast_dev_run=true", 21 | ] 22 | run_command(command) 23 | 24 | 25 | @pytest.mark.slow 26 | def test_optuna_sweep(): 27 | """Test Optuna sweeper.""" 28 | command = [ 29 | "run.py", 30 | "-m", 31 | "hparams_search=mnist_optuna", 32 | "trainer=default", 33 | "trainer.fast_dev_run=true", 34 | ] 35 | run_command(command) 36 | 37 | 38 | @pytest.mark.skip(reason="TODO: Add Ax sweep config.") 39 | @pytest.mark.slow 40 | def test_ax_sweep(): 41 | """Test Ax sweeper.""" 42 | command = ["run.py", "-m", "hparams_search=mnist_ax", "trainer.fast_dev_run=true"] 43 | run_command(command) 44 | -------------------------------------------------------------------------------- /tests/smoke/test_wandb.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | 5 | """ 6 | Use the following command to skip slow tests: 7 | pytest -k "not slow" 8 | """ 9 | 10 | 11 | # @pytest.mark.slow 12 | # def test_wandb_optuna_sweep(): 13 | # """Test wandb logging with Optuna sweep.""" 14 | # command = [ 15 | # "run.py", 16 | # "-m", 17 | # "hparams_search=mnist_optuna", 18 | # "trainer=default", 19 | # "trainer.max_epochs=10", 20 | # "trainer.limit_train_batches=20", 21 | # "logger=wandb", 22 | # "logger.wandb.project=template-tests", 23 | # "logger.wandb.group=Optuna_SimpleDenseNet_MNIST", 24 | # "hydra.sweeper.n_trials=5", 25 | # ] 26 | # run_command(command) 27 | 28 | 29 | # @pytest.mark.slow 30 | # def test_wandb_callbacks(): 31 | # """Test wandb callbacks.""" 32 | # command = [ 33 | # "run.py", 34 | # "trainer=default", 35 | # "trainer.max_epochs=3", 36 | # "logger=wandb", 37 | # "logger.wandb.project=template-tests", 38 | # "callbacks=wandb", 39 | # ] 40 | # run_command(command) 41 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | # load environment variables from `.env` file if it exists 6 | # recursively searches for `.env` in all folders starting from work dir 7 | from pytorch_lightning.utilities import rank_zero_info 8 | 9 | dotenv.load_dotenv(override=True) 10 | 11 | 12 | @hydra.main(config_path="configs/", config_name="config.yaml") 13 | def main(config: DictConfig): 14 | 15 | # Imports should be nested inside @hydra.main to optimize tab completion 16 | # Read more here: https://github.com/facebookresearch/hydra/issues/934 17 | from src.train import train 18 | from src.utils import utils 19 | 20 | rank_zero_info(OmegaConf.to_yaml(config)) 21 | 22 | # A couple of optional utilities: 23 | # - disabling python warnings 24 | # - easier access to debug mode 25 | # - forcing debug friendly configuration 26 | # - forcing multi-gpu friendly configuration 27 | # You can safely get rid of this line if you don't want those 28 | utils.extras(config) 29 | 30 | # Pretty print config using Rich library 31 | if config.get("print_config"): 32 | utils.print_config(config, resolve=True) 33 | 34 | # Train model 35 | return train(config) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /tests/smoke/test_mixed_precision.py: -------------------------------------------------------------------------------- 1 | from tests.helpers.run_command import run_command 2 | from tests.helpers.runif import RunIf 3 | 4 | 5 | @RunIf(amp_apex=True) 6 | def test_apex_01(): 7 | """Test mixed-precision level 01.""" 8 | command = [ 9 | "run.py", 10 | "trainer=default", 11 | "trainer.max_epochs=1", 12 | "trainer.gpus=1", 13 | "trainer.amp_backend=apex", 14 | "trainer.amp_level=O1", 15 | "trainer.precision=16", 16 | ] 17 | run_command(command) 18 | 19 | 20 | @RunIf(amp_apex=True) 21 | def test_apex_02(): 22 | """Test mixed-precision level 02.""" 23 | command = [ 24 | "run.py", 25 | "trainer=default", 26 | "trainer.max_epochs=1", 27 | "trainer.gpus=1", 28 | "trainer.amp_backend=apex", 29 | "trainer.amp_level=O2", 30 | "trainer.precision=16", 31 | ] 32 | run_command(command) 33 | 34 | 35 | @RunIf(amp_apex=True) 36 | def test_apex_03(): 37 | """Test mixed-precision level 03.""" 38 | command = [ 39 | "run.py", 40 | "trainer=default", 41 | "trainer.max_epochs=1", 42 | "trainer.gpus=1", 43 | "trainer.amp_backend=apex", 44 | "trainer.amp_level=O3", 45 | "trainer.precision=16", 46 | ] 47 | run_command(command) 48 | -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | # default values for all trainer parameters 4 | checkpoint_callback: True 5 | default_root_dir: null 6 | gradient_clip_val: 0.0 7 | process_position: 0 8 | num_nodes: 1 9 | num_processes: 1 10 | gpus: null 11 | auto_select_gpus: False 12 | tpu_cores: null 13 | log_gpu_memory: null 14 | progress_bar_refresh_rate: 1 15 | overfit_batches: 0.0 16 | track_grad_norm: -1 17 | check_val_every_n_epoch: 1 18 | fast_dev_run: False 19 | accumulate_grad_batches: 1 20 | max_epochs: 1 21 | min_epochs: 1 22 | max_steps: null 23 | min_steps: null 24 | limit_train_batches: 1.0 25 | limit_val_batches: 1.0 26 | limit_test_batches: 1.0 27 | val_check_interval: 1.0 28 | flush_logs_every_n_steps: 100 29 | log_every_n_steps: 50 30 | accelerator: null 31 | sync_batchnorm: False 32 | precision: 32 33 | weights_summary: "top" 34 | weights_save_path: null 35 | num_sanity_val_steps: 0 36 | truncated_bptt_steps: null 37 | resume_from_checkpoint: null 38 | profiler: null 39 | benchmark: False 40 | deterministic: False 41 | reload_dataloaders_every_epoch: False 42 | auto_lr_find: False 43 | replace_sampler_ddp: True 44 | terminate_on_nan: False 45 | auto_scale_batch_size: False 46 | prepare_data_per_node: True 47 | plugins: null 48 | amp_backend: "native" 49 | amp_level: "O2" 50 | move_metrics_to_cpu: False 51 | -------------------------------------------------------------------------------- /configs/datamodule/musdb18_hq.yaml: -------------------------------------------------------------------------------- 1 | _target_: src.datamodules.musdb_datamodule.MusdbDataModule 2 | 3 | # data_dir is specified in config.yaml 4 | data_dir: ${data_dir} 5 | 6 | # chunk_size = (hop_length * (dim_t - 1) / sample_rate) secs 7 | sample_rate: 44100 8 | hop_length: ${model.hop_length} # stft hop_length 9 | dim_t: ${model.dim_t} # number of stft frames 10 | 11 | # number of overlapping wave samples between chunks when separating a whole track 12 | overlap: ${model.overlap} 13 | 14 | source_names: 15 | - bass 16 | - drums 17 | - other 18 | - vocals 19 | target_name: ${model.target_name} 20 | 21 | external_datasets: 22 | 23 | batch_size: 8 24 | num_workers: 0 25 | pin_memory: False 26 | 27 | aug_params: 28 | - 2 # maximum pitch shift in semitones (-x < shift param < x) 29 | - 20 # maximum time stretch percentage (-x < stretch param < x) 30 | 31 | validation_set: 32 | - Actions - One Minute Smile 33 | - Clara Berry And Wooldog - Waltz For My Victims 34 | - Johnny Lokke - Promises & Lies 35 | - Patrick Talbot - A Reason To Leave 36 | - Triviul - Angelsaint 37 | - Alexander Ross - Goodbye Bolero 38 | - Fergessen - Nos Palpitants 39 | - Leaf - Summerghost 40 | - Skelpolu - Human Mistakes 41 | - Young Griffo - Pennies 42 | - ANiMAL - Rockshow 43 | - James May - On The Line 44 | - Meaxic - Take A Step 45 | - Traffic Experiment - Sirens -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - trainer: minimal 6 | - datamodule: musdb18_hq 7 | - callbacks: default # set this to null if you don't want to use callbacks 8 | - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`) 9 | - model: null 10 | - experiment: null 11 | - hparams_search: null 12 | 13 | - hydra: default 14 | 15 | # enable color logging 16 | - override hydra/hydra_logging: colorlog 17 | - override hydra/job_logging: colorlog 18 | 19 | 20 | # path to original working directory 21 | # hydra hijacks working directory by changing it to the current log directory, 22 | # so it's useful to have this path as a special variable 23 | # learn more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 24 | work_dir: ${hydra:runtime.cwd} 25 | 26 | # path to folder with data 27 | data_dir: ${oc.env:data_dir} 28 | 29 | # use `python run.py debug=true` for easy debugging! 30 | # this will run 1 train, val and test loop with only 1 batch 31 | # equivalent to running `python run.py trainer.fast_dev_run=true` 32 | # (this is placed here just for easier access from command line) 33 | debug: False 34 | 35 | # pretty print config at the start of the run using Rich library 36 | print_config: True 37 | 38 | # disable python warnings if they annoy you 39 | ignore_warnings: True 40 | 41 | wandb_api_key: ${oc.env:wandb_api_key} -------------------------------------------------------------------------------- /configs/experiment/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python run.py experiment=example_simple.yaml 5 | 6 | defaults: 7 | - override /trainer: minimal.yaml 8 | - override /callbacks: none.yaml 9 | - override /logger: none.yaml 10 | 11 | # Training Environment 12 | target_name: vocals 13 | batch_size: 4 14 | gpus: 1 15 | lr: 0.0003 16 | check_val_every_n_epoch: 1 17 | seed: 2021 18 | deterministic: True 19 | num_workers: 2 # num of cpus 20 | pin_memory: False # True then faster 21 | 22 | # TODO: check effects 23 | sync_batchnorm: False 24 | resume_from_checkpoint: 25 | 26 | # STFT 27 | n_fft: 6144 28 | dim_f: 2048 29 | dim_t: 128 30 | dim_c: 4 31 | hop_length: 1024 32 | 33 | # data augmentation 34 | external_datasets: 35 | augmentation: False 36 | 37 | model: 38 | # model 39 | _target_: src.models.mdxnet.ConvTDFNet 40 | target_name: ${target_name} 41 | 42 | # model configuration 43 | num_blocks: 3 44 | l: 5 45 | g: 8 46 | k: 3 47 | bn: 8 48 | bias: False 49 | 50 | # optimizer 51 | lr: ${lr} 52 | optimizer: rmsprop 53 | 54 | # stft 55 | n_fft: ${n_fft} 56 | dim_f: ${dim_f} 57 | dim_t: ${dim_t} 58 | dim_c: ${dim_c} 59 | hop_length: ${hop_length} 60 | 61 | trainer: 62 | max_epochs: 3 63 | min_epochs: 1 64 | sync_batchnorm: ${sync_batchnorm} 65 | precision: 16 66 | resume_from_checkpoint: ${resume_from_checkpoint} 67 | auto_lr_find: False 68 | amp_backend: "native" 69 | amp_level: "O2" 70 | 71 | deterministic: ${deterministic} 72 | gpus: 1 73 | check_val_every_n_epoch: ${check_val_every_n_epoch} 74 | 75 | # auto script 76 | datamodule: 77 | num_workers: ${num_workers} # num of cpus 78 | pin_memory: ${pin_memory} # True then faster 79 | external_datasets: ${external_datasets} 80 | augmentation: ${augmentation} 81 | batch_size: ${batch_size} 82 | target_name: ${target_name} 83 | n_fft: ${n_fft} 84 | hop_length: ${hop_length} 85 | dim_c: ${dim_c} 86 | dim_f: ${dim_f} 87 | dim_t: ${dim_t} 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KUIELab-MDX-Net 2 | 3 | - [presentation slide](https://ws-choi.github.io/personal/presentations/slide/2021-08-21-aicrowd) 4 | 5 | ## 0. Environment 6 | 7 | - Ubuntu 20.04 8 | - at least four cuda-able GPUs (each >= 2080ti) 9 | - 1.5 TB disk storage for data augmentation 10 | - wandb for logging 11 | 12 | Also, you ***must*** create .env file by copying .env.sample to set environmental variables. 13 | 14 | ``` 15 | wandb_api_key=[Your Key] # "xxxxxxxxxxxxxxxxxxxxxxxx" 16 | data_dir=[Your Path] # "/home/ielab/repos/musdbHQ" 17 | ``` 18 | 19 | - about ```wandb_api_key``` 20 | - we currently only support wandb for logging. 21 | - for ```wandb_api_key```, visit [wandb](https://wandb.ai/site), go to ```setting```, and then copy your api key 22 | - about ```data_dir``` 23 | - the ***absolute*** path where datasets are stored 24 | 25 | ## 1. Installation 26 | 27 | ```bash 28 | conda env create -f conda_env_gpu.yaml -n mdx-net 29 | conda activate mdx-net 30 | pip install -r requirements.txt 31 | sudo apt-get install soundstretch 32 | ``` 33 | 34 | ## 2. Training & Submission 35 | 36 | - [Leaderboard_A](https://github.com/kuielab/mdx-net/tree/Leaderboard_A) 37 | - [Leaderboard_B](https://github.com/kuielab/mdx-net/tree/Leaderboard_B) 38 | 39 | ## 3. Leaderboard A vs Leaderboard B 40 | 41 | - The main difference between the branch [Leaderboard_A](https://github.com/kuielab/mdx-net/tree/Leaderboard_A) and [Leaderboard_B](https://github.com/kuielab/mdx-net/tree/Leaderboard_B) is the usage of the test dataset of Musdb18. 42 | - Leaderboard A does not use test dataset for training: https://github.com/kuielab/mdx-net/blob/Leaderboard_A/configs/experiment/multigpu_default.yaml 43 | - Leaderboard B uses test dataset for training: https://github.com/kuielab/mdx-net/blob/b45eff172928dc9fc31852ee65072fb01f4c2d08/configs/experiment/multigpu_default.yaml#L16 44 | 45 | # ACKNOWLEDGEMENT 46 | 47 | - This repository is based on [Lightning-Hydra Template](https://github.com/ashleve/lightning-hydra-template) 48 | - Repository of [TFC-TDF-U-Net](https://github.com/ws-choi/ISMIR2020_U_Nets_SVS), our previous ISMIR 2020 paper 49 | - Also, facebook/[demucs](https://github.com/facebookresearch/demucs) 50 | -------------------------------------------------------------------------------- /src/callbacks/onnx_callback.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from typing import Dict, Any 3 | 4 | import torch 5 | from pytorch_lightning import Callback 6 | import pytorch_lightning as pl 7 | import inspect 8 | from src.models.mdxnet import AbstractMDXNet 9 | 10 | 11 | class MakeONNXCallback(Callback): 12 | """Upload all *.py files to wandb as an artifact, at the beginning of the run.""" 13 | 14 | def __init__(self, dirpath: str): 15 | self.dirpath = dirpath 16 | if not os.path.exists(self.dirpath): 17 | os.mkdir(self.dirpath) 18 | 19 | def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', 20 | checkpoint: Dict[str, Any]) -> dict: 21 | res = super().on_save_checkpoint(trainer, pl_module, checkpoint) 22 | 23 | var = inspect.signature(pl_module.__init__).parameters 24 | model = pl_module.__class__(**dict((name, pl_module.__dict__[name]) for name in var)) 25 | model.load_state_dict(pl_module.state_dict()) 26 | 27 | target_dir = '{}epoch_{}'.format(self.dirpath, pl_module.current_epoch) 28 | 29 | try: 30 | if not os.path.exists(target_dir): 31 | os.mkdir(target_dir) 32 | 33 | with torch.no_grad(): 34 | torch.onnx.export(model, 35 | torch.zeros(model.input_sample_shape), 36 | '{}/{}.onnx'.format(target_dir, model.target_name), 37 | export_params=True, # store the trained parameter weights inside the model file 38 | opset_version=13, # the ONNX version to export the model to 39 | do_constant_folding=True, # whether to execute constant folding for optimization 40 | input_names=['input'], # the model's input names 41 | output_names=['output'], # the model's output names 42 | dynamic_axes={'input': {0: 'batch_size'}, # variable length axes 43 | 'output': {0: 'batch_size'}}) 44 | except: 45 | print('onnx error') 46 | finally: 47 | del model 48 | 49 | return res 50 | -------------------------------------------------------------------------------- /src/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TFC(nn.Module): 6 | def __init__(self, c, l, k): 7 | super(TFC, self).__init__() 8 | 9 | self.H = nn.ModuleList() 10 | for i in range(l): 11 | self.H.append( 12 | nn.Sequential( 13 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), 14 | nn.BatchNorm2d(c), 15 | nn.ReLU(), 16 | ) 17 | ) 18 | 19 | def forward(self, x): 20 | for h in self.H: 21 | x = h(x) 22 | return x 23 | 24 | 25 | class DenseTFC(nn.Module): 26 | def __init__(self, c, l, k): 27 | super(DenseTFC, self).__init__() 28 | 29 | self.conv = nn.ModuleList() 30 | for i in range(l): 31 | self.conv.append( 32 | nn.Sequential( 33 | nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2), 34 | nn.BatchNorm2d(c), 35 | nn.ReLU(), 36 | ) 37 | ) 38 | 39 | def forward(self, x): 40 | for layer in self.conv[:-1]: 41 | x = torch.cat([layer(x), x], 1) 42 | return self.conv[-1](x) 43 | 44 | 45 | class TFC_TDF(nn.Module): 46 | def __init__(self, c, l, f, k, bn, dense=False, bias=True): 47 | 48 | super(TFC_TDF, self).__init__() 49 | 50 | self.use_tdf = bn is not None 51 | 52 | self.tfc = DenseTFC(c, l, k) if dense else TFC(c, l, k) 53 | 54 | if self.use_tdf: 55 | if bn == 0: 56 | self.tdf = nn.Sequential( 57 | nn.Linear(f, f, bias=bias), 58 | nn.BatchNorm2d(c), 59 | nn.ReLU() 60 | ) 61 | else: 62 | self.tdf = nn.Sequential( 63 | nn.Linear(f, f // bn, bias=bias), 64 | nn.BatchNorm2d(c), 65 | nn.ReLU(), 66 | nn.Linear(f // bn, f, bias=bias), 67 | nn.BatchNorm2d(c), 68 | nn.ReLU() 69 | ) 70 | 71 | def forward(self, x): 72 | x = self.tfc(x) 73 | return x + self.tdf(x) if self.use_tdf else x 74 | 75 | -------------------------------------------------------------------------------- /src/evaluation/separate.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from pathlib import Path 3 | 4 | import torch 5 | import numpy as np 6 | import onnxruntime as ort 7 | 8 | 9 | def separate_with_onnx(batch_size, model, onnx_path: Path, mix): 10 | n_sample = mix.shape[1] 11 | 12 | trim = model.n_fft // 2 13 | gen_size = model.sampling_size - 2 * trim 14 | pad = gen_size - n_sample % gen_size 15 | mix_p = np.concatenate((np.zeros((2, trim)), mix, np.zeros((2, pad)), np.zeros((2, trim))), 1) 16 | 17 | mix_waves = [] 18 | i = 0 19 | while i < n_sample + pad: 20 | waves = np.array(mix_p[:, i:i + model.sampling_size], dtype=np.float32) 21 | mix_waves.append(waves) 22 | i += gen_size 23 | mix_waves_batched = torch.tensor(mix_waves, dtype=torch.float32).split(batch_size) 24 | 25 | tar_signals = [] 26 | 27 | with torch.no_grad(): 28 | _ort = ort.InferenceSession(str(onnx_path)) 29 | for mix_waves in mix_waves_batched: 30 | tar_waves = model.istft(torch.tensor( 31 | _ort.run(None, {'input': model.stft(mix_waves).numpy()})[0] 32 | )) 33 | tar_signals.append(tar_waves[:, :, trim:-trim].transpose(0, 1).reshape(2, -1).numpy()) 34 | tar_signal = np.concatenate(tar_signals, axis=-1)[:, :-pad] 35 | 36 | return tar_signal 37 | 38 | 39 | def separate_with_ckpt(batch_size, model, ckpt_path: Path, mix, device): 40 | model = model.load_from_checkpoint(ckpt_path).to(device) 41 | true_samples = model.sampling_size - 2 * model.trim 42 | 43 | right_pad = true_samples + model.trim - ((mix.shape[-1]) % true_samples) 44 | mixture = np.concatenate((np.zeros((2, model.trim), dtype='float32'), 45 | mix, 46 | np.zeros((2, right_pad), dtype='float32')), 47 | 1) 48 | num_chunks = mixture.shape[-1] // true_samples 49 | mix_waves_batched = [mixture[:, i * true_samples: i * true_samples + model.sampling_size] for i in 50 | range(num_chunks)] 51 | mix_waves_batched = torch.tensor(mix_waves_batched, dtype=torch.float32).split(batch_size) 52 | 53 | target_wav_hats = [] 54 | 55 | with torch.no_grad(): 56 | model.eval() 57 | for mixture_wav in mix_waves_batched: 58 | mix_spec = model.stft(mixture_wav.to(device)) 59 | spec_hat = model(mix_spec) 60 | target_wav_hat = model.istft(spec_hat) 61 | target_wav_hat = target_wav_hat.cpu().detach().numpy() 62 | target_wav_hats.append(target_wav_hat) 63 | 64 | target_wav_hat = np.vstack(target_wav_hats)[:, :, model.trim:-model.trim] 65 | target_wav_hat = np.concatenate(target_wav_hat, axis=-1)[:, :mix.shape[-1]] 66 | return target_wav_hat 67 | -------------------------------------------------------------------------------- /src/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from pathlib import Path 3 | from typing import Optional, List 4 | 5 | import hydra 6 | import wandb 7 | from omegaconf import DictConfig 8 | from pytorch_lightning import LightningDataModule, LightningModule 9 | from pytorch_lightning.loggers import LightningLoggerBase, WandbLogger 10 | from tqdm import tqdm 11 | import numpy as np 12 | from src.callbacks.wandb_callbacks import get_wandb_logger 13 | from src.evaluation.separate import separate_with_onnx, separate_with_ckpt 14 | from src.utils import utils 15 | from src.utils.utils import load_wav, sdr 16 | 17 | log = utils.get_logger(__name__) 18 | 19 | 20 | def evaluation(config: DictConfig): 21 | 22 | assert config.split in ['train', 'valid', 'test'] 23 | 24 | data_dir = Path(config.get('data_dir')).joinpath(config['split']) 25 | assert data_dir.exists() 26 | 27 | # Init Lightning loggers 28 | loggers: List[LightningLoggerBase] = [] 29 | if "logger" in config: 30 | for _, lg_conf in config.logger.items(): 31 | if "_target_" in lg_conf: 32 | log.info(f"Instantiating logger <{lg_conf._target_}>") 33 | loggers.append(hydra.utils.instantiate(lg_conf)) 34 | 35 | if any([isinstance(l, WandbLogger) for l in loggers]): 36 | utils.wandb_login(key=config.wandb_api_key) 37 | 38 | model = hydra.utils.instantiate(config.model) 39 | target_name = model.target_name 40 | ckpt_path = Path(config.ckpt_dir).joinpath(config.ckpt_path) 41 | 42 | scores = [] 43 | num_tracks = len(listdir(data_dir)) 44 | for i, track in tqdm(enumerate(sorted(listdir(data_dir)))): 45 | track = data_dir.joinpath(track) 46 | mixture = load_wav(track.joinpath('mixture.wav')) 47 | target = load_wav(track.joinpath(target_name + '.wav')) 48 | #target_hat = {source: separate(config['batch_size'], models[source], onnxs[source], mixture) for source in sources} 49 | target_hat = separate_with_ckpt(config.batch_size, model, ckpt_path, mixture, config.device) 50 | score = sdr(target_hat, target) 51 | scores.append(score) 52 | 53 | for logger in loggers: 54 | logger.log_metrics({'sdr': score}, i) 55 | 56 | for wandb_logger in [logger for logger in loggers if isinstance(logger, WandbLogger)]: 57 | mid = mixture.shape[-1] // 2 58 | track = target_hat[:, mid - 44100 * 3:mid + 44100 * 3] 59 | wandb_logger.experiment.log( 60 | {f'track={i}_target={target_name}': [wandb.Audio(track.T, sample_rate=44100)]}) 61 | 62 | for logger in loggers: 63 | logger.log_metrics({'mean_sdr_' + target_name: sum(scores)/num_tracks}) 64 | logger.close() 65 | 66 | if any([isinstance(logger, WandbLogger) for logger in loggers]): 67 | wandb.finish() 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Lightning-Hydra-Template 143 | data/ 144 | logs/ 145 | wandb/ 146 | .env 147 | .autoenv 148 | 149 | onnx/* 150 | outputs/* 151 | -------------------------------------------------------------------------------- /tests/smoke/test_commands.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from tests.helpers.run_command import run_command 4 | from tests.helpers.runif import RunIf 5 | 6 | 7 | def test_fast_dev_run(): 8 | """Run 1 train, val, test batch.""" 9 | command = ["run.py", "trainer=default", "trainer.fast_dev_run=true"] 10 | run_command(command) 11 | 12 | 13 | def test_default_cpu(): 14 | """Test default configuration on CPU.""" 15 | command = ["run.py", "trainer.max_epochs=1", "trainer.gpus=0"] 16 | run_command(command) 17 | 18 | 19 | @RunIf(min_gpus=1) 20 | def test_default_gpu(): 21 | """Test default configuration on GPU.""" 22 | command = [ 23 | "run.py", 24 | "trainer.max_epochs=1", 25 | "trainer.gpus=1", 26 | "datamodule.pin_memory=True", 27 | ] 28 | run_command(command) 29 | 30 | 31 | @pytest.mark.slow 32 | def test_experiments(): 33 | """Train 1 epoch with all experiment configs.""" 34 | command = ["run.py", "-m", "experiment=glob(*)", "trainer.max_epochs=1"] 35 | run_command(command) 36 | 37 | 38 | def test_limit_batches(): 39 | """Train 1 epoch on 25% of data.""" 40 | command = [ 41 | "run.py", 42 | "trainer=default", 43 | "trainer.max_epochs=1", 44 | "trainer.limit_train_batches=0.25", 45 | "trainer.limit_val_batches=0.25", 46 | "trainer.limit_test_batches=0.25", 47 | ] 48 | run_command(command) 49 | 50 | 51 | def test_gradient_accumulation(): 52 | """Train 1 epoch with gradient accumulation.""" 53 | command = [ 54 | "run.py", 55 | "trainer=default", 56 | "trainer.max_epochs=1", 57 | "trainer.accumulate_grad_batches=10", 58 | ] 59 | run_command(command) 60 | 61 | 62 | def test_double_validation_loop(): 63 | """Train 1 epoch with validation loop twice per epoch.""" 64 | command = [ 65 | "run.py", 66 | "trainer=default", 67 | "trainer.max_epochs=1", 68 | "trainer.val_check_interval=0.5", 69 | ] 70 | run_command(command) 71 | 72 | 73 | def test_csv_logger(): 74 | """Train 5 epochs with 5 batches with CSVLogger.""" 75 | command = [ 76 | "run.py", 77 | "trainer=default", 78 | "trainer.max_epochs=5", 79 | "trainer.limit_train_batches=5", 80 | "logger=csv", 81 | ] 82 | run_command(command) 83 | 84 | 85 | def test_tensorboard_logger(): 86 | """Train 5 epochs with 5 batches with TensorboardLogger.""" 87 | command = [ 88 | "run.py", 89 | "trainer=default", 90 | "trainer.max_epochs=5", 91 | "trainer.limit_train_batches=5", 92 | "logger=tensorboard", 93 | ] 94 | run_command(command) 95 | 96 | 97 | def test_overfit_batches(): 98 | """Overfit to 10 batches over 10 epochs.""" 99 | command = [ 100 | "run.py", 101 | "trainer=default", 102 | "trainer.min_epochs=10", 103 | "trainer.max_epochs=10", 104 | "trainer.overfit_batches=10", 105 | ] 106 | run_command(command) 107 | -------------------------------------------------------------------------------- /src/mdx_kit/evaluator/aicrowd_helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import aicrowd_api 3 | import os 4 | 5 | ######################################################################## 6 | # Instatiate Event Notifier 7 | ######################################################################## 8 | aicrowd_events = aicrowd_api.events.AIcrowdEvents() 9 | 10 | 11 | def execution_start(): 12 | ######################################################################## 13 | # Register Evaluation Start event 14 | ######################################################################## 15 | aicrowd_events.register_event( 16 | event_type=aicrowd_events.AICROWD_EVENT_INFO, 17 | message="execution_started", 18 | payload={ 19 | "event_type": "airborne_detection:execution_started" 20 | } 21 | ) 22 | 23 | def execution_running(): 24 | ######################################################################## 25 | # Register Evaluation Start event 26 | ######################################################################## 27 | aicrowd_events.register_event( 28 | event_type=aicrowd_events.AICROWD_EVENT_INFO, 29 | message="execution_progress", 30 | payload={ 31 | "event_type": "airborne_detection:execution_progress", 32 | "progress": 0.0 33 | } 34 | ) 35 | 36 | 37 | def execution_progress(progress): 38 | ######################################################################## 39 | # Register Evaluation Progress event 40 | ######################################################################## 41 | aicrowd_events.register_event( 42 | event_type=aicrowd_events.AICROWD_EVENT_INFO, 43 | message="execution_progress", 44 | payload={ 45 | "event_type": "airborne_detection:execution_progress", 46 | "progress" : progress 47 | } 48 | ) 49 | 50 | def execution_success(): 51 | ######################################################################## 52 | # Register Evaluation Complete event 53 | ######################################################################## 54 | predictions_output_path = os.getenv("PREDICTIONS_OUTPUT_PATH", False) 55 | 56 | aicrowd_events.register_event( 57 | event_type=aicrowd_events.AICROWD_EVENT_SUCCESS, 58 | message="execution_success", 59 | payload={ 60 | "event_type": "airborne_detection:execution_success", 61 | "predictions_output_path" : predictions_output_path 62 | }, 63 | blocking=True 64 | ) 65 | 66 | def execution_error(error): 67 | ######################################################################## 68 | # Register Evaluation Complete event 69 | ######################################################################## 70 | aicrowd_events.register_event( 71 | event_type=aicrowd_events.AICROWD_EVENT_ERROR, 72 | message="execution_error", 73 | payload={ #Arbitrary Payload 74 | "event_type": "airborne_detection:execution_error", 75 | "error" : error 76 | }, 77 | blocking=True 78 | ) 79 | 80 | def is_grading(): 81 | return os.getenv("AICROWD_IS_GRADING", False) 82 | -------------------------------------------------------------------------------- /tests/helpers/runif.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import Optional 3 | 4 | import pytest 5 | import torch 6 | from packaging.version import Version 7 | from pkg_resources import get_distribution 8 | 9 | """ 10 | Adapted from: 11 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py 12 | """ 13 | 14 | from tests.helpers.module_available import ( 15 | _APEX_AVAILABLE, 16 | _DEEPSPEED_AVAILABLE, 17 | _FAIRSCALE_AVAILABLE, 18 | _IS_WINDOWS, 19 | _RPC_AVAILABLE, 20 | ) 21 | 22 | 23 | class RunIf: 24 | """ 25 | RunIf wrapper for conditional skipping of tests. 26 | Fully compatible with `@pytest.mark`. 27 | 28 | Example: 29 | 30 | @RunIf(min_torch="1.8") 31 | @pytest.mark.parametrize("arg1", [1.0, 2.0]) 32 | def test_wrapper(arg1): 33 | assert arg1 > 0 34 | 35 | """ 36 | 37 | def __new__( 38 | self, 39 | min_gpus: int = 0, 40 | min_torch: Optional[str] = None, 41 | max_torch: Optional[str] = None, 42 | min_python: Optional[str] = None, 43 | amp_apex: bool = False, 44 | skip_windows: bool = False, 45 | rpc: bool = False, 46 | fairscale: bool = False, 47 | deepspeed: bool = False, 48 | **kwargs, 49 | ): 50 | """ 51 | Args: 52 | min_gpus: min number of gpus required to run test 53 | min_torch: minimum pytorch version to run test 54 | max_torch: maximum pytorch version to run test 55 | min_python: minimum python version required to run test 56 | amp_apex: NVIDIA Apex is installed 57 | skip_windows: skip test for Windows platform 58 | rpc: requires Remote Procedure Call (RPC) 59 | fairscale: if `fairscale` module is required to run the test 60 | deepspeed: if `deepspeed` module is required to run the test 61 | kwargs: native pytest.mark.skipif keyword arguments 62 | """ 63 | conditions = [] 64 | reasons = [] 65 | 66 | if min_gpus: 67 | conditions.append(torch.cuda.device_count() < min_gpus) 68 | reasons.append(f"GPUs>={min_gpus}") 69 | 70 | if min_torch: 71 | torch_version = get_distribution("torch").version 72 | conditions.append(Version(torch_version) < Version(min_torch)) 73 | reasons.append(f"torch>={min_torch}") 74 | 75 | if max_torch: 76 | torch_version = get_distribution("torch").version 77 | conditions.append(Version(torch_version) >= Version(max_torch)) 78 | reasons.append(f"torch<{max_torch}") 79 | 80 | if min_python: 81 | py_version = ( 82 | f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" 83 | ) 84 | conditions.append(Version(py_version) < Version(min_python)) 85 | reasons.append(f"python>={min_python}") 86 | 87 | if amp_apex: 88 | conditions.append(not _APEX_AVAILABLE) 89 | reasons.append("NVIDIA Apex") 90 | 91 | if skip_windows: 92 | conditions.append(_IS_WINDOWS) 93 | reasons.append("does not run on Windows") 94 | 95 | if rpc: 96 | conditions.append(not _RPC_AVAILABLE) 97 | reasons.append("RPC") 98 | 99 | if fairscale: 100 | conditions.append(not _FAIRSCALE_AVAILABLE) 101 | reasons.append("Fairscale") 102 | 103 | if deepspeed: 104 | conditions.append(not _DEEPSPEED_AVAILABLE) 105 | reasons.append("Deepspeed") 106 | 107 | reasons = [rs for cond, rs in zip(conditions, reasons) if cond] 108 | return pytest.mark.skipif( 109 | condition=any(conditions), 110 | reason=f"Requires: [{' + '.join(reasons)}]", 111 | **kwargs, 112 | ) 113 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import hydra 4 | import torch 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import ( 7 | Callback, 8 | LightningDataModule, 9 | LightningModule, 10 | Trainer, 11 | seed_everything, 12 | ) 13 | from pytorch_lightning.loggers import LightningLoggerBase, WandbLogger 14 | 15 | from src.utils import utils 16 | 17 | log = utils.get_logger(__name__) 18 | 19 | 20 | def train(config: DictConfig) -> Optional[float]: 21 | """Contains training pipeline. 22 | Instantiates all PyTorch Lightning objects from config. 23 | 24 | Args: 25 | config (DictConfig): Configuration composed by Hydra. 26 | 27 | Returns: 28 | Optional[float]: Metric score for hyperparameter optimization. 29 | """ 30 | 31 | # Set seed for random number generators in pytorch, numpy and python.random 32 | try: 33 | if "seed" in config: 34 | seed_everything(config.seed, workers=True) 35 | else: 36 | raise ModuleNotFoundError 37 | 38 | except ModuleNotFoundError: 39 | print('[Error] seed should be fixed for reproducibility \n=> e.g. python run.py +seed=$SEED') 40 | exit(-1) 41 | 42 | # Init Lightning datamodule 43 | log.info(f"Instantiating datamodule <{config.datamodule._target_}>") 44 | datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) 45 | 46 | # Init Lightning model 47 | log.info(f"Instantiating model <{config.model._target_}>") 48 | model: LightningModule = hydra.utils.instantiate(config.model) 49 | 50 | # Init Lightning callbacks 51 | callbacks: List[Callback] = [] 52 | if "callbacks" in config: 53 | for _, cb_conf in config["callbacks"].items(): 54 | if "_target_" in cb_conf: 55 | log.info(f"Instantiating callback <{cb_conf._target_}>") 56 | callbacks.append(hydra.utils.instantiate(cb_conf)) 57 | 58 | # Init Lightning loggers 59 | logger: List[LightningLoggerBase] = [] 60 | if "logger" in config: 61 | for _, lg_conf in config["logger"].items(): 62 | if "_target_" in lg_conf: 63 | log.info(f"Instantiating logger <{lg_conf._target_}>") 64 | logger.append(hydra.utils.instantiate(lg_conf)) 65 | 66 | if any([isinstance(l, WandbLogger) for l in logger]): 67 | utils.wandb_login(key=config.wandb_api_key) 68 | 69 | # Init Lightning trainer 70 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 71 | trainer: Trainer = hydra.utils.instantiate( 72 | config.trainer, callbacks=callbacks, logger=logger, _convert_="partial" 73 | ) 74 | 75 | # Send some parameters from config to all lightning loggers 76 | log.info("Logging hyperparameters!") 77 | utils.log_hyperparameters( 78 | config=config, 79 | model=model, 80 | datamodule=datamodule, 81 | trainer=trainer, 82 | callbacks=callbacks, 83 | logger=logger, 84 | ) 85 | 86 | # Train the model 87 | log.info("Starting training!") 88 | trainer.fit(model=model, datamodule=datamodule) 89 | 90 | # Evaluate model on test set after training 91 | if not config.trainer.get("fast_dev_run"): 92 | log.info("Starting testing!") 93 | trainer.test() 94 | 95 | # Make sure everything closed properly 96 | log.info("Finalizing!") 97 | utils.finish( 98 | config=config, 99 | model=model, 100 | datamodule=datamodule, 101 | trainer=trainer, 102 | callbacks=callbacks, 103 | logger=logger, 104 | ) 105 | 106 | # Print path to best checkpoint 107 | log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") 108 | 109 | # Return metric score for hyperparameter optimization 110 | optimized_metric = config.get("optimized_metric") 111 | if optimized_metric: 112 | return trainer.callback_metrics[optimized_metric] 113 | -------------------------------------------------------------------------------- /src/utils/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess as sp 3 | import tempfile 4 | import warnings 5 | from argparse import ArgumentParser 6 | 7 | import numpy as np 8 | import soundfile as sf 9 | import torch 10 | from tqdm import tqdm 11 | 12 | warnings.simplefilter(action='ignore', category=Warning) 13 | source_names = ['vocals', 'drums', 'bass', 'other'] 14 | sample_rate = 44100 15 | 16 | def main (args): 17 | data_root = args.data_dir 18 | train = args.train 19 | test = args.test 20 | valid = args.valid 21 | 22 | musdb_train_path = data_root + 'train/' 23 | musdb_test_path = data_root + 'test/' 24 | musdb_valid_path = data_root + 'valid/' 25 | 26 | mix_name = 'mixture' 27 | 28 | P = [-3, -2, -1, 0, 1, 2, 3] # pitch shift amounts (in semitones) 29 | T = [-30, -20, -10, 0, 10, 20, 30] # time stretch amounts (10 means 10% slower) 30 | 31 | for p in P: 32 | for t in T: 33 | if not (p==0 and t==0): 34 | if train: 35 | save_shifted_dataset(p, t, musdb_train_path) 36 | if valid: 37 | save_shifted_dataset(p, t, musdb_valid_path) 38 | if test: 39 | save_shifted_dataset(p, t, musdb_test_path) 40 | 41 | 42 | def shift(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): 43 | def i16_pcm(wav): 44 | if wav.dtype == np.int16: 45 | return wav 46 | return (wav * 2 ** 15).clamp_(-2 ** 15, 2 ** 15 - 1).short() 47 | 48 | def f32_pcm(wav): 49 | if wav.dtype == np.float: 50 | return wav 51 | return wav.float() / 2 ** 15 52 | 53 | """ 54 | tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! 55 | pitch is in semi tones. 56 | Requires `soundstretch` to be installed, see 57 | https://www.surina.net/soundtouch/soundstretch.html 58 | """ 59 | 60 | inputfile = tempfile.NamedTemporaryFile(suffix=".wav") 61 | outfile = tempfile.NamedTemporaryFile(suffix=".wav") 62 | 63 | sf.write(inputfile.name, data=i16_pcm(wav).t().numpy(), samplerate=samplerate, format='WAV') 64 | command = [ 65 | "soundstretch", 66 | inputfile.name, 67 | outfile.name, 68 | f"-pitch={pitch}", 69 | f"-tempo={tempo:.6f}", 70 | ] 71 | if quick: 72 | command += ["-quick"] 73 | if voice: 74 | command += ["-speech"] 75 | try: 76 | sp.run(command, capture_output=True, check=True) 77 | except sp.CalledProcessError as error: 78 | raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") 79 | wav, sr = sf.read(outfile.name, dtype='float32') 80 | # wav = np.float32(wav) 81 | # wav = f32_pcm(torch.from_numpy(wav).t()) 82 | assert sr == samplerate 83 | return wav 84 | 85 | 86 | def save_shifted_dataset(delta_pitch, delta_tempo, data_path): 87 | out_path = data_path[:-1] + f'_p={delta_pitch}_t={delta_tempo}/' 88 | try: 89 | os.mkdir(out_path) 90 | except FileExistsError: 91 | pass 92 | track_names = list(filter(lambda x: os.path.isdir(f'{data_path}/{x}'), sorted(os.listdir(data_path)))) 93 | for track_name in tqdm(track_names): 94 | try: 95 | os.mkdir(f'{out_path}/{track_name}') 96 | except FileExistsError: 97 | pass 98 | for s_name in source_names: 99 | source = load_wav(f'{data_path}/{track_name}/{s_name}.wav') 100 | shifted = shift( 101 | torch.tensor(source), 102 | delta_pitch, 103 | delta_tempo, 104 | voice=s_name == 'vocals') 105 | sf.write(f'{out_path}/{track_name}/{s_name}.wav', shifted, samplerate=sample_rate, format='WAV') 106 | 107 | 108 | def load_wav(path, sr=None): 109 | return sf.read(path, samplerate=sr, dtype='float32')[0].T 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = ArgumentParser() 114 | parser.add_argument('--data_dir', type=str) 115 | parser.add_argument('--train', type=bool, default=True) 116 | parser.add_argument('--valid', type=bool, default=False) 117 | parser.add_argument('--test', type=bool, default=False) 118 | 119 | main(parser.parse_args()) -------------------------------------------------------------------------------- /src/datamodules/musdb_datamodule.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import exists, join 3 | from pathlib import Path 4 | from typing import Optional, Tuple 5 | 6 | from pytorch_lightning import LightningDataModule 7 | from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split 8 | 9 | from src.datamodules.datasets.musdb import MusdbTrainDataset, MusdbValidDataset 10 | 11 | 12 | class MusdbDataModule(LightningDataModule): 13 | """ 14 | LightningDataModule for Musdb18-HQ dataset. 15 | A DataModule implements 5 key methods: 16 | - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) 17 | - setup (things to do on every accelerator in distributed mode) 18 | - train_dataloader (the training dataloader) 19 | - val_dataloader (the validation dataloader(s)) 20 | - test_dataloader (the test dataloader(s)) 21 | This allows you to share a full dataset without explaining how to download, 22 | split, transform and process the data 23 | Read the docs: 24 | https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html 25 | """ 26 | 27 | def __init__( 28 | self, 29 | data_dir: str, 30 | aug_params, 31 | target_name: str, 32 | overlap: int, 33 | hop_length: int, 34 | dim_t: int, 35 | sample_rate: int, 36 | batch_size: int, 37 | num_workers: int, 38 | pin_memory: bool, 39 | external_datasets, 40 | **kwargs, 41 | ): 42 | super().__init__() 43 | 44 | self.data_dir = Path(data_dir) 45 | self.target_name = target_name 46 | self.aug_params = aug_params 47 | self.external_datasets = external_datasets 48 | 49 | self.batch_size = batch_size 50 | self.num_workers = num_workers 51 | self.pin_memory = pin_memory 52 | 53 | # audio-related 54 | self.hop_length = hop_length 55 | self.sample_rate = sample_rate 56 | 57 | # derived 58 | self.chunk_size = hop_length * (dim_t - 1) 59 | self.overlap = overlap 60 | 61 | self.data_train: Optional[Dataset] = None 62 | self.data_val: Optional[Dataset] = None 63 | self.data_test: Optional[Dataset] = None 64 | 65 | trainset_path = self.data_dir.joinpath('train') 66 | validset_path = self.data_dir.joinpath('valid') 67 | 68 | # create validation split 69 | if not exists(validset_path): 70 | from shutil import move 71 | os.mkdir(validset_path) 72 | for track in kwargs['validation_set']: 73 | if trainset_path.joinpath(track).exists(): 74 | move(trainset_path.joinpath(track), validset_path.joinpath(track)) 75 | else: 76 | valid_files = os.listdir(validset_path) 77 | assert set(valid_files) == set(kwargs['validation_set']) 78 | 79 | def setup(self, stage: Optional[str] = None): 80 | """Load data. Set variables: self.data_train, self.data_val, self.data_test.""" 81 | self.data_train = MusdbTrainDataset(self.data_dir, 82 | self.chunk_size, 83 | self.target_name, 84 | self.aug_params, 85 | self.external_datasets) 86 | 87 | self.data_val = MusdbValidDataset(self.data_dir, 88 | self.chunk_size, 89 | self.target_name, 90 | self.overlap, 91 | self.batch_size) 92 | 93 | def train_dataloader(self): 94 | return DataLoader( 95 | dataset=self.data_train, 96 | batch_size=self.batch_size, 97 | num_workers=self.num_workers, 98 | pin_memory=self.pin_memory, 99 | shuffle=True, 100 | ) 101 | 102 | def val_dataloader(self): 103 | return DataLoader( 104 | dataset=self.data_val, 105 | batch_size=1, 106 | num_workers=self.num_workers, 107 | pin_memory=self.pin_memory, 108 | shuffle=False, 109 | ) -------------------------------------------------------------------------------- /src/datamodules/datasets/musdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABCMeta, ABC 3 | from pathlib import Path 4 | 5 | import soundfile 6 | from torch.utils.data import Dataset 7 | import torch 8 | import numpy as np 9 | import random 10 | from tqdm import tqdm 11 | 12 | from src.utils.utils import load_wav 13 | 14 | 15 | def check_target_name(target_name, source_names): 16 | try: 17 | assert target_name is not None 18 | except AssertionError: 19 | print('[ERROR] please identify target name. ex) +datamodule.target_name="vocals"') 20 | exit(-1) 21 | try: 22 | assert target_name in source_names or target_name == 'all' 23 | except AssertionError: 24 | print('[ERROR] target name should one of "bass", "drums", "other", "vocals", "all"') 25 | exit(-1) 26 | 27 | 28 | def check_sample_rate(sr, sample_track): 29 | try: 30 | sample_rate = soundfile.read(sample_track)[1] 31 | assert sample_rate == sr 32 | except AssertionError: 33 | sample_rate = soundfile.read(sample_track)[1] 34 | print('[ERROR] sampling rate mismatched') 35 | print('\t=> sr in Config file: {}, but sr of data: {}'.format(sr, sample_rate)) 36 | exit(-1) 37 | 38 | 39 | class MusdbDataset(Dataset): 40 | __metaclass__ = ABCMeta 41 | 42 | def __init__(self, data_dir, chunk_size): 43 | self.source_names = ['bass', 'drums', 'other', 'vocals'] 44 | self.chunk_size = chunk_size 45 | self.musdb_path = Path(data_dir) 46 | 47 | 48 | class MusdbTrainDataset(MusdbDataset): 49 | def __init__(self, data_dir, chunk_size, target_name, aug_params, external_datasets): 50 | super(MusdbTrainDataset, self).__init__(data_dir, chunk_size) 51 | 52 | self.target_name = target_name 53 | check_target_name(self.target_name, self.source_names) 54 | 55 | if not self.musdb_path.joinpath('metadata').exists(): 56 | os.mkdir(self.musdb_path.joinpath('metadata')) 57 | 58 | splits = ['train'] 59 | if external_datasets is not None: 60 | splits += external_datasets 61 | 62 | # collect paths for datasets and metadata (track names and duration) 63 | datasets, metadata_caches = [], [] 64 | raw_datasets = [] # un-augmented datasets 65 | for split in splits: 66 | raw_datasets.append(self.musdb_path.joinpath(split)) 67 | max_pitch, max_tempo = aug_params 68 | for p in range(-max_pitch, max_pitch+1): 69 | for t in range(-max_tempo, max_tempo+1, 10): 70 | aug_split = split if p==t==0 else split + f'_p={p}_t={t}' 71 | datasets.append(self.musdb_path.joinpath(aug_split)) 72 | metadata_caches.append(self.musdb_path.joinpath('metadata').joinpath(aug_split + '.pkl')) 73 | 74 | # collect all track names and their duration 75 | self.metadata = [] 76 | raw_track_lengths = [] # for calculating epoch size 77 | for i, (dataset, metadata_cache) in enumerate(tqdm(zip(datasets, metadata_caches))): 78 | try: 79 | metadata = torch.load(metadata_cache) 80 | except FileNotFoundError: 81 | print('creating metadata for', dataset) 82 | metadata = [] 83 | for track_name in sorted(os.listdir(dataset)): 84 | track_path = dataset.joinpath(track_name) 85 | track_length = load_wav(track_path.joinpath('vocals.wav')).shape[-1] 86 | metadata.append((track_path, track_length)) 87 | torch.save(metadata, metadata_cache) 88 | 89 | self.metadata += metadata 90 | if dataset in raw_datasets: 91 | raw_track_lengths += [length for path, length in metadata] 92 | 93 | self.epoch_size = sum(raw_track_lengths) // self.chunk_size 94 | 95 | def __getitem__(self, _): 96 | sources = [] 97 | for source_name in self.source_names: 98 | track_path, track_length = random.choice(self.metadata) # random mixing between tracks 99 | source = load_wav(track_path.joinpath(source_name + '.wav'), 100 | track_length=track_length, chunk_size=self.chunk_size) 101 | sources.append(source) 102 | 103 | mix = sum(sources) 104 | 105 | if self.target_name == 'all': 106 | # Targets for models that separate all four sources (ex. Demucs). 107 | # This adds additional 'source' dimension => batch_shape=[batch, source, channel, time] 108 | target = sources 109 | else: 110 | target = sources[self.source_names.index(self.target_name)] 111 | 112 | return torch.tensor(mix), torch.tensor(target) 113 | 114 | def __len__(self): 115 | return self.epoch_size 116 | 117 | 118 | class MusdbValidDataset(MusdbDataset): 119 | 120 | def __init__(self, data_dir, chunk_size, target_name, overlap, batch_size): 121 | super(MusdbValidDataset, self).__init__(data_dir, chunk_size) 122 | 123 | self.target_name = target_name 124 | check_target_name(self.target_name, self.source_names) 125 | 126 | self.overlap = overlap 127 | self.batch_size = batch_size 128 | 129 | musdb_valid_path = self.musdb_path.joinpath('valid') 130 | self.track_paths = [musdb_valid_path.joinpath(track_name) 131 | for track_name in os.listdir(musdb_valid_path)] 132 | 133 | def __getitem__(self, index): 134 | mix = load_wav(self.track_paths[index].joinpath('mixture.wav')) 135 | 136 | if self.target_name == 'all': 137 | # Targets for models that separate all four sources (ex. Demucs). 138 | # This adds additional 'source' dimension => batch_shape=[batch, source, channel, time] 139 | target = [load_wav(self.track_paths[index].joinpath(source_name + '.wav')) 140 | for source_name in self.source_names] 141 | else: 142 | target = load_wav(self.track_paths[index].joinpath(self.target_name + '.wav')) 143 | 144 | chunk_output_size = self.chunk_size - 2 * self.overlap 145 | left_pad = np.zeros([2, self.overlap]) 146 | right_pad = np.zeros([2, self.overlap + chunk_output_size - (mix.shape[-1] % chunk_output_size)]) 147 | mix_padded = np.concatenate([left_pad, mix, right_pad], 1) 148 | 149 | num_chunks = mix_padded.shape[-1] // chunk_output_size 150 | mix_chunks = [mix_padded[:, i * chunk_output_size: i * chunk_output_size + self.chunk_size] 151 | for i in range(num_chunks)] 152 | mix_chunk_batches = torch.tensor(mix_chunks, dtype=torch.float32).split(self.batch_size) 153 | return mix_chunk_batches, torch.tensor(target) 154 | 155 | def __len__(self): 156 | return len(self.track_paths) -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | import pytorch_lightning as pl 4 | import rich.syntax 5 | import rich.tree 6 | import wandb 7 | import numpy as np 8 | import torch 9 | import warnings 10 | import soundfile as sf 11 | from typing import List, Sequence 12 | from omegaconf import DictConfig, OmegaConf 13 | from pytorch_lightning.loggers.wandb import WandbLogger 14 | from pytorch_lightning.utilities import rank_zero_only 15 | 16 | 17 | def sdr(est, ref): 18 | ratio = np.sum(ref**2) / np.sum((ref-est)**2) 19 | return 10*np.log10(ratio + 1e-10) 20 | 21 | 22 | def load_wav(path, track_length=None, chunk_size=None): 23 | if track_length is None: 24 | return sf.read(path, dtype='float32')[0].T 25 | else: 26 | s = np.random.randint(track_length - chunk_size) 27 | return sf.read(path, dtype='float32', start=s, frames=chunk_size)[0].T 28 | 29 | 30 | def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: 31 | """Initializes multi-GPU-friendly python logger.""" 32 | 33 | logger = logging.getLogger(name) 34 | logger.setLevel(level) 35 | 36 | # this ensures all logging levels get marked with the rank zero decorator 37 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 38 | for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): 39 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 40 | 41 | return logger 42 | 43 | 44 | def extras(config: DictConfig) -> None: 45 | """A couple of optional utilities, controlled by main config file: 46 | - disabling warnings 47 | - easier access to debug mode 48 | - forcing debug friendly configuration 49 | - forcing multi-gpu friendly configuration 50 | 51 | Modifies DictConfig in place. 52 | 53 | Args: 54 | config (DictConfig): Configuration composed by Hydra. 55 | """ 56 | 57 | log = get_logger() 58 | 59 | # enable adding new keys to config 60 | OmegaConf.set_struct(config, False) 61 | 62 | # disable python warnings if 63 | if config.get("ignore_warnings"): 64 | log.info("Disabling python warnings! ") 65 | warnings.filterwarnings("ignore") 66 | 67 | # set if 68 | if config.get("debug"): 69 | log.info("Running in debug mode! ") 70 | config.trainer.fast_dev_run = True 71 | 72 | # force debugger friendly configuration if 73 | if config.trainer.get("fast_dev_run"): 74 | log.info("Forcing debugger friendly configuration! ") 75 | # Debuggers don't like GPUs or multiprocessing 76 | if config.trainer.get("gpus"): 77 | config.trainer.num_valid_process = 0 78 | if config.datamodule.get("pin_memory"): 79 | config.datamodule.pin_memory = False 80 | if config.datamodule.get("num_workers"): 81 | config.datamodule.num_workers = 0 82 | 83 | # force multi-gpu friendly configuration if 84 | # accelerator = config.trainer.get("accelerator") 85 | # if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]: 86 | # log.info(f"Forcing ddp friendly configuration! ") 87 | # if config.datamodule.get("num_workers"): 88 | # config.datamodule.num_workers = 0 89 | # if config.datamodule.get("pin_memory"): 90 | # config.datamodule.pin_memory = False 91 | 92 | # disable adding new keys to config 93 | OmegaConf.set_struct(config, True) 94 | 95 | 96 | @rank_zero_only 97 | def print_config( 98 | config: DictConfig, 99 | fields: Sequence[str] = ( 100 | "trainer", 101 | "model", 102 | "datamodule", 103 | "callbacks", 104 | "logger", 105 | "seed", 106 | ), 107 | resolve: bool = True, 108 | ) -> None: 109 | """Prints content of DictConfig using Rich library and its tree structure. 110 | 111 | Args: 112 | config (DictConfig): Configuration composed by Hydra. 113 | fields (Sequence[str], optional): Determines which main fields from config will 114 | be printed and in what order. 115 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 116 | """ 117 | 118 | style = "dim" 119 | tree = rich.tree.Tree(":gear: CONFIG", style=style, guide_style=style) 120 | 121 | for field in fields: 122 | branch = tree.add(field, style=style, guide_style=style) 123 | 124 | config_section = config.get(field) 125 | branch_content = str(config_section) 126 | if isinstance(config_section, DictConfig): 127 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 128 | 129 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 130 | 131 | rich.print(tree) 132 | 133 | 134 | def empty(*args, **kwargs): 135 | pass 136 | 137 | 138 | @rank_zero_only 139 | def log_hyperparameters( 140 | config: DictConfig, 141 | model: pl.LightningModule, 142 | datamodule: pl.LightningDataModule, 143 | trainer: pl.Trainer, 144 | callbacks: List[pl.Callback], 145 | logger: List[pl.loggers.LightningLoggerBase], 146 | ) -> None: 147 | """This method controls which parameters from Hydra config are saved by Lightning loggers. 148 | 149 | Additionaly saves: 150 | - number of trainable model parameters 151 | """ 152 | 153 | hparams = {} 154 | 155 | # choose which parts of hydra config will be saved to loggers 156 | hparams["trainer"] = config["trainer"] 157 | hparams["model"] = config["model"] 158 | hparams["datamodule"] = config["datamodule"] 159 | if "callbacks" in config: 160 | hparams["callbacks"] = config["callbacks"] 161 | 162 | # save number of model parameters 163 | hparams["model/params_total"] = sum(p.numel() for p in model.parameters()) 164 | hparams["model/params_trainable"] = sum( 165 | p.numel() for p in model.parameters() if p.requires_grad 166 | ) 167 | hparams["model/params_not_trainable"] = sum( 168 | p.numel() for p in model.parameters() if not p.requires_grad 169 | ) 170 | 171 | # send hparams to all loggers 172 | trainer.logger.log_hyperparams(hparams) 173 | 174 | # disable logging any more hyperparameters for all loggers 175 | # this is just a trick to prevent trainer from logging hparams of model, 176 | # since we already did that above 177 | trainer.logger.log_hyperparams = empty 178 | 179 | 180 | def finish( 181 | config: DictConfig, 182 | model: pl.LightningModule, 183 | datamodule: pl.LightningDataModule, 184 | trainer: pl.Trainer, 185 | callbacks: List[pl.Callback], 186 | logger: List[pl.loggers.LightningLoggerBase], 187 | ) -> None: 188 | """Makes sure everything closed properly.""" 189 | 190 | # without this sweeps with wandb logger might crash! 191 | for lg in logger: 192 | if isinstance(lg, WandbLogger): 193 | wandb.finish() 194 | 195 | 196 | def wandb_login(key): 197 | wandb.login(key=key) -------------------------------------------------------------------------------- /README_SUBMISSION.md: -------------------------------------------------------------------------------- 1 | # Submission 2 | 3 | ## Submission Summary 4 | 5 | ### Leaderboard A 6 | * Submission ID: 151907 7 | * Submitter: kim_min_seok 8 | * Final rank: 2nd place on leaderboard A 9 | * Final scores on MDXDB21: 10 | 11 | | SDR_song | SDR_bass | SDR_drums | SDR_other | SDR_vocals | 12 | | :------: | :------: | :-------: | :-------: | :--------: | 13 | | 7.24 | 7.23 | 7.17 | 5.64 | 8.90 | 14 | 15 | ### Leaderboard B 16 | * Submission ID: 151249 17 | * Submitter: kim_min_seok 18 | * Final rank: 3nd place on leaderboard A 19 | * Final scores on MDXDB21: 20 | 21 | 22 | | SDR_song | SDR_bass | SDR_drums | SDR_other | SDR_vocals | 23 | | :------: | :------: | :-------: | :-------: | :--------: | 24 | | 7.37 | 7.50 | 7.55 | 5.53 | 8.90 | 25 | 26 | 27 | ## Model Summary 28 | 29 | * Data 30 | * We used the MusDB default 86/14 train and validation splits. 31 | * Augmentation 32 | * Random chunking and mixing sources from different tracks ([1]) 33 | * Pitch shift and time stretch ([2]) 34 | * Model 35 | * Blend[1] of two models: a modified version of TFC-TDF[3] and Demucs[4] 36 | * TFC-TDF 37 | * Models were trained separately for each source. 38 | * The input [frequency, time] dimensions are fixed to [2048, 256] for all sources 39 | * 256 frames = 6 seconds of audio (sample_rate=44100, hop_length=1024) 40 | * High frequencies were cut off from the mixture before being input to the networks, and the number of frequency bins to be discarded differs for each source (ex. drums have more high frequencies compared to bass, so cut off more when doing bass separation). In order to fit the frequency dimension of 2048, n_fft differs for each source. 41 | * We made the following modifications to the original TFC-TDF model: 42 | * No densely connected convolutional blocks 43 | * Multiplicative skip connections 44 | * Increased depth and number of hidden channels 45 | * After training the per-source models we trained an additional network (which we call the 'Mixer') on top of the model outputs, which takes all four estimated sources as input and outputs better estimated sources 46 | * We only tried a single 1x1 convolution layer for the Mixer (due to inference time limit), but still gained at least 0.1 SDR for every source on the MDX test set. 47 | * Mixer is trained without fine-tuning the separation models. 48 | * Demucs 49 | * we used the pretrained model with 64 initial hidden channels (not demucs48_hq) 50 | * overlap=0.5 and no shift trick 51 | * blending parameters (TFC-TDF : Demucs) => bass 5:5, drums 5:5, other 7:3, vocals 9:1 52 | 53 | [1] S. Uhlich et al., "Improving music source separation based on deep neural networks through data augmentation and network blending," 2017 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2017. 54 | 55 | [2] Cohen-Hadria, Alice, Axel Roebel, and Geoffroy Peeters. "Improving singing voice separation using Deep U-Net and Wave-U-Net with data augmentation." 2019 27th European Signal Processing Conference (EUSIPCO). IEEE, 2019. 56 | 57 | [3] Choi, Woosung, et al. "Investigating u-nets with various intermediate blocks for spectrogram-based singing voice separation." 21th International Society for Music Information Retrieval Conference, ISMIR. 2020. 58 | 59 | [4] Défossez, Alexandre, et al. "Music source separation in the waveform domain." arXiv preprint arXiv:1911.13254 (2019). 60 | 61 | 62 | # Reproduction 63 | 64 | ## How to reproduce the submission 65 | 66 | ***Note***: The inference time is very close to the time limit, so submission will randomly fail. You might have to submit it several times. 67 | 68 | - obtain ```.onnx``` files and ```.pt``` file as described in the [following section](#how-to-reproduce-the-training) 69 | - follow this instruction to deploy parameters 70 | ``` 71 | git clone https://github.com/kuielab/mdx-net-submission.git 72 | cd mdx-net-submission 73 | git checkout leaderboard_A 74 | git lfs install 75 | mv ${*.onnx} onnx/ 76 | mv ${*.pt} model/ 77 | ``` 78 | - or visit the following links that hold the pretrained ```.onnx``` files and ```.pt``` file 79 | - [Leaderboard A](https://github.com/kuielab/mdx-net-submission/tree/leaderboard_A) 80 | - [Leaderboard B](https://github.com/kuielab/mdx-net-submission/tree/leaderboard_B) 81 | 82 | - or visit the submitted repository 83 | - [Leaderboard A](https://gitlab.aicrowd.com/kim_min_seok/demix/tree/submission133) 84 | - [Leaderboard B](https://gitlab.aicrowd.com/kim_min_seok/demix/tree/submission106) 85 | 86 | 87 | ## How to reproduce the training 88 | 89 | ### 1. Data Preparation 90 | 91 | Pitch Shift and Time Stretch [2] 92 | - This could have been done on-the-fly along with chunking and mixing ([1]), but we preferred faster train steps over less disk usage. The following scripts are for saving augmented tracks to disk before training. 93 | 94 | - For Leaderboard A 95 | - run ```python src/utils/data_augmentation.py --data_dir ${your_musdb_path} --train True --valid False --test False``` 96 | - For Leaderboard B 97 | - run ```python src/utils/data_augmentation.py --data_dir ${your_musdb_path} --train True --valid True --test True``` 98 | 99 | ### 2. Phase 1 100 | 101 | - Train ```src.models.mdxnet.ConvTDFNet``` for each source. 102 | - vocals: ```python run.py experiment=multigpu_vocals model=ConvTDFNet_vocals``` 103 | - drums: ```python run.py experiment=multigpu_drums model=ConvTDFNet_drums``` 104 | - bass: ```python run.py experiment=multigpu_bass model=ConvTDFNet_bass``` 105 | - other: ```python run.py experiment=multigpu_other model=ConvTDFNet_other``` 106 | 107 | - For training, each takes at least 3 days, usually 4~5 days to early-stop for the current configurations. 108 | 109 | - Default logging system is [wandb](https://www.wandb.com/) 110 | ![](val_loss_vocals.png) 111 | 112 | - Checkpoint result saving callbacks 113 | - We use [onnx](https://onnx.ai/) for faster inference to meet the time limit 114 | - see the [related issue](https://github.com/ws-choi/Conditioned-Source-Separation-LaSAFT/issues/20#issuecomment-840407759) 115 | - You don't have to manually convert ```.onnx``` files. Our code automatically generates ```.onnx``` whenever a new checkpoint is saved by [checkpoint callback](https://github.com/kuielab/mdx-net/blob/7c6f7daecde13c0e8ed97f308577f6690b0c31af/configs/callbacks/default.yaml#L2) 116 | ![](onnx_callback.png) 117 | - This function was implemented as a callback function 118 | - see [this](https://github.com/kuielab/mdx-net/blob/7c6f7daecde13c0e8ed97f308577f6690b0c31af/configs/callbacks/default.yaml#L18) 119 | - and [this](https://github.com/kuielab/mdx-net/blob/7c6f7daecde13c0e8ed97f308577f6690b0c31af/src/callbacks/onnx_callback.py#L11) 120 | 121 | #### The epoch of each checkpoint we used 122 | - Leaderboard A 123 | - vocals: 2360 epoch 124 | - bass: 1720 epoch 125 | - drums: 600 epoch 126 | - other: 1720 epoch 127 | 128 | - Leaderboard B 129 | - vocals: 1960 epoch 130 | - bass: 1200 epoch 131 | - drums: 940 epoch 132 | - other: 1660 epoch 133 | 134 | > note: the models were submitted before convergence, and the learning rate might have not been optimal as well (ex. for 'other', Leaderboard A score is higher) 135 | 136 | ### 3. Phase 2 (Optional) 137 | 138 | This phase **does not fine-tune** the pretrained separators from the previous phase. 139 | 140 | - Train Mixer 141 | - locate candidate checkpoints by appending ```ckpt``` variable in the ```yaml``` config file. 142 | - train ```from src.models.mdxnet Mixer ``` 143 | - save ```.pt```, the only learnable parameters in ```Mixer``` 144 | 145 | 146 | # License 147 | 148 | [MIT Licence](LICENSE.MD) 149 | -------------------------------------------------------------------------------- /src/mdx_kit/evaluator/music_demixing.py: -------------------------------------------------------------------------------- 1 | ###################################################################################### 2 | ### This is a read-only file to allow participants to run their code locally. ### 3 | ### It will be over-writter during the evaluation, Please do not make any changes ### 4 | ### to this file. ### 5 | ###################################################################################### 6 | 7 | import traceback 8 | import os 9 | import signal 10 | from contextlib import contextmanager 11 | from os import listdir 12 | from os.path import isfile, join 13 | 14 | import soundfile as sf 15 | import numpy as np 16 | from evaluator import aicrowd_helpers 17 | 18 | 19 | class TimeoutException(Exception): pass 20 | 21 | 22 | @contextmanager 23 | def time_limit(seconds): 24 | def signal_handler(signum, frame): 25 | raise TimeoutException("Prediction timed out!") 26 | 27 | signal.signal(signal.SIGALRM, signal_handler) 28 | signal.alarm(seconds) 29 | try: 30 | yield 31 | finally: 32 | signal.alarm(0) 33 | 34 | 35 | class MusicDemixingPredictor: 36 | def __init__(self): 37 | self.test_data_path = os.getenv("TEST_DATASET_PATH", os.getcwd() + "/data/test/") 38 | self.results_data_path = os.getenv("RESULTS_DATASET_PATH", os.getcwd() + "/data/results/") 39 | self.inference_setup_timeout = int(os.getenv("INFERENCE_SETUP_TIMEOUT_SECONDS", "900")) 40 | self.inference_per_music_timeout = int(os.getenv("INFERENCE_PER_MUSIC_TIMEOUT_SECONDS", "240")) 41 | self.partial_run = os.getenv("PARTIAL_RUN_MUSIC_NAMES", None) 42 | self.results = [] 43 | self.current_music_name = None 44 | 45 | def get_all_music_names(self): 46 | valid_music_names = None 47 | if self.partial_run: 48 | valid_music_names = self.partial_run.split(',') 49 | music_names = [] 50 | for folder in listdir(self.test_data_path): 51 | if not isfile(join(self.test_data_path, folder)): 52 | if valid_music_names is None or folder in valid_music_names: 53 | music_names.append(folder) 54 | return music_names 55 | 56 | def get_music_folder_location(self, music_name): 57 | return join(self.test_data_path, music_name) 58 | 59 | def get_music_file_location(self, music_name, instrument=None): 60 | if instrument is None: 61 | instrument = "mixture" 62 | return join(self.test_data_path, music_name, instrument + ".wav") 63 | 64 | if not os.path.exists(self.results_data_path): 65 | os.makedirs(self.results_data_path) 66 | if not os.path.exists(join(self.results_data_path, music_name)): 67 | os.makedirs(join(self.results_data_path, music_name)) 68 | 69 | return join(self.results_data_path, music_name, instrument + ".wav") 70 | 71 | def scoring(self): 72 | """ 73 | Add scoring function in the starter kit for participant's reference 74 | """ 75 | def sdr(references, estimates): 76 | # compute SDR for one song 77 | delta = 1e-7 # avoid numerical errors 78 | num = np.sum(np.square(references), axis=(1, 2)) 79 | den = np.sum(np.square(references - estimates), axis=(1, 2)) 80 | num += delta 81 | den += delta 82 | return 10 * np.log10(num / den) 83 | 84 | music_names = self.get_all_music_names() 85 | instruments = ["bass", "drums", "other", "vocals"] 86 | scores = {} 87 | for music_name in music_names: 88 | print("Evaluating for: %s" % music_name) 89 | scores[music_name] = {} 90 | references = [] 91 | estimates = [] 92 | for instrument in instruments: 93 | reference_file = join(self.test_data_path, music_name, instrument + ".wav") 94 | estimate_file = self.get_music_file_location(music_name, instrument) 95 | reference, _ = sf.read(reference_file) 96 | estimate, _ = sf.read(estimate_file) 97 | references.append(reference) 98 | estimates.append(estimate) 99 | references = np.stack(references) 100 | estimates = np.stack(estimates) 101 | references = references.astype(np.float32) 102 | estimates = estimates.astype(np.float32) 103 | song_score = sdr(references, estimates).tolist() 104 | scores[music_name]["sdr_bass"] = song_score[0] 105 | scores[music_name]["sdr_drums"] = song_score[1] 106 | scores[music_name]["sdr_other"] = song_score[2] 107 | scores[music_name]["sdr_vocals"] = song_score[3] 108 | scores[music_name]["sdr"] = np.mean(song_score) 109 | return scores 110 | 111 | 112 | def evaluation(self): 113 | """ 114 | Admin function: Runs the whole evaluation 115 | """ 116 | aicrowd_helpers.execution_start() 117 | try: 118 | with time_limit(self.inference_setup_timeout): 119 | self.prediction_setup() 120 | except NotImplementedError: 121 | print("prediction_setup doesn't exist for this run, skipping...") 122 | 123 | aicrowd_helpers.execution_running() 124 | 125 | music_names = self.get_all_music_names() 126 | 127 | for music_name in music_names: 128 | with time_limit(self.inference_per_music_timeout): 129 | self.prediction(mixture_file_path=self.get_music_file_location(music_name), 130 | bass_file_path=self.get_music_file_location(music_name, "bass"), 131 | drums_file_path=self.get_music_file_location(music_name, "drums"), 132 | other_file_path=self.get_music_file_location(music_name, "other"), 133 | vocals_file_path=self.get_music_file_location(music_name, "vocals"), 134 | ) 135 | 136 | if not self.verify_results(music_name): 137 | raise Exception("verification failed, demixed files not found.") 138 | aicrowd_helpers.execution_success() 139 | 140 | def run(self): 141 | try: 142 | self.evaluation() 143 | except Exception as e: 144 | error = traceback.format_exc() 145 | print(error) 146 | aicrowd_helpers.execution_error(error) 147 | if not aicrowd_helpers.is_grading(): 148 | raise e 149 | 150 | def prediction_setup(self): 151 | """ 152 | You can do any preprocessing required for your codebase here : 153 | like loading your models into memory, etc. 154 | """ 155 | raise NotImplementedError 156 | 157 | def prediction(self, music_name, mixture_file_path, bass_file_path, drums_file_path, other_file_path, 158 | vocals_file_path): 159 | """ 160 | This function will be called for all the flight during the evaluation. 161 | NOTE: In case you want to load your model, please do so in `inference_setup` function. 162 | """ 163 | raise NotImplementedError 164 | 165 | def verify_results(self, music_name): 166 | """ 167 | This function will be called to check all the files exist and other verification needed. 168 | (like length of the wav files) 169 | """ 170 | valid = True 171 | valid = valid and os.path.isfile(self.get_music_file_location(music_name, "vocals")) 172 | valid = valid and os.path.isfile(self.get_music_file_location(music_name, "bass")) 173 | valid = valid and os.path.isfile(self.get_music_file_location(music_name, "drums")) 174 | valid = valid and os.path.isfile(self.get_music_file_location(music_name, "other")) 175 | return valid 176 | -------------------------------------------------------------------------------- /src/callbacks/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from typing import List, Optional, Any 4 | 5 | import matplotlib.pyplot as plt 6 | import seaborn as sn 7 | import torch 8 | import wandb 9 | from pytorch_lightning import Callback, Trainer 10 | from pytorch_lightning.loggers import LoggerCollection, WandbLogger 11 | from pytorch_lightning.utilities.types import STEP_OUTPUT 12 | from sklearn import metrics 13 | from sklearn.metrics import f1_score, precision_score, recall_score 14 | 15 | 16 | def get_wandb_logger(trainer: Trainer) -> WandbLogger: 17 | """Safely get Weights&Biases logger from Trainer.""" 18 | 19 | if isinstance(trainer.logger, WandbLogger): 20 | return trainer.logger 21 | 22 | if isinstance(trainer.logger, LoggerCollection): 23 | for logger in trainer.logger: 24 | if isinstance(logger, WandbLogger): 25 | return logger 26 | 27 | raise Exception( 28 | "You are using wandb related callback, but WandbLogger was not found for some reason..." 29 | ) 30 | 31 | 32 | class UploadValidTrack(Callback): 33 | def __init__(self, crop: int, upload_after_n_epoch: int): 34 | self.sample_length = crop * 44100 35 | self.upload_after_n_epoch = upload_after_n_epoch 36 | self.len_left_window = self.len_right_window = self.sample_length // 2 37 | 38 | def on_validation_batch_end( 39 | self, 40 | trainer: 'pl.Trainer', 41 | pl_module: 'pl.LightningModule', 42 | outputs: Optional[STEP_OUTPUT], 43 | batch: Any, 44 | batch_idx: int, 45 | dataloader_idx: int, 46 | ) -> None: 47 | if outputs is None: 48 | return 49 | track_id = outputs['track_id'] 50 | track = outputs['track'] 51 | 52 | logger = get_wandb_logger(trainer=trainer) 53 | experiment = logger.experiment 54 | if pl_module.current_epoch < self.upload_after_n_epoch: 55 | return None 56 | 57 | mid = track.shape[-1]//2 58 | track = track[:, mid-self.len_left_window:mid+self.len_right_window] 59 | 60 | experiment.log({'track={}_epoch={}'.format(track_id, pl_module.current_epoch): 61 | [wandb.Audio(track.T, sample_rate=44100)]}) 62 | 63 | 64 | class WatchModel(Callback): 65 | """Make wandb watch model at the beginning of the run.""" 66 | 67 | def __init__(self, log: str = "gradients", log_freq: int = 100): 68 | self.log = log 69 | self.log_freq = log_freq 70 | 71 | def on_train_start(self, trainer, pl_module): 72 | logger = get_wandb_logger(trainer=trainer) 73 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) 74 | 75 | 76 | class UploadCodeAsArtifact(Callback): 77 | """Upload all *.py files to wandb as an artifact, at the beginning of the run.""" 78 | 79 | def __init__(self, code_dir: str): 80 | self.code_dir = code_dir 81 | 82 | def on_train_start(self, trainer, pl_module): 83 | logger = get_wandb_logger(trainer=trainer) 84 | experiment = logger.experiment 85 | 86 | code = wandb.Artifact("project-source", type="code") 87 | for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True): 88 | code.add_file(path) 89 | 90 | experiment.use_artifact(code) 91 | 92 | 93 | class UploadCheckpointsAsArtifact(Callback): 94 | """Upload checkpoints to wandb as an artifact, at the end of run.""" 95 | 96 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): 97 | self.ckpt_dir = ckpt_dir 98 | self.upload_best_only = upload_best_only 99 | 100 | def on_train_end(self, trainer, pl_module): 101 | logger = get_wandb_logger(trainer=trainer) 102 | experiment = logger.experiment 103 | 104 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") 105 | 106 | if self.upload_best_only: 107 | ckpts.add_file(trainer.checkpoint_callback.best_model_path) 108 | else: 109 | for path in glob.glob(os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True): 110 | ckpts.add_file(path) 111 | 112 | experiment.use_artifact(ckpts) 113 | 114 | 115 | class LogConfusionMatrix(Callback): 116 | """Generate confusion matrix every epoch and send it to wandb. 117 | Expects validation step to return predictions and targets. 118 | """ 119 | 120 | def __init__(self): 121 | self.preds = [] 122 | self.targets = [] 123 | self.ready = True 124 | 125 | def on_sanity_check_start(self, trainer, pl_module) -> None: 126 | self.ready = False 127 | 128 | def on_sanity_check_end(self, trainer, pl_module): 129 | """Start executing this callback only after all validation sanity checks end.""" 130 | self.ready = True 131 | 132 | def on_validation_batch_end( 133 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 134 | ): 135 | """Gather data from single batch.""" 136 | if self.ready: 137 | self.preds.append(outputs["preds"]) 138 | self.targets.append(outputs["targets"]) 139 | 140 | def on_validation_epoch_end(self, trainer, pl_module): 141 | """Generate confusion matrix.""" 142 | if self.ready: 143 | logger = get_wandb_logger(trainer) 144 | experiment = logger.experiment 145 | 146 | preds = torch.cat(self.preds).cpu().numpy() 147 | targets = torch.cat(self.targets).cpu().numpy() 148 | 149 | confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) 150 | 151 | # set figure size 152 | plt.figure(figsize=(14, 8)) 153 | 154 | # set labels size 155 | sn.set(font_scale=1.4) 156 | 157 | # set font size 158 | sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") 159 | 160 | # names should be uniqe or else charts from different experiments in wandb will overlap 161 | experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) 162 | 163 | # according to wandb docs this should also work but it crashes 164 | # experiment.log(f{"confusion_matrix/{experiment.name}": plt}) 165 | 166 | # reset plot 167 | plt.clf() 168 | 169 | self.preds.clear() 170 | self.targets.clear() 171 | 172 | 173 | class LogF1PrecRecHeatmap(Callback): 174 | """Generate f1, precision, recall heatmap every epoch and send it to wandb. 175 | Expects validation step to return predictions and targets. 176 | """ 177 | 178 | def __init__(self, class_names: List[str] = None): 179 | self.preds = [] 180 | self.targets = [] 181 | self.ready = True 182 | 183 | def on_sanity_check_start(self, trainer, pl_module): 184 | self.ready = False 185 | 186 | def on_sanity_check_end(self, trainer, pl_module): 187 | """Start executing this callback only after all validation sanity checks end.""" 188 | self.ready = True 189 | 190 | def on_validation_batch_end( 191 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 192 | ): 193 | """Gather data from single batch.""" 194 | if self.ready: 195 | self.preds.append(outputs["preds"]) 196 | self.targets.append(outputs["targets"]) 197 | 198 | def on_validation_epoch_end(self, trainer, pl_module): 199 | """Generate f1, precision and recall heatmap.""" 200 | if self.ready: 201 | logger = get_wandb_logger(trainer=trainer) 202 | experiment = logger.experiment 203 | 204 | preds = torch.cat(self.preds).cpu().numpy() 205 | targets = torch.cat(self.targets).cpu().numpy() 206 | f1 = f1_score(preds, targets, average=None) 207 | r = recall_score(preds, targets, average=None) 208 | p = precision_score(preds, targets, average=None) 209 | data = [f1, p, r] 210 | 211 | # set figure size 212 | plt.figure(figsize=(14, 3)) 213 | 214 | # set labels size 215 | sn.set(font_scale=1.2) 216 | 217 | # set font size 218 | sn.heatmap( 219 | data, 220 | annot=True, 221 | annot_kws={"size": 10}, 222 | fmt=".3f", 223 | yticklabels=["F1", "Precision", "Recall"], 224 | ) 225 | 226 | # names should be uniqe or else charts from different experiments in wandb will overlap 227 | experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) 228 | 229 | # reset plot 230 | plt.clf() 231 | 232 | self.preds.clear() 233 | self.targets.clear() 234 | 235 | 236 | class LogImagePredictions(Callback): 237 | """Logs a validation batch and their predictions to wandb. 238 | Example adapted from: 239 | https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY 240 | """ 241 | 242 | def __init__(self, num_samples: int = 8): 243 | super().__init__() 244 | self.num_samples = num_samples 245 | self.ready = True 246 | 247 | def on_sanity_check_start(self, trainer, pl_module): 248 | self.ready = False 249 | 250 | def on_sanity_check_end(self, trainer, pl_module): 251 | """Start executing this callback only after all validation sanity checks end.""" 252 | self.ready = True 253 | 254 | def on_validation_epoch_end(self, trainer, pl_module): 255 | if self.ready: 256 | logger = get_wandb_logger(trainer=trainer) 257 | experiment = logger.experiment 258 | 259 | # get a validation batch from the validation dat loader 260 | val_samples = next(iter(trainer.datamodule.val_dataloader())) 261 | val_imgs, val_labels = val_samples 262 | 263 | # run the batch through the network 264 | val_imgs = val_imgs.to(device=pl_module.device) 265 | logits = pl_module(val_imgs) 266 | preds = torch.argmax(logits, axis=-1) 267 | 268 | # log the images as wandb Image 269 | experiment.log( 270 | { 271 | f"Images/{experiment.name}": [ 272 | wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 273 | for x, pred, y in zip( 274 | val_imgs[: self.num_samples], 275 | preds[: self.num_samples], 276 | val_labels[: self.num_samples], 277 | ) 278 | ] 279 | } 280 | ) 281 | -------------------------------------------------------------------------------- /src/models/mdxnet.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from omegaconf import OmegaConf 7 | from pytorch_lightning import LightningModule 8 | from pytorch_lightning.utilities.types import STEP_OUTPUT 9 | from torch.nn.functional import mse_loss 10 | 11 | from src.models.modules import TFC_TDF 12 | from src.utils.utils import sdr 13 | 14 | 15 | class AbstractMDXNet(LightningModule): 16 | __metaclass__ = ABCMeta 17 | 18 | def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap, ckpt): 19 | super().__init__() 20 | self.target_name = target_name 21 | self.lr = lr 22 | self.optimizer = optimizer 23 | self.dim_c = dim_c 24 | self.dim_f = dim_f 25 | self.dim_t = dim_t 26 | self.n_fft = n_fft 27 | self.n_bins = n_fft // 2 + 1 28 | self.hop_length = hop_length 29 | 30 | self.chunk_size = hop_length * (self.dim_t - 1) 31 | self.overlap = overlap 32 | self.window = nn.Parameter(torch.hann_window(window_length=self.n_fft, periodic=True), requires_grad=False) 33 | self.freq_pad = nn.Parameter(torch.zeros([1, dim_c, self.n_bins - self.dim_f, self.dim_t]), requires_grad=False) 34 | self.input_sample_shape = (self.stft(torch.zeros([1, 2, self.chunk_size]))).shape 35 | 36 | if ckpt is not None: 37 | self.load_from_checkpoint(ckpt) 38 | 39 | def configure_optimizers(self): 40 | if self.optimizer == 'rmsprop': 41 | return torch.optim.RMSprop(self.parameters(), self.lr) 42 | 43 | def training_step(self, *args, **kwargs) -> STEP_OUTPUT: 44 | mix_wave, target_wave = args[0] 45 | mix_spec = self.stft(mix_wave) 46 | 47 | target_wave_hat = self.istft(self(mix_spec)) 48 | loss = mse_loss(target_wave_hat, target_wave) 49 | self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True) 50 | 51 | return {"loss": loss} 52 | 53 | # Validation SDR is calculated on whole tracks and not chunks since 54 | # short inputs have high possibility of being silent (all-zero signal) 55 | # which leads to very low sdr values regardless of the model. 56 | # A natural procedure would be to split a track into chunk batches and 57 | # load them on multiple gpus, but aggregation was too difficult. 58 | # So instead we load one whole track on a single device (data_loader batch_size should always be 1) 59 | # and do all the batch splitting and aggregation on a single device. 60 | def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: 61 | mix_chunk_batches, target = args[0] 62 | 63 | # remove data_loader batch dimension 64 | mix_chunk_batches, target = [batch[0] for batch in mix_chunk_batches], target[0] 65 | 66 | # process whole track in batches of chunks 67 | target_hat_chunks = [] 68 | for batch in mix_chunk_batches: 69 | mix_spec = self.stft(batch) 70 | target_hat_chunks.append(self.istft(self(mix_spec))[..., self.overlap:-self.overlap]) 71 | target_hat_chunks = torch.cat(target_hat_chunks) 72 | 73 | # concat all output chunks 74 | target_hat = target_hat_chunks.transpose(0, 1).reshape(2, -1)[..., :target.shape[-1]] 75 | 76 | score = sdr(target_hat.detach().cpu().numpy(), target.cpu().numpy()) 77 | self.log("val/sdr", score, sync_dist=True, on_step=False, on_epoch=True, logger=True) 78 | 79 | return {'loss': score} 80 | 81 | def stft(self, x): 82 | x = x.reshape([-1, self.chunk_size]) 83 | x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True) 84 | x = x.permute([0, 3, 1, 2]) 85 | x = x.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, self.dim_c, self.n_bins, self.dim_t]) 86 | return x[:, :, :self.dim_f] 87 | 88 | def istft(self, spec): 89 | spec = torch.cat([spec, self.freq_pad.repeat([spec.shape[0], 1, 1, 1])], -2) 90 | spec = spec.reshape([-1, 2, 2, self.n_bins, self.dim_t]).reshape([-1, 2, self.n_bins, self.dim_t]) 91 | spec = spec.permute([0, 2, 3, 1]) 92 | spec = torch.istft(spec, n_fft=self.n_fft, hop_length=self.hop_length, window=self.window, center=True) 93 | return spec.reshape([-1, 2, self.chunk_size]) 94 | 95 | 96 | class ConvTDFNet(AbstractMDXNet): 97 | def __init__(self, target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, 98 | num_blocks, l, g, k, bn, bias, overlap, ckpt): 99 | 100 | super(ConvTDFNet, self).__init__( 101 | target_name, lr, optimizer, dim_c, dim_f, dim_t, n_fft, hop_length, overlap, ckpt) 102 | self.save_hyperparameters() 103 | 104 | self.num_blocks = num_blocks 105 | self.l = l 106 | self.g = g 107 | self.k = k 108 | self.bn = bn 109 | self.bias = bias 110 | 111 | self.n = num_blocks // 2 112 | scale = (2, 2) 113 | 114 | self.first_conv = nn.Sequential( 115 | nn.Conv2d(in_channels=self.dim_c, out_channels=g, kernel_size=(1, 1)), 116 | nn.BatchNorm2d(g), 117 | nn.ReLU(), 118 | ) 119 | 120 | f = self.dim_f 121 | c = g 122 | self.encoding_blocks = nn.ModuleList() 123 | self.ds = nn.ModuleList() 124 | for i in range(self.n): 125 | self.encoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias)) 126 | self.ds.append( 127 | nn.Sequential( 128 | nn.Conv2d(in_channels=c, out_channels=c + g, kernel_size=scale, stride=scale), 129 | nn.BatchNorm2d(c + g), 130 | nn.ReLU() 131 | ) 132 | ) 133 | f = f // 2 134 | c += g 135 | 136 | self.bottleneck_block = TFC_TDF(c, l, f, k, bn, bias=bias) 137 | 138 | self.decoding_blocks = nn.ModuleList() 139 | self.us = nn.ModuleList() 140 | for i in range(self.n): 141 | self.us.append( 142 | nn.Sequential( 143 | nn.ConvTranspose2d(in_channels=c, out_channels=c - g, kernel_size=scale, stride=scale), 144 | nn.BatchNorm2d(c - g), 145 | nn.ReLU() 146 | ) 147 | ) 148 | f = f * 2 149 | c -= g 150 | 151 | self.decoding_blocks.append(TFC_TDF(c, l, f, k, bn, bias=bias)) 152 | 153 | self.final_conv = nn.Sequential( 154 | nn.Conv2d(in_channels=c, out_channels=self.dim_c, kernel_size=(1, 1)), 155 | ) 156 | 157 | def forward(self, x): 158 | 159 | x = self.first_conv(x) 160 | 161 | x = x.transpose(-1, -2) 162 | 163 | ds_outputs = [] 164 | for i in range(self.n): 165 | x = self.encoding_blocks[i](x) 166 | ds_outputs.append(x) 167 | x = self.ds[i](x) 168 | 169 | x = self.bottleneck_block(x) 170 | 171 | for i in range(self.n): 172 | x = self.us[i](x) 173 | x *= ds_outputs[-i - 1] 174 | x = self.decoding_blocks[i](x) 175 | 176 | x = x.transpose(-1, -2) 177 | 178 | x = self.final_conv(x) 179 | 180 | return x 181 | 182 | 183 | class Mixer(LightningModule): 184 | def __init__(self, separator_configs, separator_ckpts, lr, optimizer, dim_t, hop_length, overlap, target_name='all'): 185 | super().__init__() 186 | self.save_hyperparameters() 187 | 188 | # Load pretrained separators per source 189 | self.separators = nn.ModuleDict() 190 | for ckpt in separator_ckpts.values(): 191 | # if failed here, then fill valid ckpt pahts in the given yaml for Mixer training 192 | assert ckpt is not None 193 | 194 | for source in separator_configs.keys(): 195 | model_config = OmegaConf.load(separator_configs[source]) 196 | assert 'ConvTDFNet' in model_config._target_ 197 | separator = ConvTDFNet(**{key: model_config[key] for key in dict(model_config) if key !='_target_'}) 198 | separator.load_from_checkpoint(separator_ckpts[source]) 199 | self.separators[source] = separator 200 | 201 | # Freeze 202 | with torch.no_grad(): 203 | for param in self.separators.parameters(): 204 | param.requires_grad = False 205 | 206 | self.lr = lr 207 | self.optimizer = optimizer 208 | 209 | self.chunk_size = hop_length * (dim_t - 1) 210 | self.overlap = overlap 211 | self.dim_s = len(separator_configs) 212 | self.mixing_layer = nn.Linear((self.dim_s+1) * 2, self.dim_s * 2, bias=False) 213 | 214 | def configure_optimizers(self): 215 | if self.optimizer == 'rmsprop': 216 | return torch.optim.RMSprop(self.parameters(), self.lr) 217 | 218 | def training_step(self, *args, **kwargs) -> STEP_OUTPUT: 219 | mix_wave, target_waves = args[0] 220 | 221 | with torch.no_grad(): 222 | target_wave_hats = [] 223 | for source in ['bass', 'drums', 'other', 'vocals']: 224 | S = self.separators[source] 225 | target_wave_hat = S.istft(S(S.stft(mix_wave))) 226 | target_wave_hats.append(target_wave_hat) # shape = [source, batch, channel, time] 227 | 228 | target_wave_hats = torch.stack(target_wave_hats).transpose(0, 1) 229 | 230 | mixer_output = self(torch.cat([target_wave_hats, mix_wave.unsqueeze(1)], 1)) 231 | 232 | loss = mse_loss(mixer_output[..., self.overlap:-self.overlap], 233 | target_waves[..., self.overlap:-self.overlap]) 234 | self.log("train/loss", loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True) 235 | 236 | return {"loss": loss} 237 | 238 | # data_loader batch_size should always be 1 239 | def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: 240 | mix_chunk_batches, target_waves = args[0] 241 | 242 | # remove data_loader batch dimension 243 | mix_chunk_batches, target_waves = [batch[0] for batch in mix_chunk_batches], target_waves[0] 244 | 245 | # process whole track in batches of chunks 246 | target_hat_chunks = [] 247 | for mix_wave in mix_chunk_batches: 248 | target_wave_hats = [] 249 | for source in ['bass', 'drums', 'other', 'vocals']: 250 | S = self.separators[source] 251 | target_wave_hat = S.istft(S(S.stft(mix_wave))) 252 | target_wave_hats.append(target_wave_hat) # shape = [source, batch, channel, time] 253 | target_wave_hats = torch.stack(target_wave_hats).transpose(0, 1) 254 | mixer_output = self(torch.cat([target_wave_hats, mix_wave.unsqueeze(1)], 1)) 255 | target_hat_chunks.append(mixer_output[..., self.overlap:-self.overlap]) 256 | 257 | target_hat_chunks = torch.cat(target_hat_chunks) 258 | 259 | # concat all output chunks 260 | target_hat = target_hat_chunks.permute(1,2,0,3).reshape(self.dim_s, 2, -1)[..., :target_waves.shape[-1]] 261 | 262 | score = sdr(target_hat.detach().cpu().numpy(), target_waves.cpu().numpy()) 263 | self.log("val/sdr", score, sync_dist=True, on_step=False, on_epoch=True, logger=True) 264 | 265 | return {'loss': score} 266 | 267 | def forward(self, x): 268 | x = x.reshape(-1, (self.dim_s + 1) * 2, self.chunk_size).transpose(-1, -2) 269 | x = self.mixing_layer(x) 270 | return x.transpose(-1, -2).reshape(-1, self.dim_s, 2, self.chunk_size) 271 | --------------------------------------------------------------------------------