├── .gitignore ├── README.md ├── eval.sh ├── eval_general.py ├── local ├── ConvTasNet.yml ├── DPRNNTasNet.yml ├── DPTNet.yml ├── SepFormer2TasNet.yml ├── SepFormerTasNet.yml ├── librimix │ ├── create_local_metadata.py │ └── prepare_data.sh └── wham │ ├── convert_sphere2wav.sh │ ├── prepare_data.sh │ └── preprocess_wham.py ├── perm_general.py ├── prepare_librimix_data.sh ├── prepare_wham_data.sh ├── requirements.txt ├── run.sh ├── scripts ├── run_ConvTasNet_Libri2Mix_sep_clean_from_scratch.sh ├── run_ConvTasNet_Libri2Mix_sep_clean_multi_task.sh ├── run_ConvTasNet_Libri2Mix_sep_clean_pretrained.sh ├── run_ConvTasNet_wsj0-2mix_sep_clean_from_scratch.sh ├── run_ConvTasNet_wsj0-2mix_sep_clean_multi_task.sh ├── run_ConvTasNet_wsj0-2mix_sep_clean_pretrained.sh ├── run_DPRNN_Libri2Mix_enh_single_from_scratch.sh ├── run_DPRNN_wsj0_sep_clean_from_scratch.sh ├── run_DPRNN_wsj0_sep_clean_pretrained.sh ├── run_DPTNet_Libri2Mix_enh_single_from_scratch.sh ├── run_DPTNet_wsj0_sep_clean_from_scratch.sh └── run_DPTNet_wsj0_sep_clean_pretrained.sh ├── src ├── __init__.py ├── data │ ├── __init__.py │ └── utils.py ├── engine │ └── system.py ├── losses │ └── multi_task_wrapper.py ├── masknn │ ├── __init__.py │ └── attention.py └── models │ ├── __init__.py │ └── sepformer_tasnet.py ├── train_general.py └── utils ├── parse_options.sh └── prepare_python_env.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | tools 3 | data 4 | pretrained 5 | exp 6 | logs 7 | 8 | wham_wav 9 | wsj0_wav 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSL-pretraining-separation 2 | This is the official repository of [Stabilizing Label Assignment for Speech Separation by Self-supervised Pre-training 3 | ](https://arxiv.org/abs/2010.15366). 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-pre-training-reduces-label/speech-separation-on-libri2mix)](https://paperswithcode.com/sota/speech-separation-on-libri2mix?p=self-supervised-pre-training-reduces-label) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervised-pre-training-reduces-label/speech-separation-on-wsj0-2mix)](https://paperswithcode.com/sota/speech-separation-on-wsj0-2mix?p=self-supervised-pre-training-reduces-label) 8 | ------------------------------------ 9 | Corpus Preprocessing 10 | ------------------------------------ 11 | ### WHAM! / WSJ0-mix 12 | - Prepare your WSJ0 corpus and place under `./` 13 | - Run: 14 | ```bash 15 | bash prepare_wham_data.sh 16 | ``` 17 | 18 | ### Libri2Mix 19 | - Run: 20 | ``` bash 21 | bash prepare_librimix_data.sh --n_src 2 22 | ``` 23 | 24 | ------------------------------------ 25 | Train 26 | ------------------------------------ 27 | Run `scripts/*.sh` to reproduce experiments in the paper. 28 | 29 | ### Models 30 | * [x] ConvTasNet 31 | * [x] DPRNNTasNet 32 | * [x] DPTNet 33 | * [x] SepFormerTasNet (my implementation of [SepFormer](https://arxiv.org/pdf/2010.13154.pdf)) 34 | * [x] SepFormer2TasNet (my modification of [SepFormer](https://arxiv.org/pdf/2010.13154.pdf)) 35 | 36 | Note: our SepFormer does not include data augmentation and dynamic mixing, thus could not perform as well as the official results. 37 | 38 | ------------------------------------ 39 | Reference 40 | ------------------------------------ 41 | The codes were adapted from 42 | - [asteroid/egs/librimix/ConvTasNet/](https://github.com/asteroid-team/asteroid/tree/master/egs/librimix/ConvTasNet) 43 | - [asteroid/egs/wham/ConvTasNet/](https://github.com/asteroid-team/asteroid/tree/master/egs/wham/ConvTasNet). 44 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #echo "exp/train_convtasnet_ConvTasNet_Libri2Mix_sep_clean/_ckpt_epoch_99.ckpt"; 2 | #python eval_general.py --use_gpu 1 \ 3 | #--corpus LibriMix --test_dir data/wav8k/min/test --task sep_clean \ 4 | #--exp_dir exp/train_convtasnet_ConvTasNet_Libri2Mix_sep_clean --ckpt_path _ckpt_epoch_99.ckpt --out_dir eval_99_tt; 5 | 6 | #echo "exp/train_convtasnet_ConvTasNet_Libri2Mix_sep_clean_pretrained/_ckpt_epoch_98.ckpt"; 7 | #python eval_general.py --use_gpu 1 \ 8 | #--corpus LibriMix --test_dir data/wav8k/min/test --task sep_clean \ 9 | #--exp_dir exp/train_convtasnet_ConvTasNet_Libri2Mix_sep_clean_pretrained --ckpt_path _ckpt_epoch_98.ckpt --out_dir eval_98_tt; 10 | 11 | #echo "exp/train_convtasnet_ConvTasNet_Libri2Mix_sep_clean_multi_task_enh360/_ckpt_epoch_97.ckpt"; 12 | #python eval_general.py --use_gpu 1 \ 13 | #--corpus LibriMix --test_dir data/wav8k/min/test --task sep_clean \ 14 | #--exp_dir exp/train_convtasnet_ConvTasNet_Libri2Mix_sep_clean_multi_task_enh360 --ckpt_path _ckpt_epoch_97.ckpt --out_dir eval_97_tt; 15 | 16 | #echo "exp/train_convtasnet_ConvTasNet_wsj0-2mix_sep_clean/_ckpt_epoch_98.ckpt"; 17 | #python eval_general.py --use_gpu 1\ 18 | #--corpus wsj0-mix --test_dir wsj0-mix/2speakers/wav8k/min/tt --task sep_clean \ 19 | #--exp_dir exp/train_convtasnet_ConvTasNet_wsj0-2mix_sep_clean --ckpt_path _ckpt_epoch_98.ckpt --out_dir eval_98_tt; 20 | 21 | #echo "exp/train_convtasnet_ConvTasNet_wsj0-2mix_sep_clean_pretrained_old/_ckpt_epoch_99.ckpt"; 22 | #python eval_general.py --use_gpu 1\ 23 | #--corpus wsj0-mix --test_dir wsj0-mix/2speakers/wav8k/min/tt --task sep_clean \ 24 | #--exp_dir exp/train_convtasnet_ConvTasNet_wsj0-2mix_sep_clean_pretrained_old --ckpt_path _ckpt_epoch_99.ckpt --out_dir eval_99_tt; 25 | 26 | #echo "exp/train_convtasnet_ConvTasNet_wsj0-2mix_sep_clean_multi_task_enh360/_ckpt_epoch_94.ckpt"; 27 | #python eval_general.py --use_gpu 1\ 28 | #--corpus wsj0-mix --test_dir wsj0-mix/2speakers/wav8k/min/tt --task sep_clean \ 29 | #--exp_dir exp/train_convtasnet_ConvTasNet_wsj0-2mix_sep_clean_multi_task_enh360 --ckpt_path _ckpt_epoch_94.ckpt --out_dir eval_94_tt; 30 | 31 | #echo "exp/train_dprnn_DPRNNTasNet_wsj0-2mix_sep_clean_new_8gpu/_ckpt_epoch_171.ckpt"; 32 | #python eval_general.py --use_gpu 1 --model DPRNNTasNet \ 33 | #--corpus wsj0-mix --test_dir wsj0-mix/2speakers/wav8k/min/tt --task sep_clean \ 34 | #--exp_dir exp/train_dprnn_DPRNNTasNet_wsj0-2mix_sep_clean_new_8gpu --ckpt_path _ckpt_epoch_171.ckpt --out_dir eval_171_tt; 35 | 36 | #echo "exp/train_dprnn_DPRNNTasNet_wsj0-2mix_sep_clean_new_8gpu_pretrained/_ckpt_epoch_122.ckpt"; 37 | #python eval_general.py --use_gpu 1 --model DPRNNTasNet \ 38 | #--corpus wsj0-mix --test_dir wsj0-mix/2speakers/wav8k/min/tt --task sep_clean \ 39 | #--exp_dir exp/train_dprnn_DPRNNTasNet_wsj0-2mix_sep_clean_new_8gpu_pretrained --ckpt_path _ckpt_epoch_122.ckpt --out_dir eval_122_tt; 40 | 41 | #echo "exp/train_dptnet_DPTNet_wsj0-2mix_sep_clean/_ckpt_epoch_196.ckpt"; 42 | #python eval_general.py --use_gpu 1 --model DPTNet \ 43 | #--corpus wsj0-mix --test_dir wsj0-mix/2speakers/wav8k/min/tt --task sep_clean \ 44 | #--exp_dir exp/train_dptnet_DPTNet_wsj0-2mix_sep_clean --ckpt_path _ckpt_epoch_196.ckpt --out_dir eval_196_tt; 45 | #exit 46 | 47 | echo "exp/train_dptnet_DPTNet_wsj0-2mix_sep_clean_8gpu_pretrained/_ckpt_epoch_199.ckpt"; 48 | python eval_general.py --use_gpu 1 --model DPTNet \ 49 | --corpus wsj0-mix --test_dir wsj0-mix/2speakers/wav8k/min/tt --task sep_clean \ 50 | --exp_dir exp/train_dptnet_DPTNet_wsj0-2mix_sep_clean_8gpu_pretrained --ckpt_path _ckpt_epoch_199.ckpt --out_dir eval_199_tt; 51 | -------------------------------------------------------------------------------- /eval_general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import soundfile as sf 4 | import torch 5 | import yaml 6 | import json 7 | import argparse 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm import tqdm 11 | from pprint import pprint 12 | 13 | import asteroid 14 | from asteroid.metrics import get_metrics 15 | from asteroid.data.librimix_dataset import LibriMix 16 | from asteroid.data.wsj0_mix import Wsj0mixDataset 17 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr 18 | from asteroid.models import save_publishable 19 | from asteroid.utils import tensors_to_device 20 | 21 | from src.data import make_test_dataset 22 | from src.models import * 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--corpus", default="LibriMix", choices=["LibriMix", "wsj0-mix"]) 27 | parser.add_argument("--model", default="ConvTasNet", choices=["ConvTasNet", "DPRNNTasNet", "DPTNet"]) 28 | parser.add_argument("--test_dir", type=str, required=True, help="Test directory including the csv files") 29 | parser.add_argument("--task", type=str, default="sep_clean", choices=["sep_clean", "sep_noisy"]) 30 | parser.add_argument("--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution") 31 | parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") 32 | parser.add_argument("--out_dir", type=str, default="results/best_model", help="Directory in exp_dir where the eval results will be stored") 33 | parser.add_argument("--n_save_ex", type=int, default=10, help="Number of audio examples to save, -1 means all") 34 | 35 | parser.add_argument("--ckpt_path", default="best_model.pth", help="Experiment checkpoint path") 36 | parser.add_argument("--publishable", action="store_true", help="Save publishable.") 37 | 38 | compute_metrics = ["si_sdr", "sdr", "sir", "sar", "stoi"] 39 | 40 | 41 | def main(conf): 42 | model_path = os.path.join(conf["exp_dir"], conf["ckpt_path"]) 43 | 44 | # all resulting files would be saved in eval_save_dir 45 | eval_save_dir = os.path.join(conf["exp_dir"], conf["out_dir"]) 46 | os.makedirs(eval_save_dir, exist_ok=True) 47 | 48 | if not os.path.exists(os.path.join(eval_save_dir, "final_metrics.json")): 49 | if conf["ckpt_path"] == "best_model.pth": 50 | # serialized checkpoint 51 | model = getattr(asteroid, conf["model"]).from_pretrained(model_path) 52 | else: 53 | # non-serialized checkpoint, _ckpt_epoch_{i}.ckpt, keys would start with 54 | # "model.", which need to be removed 55 | model = getattr(asteroid, conf["model"])(**conf["train_conf"]["filterbank"], **conf["train_conf"]["masknet"]) 56 | all_states = torch.load(model_path, map_location="cpu") 57 | state_dict = {k.split('.', 1)[1]: all_states["state_dict"][k] for k in all_states["state_dict"]} 58 | model.load_state_dict(state_dict) 59 | # model.load_state_dict(all_states["state_dict"], strict=False) 60 | 61 | # Handle device placement 62 | if conf["use_gpu"]: 63 | model.cuda() 64 | model_device = next(model.parameters()).device 65 | test_set = make_test_dataset( 66 | corpus=conf["corpus"], 67 | test_dir=conf["test_dir"], 68 | task=conf["task"], 69 | sample_rate=conf["sample_rate"], 70 | n_src=conf["train_conf"]["data"]["n_src"], 71 | ) 72 | # Used to reorder sources only 73 | loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") 74 | 75 | # Randomly choose the indexes of sentences to save. 76 | ex_save_dir = os.path.join(eval_save_dir, "examples/") 77 | if conf["n_save_ex"] == -1: 78 | conf["n_save_ex"] = len(test_set) 79 | save_idx = random.sample(range(len(test_set)), conf["n_save_ex"]) 80 | 81 | series_list = [] 82 | torch.no_grad().__enter__() 83 | for idx in tqdm(range(len(test_set))): 84 | # Forward the network on the mixture. 85 | mix, sources = tensors_to_device(test_set[idx], device=model_device) 86 | est_sources = model(mix.unsqueeze(0)) 87 | 88 | # When inferencing separation for multi-task training, 89 | # exclude the last channel. Does not effect single-task training 90 | # models (from_scratch, pre+FT). 91 | est_sources = est_sources[:, :sources.shape[0]] 92 | 93 | loss, reordered_sources = loss_func(est_sources, sources[None], return_est=True) 94 | mix_np = mix.cpu().data.numpy() 95 | sources_np = sources.cpu().data.numpy() 96 | est_sources_np = reordered_sources.squeeze(0).cpu().data.numpy() 97 | # For each utterance, we get a dictionary with the mixture path, 98 | # the input and output metrics 99 | utt_metrics = get_metrics( 100 | mix_np, 101 | sources_np, 102 | est_sources_np, 103 | sample_rate=conf["sample_rate"], 104 | metrics_list=compute_metrics, 105 | ) 106 | if hasattr(test_set, "mixture_path"): 107 | utt_metrics["mix_path"] = test_set.mixture_path 108 | series_list.append(pd.Series(utt_metrics)) 109 | 110 | # Save some examples in a folder. Wav files and metrics as text. 111 | if idx in save_idx: 112 | local_save_dir = os.path.join(ex_save_dir, "ex_{}/".format(idx)) 113 | os.makedirs(local_save_dir, exist_ok=True) 114 | sf.write(local_save_dir + "mixture.wav", mix_np, conf["sample_rate"]) 115 | # Loop over the sources and estimates 116 | for src_idx, src in enumerate(sources_np): 117 | sf.write(local_save_dir + "s{}.wav".format(src_idx), src, conf["sample_rate"]) 118 | for src_idx, est_src in enumerate(est_sources_np): 119 | est_src *= np.max(np.abs(mix_np)) / np.max(np.abs(est_src)) 120 | sf.write( 121 | local_save_dir + "s{}_estimate.wav".format(src_idx), 122 | est_src, 123 | conf["sample_rate"], 124 | ) 125 | # Write local metrics to the example folder. 126 | with open(local_save_dir + "metrics.json", "w") as f: 127 | json.dump(utt_metrics, f, indent=0) 128 | 129 | # Save all metrics to the experiment folder. 130 | all_metrics_df = pd.DataFrame(series_list) 131 | all_metrics_df.to_csv(os.path.join(eval_save_dir, "all_metrics.csv")) 132 | 133 | # Print and save summary metrics 134 | final_results = {} 135 | for metric_name in compute_metrics: 136 | input_metric_name = "input_" + metric_name 137 | ldf = all_metrics_df[metric_name] - all_metrics_df[input_metric_name] 138 | final_results[metric_name] = all_metrics_df[metric_name].mean() 139 | final_results[metric_name + "_imp"] = ldf.mean() 140 | print("Overall metrics :") 141 | pprint(final_results) 142 | with open(os.path.join(eval_save_dir, "final_metrics.json"), "w") as f: 143 | json.dump(final_results, f, indent=0) 144 | else: 145 | with open(os.path.join(eval_save_dir, "final_metrics.json"), "r") as f: 146 | final_results = json.load(f) 147 | 148 | if conf["publishable"]: 149 | assert conf["ckpt_path"] == "best_model.pth" 150 | model_dict = torch.load(model_path, map_location="cpu") 151 | os.makedirs(os.path.join(conf["exp_dir"], "publish_dir"), exist_ok=True) 152 | publishable = save_publishable( 153 | os.path.join(conf["exp_dir"], "publish_dir"), 154 | model_dict, 155 | metrics=final_results, 156 | train_conf=train_conf, 157 | ) 158 | 159 | 160 | if __name__ == "__main__": 161 | args = parser.parse_args() 162 | arg_dic = dict(vars(args)) 163 | 164 | # Load training config 165 | conf_path = os.path.join(args.exp_dir, "conf.yml") 166 | with open(conf_path) as f: 167 | train_conf = yaml.safe_load(f) 168 | arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] 169 | arg_dic["train_conf"] = train_conf 170 | 171 | if args.task != arg_dic["train_conf"]["data"]["task"]: 172 | print( 173 | "Warning : the task used to test is different than " 174 | "the one from training, be sure this is what you want." 175 | ) 176 | 177 | main(arg_dic) 178 | -------------------------------------------------------------------------------- /local/ConvTasNet.yml: -------------------------------------------------------------------------------- 1 | data: 2 | mode: min 3 | n_src: 2 4 | sample_rate: 8000 5 | segment: 3 6 | task: sep_clean 7 | train_dir: data/librimix/wav8k/min/train-100 8 | valid_dir: data/librimix/wav8k/min/dev 9 | filterbank: 10 | kernel_size: 16 11 | n_filters: 512 12 | stride: 8 13 | masknet: 14 | bn_chan: 128 15 | hid_chan: 512 16 | mask_act: relu 17 | n_blocks: 8 18 | n_repeats: 3 19 | skip_chan: 128 20 | optim: 21 | lr: 0.001 22 | optimizer: adam 23 | weight_decay: 0.0 24 | positional arguments: {} 25 | training: 26 | batch_size: 24 27 | early_stop: true 28 | epochs: 100 29 | half_lr: true 30 | num_workers: 8 31 | -------------------------------------------------------------------------------- /local/DPRNNTasNet.yml: -------------------------------------------------------------------------------- 1 | data: 2 | mode: min 3 | n_src: 2 4 | sample_rate: 8000 5 | segment: 2.0 6 | task: sep_clean 7 | train_dir: data/wham/wav8k/min/tr 8 | valid_dir: data/wham/wav8k/min/cv 9 | filterbank: 10 | kernel_size: 2 11 | n_filters: 64 12 | stride: 1 13 | masknet: 14 | bidirectional: true 15 | bn_chan: 128 16 | chunk_size: 250 17 | dropout: 0 18 | hid_size: 128 19 | hop_size: 125 20 | in_chan: 64 21 | mask_act: relu 22 | n_repeats: 6 23 | out_chan: 64 24 | optim: 25 | lr: 0.001 26 | optimizer: adam 27 | weight_decay: 1.0e-05 28 | positional arguments: {} 29 | training: 30 | batch_size: 24 31 | early_stop: true 32 | epochs: 800 33 | gradient_clipping: 5 34 | half_lr: true 35 | num_workers: 8 36 | -------------------------------------------------------------------------------- /local/DPTNet.yml: -------------------------------------------------------------------------------- 1 | data: 2 | mode: min 3 | n_src: 2 4 | sample_rate: 8000 5 | segment: 4.0 6 | task: sep_clean 7 | train_dir: data/wham/wav8k/min/tr 8 | valid_dir: data/wham/wav8k/min/cv 9 | filterbank: 10 | kernel_size: 2 11 | n_filters: 64 12 | stride: 1 13 | masknet: 14 | bidirectional: true 15 | chunk_size: 250 16 | dropout: 0 17 | ff_activation: relu 18 | ff_hid: 256 19 | hop_size: 125 20 | in_chan: 64 21 | mask_act: relu 22 | n_repeats: 6 23 | norm_type: gLN 24 | out_chan: 64 25 | optim: 26 | lr: 0.001 27 | optimizer: adam 28 | weight_decay: 1.0e-05 29 | positional arguments: {} 30 | scheduler: 31 | d_model: 64 32 | noam_scale: 0.2 33 | steps_per_epoch: 10000 34 | training: 35 | batch_size: 24 36 | early_stop: true 37 | epochs: 200 38 | gradient_clipping: 5 39 | half_lr: true 40 | num_workers: 8 41 | -------------------------------------------------------------------------------- /local/SepFormer2TasNet.yml: -------------------------------------------------------------------------------- 1 | data: 2 | mode: min 3 | n_src: 2 4 | sample_rate: 8000 5 | segment: 4.0 6 | task: sep_clean 7 | train_dir: data/wham/wav8k/min/tr 8 | valid_dir: data/wham/wav8k/min/cv 9 | filterbank: 10 | kernel_size: 16 11 | n_filters: 256 12 | stride: 8 13 | masknet: 14 | chunk_size: 200 15 | dropout: 0 16 | ff_activation: relu 17 | ff_hid: 2048 18 | hop_size: 100 19 | mask_act: relu 20 | n_repeats: 2 21 | k_repeats: 4 22 | norm_type: gLN 23 | n_heads: 16 24 | optim: 25 | lr: 0.001 26 | optimizer: adam 27 | weight_decay: 1.0e-05 28 | positional arguments: {} 29 | scheduler: 30 | d_model: 64 31 | noam_scale: 0.2 32 | steps_per_epoch: 10000 33 | training: 34 | batch_size: 24 35 | early_stop: true 36 | epochs: 200 37 | gradient_clipping: 5 38 | half_lr: true 39 | num_workers: 8 40 | -------------------------------------------------------------------------------- /local/SepFormerTasNet.yml: -------------------------------------------------------------------------------- 1 | data: 2 | mode: min 3 | n_src: 2 4 | sample_rate: 8000 5 | segment: 4.0 6 | task: sep_clean 7 | train_dir: data/wham/wav8k/min/tr 8 | valid_dir: data/wham/wav8k/min/cv 9 | filterbank: 10 | kernel_size: 16 11 | n_filters: 256 12 | stride: 8 13 | masknet: 14 | chunk_size: 200 15 | dropout: 0 16 | ff_activation: relu 17 | ff_hid: 2048 18 | hop_size: 100 19 | mask_act: relu 20 | n_repeats: 2 21 | k_repeats: 4 22 | norm_type: gLN 23 | n_heads: 16 24 | optim: 25 | lr: 0.001 26 | optimizer: adam 27 | weight_decay: 1.0e-05 28 | positional arguments: {} 29 | scheduler: 30 | d_model: 64 31 | noam_scale: 0.2 32 | steps_per_epoch: 10000 33 | training: 34 | batch_size: 24 35 | early_stop: true 36 | epochs: 200 37 | gradient_clipping: 5 38 | half_lr: true 39 | num_workers: 8 40 | -------------------------------------------------------------------------------- /local/librimix/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | import pandas as pd 6 | 7 | # Command line arguments 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 11 | ) 12 | parser.add_argument( 13 | "--metadata_old_root", type=str, default=None, help="Old root in metadata, specified to change to new root" 14 | ) 15 | 16 | 17 | def main(args): 18 | librimix_dir = args.librimix_dir 19 | metadata_old_root = args.metadata_old_root 20 | create_local_metadata(librimix_dir, metadata_old_root) 21 | 22 | 23 | def create_local_metadata(librimix_dir, metadata_old_root): 24 | 25 | corpus = librimix_dir.split("/")[-1] 26 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 27 | for md_dir in md_dirs: 28 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 29 | for md_file in md_files: 30 | print(md_dir, md_file) 31 | subset = md_file.split("_")[1] 32 | local_path = os.path.join( 33 | "data/librimix", os.path.relpath(md_dir, librimix_dir), subset 34 | ).replace("/metadata", "") 35 | os.makedirs(local_path, exist_ok=True) 36 | if metadata_old_root is None: 37 | shutil.copy(os.path.join(md_dir, md_file), local_path) 38 | else: 39 | data = pd.read_csv(os.path.join(md_dir, md_file)) 40 | for key in data.keys(): 41 | if "path" in key: 42 | data[key] = data[key].str.replace(metadata_old_root, librimix_dir) 43 | data.to_csv(os.path.join(local_path, md_file), index=0) 44 | 45 | 46 | if __name__ == "__main__": 47 | args = parser.parse_args() 48 | main(args) 49 | -------------------------------------------------------------------------------- /local/librimix/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Main storage directory. You'll need disk space to store LibriSpeech, WHAM noises 4 | # and LibriMix. This is about 472GB for Libri2Mix and 369GB for Libri3Mix 5 | storage_dir= 6 | n_src= 7 | python_path=python 8 | 9 | . ./utils/parse_options.sh 10 | 11 | current_dir=$(pwd) 12 | # Clone LibriMix repo 13 | git clone https://github.com/JorisCos/LibriMix 14 | 15 | # Run generation script 16 | # Modify generate_librimix.sh if you only want to generate a subset of LibriMix 17 | cd LibriMix 18 | . generate_librimix.sh $storage_dir 19 | 20 | cd $current_dir 21 | $python_path local/librimix/create_local_metadata.py --librimix_dir $storage_dir/Libri$n_src"Mix" 22 | -------------------------------------------------------------------------------- /local/wham/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P tools 14 | cd tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/wham/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir/wsj0 -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /local/wham/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=tmp 4 | out_dir=tmp 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | mkdir -p $out_dir/logs 15 | unzip $out_dir/wham_noise.zip -d $out_dir >> $out_dir/logs/unzip_wham.log 16 | 17 | echo "Download WHAM scripts into $out_dir" 18 | wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir 19 | tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir 20 | mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts 21 | 22 | wait 23 | 24 | echo "Run python scripts to create the WHAM mixtures" 25 | # Requires : Numpy, Scipy, Pandas, and Pysoundfile 26 | cd $out_dir/wham_scripts 27 | $python_path create_wham_from_scratch.py \ 28 | --wsj0-root $wav_dir \ 29 | --wham-noise-root $out_dir/wham_noise\ 30 | --output-dir $out_dir 31 | cd - 32 | -------------------------------------------------------------------------------- /local/wham/preprocess_wham.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import os 5 | import soundfile as sf 6 | from tqdm import tqdm 7 | 8 | 9 | def preprocess_one_dir(in_dir, out_dir, out_filename): 10 | """ Create .json file for one condition.""" 11 | print(in_dir) 12 | file_infos = [] 13 | in_dir = os.path.abspath(in_dir) 14 | wav_list = os.listdir(in_dir) 15 | wav_list.sort() 16 | for wav_file in tqdm(wav_list): 17 | if not wav_file.endswith(".wav"): 18 | continue 19 | wav_path = os.path.join(in_dir, wav_file) 20 | samples = sf.SoundFile(wav_path) 21 | file_infos.append((wav_path, len(samples))) 22 | if not os.path.exists(out_dir): 23 | os.makedirs(out_dir) 24 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 25 | json.dump(file_infos, f, indent=4) 26 | 27 | 28 | def preprocess(inp_args): 29 | """ Create .json files for all conditions.""" 30 | speaker_list = ["mix_both", "mix_clean", "mix_single", "s1", "s2", "noise"] 31 | for data_type in ["tr", "cv", "tt"]: 32 | for spk in speaker_list: 33 | preprocess_one_dir( 34 | os.path.join(inp_args.in_dir, data_type, spk), 35 | os.path.join(inp_args.out_dir, data_type), 36 | spk, 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser("WHAM data preprocessing") 42 | parser.add_argument( 43 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 44 | ) 45 | parser.add_argument( 46 | "--out_dir", type=str, default=None, help="Directory path to put output files" 47 | ) 48 | args = parser.parse_args() 49 | print(args) 50 | preprocess(args) 51 | -------------------------------------------------------------------------------- /perm_general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import soundfile as sf 4 | import torch 5 | import yaml 6 | import json 7 | import argparse 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm import tqdm 11 | from pprint import pprint 12 | from itertools import permutations 13 | 14 | import asteroid 15 | from asteroid.metrics import get_metrics 16 | from asteroid.data.librimix_dataset import LibriMix 17 | from asteroid.data.wsj0_mix import Wsj0mixDataset 18 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr 19 | from asteroid.models import save_publishable 20 | from asteroid.utils import tensors_to_device 21 | 22 | from utils import make_test_dataset 23 | 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--corpus", default="LibriMix", choices=["LibriMix", "wsj0-mix"]) 27 | parser.add_argument("--model", default="ConvTasNet", choices=["ConvTasNet", "DPRNNTasNet", "DPTNet"]) 28 | parser.add_argument("--test_dir", type=str, required=True, help="Test directory including the csv files") 29 | parser.add_argument("--task", type=str, default="sep_clean", choices=["sep_clean", "sep_noisy"]) 30 | parser.add_argument("--use_gpu", type=int, default=0, help="Whether to use the GPU for model execution") 31 | parser.add_argument("--exp_dir", default="exp/tmp", help="Experiment root") 32 | parser.add_argument("--out_dir", type=str, required=True, help="Directory in exp_dir where the eval results will be stored") 33 | 34 | parser.add_argument("--ckpt_path", default="best_model.pth", help="Experiment checkpoint path") 35 | 36 | compute_metrics = ["si_sdr", "sdr", "sir", "sar", "stoi"] 37 | 38 | 39 | def main(conf): 40 | perms = list(permutations(range(conf["train_conf"]["data"]["n_src"]))) 41 | 42 | model_path = os.path.join(conf["exp_dir"], conf["ckpt_path"]) 43 | if conf["ckpt_path"] == "best_model.pth": 44 | # serialized checkpoint 45 | model = getattr(asteroid, conf["model"]).from_pretrained(model_path) 46 | else: 47 | # non-serialized checkpoint, _ckpt_epoch_{i}.ckpt, keys would start with 48 | # "model.", which need to be removed 49 | model = getattr(asteroid, conf["model"])(**conf["train_conf"]["filterbank"], **conf["train_conf"]["masknet"]) 50 | all_states = torch.load(model_path, map_location="cpu") 51 | state_dict = {k.split('.', 1)[1]: all_states["state_dict"][k] for k in all_states["state_dict"]} 52 | model.load_state_dict(state_dict) 53 | # model.load_state_dict(all_states["state_dict"], strict=False) 54 | 55 | # Handle device placement 56 | if conf["use_gpu"]: 57 | model.cuda() 58 | model_device = next(model.parameters()).device 59 | test_set = make_test_dataset( 60 | corpus=conf["corpus"], 61 | test_dir=conf["test_dir"], 62 | task=conf["task"], 63 | sample_rate=conf["sample_rate"], 64 | n_src=conf["train_conf"]["data"]["n_src"], 65 | ) 66 | # Used to reorder sources only 67 | loss_func = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx") 68 | 69 | # all resulting files would be saved in eval_save_dir 70 | eval_save_dir = os.path.join(conf["exp_dir"], conf["out_dir"]) 71 | os.makedirs(eval_save_dir, exist_ok=True) 72 | 73 | series_list = [] 74 | torch.no_grad().__enter__() 75 | for idx in tqdm(range(len(test_set))): 76 | # Forward the network on the mixture. 77 | mix, sources = tensors_to_device(test_set[idx], device=model_device) 78 | est_sources = model(mix.unsqueeze(0)) 79 | 80 | # When inferencing separation for multi-task training, 81 | # exclude the last channel. Does not effect single-task training 82 | # models (from_scratch, pre+FT). 83 | est_sources = est_sources[:, :sources.shape[0]] 84 | _, best_perm_idx = loss_func.find_best_perm(pairwise_neg_sisdr(est_sources, sources[None]), conf["train_conf"]["data"]["n_src"]) 85 | 86 | utt_metrics = {} 87 | if hasattr(test_set, "mixture_path"): 88 | utt_metrics["mix_path"] = test_set.mixture_path 89 | utt_metrics["best_perm_idx"] = ' '.join([str(pidx) for pidx in perms[best_perm_idx[0]]]) 90 | series_list.append(pd.Series(utt_metrics)) 91 | 92 | # Save all metrics to the experiment folder. 93 | all_metrics_df = pd.DataFrame(series_list) 94 | all_metrics_df.to_csv(os.path.join(eval_save_dir, "best_perms.csv")) 95 | 96 | 97 | if __name__ == "__main__": 98 | args = parser.parse_args() 99 | arg_dic = dict(vars(args)) 100 | 101 | # Load training config 102 | conf_path = os.path.join(args.exp_dir, "conf.yml") 103 | with open(conf_path) as f: 104 | train_conf = yaml.safe_load(f) 105 | arg_dic["sample_rate"] = train_conf["data"]["sample_rate"] 106 | arg_dic["train_conf"] = train_conf 107 | 108 | if args.task != arg_dic["train_conf"]["data"]["task"]: 109 | print( 110 | "Warning : the task used to test is different than " 111 | "the one from training, be sure this is what you want." 112 | ) 113 | 114 | main(arg_dic) 115 | -------------------------------------------------------------------------------- /prepare_librimix_data.sh: -------------------------------------------------------------------------------- 1 | local/librimix/prepare_data.sh -------------------------------------------------------------------------------- /prepare_wham_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit on error 4 | set -e 5 | set -o pipefail 6 | 7 | # Main storage directory. You'll need disk space to dump the WHAM mixtures and the wsj0 wav 8 | # files if you start from sphere files. 9 | storage_dir=$PWD 10 | 11 | # If you start from the sphere files, specify the path to the directory and start from stage 0 12 | sphere_dir=$storage_dir/WSJ0 # Directory containing sphere files 13 | # If you already have wsj0 wav files, specify the path to the directory here and start from stage 1 14 | wsj0_wav_dir=$storage_dir/wsj0_wav 15 | # If you already have the WHAM mixtures, specify the path to the directory here and start from stage 2 16 | wham_wav_dir=$storage_dir/wham_wav 17 | # After running the recipe a first time, you can run it from stage 3 directly to train new models. 18 | 19 | # Path to the python you'll use for the experiment. Defaults to the current python 20 | # You can run ./utils/prepare_python_env.sh to create a suitable python environment, paste the output here. 21 | python_path=python 22 | 23 | # Example usage 24 | # ./prepare_wham_data.sh --stage 0 --sphere_dir --storage_dir $PWD 25 | 26 | # General 27 | stage=0 # Controls from which stage to start 28 | 29 | 30 | . utils/parse_options.sh 31 | 32 | if [[ $stage -le 0 ]]; then 33 | echo "WHAM Stage 0: Converting sphere files to wav files" 34 | . local/wham/convert_sphere2wav.sh --sphere_dir $sphere_dir --wav_dir $wsj0_wav_dir 35 | fi 36 | 37 | if [[ $stage -le 1 ]]; then 38 | echo "WHAM Stage 1: Generating 8k and 16k WHAM dataset" 39 | . local/wham/prepare_data.sh --wav_dir $wsj0_wav_dir --out_dir $wham_wav_dir --python_path $python_path 40 | fi 41 | 42 | if [[ $stage -le 2 ]]; then 43 | # Make json directories with min/max modes and sampling rates 44 | echo "WHAM Stage 2: Generating json files including wav path and duration" 45 | for sr_string in 8 16; do 46 | for mode_option in min max; do 47 | tmp_dumpdir=data/wham/wav${sr_string}k/$mode_option 48 | echo "Generating json files in $tmp_dumpdir" 49 | [[ ! -d $tmp_dumpdir ]] && mkdir -p $tmp_dumpdir 50 | local_wham_dir=$wham_wav_dir/wav${sr_string}k/$mode_option/ 51 | $python_path local/wham/preprocess_wham.py --in_dir $local_wham_dir --out_dir $tmp_dumpdir 52 | done 53 | done 54 | fi 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchtext==0.8.0 3 | torchaudio==0.7.0 4 | asteroid 5 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit on error 4 | set -e 5 | set -o pipefail 6 | 7 | # Path to the python you'll use for the experiment. Defaults to the current python 8 | # You can run ./utils/prepare_python_env.sh to create a suitable python environment, paste the output here. 9 | python_path=python 10 | 11 | # Example usage 12 | # bash ./run.sh --id 0 13 | # bash ./run.sh --id 0 --strategy pretrained --load_path exp/xxx/model.pth 14 | # bash ./run.sh --id 0 --strategy multi_task --enh_set train-360 15 | 16 | # General 17 | stage=1 # Controls from which stage to start 18 | tag="" # Controls the directory name associated to the experiment 19 | # You can ask for several GPUs using id (passed to CUDA_VISIBLE_DEVICES) 20 | id=$CUDA_VISIBLE_DEVICES 21 | corpus=wsj0-mix # wsj0-mix or LibriMix 22 | model=ConvTasNet # The model class 23 | 24 | # Arguments for generating data. For more comments, see prepare_wham_data.sh. 25 | storage_dir= 26 | if [[ $corpus == "wsj0-mix" ]]; then 27 | wham_stage=0 28 | sphere_dir= # Directory containing sphere files 29 | wsj0_wav_dir= 30 | wham_wav_dir= 31 | fi 32 | 33 | # Data 34 | task=sep_clean # one of 'enh_single', 'enh_both', 'sep_clean', 'sep_noisy' 35 | sample_rate=8000 36 | mode=min 37 | n_src=2 38 | segment=4 39 | 40 | # Training config 41 | epochs=100 42 | batch_size=6 # batch size per gpu per step 43 | accumulate_grad_batches=1 # accumulate steps 44 | num_workers=8 45 | half_lr=yes 46 | early_stop=yes 47 | strategy=from_scratch 48 | load_path= 49 | enh_set=train-360 50 | resume=no 51 | comet=yes 52 | comet_exp_key= 53 | resume_ckpt= 54 | 55 | # Optim config 56 | optimizer=adam 57 | lr=0.001 58 | weight_decay=0. 59 | 60 | # Network config 61 | if [[ $model == "ConvTasNet" ]]; then 62 | n_blocks=8 63 | n_repeats=3 64 | mask_nonlinear=relu 65 | else 66 | # Add whatever config you want to modify here, and also modify $train_cmd 67 | # below. 68 | true 69 | fi 70 | 71 | # Data config 72 | train_set=train-100 73 | valid_set=dev 74 | test_set=test 75 | 76 | # Evaluation 77 | eval_use_gpu=1 78 | 79 | 80 | . utils/parse_options.sh 81 | 82 | 83 | sr_string=$(($sample_rate/1000)) 84 | suffix=wav${sr_string}k/$mode 85 | 86 | if [[ $corpus == "LibriMix" ]]; then 87 | dumpdir=data/librimix/$suffix # directory to put generated csv file 88 | train_dir=$dumpdir/$train_set 89 | valid_dir=$dumpdir/$valid_set 90 | test_dir=$dumpdir/$test_set 91 | elif [[ $corpus == "wsj0-mix" ]]; then 92 | dumpdir=data/wham/$suffix # directory to put generated json file 93 | train_dir=$dumpdir/tr 94 | valid_dir=$dumpdir/cv 95 | test_dir=$dumpdir/tt 96 | fi 97 | 98 | if [[ $stage -le 0 ]]; then 99 | echo "Stage 0: Generating $corpus dataset" 100 | if [[ $corpus == "LibriMix" ]]; then 101 | . prepare_librimix_data.sh --storage_dir $storage_dir --n_src $n_src 102 | elif [[ $corpus == "wsj0-mix" ]]; then 103 | . prepare_wham_data.sh --stage wham_stage --storage_dir $storage_dir \ 104 | --sphere_dir $sphere_dir --wsj0_wav_dir $wsj0_wav_dir --wham_wav_dir $wham_wav_dir 105 | fi 106 | fi 107 | 108 | # Generate a random ID for the run if no tag is specified 109 | # May need a better recognizable automatic tag in the future 110 | uuid=$($python_path -c 'import uuid, sys; print(str(uuid.uuid4())[:8])') 111 | if [[ -z ${tag} ]]; then 112 | tag=${uuid} 113 | fi 114 | 115 | expdir=exp/train_${model}_${corpus}_${task}_${strategy}_${tag} 116 | mkdir -p $expdir && echo $uuid >> $expdir/run_uuid.txt 117 | echo "Results from the following experiment will be stored in $expdir" 118 | 119 | 120 | # Remove those you want to use from yaml instead of here 121 | train_cmd="--corpus $corpus --model $model \ 122 | --train_dir $train_dir --valid_dir $valid_dir \ 123 | --task $task --sample_rate $sample_rate --n_src $n_src --segment $segment \ 124 | --epochs $epochs --batch_size $batch_size --accumulate_grad_batches $accumulate_grad_batches \ 125 | --num_workers $num_workers --half_lr $half_lr --early_stop $early_stop \ 126 | --optimizer $optimizer --lr $lr --weight_decay $weight_decay" 127 | 128 | # Training config 129 | if [[ $strategy == "multi_task" && -n $enh_set ]]; then 130 | dumpdir=data/librimix/$suffix # directory to put generated csv file 131 | train_enh_dir=$dumpdir/$enh_set 132 | train_cmd="$train_cmd --strategy $strategy --train_enh_dir $train_enh_dir" 133 | elif [[ $strategy == "pretrained" && -n $load_path ]]; then 134 | train_cmd="$train_cmd --strategy $strategy --load_path $load_path" 135 | fi 136 | 137 | if [[ $comet == "yes" ]]; then 138 | train_cmd="$train_cmd --comet" 139 | fi 140 | if [[ $resume == "yes" ]]; then 141 | train_cmd="$train_cmd --resume" 142 | if [[ -n $resume_ckpt ]]; then 143 | train_cmd="$train_cmd --resume_ckpt $resume_ckpt" 144 | fi 145 | if [[ -n $comet_exp_key ]]; then 146 | train_cmd="$train_cmd --comet_exp_key $comet_exp_key" 147 | fi 148 | fi 149 | 150 | # Network config 151 | if [[ $model == "ConvTasNet" ]]; then 152 | train_cmd="$train_cmd --n_blocks $n_blocks --n_repeats $n_repeats --mask_act $mask_nonlinear" 153 | fi 154 | 155 | 156 | if [[ $stage -le 1 ]]; then 157 | echo "Stage 1: Training" 158 | mkdir -p logs 159 | CUDA_VISIBLE_DEVICES=$id $python_path train_general.py \ 160 | $train_cmd \ 161 | --exp_dir ${expdir}/ | tee logs/train_${tag}.log 162 | cp logs/train_${tag}.log $expdir/train.log 163 | 164 | # Get ready to publish 165 | # NOTE: Not recommend to publish from this repo, the recipe_name would be 166 | # confusing. If you wish to upload your pretrained models, please directly 167 | # train your model from asteroid official repo and follow the upload guideline. 168 | mkdir -p $expdir/publish_dir 169 | echo "SungFeng-Huang/SSL-pretraining-separation" > $expdir/publish_dir/recipe_name.txt 170 | fi 171 | 172 | if [[ $stage -le 2 ]]; then 173 | echo "Stage 2 : Evaluation" 174 | CUDA_VISIBLE_DEVICES=$id $python_path eval_general.py \ 175 | --corpus $corpus \ 176 | --model $model \ 177 | --test_dir $test_dir \ 178 | --task $task \ 179 | --use_gpu $eval_use_gpu \ 180 | --exp_dir ${expdir} \ 181 | --out_dir results/best_model \ 182 | --ckpt_path best_model.pth \ 183 | --publishable | tee logs/eval_${tag}.log 184 | cp logs/eval_${tag}.log $expdir/eval.log 185 | fi 186 | -------------------------------------------------------------------------------- /scripts/run_ConvTasNet_Libri2Mix_sep_clean_from_scratch.sh: -------------------------------------------------------------------------------- 1 | # 1 GPU (2080Ti), total batch size=24 2 | bash run.sh --id 0 --corpus LibriMix --batch_size 6 --accumulate_grad_batches 4 --segment 3 3 | -------------------------------------------------------------------------------- /scripts/run_ConvTasNet_Libri2Mix_sep_clean_multi_task.sh: -------------------------------------------------------------------------------- 1 | enh_set=train-360 2 | . utils/parse_options.sh 3 | 4 | # 1 GPU (2080Ti), total batch size=24 5 | bash run.sh --id 0 --corpus LibriMix --batch_size 6 --accumulate_grad_batches 4 --segment 3 --strategy multi_task --enh_set $enh_set 6 | -------------------------------------------------------------------------------- /scripts/run_ConvTasNet_Libri2Mix_sep_clean_pretrained.sh: -------------------------------------------------------------------------------- 1 | load_path=brijmohan/ConvTasNet_Libri1Mix_enhsingle 2 | . utils/parse_options.sh 3 | 4 | # 1 GPU (2080Ti), total batch size=24 5 | bash run.sh --id 0 --corpus LibriMix --batch_size 6 --accumulate_grad_batches 4 --segment 3 --strategy pretrained --load_path $load_path 6 | -------------------------------------------------------------------------------- /scripts/run_ConvTasNet_wsj0-2mix_sep_clean_from_scratch.sh: -------------------------------------------------------------------------------- 1 | # 1 GPU (2080Ti), total batch size=24 2 | bash run.sh --id 0 --batch_size 6 --accumulate_grad_batches 4 --segment 3 3 | -------------------------------------------------------------------------------- /scripts/run_ConvTasNet_wsj0-2mix_sep_clean_multi_task.sh: -------------------------------------------------------------------------------- 1 | enh_set=train-360 2 | . utils/parse_options.sh 3 | 4 | # 1 GPU (2080Ti), total batch size=24 5 | bash run.sh --id 0 --batch_size 6 --accumulate_grad_batches 4 --segment 3 --strategy multi_task --enh_set $enh_set 6 | -------------------------------------------------------------------------------- /scripts/run_ConvTasNet_wsj0-2mix_sep_clean_pretrained.sh: -------------------------------------------------------------------------------- 1 | load_path=brijmohan/ConvTasNet_Libri1Mix_enhsingle 2 | . utils/parse_options.sh 3 | 4 | # 1 GPU (2080Ti), total batch size=24 5 | bash run.sh --id 0 --batch_size 6 --accumulate_grad_batches 4 --segment 3 --strategy pretrained --load_path $load_path 6 | -------------------------------------------------------------------------------- /scripts/run_DPRNN_Libri2Mix_enh_single_from_scratch.sh: -------------------------------------------------------------------------------- 1 | # batch_size (per GPU) = 3 2 | # 8 GPU (V100) 3 | # accumulate_grad_batches = 1 4 | # total batch size = 3 * 8 * 1 = 24 5 | bash run.sh --id 0,1,2,3,4,5,6,7 --corpus LibriMix --model DPRNNTasNet --batch_size 3 --segment 2 --tag denoise --n_src 1 --task enh_single 6 | -------------------------------------------------------------------------------- /scripts/run_DPRNN_wsj0_sep_clean_from_scratch.sh: -------------------------------------------------------------------------------- 1 | # batch_size (per GPU) = 3 2 | # 8 GPU (V100) 3 | # accumulate_grad_batches = 1 4 | # total batch size = 3 * 8 * 1 = 24 5 | bash run.sh --id 0,1,2,3,4,5,6,7 --model DPRNNTasNet --batch_size 3 --segment 2 6 | -------------------------------------------------------------------------------- /scripts/run_DPRNN_wsj0_sep_clean_pretrained.sh: -------------------------------------------------------------------------------- 1 | load_path=exp/train_DPRNNTasNet_LibriMix_enh_single_from_scratch_denoise/best_model.pth 2 | . utils/parse_options.sh 3 | 4 | # batch_size (per GPU) = 3 5 | # 8 GPU (V100) 6 | # accumulate_grad_batches = 1 7 | # total batch size = 3 * 8 * 1 = 24 8 | bash run.sh --id 0,1,2,3,4,5,6,7 --model DPRNNTasNet --batch_size 3 --segment 2 --strategy pretrained --load_path $load_path 9 | -------------------------------------------------------------------------------- /scripts/run_DPTNet_Libri2Mix_enh_single_from_scratch.sh: -------------------------------------------------------------------------------- 1 | # batch_size (per GPU) = 1 2 | # 8 GPU (V100) 3 | # accumulate_grad_batches = 3 4 | # total batch size = 1 * 8 * 3 = 24 5 | bash run.sh --id 0,1,2,3,4,5,6,7 --corpus LibriMix --model DPTNet --batch_size 1 --accumulate_grad_batches 3 --tag denoise --n_src 1 --task enh_single 6 | -------------------------------------------------------------------------------- /scripts/run_DPTNet_wsj0_sep_clean_from_scratch.sh: -------------------------------------------------------------------------------- 1 | # batch_size (per GPU) = 1 2 | # 8 GPU (V100) 3 | # accumulate_grad_batches = 3 4 | # total batch size = 1 * 8 * 3 = 24 5 | bash run.sh --id 0,1,2,3,4,5,6,7 --model DPTNet --batch_size 1 --accumulate_grad_batches 3 6 | -------------------------------------------------------------------------------- /scripts/run_DPTNet_wsj0_sep_clean_pretrained.sh: -------------------------------------------------------------------------------- 1 | load_path=exp/train_DPTNet_LibriMix_enh_single_from_scratch_denoise/best_model.pth 2 | . utils/parse_options.sh 3 | 4 | # batch_size (per GPU) = 1 5 | # 8 GPU (V100) 6 | # accumulate_grad_batches = 3 7 | # total batch size = 1 * 8 * 3 = 24 8 | bash run.sh --id 0,1,2,3,4,5,6,7 --model DPTNet --batch_size 1 --accumulate_grad_batches 3 --strategy pretrained --load_path $load_path 9 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SungFeng-Huang/SSL-pretraining-separation/d7ec4cf6a99f33f38f50b09619b838f51ac456da/src/__init__.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from asteroid.data import LibriMix, WhamDataset 3 | 4 | from .utils import MultiTaskDataLoader 5 | 6 | 7 | def make_dataloaders(corpus, train_dir, val_dir, train_enh_dir=None, task="sep_clean", 8 | sample_rate=8000, n_src=2, segment=4.0, batch_size=4, num_workers=None,): 9 | if corpus == "LibriMix": 10 | train_set = LibriMix(csv_dir=train_dir, task=task, sample_rate=sample_rate, n_src=n_src, segment=segment,) 11 | val_set = LibriMix(csv_dir=val_dir, task=task, sample_rate=sample_rate, n_src=n_src, segment=segment,) 12 | elif corpus == "wsj0-mix": 13 | train_set = WhamDataset(json_dir=train_dir, task=task, sample_rate=sample_rate, nondefault_nsrc=n_src, segment=segment,) 14 | val_set = WhamDataset(json_dir=val_dir, task=task, sample_rate=sample_rate, nondefault_nsrc=n_src, segment=segment,) 15 | 16 | if train_enh_dir is None: 17 | train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=num_workers, drop_last=True,) 18 | else: 19 | train_enh_set = LibriMix(csv_dir=train_enh_dir, task="enh_single", sample_rate=sample_rate, n_src=1, segment=segment,) 20 | train_loader = MultiTaskDataLoader([train_set, train_enh_set], 21 | shuffle=True, batch_size=batch_size, drop_last=True, num_workers=num_workers,) 22 | val_loader = DataLoader(val_set, shuffle=True, batch_size=batch_size, num_workers=num_workers, drop_last=True,) 23 | 24 | infos = train_set.get_infos() 25 | # if train_enh_dir: 26 | # enh_infos = train_enh_set.get_infos() 27 | # for key in enh_infos: 28 | # infos["enh_"+key] = enh_infos[key] 29 | 30 | return train_loader, val_loader, infos 31 | 32 | 33 | def make_test_dataset(corpus, test_dir, task="sep_clean", sample_rate=8000, n_src=2): 34 | if corpus == "LibriMix": 35 | test_set = LibriMix(csv_dir=test_dir, task=task, sample_rate=sample_rate, n_src=n_src, segment=None,) 36 | elif corpus == "wsj0-mix": 37 | test_set = WhamDataset(json_dir=test_dir, task=task, sample_rate=sample_rate, nondefault_nsrc=n_src, segment=None,) 38 | return test_set 39 | -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader, ConcatDataset 2 | from torch.utils.data.sampler import RandomSampler, SequentialSampler, BatchSampler 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | 6 | class MultiTaskBatchSampler(BatchSampler): 7 | 8 | def __init__(self, sampler, batch_size, drop_last, cum_thresholds): 9 | super().__init__(sampler, batch_size, drop_last) 10 | self.thresholds = cum_thresholds 11 | self.thres_ranges = list(zip(self.thresholds, self.thresholds[1:])) 12 | self.range_lens = [ed - st for st, ed in self.thres_ranges] 13 | 14 | def __iter__(self): 15 | batches = [[] for _ in self.thres_ranges] 16 | for idx in self.sampler: 17 | for range_idx, (st, ed) in enumerate(self.thres_ranges): 18 | if st <= idx < ed: 19 | batches[range_idx].append(idx) 20 | if len(batches[range_idx]) == self.batch_size: 21 | yield batches[range_idx] 22 | batches[range_idx] = [] 23 | for range_idx in range(len(self.thres_ranges)): 24 | if len(batches[range_idx]) > 0 and not self.drop_last: 25 | yield batches[range_idx] 26 | 27 | def __len__(self): 28 | if self.drop_last: 29 | return sum([range_len // self.batch_size for range_len in self.range_lens]) 30 | else: 31 | return sum([(range_len + self.batch_size - 1) // self.batch_size for range_len in self.range_lens]) 32 | 33 | 34 | class DistributedBatchSampler(BatchSampler): 35 | """ `BatchSampler` wrapper that distributes across each batch multiple workers. 36 | 37 | Args: 38 | batch_sampler (torch.utils.data.sampler.BatchSampler) 39 | num_replicas (int, optional): Number of processes participating in distributed training. 40 | rank (int, optional): Rank of the current process within num_replicas. 41 | 42 | Example: 43 | >>> from torch.utils.data.sampler import BatchSampler 44 | >>> from torch.utils.data.sampler import SequentialSampler 45 | >>> sampler = SequentialSampler(list(range(12))) 46 | >>> batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False) 47 | >>> 48 | >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=0)) 49 | [[0, 2], [4, 6], [8, 10]] 50 | >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=1)) 51 | [[1, 3], [5, 7], [9, 11]] 52 | 53 | Reference: 54 | torchnlp.samplers.distributed_batch_sampler 55 | """ 56 | 57 | def __init__(self, batch_sampler, **kwargs): 58 | self.batch_sampler = batch_sampler 59 | self.kwargs = kwargs 60 | 61 | def __iter__(self): 62 | for batch in self.batch_sampler: 63 | yield list(DistributedSampler(batch, **self.kwargs)) 64 | 65 | def __len__(self): 66 | return len(self.batch_sampler) 67 | 68 | 69 | def MultiTaskDataLoader(data_sources, shuffle, batch_size, drop_last, generator=None, **kwargs): 70 | dataset = ConcatDataset(data_sources) 71 | cum_thresholds = [0] 72 | for data_source in data_sources: 73 | cum_thresholds.append(cum_thresholds[-1] + len(data_source)) 74 | if shuffle: 75 | sampler = RandomSampler(dataset, generator=generator) 76 | else: 77 | sampler = SequentialSampler(dataset) 78 | batch_sampler = MultiTaskBatchSampler(sampler, batch_size=batch_size, drop_last=drop_last, cum_thresholds=cum_thresholds) 79 | batch_sampler = DistributedBatchSampler(batch_sampler) 80 | return DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) 81 | -------------------------------------------------------------------------------- /src/engine/system.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from asteroid.engine.system import System 3 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr 4 | 5 | 6 | class GeneralSystem(System): 7 | def __init__( 8 | self, 9 | model, 10 | optimizer, 11 | loss_func, 12 | train_loader=None, 13 | val_loader=None, 14 | scheduler=None, 15 | config=None, 16 | ): 17 | super().__init__(model, optimizer, loss_func, train_loader, val_loader, scheduler, config) 18 | 19 | # Load from checkpoint if provided. 20 | if self.config["main_args"].get("load_path") is not None: 21 | all_states = torch.load(self.config["main_args"]["load_path"], map_location="cpu") 22 | assert "state_dict" in all_states 23 | 24 | # If the checkpoint is not the serialized "best_model.pth", its keys 25 | # would start with "model.", which should be removed to avoid none 26 | # of the parameters are loaded. 27 | for key in list(all_states["state_dict"].keys()): 28 | if key.startswith("model"): 29 | print(f"key {key} changed to {key.split('.', 1)[1]}") 30 | all_states["state_dict"][key.split('.', 1)[1]] = all_states["state_dict"][key] 31 | del all_states["state_dict"][key] 32 | 33 | # For debugging, set strict=True to check whether only the following 34 | # parameters have different sizes (since n_src=1 for pre-training 35 | # and n_src=2 for fine-tuning): 36 | # for ConvTasNet: "masker.mask_net.1.*" 37 | # for DPRNNTasNet/DPTNet: "masker.first_out.1.*" 38 | if self.config["main_args"]["model"] == "ConvTasNet": 39 | print(f"key masker.mask_net.1.* removed") 40 | del all_states["state_dict"]["masker.mask_net.1.weight"] 41 | del all_states["state_dict"]["masker.mask_net.1.bias"] 42 | elif self.config["main_args"]["model"] in ["DPRNNTasNet", "DPTNet", "SepFormerTasNet", "SepFormer2TasNet"]: 43 | print(f"key masker.first_out.1.* removed") 44 | del all_states["state_dict"]["masker.first_out.1.weight"] 45 | del all_states["state_dict"]["masker.first_out.1.bias"] 46 | self.model.load_state_dict(all_states["state_dict"], strict=False) 47 | -------------------------------------------------------------------------------- /src/losses/multi_task_wrapper.py: -------------------------------------------------------------------------------- 1 | from asteroid.losses import PITLossWrapper 2 | 3 | 4 | class MultiTaskLossWrapper(PITLossWrapper): 5 | """ n_src separation + 1_src enhancement 6 | """ 7 | def __init__(self, loss_func, pit_from="pw_mtx", perm_reduce=None): 8 | super().__init__(loss_func, pit_from=pit_from, perm_reduce=perm_reduce) 9 | 10 | def forward(self, est_targets, targets, **kwargs): 11 | n_src = targets.shape[1] 12 | if n_src == 1: 13 | # est_targets = est_targets[:, -1].reshape(est_targets.size(0), 1, est_targets.size(2)) 14 | est_targets = est_targets[:, None, -1] 15 | return super().forward(est_targets, targets, **kwargs) 16 | else: 17 | assert est_targets.shape[1] == n_src + 1 18 | est_targets = est_targets[:, :-1] 19 | return super().forward(est_targets, targets, **kwargs) 20 | -------------------------------------------------------------------------------- /src/masknn/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import SepFormer, SepFormer2 2 | 3 | __all__ = [ 4 | "SepFormer", 5 | "SepFormer2", 6 | ] 7 | -------------------------------------------------------------------------------- /src/masknn/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | from math import ceil 3 | import warnings 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.nn.modules.activation import MultiheadAttention 10 | from asteroid.masknn import activations, norms 11 | from asteroid.utils import has_arg 12 | from asteroid.dsp.overlap_add import DualPathProcessing 13 | 14 | 15 | class PreLNTransformerLayer(nn.Module): 16 | """ 17 | Pre-LN Transformer layer. 18 | 19 | Args: 20 | embed_dim (int): Number of input channels. 21 | n_heads (int): Number of attention heads. 22 | dim_ff (int): Number of neurons in the RNNs cell state. 23 | Defaults to 256. RNN here replaces standard FF linear layer in plain Transformer. 24 | dropout (float, optional): Dropout ratio, must be in [0,1]. 25 | activation (str, optional): activation function applied at the output of RNN. 26 | norm (str, optional): Type of normalization to use. 27 | 28 | References 29 | [1] Cem Subakan, Mirco Ravanelli, Samuele Cornell, Mirko Bronzi, and 30 | Jianyuan Zhong. "Attention is All You Need in Speech Separation." 31 | arXiv (2020). 32 | """ 33 | 34 | def __init__( 35 | self, 36 | embed_dim, 37 | n_heads, 38 | dim_ff, 39 | dropout=0.0, 40 | activation="relu", 41 | norm="gLN", 42 | ): 43 | super(PreLNTransformerLayer, self).__init__() 44 | 45 | self.mha = MultiheadAttention(embed_dim, n_heads, dropout=dropout) 46 | self.dropout = nn.Dropout(dropout) 47 | self.linear1 = nn.Linear(embed_dim, dim_ff) 48 | self.linear2 = nn.Linear(dim_ff, embed_dim) 49 | self.activation = activations.get(activation)() 50 | self.norm_mha = norms.get(norm)(embed_dim) 51 | self.norm_ff = norms.get(norm)(embed_dim) 52 | 53 | def forward(self, x): 54 | assert x.shape[0] != x.shape[1], "seq_len == channels would lead to wrong LN dimension" 55 | tomha = self.norm_mha(x) 56 | tomha = tomha.permute(2, 0, 1) 57 | # x is batch, channels, seq_len 58 | # mha is seq_len, batch, channels 59 | # self-attention is applied 60 | out = self.mha(tomha, tomha, tomha)[0] 61 | x = self.dropout(out.permute(1, 2, 0)) + x 62 | 63 | # lstm is applied 64 | toff = self.norm_ff(x) 65 | out = self.linear2(self.dropout(self.activation(self.linear1(toff.transpose(1, -1))))) 66 | x = self.dropout(out.transpose(1, -1)) + x 67 | return x 68 | 69 | class SepFormerLayer(PreLNTransformerLayer): 70 | """ 71 | SepFormer layer. Only the forward in different way. 72 | 73 | Args: 74 | embed_dim (int): Number of input channels. 75 | n_heads (int): Number of attention heads. 76 | dim_ff (int): Number of neurons in the RNNs cell state. 77 | Defaults to 256. RNN here replaces standard FF linear layer in plain Transformer. 78 | dropout (float, optional): Dropout ratio, must be in [0,1]. 79 | activation (str, optional): activation function applied at the output of RNN. 80 | norm (str, optional): Type of normalization to use. 81 | 82 | References 83 | [1] Cem Subakan, Mirco Ravanelli, Samuele Cornell, Mirko Bronzi, and 84 | Jianyuan Zhong. "Attention is All You Need in Speech Separation." 85 | arXiv (2020). 86 | """ 87 | def __init__( 88 | self, 89 | embed_dim, 90 | n_heads, 91 | dim_ff, 92 | dropout=0.0, 93 | activation="relu", 94 | norm="gLN", 95 | ): 96 | super().__init__(embed_dim, n_heads, dim_ff, dropout, activation, norm) 97 | 98 | def forward(self, x): 99 | assert x.shape[0] != x.shape[1], "seq_len == channels would lead to wrong LN dimension" 100 | tomha = self.norm_mha(x) 101 | tomha = tomha.permute(2, 0, 1) 102 | # x is batch, channels, seq_len 103 | # mha is seq_len, batch, channels 104 | # self-attention is applied 105 | out = self.mha(tomha, tomha, tomha)[0] 106 | x1 = self.dropout(out.permute(1, 2, 0)) + x 107 | 108 | # lstm is applied 109 | toff = self.norm_ff(x1) 110 | out = self.linear2(self.dropout(self.activation(self.linear1(toff.transpose(1, -1))))) 111 | x2 = self.dropout(out.transpose(1, -1)) + x 112 | return x2 113 | 114 | 115 | class SepFormer(nn.Module): 116 | """SepFormer introduced in [1]. 117 | 118 | Args: 119 | in_chan (int): Number of input filters. 120 | n_src (int): Number of masks to estimate. 121 | n_heads (int): Number of attention heads. 122 | ff_hid (int): Number of neurons in the RNNs cell state. 123 | Defaults to 256. 124 | chunk_size (int): window size of overlap and add processing. 125 | Defaults to 100. 126 | hop_size (int or None): hop size (stride) of overlap and add processing. 127 | Default to `chunk_size // 2` (50% overlap). 128 | n_repeats (int): Number of repeats. Defaults to 6. 129 | norm_type (str, optional): Type of normalization to use. 130 | ff_activation (str, optional): activation function applied at the output of RNN. 131 | mask_act (str, optional): Which non-linear function to generate mask. 132 | dropout (float, optional): Dropout ratio, must be in [0,1]. 133 | 134 | References 135 | [1] Cem Subakan, Mirco Ravanelli, Samuele Cornell, Mirko Bronzi, and 136 | Jianyuan Zhong. "Attention is All You Need in Speech Separation." 137 | arXiv (2020). 138 | """ 139 | 140 | def __init__( 141 | self, 142 | in_chan, 143 | n_src, 144 | n_heads=4, 145 | ff_hid=256, 146 | chunk_size=100, 147 | hop_size=None, 148 | n_repeats=2, 149 | k_repeats=4, 150 | norm_type="gLN", 151 | ff_activation="relu", 152 | mask_act="relu", 153 | dropout=0, 154 | ): 155 | super().__init__() 156 | self.in_chan = in_chan 157 | self.n_src = n_src 158 | self.n_heads = n_heads 159 | self.ff_hid = ff_hid 160 | self.chunk_size = chunk_size 161 | hop_size = hop_size if hop_size is not None else chunk_size // 2 162 | self.hop_size = hop_size 163 | self.n_repeats = n_repeats 164 | self.k_repeats = k_repeats 165 | self.n_src = n_src 166 | self.norm_type = norm_type 167 | self.ff_activation = ff_activation 168 | self.mask_act = mask_act 169 | self.dropout = dropout 170 | 171 | self.mha_in_dim = ceil(self.in_chan / self.n_heads) * self.n_heads 172 | if self.in_chan % self.n_heads != 0: 173 | warnings.warn( 174 | f"DPTransformer input dim ({self.in_chan}) is not a multiple of the number of " 175 | f"heads ({self.n_heads}). Adding extra linear layer at input to accomodate " 176 | f"(size [{self.in_chan} x {self.mha_in_dim}])" 177 | ) 178 | self.input_layer = nn.Linear(self.in_chan, self.mha_in_dim) 179 | else: 180 | self.input_layer = None 181 | 182 | self.in_norm = norms.get(norm_type)(self.mha_in_dim) 183 | self.ola = DualPathProcessing(self.chunk_size, self.hop_size) 184 | 185 | # Succession of DPRNNBlocks. 186 | self.layers = nn.ModuleList([]) 187 | for x in range(self.n_repeats): 188 | self.layers.append( 189 | nn.ModuleList( 190 | [ 191 | nn.Sequential(*[ 192 | PositionalEncoding( 193 | self.mha_in_dim, 194 | self.dropout 195 | ), 196 | *[ 197 | SepFormerLayer( 198 | self.mha_in_dim, 199 | self.n_heads, 200 | self.ff_hid, 201 | self.dropout, 202 | self.ff_activation, 203 | self.norm_type, 204 | ) for _ in range(self.k_repeats) 205 | ] 206 | ]), 207 | nn.Sequential(*[ 208 | PositionalEncoding( 209 | self.mha_in_dim, 210 | self.dropout 211 | ), 212 | *[ 213 | SepFormerLayer( 214 | self.mha_in_dim, 215 | self.n_heads, 216 | self.ff_hid, 217 | self.dropout, 218 | self.ff_activation, 219 | self.norm_type, 220 | ) for _ in range(self.k_repeats) 221 | ] 222 | ]), 223 | ] 224 | ) 225 | ) 226 | net_out_conv = nn.Conv2d(self.mha_in_dim, n_src * self.in_chan, 1) 227 | self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) 228 | # Gating and masking in 2D space (after fold) 229 | self.mask_net = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), 230 | nn.ReLU(), 231 | nn.Conv1d(self.in_chan, self.in_chan, 1)) 232 | 233 | # Get activation function. 234 | mask_nl_class = activations.get(mask_act) 235 | # For softmax, feed the source dimension. 236 | if has_arg(mask_nl_class, "dim"): 237 | self.output_act = mask_nl_class(dim=1) 238 | else: 239 | self.output_act = mask_nl_class() 240 | 241 | def forward(self, mixture_w): 242 | r"""Forward. 243 | 244 | Args: 245 | mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ 246 | 247 | Returns: 248 | :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ 249 | """ 250 | if self.input_layer is not None: 251 | mixture_w = self.input_layer(mixture_w.transpose(1, 2)).transpose(1, 2) 252 | mixture_w = self.in_norm(mixture_w) # [batch, bn_chan, n_frames] 253 | n_orig_frames = mixture_w.shape[-1] 254 | 255 | mixture_w = self.ola.unfold(mixture_w) 256 | batch, n_filters, self.chunk_size, n_chunks = mixture_w.size() 257 | 258 | for layer_idx in range(len(self.layers)): 259 | intra, inter = self.layers[layer_idx] 260 | mixture_w = self.ola.intra_process(mixture_w, intra) 261 | mixture_w = self.ola.inter_process(mixture_w, inter) 262 | 263 | output = self.first_out(mixture_w) 264 | output = output.reshape(batch * self.n_src, self.in_chan, self.chunk_size, n_chunks) 265 | output = self.ola.fold(output, output_size=n_orig_frames) 266 | 267 | output = self.mask_net(output) 268 | # Compute mask 269 | output = output.reshape(batch, self.n_src, self.in_chan, -1) 270 | est_mask = self.output_act(output) 271 | return est_mask 272 | 273 | def get_config(self): 274 | config = { 275 | "in_chan": self.in_chan, 276 | "ff_hid": self.ff_hid, 277 | "n_heads": self.n_heads, 278 | "chunk_size": self.chunk_size, 279 | "hop_size": self.hop_size, 280 | "n_repeats": self.n_repeats, 281 | "k_repeats": self.k_repeats, 282 | "n_src": self.n_src, 283 | "norm_type": self.norm_type, 284 | "ff_activation": self.ff_activation, 285 | "mask_act": self.mask_act, 286 | "dropout": self.dropout, 287 | } 288 | return config 289 | 290 | 291 | class SepFormer2(nn.Module): 292 | """Modified SepFormer introduced in [1]. 293 | 294 | Args: 295 | in_chan (int): Number of input filters. 296 | n_src (int): Number of masks to estimate. 297 | n_heads (int): Number of attention heads. 298 | ff_hid (int): Number of neurons in the RNNs cell state. 299 | Defaults to 256. 300 | chunk_size (int): window size of overlap and add processing. 301 | Defaults to 100. 302 | hop_size (int or None): hop size (stride) of overlap and add processing. 303 | Default to `chunk_size // 2` (50% overlap). 304 | n_repeats (int): Number of repeats. Defaults to 6. 305 | norm_type (str, optional): Type of normalization to use. 306 | ff_activation (str, optional): activation function applied at the output of RNN. 307 | mask_act (str, optional): Which non-linear function to generate mask. 308 | dropout (float, optional): Dropout ratio, must be in [0,1]. 309 | 310 | References 311 | [1] Cem Subakan, Mirco Ravanelli, Samuele Cornell, Mirko Bronzi, and 312 | Jianyuan Zhong. "Attention is All You Need in Speech Separation." 313 | arXiv (2020). 314 | """ 315 | 316 | def __init__( 317 | self, 318 | in_chan, 319 | n_src, 320 | n_heads=4, 321 | ff_hid=256, 322 | chunk_size=100, 323 | hop_size=None, 324 | n_repeats=2, 325 | k_repeats=4, 326 | norm_type="gLN", 327 | ff_activation="relu", 328 | mask_act="relu", 329 | dropout=0, 330 | ): 331 | super().__init__() 332 | self.in_chan = in_chan 333 | self.n_src = n_src 334 | self.n_heads = n_heads 335 | self.ff_hid = ff_hid 336 | self.chunk_size = chunk_size 337 | hop_size = hop_size if hop_size is not None else chunk_size // 2 338 | self.hop_size = hop_size 339 | self.n_repeats = n_repeats 340 | self.k_repeats = k_repeats 341 | self.n_src = n_src 342 | self.norm_type = norm_type 343 | self.ff_activation = ff_activation 344 | self.mask_act = mask_act 345 | self.dropout = dropout 346 | 347 | self.mha_in_dim = ceil(self.in_chan / self.n_heads) * self.n_heads 348 | if self.in_chan % self.n_heads != 0: 349 | warnings.warn( 350 | f"DPTransformer input dim ({self.in_chan}) is not a multiple of the number of " 351 | f"heads ({self.n_heads}). Adding extra linear layer at input to accomodate " 352 | f"(size [{self.in_chan} x {self.mha_in_dim}])" 353 | ) 354 | self.input_layer = nn.Linear(self.in_chan, self.mha_in_dim) 355 | else: 356 | self.input_layer = None 357 | 358 | self.in_norm = norms.get(norm_type)(self.mha_in_dim) 359 | self.ola = DualPathProcessing(self.chunk_size, self.hop_size) 360 | 361 | # Succession of DPRNNBlocks. 362 | self.layers = nn.ModuleList([]) 363 | for x in range(self.n_repeats): 364 | self.layers.append( 365 | nn.ModuleList( 366 | [ 367 | nn.Sequential(*[ 368 | PositionalEncoding( 369 | self.mha_in_dim, 370 | self.dropout 371 | ), 372 | *[ 373 | PreLNTransformerLayer( 374 | self.mha_in_dim, 375 | self.n_heads, 376 | self.ff_hid, 377 | self.dropout, 378 | self.ff_activation, 379 | self.norm_type, 380 | ) for _ in range(self.k_repeats) 381 | ] 382 | ]), 383 | nn.Sequential(*[ 384 | PositionalEncoding( 385 | self.mha_in_dim, 386 | self.dropout 387 | ), 388 | *[ 389 | PreLNTransformerLayer( 390 | self.mha_in_dim, 391 | self.n_heads, 392 | self.ff_hid, 393 | self.dropout, 394 | self.ff_activation, 395 | self.norm_type, 396 | ) for _ in range(self.k_repeats) 397 | ] 398 | ]), 399 | ] 400 | ) 401 | ) 402 | net_out_conv = nn.Conv2d(self.mha_in_dim, n_src * self.in_chan, 1) 403 | self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) 404 | # Gating and masking in 2D space (after fold) 405 | self.net_out = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Tanh()) 406 | self.net_gate = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Sigmoid()) 407 | 408 | # Get activation function. 409 | mask_nl_class = activations.get(mask_act) 410 | # For softmax, feed the source dimension. 411 | if has_arg(mask_nl_class, "dim"): 412 | self.output_act = mask_nl_class(dim=1) 413 | else: 414 | self.output_act = mask_nl_class() 415 | 416 | def forward(self, mixture_w): 417 | r"""Forward. 418 | 419 | Args: 420 | mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ 421 | 422 | Returns: 423 | :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ 424 | """ 425 | if self.input_layer is not None: 426 | mixture_w = self.input_layer(mixture_w.transpose(1, 2)).transpose(1, 2) 427 | mixture_w = self.in_norm(mixture_w) # [batch, bn_chan, n_frames] 428 | n_orig_frames = mixture_w.shape[-1] 429 | 430 | mixture_w = self.ola.unfold(mixture_w) 431 | batch, n_filters, self.chunk_size, n_chunks = mixture_w.size() 432 | 433 | for layer_idx in range(len(self.layers)): 434 | intra, inter = self.layers[layer_idx] 435 | mixture_w = self.ola.intra_process(mixture_w, intra) 436 | mixture_w = self.ola.inter_process(mixture_w, inter) 437 | 438 | output = self.first_out(mixture_w) 439 | output = output.reshape(batch * self.n_src, self.in_chan, self.chunk_size, n_chunks) 440 | output = self.ola.fold(output, output_size=n_orig_frames) 441 | 442 | output = self.net_out(output) * self.net_gate(output) 443 | # Compute mask 444 | output = output.reshape(batch, self.n_src, self.in_chan, -1) 445 | est_mask = self.output_act(output) 446 | return est_mask 447 | 448 | def get_config(self): 449 | config = { 450 | "in_chan": self.in_chan, 451 | "ff_hid": self.ff_hid, 452 | "n_heads": self.n_heads, 453 | "chunk_size": self.chunk_size, 454 | "hop_size": self.hop_size, 455 | "n_repeats": self.n_repeats, 456 | "k_repeats": self.k_repeats, 457 | "n_src": self.n_src, 458 | "norm_type": self.norm_type, 459 | "ff_activation": self.ff_activation, 460 | "mask_act": self.mask_act, 461 | "dropout": self.dropout, 462 | } 463 | return config 464 | 465 | 466 | class PositionalEncoding(nn.Module): 467 | 468 | def __init__(self, d_model, dropout=0.1, max_len=5000): 469 | super(PositionalEncoding, self).__init__() 470 | self.dropout = nn.Dropout(p=dropout) 471 | 472 | pe = torch.zeros(max_len, d_model) 473 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 474 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 475 | pe[:, 0::2] = torch.sin(position * div_term) 476 | pe[:, 1::2] = torch.cos(position * div_term) 477 | # pe = pe.unsqueeze(0).transpose(0, 1) # seq_len, batch, channels 478 | pe = pe.transpose(0, 1).unsqueeze(0) # batch, channels, seq_len 479 | self.register_buffer('pe', pe) 480 | 481 | def forward(self, x): 482 | # x is seq_len, batch, channels 483 | # x = x + self.pe[:x.size(0), :] 484 | 485 | # x is batch, channels, seq_len 486 | x = x + self.pe[:, :, :x.size(2)] 487 | return self.dropout(x) 488 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | import asteroid 2 | from .sepformer_tasnet import SepFormerTasNet, SepFormer2TasNet 3 | asteroid.models.register_model(SepFormerTasNet) 4 | asteroid.models.register_model(SepFormer2TasNet) 5 | 6 | __all__ = [ 7 | "SepFormerTasNet", 8 | "SepFormer2TasNet", 9 | ] 10 | -------------------------------------------------------------------------------- /src/models/sepformer_tasnet.py: -------------------------------------------------------------------------------- 1 | from asteroid_filterbanks import make_enc_dec 2 | from ..masknn.attention import SepFormer, SepFormer2 3 | from asteroid.models.base_models import BaseEncoderMaskerDecoder 4 | 5 | 6 | class SepFormerTasNet(BaseEncoderMaskerDecoder): 7 | """SepFormer separation model, as described in [1]. 8 | 9 | Args: 10 | n_src (int): Number of masks to estimate. 11 | out_chan (int or None): Number of bins in the estimated masks. 12 | Defaults to `in_chan`. 13 | bn_chan (int): Number of channels after the bottleneck. 14 | Defaults to 128. 15 | hid_size (int): Number of neurons in the RNNs cell state. 16 | Defaults to 128. 17 | chunk_size (int): window size of overlap and add processing. 18 | Defaults to 100. 19 | hop_size (int or None): hop size (stride) of overlap and add processing. 20 | Default to `chunk_size // 2` (50% overlap). 21 | n_repeats (int): Number of repeats. Defaults to 2. 22 | k_repeats (int): Number of intra/inter repeats. Defaults to 4. 23 | norm_type (str, optional): Type of normalization to use. To choose from 24 | 25 | - ``'gLN'``: global Layernorm 26 | - ``'cLN'``: channelwise Layernorm 27 | mask_act (str, optional): Which non-linear function to generate mask. 28 | rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``, 29 | ``'LSTM'`` and ``'GRU'``. 30 | num_layers (int, optional): Number of layers in each RNN. 31 | dropout (float, optional): Dropout ratio, must be in [0,1]. 32 | in_chan (int, optional): Number of input channels, should be equal to 33 | n_filters. 34 | fb_name (str, className): Filterbank family from which to make encoder 35 | and decoder. To choose among [``'free'``, ``'analytic_free'``, 36 | ``'param_sinc'``, ``'stft'``]. 37 | n_filters (int): Number of filters / Input dimension of the masker net. 38 | kernel_size (int): Length of the filters. 39 | stride (int, optional): Stride of the convolution. 40 | If None (default), set to ``kernel_size // 2``. 41 | sample_rate (float): Sampling rate of the model. 42 | **fb_kwargs (dict): Additional kwards to pass to the filterbank 43 | creation. 44 | 45 | References 46 | - [1] Cem Subakan, Mirco Ravanelli, Samuele Cornell, Mirko Bronzi, and 47 | Jianyuan Zhong. "Attention is All You Need in Speech Separation." 48 | arXiv (2020). 49 | """ 50 | 51 | def __init__( 52 | self, 53 | n_src, 54 | n_heads=4, 55 | ff_hid=256, 56 | chunk_size=100, 57 | hop_size=None, 58 | n_repeats=2, 59 | k_repeats=4, 60 | norm_type="gLN", 61 | ff_activation="relu", 62 | encoder_activation="relu", 63 | mask_act="relu", 64 | dropout=0, 65 | in_chan=None, 66 | fb_name="free", 67 | kernel_size=16, 68 | n_filters=64, 69 | stride=8, 70 | sample_rate=8000, 71 | **fb_kwargs, 72 | ): 73 | encoder, decoder = make_enc_dec( 74 | fb_name, 75 | kernel_size=kernel_size, 76 | n_filters=n_filters, 77 | stride=stride, 78 | sample_rate=sample_rate, 79 | **fb_kwargs, 80 | ) 81 | n_feats = encoder.n_feats_out 82 | if in_chan is not None: 83 | assert in_chan == n_feats, ( 84 | "Number of filterbank output channels" 85 | " and number of input channels should " 86 | "be the same. Received " 87 | f"{n_feats} and {in_chan}" 88 | ) 89 | # Update in_chan 90 | masker = SepFormer( 91 | n_feats, 92 | n_src, 93 | n_heads=n_heads, 94 | ff_hid=ff_hid, 95 | ff_activation=ff_activation, 96 | chunk_size=chunk_size, 97 | hop_size=hop_size, 98 | n_repeats=n_repeats, 99 | k_repeats=k_repeats, 100 | norm_type=norm_type, 101 | mask_act=mask_act, 102 | dropout=dropout, 103 | ) 104 | super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation) 105 | 106 | 107 | class SepFormer2TasNet(BaseEncoderMaskerDecoder): 108 | """SepFormer separation model, as described in [1]. 109 | 110 | Args: 111 | n_src (int): Number of masks to estimate. 112 | out_chan (int or None): Number of bins in the estimated masks. 113 | Defaults to `in_chan`. 114 | bn_chan (int): Number of channels after the bottleneck. 115 | Defaults to 128. 116 | hid_size (int): Number of neurons in the RNNs cell state. 117 | Defaults to 128. 118 | chunk_size (int): window size of overlap and add processing. 119 | Defaults to 100. 120 | hop_size (int or None): hop size (stride) of overlap and add processing. 121 | Default to `chunk_size // 2` (50% overlap). 122 | n_repeats (int): Number of repeats. Defaults to 2. 123 | k_repeats (int): Number of intra/inter repeats. Defaults to 4. 124 | norm_type (str, optional): Type of normalization to use. To choose from 125 | 126 | - ``'gLN'``: global Layernorm 127 | - ``'cLN'``: channelwise Layernorm 128 | mask_act (str, optional): Which non-linear function to generate mask. 129 | rnn_type (str, optional): Type of RNN used. Choose between ``'RNN'``, 130 | ``'LSTM'`` and ``'GRU'``. 131 | num_layers (int, optional): Number of layers in each RNN. 132 | dropout (float, optional): Dropout ratio, must be in [0,1]. 133 | in_chan (int, optional): Number of input channels, should be equal to 134 | n_filters. 135 | fb_name (str, className): Filterbank family from which to make encoder 136 | and decoder. To choose among [``'free'``, ``'analytic_free'``, 137 | ``'param_sinc'``, ``'stft'``]. 138 | n_filters (int): Number of filters / Input dimension of the masker net. 139 | kernel_size (int): Length of the filters. 140 | stride (int, optional): Stride of the convolution. 141 | If None (default), set to ``kernel_size // 2``. 142 | sample_rate (float): Sampling rate of the model. 143 | **fb_kwargs (dict): Additional kwards to pass to the filterbank 144 | creation. 145 | 146 | References 147 | - [1] Cem Subakan, Mirco Ravanelli, Samuele Cornell, Mirko Bronzi, and 148 | Jianyuan Zhong. "Attention is All You Need in Speech Separation." 149 | arXiv (2020). 150 | """ 151 | 152 | def __init__( 153 | self, 154 | n_src, 155 | n_heads=4, 156 | ff_hid=256, 157 | chunk_size=100, 158 | hop_size=None, 159 | n_repeats=2, 160 | k_repeats=4, 161 | norm_type="gLN", 162 | ff_activation="relu", 163 | encoder_activation="relu", 164 | mask_act="relu", 165 | dropout=0, 166 | in_chan=None, 167 | fb_name="free", 168 | kernel_size=16, 169 | n_filters=64, 170 | stride=8, 171 | sample_rate=8000, 172 | **fb_kwargs, 173 | ): 174 | encoder, decoder = make_enc_dec( 175 | fb_name, 176 | kernel_size=kernel_size, 177 | n_filters=n_filters, 178 | stride=stride, 179 | sample_rate=sample_rate, 180 | **fb_kwargs, 181 | ) 182 | n_feats = encoder.n_feats_out 183 | if in_chan is not None: 184 | assert in_chan == n_feats, ( 185 | "Number of filterbank output channels" 186 | " and number of input channels should " 187 | "be the same. Received " 188 | f"{n_feats} and {in_chan}" 189 | ) 190 | # Update in_chan 191 | masker = SepFormer2( 192 | n_feats, 193 | n_src, 194 | n_heads=n_heads, 195 | ff_hid=ff_hid, 196 | ff_activation=ff_activation, 197 | chunk_size=chunk_size, 198 | hop_size=hop_size, 199 | n_repeats=n_repeats, 200 | k_repeats=k_repeats, 201 | norm_type=norm_type, 202 | mask_act=mask_act, 203 | dropout=dropout, 204 | ) 205 | super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation) 206 | -------------------------------------------------------------------------------- /train_general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import comet_ml 5 | 6 | import torch 7 | from torch.optim.lr_scheduler import ReduceLROnPlateau 8 | from torch.utils.data import DataLoader 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 11 | 12 | import asteroid 13 | from asteroid.engine.optimizers import make_optimizer 14 | from asteroid.engine.system import System 15 | from asteroid.engine.schedulers import DPTNetScheduler 16 | from asteroid.losses import PITLossWrapper, pairwise_neg_sisdr 17 | 18 | from src.data import make_dataloaders 19 | from src.engine.system import GeneralSystem 20 | from src.losses.multi_task_wrapper import MultiTaskLossWrapper 21 | from src.models import * 22 | pl.seed_everything(42) 23 | 24 | # Keys which are not in the conf.yml file can be added here. 25 | # In the hierarchical dictionary created when parsing, the key `key` can be 26 | # found at dic['main_args'][key] 27 | 28 | # By default train.py will use all available GPUs. The `id` option in run.sh 29 | # will limit the number of available GPUs for train.py . 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--corpus", default="LibriMix", choices=["LibriMix", "wsj0-mix"]) 32 | parser.add_argument("--model", default="ConvTasNet", choices=["ConvTasNet", "DPRNNTasNet", "DPTNet", "SepFormerTasNet", "SepFormer2TasNet"]) 33 | parser.add_argument("--strategy", default="from_scratch", choices=["from_scratch", "pretrained", "multi_task"]) 34 | parser.add_argument("--exp_dir", default="exp/tmp", help="Full path to save best validation model") 35 | parser.add_argument("--accumulate_grad_batches", type=int, default=1, help="Total batch size = batch_size * accumulate_grad_batches") 36 | parser.add_argument("--comet", action="store_true", help="Comet logger") 37 | parser.add_argument("--resume", action="store_true", help="Resume-training") 38 | 39 | known_args = parser.parse_known_args()[0] 40 | if known_args.strategy == "pretrained": 41 | parser.add_argument("--load_path", default=None, required=True, help="Checkpoint path to load for fine-tuning.") 42 | elif known_args.strategy == "multi_task": 43 | parser.add_argument("--train_enh_dir", default=None, required=True, help="Multi-task data dir.") 44 | 45 | if known_args.resume: 46 | parser.add_argument("--resume_ckpt", default="last.ckpt", help="Checkpoint path to load for resume-training") 47 | if known_args.comet: 48 | parser.add_argument("--comet_exp_key", default=None, required=True, help="Comet experiment key") 49 | 50 | 51 | def main(conf): 52 | train_enh_dir = conf["main_args"].get("train_enh_dir", None) 53 | resume_ckpt = conf["main_args"].get("resume_ckpt", None) 54 | 55 | train_loader, val_loader, train_set_infos = make_dataloaders( 56 | corpus=conf["main_args"]["corpus"], 57 | train_dir=conf["data"]["train_dir"], 58 | val_dir=conf["data"]["valid_dir"], 59 | train_enh_dir=train_enh_dir, 60 | task=conf["data"]["task"], 61 | sample_rate=conf["data"]["sample_rate"], 62 | n_src=conf["data"]["n_src"], 63 | segment=conf["data"]["segment"], 64 | batch_size=conf["training"]["batch_size"], 65 | num_workers=conf["training"]["num_workers"], 66 | ) 67 | 68 | conf["masknet"].update({"n_src": conf["data"]["n_src"]}) 69 | if conf["main_args"]["strategy"] == "multi_task": 70 | conf["masknet"].update({"n_src": conf["data"]["n_src"]+1}) 71 | 72 | model = getattr(asteroid.models, conf["main_args"]["model"])(**conf["filterbank"], **conf["masknet"]) 73 | 74 | optimizer = make_optimizer(model.parameters(), **conf["optim"]) 75 | 76 | # Define scheduler 77 | scheduler = None 78 | if conf["main_args"]["model"] in ["DPTNet", "SepFormerTasNet", "SepFormer2TasNet"]: 79 | steps_per_epoch = len(train_loader) // conf["main_args"]["accumulate_grad_batches"] 80 | conf["scheduler"]["steps_per_epoch"] = steps_per_epoch 81 | scheduler = { 82 | "scheduler": DPTNetScheduler( 83 | optimizer=optimizer, 84 | steps_per_epoch=steps_per_epoch, 85 | d_model=model.masker.mha_in_dim, 86 | ), 87 | "interval": "batch", 88 | } 89 | elif conf["training"]["half_lr"]: 90 | scheduler = ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5) 91 | 92 | # Just after instantiating, save the args. Easy loading in the future. 93 | exp_dir = conf["main_args"]["exp_dir"] 94 | os.makedirs(exp_dir, exist_ok=True) 95 | conf_path = os.path.join(exp_dir, "conf.yml") 96 | with open(conf_path, "w") as outfile: 97 | yaml.safe_dump(conf, outfile) 98 | 99 | # Define Loss function. 100 | pit_wrapper = MultiTaskLossWrapper if conf["main_args"]["strategy"] == "multi_task" else PITLossWrapper 101 | loss_func = pit_wrapper(pairwise_neg_sisdr, pit_from="pw_mtx") 102 | system = GeneralSystem( 103 | model=model, 104 | optimizer=optimizer, 105 | loss_func=loss_func, 106 | train_loader=train_loader, 107 | val_loader=val_loader, 108 | scheduler=scheduler, 109 | config=conf, 110 | ) 111 | 112 | # Define callbacks 113 | callbacks = [] 114 | checkpoint_dir = os.path.join(exp_dir, "checkpoints/") 115 | checkpoint = ModelCheckpoint( 116 | dirpath=checkpoint_dir, filename='{epoch}-{step}', monitor="val_loss", mode="min", 117 | save_top_k=conf["training"]["epochs"], save_last=True, verbose=True, 118 | ) 119 | callbacks.append(checkpoint) 120 | if conf["training"]["early_stop"]: 121 | callbacks.append(EarlyStopping(monitor="val_loss", mode="min", patience=30, verbose=True)) 122 | 123 | loggers = [] 124 | tb_logger = pl.loggers.TensorBoardLogger( 125 | os.path.join(exp_dir, "tb_logs/"), 126 | ) 127 | loggers.append(tb_logger) 128 | if conf["main_args"]["comet"]: 129 | comet_logger = pl.loggers.CometLogger( 130 | save_dir=os.path.join(exp_dir, "comet_logs/"), 131 | experiment_key=conf["main_args"].get("comet_exp_key", None), 132 | log_code=True, 133 | log_graph=True, 134 | parse_args=True, 135 | log_env_details=True, 136 | log_git_metadata=True, 137 | log_git_patch=True, 138 | log_env_gpu=True, 139 | log_env_cpu=True, 140 | log_env_host=True, 141 | ) 142 | comet_logger.log_hyperparams(conf) 143 | loggers.append(comet_logger) 144 | 145 | 146 | # Don't ask GPU if they are not available. 147 | gpus = -1 if torch.cuda.is_available() else None 148 | distributed_backend = "ddp" if torch.cuda.is_available() else None # Don't use ddp for multi-task training 149 | 150 | trainer = pl.Trainer( 151 | max_epochs=conf["training"]["epochs"], 152 | logger=loggers, 153 | callbacks=callbacks, 154 | # checkpoint_callback=checkpoint, 155 | # early_stop_callback=callbacks[1], 156 | default_root_dir=exp_dir, 157 | gpus=gpus, 158 | distributed_backend=distributed_backend, 159 | limit_train_batches=1.0, # Useful for fast experiment 160 | # fast_dev_run=True, # Useful for debugging 161 | # overfit_batches=0.001, # Useful for debugging 162 | gradient_clip_val=5.0, 163 | accumulate_grad_batches=conf["main_args"]["accumulate_grad_batches"], 164 | resume_from_checkpoint=resume_ckpt, 165 | deterministic=True, 166 | replace_sampler_ddp=False if conf["main_args"]["strategy"] == "multi_task" else True, 167 | ) 168 | trainer.fit(system) 169 | 170 | best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} 171 | with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f: 172 | json.dump(best_k, f, indent=0) 173 | 174 | state_dict = torch.load(checkpoint.best_model_path) 175 | system.load_state_dict(state_dict=state_dict["state_dict"]) 176 | system.cpu() 177 | 178 | to_save = system.model.serialize() 179 | to_save.update(train_set_infos) 180 | torch.save(to_save, os.path.join(exp_dir, "best_model.pth")) 181 | 182 | 183 | if __name__ == "__main__": 184 | import yaml 185 | from pprint import pprint 186 | from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict 187 | 188 | model_type = known_args.model 189 | # We start with opening the config file conf.yml as a dictionary from 190 | # which we can create parsers. Each top level key in the dictionary defined 191 | # by the YAML file creates a group in the parser. 192 | with open(f"local/{model_type}.yml") as f: 193 | def_conf = yaml.safe_load(f) 194 | parser = prepare_parser_from_dict(def_conf, parser=parser) 195 | # Arguments are then parsed into a hierarchical dictionary (instead of 196 | # flat, as returned by argparse) to facilitate calls to the different 197 | # asteroid methods (see in main). 198 | # plain_args is the direct output of parser.parse_args() and contains all 199 | # the attributes in an non-hierarchical structure. It can be useful to also 200 | # have it so we included it here but it is not used. 201 | arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True) 202 | pprint(arg_dic) 203 | main(arg_dic) 204 | -------------------------------------------------------------------------------- /utils/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /utils/prepare_python_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C 3 | install_dir=~ 4 | asteroid_root=../../../../ 5 | pip_requires=../../../requirements.txt # Expects a requirement.txt 6 | 7 | . utils/parse_options.sh || exit 1 8 | 9 | mkdir -p $install_dir 10 | cd $install_dir 11 | echo "Download and install latest version of miniconda3 into ${install_dir}" 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 15 | pip_path=$PWD/miniconda3/bin/pip 16 | 17 | rm Miniconda3-latest-Linux-x86_64.sh 18 | cd - 19 | 20 | if [[ ! -z ${pip_requires} ]]; then 21 | $pip_path install -r $pip_requires 22 | fi 23 | $pip_path install soundfile 24 | $pip_path install -e $asteroid_root 25 | #$pip_path install ${asteroid_root}/\[""evaluate""\] 26 | echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" 27 | --------------------------------------------------------------------------------