├── .gitignore ├── README.md ├── anomaly ├── baseline.py ├── baseline.yaml ├── baseline_mix.py ├── baseline_shared.py ├── baseline_src_xumx_original.py ├── model.py ├── requirements.txt ├── run_mix.sh ├── run_seed.sh ├── run_sep.sh ├── run_single.sh ├── ssad_masked_ae.py └── utils.py ├── asteroid ├── __init__.py ├── binarize.py ├── complex_nn.py ├── data │ ├── __init__.py │ ├── mimii_dataset.py │ ├── mimii_single_dataset.py │ ├── mimii_slider_dataset.py │ ├── mimii_valve_dataset.py │ └── musdb18_dataset.py ├── dsp │ ├── __init__.py │ ├── beamforming.py │ ├── consistency.py │ ├── deltas.py │ ├── normalization.py │ ├── overlap_add.py │ ├── spatial.py │ └── vad.py ├── engine │ ├── __init__.py │ ├── optimizers.py │ ├── schedulers.py │ └── system.py ├── losses │ ├── __init__.py │ ├── bark_matrix_16k.mat │ ├── bark_matrix_8k.mat │ ├── cluster.py │ ├── mixit_wrapper.py │ ├── mse.py │ ├── multi_scale_spectral.py │ ├── pit_wrapper.py │ ├── pmsqe.py │ ├── sdr.py │ ├── sinkpit_wrapper.py │ ├── soft_f1.py │ └── stoi.py ├── masknn │ ├── __init__.py │ ├── _dccrn_architectures.py │ ├── _dcunet_architectures.py │ ├── _local.py │ ├── activations.py │ ├── attention.py │ ├── base.py │ ├── convolutional.py │ ├── norms.py │ ├── recurrent.py │ └── tac.py ├── metrics.py ├── models │ ├── README.md │ ├── __init__.py │ ├── base_models.py │ ├── publisher.py │ ├── x_umx.py │ └── x_umx_control.py ├── scripts │ ├── __init__.py │ ├── asteroid_cli.py │ └── asteroid_versions.py ├── separate.py └── utils │ ├── __init__.py │ ├── deprecation_utils.py │ ├── generic_utils.py │ ├── hub_utils.py │ ├── parser_utils.py │ ├── test_utils.py │ └── torch_utils.py ├── environment.yml ├── informed-X-UMX ├── local │ ├── conf_base.yml │ ├── conf_informed.yml │ ├── conf_pit.yml │ └── dataloader.py ├── loss.py ├── requirements.txt └── train.py ├── requirements.txt ├── requirements ├── dev.txt └── install.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific 2 | uploader_info.yml 3 | egs/**/exp 4 | # Where we'll save pretrained models git folders. 5 | pretrained_models 6 | 7 | **/wandb 8 | 9 | anomaly/model/** 10 | anomaly/pickle/** 11 | anomaly/result/** 12 | anomaly/result** 13 | 14 | .vscode* 15 | 16 | *.pkl 17 | 18 | # Byte-compiled / optimized / DLL files 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | .idea # Folder 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Jupyter Notebook 55 | .ipynb_checkpoints 56 | 57 | # pyenv 58 | .python-version 59 | 60 | # Environments 61 | .env 62 | .venv 63 | env/ 64 | venv/ 65 | ENV/ 66 | env.bak/ 67 | venv.bak/ 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # SSAD: source separation followed by anomaly detection 3 | 4 | This is an official repository for our paper, "Activity-informed Industrial Audio Anomaly Detection via Source Separation". 5 | 6 | If you are considering using repository, please cite our paper: 7 | ``` 8 | @inproceedings{kim2023activity, 9 | title={Activity-informed Industrial Audio Anomaly Detection via Source Separation}, 10 | author={Jaechang Kim and Yunjoo Lee and Hyun Mi Cho and Dong Woo Kim and Chi Hoon Song and Jungseul Ok}, 11 | booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing}, 12 | year={2023} 13 | } 14 | ``` 15 | 16 | 17 | # Environment Setting 18 | ```base 19 | conda env create -n asteroid 20 | conda activate asteroid 21 | conda install python=3.7 22 | conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia 23 | 24 | # for asteroid 25 | pip install -r requirements/dev.txt 26 | pip install -e . 27 | pip install torchmetrics==0.6.0 28 | 29 | # for anomaly detection 30 | pip install -r anomaly/requirements.txt 31 | 32 | ``` 33 | 34 | # Training Source Separation Models 35 | To run X-UMX change the configuration file considering the type of data and the use of control signal. 36 | ## 1. First change the configuration 37 | ```bash 38 | cd informed-X-UMX 39 | vi local/conf_???.yml 40 | ``` 41 | 42 | Edit `local/conf_base.yml` for XUMX baseline and `local/conf_informed.yml` for informed source separation model. 43 | 44 | * data:train_dir -> MIMII dataset directory 45 | * data:output -> directory where checkpoint and log files will be saved 46 | * data:machine_type -> machine types to use 47 | * data:sources -> machine ids to use 48 | * model:pretrained -> pretrained model path 49 | 50 | ## 2. train the model by running 51 | ```bash 52 | cd informed-X-UMX 53 | train.py --conf local/conf_base.yml 54 | ``` 55 | Run train.py for with given configuration file. 56 | 57 | 58 | # Anomaly Detection models 59 | 60 | Edit `anomaly/baseline.yaml` 61 | 62 | * base_directory -> MIMII dataset path 63 | 64 | ## Oracle baseline 65 | 66 | Edit `anomaly/baseline.py` 67 | 68 | * Check datapath near line 196 69 | * dirs = sorted(glob.glob(os.path.abspath("{base}/6dB/valve/id_00".format(base=param["base_directory"])))) 70 | * Choose which machines (types, id) to use 71 | 72 | ```bash 73 | cd anomaly 74 | python baseline.py 75 | ``` 76 | 77 | 78 | ## Mixture baseline 79 | 80 | Edit `anomaly/baseline_mix.py` 81 | 82 | * Check datapath near line 228 83 | * dirs = sorted(glob.glob(os.path.abspath("{base}/6dB/valve/id_00".format(base=param["base_directory"])))) 84 | * Choose which machines (types, id) to use 85 | * Check machine_types near line 42 86 | * Those machine types will be used to make a mixture 87 | 88 | ```bash 89 | cd anomaly 90 | python baseline_mix.py 91 | ``` 92 | 93 | ## SSAD (Proposed Method) 94 | 95 | 96 | Edit `anomaly/baseline_src_xumx_original.py` 97 | 98 | * Check datapath near line 318 99 | * dirs = sorted(glob.glob(os.path.abspath("{base}/6dB/valve/id_00".format(base=param["base_directory"])))) 100 | * Choose which machines (types, id) to use 101 | * Check trained separation model path near 363 102 | * Check conf near line 43 103 | * S1, S2 -> machine id 104 | * FILE -> AE model path (to save) 105 | * 106 | 107 | ```bash 108 | cd anomaly 109 | python baseline_src_xumx_original.py 110 | ``` 111 | 112 | # Acknowledgement 113 | 114 | This repository is based on 115 | 116 | * https://github.com/asteroid-team/asteroid 117 | * https://github.com/MIMII-hitachi/mimii_baseline 118 | 119 | -------------------------------------------------------------------------------- /anomaly/baseline.yaml: -------------------------------------------------------------------------------- 1 | base_directory : /dev/shm/mimii 2 | pickle_directory: ./pickle 3 | model_directory: ./model 4 | result_directory: ./result 5 | result_file: slider_id00_02_original_seed3.yaml 6 | seed: 3 7 | 8 | feature: 9 | n_mels: 64 10 | frames : 5 11 | n_fft: 1024 12 | hop_length: 512 13 | power: 2.0 14 | 15 | fit: 16 | compile: 17 | optimizer : adam 18 | loss : mean_squared_error 19 | epochs : 1 20 | batch_size : 512 21 | shuffle : True 22 | validation_split : 0.1 23 | verbose : 1 -------------------------------------------------------------------------------- /anomaly/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from asteroid.models import XUMXControl 4 | 5 | from utils import bandwidth_to_max_bin 6 | 7 | ######################################################################## 8 | # model 9 | ######################################################################## 10 | 11 | class TorchModel(nn.Module): 12 | def __init__(self, dim_input): 13 | super(TorchModel,self).__init__() 14 | self.ff = nn.Sequential( 15 | nn.Linear(dim_input, 64), 16 | nn.ReLU(), 17 | nn.Linear(64, 64), 18 | nn.ReLU(), 19 | nn.Linear(64, 8), 20 | nn.ReLU(), 21 | nn.Linear(8, 64), 22 | nn.ReLU(), 23 | nn.Linear(64, 64), 24 | nn.ReLU(), 25 | nn.Linear(64, dim_input), 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.ff(x) 30 | return x 31 | 32 | 33 | class TorchConvModel(nn.Module): 34 | def __init__(self): 35 | super(TorchConvModel,self).__init__() 36 | self.ff = nn.Sequential( 37 | nn.Conv2d(1, 4, kernel_size=5), 38 | nn.ReLU(), 39 | nn.Conv2d(4, 32, kernel_size=3), 40 | nn.ReLU(), 41 | nn.Conv2d(32, 64, kernel_size=3), 42 | nn.ReLU(), 43 | nn.ConvTranspose2d(64, 32, kernel_size=3), 44 | nn.ReLU(), 45 | nn.ConvTranspose2d(32, 4, kernel_size=3), 46 | nn.ReLU(), 47 | nn.ConvTranspose2d(4, 1, kernel_size=5), 48 | ) 49 | 50 | def forward(self, x): 51 | assert len(x.shape) == 3 52 | #[B, T, F] 53 | x = self.ff(x.unsqueeze(1)).squeeze(1) 54 | return x 55 | 56 | ######################################################################## 57 | 58 | class XUMXSystem(torch.nn.Module): 59 | def __init__(self): 60 | super().__init__() 61 | self.model = None 62 | 63 | 64 | def xumx_model(path): 65 | 66 | x_unmix = XUMXControl( 67 | window_length=4096, 68 | input_mean=None, 69 | input_scale=None, 70 | nb_channels=2, 71 | hidden_size=512, 72 | in_chan=4096, 73 | n_hop=1024, 74 | sources=['s1', 's2'], 75 | max_bin=bandwidth_to_max_bin(16000, 4096, 16000), 76 | bidirectional=True, 77 | sample_rate=16000, 78 | spec_power=1, 79 | return_time_signals=True, 80 | ) 81 | 82 | conf = torch.load(path, map_location="cpu") 83 | 84 | system = XUMXSystem() 85 | system.model = x_unmix 86 | 87 | system.load_state_dict(conf['state_dict'], strict=False) 88 | 89 | return system.model 90 | -------------------------------------------------------------------------------- /anomaly/requirements.txt: -------------------------------------------------------------------------------- 1 | # Minimal package to use MIMII dataset baseline. 2 | matplotlib 3 | numpy 4 | PyYAML 5 | scikit-learn 6 | librosa 7 | audioread 8 | setuptools 9 | fast_bss_eval 10 | -------------------------------------------------------------------------------- /anomaly/run_mix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | SEEDS=(1 2 3 4 5 6 7 8 9 10) 5 | MACHINES=("valve" "slider") 6 | GPU=6 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | PYTHON="python" 10 | PYTHON_FILE="baseline_mix.py" 11 | 12 | for seed in ${SEEDS[@]}; do 13 | for MACHINE in ${MACHINES[@]}; do 14 | sed -i "s@^seed.*@seed: ${seed}@g" baseline.yaml 15 | 16 | RESULT_DIR="result_1022_dilate_label" 17 | mkdir -p ${RESULT_DIR} 18 | sed -i "s@^result_directory.*@result_directory: ${RESULT_DIR}@g" baseline.yaml 19 | 20 | RESULT_NAME="mixture_baseline_${MACHINE}_seed${seed}.yaml" 21 | sed -i "s@^result_file.*@result_file: ${RESULT_NAME}@g" baseline.yaml 22 | 23 | sed -i "s@^MACHINE =.*@MACHINE = '${MACHINE}'@g" ${PYTHON_FILE} 24 | sed -i "s@^S1 =.*@S1 = 'id_00'@g" ${PYTHON_FILE} 25 | sed -i "s@^S2 =.*@S2 = 'id_02'@g" ${PYTHON_FILE} 26 | 27 | $PYTHON ${PYTHON_FILE} 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /anomaly/run_seed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SEEDS=(1 2 3 4 5 6 7 8 9 10) 4 | 5 | for seed in ${SEEDS[@]}; do 6 | RESULT_NAME="slider_id00_02_original_seed${seed}.yaml" 7 | sed -i "s/^seed.*/seed: ${seed}/g" baseline.yaml 8 | sed -i "s/^result_file.*/result_file: ${RESULT_NAME}/g" baseline.yaml 9 | python baseline_src_xumx_original.py 10 | done 11 | -------------------------------------------------------------------------------- /anomaly/run_sep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | SEEDS=(1 2 3 4 5 6 7 8 9 10) 5 | MACHINES=("valve" "slider") 6 | GPU=5 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | PYTHON="python" 10 | PYTHON_FILE="baseline_src_xumx_original.py" 11 | 12 | for seed in ${SEEDS[@]}; do 13 | for MACHINE in ${MACHINES[@]}; do 14 | sed -i "s@^seed.*@seed: ${seed}@g" baseline.yaml 15 | 16 | RESULT_DIR="result_1022_dilate_label" 17 | mkdir -p ${RESULT_DIR} 18 | sed -i "s@^result_directory.*@result_directory: ${RESULT_DIR}@g" baseline.yaml 19 | 20 | RESULT_NAME="ssad_${MACHINE}_seed${seed}.yaml" 21 | sed -i "s@^result_file.*@result_file: ${RESULT_NAME}@g" baseline.yaml 22 | 23 | sed -i "s@^MACHINE =.*@MACHINE = '${MACHINE}'@g" ${PYTHON_FILE} 24 | sed -i "s@^S1 =.*@S1 = 'id_00'@g" ${PYTHON_FILE} 25 | sed -i "s@^S2 =.*@S2 = 'id_02'@g" ${PYTHON_FILE} 26 | 27 | $PYTHON ${PYTHON_FILE} 28 | done 29 | done 30 | -------------------------------------------------------------------------------- /anomaly/run_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | SEEDS=(1 2 3 4 5 6 7 8 9 10) 5 | MACHINES=("valve" "slider") 6 | GPU=7 7 | 8 | export CUDA_VISIBLE_DEVICES=$GPU 9 | PYTHON="python" 10 | PYTHON_FILE="baseline.py" 11 | 12 | for seed in ${SEEDS[@]}; do 13 | for MACHINE in ${MACHINES[@]}; do 14 | sed -i "s@^seed.*@seed: ${seed}@g" baseline.yaml 15 | 16 | RESULT_DIR="result_1022_dilate_label" 17 | mkdir -p ${RESULT_DIR} 18 | sed -i "s@^result_directory.*@result_directory: ${RESULT_DIR}@g" baseline.yaml 19 | 20 | RESULT_NAME="oracle_baseline_${MACHINE}_seed${seed}.yaml" 21 | sed -i "s@^result_file.*@result_file: ${RESULT_NAME}@g" baseline.yaml 22 | 23 | sed -i "s@^MACHINE =.*@MACHINE = '${MACHINE}'@g" ${PYTHON_FILE} 24 | 25 | ${PYTHON} ${PYTHON_FILE} 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /asteroid/__init__.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | #from .models import ConvTasNet, DCCRNet, DCUNet, DPRNNTasNet, DPTNet, LSTMTasNet, DeMask 4 | from .utils import deprecation_utils, torch_utils # noqa 5 | 6 | project_root = str(pathlib.Path(__file__).expanduser().absolute().parent.parent) 7 | __version__ = "0.6.0dev" 8 | 9 | 10 | def show_available_models(): 11 | from .utils.hub_utils import MODELS_URLS_HASHTABLE 12 | 13 | print(" \n".join(list(MODELS_URLS_HASHTABLE.keys()))) 14 | 15 | 16 | def available_models(): 17 | from .utils.hub_utils import MODELS_URLS_HASHTABLE 18 | 19 | return MODELS_URLS_HASHTABLE 20 | 21 | 22 | __all__ = [ 23 | "show_available_models", 24 | ] 25 | -------------------------------------------------------------------------------- /asteroid/binarize.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | import torch 3 | 4 | 5 | class Binarize(torch.nn.Module): 6 | """This module transform a sequence of real numbers between 0 and 1 to a sequence of 0 or 1. 7 | The logic for transformation is based on thresholding and avoids jumping from 0 to 1 inadvertently. 8 | 9 | Example: 10 | 11 | >>> binarizer = Binarize(threshold=0.5, stability=3, sample_rate=1) 12 | >>> inputs=torch.Tensor([0.1, 0.6, 0.2, 0.6, 0.1, 0.1, 0.1, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.1]) 13 | >>> # |------------------|-------------|----------------------------|----| 14 | >>> # unstable stable stable irregularity 15 | >>> result = binarizer(inputs.unsqueeze(0).unsqueeze(0)) 16 | >>> print(result) 17 | tensor([[[0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1.]]]) 18 | 19 | """ 20 | 21 | def __init__(self, threshold=0.5, stability=0.1, sample_rate=8000): 22 | """ 23 | 24 | Args: 25 | threshold (float): if x > threshold 0 else 1 26 | stability (float): Minimum number of seconds of 0 (or 1) required to change from 1 (or 0) to 0 (or 1) 27 | sample_rate (int): The sample rate of the wave form 28 | """ 29 | super().__init__() 30 | self.threshold = threshold 31 | self.stability = stability 32 | self.sample_rate = sample_rate 33 | 34 | def forward(self, x): 35 | active = x > self.threshold 36 | active = active.squeeze(1).tolist() 37 | pairs = count_same_pair(active) 38 | active = transform_to_binary_sequence(pairs, self.stability, self.sample_rate) 39 | return active 40 | 41 | 42 | def count_same_pair(nums): 43 | """Transform a list of 0 and 1 in a list of (value, num_consecutive_occurences). 44 | 45 | Args: 46 | nums (list): List of list containing the binary sequences. 47 | 48 | Returns: 49 | List of values and number consecutive occurences. 50 | 51 | Example: 52 | >>> nums = [[0,0,1,0]] 53 | >>> result = count_same_pair(nums) 54 | >>> print(result) 55 | [[[0, 2], [1, 1], [0, 1]]] 56 | 57 | """ 58 | result = [] 59 | for num in nums: 60 | result.append([[i, sum(1 for _ in group)] for i, group in groupby(num)]) 61 | return result 62 | 63 | 64 | def transform_to_binary_sequence(pairs, stability, sample_rate): 65 | """Transforms list of value and consecutive occurrences into a binary sequence with respect to stability 66 | 67 | Args: 68 | pairs (List): List of list of value and consecutive occurrences 69 | stability (Float): Minimal number of seconds to change from 0 to 1 or 1 to 0. 70 | sample_rate (int): The sample rate of the waveform. 71 | 72 | Returns: 73 | Torch.tensor : The binary sequences. 74 | """ 75 | 76 | batch_active = [] 77 | for pair in pairs: 78 | active = [] 79 | # Check for fully silent or fully voice sequence 80 | active, check = check_silence_or_voice(active, pair) 81 | if check: 82 | return active 83 | # Counter for every set of (value, num_consecutive_occ) 84 | i = 0 85 | # Do until every every sets as been used i.e until we have same sequence length as input length 86 | while i < len(pair): 87 | value, num_consecutive_occurrences = pair[i] 88 | # Counter for active set of (value, num_consecutive_occ) i.e (1,num_consecutive_occ) 89 | actived = 0 90 | # Counter for inactive set of (value, num_consecutive_occ) i.e (0,num_consecutive_occ) 91 | not_actived = 0 92 | # num_consecutive_occ < int(stability * sample_rate) need to resolve instability 93 | if num_consecutive_occurrences < int(stability * sample_rate): 94 | # Resolve instability 95 | active, i = resolve_instability( 96 | i, pair, stability, sample_rate, actived, not_actived, active 97 | ) 98 | # num_consecutive_occ > int(stability * sample_rate) we can already choose 99 | else: 100 | if value: 101 | active.append(torch.ones(pair[i][1])) 102 | else: 103 | active.append(torch.zeros(pair[i][1])) 104 | i += 1 105 | # Stack sequence to return a batch shaped tensor 106 | batch_active.append(torch.hstack(active)) 107 | batch_active = torch.vstack(batch_active).unsqueeze(1) 108 | return batch_active 109 | 110 | 111 | def check_silence_or_voice(active, pair): 112 | """Check if sequence is fully silence or fully voice. 113 | 114 | Args: 115 | active (List) : List containing the binary sequence 116 | pair: (List): list of value and consecutive occurrences 117 | 118 | """ 119 | value, num_consecutive_occurrences = pair[0] 120 | check = False 121 | if len(pair) == 1: 122 | check = True 123 | if value: 124 | active = torch.ones(num_consecutive_occurrences) 125 | else: 126 | active = torch.zeros(num_consecutive_occurrences) 127 | return active, check 128 | 129 | 130 | def resolve_instability(i, pair, stability, sample_rate, actived, not_actived, active): 131 | """Resolve stability issue in input list of value and num_consecutive_occ 132 | 133 | Args: 134 | i (int): The index of the considered pair of value and num_consecutive_occ. 135 | pair (list) : Value and num_consecutive_occ. 136 | stability (float): Minimal number of seconds to change from 0 to 1 or 1 to 0. 137 | sample_rate (int): The sample rate of the waveform. 138 | actived (int) : Number of occurrences of the value 1. 139 | not_actived (int): Number of occurrences of the value 0. 140 | active (list) : The binary sequence. 141 | 142 | Returns: 143 | active (list) : The binary sequence. 144 | i (int): The index of the considered pair of value and num_consecutive_occ. 145 | """ 146 | # Until we find stability count the number of samples active and inactive 147 | while i < len(pair) and pair[i][1] < int(stability * sample_rate): 148 | value, num_consecutive_occurrences = pair[i] 149 | if value: 150 | actived += num_consecutive_occurrences 151 | i += 1 152 | else: 153 | not_actived += num_consecutive_occurrences 154 | i += 1 155 | # If the length of unstable samples is smaller than the stability criteria and we are already in a state 156 | # then keep this state. 157 | if actived + not_actived < int(stability * sample_rate) and len(active) > 0: 158 | # Last value 159 | if active[-1][0] == 1: 160 | active.append(torch.ones(actived + not_actived)) 161 | else: 162 | active.append(torch.zeros(actived + not_actived)) 163 | # If the length of unstable samples is smaller than the stability criteria and but we have no state yet 164 | # then consider it silent 165 | elif actived + not_actived < int(stability * sample_rate) and len(active) == 0: 166 | active.append(torch.zeros(actived + not_actived)) 167 | # If the length of unstable samples is greater than the stability criteria then compare number of active 168 | # and inactive samples and choose. 169 | else: 170 | if actived > not_actived: 171 | active.append(torch.ones(actived + not_actived)) 172 | else: 173 | active.append(torch.zeros(actived + not_actived)) 174 | 175 | return active, i 176 | -------------------------------------------------------------------------------- /asteroid/complex_nn.py: -------------------------------------------------------------------------------- 1 | """Complex building blocks that work with PyTorch native (!) complex tensors, i.e. 2 | dtypes complex64/complex128, or tensors for which `.is_complex()` returns True. 3 | 4 | Note that Asteroid code has two other representations of complex numbers: 5 | 6 | - Torchaudio representation [..., 2] where [..., 0] and [..., 1] are real and 7 | imaginary components, respectively 8 | - Asteroid style representation, identical to the Torchaudio representation, but 9 | with the last dimension concatenated: tensor([r1, r2, ..., rn, i1, i2, ..., in]). 10 | The concatenated (2 * n) dimension may be at an arbitrary position, i.e. the tensor 11 | is of shape [..., 2 * n, ...]. See `asteroid_filterbanks.transforms` for details. 12 | """ 13 | import functools 14 | import torch 15 | from asteroid_filterbanks import transforms 16 | from torch import nn 17 | 18 | 19 | # Alias to denote PyTorch native complex tensor (complex64/complex128). 20 | # `.is_complex()` returns True on these tensors. 21 | ComplexTensor = torch.Tensor 22 | 23 | 24 | def is_torch_complex(x): 25 | return x.is_complex() 26 | 27 | 28 | def torch_complex_from_magphase(mag, phase): 29 | return torch.view_as_complex( 30 | torch.stack((mag * torch.cos(phase), mag * torch.sin(phase)), dim=-1) 31 | ) 32 | 33 | 34 | def torch_complex_from_reim(re, im): 35 | return torch.view_as_complex(torch.stack([re, im], dim=-1)) 36 | 37 | 38 | def on_reim(f): 39 | """Make a complex-valued function callable from a real-valued one by applying it to 40 | the real and imaginary components independently. 41 | 42 | Return: 43 | cf(x), complex version of `f`: A function that applies `f` to the real and 44 | imaginary components of `x` and returns the result as PyTorch complex tensor. 45 | """ 46 | 47 | @functools.wraps(f) 48 | def cf(x): 49 | return torch_complex_from_reim(f(x.real), f(x.imag)) 50 | 51 | # functools.wraps keeps the original name of `f`, which might be confusing, 52 | # since we are creating a new function that behaves differently. 53 | # Both __name__ and __qualname__ are used by printing code. 54 | cf.__name__ == f"{f.__name__} (complex)" 55 | cf.__qualname__ == f"{f.__qualname__} (complex)" 56 | return cf 57 | 58 | 59 | class OnReIm(nn.Module): 60 | """Like `on_reim`, but for stateful modules. 61 | 62 | Args: 63 | module_cls (callable): A class or function that returns a Torch module/functional. 64 | Called 2x with *args, **kwargs, to construct the real and imaginary component modules. 65 | """ 66 | 67 | def __init__(self, module_cls, *args, **kwargs): 68 | super().__init__() 69 | self.re_module = module_cls(*args, **kwargs) 70 | self.im_module = module_cls(*args, **kwargs) 71 | 72 | def forward(self, x): 73 | return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag)) 74 | 75 | 76 | class ComplexMultiplicationWrapper(nn.Module): 77 | """Make a complex-valued module `F` from a real-valued module `f` by applying 78 | complex multiplication rules: 79 | 80 | F(a + i b) = f1(a) - f1(b) + i (f2(b) + f2(a)) 81 | 82 | where `f1`, `f2` are instances of `f` that do *not* share weights. 83 | 84 | Args: 85 | module_cls (callable): A class or function that returns a Torch module/functional. 86 | Constructor of `f` in the formula above. Called 2x with `*args`, `**kwargs`, 87 | to construct the real and imaginary component modules. 88 | """ 89 | 90 | def __init__(self, module_cls, *args, **kwargs): 91 | super().__init__() 92 | self.re_module = module_cls(*args, **kwargs) 93 | self.im_module = module_cls(*args, **kwargs) 94 | 95 | def forward(self, x: ComplexTensor) -> ComplexTensor: 96 | return torch_complex_from_reim( 97 | self.re_module(x.real) - self.im_module(x.imag), 98 | self.re_module(x.imag) + self.im_module(x.real), 99 | ) 100 | 101 | 102 | class ComplexSingleRNN(nn.Module): 103 | """Module for a complex RNN block. 104 | 105 | This is similar to :cls:`asteroid.masknn.recurrent.SingleRNN` but uses complex 106 | multiplication as described in [1]. Arguments are identical to those of `SingleRNN`, 107 | except for `dropout`, which is not yet supported. 108 | 109 | Args: 110 | rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can 111 | also be passed in lowercase letters. 112 | input_size (int): Dimension of the input feature. The input should have 113 | shape [batch, seq_len, input_size]. 114 | hidden_size (int): Dimension of the hidden state. 115 | n_layers (int, optional): Number of layers used in RNN. Default is 1. 116 | bidirectional (bool, optional): Whether the RNN layers are 117 | bidirectional. Default is ``False``. 118 | dropout: Not yet supported. 119 | 120 | References 121 | [1] : "DCCRN: Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement", 122 | Yanxin Hu et al. https://arxiv.org/abs/2008.00264 123 | """ 124 | 125 | def __init__( 126 | self, rnn_type, input_size, hidden_size, n_layers=1, dropout=0, bidirectional=False 127 | ): 128 | assert not (dropout and n_layers > 1), "Dropout is not yet supported for complex RNN" 129 | super().__init__() 130 | from .masknn.recurrent import SingleRNN # Avoid circual import 131 | 132 | kwargs = { 133 | "rnn_type": rnn_type, 134 | "hidden_size": hidden_size, 135 | "n_layers": 1, 136 | "dropout": 0, 137 | "bidirectional": bidirectional, 138 | } 139 | first_rnn = ComplexMultiplicationWrapper(SingleRNN, input_size=input_size, **kwargs) 140 | self.rnns = torch.nn.ModuleList([first_rnn]) 141 | for _ in range(n_layers - 1): 142 | self.rnns.append( 143 | ComplexMultiplicationWrapper( 144 | SingleRNN, input_size=first_rnn.re_module.output_size, **kwargs 145 | ) 146 | ) 147 | 148 | @property 149 | def output_size(self): 150 | return self.rnns[-1].re_module.output_size 151 | 152 | def forward(self, x: ComplexTensor) -> ComplexTensor: 153 | """Input shape [batch, seq, feats]""" 154 | for rnn in self.rnns: 155 | x = rnn(x) 156 | return x 157 | 158 | 159 | ComplexConv2d = functools.partial(ComplexMultiplicationWrapper, nn.Conv2d) 160 | ComplexConvTranspose2d = functools.partial(ComplexMultiplicationWrapper, nn.ConvTranspose2d) 161 | 162 | 163 | class BoundComplexMask(nn.Module): 164 | """Module version of `bound_complex_mask`""" 165 | 166 | def __init__(self, bound_type): 167 | super().__init__() 168 | self.bound_type = bound_type 169 | 170 | def forward(self, mask: ComplexTensor): 171 | return bound_complex_mask(mask, self.bound_type) 172 | 173 | 174 | def bound_complex_mask(mask: ComplexTensor, bound_type="tanh"): 175 | r"""Bound a complex mask, as proposed in [1], section 3.2. 176 | 177 | Valid bound types, for a complex mask :math:`M = |M| ⋅ e^{i φ(M)}`: 178 | 179 | - Unbounded ("UBD"): :math:`M_{\mathrm{UBD}} = M` 180 | - Sigmoid ("BDSS"): :math:`M_{\mathrm{BDSS}} = σ(|M|) e^{i σ(φ(M))}` 181 | - Tanh ("BDT"): :math:`M_{\mathrm{BDT}} = \mathrm{tanh}(|M|) e^{i φ(M)}` 182 | 183 | Args: 184 | bound_type (str or None): The type of bound to use, either of 185 | "tanh"/"bdt" (default), "sigmoid"/"bdss" or None/"bdt". 186 | 187 | References 188 | [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net", 189 | Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107 190 | """ 191 | if bound_type in {"BDSS", "sigmoid"}: 192 | return on_reim(torch.sigmoid)(mask) 193 | elif bound_type in {"BDT", "tanh", "UBD", None}: 194 | mask_mag, mask_phase = transforms.magphase(transforms.from_torch_complex(mask)) 195 | if bound_type in {"BDT", "tanh"}: 196 | mask_mag_bounded = torch.tanh(mask_mag) 197 | else: 198 | mask_mag_bounded = mask_mag 199 | return torch_complex_from_magphase(mask_mag_bounded, mask_phase) 200 | else: 201 | raise ValueError(f"Unknown mask bound {bound_type}") 202 | -------------------------------------------------------------------------------- /asteroid/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .mimii_dataset import MIMIIDataset 2 | from .mimii_single_dataset import MIMIISingleDataset 3 | from .mimii_valve_dataset import MIMIIValveDataset 4 | from .mimii_slider_dataset import MIMIISliderDataset 5 | __all__ = [ 6 | "MIMIIDataset", 7 | "MIMIISingleDataset", 8 | "MIMIIValveDataset", 9 | "MIMIISliderDataset", 10 | ] 11 | -------------------------------------------------------------------------------- /asteroid/data/mimii_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch.utils.data 3 | import random 4 | import torch 5 | import tqdm 6 | import soundfile as sf 7 | 8 | from torchaudio import transforms 9 | import librosa 10 | 11 | class MIMIIDataset(torch.utils.data.Dataset): 12 | """MUSDB18 music separation dataset 13 | 14 | Folder Structure: 15 | >>> #0dB/fan/id_00/normal/00000000.wav ---------| 16 | >>> #0dB/fan/id_02/normal/00000000.wav ---------| 17 | >>> #0dB/pump/id_00/normal/00000000.wav ---------| 18 | 19 | Args: 20 | root (str): Root path of dataset 21 | sources (:obj:`list` of :obj:`str`, optional): List of source names 22 | that composes the mixture. 23 | Defaults to MUSDB18 4 stem scenario: `vocals`, `drums`, `bass`, `other`. 24 | targets (list or None, optional): List of source names to be used as 25 | targets. If None, a dict with the 4 stems is returned. 26 | If e.g [`vocals`, `drums`], a tensor with stacked `vocals` and 27 | `drums` is returned instead of a dict. Defaults to None. 28 | suffix (str, optional): Filename suffix, defaults to `.wav`. 29 | split (str, optional): Dataset subfolder, defaults to `train`. 30 | subset (:obj:`list` of :obj:`str`, optional): Selects a specific of 31 | list of tracks to be loaded, defaults to `None` (loads all tracks). 32 | segment (float, optional): Duration of segments in seconds, 33 | defaults to ``None`` which loads the full-length audio tracks. 34 | samples_per_track (int, optional): 35 | Number of samples yielded from each track, can be used to increase 36 | dataset size, defaults to `1`. 37 | random_segments (boolean, optional): Enables random offset for track segments. 38 | random_track_mix boolean: enables mixing of random sources from 39 | different tracks to assemble mix. 40 | source_augmentations (:obj:`list` of :obj:`callable`): list of augmentation 41 | function names, defaults to no-op augmentations (input = output) 42 | sample_rate (int, optional): Samplerate of files in dataset. 43 | 44 | Attributes: 45 | root (str): Root path of dataset 46 | sources (:obj:`list` of :obj:`str`, optional): List of source names. 47 | Defaults to MUSDB18 4 stem scenario: `vocals`, `drums`, `bass`, `other`. 48 | suffix (str, optional): Filename suffix, defaults to `.wav`. 49 | split (str, optional): Dataset subfolder, defaults to `train`. 50 | subset (:obj:`list` of :obj:`str`, optional): Selects a specific of 51 | list of tracks to be loaded, defaults to `None` (loads all tracks). 52 | segment (float, optional): Duration of segments in seconds, 53 | defaults to ``None`` which loads the full-length audio tracks. 54 | samples_per_track (int, optional): 55 | Number of samples yielded from each track, can be used to increase 56 | dataset size, defaults to `1`. 57 | random_segments (boolean, optional): Enables random offset for track segments. 58 | random_track_mix boolean: enables mixing of random sources from 59 | different tracks to assemble mix. 60 | source_augmentations (:obj:`list` of :obj:`callable`): list of augmentation 61 | function names, defaults to no-op augmentations (input = output) 62 | sample_rate (int, optional): Samplerate of files in dataset. 63 | tracks (:obj:`list` of :obj:`Dict`): List of track metadata 64 | 65 | References 66 | "The 2018 Signal Separation Evaluation Campaign" Stoter et al. 2018. 67 | """ 68 | 69 | dataset_name = "MIMII" 70 | 71 | def __init__( 72 | self, 73 | root, 74 | sources=["fan", "pump", "slider", "valve"], 75 | targets=None, 76 | suffix=".wav", 77 | split="0dB", 78 | subset=None, 79 | segment=None, 80 | samples_per_track=2, 81 | random_segments=False, 82 | random_track_mix=False, 83 | source_augmentations=lambda audio: audio, 84 | sample_rate=16000, 85 | normal=True, 86 | use_control=False, 87 | ): 88 | 89 | self.root = Path(root).expanduser() 90 | self.split = split 91 | self.sample_rate = sample_rate 92 | self.segment = segment 93 | self.random_track_mix = random_track_mix 94 | self.random_segments = random_segments 95 | self.source_augmentations = source_augmentations 96 | self.sources = sources 97 | self.targets = targets 98 | self.suffix = suffix 99 | self.subset = subset 100 | self.samples_per_track = samples_per_track 101 | self.normal = normal 102 | self.tracks = list(self.get_tracks()) 103 | if not self.tracks: 104 | raise RuntimeError("No tracks found.") 105 | self.use_control = use_control 106 | self.normal = True 107 | 108 | def __getitem__(self, index): 109 | 110 | audio_sources = {} 111 | active_label_sources = {} 112 | 113 | # get track_id 114 | track_id = index // self.samples_per_track 115 | if self.random_segments: 116 | start = random.uniform(0, self.tracks[track_id]["min_duration"] - self.segment) 117 | else: 118 | start = 0 119 | 120 | # load sources 121 | for i, source in enumerate(self.sources): 122 | # optionally select a random track for each source 123 | if self.random_track_mix: 124 | # load a different track 125 | track_id = random.choice(range(len(self.tracks))) 126 | if self.random_segments: 127 | start = random.uniform(0, self.tracks[track_id]["min_duration"] - self.segment) 128 | 129 | # loads the full track duration 130 | start_sample = int(start * self.sample_rate) 131 | # check if dur is none 132 | if self.segment: 133 | # stop in soundfile is calc in samples, not seconds 134 | stop_sample = start_sample + int(self.segment * self.sample_rate) 135 | else: 136 | # set to None for reading complete file 137 | stop_sample = None 138 | 139 | # load actual audio 140 | audio, _ = sf.read( 141 | Path(self.tracks[track_id]["source_paths"][i]), 142 | always_2d=True, 143 | start=start_sample, 144 | stop=stop_sample, 145 | ) 146 | # convert to torch tensor 147 | audio = torch.tensor(audio.T, dtype=torch.float)[:, :] 148 | # apply source-wise augmentations 149 | audio = self.source_augmentations(audio) 150 | 151 | #apply mask 152 | audio_len = audio.shape[1] 153 | mask_len = random.randrange(audio_len//2) 154 | start_point = random.randrange(0, audio_len - mask_len) 155 | torch.clamp_(audio[:, start_point:start_point + mask_len], min=-0.001, max=0.001) 156 | audio_sources[source] = audio 157 | 158 | if self.use_control: 159 | active_label = torch.ones_like(audio) 160 | active_label[:, start_point:start_point + mask_len] = 0 161 | active_label_sources[source] = active_label 162 | # [channel, time] 163 | 164 | # #make mean label 165 | # dur_chunk = audio_len//1000 166 | # audio_threashold = 0.01 167 | # active_label = torch.empty((2, dur_chunk)) 168 | # num_chunk = audio_len//dur_chunk 169 | # if audio_len % dur_chunk != 0: 170 | # num_chunk = num_chunk + 1 171 | 172 | # for i in range(num_chunk): 173 | # chunk = audio[:, i*dur_chunk:(i+1)*dur_chunk] 174 | # chunk_label = torch.empty_like(chunk[:2, :]) 175 | # for j in range(2): 176 | # if torch.mean(torch.abs(chunk[j, :])) < audio_threashold: 177 | # chunk_label[j, :] = torch.zeros_like(chunk[j, :]) 178 | # else: 179 | # chunk_label[j, :] = torch.ones_like(chunk[j, :]) 180 | # active_label = torch.cat((active_label, chunk_label), dim = 1) 181 | # active_label_sources[source] = active_label[:, dur_chunk:] 182 | 183 | 184 | 185 | # apply linear mix over source index=0 186 | # make mixture for i-th channel and use 0-th chnnel as gt 187 | audioes = torch.stack([audio_sources[src] for src in self.targets]) 188 | audio_mix = torch.stack([audioes[i, 2 * i : 2 * i + 2, :] for i in range(len(self.sources))]).sum(0) 189 | 190 | if self.targets: 191 | audio_sources = audioes[:, 0:2, :] 192 | for i in range(len(self.sources)): 193 | audio_sources[i, :, :] = audioes[i, 2 * i : 2 * i + 2, :] 194 | 195 | if self.use_control: 196 | active_labels = torch.stack([active_label_sources[src] for src in self.targets]) 197 | # [source, channel, time] 198 | if self.targets: 199 | active_labels = active_labels[:, 0:2, :] 200 | 201 | return audio_mix, audio_sources, active_labels 202 | 203 | 204 | return audio_mix, audio_sources 205 | 206 | def __len__(self): 207 | return len(self.tracks) * self.samples_per_track 208 | 209 | def get_tracks(self): 210 | """Loads input and output tracks""" 211 | ids = ["id_00", "id_02", "id_04"] 212 | p = Path(self.root, self.split) 213 | pp = [] 214 | for id in ids: 215 | pp.extend(p.glob(f'fan/{id}/{"normal" if self.normal else "abnormal"}/*.wav')) 216 | 217 | for track_path in tqdm.tqdm(pp): 218 | # print(track_path) 219 | if self.subset and track_path.stem not in self.subset: 220 | # skip this track 221 | continue 222 | 223 | source_paths = [Path(str(track_path).replace(self.sources[0], s)) for s in self.sources] 224 | if not all(sp.exists() for sp in source_paths): 225 | print("Exclude track due to non-existing source", track_path) 226 | continue 227 | 228 | # get metadata 229 | infos = list(map(sf.info, source_paths)) 230 | if not all(i.samplerate == self.sample_rate for i in infos): 231 | print("Exclude track due to different sample rate ", track_path) 232 | continue 233 | 234 | if self.segment is not None: 235 | # get minimum duration of track 236 | min_duration = min(i.duration for i in infos) 237 | if min_duration > self.segment: 238 | yield ({"path": track_path, "min_duration": min_duration, "source_paths": source_paths}) 239 | else: 240 | yield ({"path": track_path, "min_duration": None, "source_paths": source_paths}) 241 | -------------------------------------------------------------------------------- /asteroid/data/mimii_single_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch.utils.data 3 | import random 4 | import torch 5 | import tqdm 6 | import soundfile as sf 7 | 8 | from torchaudio import transforms 9 | import librosa 10 | from itertools import product 11 | import numpy as np 12 | 13 | class MIMIISingleDataset(torch.utils.data.Dataset): 14 | 15 | dataset_name = "MIMII" 16 | 17 | def __init__( 18 | self, 19 | root, 20 | sources=["id_00", "id_02"], 21 | targets=None, 22 | suffix=".wav", 23 | split="0dB", 24 | subset=None, 25 | segment=None, 26 | samples_per_track=2, 27 | random_segments=False, 28 | random_track_mix=False, 29 | source_augmentations=lambda audio: audio, 30 | sample_rate=16000, 31 | normal=True, 32 | use_control=False, 33 | task_random = True, 34 | source_random = False, 35 | machine_type = None, 36 | control_type = 'rms', 37 | ): 38 | 39 | self.root = Path(root).expanduser() 40 | self.split = split 41 | self.sample_rate = sample_rate 42 | self.segment = segment 43 | self.random_track_mix = random_track_mix 44 | self.random_segments = random_segments 45 | self.source_augmentations = source_augmentations 46 | self.sources = sources 47 | self.targets = targets 48 | self.suffix = suffix 49 | self.subset = subset 50 | self.samples_per_track = samples_per_track 51 | self.normal = normal 52 | self.tracks = list(self.get_tracks()) 53 | if not self.tracks: 54 | raise RuntimeError("No tracks found.") 55 | self.use_control = use_control 56 | self.normal = True 57 | self.task_random = task_random 58 | self.source_random = source_random 59 | self.machine_type = machine_type 60 | self.control_type = control_type 61 | 62 | def __getitem__(self, index): 63 | 64 | audio_sources = {} 65 | active_label_sources = {} 66 | 67 | # get track_id 68 | track_id = index // self.samples_per_track 69 | if self.random_segments: 70 | start = random.uniform(0, self.tracks[track_id]["min_duration"] - self.segment) 71 | else: 72 | start = 0 73 | 74 | # load sources 75 | for i, source in enumerate(self.sources): 76 | # optionally select a random track for each source 77 | if self.random_track_mix: 78 | # load a different track 79 | track_id = random.choice(range(len(self.tracks))) 80 | if self.random_segments: 81 | start = random.uniform(0, self.tracks[track_id]["min_duration"] - self.segment) 82 | 83 | # loads the full track duration 84 | start_sample = int(start * self.sample_rate) 85 | # check if dur is none 86 | if self.segment: 87 | # stop in soundfile is calc in samples, not seconds 88 | stop_sample = start_sample + int(self.segment * self.sample_rate) 89 | else: 90 | # set to None for reading complete file 91 | stop_sample = None 92 | 93 | # load actual audio 94 | np_audio, _ = sf.read( 95 | Path(self.tracks[track_id]["source_paths"][i]), 96 | always_2d=True, 97 | start=start_sample, 98 | stop=stop_sample, 99 | ) 100 | # convert to torch tensor 101 | audio = torch.tensor(np_audio.T, dtype=torch.float)[:, :] 102 | # apply source-wise augmentations 103 | audio = self.source_augmentations(audio) 104 | 105 | # apply mask 106 | audio_len = audio.shape[1] 107 | mask_len = random.randrange(int(audio_len * 0.8)) 108 | if i == 0: 109 | start_point = 0 110 | else: 111 | start_point = audio_len - mask_len 112 | torch.clamp_(audio[:, start_point:start_point + mask_len], min=-0.01, max=0.01) 113 | audio_sources[source] = audio 114 | 115 | # make control signal 116 | if self.use_control: 117 | if self.control_type =='mfcc': 118 | mfcc = transforms.MFCC(log_mels=True, melkwargs={"n_mels":64}) 119 | n_mfcc = 8 120 | features = mfcc(audio)[0, :n_mfcc, :] 121 | features = features.unsqueeze(2).expand(-1, -1, 200).reshape(n_mfcc, -1)[:, :160000] 122 | label = features 123 | active_label_sources[source] = label 124 | 125 | elif self.control_type == 'rms': 126 | rms_fig = librosa.feature.rms(np.transpose(np_audio)) #[1, 313] 127 | rms_tensor = torch.tensor(rms_fig).reshape(1, -1, 1) 128 | rms_trim = rms_tensor.expand(-1, -1, 512).reshape(1, -1)[:, :160000] 129 | 130 | if self.machine_type == 'slider': 131 | min_threshold = (torch.max(rms_trim) + torch.min(rms_trim))/2 132 | elif self.machine_type == 'valve': 133 | k = int(audio.shape[1]*0.8) 134 | min_threshold, _ = torch.kthvalue(rms_trim, k) 135 | 136 | label = (rms_trim > min_threshold).type(torch.float) 137 | label = label.expand(audio.shape[0], -1) 138 | active_label_sources[source] = label 139 | 140 | 141 | # make mixture 142 | target_tmp = self.targets 143 | if self.task_random: 144 | targets = target_tmp.copy() 145 | random.shuffle(targets) 146 | else: 147 | targets = target_tmp 148 | audioes = torch.stack([audio_sources[src] for src in targets]) 149 | audio_mix = torch.stack([audioes[i, 0:2, :] for i in range(len(self.sources))]).sum(0) 150 | 151 | if targets: 152 | audio_sources = audioes[:, 0:2, :] 153 | if self.use_control: 154 | active_labels = torch.stack([active_label_sources[src] for src in targets]) 155 | # [source, channel, time] 156 | return audio_mix, audio_sources, active_labels 157 | 158 | return audio_mix, audio_sources 159 | 160 | def __len__(self): 161 | return len(self.tracks) * self.samples_per_track 162 | 163 | def get_tracks(self): 164 | """Loads input and output tracks""" 165 | p = Path(self.root, self.split) 166 | pp = [] 167 | pp.extend(p.glob(f'{self.machine_type}/{self.sources[0]}/{"normal" if self.normal else "abnormal"}/*.wav')) 168 | 169 | for track_path in tqdm.tqdm(pp): 170 | if self.subset and track_path.stem not in self.subset: 171 | # skip this track 172 | continue 173 | 174 | source_paths = [Path(str(track_path).replace(self.sources[0], s)) for s in self.sources] 175 | if not all(sp.exists() for sp in source_paths): 176 | print("Exclude track due to non-existing source", track_path) 177 | continue 178 | 179 | # get metadata 180 | infos = list(map(sf.info, source_paths)) 181 | if not all(i.samplerate == self.sample_rate for i in infos): 182 | print("Exclude track due to different sample rate ", track_path) 183 | continue 184 | 185 | if self.segment is not None: 186 | # get minimum duration of track 187 | min_duration = min(i.duration for i in infos) 188 | if min_duration > self.segment: 189 | yield ({"path": track_path, "min_duration": min_duration, "source_paths": source_paths}) 190 | else: 191 | yield ({"path": track_path, "min_duration": None, "source_paths": source_paths}) 192 | -------------------------------------------------------------------------------- /asteroid/data/mimii_slider_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch.utils.data 3 | import random 4 | import torch 5 | import tqdm 6 | import soundfile as sf 7 | 8 | from torchaudio import transforms 9 | import librosa 10 | from itertools import product 11 | import numpy as np 12 | import scipy 13 | 14 | from .mimii_valve_dataset import MIMIIValveDataset 15 | 16 | class MIMIISliderDataset(MIMIIValveDataset): 17 | 18 | dataset_name = "MIMII" 19 | 20 | def __init__( 21 | self, 22 | root, 23 | sources=["id_00", "id_02", "id_04", "id_06"], 24 | targets=None, 25 | suffix=".wav", 26 | split="0dB", 27 | subset=None, 28 | segment=None, 29 | samples_per_track=2, 30 | random_segments=False, 31 | random_track_mix=False, 32 | source_augmentations=lambda audio: audio, 33 | sample_rate=16000, 34 | normal=True, 35 | use_control=False, 36 | task_random=False, 37 | source_random=False, 38 | num_src_in_mix=2, 39 | ): 40 | 41 | super().__init__(root, 42 | sources=sources, 43 | targets=targets, 44 | suffix=suffix, 45 | split=split, 46 | subset=subset, 47 | segment=segment, 48 | samples_per_track=samples_per_track, 49 | random_segments=random_segments, 50 | random_track_mix=random_track_mix, 51 | source_augmentations=source_augmentations, 52 | sample_rate=sample_rate, 53 | normal=normal, 54 | use_control=use_control, 55 | task_random=task_random, 56 | source_random=source_random, 57 | num_src_in_mix=num_src_in_mix, 58 | machine_type_dir="slider" 59 | ) 60 | 61 | 62 | def generate_label(self, audio): 63 | # np, [1, 313] 64 | channels = audio.shape[0] 65 | rms_fig = librosa.feature.rms(y=audio.numpy()) 66 | #[c, 1, 313] 67 | rms_tensor = torch.tensor(rms_fig).permute(0, 2, 1) 68 | # [channel, time, 1] 69 | rms_trim = rms_tensor.expand(-1, -1, 512).reshape(channels, -1)[:, :160000] 70 | # [channel, time] 71 | 72 | min_threshold = (torch.max(rms_trim) + torch.min(rms_trim))/2 73 | 74 | label = (rms_trim > min_threshold).type(torch.float) 75 | label = torch.Tensor(scipy.ndimage.binary_dilation(label.numpy(), iterations=3)).type(torch.float) 76 | #[channel, time] 77 | return label 78 | -------------------------------------------------------------------------------- /asteroid/data/musdb18_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch.utils.data 3 | import random 4 | import torch 5 | import tqdm 6 | import soundfile as sf 7 | 8 | 9 | class MUSDB18Dataset(torch.utils.data.Dataset): 10 | """MUSDB18 music separation dataset 11 | 12 | The dataset consists of 150 full lengths music tracks (~10h duration) of 13 | different genres along with their isolated stems: 14 | `drums`, `bass`, `vocals` and `others`. 15 | 16 | Out-of-the-box, asteroid does only support MUSDB18-HQ which comes as 17 | uncompressed WAV files. To use the MUSDB18, please convert it to WAV first: 18 | 19 | - MUSDB18 HQ: https://zenodo.org/record/3338373 20 | - MUSDB18 https://zenodo.org/record/1117372 21 | 22 | .. note:: 23 | The datasets are hosted on Zenodo and require that users 24 | request access, since the tracks can only be used for academic purposes. 25 | We manually check this requests. 26 | 27 | This dataset asssumes music tracks in (sub)folders where each folder 28 | has a fixed number of sources (defaults to 4). For each track, a list 29 | of `sources` and a common `suffix` can be specified. 30 | A linear mix is performed on the fly by summing up the sources 31 | 32 | Due to the fact that all tracks comprise the exact same set 33 | of sources, random track mixing can be used can be used, 34 | where sources from different tracks are mixed together. 35 | 36 | Folder Structure: 37 | >>> #train/1/vocals.wav ---------| 38 | >>> #train/1/drums.wav ----------+--> input (mix), output[target] 39 | >>> #train/1/bass.wav -----------| 40 | >>> #train/1/other.wav ---------/ 41 | 42 | Args: 43 | root (str): Root path of dataset 44 | sources (:obj:`list` of :obj:`str`, optional): List of source names 45 | that composes the mixture. 46 | Defaults to MUSDB18 4 stem scenario: `vocals`, `drums`, `bass`, `other`. 47 | targets (list or None, optional): List of source names to be used as 48 | targets. If None, a dict with the 4 stems is returned. 49 | If e.g [`vocals`, `drums`], a tensor with stacked `vocals` and 50 | `drums` is returned instead of a dict. Defaults to None. 51 | suffix (str, optional): Filename suffix, defaults to `.wav`. 52 | split (str, optional): Dataset subfolder, defaults to `train`. 53 | subset (:obj:`list` of :obj:`str`, optional): Selects a specific of 54 | list of tracks to be loaded, defaults to `None` (loads all tracks). 55 | segment (float, optional): Duration of segments in seconds, 56 | defaults to ``None`` which loads the full-length audio tracks. 57 | samples_per_track (int, optional): 58 | Number of samples yielded from each track, can be used to increase 59 | dataset size, defaults to `1`. 60 | random_segments (boolean, optional): Enables random offset for track segments. 61 | random_track_mix boolean: enables mixing of random sources from 62 | different tracks to assemble mix. 63 | source_augmentations (:obj:`list` of :obj:`callable`): list of augmentation 64 | function names, defaults to no-op augmentations (input = output) 65 | sample_rate (int, optional): Samplerate of files in dataset. 66 | 67 | Attributes: 68 | root (str): Root path of dataset 69 | sources (:obj:`list` of :obj:`str`, optional): List of source names. 70 | Defaults to MUSDB18 4 stem scenario: `vocals`, `drums`, `bass`, `other`. 71 | suffix (str, optional): Filename suffix, defaults to `.wav`. 72 | split (str, optional): Dataset subfolder, defaults to `train`. 73 | subset (:obj:`list` of :obj:`str`, optional): Selects a specific of 74 | list of tracks to be loaded, defaults to `None` (loads all tracks). 75 | segment (float, optional): Duration of segments in seconds, 76 | defaults to ``None`` which loads the full-length audio tracks. 77 | samples_per_track (int, optional): 78 | Number of samples yielded from each track, can be used to increase 79 | dataset size, defaults to `1`. 80 | random_segments (boolean, optional): Enables random offset for track segments. 81 | random_track_mix boolean: enables mixing of random sources from 82 | different tracks to assemble mix. 83 | source_augmentations (:obj:`list` of :obj:`callable`): list of augmentation 84 | function names, defaults to no-op augmentations (input = output) 85 | sample_rate (int, optional): Samplerate of files in dataset. 86 | tracks (:obj:`list` of :obj:`Dict`): List of track metadata 87 | 88 | References 89 | "The 2018 Signal Separation Evaluation Campaign" Stoter et al. 2018. 90 | """ 91 | 92 | dataset_name = "MUSDB18" 93 | 94 | def __init__( 95 | self, 96 | root, 97 | sources=["vocals", "bass", "drums", "other"], 98 | targets=None, 99 | suffix=".wav", 100 | split="train", 101 | subset=None, 102 | segment=None, 103 | samples_per_track=1, 104 | random_segments=False, 105 | random_track_mix=False, 106 | source_augmentations=lambda audio: audio, 107 | sample_rate=44100, 108 | ): 109 | 110 | self.root = Path(root).expanduser() 111 | self.split = split 112 | self.sample_rate = sample_rate 113 | self.segment = segment 114 | self.random_track_mix = random_track_mix 115 | self.random_segments = random_segments 116 | self.source_augmentations = source_augmentations 117 | self.sources = sources 118 | self.targets = targets 119 | self.suffix = suffix 120 | self.subset = subset 121 | self.samples_per_track = samples_per_track 122 | self.tracks = list(self.get_tracks()) 123 | if not self.tracks: 124 | raise RuntimeError("No tracks found.") 125 | 126 | def __getitem__(self, index): 127 | # assemble the mixture of target and interferers 128 | audio_sources = {} 129 | 130 | # get track_id 131 | track_id = index // self.samples_per_track 132 | if self.random_segments: 133 | start = random.uniform(0, self.tracks[track_id]["min_duration"] - self.segment) 134 | else: 135 | start = 0 136 | 137 | # load sources 138 | for source in self.sources: 139 | # optionally select a random track for each source 140 | if self.random_track_mix: 141 | # load a different track 142 | track_id = random.choice(range(len(self.tracks))) 143 | if self.random_segments: 144 | start = random.uniform(0, self.tracks[track_id]["min_duration"] - self.segment) 145 | 146 | # loads the full track duration 147 | start_sample = int(start * self.sample_rate) 148 | # check if dur is none 149 | if self.segment: 150 | # stop in soundfile is calc in samples, not seconds 151 | stop_sample = start_sample + int(self.segment * self.sample_rate) 152 | else: 153 | # set to None for reading complete file 154 | stop_sample = None 155 | 156 | # load actual audio 157 | audio, _ = sf.read( 158 | Path(self.tracks[track_id]["path"] / source).with_suffix(self.suffix), 159 | always_2d=True, 160 | start=start_sample, 161 | stop=stop_sample, 162 | ) 163 | # convert to torch tensor 164 | audio = torch.tensor(audio.T, dtype=torch.float) 165 | # apply source-wise augmentations 166 | audio = self.source_augmentations(audio) 167 | audio_sources[source] = audio 168 | 169 | # apply linear mix over source index=0 170 | audio_mix = torch.stack(list(audio_sources.values())).sum(0) 171 | if self.targets: 172 | audio_sources = torch.stack( 173 | [wav for src, wav in audio_sources.items() if src in self.targets], dim=0 174 | ) 175 | return audio_mix, audio_sources 176 | 177 | def __len__(self): 178 | return len(self.tracks) * self.samples_per_track 179 | 180 | def get_tracks(self): 181 | """Loads input and output tracks""" 182 | p = Path(self.root, self.split) 183 | for track_path in tqdm.tqdm(p.iterdir()): 184 | if track_path.is_dir(): 185 | if self.subset and track_path.stem not in self.subset: 186 | # skip this track 187 | continue 188 | 189 | source_paths = [track_path / (s + self.suffix) for s in self.sources] 190 | if not all(sp.exists() for sp in source_paths): 191 | print("Exclude track due to non-existing source", track_path) 192 | continue 193 | 194 | # get metadata 195 | infos = list(map(sf.info, source_paths)) 196 | if not all(i.samplerate == self.sample_rate for i in infos): 197 | print("Exclude track due to different sample rate ", track_path) 198 | continue 199 | 200 | if self.segment is not None: 201 | # get minimum duration of track 202 | min_duration = min(i.duration for i in infos) 203 | if min_duration > self.segment: 204 | yield ({"path": track_path, "min_duration": min_duration}) 205 | else: 206 | yield ({"path": track_path, "min_duration": None}) 207 | 208 | def get_infos(self): 209 | """Get dataset infos (for publishing models). 210 | 211 | Returns: 212 | dict, dataset infos with keys `dataset`, `task` and `licences`. 213 | """ 214 | infos = dict() 215 | infos["dataset"] = self.dataset_name 216 | infos["task"] = "enhancement" 217 | infos["licenses"] = [musdb_license] 218 | return infos 219 | 220 | 221 | musdb_license = dict() 222 | -------------------------------------------------------------------------------- /asteroid/dsp/__init__.py: -------------------------------------------------------------------------------- 1 | from .consistency import mixture_consistency 2 | from .overlap_add import LambdaOverlapAdd, DualPathProcessing 3 | from .beamforming import ( 4 | SCM, 5 | Beamformer, 6 | RTFMVDRBeamformer, 7 | SoudenMVDRBeamformer, 8 | SDWMWFBeamformer, 9 | GEVBeamformer, 10 | ) 11 | 12 | __all__ = [ 13 | "mixture_consistency", 14 | "LambdaOverlapAdd", 15 | "DualPathProcessing", 16 | ] 17 | -------------------------------------------------------------------------------- /asteroid/dsp/consistency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, List 3 | 4 | 5 | def mixture_consistency( 6 | mixture: torch.Tensor, 7 | est_sources: torch.Tensor, 8 | src_weights: Optional[torch.Tensor] = None, 9 | dim: int = 1, 10 | ) -> torch.Tensor: 11 | """Applies mixture consistency to a tensor of estimated sources. 12 | 13 | Args: 14 | mixture (torch.Tensor): Mixture waveform or TF representation. 15 | est_sources (torch.Tensor): Estimated sources waveforms or TF representations. 16 | src_weights (torch.Tensor): Consistency weight for each source. 17 | Shape needs to be broadcastable to `est_source`. 18 | We make sure that the weights sum up to 1 along dim `dim`. 19 | If `src_weights` is None, compute them based on relative power. 20 | dim (int): Axis which contains the sources in `est_sources`. 21 | 22 | Returns 23 | torch.Tensor with same shape as `est_sources`, after applying mixture 24 | consistency. 25 | 26 | Examples 27 | >>> # Works on waveforms 28 | >>> mix = torch.randn(10, 16000) 29 | >>> est_sources = torch.randn(10, 2, 16000) 30 | >>> new_est_sources = mixture_consistency(mix, est_sources, dim=1) 31 | >>> # Also works on spectrograms 32 | >>> mix = torch.randn(10, 514, 400) 33 | >>> est_sources = torch.randn(10, 2, 514, 400) 34 | >>> new_est_sources = mixture_consistency(mix, est_sources, dim=1) 35 | 36 | .. note:: 37 | This method can be used only in 'complete' separation tasks, otherwise 38 | the residual error will contain unwanted sources. For example, this 39 | won't work with the task `"sep_noisy"` from WHAM. 40 | 41 | References 42 | Scott Wisdom et al. "Differentiable consistency constraints for improved 43 | deep speech enhancement", ICASSP 2019. 44 | """ 45 | # If the source weights are not specified, the weights are the relative 46 | # power of each source to the sum. w_i = P_i / (P_all), P for power. 47 | if src_weights is None: 48 | all_dims: List[int] = torch.arange(est_sources.ndim).tolist() 49 | all_dims.pop(dim) # Remove source axis 50 | all_dims.pop(0) # Remove batch axis 51 | src_weights = torch.mean(est_sources ** 2, dim=all_dims, keepdim=True) 52 | # Make sure that the weights sum up to 1 53 | norm_weights = torch.sum(src_weights, dim=dim, keepdim=True) + 1e-8 54 | src_weights = src_weights / norm_weights 55 | 56 | # Compute residual mix - sum(est_sources) 57 | if mixture.ndim == est_sources.ndim - 1: 58 | # mixture (batch, *), est_sources (batch, n_src, *) 59 | residual = (mixture - est_sources.sum(dim=dim)).unsqueeze(dim) 60 | elif mixture.ndim == est_sources.ndim: 61 | # mixture (batch, 1, *), est_sources (batch, n_src, *) 62 | residual = mixture - est_sources.sum(dim=dim, keepdim=True) 63 | else: 64 | n, m = est_sources.ndim, mixture.ndim 65 | raise RuntimeError( 66 | f"The size of the mixture tensor should match the " 67 | f"size of the est_sources tensor. Expected mixture" 68 | f"tensor to have {n} or {n-1} dimension, found {m}." 69 | ) 70 | # Compute remove 71 | new_sources = est_sources + src_weights * residual 72 | return new_sources 73 | -------------------------------------------------------------------------------- /asteroid/dsp/deltas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def compute_delta(feats: torch.Tensor, dim: int = -1) -> torch.Tensor: 5 | """Compute delta coefficients of a tensor. 6 | 7 | Args: 8 | feats: Input features to compute deltas with. 9 | dim: feature dimension in the feats tensor. 10 | 11 | Returns: 12 | Tensor: Tensor of deltas. 13 | 14 | Examples 15 | >>> import torch 16 | >>> phase = torch.randn(2, 257, 100) 17 | >>> # Compute instantaneous frequency 18 | >>> inst_freq = compute_delta(phase, dim=-1) 19 | >>> # Or group delay 20 | >>> group_delay = compute_delta(phase, dim=-2) 21 | """ 22 | if dim != -1: 23 | return compute_delta(feats.transpose(-1, dim), dim=-1).transpose(-1, dim) 24 | # First frame has nothing. Then each frame is the diff with the previous one. 25 | delta = feats.new_zeros(feats.shape) 26 | delta[..., 1:] = feats[..., 1:] - feats[..., :-1] 27 | return delta 28 | 29 | 30 | def concat_deltas(feats: torch.Tensor, order: int = 1, dim: int = -1) -> torch.Tensor: 31 | """Concatenate delta coefficients of a tensor to itself. 32 | 33 | Args: 34 | feats: Input features to compute deltas with. 35 | order: Order of the delta e.g with order==2, compute delta of delta 36 | as well. 37 | dim: feature dimension in the feats tensor. 38 | 39 | Returns: 40 | Tensor: Concatenation of the features, the deltas and subsequent deltas. 41 | 42 | Examples 43 | >>> import torch 44 | >>> phase = torch.randn(2, 257, 100) 45 | >>> # Compute second order instantaneous frequency 46 | >>> phase_and_inst_freq = concat_deltas(phase, order=2, dim=-1) 47 | >>> # Or group delay 48 | >>> phase_and_group_delay = concat_deltas(phase, order=2, dim=-2) 49 | """ 50 | all_feats = [feats] 51 | for _ in range(order): 52 | all_feats.append(compute_delta(all_feats[-1], dim=dim)) 53 | return torch.cat(all_feats, dim=dim) 54 | -------------------------------------------------------------------------------- /asteroid/dsp/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def normalize_estimates(est_np, mix_np): 5 | """Normalizes estimates according to the mixture maximum amplitude 6 | 7 | Args: 8 | est_np (np.array): Estimates with shape (n_src, time). 9 | mix_np (np.array): One mixture with shape (time, ). 10 | 11 | """ 12 | mix_max = np.max(np.abs(mix_np)) 13 | return np.stack([est * mix_max / np.max(np.abs(est)) for est in est_np]) 14 | -------------------------------------------------------------------------------- /asteroid/dsp/spatial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def xcorr(inp, ref, normalized=True, eps=1e-8): 6 | r"""Multi-channel cross correlation. 7 | 8 | The two signals can have different lengths but the input signal should be shorter than the reference signal. 9 | 10 | .. note:: The cross correlation is computed between each pair of microphone channels and not 11 | between all possible pairs e.g. if both input and ref have shape ``(1, 2, 100)`` 12 | the output will be ``(1, 2, 1)`` the first element is the xcorr between 13 | the first mic channel of input and the first mic channel of ref. 14 | If either input and ref have only one channel e.g. input: (1, 3, 100) and ref: ``(1, 1, 100)`` 15 | then output will be ``(1, 3, 1)`` as ref will be broadcasted to have same shape as input. 16 | 17 | Args: 18 | inp (:class:`torch.Tensor`): multi-channel input signal. Shape: :math:`(batch, mic\_channels, seq\_len)`. 19 | ref (:class:`torch.Tensor`): multi-channel reference signal. Shape: :math:`(batch, mic\_channels, seq\_len)`. 20 | normalized (bool, optional): whether to normalize the cross-correlation with the l2 norm of input signals. 21 | eps (float, optional): machine epsilon used for numerical stabilization when normalization is used. 22 | 23 | Returns: 24 | out (:class:`torch.Tensor`): cross correlation between the two multi-channel signals. 25 | Shape: :math:`(batch, mic\_channels, seq\_len\_ref - seq\_len\_input + 1)`. 26 | 27 | """ 28 | # inp: batch, nmics2, seq_len2 || ref: batch, nmics1, seq_len1 29 | assert inp.size(0) == ref.size(0), "ref and inp signals should have same batch size." 30 | assert inp.size(2) >= ref.size(2), "Input signal should be shorter than the ref signal." 31 | 32 | inp = inp.permute(1, 0, 2).contiguous() 33 | ref = ref.permute(1, 0, 2).contiguous() 34 | bsz = inp.size(1) 35 | inp_mics = inp.size(0) 36 | 37 | if ref.size(0) > inp.size(0): 38 | inp = inp.expand(ref.size(0), inp.size(1), inp.size(2)).contiguous() # nmic2, L, seg1 39 | inp_mics = ref.size(0) 40 | elif ref.size(0) < inp.size(0): 41 | ref = ref.expand(inp.size(0), ref.size(1), ref.size(2)).contiguous() # nmic1, L, seg2 42 | # cosine similarity 43 | out = F.conv1d( 44 | inp.view(1, -1, inp.size(2)), ref.view(-1, 1, ref.size(2)), groups=inp_mics * bsz 45 | ) # 1, inp_mics*L, seg1-seg2+1 46 | 47 | # L2 norms 48 | if normalized: 49 | inp_norm = F.conv1d( 50 | inp.view(1, -1, inp.size(2)).pow(2), 51 | torch.ones(inp.size(0) * inp.size(1), 1, ref.size(2)).type(inp.type()), 52 | groups=inp_mics * bsz, 53 | ) # 1, inp_mics*L, seg1-seg2+1 54 | inp_norm = inp_norm.sqrt() + eps 55 | ref_norm = ref.norm(2, dim=2).view(1, -1, 1) + eps # 1, inp_mics*L, 1 56 | out = out / (inp_norm * ref_norm) 57 | return out.view(inp_mics, bsz, -1).permute(1, 0, 2).contiguous() 58 | -------------------------------------------------------------------------------- /asteroid/dsp/vad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..utils.torch_utils import script_if_tracing 3 | 4 | 5 | @script_if_tracing 6 | def ebased_vad(mag_spec, th_db: int = 40): 7 | """Compute energy-based VAD from a magnitude spectrogram (or equivalent). 8 | 9 | Args: 10 | mag_spec (torch.Tensor): the spectrogram to perform VAD on. 11 | Expected shape (batch, *, freq, time). 12 | The VAD mask will be computed independently for all the leading 13 | dimensions until the last two. Independent of the ordering of the 14 | last two dimensions. 15 | th_db (int): The threshold in dB from which a TF-bin is considered 16 | silent. 17 | 18 | Returns: 19 | :class:`torch.BoolTensor`, the VAD mask. 20 | 21 | 22 | Examples 23 | >>> import torch 24 | >>> mag_spec = torch.abs(torch.randn(10, 2, 65, 16)) 25 | >>> batch_src_mask = ebased_vad(mag_spec) 26 | """ 27 | log_mag = 20 * torch.log10(mag_spec) 28 | # Compute VAD for each utterance in a batch independently. 29 | to_view = list(mag_spec.shape[:-2]) + [1, -1] 30 | max_log_mag = torch.max(log_mag.view(to_view), -1, keepdim=True)[0] 31 | return log_mag > (max_log_mag - th_db) 32 | -------------------------------------------------------------------------------- /asteroid/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .system import System 2 | from .optimizers import make_optimizer 3 | 4 | __all__ = ["System", "make_optimizer"] 5 | -------------------------------------------------------------------------------- /asteroid/engine/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.optimizer import Optimizer 2 | from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD 3 | from torch_optimizer import ( 4 | AccSGD, 5 | AdaBound, 6 | AdaMod, 7 | DiffGrad, 8 | Lamb, 9 | NovoGrad, 10 | PID, 11 | QHAdam, 12 | QHM, 13 | RAdam, 14 | SGDW, 15 | Yogi, 16 | Ranger, 17 | RangerQH, 18 | RangerVA, 19 | ) 20 | 21 | 22 | __all__ = [ 23 | "AccSGD", 24 | "AdaBound", 25 | "AdaMod", 26 | "DiffGrad", 27 | "Lamb", 28 | "NovoGrad", 29 | "PID", 30 | "QHAdam", 31 | "QHM", 32 | "RAdam", 33 | "SGDW", 34 | "Yogi", 35 | "Ranger", 36 | "RangerQH", 37 | "RangerVA", 38 | "Adam", 39 | "RMSprop", 40 | "SGD", 41 | "Adadelta", 42 | "Adagrad", 43 | "Adamax", 44 | "AdamW", 45 | "ASGD", 46 | "make_optimizer", 47 | "get", 48 | ] 49 | 50 | 51 | def make_optimizer(params, optimizer="adam", **kwargs): 52 | """ 53 | 54 | Args: 55 | params (iterable): Output of `nn.Module.parameters()`. 56 | optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood 57 | by :func:`~.get`. 58 | **kwargs (dict): keyword arguments for the optimizer. 59 | 60 | Returns: 61 | torch.optim.Optimizer 62 | Examples 63 | >>> from torch import nn 64 | >>> model = nn.Sequential(nn.Linear(10, 10)) 65 | >>> optimizer = make_optimizer(model.parameters(), optimizer='sgd', 66 | >>> lr=1e-3) 67 | """ 68 | return get(optimizer)(params, **kwargs) 69 | 70 | 71 | def register_optimizer(custom_opt): 72 | """Register a custom opt, gettable with `optimzers.get`. 73 | 74 | Args: 75 | custom_opt: Custom optimizer to register. 76 | 77 | """ 78 | if custom_opt.__name__ in globals().keys() or custom_opt.__name__.lower() in globals().keys(): 79 | raise ValueError(f"Activation {custom_opt.__name__} already exists. Choose another name.") 80 | globals().update({custom_opt.__name__: custom_opt}) 81 | 82 | 83 | def get(identifier): 84 | """Returns an optimizer function from a string. Returns its input if it 85 | is callable (already a :class:`torch.optim.Optimizer` for example). 86 | 87 | Args: 88 | identifier (str or Callable): the optimizer identifier. 89 | 90 | Returns: 91 | :class:`torch.optim.Optimizer` or None 92 | """ 93 | if isinstance(identifier, Optimizer): 94 | return identifier 95 | elif isinstance(identifier, str): 96 | to_get = {k.lower(): v for k, v in globals().items()} 97 | cls = to_get.get(identifier.lower()) 98 | if cls is None: 99 | raise ValueError(f"Could not interpret optimizer : {str(identifier)}") 100 | return cls 101 | raise ValueError(f"Could not interpret optimizer : {str(identifier)}") 102 | -------------------------------------------------------------------------------- /asteroid/engine/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | import pytorch_lightning as pl 4 | 5 | from ..losses import SinkPITLossWrapper 6 | 7 | 8 | class BaseScheduler(object): 9 | """Base class for the step-wise scheduler logic. 10 | 11 | Args: 12 | optimizer (Optimize): Optimizer instance to apply lr schedule on. 13 | 14 | Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler. 15 | """ 16 | 17 | def __init__(self, optimizer): 18 | self.optimizer = optimizer 19 | self.step_num = 0 20 | 21 | def zero_grad(self): 22 | self.optimizer.zero_grad() 23 | 24 | def _get_lr(self): 25 | raise NotImplementedError 26 | 27 | def _set_lr(self, lr): 28 | for param_group in self.optimizer.param_groups: 29 | param_group["lr"] = lr 30 | 31 | def step(self): 32 | """Update step-wise learning rate before optimizer.step.""" 33 | self.step_num += 1 34 | lr = self._get_lr() 35 | self._set_lr(lr) 36 | 37 | def load_state_dict(self, state_dict): 38 | self.__dict__.update(state_dict) 39 | 40 | def state_dict(self): 41 | return {key: value for key, value in self.__dict__.items() if key != "optimizer"} 42 | 43 | def as_tensor(self, start=0, stop=100_000): 44 | """Returns the scheduler values from start to stop.""" 45 | lr_list = [] 46 | for _ in range(start, stop): 47 | self.step_num += 1 48 | lr_list.append(self._get_lr()) 49 | self.step_num = 0 50 | return torch.tensor(lr_list) 51 | 52 | def plot(self, start=0, stop=100_000): # noqa 53 | """Plot the scheduler values from start to stop.""" 54 | import matplotlib.pyplot as plt 55 | 56 | all_lr = self.as_tensor(start=start, stop=stop) 57 | plt.plot(all_lr.numpy()) 58 | plt.show() 59 | 60 | 61 | class NoamScheduler(BaseScheduler): 62 | r"""The Noam learning rate scheduler, originally used in conjunction with 63 | the Adam optimizer in [1]. 64 | 65 | Args: 66 | optimizer (Optimizer): Optimizer instance to apply lr schedule on. 67 | d_model(int): The number of units in the layer output. 68 | warmup_steps (int): The number of steps in the warmup stage of training. 69 | scale (float): A fixed coefficient for rescaling the final learning rate. 70 | 71 | Schedule: 72 | The Noam scheduler increases the learning rate linearly for the first 73 | ``warmup_steps`` steps, and decreases it thereafter proportionally to the 74 | inverse square root of the step number: 75 | :math:`lr = scale\_factor * ( model\_dim^{-0.5} * adj\_step )` 76 | :math:`adj\_step = min(step\_num^{0.5}, step\_num * warmup\_steps^{-1.5})` 77 | 78 | References 79 | [1] Vaswani et al. (2017) "Attention is all you need". 31st 80 | Conference on Neural Information Processing Systems 81 | """ 82 | 83 | def __init__(self, optimizer, d_model, warmup_steps, scale=1.0): 84 | super().__init__(optimizer) 85 | self.d_model = d_model 86 | self.scale = scale 87 | self.warmup_steps = warmup_steps 88 | 89 | def _get_lr(self): 90 | lr = ( 91 | self.scale 92 | * self.d_model ** (-0.5) 93 | * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5)) 94 | ) 95 | return lr 96 | 97 | 98 | class DPTNetScheduler(BaseScheduler): 99 | """Dual Path Transformer Scheduler used in [1] 100 | 101 | Args: 102 | optimizer (Optimizer): Optimizer instance to apply lr schedule on. 103 | steps_per_epoch (int): Number of steps per epoch. 104 | d_model(int): The number of units in the layer output. 105 | warmup_steps (int): The number of steps in the warmup stage of training. 106 | noam_scale (float): Linear increase rate in first phase. 107 | exp_max (float): Max learning rate in second phase. 108 | exp_base (float): Exp learning rate base in second phase. 109 | 110 | Schedule: 111 | This scheduler increases the learning rate linearly for the first 112 | ``warmup_steps``, and then decay it by 0.98 for every two epochs. 113 | 114 | References 115 | [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context- 116 | Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020. 117 | """ 118 | 119 | def __init__( 120 | self, 121 | optimizer, 122 | steps_per_epoch, 123 | d_model, 124 | warmup_steps=4000, 125 | noam_scale=1.0, 126 | exp_max=0.0004, 127 | exp_base=0.98, 128 | ): 129 | super().__init__(optimizer) 130 | self.noam_scale = noam_scale 131 | self.d_model = d_model 132 | self.warmup_steps = warmup_steps 133 | self.exp_max = exp_max 134 | self.exp_base = exp_base 135 | self.steps_per_epoch = steps_per_epoch 136 | self.epoch = 0 137 | 138 | def _get_lr(self): 139 | if self.step_num % self.steps_per_epoch == 0: 140 | self.epoch += 1 141 | 142 | if self.step_num > self.warmup_steps: 143 | # exp decaying 144 | lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2)) 145 | else: 146 | # noam 147 | lr = ( 148 | self.noam_scale 149 | * self.d_model ** (-0.5) 150 | * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5)) 151 | ) 152 | return lr 153 | 154 | 155 | def sinkpit_default_beta_schedule(epoch): 156 | return min([1.02 ** epoch, 10]) 157 | 158 | 159 | class SinkPITBetaScheduler(pl.callbacks.Callback): 160 | r"""Scheduler of the beta value of SinkPITLossWrapper 161 | This module is used as a Callback function of `pytorch_lightning.Trainer`. 162 | 163 | Args: 164 | cooling_schedule (callable) : A callable that takes a parameter `epoch` (int) 165 | and returns the value of `beta` (float). 166 | 167 | The default function is ``sinkpit_default_beta_schedule``: :math:`\beta = min(1.02^{epoch}, 10)` 168 | 169 | Example 170 | >>> from pytorch_lightning import Trainer 171 | >>> from asteroid.losses import SinkPITBetaScheduler 172 | >>> # Default scheduling function 173 | >>> sinkpit_beta_schedule = SinkPITBetaSchedule() 174 | >>> trainer = Trainer(callbacks=[sinkpit_beta_schedule]) 175 | >>> # User-defined schedule 176 | >>> sinkpit_beta_schedule = SinkPITBetaScheduler(lambda ep: 1. if ep < 10 else 100.) 177 | >>> trainer = Trainer(callbacks=[sinkpit_beta_schedule]) 178 | """ 179 | 180 | def __init__(self, cooling_schedule=sinkpit_default_beta_schedule): 181 | self.cooling_schedule = cooling_schedule 182 | 183 | def on_epoch_start(self, trainer, pl_module): 184 | assert isinstance(pl_module.loss_func, SinkPITLossWrapper) 185 | assert trainer.current_epoch == pl_module.current_epoch # same 186 | epoch = pl_module.current_epoch 187 | # step = pl_module.global_step 188 | beta = self.cooling_schedule(epoch) 189 | pl_module.loss_func.beta = beta 190 | 191 | 192 | # Backward compat 193 | _BaseScheduler = BaseScheduler 194 | -------------------------------------------------------------------------------- /asteroid/engine/system.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from torch.optim.lr_scheduler import ReduceLROnPlateau 4 | 5 | from ..utils import flatten_dict 6 | 7 | 8 | class System(pl.LightningModule): 9 | """Base class for deep learning systems. 10 | Contains a model, an optimizer, a loss function, training and validation 11 | dataloaders and learning rate scheduler. 12 | 13 | Note that by default, any PyTorch-Lightning hooks are *not* passed to the model. 14 | If you want to use Lightning hooks, add the hooks to a subclass:: 15 | 16 | class MySystem(System): 17 | def on_train_batch_start(self, batch, batch_idx, dataloader_idx): 18 | return self.model.on_train_batch_start(batch, batch_idx, dataloader_idx) 19 | 20 | Args: 21 | model (torch.nn.Module): Instance of model. 22 | optimizer (torch.optim.Optimizer): Instance or list of optimizers. 23 | loss_func (callable): Loss function with signature 24 | (est_targets, targets). 25 | train_loader (torch.utils.data.DataLoader): Training dataloader. 26 | val_loader (torch.utils.data.DataLoader): Validation dataloader. 27 | scheduler (torch.optim.lr_scheduler._LRScheduler): Instance, or list 28 | of learning rate schedulers. Also supports dict or list of dict as 29 | ``{"interval": "step", "scheduler": sched}`` where ``interval=="step"`` 30 | for step-wise schedulers and ``interval=="epoch"`` for classical ones. 31 | config: Anything to be saved with the checkpoints during training. 32 | The config dictionary to re-instantiate the run for example. 33 | 34 | .. note:: By default, ``training_step`` (used by ``pytorch-lightning`` in the 35 | training loop) and ``validation_step`` (used for the validation loop) 36 | share ``common_step``. If you want different behavior for the training 37 | loop and the validation loop, overwrite both ``training_step`` and 38 | ``validation_step`` instead. 39 | 40 | For more info on its methods, properties and hooks, have a look at lightning's docs: 41 | https://pytorch-lightning.readthedocs.io/en/stable/lightning_module.html#lightningmodule-api 42 | """ 43 | 44 | default_monitor: str = "val_loss" 45 | 46 | def __init__( 47 | self, 48 | model, 49 | optimizer, 50 | loss_func, 51 | train_loader, 52 | val_loader=None, 53 | scheduler=None, 54 | config=None, 55 | ): 56 | super().__init__() 57 | self.model = model 58 | self.optimizer = optimizer 59 | self.loss_func = loss_func 60 | self.train_loader = train_loader 61 | self.val_loader = val_loader 62 | self.scheduler = scheduler 63 | self.config = {} if config is None else config 64 | # Save lightning's AttributeDict under self.hparams 65 | self.save_hyperparameters(self.config_to_hparams(self.config)) 66 | 67 | def forward(self, *args, **kwargs): 68 | """Applies forward pass of the model. 69 | 70 | Returns: 71 | :class:`torch.Tensor` 72 | """ 73 | return self.model(*args, **kwargs) 74 | 75 | def common_step(self, batch, batch_nb, train=True): 76 | """Common forward step between training and validation. 77 | 78 | The function of this method is to unpack the data given by the loader, 79 | forward the batch through the model and compute the loss. 80 | Pytorch-lightning handles all the rest. 81 | 82 | Args: 83 | batch: the object returned by the loader (a list of torch.Tensor 84 | in most cases) but can be something else. 85 | batch_nb (int): The number of the batch in the epoch. 86 | train (bool): Whether in training mode. Needed only if the training 87 | and validation steps are fundamentally different, otherwise, 88 | pytorch-lightning handles the usual differences. 89 | 90 | Returns: 91 | :class:`torch.Tensor` : The loss value on this batch. 92 | 93 | .. note:: 94 | This is typically the method to overwrite when subclassing 95 | ``System``. If the training and validation steps are somehow 96 | different (except for ``loss.backward()`` and ``optimzer.step()``), 97 | the argument ``train`` can be used to switch behavior. 98 | Otherwise, ``training_step`` and ``validation_step`` can be overwriten. 99 | """ 100 | inputs, targets = batch 101 | est_targets = self(inputs) 102 | loss = self.loss_func(est_targets, targets) 103 | return loss 104 | 105 | def training_step(self, batch, batch_nb): 106 | """Pass data through the model and compute the loss. 107 | 108 | Backprop is **not** performed (meaning PL will do it for you). 109 | 110 | Args: 111 | batch: the object returned by the loader (a list of torch.Tensor 112 | in most cases) but can be something else. 113 | batch_nb (int): The number of the batch in the epoch. 114 | 115 | Returns: 116 | torch.Tensor, the value of the loss. 117 | """ 118 | loss = self.common_step(batch, batch_nb, train=True) 119 | self.log("loss", loss, logger=True) 120 | return loss 121 | 122 | def validation_step(self, batch, batch_nb): 123 | """Need to overwrite PL validation_step to do validation. 124 | 125 | Args: 126 | batch: the object returned by the loader (a list of torch.Tensor 127 | in most cases) but can be something else. 128 | batch_nb (int): The number of the batch in the epoch. 129 | """ 130 | loss = self.common_step(batch, batch_nb, train=False) 131 | self.log("val_loss", loss, on_epoch=True, prog_bar=True) 132 | 133 | def on_validation_epoch_end(self): 134 | """Log hp_metric to tensorboard for hparams selection.""" 135 | hp_metric = self.trainer.callback_metrics.get("val_loss", None) 136 | if hp_metric is not None: 137 | self.trainer.logger.log_metrics({"hp_metric": hp_metric}, step=self.trainer.global_step) 138 | 139 | def configure_optimizers(self): 140 | """Initialize optimizers, batch-wise and epoch-wise schedulers.""" 141 | if self.scheduler is None: 142 | return self.optimizer 143 | 144 | if not isinstance(self.scheduler, (list, tuple)): 145 | self.scheduler = [self.scheduler] # support multiple schedulers 146 | 147 | epoch_schedulers = [] 148 | for sched in self.scheduler: 149 | if not isinstance(sched, dict): 150 | if isinstance(sched, ReduceLROnPlateau): 151 | sched = {"scheduler": sched, "monitor": self.default_monitor} 152 | epoch_schedulers.append(sched) 153 | else: 154 | sched.setdefault("monitor", self.default_monitor) 155 | sched.setdefault("frequency", 1) 156 | # Backward compat 157 | if sched["interval"] == "batch": 158 | sched["interval"] = "step" 159 | assert sched["interval"] in [ 160 | "epoch", 161 | "step", 162 | ], "Scheduler interval should be either step or epoch" 163 | epoch_schedulers.append(sched) 164 | return [self.optimizer], epoch_schedulers 165 | 166 | def train_dataloader(self): 167 | """Training dataloader""" 168 | return self.train_loader 169 | 170 | def val_dataloader(self): 171 | """Validation dataloader""" 172 | return self.val_loader 173 | 174 | def on_save_checkpoint(self, checkpoint): 175 | """Overwrite if you want to save more things in the checkpoint.""" 176 | checkpoint["training_config"] = self.config 177 | return checkpoint 178 | 179 | @staticmethod 180 | def config_to_hparams(dic): 181 | """Sanitizes the config dict to be handled correctly by torch 182 | SummaryWriter. It flatten the config dict, converts ``None`` to 183 | ``"None"`` and any list and tuple into torch.Tensors. 184 | 185 | Args: 186 | dic (dict): Dictionary to be transformed. 187 | 188 | Returns: 189 | dict: Transformed dictionary. 190 | """ 191 | dic = flatten_dict(dic) 192 | for k, v in dic.items(): 193 | if v is None: 194 | dic[k] = str(v) 195 | elif isinstance(v, (list, tuple)): 196 | dic[k] = torch.tensor(v) 197 | return dic 198 | -------------------------------------------------------------------------------- /asteroid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .pit_wrapper import PITLossWrapper 2 | from .mixit_wrapper import MixITLossWrapper 3 | from .sinkpit_wrapper import SinkPITLossWrapper 4 | from .sdr import PairwiseNegSDR 5 | from .sdr import pairwise_neg_sisdr, singlesrc_neg_sisdr, multisrc_neg_sisdr 6 | from .sdr import pairwise_neg_sdsdr, singlesrc_neg_sdsdr, multisrc_neg_sdsdr 7 | from .sdr import pairwise_neg_snr, singlesrc_neg_snr, multisrc_neg_snr 8 | from .mse import pairwise_mse, singlesrc_mse, multisrc_mse 9 | from .cluster import deep_clustering_loss 10 | from .pmsqe import SingleSrcPMSQE 11 | from .multi_scale_spectral import SingleSrcMultiScaleSpectral 12 | 13 | try: 14 | from .stoi import NegSTOILoss as SingleSrcNegSTOI 15 | except ModuleNotFoundError: 16 | # Is installed with asteroid, but remove the deps for TorchHub. 17 | def f(): 18 | raise ModuleNotFoundError("No module named 'torch_stoi'") 19 | 20 | SingleSrcNegSTOI = lambda *a, **kw: f() 21 | 22 | 23 | __all__ = [ 24 | "PITLossWrapper", 25 | "MixITLossWrapper", 26 | "SinkPITLossWrapper", 27 | "PairwiseNegSDR", 28 | "singlesrc_neg_sisdr", 29 | "pairwise_neg_sisdr", 30 | "multisrc_neg_sisdr", 31 | "pairwise_neg_sdsdr", 32 | "singlesrc_neg_sdsdr", 33 | "multisrc_neg_sdsdr", 34 | "pairwise_neg_snr", 35 | "singlesrc_neg_snr", 36 | "multisrc_neg_snr", 37 | "pairwise_mse", 38 | "singlesrc_mse", 39 | "multisrc_mse", 40 | "deep_clustering_loss", 41 | "SingleSrcPMSQE", 42 | "SingleSrcNegSTOI", 43 | "SingleSrcMultiScaleSpectral", 44 | ] 45 | -------------------------------------------------------------------------------- /asteroid/losses/bark_matrix_16k.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-postech/SSAD/642ee27f5e0c10ec9e3b643b49b12ac94843f998/asteroid/losses/bark_matrix_16k.mat -------------------------------------------------------------------------------- /asteroid/losses/bark_matrix_8k.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-postech/SSAD/642ee27f5e0c10ec9e3b643b49b12ac94843f998/asteroid/losses/bark_matrix_8k.mat -------------------------------------------------------------------------------- /asteroid/losses/cluster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def deep_clustering_loss(embedding, tgt_index, binary_mask=None): 5 | r"""Compute the deep clustering loss defined in [1]. 6 | 7 | Args: 8 | embedding (torch.Tensor): Estimated embeddings. 9 | Expected shape :math:`(batch, frequency * frame, embedding\_dim)`. 10 | tgt_index (torch.Tensor): Dominating source index in each TF bin. 11 | Expected shape: :math:`(batch, frequency, frame)`. 12 | binary_mask (torch.Tensor): VAD in TF plane. Bool or Float. 13 | See asteroid.dsp.vad.ebased_vad. 14 | 15 | Returns: 16 | `torch.Tensor`. Deep clustering loss for every batch sample. 17 | 18 | Examples 19 | >>> import torch 20 | >>> from asteroid.losses.cluster import deep_clustering_loss 21 | >>> spk_cnt = 3 22 | >>> embedding = torch.randn(10, 5*400, 20) 23 | >>> targets = torch.LongTensor([10, 400, 5]).random_(0, spk_cnt) 24 | >>> loss = deep_clustering_loss(embedding, targets) 25 | 26 | Reference 27 | [1] Zhong-Qiu Wang, Jonathan Le Roux, John R. Hershey 28 | "ALTERNATIVE OBJECTIVE FUNCTIONS FOR DEEP CLUSTERING" 29 | 30 | .. note:: 31 | Be careful in viewing the embedding tensors. The target indices 32 | ``tgt_index`` are of shape :math:`(batch, freq, frames)`. Even if the embedding 33 | is of shape :math:`(batch, freq * frames, emb)`, the underlying view should be 34 | :math:`(batch, freq, frames, emb)` and not :math:`(batch, frames, freq, emb)`. 35 | """ 36 | spk_cnt = len(tgt_index.unique()) 37 | 38 | batch, bins, frames = tgt_index.shape 39 | if binary_mask is None: 40 | binary_mask = torch.ones(batch, bins * frames, 1) 41 | binary_mask = binary_mask.float() 42 | if len(binary_mask.shape) == 3: 43 | binary_mask = binary_mask.view(batch, bins * frames, 1) 44 | # If boolean mask, make it float. 45 | binary_mask = binary_mask.to(tgt_index.device) 46 | 47 | # Fill in one-hot vector for each TF bin 48 | tgt_embedding = torch.zeros(batch, bins * frames, spk_cnt, device=tgt_index.device) 49 | tgt_embedding.scatter_(2, tgt_index.view(batch, bins * frames, 1), 1) 50 | 51 | # Compute VAD-weighted DC loss 52 | tgt_embedding = tgt_embedding * binary_mask 53 | embedding = embedding * binary_mask 54 | est_proj = torch.einsum("ijk,ijl->ikl", embedding, embedding) 55 | true_proj = torch.einsum("ijk,ijl->ikl", tgt_embedding, tgt_embedding) 56 | true_est_proj = torch.einsum("ijk,ijl->ikl", embedding, tgt_embedding) 57 | # Equation (1) in [1] 58 | cost = batch_matrix_norm(est_proj) + batch_matrix_norm(true_proj) 59 | cost = cost - 2 * batch_matrix_norm(true_est_proj) 60 | # Divide by number of active bins, for each element in batch 61 | return cost / torch.sum(binary_mask, dim=[1, 2]) 62 | 63 | 64 | def batch_matrix_norm(matrix, norm_order=2): 65 | """Normalize a matrix according to `norm_order` 66 | 67 | Args: 68 | matrix (torch.Tensor): Expected shape [batch, *] 69 | norm_order (int): Norm order. 70 | 71 | Returns: 72 | torch.Tensor, normed matrix of shape [batch] 73 | """ 74 | keep_batch = list(range(1, matrix.ndim)) 75 | return torch.norm(matrix, p=norm_order, dim=keep_batch) ** norm_order 76 | -------------------------------------------------------------------------------- /asteroid/losses/mixit_wrapper.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from itertools import combinations 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MixITLossWrapper(nn.Module): 8 | r"""Mixture invariant loss wrapper. 9 | 10 | Args: 11 | loss_func: function with signature (est_targets, targets, **kwargs). 12 | generalized (bool): Determines how MixIT is applied. If False , 13 | apply MixIT for any number of mixtures as soon as they contain 14 | the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.) 15 | If True (default), apply MixIT for two mixtures, but those mixtures do not 16 | necessarly have to contain the same number of sources. 17 | See :meth:`~MixITLossWrapper.best_part_mixit_generalized`. 18 | 19 | For each of these modes, the best partition and reordering will be 20 | automatically computed. 21 | 22 | Examples: 23 | >>> import torch 24 | >>> from asteroid.losses import multisrc_mse 25 | >>> mixtures = torch.randn(10, 2, 16000) 26 | >>> est_sources = torch.randn(10, 4, 16000) 27 | >>> # Compute MixIT loss based on pairwise losses 28 | >>> loss_func = MixITLossWrapper(multisrc_mse) 29 | >>> loss_val = loss_func(est_sources, mixtures) 30 | 31 | References 32 | [1] Scott Wisdom et al. "Unsupervised sound separation using 33 | mixtures of mixtures." arXiv:2006.12701 (2020) 34 | """ 35 | 36 | def __init__(self, loss_func, generalized=True): 37 | super().__init__() 38 | self.loss_func = loss_func 39 | self.generalized = generalized 40 | 41 | def forward(self, est_targets, targets, return_est=False, **kwargs): 42 | r"""Find the best partition and return the loss. 43 | 44 | Args: 45 | est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`. 46 | The batch of target estimates. 47 | targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. 48 | The batch of training targets 49 | return_est: Boolean. Whether to return the estimated mixtures 50 | estimates (To compute metrics or to save example). 51 | **kwargs: additional keyword argument that will be passed to the 52 | loss function. 53 | 54 | Returns: 55 | - Best partition loss for each batch sample, average over 56 | the batch. torch.Tensor(loss_value) 57 | - The estimated mixtures (estimated sources summed according to the partition) 58 | if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`. 59 | """ 60 | # Check input dimensions 61 | assert est_targets.shape[0] == targets.shape[0] 62 | assert est_targets.shape[2] == targets.shape[2] 63 | 64 | if not self.generalized: 65 | min_loss, min_loss_idx, parts = self.best_part_mixit( 66 | self.loss_func, est_targets, targets, **kwargs 67 | ) 68 | else: 69 | min_loss, min_loss_idx, parts = self.best_part_mixit_generalized( 70 | self.loss_func, est_targets, targets, **kwargs 71 | ) 72 | # Take the mean over the batch 73 | mean_loss = torch.mean(min_loss) 74 | if not return_est: 75 | return mean_loss 76 | # Order and sum on the best partition to get the estimated mixtures 77 | reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts) 78 | return mean_loss, reordered 79 | 80 | @staticmethod 81 | def best_part_mixit(loss_func, est_targets, targets, **kwargs): 82 | r"""Find best partition of the estimated sources that gives the minimum 83 | loss for the MixIT training paradigm in [1]. Valid for any number of 84 | mixtures as soon as they contain the same number of sources. 85 | 86 | Args: 87 | loss_func: function with signature ``(est_targets, targets, **kwargs)`` 88 | The loss function to get batch losses from. 89 | est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. 90 | The batch of target estimates. 91 | targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. 92 | The batch of training targets (mixtures). 93 | **kwargs: additional keyword argument that will be passed to the 94 | loss function. 95 | 96 | Returns: 97 | - :class:`torch.Tensor`: 98 | The loss corresponding to the best permutation of size (batch,). 99 | 100 | - :class:`torch.LongTensor`: 101 | The indices of the best partition. 102 | 103 | - :class:`list`: 104 | list of the possible partitions of the sources. 105 | 106 | """ 107 | nmix = targets.shape[1] 108 | nsrc = est_targets.shape[1] 109 | if nsrc % nmix != 0: 110 | raise ValueError("The mixtures are assumed to contain the same number of sources") 111 | nsrcmix = nsrc // nmix 112 | 113 | # Generate all unique partitions of size k from a list lst of 114 | # length n, where l = n // k is the number of parts. The total 115 | # number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!) 116 | # Algorithm recursively distributes items over parts 117 | def parts_mixit(lst, k, l): 118 | if l == 0: 119 | yield [] 120 | else: 121 | for c in combinations(lst, k): 122 | rest = [x for x in lst if x not in c] 123 | for r in parts_mixit(rest, k, l - 1): 124 | yield [list(c), *r] 125 | 126 | # Generate all the possible partitions 127 | parts = list(parts_mixit(range(nsrc), nsrcmix, nmix)) 128 | # Compute the loss corresponding to each partition 129 | loss_set = MixITLossWrapper.loss_set_from_parts( 130 | loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs 131 | ) 132 | # Indexes and values of min losses for each batch element 133 | min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) 134 | return min_loss, min_loss_indexes, parts 135 | 136 | @staticmethod 137 | def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs): 138 | r"""Find best partition of the estimated sources that gives the minimum 139 | loss for the MixIT training paradigm in [1]. Valid only for two mixtures, 140 | but those mixtures do not necessarly have to contain the same number of 141 | sources e.g the case where one mixture is silent is allowed.. 142 | 143 | Args: 144 | loss_func: function with signature ``(est_targets, targets, **kwargs)`` 145 | The loss function to get batch losses from. 146 | est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. 147 | The batch of target estimates. 148 | targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. 149 | The batch of training targets (mixtures). 150 | **kwargs: additional keyword argument that will be passed to the 151 | loss function. 152 | 153 | Returns: 154 | - :class:`torch.Tensor`: 155 | The loss corresponding to the best permutation of size (batch,). 156 | 157 | - :class:`torch.LongTensor`: 158 | The indexes of the best permutations. 159 | 160 | - :class:`list`: 161 | list of the possible partitions of the sources. 162 | """ 163 | nmix = targets.shape[1] # number of mixtures 164 | nsrc = est_targets.shape[1] # number of estimated sources 165 | if nmix != 2: 166 | raise ValueError("Works only with two mixtures") 167 | 168 | # Generate all unique partitions of any size from a list lst of 169 | # length n. Algorithm recursively distributes items over parts 170 | def parts_mixit_gen(lst): 171 | partitions = [] 172 | for k in range(len(lst) + 1): 173 | for c in combinations(lst, k): 174 | rest = [x for x in lst if x not in c] 175 | partitions.append([list(c), rest]) 176 | return partitions 177 | 178 | # Generate all the possible partitions 179 | parts = parts_mixit_gen(range(nsrc)) 180 | # Compute the loss corresponding to each partition 181 | loss_set = MixITLossWrapper.loss_set_from_parts( 182 | loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs 183 | ) 184 | # Indexes and values of min losses for each batch element 185 | min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True) 186 | return min_loss, min_loss_indexes, parts 187 | 188 | @staticmethod 189 | def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs): 190 | """Common loop between both best_part_mixit""" 191 | loss_set = [] 192 | for partition in parts: 193 | # sum the sources according to the given partition 194 | est_mixes = torch.stack([est_targets[:, idx, :].sum(1) for idx in partition], dim=1) 195 | # get loss for the given partition 196 | loss_set.append(loss_func(est_mixes, targets, **kwargs)[:, None]) 197 | loss_set = torch.cat(loss_set, dim=1) 198 | return loss_set 199 | 200 | @staticmethod 201 | def reorder_source(est_targets, targets, min_loss_idx, parts): 202 | """Reorder sources according to the best partition. 203 | 204 | Args: 205 | est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. 206 | The batch of target estimates. 207 | targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`. 208 | The batch of training targets. 209 | min_loss_idx: torch.LongTensor. The indexes of the best permutations. 210 | parts: list of the possible partitions of the sources. 211 | 212 | Returns: 213 | :class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`. 214 | 215 | """ 216 | # For each batch there is a different min_loss_idx 217 | ordered = torch.zeros_like(targets) 218 | for b, idx in enumerate(min_loss_idx): 219 | right_partition = parts[idx] 220 | # Sum the estimated sources to get the estimated mixtures 221 | ordered[b, :, :] = torch.stack( 222 | [est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1 223 | ) 224 | 225 | return ordered 226 | -------------------------------------------------------------------------------- /asteroid/losses/mse.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.loss import _Loss 2 | 3 | 4 | class PairwiseMSE(_Loss): 5 | r"""Measure pairwise mean square error on a batch. 6 | 7 | Shape: 8 | - est_targets : :math:`(batch, nsrc, ...)`. 9 | - targets: :math:`(batch, nsrc, ...)`. 10 | 11 | Returns: 12 | :class:`torch.Tensor`: with shape :math:`(batch, nsrc, nsrc)` 13 | 14 | Examples 15 | >>> import torch 16 | >>> from asteroid.losses import PITLossWrapper 17 | >>> targets = torch.randn(10, 2, 32000) 18 | >>> est_targets = torch.randn(10, 2, 32000) 19 | >>> loss_func = PITLossWrapper(PairwiseMSE(), pit_from='pairwise') 20 | >>> loss = loss_func(est_targets, targets) 21 | """ 22 | 23 | def forward(self, est_targets, targets): 24 | if targets.size() != est_targets.size() or targets.ndim < 3: 25 | raise TypeError( 26 | f"Inputs must be of shape [batch, n_src, *], got {targets.size()} and {est_targets.size()} instead" 27 | ) 28 | targets = targets.unsqueeze(1) 29 | est_targets = est_targets.unsqueeze(2) 30 | pw_loss = (targets - est_targets) ** 2 31 | # Need to return [batch, nsrc, nsrc] 32 | mean_over = list(range(3, pw_loss.ndim)) 33 | return pw_loss.mean(dim=mean_over) 34 | 35 | 36 | class SingleSrcMSE(_Loss): 37 | r"""Measure mean square error on a batch. 38 | Supports both tensors with and without source axis. 39 | 40 | Shape: 41 | - est_targets: :math:`(batch, ...)`. 42 | - targets: :math:`(batch, ...)`. 43 | 44 | Returns: 45 | :class:`torch.Tensor`: with shape :math:`(batch)` 46 | 47 | Examples 48 | >>> import torch 49 | >>> from asteroid.losses import PITLossWrapper 50 | >>> targets = torch.randn(10, 2, 32000) 51 | >>> est_targets = torch.randn(10, 2, 32000) 52 | >>> # singlesrc_mse / multisrc_mse support both 'pw_pt' and 'perm_avg'. 53 | >>> loss_func = PITLossWrapper(singlesrc_mse, pit_from='pw_pt') 54 | >>> loss = loss_func(est_targets, targets) 55 | """ 56 | 57 | def forward(self, est_targets, targets): 58 | if targets.size() != est_targets.size() or targets.ndim < 2: 59 | raise TypeError( 60 | f"Inputs must be of shape [batch, *], got {targets.size()} and {est_targets.size()} instead" 61 | ) 62 | loss = (targets - est_targets) ** 2 63 | mean_over = list(range(1, loss.ndim)) 64 | return loss.mean(dim=mean_over) 65 | 66 | 67 | # aliases 68 | MultiSrcMSE = SingleSrcMSE 69 | pairwise_mse = PairwiseMSE() 70 | singlesrc_mse = SingleSrcMSE() 71 | multisrc_mse = MultiSrcMSE() 72 | -------------------------------------------------------------------------------- /asteroid/losses/multi_scale_spectral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.modules.loss import _Loss 4 | from asteroid_filterbanks import STFTFB, Encoder 5 | from asteroid_filterbanks.transforms import mag 6 | 7 | 8 | class SingleSrcMultiScaleSpectral(_Loss): 9 | r"""Measure multi-scale spectral loss as described in [1] 10 | 11 | Args: 12 | n_filters (list): list containing the number of filter desired for 13 | each STFT 14 | windows_size (list): list containing the size of the window desired for 15 | each STFT 16 | hops_size (list): list containing the size of the hop desired for 17 | each STFT 18 | 19 | Shape: 20 | - est_targets : :math:`(batch, time)`. 21 | - targets: :math:`(batch, time)`. 22 | 23 | Returns: 24 | :class:`torch.Tensor`: with shape [batch] 25 | 26 | Examples 27 | >>> import torch 28 | >>> targets = torch.randn(10, 32000) 29 | >>> est_targets = torch.randn(10, 32000) 30 | >>> # Using it by itself on a pair of source/estimate 31 | >>> loss_func = SingleSrcMultiScaleSpectral() 32 | >>> loss = loss_func(est_targets, targets) 33 | 34 | >>> import torch 35 | >>> from asteroid.losses import PITLossWrapper 36 | >>> targets = torch.randn(10, 2, 32000) 37 | >>> est_targets = torch.randn(10, 2, 32000) 38 | >>> # Using it with PITLossWrapper with sets of source/estimates 39 | >>> loss_func = PITLossWrapper(SingleSrcMultiScaleSpectral(), 40 | >>> pit_from='pw_pt') 41 | >>> loss = loss_func(est_targets, targets) 42 | 43 | References 44 | [1] Jesse Engel and Lamtharn (Hanoi) Hantrakul and Chenjie Gu and 45 | Adam Roberts "DDSP: Differentiable Digital Signal Processing" ICLR 2020. 46 | """ 47 | 48 | def __init__(self, n_filters=None, windows_size=None, hops_size=None, alpha=1.0): 49 | super().__init__() 50 | 51 | if windows_size is None: 52 | windows_size = [2048, 1024, 512, 256, 128, 64, 32] 53 | if n_filters is None: 54 | n_filters = [2048, 1024, 512, 256, 128, 64, 32] 55 | if hops_size is None: 56 | hops_size = [1024, 512, 256, 128, 64, 32, 16] 57 | 58 | self.windows_size = windows_size 59 | self.n_filters = n_filters 60 | self.hops_size = hops_size 61 | self.alpha = alpha 62 | 63 | self.encoders = nn.ModuleList( 64 | Encoder(STFTFB(n_filters[i], windows_size[i], hops_size[i])) 65 | for i in range(len(self.n_filters)) 66 | ) 67 | 68 | def forward(self, est_target, target): 69 | batch_size = est_target.shape[0] 70 | est_target = est_target.unsqueeze(1) 71 | target = target.unsqueeze(1) 72 | 73 | loss = torch.zeros(batch_size, device=est_target.device) 74 | for encoder in self.encoders: 75 | loss += self.compute_spectral_loss(encoder, est_target, target) 76 | return loss 77 | 78 | def compute_spectral_loss(self, encoder, est_target, target, EPS=1e-8): 79 | batch_size = est_target.shape[0] 80 | spect_est_target = mag(encoder(est_target)).view(batch_size, -1) 81 | spect_target = mag(encoder(target)).view(batch_size, -1) 82 | linear_loss = self.norm1(spect_est_target - spect_target) 83 | log_loss = self.norm1(torch.log(spect_est_target + EPS) - torch.log(spect_target + EPS)) 84 | return linear_loss + self.alpha * log_loss 85 | 86 | @staticmethod 87 | def norm1(a): 88 | return torch.norm(a, p=1, dim=1) 89 | -------------------------------------------------------------------------------- /asteroid/losses/sinkpit_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from . import PITLossWrapper 5 | 6 | 7 | class SinkPITLossWrapper(nn.Module): 8 | r"""Permutation invariant loss wrapper. 9 | 10 | Args: 11 | loss_func: function with signature (targets, est_targets, **kwargs). 12 | n_iter (int): number of the Sinkhorn iteration (default = 200). 13 | Supposed to be an even number. 14 | hungarian_validation (boolean) : Whether to use the Hungarian algorithm 15 | for the validation. (default = True) 16 | 17 | ``loss_func`` computes pairwise losses and returns a torch.Tensor of shape 18 | :math:`(batch, n\_src, n\_src)`. Each element :math:`(batch, i, j)` corresponds to 19 | the loss between :math:`targets[:, i]` and :math:`est\_targets[:, j]` 20 | It evaluates an approximate value of the PIT loss using Sinkhorn's iterative algorithm. 21 | See :meth:`~PITLossWrapper.best_softperm_sinkhorn` and http://arxiv.org/abs/2010.11871 22 | 23 | Examples 24 | >>> import torch 25 | >>> import pytorch_lightning as pl 26 | >>> from asteroid.losses import pairwise_neg_sisdr 27 | >>> sources = torch.randn(10, 3, 16000) 28 | >>> est_sources = torch.randn(10, 3, 16000) 29 | >>> # Compute SinkPIT loss based on pairwise losses 30 | >>> loss_func = SinkPITLossWrapper(pairwise_neg_sisdr) 31 | >>> loss_val = loss_func(est_sources, sources) 32 | >>> # A fixed temperature parameter `beta` (=10) is used 33 | >>> # unless a cooling callback is set. The value can be 34 | >>> # dynamically changed using a cooling callback module as follows. 35 | >>> model = NeuralNetworkModel() 36 | >>> optimizer = optim.Adam(model.parameters(), lr=1e-3) 37 | >>> dataset = YourDataset() 38 | >>> loader = data.DataLoader(dataset, batch_size=16) 39 | >>> system = System( 40 | >>> model, 41 | >>> optimizer, 42 | >>> loss_func=SinkPITLossWrapper(pairwise_neg_sisdr), 43 | >>> train_loader=loader, 44 | >>> val_loader=loader, 45 | >>> ) 46 | >>> 47 | >>> trainer = pl.Trainer( 48 | >>> max_epochs=100, 49 | >>> callbacks=[SinkPITBetaScheduler(lambda epoch : 1.02 ** epoch)], 50 | >>> ) 51 | >>> 52 | >>> trainer.fit(system) 53 | """ 54 | 55 | def __init__(self, loss_func, n_iter=200, hungarian_validation=True): 56 | super().__init__() 57 | self.loss_func = loss_func 58 | self._beta = 10 59 | self.n_iter = n_iter 60 | self.hungarian_validation = hungarian_validation 61 | 62 | @property 63 | def beta(self): 64 | return self._beta 65 | 66 | @beta.setter 67 | def beta(self, beta): 68 | assert beta > 0 69 | self._beta = beta 70 | 71 | def forward(self, est_targets, targets, return_est=False, **kwargs): 72 | """Evaluate the loss using Sinkhorn's algorithm. 73 | 74 | Args: 75 | est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. 76 | The batch of target estimates. 77 | targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`. 78 | The batch of training targets 79 | return_est: Boolean. Whether to return the reordered targets 80 | estimates (To compute metrics or to save example). 81 | **kwargs: additional keyword argument that will be passed to the 82 | loss function. 83 | 84 | Returns: 85 | - Best permutation loss for each batch sample, average over 86 | the batch. torch.Tensor(loss_value) 87 | - The reordered targets estimates if return_est is True. 88 | torch.Tensor of shape :math:`(batch, nsrc, ...)`. 89 | """ 90 | n_src = targets.shape[1] 91 | assert n_src < 100, f"Expected source axis along dim 1, found {n_src}" 92 | 93 | # Evaluate the loss using Sinkhorn's iterative algorithm 94 | pw_losses = self.loss_func(est_targets, targets, **kwargs) 95 | 96 | assert pw_losses.ndim == 3, ( 97 | "Something went wrong with the loss " "function, please read the docs." 98 | ) 99 | assert pw_losses.shape[0] == targets.shape[0], "PIT loss needs same batch dim as input" 100 | 101 | if not return_est: 102 | if self.training or not self.hungarian_validation: 103 | # Train or sinkhorn validation 104 | min_loss, soft_perm = self.best_softperm_sinkhorn( 105 | pw_losses, self._beta, self.n_iter 106 | ) 107 | mean_loss = torch.mean(min_loss) 108 | return mean_loss 109 | else: 110 | # Reorder the output by using the Hungarian algorithm below 111 | min_loss, batch_indices = PITLossWrapper.find_best_perm(pw_losses) 112 | mean_loss = torch.mean(min_loss) 113 | return mean_loss 114 | else: 115 | # Test -> reorder the output by using the Hungarian algorithm below 116 | min_loss, batch_indices = PITLossWrapper.find_best_perm(pw_losses) 117 | mean_loss = torch.mean(min_loss) 118 | reordered = PITLossWrapper.reorder_source(est_targets, batch_indices) 119 | return mean_loss, reordered 120 | 121 | @staticmethod 122 | def best_softperm_sinkhorn(pair_wise_losses, beta=10, n_iter=200): 123 | r"""Compute an approximate PIT loss using Sinkhorn's algorithm. 124 | See http://arxiv.org/abs/2010.11871 125 | 126 | Args: 127 | pair_wise_losses (:class:`torch.Tensor`): 128 | Tensor of shape :math:`(batch, n_src, n_src)`. Pairwise losses. 129 | beta (float) : Inverse temperature parameter. (default = 10) 130 | n_iter (int) : Number of iteration. Even number. (default = 200) 131 | 132 | Returns: 133 | - :class:`torch.Tensor`: 134 | The loss corresponding to the best permutation of size (batch,). 135 | 136 | - :class:`torch.Tensor`: 137 | A soft permutation matrix. 138 | """ 139 | C = pair_wise_losses.transpose(-1, -2) 140 | n_src = C.shape[-1] 141 | # initial values 142 | Z = -beta * C 143 | for it in range(n_iter // 2): 144 | Z = Z - torch.logsumexp(Z, axis=1, keepdim=True) 145 | Z = Z - torch.logsumexp(Z, axis=2, keepdim=True) 146 | min_loss = torch.einsum("bij,bij->b", C + Z / beta, torch.exp(Z)) 147 | min_loss = min_loss / n_src 148 | return min_loss, torch.exp(Z) 149 | -------------------------------------------------------------------------------- /asteroid/losses/soft_f1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | 5 | class F1_loss(_Loss): 6 | """Calculate F1 score""" 7 | 8 | def __init__(self, eps=1e-10): 9 | super().__init__() 10 | self.eps = eps 11 | 12 | def forward(self, estimates, targets): 13 | tp = (targets * estimates).sum() 14 | fp = ((1 - targets) * estimates).sum() 15 | fn = (targets * (1 - estimates)).sum() 16 | 17 | precision = tp / (tp + fp + self.eps) 18 | recall = tp / (tp + fn + self.eps) 19 | 20 | f1 = 2 * (precision * recall) / (precision + recall + self.eps) 21 | return 1 - f1.mean() 22 | -------------------------------------------------------------------------------- /asteroid/losses/stoi.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | with warnings.catch_warnings(): 4 | warnings.simplefilter("ignore") 5 | from torch_stoi import NegSTOILoss as _NegSTOILoss 6 | 7 | 8 | class NegSTOILoss(_NegSTOILoss): 9 | r"""Negated Short Term Objective Intelligibility (STOI) metric, to be used 10 | as a loss function. 11 | Inspired from [1, 2, 3] but not the same. 12 | 13 | Args: 14 | sample_rate (int): sample rate of the audio files 15 | use_vad (bool): Whether to use simple VAD (see Notes) 16 | extended (bool): Whether to compute extended version [3]. 17 | 18 | Shapes: 19 | - :math:`(time,) -> (1, )` 20 | - :math:`(batch, time) -> (batch, )` 21 | - :math:`(batch, n\_src, time) -> (batch, n\_src)` 22 | 23 | Returns: 24 | torch.Tensor of shape (batch, *, ), only the time dimension has 25 | been reduced. 26 | 27 | .. warnings:: 28 | This function cannot be used to compute the "real" STOI metric as 29 | we applied some changes to speed-up loss computation. See Notes section. 30 | 31 | .. note:: 32 | In the NumPy version, some kind of simple VAD was used to remove the 33 | silent frames before chunking the signal into short-term envelope 34 | vectors. We don't do the same here because removing frames in a 35 | batch is cumbersome and inefficient. 36 | If `use_vad` is set to True, instead we detect the silent frames and 37 | keep a mask tensor. At the end, the normalized correlation of 38 | short-term envelope vectors is masked using this mask (unfolded) and 39 | the mean is computed taking the mask values into account. 40 | 41 | Examples 42 | >>> import torch 43 | >>> from asteroid.losses import PITLossWrapper 44 | >>> targets = torch.randn(10, 2, 32000) 45 | >>> est_targets = torch.randn(10, 2, 32000) 46 | >>> loss_func = PITLossWrapper(NegSTOILoss(sample_rate=8000), pit_from='pw_pt') 47 | >>> loss = loss_func(est_targets, targets) 48 | 49 | References 50 | [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time 51 | Objective Intelligibility Measure for Time-Frequency Weighted Noisy 52 | Speech', ICASSP 2010, Texas, Dallas. 53 | 54 | [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for 55 | Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', 56 | IEEE Transactions on Audio, Speech, and Language Processing, 2011. 57 | 58 | [3] Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the 59 | Intelligibility of Speech Masked by Modulated Noise Maskers', 60 | IEEE Transactions on Audio, Speech and Language Processing, 2016. 61 | """ 62 | 63 | def __init__(self, *args, **kwargs): 64 | super().__init__(*args, **kwargs) 65 | -------------------------------------------------------------------------------- /asteroid/masknn/__init__.py: -------------------------------------------------------------------------------- 1 | from .convolutional import TDConvNet, TDConvNetpp, SuDORMRF, SuDORMRFImproved 2 | from .recurrent import DPRNN, LSTMMasker 3 | from .attention import DPTransformer 4 | 5 | __all__ = [ 6 | "TDConvNet", 7 | "DPRNN", 8 | "DPTransformer", 9 | "LSTMMasker", 10 | "SuDORMRF", 11 | "SuDORMRFImproved", 12 | ] 13 | -------------------------------------------------------------------------------- /asteroid/masknn/_dccrn_architectures.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | DCCRN_ARCHITECTURES = { 3 | "DCCRN-CL": ( 4 | # Encoders: 5 | # (in_chan, out_chan, kernel_size, stride, padding) 6 | ( 7 | ( 1, 16, (5, 2), (2, 1), (2, 0)), 8 | ( 16, 32, (5, 2), (2, 1), (2, 0)), 9 | ( 32, 64, (5, 2), (2, 1), (2, 0)), 10 | ( 64, 128, (5, 2), (2, 1), (2, 0)), 11 | (128, 128, (5, 2), (2, 1), (2, 0)), 12 | (128, 128, (5, 2), (2, 1), (2, 0)), 13 | ), 14 | # Decoders: 15 | # (in_chan, out_chan, kernel_size, stride, padding, output_padding) 16 | ( 17 | (256, 128, (5, 2), (2, 1), (2, 0), (1, 0)), 18 | (256, 128, (5, 2), (2, 1), (2, 0), (1, 0)), 19 | (256, 64, (5, 2), (2, 1), (2, 0), (1, 0)), 20 | (128, 32, (5, 2), (2, 1), (2, 0), (1, 0)), 21 | ( 64, 16, (5, 2), (2, 1), (2, 0), (1, 0)), 22 | ( 32, 1, (5, 2), (2, 1), (2, 0), (1, 0)), 23 | ), 24 | ), 25 | "mini": ( 26 | # This is a dummy architecture used for Asteroid unit tests. 27 | 28 | # Encoders: 29 | # (in_chan, out_chan, kernel_size, stride, padding) 30 | ( 31 | (1, 4, (5, 2), (2, 1), (2, 0)), 32 | (4, 8, (5, 2), (2, 1), (2, 0)), 33 | ), 34 | # Decoders: 35 | # (in_chan, out_chan, kernel_size, stride, padding, output_padding) 36 | ( 37 | (16, 4, (5, 2), (2, 1), (2, 0), (1, 0)), 38 | ( 8, 1, (5, 2), (2, 1), (2, 0), (1, 0)), 39 | ), 40 | ), 41 | } 42 | -------------------------------------------------------------------------------- /asteroid/masknn/_dcunet_architectures.py: -------------------------------------------------------------------------------- 1 | from ..utils.generic_utils import unet_decoder_args 2 | 3 | 4 | def make_unet_encoder_decoder_args(encoder_args, decoder_args): 5 | encoder_args = tuple( 6 | ( 7 | in_chan, 8 | out_chan, 9 | kernel_size, 10 | stride, 11 | tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding, 12 | ) 13 | for in_chan, out_chan, kernel_size, stride, padding in encoder_args 14 | ) 15 | 16 | if decoder_args == "auto": 17 | decoder_args = unet_decoder_args( 18 | encoder_args, 19 | skip_connections=True, 20 | ) 21 | else: 22 | decoder_args = tuple( 23 | ( 24 | in_chan, 25 | out_chan, 26 | kernel_size, 27 | stride, 28 | tuple([n // 2 for n in kernel_size]) if padding == "auto" else padding, 29 | output_padding, 30 | ) 31 | for in_chan, out_chan, kernel_size, stride, padding, output_padding in decoder_args 32 | ) 33 | 34 | return encoder_args, decoder_args 35 | 36 | 37 | # fmt: off 38 | 39 | DCUNET_ARCHITECTURES = { 40 | "DCUNet-10": make_unet_encoder_decoder_args( 41 | # Encoders: 42 | # (in_chan, out_chan, kernel_size, stride, padding) 43 | ( 44 | ( 1, 32, (7, 5), (2, 2), "auto"), 45 | (32, 64, (7, 5), (2, 2), "auto"), 46 | (64, 64, (5, 3), (2, 2), "auto"), 47 | (64, 64, (5, 3), (2, 2), "auto"), 48 | (64, 64, (5, 3), (2, 1), "auto"), 49 | ), 50 | # Decoders: automatic inverse 51 | "auto", 52 | ), 53 | "DCUNet-16": make_unet_encoder_decoder_args( 54 | # Encoders: 55 | # (in_chan, out_chan, kernel_size, stride, padding) 56 | ( 57 | ( 1, 32, (7, 5), (2, 2), "auto"), 58 | (32, 32, (7, 5), (2, 1), "auto"), 59 | (32, 64, (7, 5), (2, 2), "auto"), 60 | (64, 64, (5, 3), (2, 1), "auto"), 61 | (64, 64, (5, 3), (2, 2), "auto"), 62 | (64, 64, (5, 3), (2, 1), "auto"), 63 | (64, 64, (5, 3), (2, 2), "auto"), 64 | (64, 64, (5, 3), (2, 1), "auto"), 65 | ), 66 | # Decoders: automatic inverse 67 | "auto", 68 | ), 69 | "DCUNet-20": make_unet_encoder_decoder_args( 70 | # Encoders: 71 | # (in_chan, out_chan, kernel_size, stride, padding) 72 | ( 73 | ( 1, 32, (7, 1), (1, 1), "auto"), 74 | (32, 32, (1, 7), (1, 1), "auto"), 75 | (32, 64, (7, 5), (2, 2), "auto"), 76 | (64, 64, (7, 5), (2, 1), "auto"), 77 | (64, 64, (5, 3), (2, 2), "auto"), 78 | (64, 64, (5, 3), (2, 1), "auto"), 79 | (64, 64, (5, 3), (2, 2), "auto"), 80 | (64, 64, (5, 3), (2, 1), "auto"), 81 | (64, 64, (5, 3), (2, 2), "auto"), 82 | (64, 90, (5, 3), (2, 1), "auto"), 83 | ), 84 | # Decoders: automatic inverse 85 | "auto", 86 | ), 87 | "Large-DCUNet-20": make_unet_encoder_decoder_args( 88 | # Encoders: 89 | # (in_chan, out_chan, kernel_size, stride, padding) 90 | ( 91 | ( 1, 45, (7, 1), (1, 1), "auto"), 92 | (45, 45, (1, 7), (1, 1), "auto"), 93 | (45, 90, (7, 5), (2, 2), "auto"), 94 | (90, 90, (7, 5), (2, 1), "auto"), 95 | (90, 90, (5, 3), (2, 2), "auto"), 96 | (90, 90, (5, 3), (2, 1), "auto"), 97 | (90, 90, (5, 3), (2, 2), "auto"), 98 | (90, 90, (5, 3), (2, 1), "auto"), 99 | (90, 90, (5, 3), (2, 2), "auto"), 100 | (90, 128, (5, 3), (2, 1), "auto"), 101 | ), 102 | # Decoders: 103 | # (in_chan, out_chan, kernel_size, stride, padding, output_padding) 104 | ( 105 | (128, 90, (5, 3), (2, 1), "auto", (0, 0)), 106 | (180, 90, (5, 3), (2, 2), "auto", (0, 0)), 107 | (180, 90, (5, 3), (2, 1), "auto", (0, 0)), 108 | (180, 90, (5, 3), (2, 2), "auto", (0, 0)), 109 | (180, 90, (5, 3), (2, 1), "auto", (0, 0)), 110 | (180, 90, (5, 3), (2, 2), "auto", (0, 0)), 111 | (180, 90, (7, 5), (2, 1), "auto", (0, 0)), 112 | (180, 90, (7, 5), (2, 2), "auto", (0, 0)), 113 | (135, 90, (1, 7), (1, 1), "auto", (0, 0)), 114 | (135, 1, (7, 1), (1, 1), "auto", (0, 0)), 115 | ), 116 | ), 117 | "mini": make_unet_encoder_decoder_args( 118 | # This is a dummy architecture used for Asteroid unit tests. 119 | 120 | # Encoders: 121 | # (in_chan, out_chan, kernel_size, stride, padding) 122 | ( 123 | (1, 4, (7, 5), (2, 2), "auto"), 124 | (4, 8, (7, 5), (2, 2), "auto"), 125 | (8, 16, (5, 3), (2, 2), "auto"), 126 | ), 127 | # Decoders: automatic inverse 128 | "auto", 129 | ), 130 | } 131 | -------------------------------------------------------------------------------- /asteroid/masknn/_local.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from .norms import GlobLN 3 | 4 | 5 | class _ConvNormAct(nn.Module): 6 | """Convolution layer with normalization and a PReLU activation. 7 | 8 | See license and copyright notices here 9 | https://github.com/etzinis/sudo_rm_rf#copyright-and-license 10 | https://github.com/etzinis/sudo_rm_rf/blob/master/LICENSE 11 | 12 | Args 13 | nIn: number of input channels 14 | nOut: number of output channels 15 | kSize: kernel size 16 | stride: stride rate for down-sampling. Default is 1 17 | """ 18 | 19 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1, use_globln=False): 20 | 21 | super().__init__() 22 | padding = int((kSize - 1) / 2) 23 | self.conv = nn.Conv1d( 24 | nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups 25 | ) 26 | if use_globln: 27 | self.norm = GlobLN(nOut) 28 | self.act = nn.PReLU() 29 | else: 30 | self.norm = nn.GroupNorm(1, nOut, eps=1e-08) 31 | self.act = nn.PReLU(nOut) 32 | 33 | def forward(self, inp): 34 | output = self.conv(inp) 35 | output = self.norm(output) 36 | return self.act(output) 37 | 38 | 39 | class _ConvNorm(nn.Module): 40 | """Convolution layer with normalization without activation. 41 | 42 | See license and copyright notices here 43 | https://github.com/etzinis/sudo_rm_rf#copyright-and-license 44 | https://github.com/etzinis/sudo_rm_rf/blob/master/LICENSE 45 | 46 | 47 | Args: 48 | nIn: number of input channels 49 | nOut: number of output channels 50 | kSize: kernel size 51 | stride: stride rate for down-sampling. Default is 1 52 | """ 53 | 54 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 55 | 56 | super().__init__() 57 | padding = int((kSize - 1) / 2) 58 | self.conv = nn.Conv1d( 59 | nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups 60 | ) 61 | self.norm = nn.GroupNorm(1, nOut, eps=1e-08) 62 | 63 | def forward(self, inp): 64 | output = self.conv(inp) 65 | return self.norm(output) 66 | 67 | 68 | class _NormAct(nn.Module): 69 | """Normalization and PReLU activation. 70 | 71 | See license and copyright notices here 72 | https://github.com/etzinis/sudo_rm_rf#copyright-and-license 73 | https://github.com/etzinis/sudo_rm_rf/blob/master/LICENSE 74 | 75 | Args: 76 | nOut: number of output channels 77 | """ 78 | 79 | def __init__(self, nOut, use_globln=False): 80 | super().__init__() 81 | if use_globln: 82 | self.norm = GlobLN(nOut) 83 | else: 84 | self.norm = nn.GroupNorm(1, nOut, eps=1e-08) 85 | self.act = nn.PReLU(nOut) 86 | 87 | def forward(self, inp): 88 | output = self.norm(inp) 89 | return self.act(output) 90 | 91 | 92 | class _DilatedConvNorm(nn.Module): 93 | """Dilated convolution with normalized output. 94 | 95 | See license and copyright notices here 96 | https://github.com/etzinis/sudo_rm_rf#copyright-and-license 97 | https://github.com/etzinis/sudo_rm_rf/blob/master/LICENSE 98 | 99 | Args: 100 | nIn: number of input channels 101 | nOut: number of output channels 102 | kSize: kernel size 103 | stride: optional stride rate for down-sampling 104 | d: optional dilation rate 105 | """ 106 | 107 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1, use_globln=False): 108 | super().__init__() 109 | self.conv = nn.Conv1d( 110 | nIn, 111 | nOut, 112 | kSize, 113 | stride=stride, 114 | dilation=d, 115 | padding=((kSize - 1) // 2) * d, 116 | groups=groups, 117 | ) 118 | if use_globln: 119 | self.norm = GlobLN(nOut) 120 | else: 121 | self.norm = nn.GroupNorm(1, nOut, eps=1e-08) 122 | 123 | def forward(self, inp): 124 | output = self.conv(inp) 125 | return self.norm(output) 126 | -------------------------------------------------------------------------------- /asteroid/masknn/activations.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from torch import nn 4 | from .. import complex_nn 5 | 6 | 7 | class Swish(nn.Module): 8 | def __init__(self): 9 | super(Swish, self).__init__() 10 | 11 | def forward(self, x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def linear(): 16 | return nn.Identity() 17 | 18 | 19 | def relu(): 20 | return nn.ReLU() 21 | 22 | 23 | def prelu(): 24 | return nn.PReLU() 25 | 26 | 27 | def leaky_relu(): 28 | return nn.LeakyReLU() 29 | 30 | 31 | def sigmoid(): 32 | return nn.Sigmoid() 33 | 34 | 35 | def softmax(dim=None): 36 | return nn.Softmax(dim=dim) 37 | 38 | 39 | def tanh(): 40 | return nn.Tanh() 41 | 42 | 43 | def gelu(): 44 | return nn.GELU() 45 | 46 | 47 | def swish(): 48 | return Swish() 49 | 50 | 51 | def register_activation(custom_act): 52 | """Register a custom activation, gettable with `activation.get`. 53 | 54 | Args: 55 | custom_act: Custom activation function to register. 56 | 57 | """ 58 | if custom_act.__name__ in globals().keys() or custom_act.__name__.lower() in globals().keys(): 59 | raise ValueError(f"Activation {custom_act.__name__} already exists. Choose another name.") 60 | globals().update({custom_act.__name__: custom_act}) 61 | 62 | 63 | def get(identifier): 64 | """Returns an activation function from a string. Returns its input if it 65 | is callable (already an activation for example). 66 | 67 | Args: 68 | identifier (str or Callable or None): the activation identifier. 69 | 70 | Returns: 71 | :class:`nn.Module` or None 72 | """ 73 | if identifier is None: 74 | return None 75 | elif callable(identifier): 76 | return identifier 77 | elif isinstance(identifier, str): 78 | cls = globals().get(identifier) 79 | if cls is None: 80 | raise ValueError("Could not interpret activation identifier: " + str(identifier)) 81 | return cls 82 | else: 83 | raise ValueError("Could not interpret activation identifier: " + str(identifier)) 84 | 85 | 86 | def get_complex(identifier): 87 | """Like `.get` but returns a complex activation created with `asteroid.complex_nn.OnReIm`.""" 88 | activation = get(identifier) 89 | if activation is None: 90 | return None 91 | else: 92 | return partial(complex_nn.OnReIm, activation) 93 | -------------------------------------------------------------------------------- /asteroid/masknn/attention.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import warnings 3 | 4 | import torch.nn as nn 5 | from torch.nn.modules.activation import MultiheadAttention 6 | from . import activations, norms 7 | import torch 8 | from ..utils import has_arg 9 | from ..dsp.overlap_add import DualPathProcessing 10 | 11 | 12 | class ImprovedTransformedLayer(nn.Module): 13 | """ 14 | Improved Transformer module as used in [1]. 15 | It is Multi-Head self-attention followed by LSTM, activation and linear projection layer. 16 | 17 | Args: 18 | embed_dim (int): Number of input channels. 19 | n_heads (int): Number of attention heads. 20 | dim_ff (int): Number of neurons in the RNNs cell state. 21 | Defaults to 256. RNN here replaces standard FF linear layer in plain Transformer. 22 | dropout (float, optional): Dropout ratio, must be in [0,1]. 23 | activation (str, optional): activation function applied at the output of RNN. 24 | bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN 25 | (Intra-Chunk is always bidirectional). 26 | norm (str, optional): Type of normalization to use. 27 | 28 | References 29 | [1] Chen, Jingjing, Qirong Mao, and Dong Liu. "Dual-Path Transformer 30 | Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation." 31 | arXiv (2020). 32 | """ 33 | 34 | def __init__( 35 | self, 36 | embed_dim, 37 | n_heads, 38 | dim_ff, 39 | dropout=0.0, 40 | activation="relu", 41 | bidirectional=True, 42 | norm="gLN", 43 | ): 44 | super(ImprovedTransformedLayer, self).__init__() 45 | 46 | self.mha = MultiheadAttention(embed_dim, n_heads, dropout=dropout) 47 | self.dropout = nn.Dropout(dropout) 48 | self.recurrent = nn.LSTM(embed_dim, dim_ff, bidirectional=bidirectional, batch_first=True) 49 | ff_inner_dim = 2 * dim_ff if bidirectional else dim_ff 50 | self.linear = nn.Linear(ff_inner_dim, embed_dim) 51 | self.activation = activations.get(activation)() 52 | self.norm_mha = norms.get(norm)(embed_dim) 53 | self.norm_ff = norms.get(norm)(embed_dim) 54 | 55 | def forward(self, x): 56 | tomha = x.permute(2, 0, 1) 57 | # x is batch, channels, seq_len 58 | # mha is seq_len, batch, channels 59 | # self-attention is applied 60 | out = self.mha(tomha, tomha, tomha)[0] 61 | x = self.dropout(out.permute(1, 2, 0)) + x 62 | x = self.norm_mha(x) 63 | 64 | # lstm is applied 65 | out = self.linear(self.dropout(self.activation(self.recurrent(x.transpose(1, -1))[0]))) 66 | x = self.dropout(out.transpose(1, -1)) + x 67 | return self.norm_ff(x) 68 | 69 | 70 | class DPTransformer(nn.Module): 71 | """Dual-path Transformer introduced in [1]. 72 | 73 | Args: 74 | in_chan (int): Number of input filters. 75 | n_src (int): Number of masks to estimate. 76 | n_heads (int): Number of attention heads. 77 | ff_hid (int): Number of neurons in the RNNs cell state. 78 | Defaults to 256. 79 | chunk_size (int): window size of overlap and add processing. 80 | Defaults to 100. 81 | hop_size (int or None): hop size (stride) of overlap and add processing. 82 | Default to `chunk_size // 2` (50% overlap). 83 | n_repeats (int): Number of repeats. Defaults to 6. 84 | norm_type (str, optional): Type of normalization to use. 85 | ff_activation (str, optional): activation function applied at the output of RNN. 86 | mask_act (str, optional): Which non-linear function to generate mask. 87 | bidirectional (bool, optional): True for bidirectional Inter-Chunk RNN 88 | (Intra-Chunk is always bidirectional). 89 | dropout (float, optional): Dropout ratio, must be in [0,1]. 90 | 91 | References 92 | [1] Chen, Jingjing, Qirong Mao, and Dong Liu. "Dual-Path Transformer 93 | Network: Direct Context-Aware Modeling for End-to-End Monaural Speech Separation." 94 | arXiv (2020). 95 | """ 96 | 97 | def __init__( 98 | self, 99 | in_chan, 100 | n_src, 101 | n_heads=4, 102 | ff_hid=256, 103 | chunk_size=100, 104 | hop_size=None, 105 | n_repeats=6, 106 | norm_type="gLN", 107 | ff_activation="relu", 108 | mask_act="relu", 109 | bidirectional=True, 110 | dropout=0, 111 | ): 112 | super(DPTransformer, self).__init__() 113 | self.in_chan = in_chan 114 | self.n_src = n_src 115 | self.n_heads = n_heads 116 | self.ff_hid = ff_hid 117 | self.chunk_size = chunk_size 118 | hop_size = hop_size if hop_size is not None else chunk_size // 2 119 | self.hop_size = hop_size 120 | self.n_repeats = n_repeats 121 | self.n_src = n_src 122 | self.norm_type = norm_type 123 | self.ff_activation = ff_activation 124 | self.mask_act = mask_act 125 | self.bidirectional = bidirectional 126 | self.dropout = dropout 127 | 128 | self.mha_in_dim = ceil(self.in_chan / self.n_heads) * self.n_heads 129 | if self.in_chan % self.n_heads != 0: 130 | warnings.warn( 131 | f"DPTransformer input dim ({self.in_chan}) is not a multiple of the number of " 132 | f"heads ({self.n_heads}). Adding extra linear layer at input to accomodate " 133 | f"(size [{self.in_chan} x {self.mha_in_dim}])" 134 | ) 135 | self.input_layer = nn.Linear(self.in_chan, self.mha_in_dim) 136 | else: 137 | self.input_layer = None 138 | 139 | self.in_norm = norms.get(norm_type)(self.mha_in_dim) 140 | self.ola = DualPathProcessing(self.chunk_size, self.hop_size) 141 | 142 | # Succession of DPRNNBlocks. 143 | self.layers = nn.ModuleList([]) 144 | for x in range(self.n_repeats): 145 | self.layers.append( 146 | nn.ModuleList( 147 | [ 148 | ImprovedTransformedLayer( 149 | self.mha_in_dim, 150 | self.n_heads, 151 | self.ff_hid, 152 | self.dropout, 153 | self.ff_activation, 154 | True, 155 | self.norm_type, 156 | ), 157 | ImprovedTransformedLayer( 158 | self.mha_in_dim, 159 | self.n_heads, 160 | self.ff_hid, 161 | self.dropout, 162 | self.ff_activation, 163 | self.bidirectional, 164 | self.norm_type, 165 | ), 166 | ] 167 | ) 168 | ) 169 | net_out_conv = nn.Conv2d(self.mha_in_dim, n_src * self.in_chan, 1) 170 | self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) 171 | # Gating and masking in 2D space (after fold) 172 | self.net_out = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Tanh()) 173 | self.net_gate = nn.Sequential(nn.Conv1d(self.in_chan, self.in_chan, 1), nn.Sigmoid()) 174 | 175 | # Get activation function. 176 | mask_nl_class = activations.get(mask_act) 177 | # For softmax, feed the source dimension. 178 | if has_arg(mask_nl_class, "dim"): 179 | self.output_act = mask_nl_class(dim=1) 180 | else: 181 | self.output_act = mask_nl_class() 182 | 183 | def forward(self, mixture_w): 184 | r"""Forward. 185 | 186 | Args: 187 | mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ 188 | 189 | Returns: 190 | :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ 191 | """ 192 | if self.input_layer is not None: 193 | mixture_w = self.input_layer(mixture_w.transpose(1, 2)).transpose(1, 2) 194 | mixture_w = self.in_norm(mixture_w) # [batch, bn_chan, n_frames] 195 | n_orig_frames = mixture_w.shape[-1] 196 | 197 | mixture_w = self.ola.unfold(mixture_w) 198 | batch, n_filters, self.chunk_size, n_chunks = mixture_w.size() 199 | 200 | for layer_idx in range(len(self.layers)): 201 | intra, inter = self.layers[layer_idx] 202 | mixture_w = self.ola.intra_process(mixture_w, intra) 203 | mixture_w = self.ola.inter_process(mixture_w, inter) 204 | 205 | output = self.first_out(mixture_w) 206 | output = output.reshape(batch * self.n_src, self.in_chan, self.chunk_size, n_chunks) 207 | output = self.ola.fold(output, output_size=n_orig_frames) 208 | 209 | output = self.net_out(output) * self.net_gate(output) 210 | # Compute mask 211 | output = output.reshape(batch, self.n_src, self.in_chan, -1) 212 | est_mask = self.output_act(output) 213 | return est_mask 214 | 215 | def get_config(self): 216 | config = { 217 | "in_chan": self.in_chan, 218 | "ff_hid": self.ff_hid, 219 | "n_heads": self.n_heads, 220 | "chunk_size": self.chunk_size, 221 | "hop_size": self.hop_size, 222 | "n_repeats": self.n_repeats, 223 | "n_src": self.n_src, 224 | "norm_type": self.norm_type, 225 | "ff_activation": self.ff_activation, 226 | "mask_act": self.mask_act, 227 | "bidirectional": self.bidirectional, 228 | "dropout": self.dropout, 229 | } 230 | return config 231 | -------------------------------------------------------------------------------- /asteroid/masknn/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .. import complex_nn 5 | 6 | 7 | def _none_sequential(*args): 8 | return torch.nn.Sequential(*[x for x in args if x is not None]) 9 | 10 | 11 | class BaseUNet(torch.nn.Module): 12 | """Base class for u-nets with skip connections between encoders and decoders. 13 | 14 | (For u-nets without skip connections, simply use a `nn.Sequential`.) 15 | 16 | Args: 17 | encoders (List[torch.nn.Module] of length `N`): List of encoders 18 | decoders (List[torch.nn.Module] of length `N - 1`): List of decoders 19 | output_layer (Optional[torch.nn.Module], optional): 20 | Layer after last decoder. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | encoders, 26 | decoders, 27 | *, 28 | output_layer=None, 29 | ): 30 | assert len(encoders) == len(decoders) + 1 31 | 32 | super().__init__() 33 | 34 | self.encoders = torch.nn.ModuleList(encoders) 35 | self.decoders = torch.nn.ModuleList(decoders) 36 | self.output_layer = output_layer or torch.nn.Identity() 37 | 38 | def forward(self, x): 39 | enc_outs = [] 40 | for idx, enc in enumerate(self.encoders): 41 | x = enc(x) 42 | enc_outs.append(x) 43 | for idx, (enc_out, dec) in enumerate(zip(reversed(enc_outs[:-1]), self.decoders)): 44 | x = dec(x) 45 | x = torch.cat([x, enc_out], dim=1) 46 | return self.output_layer(x) 47 | 48 | 49 | class BaseDCUMaskNet(BaseUNet): 50 | """Base class for DCU-style mask nets. Used for DCUMaskNet and DCCRMaskNet. 51 | 52 | The preferred way to instantiate this class is to use the ``default_architecture()`` 53 | classmethod. 54 | 55 | Args: 56 | encoders (List[torch.nn.Module]): List of encoders 57 | decoders (List[torch.nn.Module]): List of decoders 58 | output_layer (Optional[torch.nn.Module], optional): 59 | Layer after last decoder, before mask application. 60 | mask_bound (Optional[str], optional): Type of mask bound to use, as defined in [1]. 61 | Valid values are "tanh" ("BDT mask"), "sigmoid" ("BDSS mask"), None (unbounded mask). 62 | 63 | References 64 | - [1] : "Phase-aware Speech Enhancement with Deep Complex U-Net", 65 | Hyeong-Seok Choi et al. https://arxiv.org/abs/1903.03107 66 | """ 67 | 68 | _architectures = NotImplemented 69 | 70 | @classmethod 71 | def default_architecture(cls, architecture: str, n_src=1, **kwargs): 72 | """Create a masknet instance from a predefined, named architecture. 73 | 74 | Args: 75 | architecture (str): Name of predefined architecture. Valid values 76 | are dependent on the concrete subclass of ``BaseDCUMaskNet``. 77 | n_src (int, optional): Number of sources 78 | kwargs (optional): Passed to ``__init__``. 79 | """ 80 | encoders, decoders = cls._architectures[architecture] 81 | # Fix n_src in last decoder 82 | in_chan, _ignored_out_chan, *rest = decoders[-1] 83 | decoders = (*decoders[:-1], (in_chan, n_src, *rest)) 84 | return cls(encoders, decoders, **kwargs) 85 | 86 | def __init__(self, encoders, decoders, output_layer=None, mask_bound="tanh", **kwargs): 87 | self.mask_bound = mask_bound 88 | super().__init__( 89 | encoders=encoders, 90 | decoders=decoders, 91 | output_layer=_none_sequential( 92 | output_layer, 93 | complex_nn.BoundComplexMask(mask_bound), 94 | ), 95 | **kwargs, 96 | ) 97 | 98 | def forward(self, x): 99 | fixed_x = self.fix_input_dims(x) 100 | out = super().forward(fixed_x.unsqueeze(1)) 101 | out = self.fix_output_dims(out, x) 102 | return out 103 | 104 | def fix_input_dims(self, x): 105 | """Overwrite this in subclasses to implement input dimension checks.""" 106 | return x 107 | 108 | def fix_output_dims(self, y, x): 109 | """Overwrite this in subclasses to implement output dimension checks. 110 | y is the output and x was the input (passed to use the shape).""" 111 | return y 112 | -------------------------------------------------------------------------------- /asteroid/masknn/norms.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch 3 | from torch import nn 4 | from torch.nn.modules.batchnorm import _BatchNorm 5 | from typing import List 6 | 7 | from .. import complex_nn 8 | from ..utils.torch_utils import script_if_tracing 9 | 10 | EPS = 1e-8 11 | 12 | 13 | def z_norm(x, dims: List[int], eps: float = 1e-8): 14 | mean = x.mean(dim=dims, keepdim=True) 15 | var2 = torch.var(x, dim=dims, keepdim=True, unbiased=False) 16 | value = (x - mean) / torch.sqrt((var2 + eps)) 17 | return value 18 | 19 | 20 | @script_if_tracing 21 | def _glob_norm(x, eps: float = 1e-8): 22 | dims: List[int] = torch.arange(1, len(x.shape)).tolist() 23 | return z_norm(x, dims, eps) 24 | 25 | 26 | @script_if_tracing 27 | def _feat_glob_norm(x, eps: float = 1e-8): 28 | dims: List[int] = torch.arange(2, len(x.shape)).tolist() 29 | return z_norm(x, dims, eps) 30 | 31 | 32 | class _LayerNorm(nn.Module): 33 | """Layer Normalization base class.""" 34 | 35 | def __init__(self, channel_size): 36 | super(_LayerNorm, self).__init__() 37 | self.channel_size = channel_size 38 | self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True) 39 | self.beta = nn.Parameter(torch.zeros(channel_size), requires_grad=True) 40 | 41 | def apply_gain_and_bias(self, normed_x): 42 | """ Assumes input of size `[batch, chanel, *]`. """ 43 | return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1) 44 | 45 | 46 | class GlobLN(_LayerNorm): 47 | """Global Layer Normalization (globLN).""" 48 | 49 | def forward(self, x, EPS: float = 1e-8): 50 | """Applies forward pass. 51 | 52 | Works for any input size > 2D. 53 | 54 | Args: 55 | x (:class:`torch.Tensor`): Shape `[batch, chan, *]` 56 | 57 | Returns: 58 | :class:`torch.Tensor`: gLN_x `[batch, chan, *]` 59 | """ 60 | value = _glob_norm(x, eps=EPS) 61 | return self.apply_gain_and_bias(value) 62 | 63 | 64 | class ChanLN(_LayerNorm): 65 | """Channel-wise Layer Normalization (chanLN).""" 66 | 67 | def forward(self, x, EPS: float = 1e-8): 68 | """Applies forward pass. 69 | 70 | Works for any input size > 2D. 71 | 72 | Args: 73 | x (:class:`torch.Tensor`): `[batch, chan, *]` 74 | 75 | Returns: 76 | :class:`torch.Tensor`: chanLN_x `[batch, chan, *]` 77 | """ 78 | mean = torch.mean(x, dim=1, keepdim=True) 79 | var = torch.var(x, dim=1, keepdim=True, unbiased=False) 80 | return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt()) 81 | 82 | 83 | class CumLN(_LayerNorm): 84 | """Cumulative Global layer normalization(cumLN).""" 85 | 86 | def forward(self, x, EPS: float = 1e-8): 87 | """ 88 | 89 | Args: 90 | x (:class:`torch.Tensor`): Shape `[batch, channels, length]` 91 | Returns: 92 | :class:`torch.Tensor`: cumLN_x `[batch, channels, length]` 93 | """ 94 | batch, chan, spec_len = x.size() 95 | cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=-1) 96 | cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=-1) 97 | cnt = torch.arange( 98 | start=chan, end=chan * (spec_len + 1), step=chan, dtype=x.dtype, device=x.device 99 | ).view(1, 1, -1) 100 | cum_mean = cum_sum / cnt 101 | cum_var = cum_pow_sum - cum_mean.pow(2) 102 | return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt()) 103 | 104 | 105 | class FeatsGlobLN(_LayerNorm): 106 | """Feature-wise global Layer Normalization (FeatsGlobLN). 107 | Applies normalization over frames for each channel.""" 108 | 109 | def forward(self, x, EPS: float = 1e-8): 110 | """Applies forward pass. 111 | 112 | Works for any input size > 2D. 113 | 114 | Args: 115 | x (:class:`torch.Tensor`): `[batch, chan, time]` 116 | 117 | Returns: 118 | :class:`torch.Tensor`: chanLN_x `[batch, chan, time]` 119 | """ 120 | value = _feat_glob_norm(x, eps=EPS) 121 | return self.apply_gain_and_bias(value) 122 | 123 | 124 | class BatchNorm(_BatchNorm): 125 | """Wrapper class for pytorch BatchNorm1D and BatchNorm2D""" 126 | 127 | def _check_input_dim(self, input): 128 | if input.dim() < 2 or input.dim() > 4: 129 | raise ValueError("expected 4D or 3D input (got {}D input)".format(input.dim())) 130 | 131 | 132 | # Aliases. 133 | gLN = GlobLN 134 | fgLN = FeatsGlobLN 135 | cLN = ChanLN 136 | cgLN = CumLN 137 | bN = BatchNorm 138 | 139 | 140 | def register_norm(custom_norm): 141 | """Register a custom norm, gettable with `norms.get`. 142 | 143 | Args: 144 | custom_norm: Custom norm to register. 145 | 146 | """ 147 | if custom_norm.__name__ in globals().keys() or custom_norm.__name__.lower() in globals().keys(): 148 | raise ValueError(f"Norm {custom_norm.__name__} already exists. Choose another name.") 149 | globals().update({custom_norm.__name__: custom_norm}) 150 | 151 | 152 | def get(identifier): 153 | """Returns a norm class from a string. Returns its input if it 154 | is callable (already a :class:`._LayerNorm` for example). 155 | 156 | Args: 157 | identifier (str or Callable or None): the norm identifier. 158 | 159 | Returns: 160 | :class:`._LayerNorm` or None 161 | """ 162 | if identifier is None: 163 | return None 164 | elif callable(identifier): 165 | return identifier 166 | elif isinstance(identifier, str): 167 | cls = globals().get(identifier) 168 | if cls is None: 169 | raise ValueError("Could not interpret normalization identifier: " + str(identifier)) 170 | return cls 171 | else: 172 | raise ValueError("Could not interpret normalization identifier: " + str(identifier)) 173 | 174 | 175 | def get_complex(identifier): 176 | """Like `.get` but returns a complex norm created with `asteroid.complex_nn.OnReIm`.""" 177 | norm = get(identifier) 178 | if norm is None: 179 | return None 180 | else: 181 | return partial(complex_nn.OnReIm, norm) 182 | -------------------------------------------------------------------------------- /asteroid/masknn/tac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import activations, norms 4 | 5 | 6 | class TAC(nn.Module): 7 | """Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1]. 8 | 9 | Args: 10 | input_dim (int): Number of features of input representation. 11 | hidden_dim (int, optional): size of hidden layers in TAC operations. 12 | activation (str, optional): type of activation used. See asteroid.masknn.activations. 13 | norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms. 14 | 15 | .. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)` 16 | as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``. 17 | Output is of same shape as input. 18 | 19 | References 20 | [1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel 21 | speech separation." ICASSP 2020. 22 | """ 23 | 24 | def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"): 25 | super().__init__() 26 | self.hidden_dim = hidden_dim 27 | self.input_tf = nn.Sequential( 28 | nn.Linear(input_dim, hidden_dim), activations.get(activation)() 29 | ) 30 | self.avg_tf = nn.Sequential( 31 | nn.Linear(hidden_dim, hidden_dim), activations.get(activation)() 32 | ) 33 | self.concat_tf = nn.Sequential( 34 | nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)() 35 | ) 36 | self.norm = norms.get(norm_type)(input_dim) 37 | 38 | def forward(self, x, valid_mics=None): 39 | """ 40 | Args: 41 | x: (:class:`torch.Tensor`): Input multi-channel DPRNN features. 42 | Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. 43 | valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch. 44 | Batches can be composed of examples coming from arrays with a different 45 | number of microphones and thus the ``mic_channels`` dimension is padded. 46 | E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3. 47 | Shape: :math`(batch)`. 48 | 49 | Returns: 50 | output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing. 51 | Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. 52 | """ 53 | # Input is 5D because it is multi-channel DPRNN. DPRNN single channel is 4D. 54 | batch_size, nmics, channels, chunk_size, n_chunks = x.size() 55 | if valid_mics is None: 56 | valid_mics = torch.LongTensor([nmics] * batch_size) 57 | # First operation: transform the input for each frame and independently on each mic channel. 58 | output = self.input_tf( 59 | x.permute(0, 3, 4, 1, 2).reshape(batch_size * nmics * chunk_size * n_chunks, channels) 60 | ).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim) 61 | 62 | # Mean pooling across channels 63 | if valid_mics.max() == 0: 64 | # Fixed geometry array 65 | mics_mean = output.mean(1) 66 | else: 67 | # Only consider valid channels in each batch element: each example can have different number of microphones. 68 | mics_mean = [ 69 | output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0) for b in range(batch_size) 70 | ] # 1, dim1*dim2, H 71 | mics_mean = torch.cat(mics_mean, 0) # B*dim1*dim2, H 72 | 73 | # The average is processed by a non-linear transform 74 | mics_mean = self.avg_tf( 75 | mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim) 76 | ) 77 | mics_mean = ( 78 | mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim) 79 | .unsqueeze(3) 80 | .expand_as(output) 81 | ) 82 | 83 | # Concatenate the transformed average in each channel with the original feats and 84 | # project back to same number of features 85 | output = torch.cat([output, mics_mean], -1) 86 | output = self.concat_tf( 87 | output.reshape(batch_size * chunk_size * n_chunks * nmics, -1) 88 | ).reshape(batch_size, chunk_size, n_chunks, nmics, -1) 89 | output = self.norm( 90 | output.permute(0, 3, 4, 1, 2).reshape(batch_size * nmics, -1, chunk_size, n_chunks) 91 | ).reshape(batch_size, nmics, -1, chunk_size, n_chunks) 92 | 93 | output += x 94 | return output 95 | -------------------------------------------------------------------------------- /asteroid/models/README.md: -------------------------------------------------------------------------------- 1 | ### Publishing models 2 | 3 | - First, create a account on [Zenodo](https://zenodo.org/) 4 | (you can log in with GitHub directly) 5 | - Then [create an access token](https://zenodo.org/account/settings/applications/tokens/new/), 6 | we'll need one to upload anything. 7 | 8 | -------------------------------------------------------------------------------- /asteroid/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Models 2 | from .base_models import BaseModel 3 | from .x_umx import XUMX 4 | from .x_umx_control import XUMXControl 5 | from .x_umx_control_mfcc import XUMXControlMFCC 6 | 7 | # Sharing-related 8 | from .publisher import save_publishable, upload_publishable 9 | 10 | __all__ = [ 11 | "XUMX", 12 | "XUMXControl", 13 | "XUMXControlMFCC", 14 | "save_publishable", 15 | "upload_publishable", 16 | ] 17 | 18 | 19 | def register_model(custom_model): 20 | """Register a custom model, gettable with `models.get`. 21 | 22 | Args: 23 | custom_model: Custom model to register. 24 | 25 | """ 26 | if ( 27 | custom_model.__name__ in globals().keys() 28 | or custom_model.__name__.lower() in globals().keys() 29 | ): 30 | raise ValueError(f"Model {custom_model.__name__} already exists. Choose another name.") 31 | globals().update({custom_model.__name__: custom_model}) 32 | 33 | 34 | def get(identifier): 35 | """Returns an model class from a string (case-insensitive). 36 | 37 | Args: 38 | identifier (str): the model name. 39 | 40 | Returns: 41 | :class:`torch.nn.Module` 42 | """ 43 | if isinstance(identifier, str): 44 | to_get = {k.lower(): v for k, v in globals().items()} 45 | cls = to_get.get(identifier.lower()) 46 | if cls is None: 47 | raise ValueError(f"Could not interpret model name : {str(identifier)}") 48 | return cls 49 | raise ValueError(f"Could not interpret model name : {str(identifier)}") 50 | -------------------------------------------------------------------------------- /asteroid/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-postech/SSAD/642ee27f5e0c10ec9e3b643b49b12ac94843f998/asteroid/scripts/__init__.py -------------------------------------------------------------------------------- /asteroid/scripts/asteroid_cli.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import yaml 5 | import itertools 6 | import glob 7 | import warnings 8 | from typing import List 9 | 10 | import asteroid 11 | from asteroid.separate import separate 12 | from asteroid.dsp import LambdaOverlapAdd 13 | from asteroid.models.publisher import upload_publishable 14 | from asteroid.models.base_models import BaseModel 15 | 16 | 17 | SUPPORTED_EXTENSIONS = [ 18 | ".wav", 19 | ".flac", 20 | ".ogg", 21 | ] 22 | 23 | 24 | def validate_window_length(n): 25 | try: 26 | n = int(n) 27 | except ValueError: 28 | raise argparse.ArgumentTypeError("Must be integer") 29 | if n < 10: 30 | # Note: This doesn't allow for hop < 10. 31 | raise argparse.ArgumentTypeError("Must be given in samples, not seconds") 32 | return n 33 | 34 | 35 | def upload(): 36 | """CLI function to upload pretrained models.""" 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("publish_dir", type=str, help="Path to the publish dir.") 39 | parser.add_argument( 40 | "--uploader", default=None, type=str, help="Name of the uploader. Ex: `Manuel Pariente`" 41 | ) 42 | parser.add_argument( 43 | "--affiliation", default=None, type=str, help="Affiliation of the uploader. Ex `INRIA` " 44 | ) 45 | parser.add_argument("--git_username", default=None, type=str, help="Username in GitHub") 46 | parser.add_argument( 47 | "--token", default=None, type=str, help="Access token for Zenodo (or sandbox)" 48 | ) 49 | parser.add_argument( 50 | "--force_publish", 51 | default=False, 52 | action="store_true", 53 | help="Whether to without asking confirmation", 54 | ) 55 | parser.add_argument( 56 | "--use_sandbox", default=False, action="store_true", help="Whether to use Zenodo sandbox." 57 | ) 58 | args = parser.parse_args() 59 | args_as_dict = dict(vars(args)) 60 | # Load uploader info if present 61 | info_file = os.path.join(asteroid.project_root, "uploader_info.yml") 62 | if os.path.exists(info_file): 63 | uploader_info = yaml.safe_load(open(info_file, "r")) 64 | # Replace fields that where not specified (CLI dominates) 65 | for k, v in uploader_info.items(): 66 | if args_as_dict[k] == parser.get_default(k): 67 | args_as_dict[k] = v 68 | 69 | upload_publishable(**args_as_dict) 70 | # Suggest creating uploader_infos.yml 71 | if not os.path.exists(info_file): 72 | example = """ 73 | ```asteroid/uploader_infos.yml 74 | uploader: Manuel Pariente 75 | affiliation: Universite Lorraine, CNRS, Inria, LORIA, France 76 | git_username: mpariente 77 | token: XXX 78 | ``` 79 | """ 80 | print( 81 | "You can create a `uploader_infos.yml` file in `Asteroid` root" 82 | f"to stop passing your name, affiliation etc. to the CLI. " 83 | f"Here is an example {example}" 84 | ) 85 | print( 86 | "Thanks a lot for sharing your model! Don't forget to create" 87 | "a model card in the repo! " 88 | ) 89 | 90 | 91 | def infer(argv=None): 92 | """CLI function to run pretrained model inference on wav files.""" 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("url_or_path", type=str, help="Path to the pretrained model.") 95 | parser.add_argument( 96 | "--files", 97 | default=None, 98 | required=True, 99 | type=str, 100 | help="Path to the wav files to separate. Also supports list of filenames, " 101 | "directory names and globs.", 102 | nargs="+", 103 | ) 104 | parser.add_argument( 105 | "-f", 106 | "--force-overwrite", 107 | action="store_true", 108 | help="Whether to overwrite output wav files.", 109 | ) 110 | parser.add_argument( 111 | "-r", 112 | "--resample", 113 | action="store_true", 114 | help="Whether to resample wrong sample rate input files.", 115 | ) 116 | parser.add_argument( 117 | "-w", 118 | "--ola-window", 119 | type=validate_window_length, 120 | default=None, 121 | help="Overlap-add window to use. If not set (default), overlap-add is not used.", 122 | ) 123 | parser.add_argument( 124 | "--ola-hop", 125 | type=validate_window_length, 126 | default=None, 127 | help="Overlap-add hop length in samples. Defaults to ola-window // 2. Only used if --ola-window is set.", 128 | ) 129 | parser.add_argument( 130 | "--ola-window-type", 131 | type=str, 132 | default="hanning", 133 | help="Type of overlap-add window to use. Only used if --ola-window is set.", 134 | ) 135 | parser.add_argument( 136 | "--ola-no-reorder", 137 | action="store_true", 138 | help="Disable automatic reordering of overlap-add chunk. See asteroid.dsp.LambdaOverlapAdd for details. " 139 | "Only used if --ola-window is set.", 140 | ) 141 | parser.add_argument( 142 | "-o", "--output-dir", default=None, type=str, help="Output directory to save files." 143 | ) 144 | parser.add_argument( 145 | "-d", 146 | "--device", 147 | default=None, 148 | type=str, 149 | help="Device to run the model on, eg. 'cuda:0'." 150 | "Defaults to 'cuda' if CUDA is available, else 'cpu'.", 151 | ) 152 | args = parser.parse_args(argv) 153 | 154 | if args.device is None: 155 | device = "cuda" if torch.cuda.is_available() else "cpu" 156 | else: 157 | device = args.device 158 | 159 | model = BaseModel.from_pretrained(pretrained_model_conf_or_path=args.url_or_path) 160 | if args.ola_window is not None: 161 | model = LambdaOverlapAdd( 162 | model, 163 | n_src=None, 164 | window_size=args.ola_window, 165 | hop_size=args.ola_hop, 166 | window=args.ola_window_type, 167 | reorder_chunks=not args.ola_no_reorder, 168 | ) 169 | model = model.to(device) 170 | 171 | file_list = _process_files_as_list(args.files) 172 | for f in file_list: 173 | separate( 174 | model, 175 | f, 176 | force_overwrite=args.force_overwrite, 177 | output_dir=args.output_dir, 178 | resample=args.resample, 179 | ) 180 | 181 | 182 | def register_sample_rate(): 183 | """CLI to register sample rate to an Asteroid model saved without `sample_rate`, before 0.4.0.""" 184 | 185 | def _register_sample_rate(filename, sample_rate): 186 | import torch 187 | 188 | conf = torch.load(filename, map_location="cpu") 189 | conf["model_args"]["sample_rate"] = sample_rate 190 | torch.save(conf, filename) 191 | 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("filename", type=str, help="Model file to edit.") 194 | parser.add_argument("sample_rate", type=float, help="Sampling rate to add to the model.") 195 | args = parser.parse_args() 196 | 197 | _register_sample_rate(filename=args.filename, sample_rate=args.sample_rate) 198 | 199 | 200 | def _process_files_as_list(files_str: List[str]) -> List[str]: 201 | """Support filename, folder name, and globs. Returns list of filenames.""" 202 | all_files = [] 203 | for f in files_str: 204 | # Existing file 205 | if os.path.isfile(f): 206 | all_files.append(f) 207 | # Glob folder and append. 208 | elif os.path.isdir(f): 209 | all_files.extend(glob_dir(f)) 210 | else: 211 | local_list = glob.glob(f) 212 | if not local_list: 213 | warnings.warn(f"Could find any file that matched {f}", UserWarning) 214 | all_files.extend(local_list) 215 | return all_files 216 | 217 | 218 | def glob_dir(d): 219 | """Return all filenames in directory that match the supported extensions.""" 220 | return list( 221 | itertools.chain( 222 | *[ 223 | glob.glob(os.path.join(d, "**/*" + ext), recursive=True) 224 | for ext in SUPPORTED_EXTENSIONS 225 | ] 226 | ) 227 | ) 228 | -------------------------------------------------------------------------------- /asteroid/scripts/asteroid_versions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pathlib 3 | import subprocess 4 | import torch 5 | import pytorch_lightning as pl 6 | import asteroid 7 | 8 | 9 | def print_versions(): 10 | """CLI function to get info about the Asteroid and dependency versions.""" 11 | for k, v in asteroid_versions().items(): 12 | print(f"{k:20s}{v}") 13 | 14 | 15 | def asteroid_versions(): 16 | return { 17 | "Asteroid": asteroid_version(), 18 | "PyTorch": pytorch_version(), 19 | "PyTorch-Lightning": pytorch_lightning_version(), 20 | } 21 | 22 | 23 | def pytorch_version(): 24 | return torch.__version__ 25 | 26 | 27 | def pytorch_lightning_version(): 28 | return pl.__version__ 29 | 30 | 31 | def asteroid_version(): 32 | asteroid_root = pathlib.Path(__file__).parent.parent.parent 33 | if asteroid_root.joinpath(".git").exists(): 34 | return f"{asteroid.__version__}, Git checkout {get_git_version(asteroid_root)}" 35 | else: 36 | return asteroid.__version__ 37 | 38 | 39 | def get_git_version(root): 40 | def _git(*cmd): 41 | return subprocess.check_output(["git", *cmd], cwd=root).strip().decode("ascii", "ignore") 42 | 43 | try: 44 | commit = _git("rev-parse", "HEAD") 45 | branch = _git("rev-parse", "--symbolic-full-name", "--abbrev-ref", "HEAD") 46 | dirty = _git("status", "--porcelain") 47 | except Exception as err: 48 | print(f"Failed to get Git checkout info: {err}", file=sys.stderr) 49 | return "" 50 | s = commit[:12] 51 | if branch: 52 | s += f" ({branch})" 53 | if dirty: 54 | s += f", dirty tree" 55 | return s 56 | -------------------------------------------------------------------------------- /asteroid/separate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import torch 4 | import numpy as np 5 | import soundfile as sf 6 | from typing import Optional 7 | 8 | try: 9 | from typing import Protocol 10 | except ImportError: # noqa 11 | # Python < 3.8 12 | class Protocol: 13 | pass 14 | 15 | 16 | from .dsp.overlap_add import LambdaOverlapAdd 17 | from .utils import get_device 18 | 19 | 20 | class Separatable(Protocol): 21 | """Things that are separatable.""" 22 | 23 | in_channels: Optional[int] 24 | 25 | def forward_wav(self, wav: torch.Tensor, **kwargs) -> torch.Tensor: 26 | """ 27 | Args: 28 | wav (torch.Tensor): waveform tensor. 29 | Shape: 1D, 2D or 3D tensor, time last. 30 | **kwargs: Keyword arguments from `separate`. 31 | 32 | Returns: 33 | torch.Tensor: the estimated sources. 34 | Shape: [batch, n_src, time] or [n_src, time] if the input `wav` 35 | did not have a batch dim. 36 | """ 37 | ... 38 | 39 | @property 40 | def sample_rate(self) -> float: 41 | """Operating sample rate of the model (float).""" 42 | ... 43 | 44 | 45 | def separate( 46 | model: Separatable, wav, output_dir=None, force_overwrite=False, resample=False, **kwargs 47 | ): 48 | """Infer separated sources from input waveforms. 49 | Also supports filenames. 50 | 51 | Args: 52 | model (Separatable, for example asteroid.models.BaseModel): Model to use. 53 | wav (Union[torch.Tensor, numpy.ndarray, str]): waveform array/tensor. 54 | Shape: 1D, 2D or 3D tensor, time last. 55 | output_dir (str): path to save all the wav files. If None, 56 | estimated sources will be saved next to the original ones. 57 | force_overwrite (bool): whether to overwrite existing files 58 | (when separating from file). 59 | resample (bool): Whether to resample input files with wrong sample rate 60 | (when separating from file). 61 | **kwargs: keyword arguments to be passed to `forward_wav`. 62 | 63 | Returns: 64 | Union[torch.Tensor, numpy.ndarray, None], the estimated sources. 65 | (batch, n_src, time) or (n_src, time) w/o batch dim. 66 | 67 | .. note:: 68 | `separate` calls `model.forward_wav` which calls `forward` by default. 69 | For models whose `forward` doesn't have waveform tensors as input/ouput, 70 | overwrite their `forward_wav` method to separate from waveform to waveform. 71 | """ 72 | if isinstance(wav, str): 73 | file_separate( 74 | model, 75 | wav, 76 | output_dir=output_dir, 77 | force_overwrite=force_overwrite, 78 | resample=resample, 79 | **kwargs, 80 | ) 81 | elif isinstance(wav, np.ndarray): 82 | return numpy_separate(model, wav, **kwargs) 83 | elif isinstance(wav, torch.Tensor): 84 | return torch_separate(model, wav, **kwargs) 85 | else: 86 | raise ValueError( 87 | f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}" 88 | ) 89 | 90 | 91 | @torch.no_grad() 92 | def torch_separate(model: Separatable, wav: torch.Tensor, **kwargs) -> torch.Tensor: 93 | """Core logic of `separate`.""" 94 | if model.in_channels is not None and wav.shape[-2] != model.in_channels: 95 | raise RuntimeError( 96 | f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels." 97 | f"Please match the number of channels." 98 | ) 99 | # Handle device placement 100 | input_device = get_device(wav, default="cpu") 101 | model_device = get_device(model, default="cpu") 102 | wav = wav.to(model_device) 103 | # Forward 104 | separate_func = getattr(model, "forward_wav", model) 105 | out_wavs = separate_func(wav, **kwargs) 106 | 107 | # FIXME: for now this is the best we can do. 108 | out_wavs *= wav.abs().sum() / (out_wavs.abs().sum()) 109 | 110 | # Back to input device (and numpy if necessary) 111 | out_wavs = out_wavs.to(input_device) 112 | return out_wavs 113 | 114 | 115 | def numpy_separate(model: Separatable, wav: np.ndarray, **kwargs) -> np.ndarray: 116 | """Numpy interface to `separate`.""" 117 | wav = torch.from_numpy(wav) 118 | out_wavs = torch_separate(model, wav, **kwargs) 119 | out_wavs = out_wavs.data.numpy() 120 | return out_wavs 121 | 122 | 123 | def file_separate( 124 | model: Separatable, 125 | filename: str, 126 | output_dir=None, 127 | force_overwrite=False, 128 | resample=False, 129 | **kwargs, 130 | ) -> None: 131 | """Filename interface to `separate`.""" 132 | 133 | if not hasattr(model, "sample_rate"): 134 | raise TypeError( 135 | f"This function requires your model ({type(model).__name__}) to have a " 136 | "'sample_rate' attribute. See `BaseModel.sample_rate` for details." 137 | ) 138 | 139 | # Estimates will be saved as filename_est1.wav etc... 140 | base, _ = os.path.splitext(filename) 141 | if output_dir is not None: 142 | base = os.path.join(output_dir, os.path.basename(base)) 143 | save_name_template = base + "_est{}.wav" 144 | 145 | # Bail out early if an estimate file already exists and we shall not overwrite. 146 | est1_filename = save_name_template.format(1) 147 | if os.path.isfile(est1_filename) and not force_overwrite: 148 | warnings.warn( 149 | f"File {est1_filename} already exists, pass `force_overwrite=True` to overwrite it", 150 | UserWarning, 151 | ) 152 | return 153 | 154 | # SoundFile wav shape: [time, n_chan] 155 | wav, fs = _load_audio(filename) 156 | if wav.shape[-1] > 1: 157 | warnings.warn( 158 | f"Received multichannel signal with {wav.shape[-1]} signals, " 159 | f"using the first channel only." 160 | ) 161 | # FIXME: support only single-channel files for now. 162 | if resample: 163 | wav = _resample(wav[:, 0], orig_sr=fs, target_sr=int(model.sample_rate))[:, None] 164 | elif fs != model.sample_rate: 165 | raise RuntimeError( 166 | f"Received a signal with a sampling rate of {fs}Hz for a model " 167 | f"of {model.sample_rate}Hz. You can pass `resample=True` to resample automatically." 168 | ) 169 | # Pass wav as [batch, n_chan, time]; here: [1, chan, time] 170 | wav = wav.T[None] 171 | (est_srcs,) = numpy_separate(model, wav, **kwargs) 172 | # Resample to original sr 173 | est_srcs = [ 174 | _resample(est_src, orig_sr=int(model.sample_rate), target_sr=fs) for est_src in est_srcs 175 | ] 176 | 177 | # Save wav files to filename_est1.wav etc... 178 | for src_idx, est_src in enumerate(est_srcs, 1): 179 | sf.write(save_name_template.format(src_idx), est_src, fs) 180 | 181 | 182 | def _resample(wav, orig_sr, target_sr, _resamplers={}): 183 | from julius import ResampleFrac 184 | 185 | if orig_sr == target_sr: 186 | return wav 187 | 188 | # Cache ResampleFrac instance to speed up resampling if we're repeatedly 189 | # resampling between the same two sample rates. 190 | try: 191 | resampler = _resamplers[(orig_sr, target_sr)] 192 | except KeyError: 193 | resampler = _resamplers[(orig_sr, target_sr)] = ResampleFrac(orig_sr, target_sr) 194 | 195 | return resampler(torch.from_numpy(wav)).numpy() 196 | 197 | 198 | def _load_audio(filename): 199 | try: 200 | return sf.read(filename, dtype="float32", always_2d=True) 201 | except Exception as sf_err: 202 | # If soundfile fails to load the file, try with librosa next, which uses 203 | # the 'audioread' library to support a wide range of audio formats. 204 | # We try with soundfile first because librosa takes a long time to import. 205 | try: 206 | import librosa 207 | except ModuleNotFoundError: 208 | raise RuntimeError( 209 | f"Could not load file {filename!r} with soundfile. " 210 | "Install 'librosa' to be able to load more file types." 211 | ) from sf_err 212 | 213 | wav, sr = librosa.load(filename, dtype="float32", sr=None) 214 | # Always return wav of shape [time, n_chan] 215 | if wav.ndim == 1: 216 | return wav[:, None], sr 217 | else: 218 | return wav.T, sr 219 | -------------------------------------------------------------------------------- /asteroid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .generic_utils import ( 2 | average_arrays_in_dic, 3 | flatten_dict, 4 | get_wav_random_start_stop, 5 | has_arg, 6 | unet_decoder_args, 7 | ) 8 | from .parser_utils import ( 9 | prepare_parser_from_dict, 10 | parse_args_as_dict, 11 | str_int_float, 12 | str2bool, 13 | str2bool_arg, 14 | isfloat, 15 | isint, 16 | ) 17 | from .torch_utils import tensors_to_device, to_cuda, get_device 18 | 19 | # The functions above were all in asteroid/utils.py before refactoring into 20 | # asteroid/utils/*_utils.py files. They are imported for backward compatibility. 21 | 22 | __all__ = [ 23 | "prepare_parser_from_dict", 24 | "parse_args_as_dict", 25 | "str_int_float", 26 | "str2bool", 27 | "str2bool_arg", 28 | "isfloat", 29 | "isint", 30 | "tensors_to_device", 31 | "to_cuda", 32 | "get_device", 33 | "has_arg", 34 | "flatten_dict", 35 | "average_arrays_in_dic", 36 | "get_wav_random_start_stop", 37 | "unet_decoder_args", 38 | ] 39 | -------------------------------------------------------------------------------- /asteroid/utils/deprecation_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import inspect 3 | from functools import wraps 4 | 5 | 6 | class VisibleDeprecationWarning(UserWarning): 7 | """Visible deprecation warning. 8 | 9 | By default, python will not show deprecation warnings, so this class 10 | can be used when a very visible warning is helpful, for example because 11 | the usage is most likely a user bug. 12 | 13 | """ 14 | 15 | # Taken from numpy 16 | 17 | 18 | def mark_deprecated(message, version=None): 19 | """Decorator to add deprecation message. 20 | 21 | Args: 22 | message: Migration steps to be given to users. 23 | """ 24 | 25 | def decorator(func): 26 | @wraps(func) 27 | def wrapped(*args, **kwargs): 28 | from_what = "a future release" if version is None else f"asteroid v{version}" 29 | warn_message = ( 30 | f"{func.__module__}.{func.__name__} has been deprecated " 31 | f"and will be removed from {from_what}. " 32 | f"{message}" 33 | ) 34 | warnings.warn(warn_message, VisibleDeprecationWarning, stacklevel=2) 35 | return func(*args, **kwargs) 36 | 37 | return wrapped 38 | 39 | return decorator 40 | 41 | 42 | def is_overridden(method_name, obj, parent=None) -> bool: 43 | """Check if `method_name` from parent is overridden in `obj`. 44 | 45 | Args: 46 | method_name (str): Name of the method. 47 | obj: Instance or class that potentially overrode the method. 48 | parent: parent class with which to compare. If None, traverse the MRO 49 | for the first parent that has the method. 50 | 51 | Raises RuntimeError if `parent` is not a parent class and if `parent` 52 | doesn't have the method. Or, if `parent` was None, that none of the 53 | potential parents had the method. 54 | """ 55 | 56 | def get_mro(cls): 57 | try: 58 | return inspect.getmro(cls) 59 | except AttributeError: 60 | return inspect.getmro(cls.__class__) 61 | 62 | def first_parent_with_method(fn, mro_list): 63 | for cls in mro_list[::-1]: 64 | if hasattr(cls, fn): 65 | return cls 66 | return None 67 | 68 | if not hasattr(obj, method_name): 69 | return False 70 | 71 | try: 72 | instance_attr = getattr(obj, method_name) 73 | except AttributeError: 74 | return False 75 | return False 76 | 77 | mro = get_mro(obj)[1:] # All parent classes in order, self excluded 78 | parent = parent if parent is not None else first_parent_with_method(method_name, mro) 79 | 80 | if parent not in mro: 81 | raise RuntimeError(f"`{obj}` has no parent that defined method {method_name}`.") 82 | 83 | if not hasattr(parent, method_name): 84 | raise RuntimeError(f"Parent `{parent}` does have method `{method_name}`") 85 | 86 | super_attr = getattr(parent, method_name) 87 | return instance_attr.__code__ is not super_attr.__code__ 88 | -------------------------------------------------------------------------------- /asteroid/utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections.abc import MutableMapping 3 | import numpy as np 4 | 5 | 6 | def has_arg(fn, name): 7 | """Checks if a callable accepts a given keyword argument. 8 | 9 | Args: 10 | fn (callable): Callable to inspect. 11 | name (str): Check if ``fn`` can be called with ``name`` as a keyword 12 | argument. 13 | 14 | Returns: 15 | bool: whether ``fn`` accepts a ``name`` keyword argument. 16 | """ 17 | signature = inspect.signature(fn) 18 | parameter = signature.parameters.get(name) 19 | if parameter is None: 20 | return False 21 | return parameter.kind in ( 22 | inspect.Parameter.POSITIONAL_OR_KEYWORD, 23 | inspect.Parameter.KEYWORD_ONLY, 24 | ) 25 | 26 | 27 | def flatten_dict(d, parent_key="", sep="_"): 28 | """Flattens a dictionary into a single-level dictionary while preserving 29 | parent keys. Taken from 30 | `SO `_ 31 | 32 | Args: 33 | d (MutableMapping): Dictionary to be flattened. 34 | parent_key (str): String to use as a prefix to all subsequent keys. 35 | sep (str): String to use as a separator between two key levels. 36 | 37 | Returns: 38 | dict: Single-level dictionary, flattened. 39 | """ 40 | items = [] 41 | for k, v in d.items(): 42 | new_key = parent_key + sep + k if parent_key else k 43 | if isinstance(v, MutableMapping): 44 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 45 | else: 46 | items.append((new_key, v)) 47 | return dict(items) 48 | 49 | 50 | def average_arrays_in_dic(dic): 51 | """Take average of numpy arrays in a dictionary. 52 | 53 | Args: 54 | dic (dict): Input dictionary to take average from 55 | 56 | Returns: 57 | dict: New dictionary with array averaged. 58 | 59 | """ 60 | # Copy dic first 61 | dic = dict(dic) 62 | for k, v in dic.items(): 63 | if isinstance(v, np.ndarray): 64 | dic[k] = float(v.mean()) 65 | return dic 66 | 67 | 68 | def get_wav_random_start_stop(signal_len, desired_len=4 * 8000): 69 | """Get indexes for a chunk of signal of a given length. 70 | 71 | Args: 72 | signal_len (int): length of the signal to trim. 73 | desired_len (int): the length of [start:stop] 74 | 75 | Returns: 76 | tuple: random start integer, stop integer. 77 | """ 78 | if desired_len is None: 79 | return 0, signal_len 80 | rand_start = np.random.randint(0, max(1, signal_len - desired_len)) 81 | stop = min(signal_len, rand_start + desired_len) 82 | return rand_start, stop 83 | 84 | 85 | def unet_decoder_args(encoders, *, skip_connections): 86 | """Get list of decoder arguments for upsampling (right) side of a symmetric u-net, 87 | given the arguments used to construct the encoder. 88 | 89 | Args: 90 | encoders (tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding)): 91 | List of arguments used to construct the encoders 92 | skip_connections (bool): Whether to include skip connections in the 93 | calculation of decoder input channels. 94 | 95 | Return: 96 | tuple of length `N` of tuples of (in_chan, out_chan, kernel_size, stride, padding): 97 | Arguments to be used to construct decoders 98 | """ 99 | decoder_args = [] 100 | for enc_in_chan, enc_out_chan, enc_kernel_size, enc_stride, enc_padding in reversed(encoders): 101 | if skip_connections and decoder_args: 102 | skip_in_chan = enc_out_chan 103 | else: 104 | skip_in_chan = 0 105 | decoder_args.append( 106 | (enc_out_chan + skip_in_chan, enc_in_chan, enc_kernel_size, enc_stride, enc_padding) 107 | ) 108 | return tuple(decoder_args) 109 | -------------------------------------------------------------------------------- /asteroid/utils/hub_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from hashlib import sha256 4 | from typing import Union, Dict, List 5 | 6 | import requests 7 | from torch import hub 8 | import huggingface_hub 9 | 10 | 11 | CACHE_DIR = os.getenv( 12 | "ASTEROID_CACHE", 13 | os.path.expanduser("~/.cache/torch/asteroid"), 14 | ) 15 | MODELS_URLS_HASHTABLE = { 16 | "mpariente/ConvTasNet_WHAM!_sepclean": "https://zenodo.org/record/3862942/files/model.pth?download=1", 17 | "mpariente/DPRNNTasNet_WHAM!_sepclean": "https://zenodo.org/record/3873670/files/model.pth?download=1", 18 | "mpariente/DPRNNTasNet(ks=16)_WHAM!_sepclean": "https://zenodo.org/record/3903795/files/model.pth?download=1", 19 | "Cosentino/ConvTasNet_LibriMix_sep_clean": "https://zenodo.org/record/3873572/files/model.pth?download=1", 20 | "Cosentino/ConvTasNet_LibriMix_sep_noisy": "https://zenodo.org/record/3874420/files/model.pth?download=1", 21 | "brijmohan/ConvTasNet_Libri1Mix_enhsingle": "https://zenodo.org/record/3970768/files/model.pth?download=1", 22 | "groadabike/ConvTasNet_DAMP-VSEP_enhboth": "https://zenodo.org/record/3994193/files/model.pth?download=1", 23 | "popcornell/DeMask_Surgical_mask_speech_enhancement_v1": "https://zenodo.org/record/3997047/files/model.pth?download=1", 24 | "popcornell/DPRNNTasNet_WHAM_enhancesingle": "https://zenodo.org/record/3998647/files/model.pth?download=1", 25 | "tmirzaev-dotcom/ConvTasNet_Libri3Mix_sepnoisy": "https://zenodo.org/record/4020529/files/model.pth?download=1", 26 | "mhu-coder/ConvTasNet_Libri1Mix_enhsingle": "https://zenodo.org/record/4301955/files/model.pth?download=1", 27 | "r-sawata/XUMX_MUSDB18_music_separation": "https://zenodo.org/record/4704231/files/pretrained_xumx.pth?download=1", 28 | } 29 | 30 | SR_HASHTABLE = {k: 8000.0 if not "DeMask" in k else 16000.0 for k in MODELS_URLS_HASHTABLE} 31 | 32 | 33 | def cached_download(filename_or_url): 34 | """Download from URL and cache the result in ASTEROID_CACHE. 35 | 36 | Args: 37 | filename_or_url (str): Name of a model as named on the Zenodo Community 38 | page (ex: ``"mpariente/ConvTasNet_WHAM!_sepclean"``), or model id from 39 | the Hugging Face model hub (ex: ``"julien-c/DPRNNTasNet-ks16_WHAM_sepclean"``), 40 | or a URL to a model file (ex: ``"https://zenodo.org/.../model.pth"``), or a filename 41 | that exists locally (ex: ``"local/tmp_model.pth"``) 42 | 43 | Returns: 44 | str, normalized path to the downloaded (or not) model 45 | """ 46 | from .. import __version__ as asteroid_version # Avoid circular imports 47 | 48 | if os.path.isfile(filename_or_url): 49 | return filename_or_url 50 | 51 | if filename_or_url.startswith(huggingface_hub.HUGGINGFACE_CO_URL_HOME): 52 | filename_or_url = filename_or_url[len(huggingface_hub.HUGGINGFACE_CO_URL_HOME) :] 53 | 54 | if filename_or_url.startswith(("http://", "https://")): 55 | url = filename_or_url 56 | elif filename_or_url in MODELS_URLS_HASHTABLE: 57 | url = MODELS_URLS_HASHTABLE[filename_or_url] 58 | else: 59 | # Finally, let's try to find it on Hugging Face model hub 60 | # e.g. julien-c/DPRNNTasNet-ks16_WHAM_sepclean is a valid model id 61 | # and julien-c/DPRNNTasNet-ks16_WHAM_sepclean@main supports specifying a commit/branch/tag. 62 | if "@" in filename_or_url: 63 | model_id = filename_or_url.split("@")[0] 64 | revision = filename_or_url.split("@")[1] 65 | else: 66 | model_id = filename_or_url 67 | revision = None 68 | url = huggingface_hub.hf_hub_url( 69 | model_id, filename=huggingface_hub.PYTORCH_WEIGHTS_NAME, revision=revision 70 | ) 71 | return huggingface_hub.cached_download( 72 | url, 73 | cache_dir=get_cache_dir(), 74 | library_name="asteroid", 75 | library_version=asteroid_version, 76 | ) 77 | cached_filename = url_to_filename(url) 78 | cached_dir = os.path.join(get_cache_dir(), cached_filename) 79 | cached_path = os.path.join(cached_dir, "model.pth") 80 | 81 | os.makedirs(cached_dir, exist_ok=True) 82 | if not os.path.isfile(cached_path): 83 | hub.download_url_to_file(url, cached_path) 84 | return cached_path 85 | # It was already downloaded 86 | print(f"Using cached model `{filename_or_url}`") 87 | return cached_path 88 | 89 | 90 | def url_to_filename(url): 91 | """Consistently convert ``url`` into a filename.""" 92 | _bytes = url.encode("utf-8") 93 | _hash = sha256(_bytes) 94 | filename = _hash.hexdigest() 95 | return filename 96 | 97 | 98 | def get_cache_dir(): 99 | os.makedirs(CACHE_DIR, exist_ok=True) 100 | return CACHE_DIR 101 | 102 | 103 | @lru_cache() 104 | def model_list( 105 | endpoint=huggingface_hub.HUGGINGFACE_CO_URL_HOME, name_only=False 106 | ) -> Union[str, List[Dict]]: 107 | """Get the public list of all the models on huggingface with an 'asteroid' tag.""" 108 | path = "{}api/models?full=true&filter=asteroid".format(endpoint) 109 | r = requests.get(path) 110 | r.raise_for_status() 111 | all_models = r.json() 112 | if name_only: 113 | return [x["modelId"] for x in all_models] 114 | return all_models 115 | -------------------------------------------------------------------------------- /asteroid/utils/parser_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def prepare_parser_from_dict(dic, parser=None): 5 | """Prepare an argparser from a dictionary. 6 | 7 | Args: 8 | dic (dict): Two-level config dictionary with unique bottom-level keys. 9 | parser (argparse.ArgumentParser, optional): If a parser already 10 | exists, add the keys from the dictionary on the top of it. 11 | 12 | Returns: 13 | argparse.ArgumentParser: 14 | Parser instance with groups corresponding to the first level keys 15 | and arguments corresponding to the second level keys with default 16 | values given by the values. 17 | """ 18 | 19 | def standardized_entry_type(value): 20 | """If the default value is None, replace NoneType by str_int_float. 21 | If the default value is boolean, look for boolean strings.""" 22 | if value is None: 23 | return str_int_float 24 | if isinstance(str2bool(value), bool): 25 | return str2bool_arg 26 | return type(value) 27 | 28 | if parser is None: 29 | parser = argparse.ArgumentParser() 30 | for k in dic.keys(): 31 | group = parser.add_argument_group(k) 32 | for kk in dic[k].keys(): 33 | entry_type = standardized_entry_type(dic[k][kk]) 34 | group.add_argument("--" + kk, default=dic[k][kk], type=entry_type) 35 | return parser 36 | 37 | 38 | def str_int_float(value): 39 | """Type to convert strings to int, float (in this order) if possible. 40 | 41 | Args: 42 | value (str): Value to convert. 43 | 44 | Returns: 45 | int, float, str: Converted value. 46 | """ 47 | if isint(value): 48 | return int(value) 49 | if isfloat(value): 50 | return float(value) 51 | elif isinstance(value, str): 52 | return value 53 | 54 | 55 | def str2bool(value): 56 | """ Type to convert strings to Boolean (returns input if not boolean) """ 57 | if not isinstance(value, str): 58 | return value 59 | if value.lower() in ("yes", "true", "y", "1"): 60 | return True 61 | elif value.lower() in ("no", "false", "n", "0"): 62 | return False 63 | else: 64 | return value 65 | 66 | 67 | def str2bool_arg(value): 68 | """ Argparse type to convert strings to Boolean """ 69 | value = str2bool(value) 70 | if isinstance(value, bool): 71 | return value 72 | raise argparse.ArgumentTypeError("Boolean value expected.") 73 | 74 | 75 | def isfloat(value): 76 | """Computes whether `value` can be cast to a float. 77 | 78 | Args: 79 | value (str): Value to check. 80 | 81 | Returns: 82 | bool: Whether `value` can be cast to a float. 83 | 84 | """ 85 | try: 86 | float(value) 87 | return True 88 | except ValueError: 89 | return False 90 | 91 | 92 | def isint(value): 93 | """Computes whether `value` can be cast to an int 94 | 95 | Args: 96 | value (str): Value to check. 97 | 98 | Returns: 99 | bool: Whether `value` can be cast to an int. 100 | 101 | """ 102 | try: 103 | int(value) 104 | return True 105 | except ValueError: 106 | return False 107 | 108 | 109 | def parse_args_as_dict(parser, return_plain_args=False, args=None): 110 | """Get a dict of dicts out of process `parser.parse_args()` 111 | 112 | Top-level keys corresponding to groups and bottom-level keys corresponding 113 | to arguments. Under `'main_args'`, the arguments which don't belong to a 114 | argparse group (i.e main arguments defined before parsing from a dict) can 115 | be found. 116 | 117 | Args: 118 | parser (argparse.ArgumentParser): ArgumentParser instance containing 119 | groups. Output of `prepare_parser_from_dict`. 120 | return_plain_args (bool): Whether to return the output or 121 | `parser.parse_args()`. 122 | args (list): List of arguments as read from the command line. 123 | Used for unit testing. 124 | 125 | Returns: 126 | dict: 127 | Dictionary of dictionaries containing the arguments. Optionally the 128 | direct output `parser.parse_args()`. 129 | """ 130 | args = parser.parse_args(args=args) 131 | args_dic = {} 132 | for group in parser._action_groups: 133 | group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions} 134 | args_dic[group.title] = group_dict 135 | args_dic["main_args"] = args_dic["optional arguments"] 136 | del args_dic["optional arguments"] 137 | if return_plain_args: 138 | return args_dic, args 139 | return args_dic 140 | -------------------------------------------------------------------------------- /asteroid/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | 4 | 5 | class DummyDataset(data.Dataset): 6 | def __init__(self): 7 | self.inp_dim = 10 8 | self.out_dim = 10 9 | 10 | def __len__(self): 11 | return 20 12 | 13 | def __getitem__(self, idx): 14 | return torch.randn(1, self.inp_dim), torch.randn(1, self.out_dim) 15 | 16 | 17 | class DummyWaveformDataset(data.Dataset): 18 | def __init__(self, total=12, n_src=3, len_wave=16000): 19 | self.inp_len_wave = len_wave 20 | self.out_len_wave = len_wave 21 | self.total = total 22 | self.inp_n_sig = 1 23 | self.out_n_sig = n_src 24 | 25 | def __len__(self): 26 | return self.total 27 | 28 | def __getitem__(self, idx): 29 | mixed = torch.randn(self.inp_n_sig, self.inp_len_wave) 30 | srcs = torch.randn(self.out_n_sig, self.out_len_wave) 31 | return mixed, srcs 32 | 33 | 34 | def torch_version_tuple(): 35 | version, *suffix = torch.__version__.split("+") 36 | return tuple(map(int, version.split("."))) + tuple(suffix) 37 | -------------------------------------------------------------------------------- /asteroid/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | from torch import nn 5 | from collections import OrderedDict 6 | 7 | 8 | def to_cuda(tensors): # pragma: no cover 9 | """Transfer tensor, dict or list of tensors to GPU. 10 | 11 | Args: 12 | tensors (:class:`torch.Tensor`, list or dict): May be a single, a 13 | list or a dictionary of tensors. 14 | 15 | Returns: 16 | :class:`torch.Tensor`: 17 | Same as input but transferred to cuda. Goes through lists and dicts 18 | and transfers the torch.Tensor to cuda. Leaves the rest untouched. 19 | """ 20 | if isinstance(tensors, torch.Tensor): 21 | return tensors.cuda() 22 | if isinstance(tensors, list): 23 | return [to_cuda(tens) for tens in tensors] 24 | if isinstance(tensors, dict): 25 | for key in tensors.keys(): 26 | tensors[key] = to_cuda(tensors[key]) 27 | return tensors 28 | raise TypeError( 29 | "tensors must be a tensor or a list or dict of tensors. " 30 | " Got tensors of type {}".format(type(tensors)) 31 | ) 32 | 33 | 34 | def tensors_to_device(tensors, device): 35 | """Transfer tensor, dict or list of tensors to device. 36 | 37 | Args: 38 | tensors (:class:`torch.Tensor`): May be a single, a list or a 39 | dictionary of tensors. 40 | device (:class: `torch.device`): the device where to place the tensors. 41 | 42 | Returns: 43 | Union [:class:`torch.Tensor`, list, tuple, dict]: 44 | Same as input but transferred to device. 45 | Goes through lists and dicts and transfers the torch.Tensor to 46 | device. Leaves the rest untouched. 47 | """ 48 | if isinstance(tensors, torch.Tensor): 49 | return tensors.to(device) 50 | elif isinstance(tensors, (list, tuple)): 51 | return [tensors_to_device(tens, device) for tens in tensors] 52 | elif isinstance(tensors, dict): 53 | for key in tensors.keys(): 54 | tensors[key] = tensors_to_device(tensors[key], device) 55 | return tensors 56 | else: 57 | return tensors 58 | 59 | 60 | def get_device(tensor_or_module, default=None): 61 | """Get the device of a tensor or a module. 62 | 63 | Args: 64 | tensor_or_module (Union[torch.Tensor, torch.nn.Module]): 65 | The object to get the device from. Can be a ``torch.Tensor``, 66 | a ``torch.nn.Module``, or anything else that has a ``device`` attribute 67 | or a ``parameters() -> Iterator[torch.Tensor]`` method. 68 | default (Optional[Union[str, torch.device]]): If the device can not be 69 | determined, return this device instead. If ``None`` (the default), 70 | raise a ``TypeError`` instead. 71 | 72 | Returns: 73 | torch.device: The device that ``tensor_or_module`` is on. 74 | """ 75 | if hasattr(tensor_or_module, "device"): 76 | return tensor_or_module.device 77 | elif hasattr(tensor_or_module, "parameters"): 78 | return next(tensor_or_module.parameters()).device 79 | elif default is None: 80 | raise TypeError(f"Don't know how to get device of {type(tensor_or_module)} object") 81 | else: 82 | return torch.device(default) 83 | 84 | 85 | def is_tracing(): 86 | # Taken for pytorch for compat in 1.6.0 87 | """ 88 | Returns ``True`` in tracing (if a function is called during the tracing of 89 | code with ``torch.jit.trace``) and ``False`` otherwise. 90 | """ 91 | return torch._C._is_tracing() 92 | 93 | 94 | def script_if_tracing(fn): 95 | # Taken for pytorch for compat in 1.6.0 96 | """ 97 | Compiles ``fn`` when it is first called during tracing. ``torch.jit.script`` 98 | has a non-negligible start up time when it is first called due to 99 | lazy-initializations of many compiler builtins. Therefore you should not use 100 | it in library code. However, you may want to have parts of your library work 101 | in tracing even if they use control flow. In these cases, you should use 102 | ``@torch.jit.script_if_tracing`` to substitute for 103 | ``torch.jit.script``. 104 | 105 | Arguments: 106 | fn: A function to compile. 107 | 108 | Returns: 109 | If called during tracing, a :class:`ScriptFunction` created by ` 110 | `torch.jit.script`` is returned. Otherwise, the original function ``fn`` is returned. 111 | """ 112 | 113 | @functools.wraps(fn) 114 | def wrapper(*args, **kwargs): 115 | if not is_tracing(): 116 | # Not tracing, don't do anything 117 | return fn(*args, **kwargs) 118 | 119 | compiled_fn = torch.jit.script(wrapper.__original_fn) # type: ignore 120 | return compiled_fn(*args, **kwargs) 121 | 122 | wrapper.__original_fn = fn # type: ignore 123 | wrapper.__script_if_tracing_wrapper = True # type: ignore 124 | 125 | return wrapper 126 | 127 | 128 | @script_if_tracing 129 | def pad_x_to_y(x: torch.Tensor, y: torch.Tensor, axis: int = -1) -> torch.Tensor: 130 | """Right-pad or right-trim first argument to have same size as second argument 131 | 132 | Args: 133 | x (torch.Tensor): Tensor to be padded. 134 | y (torch.Tensor): Tensor to pad `x` to. 135 | axis (int): Axis to pad on. 136 | 137 | Returns: 138 | torch.Tensor, `x` padded to match `y`'s shape. 139 | """ 140 | if axis != -1: 141 | raise NotImplementedError 142 | inp_len = y.shape[axis] 143 | output_len = x.shape[axis] 144 | return nn.functional.pad(x, [0, inp_len - output_len]) 145 | 146 | 147 | def load_state_dict_in(state_dict, model): 148 | """Strictly loads state_dict in model, or the next submodel. 149 | Useful to load standalone model after training it with System. 150 | 151 | Args: 152 | state_dict (OrderedDict): the state_dict to load. 153 | model (torch.nn.Module): the model to load it into 154 | 155 | Returns: 156 | torch.nn.Module: model with loaded weights. 157 | 158 | .. note:: Keys in a state_dict look like ``object1.object2.layer_name.weight.etc`` 159 | We first try to load the model in the classic way. 160 | If this fail we removes the first left part of the key to obtain 161 | ``object2.layer_name.weight.etc``. 162 | Blindly loading with ``strictly=False`` should be done with some logging 163 | of the missing keys in the state_dict and the model. 164 | 165 | """ 166 | try: 167 | # This can fail if the model was included into a bigger nn.Module 168 | # object. For example, into System. 169 | model.load_state_dict(state_dict, strict=True) 170 | except RuntimeError: 171 | # keys look like object1.object2.layer_name.weight.etc 172 | # The following will remove the first left part of the key to obtain 173 | # object2.layer_name.weight.etc. 174 | # Blindly loading with strictly=False should be done with some 175 | # new_state_dict of the missing keys in the state_dict and the model. 176 | new_state_dict = OrderedDict() 177 | for k, v in state_dict.items(): 178 | new_k = k[k.find(".") + 1 :] 179 | new_state_dict[new_k] = v 180 | model.load_state_dict(new_state_dict, strict=True) 181 | return model 182 | 183 | 184 | def are_models_equal(model1, model2): 185 | """Check for weights equality between models. 186 | 187 | Args: 188 | model1 (nn.Module): model instance to be compared. 189 | model2 (nn.Module): second model instance to be compared. 190 | 191 | Returns: 192 | bool: Whether all model weights are equal. 193 | """ 194 | for p1, p2 in zip(model1.parameters(), model2.parameters()): 195 | if p1.data.ne(p2.data).sum() > 0: 196 | return False 197 | return True 198 | 199 | 200 | @script_if_tracing 201 | def jitable_shape(tensor): 202 | """Gets shape of ``tensor`` as ``torch.Tensor`` type for jit compiler 203 | 204 | .. note:: 205 | Returning ``tensor.shape`` of ``tensor.size()`` directly is not torchscript 206 | compatible as return type would not be supported. 207 | 208 | Args: 209 | tensor (torch.Tensor): Tensor 210 | 211 | Returns: 212 | torch.Tensor: Shape of ``tensor`` 213 | """ 214 | return torch.tensor(tensor.shape) 215 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: asteroid 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - Cython 8 | - pip: 9 | - -r file:requirements.txt 10 | - -e . 11 | -------------------------------------------------------------------------------- /informed-X-UMX/local/conf_base.yml: -------------------------------------------------------------------------------- 1 | # Training config 2 | training: 3 | epochs: 1000 4 | batch_size: 32 5 | loss_combine_sources: yes 6 | loss_use_multidomain: yes 7 | loss_pit: no 8 | mix_coef: 10.0 9 | val_dur: 6.0 10 | # Optim config 11 | optim: 12 | optimizer: adam 13 | lr: 0.001 14 | patience: 1000 15 | lr_decay_patience: 40 16 | lr_decay_gamma: 0.5 17 | weight_decay: 0.00001 18 | # Data config 19 | data: 20 | train_dir: /dev/shm/mimii 21 | split: 6dB 22 | output: /hdd/hdd1/sss/xumx/0908_base_1 23 | sample_rate: 16000 24 | num_workers: 4 25 | seed: 42 26 | seq_dur: 6.0 27 | samples_per_track: 2 28 | source_augmentations: 29 | - gain 30 | - delay 31 | machine_type: valve 32 | sources: 33 | - id_00 34 | - id_02 35 | - id_04 36 | - id_06 37 | use_control: False 38 | control_type: rms 39 | task_random: True 40 | source_random: False 41 | num_src_in_mix: 2 42 | # Network config 43 | model: 44 | pretrained: null 45 | bidirectional: yes 46 | window_length: 4096 47 | in_chan: 4096 48 | nhop: 1024 49 | hidden_size: 512 50 | bandwidth: 16000 51 | nb_channels: 2 52 | spec_power: 1 53 | -------------------------------------------------------------------------------- /informed-X-UMX/local/conf_informed.yml: -------------------------------------------------------------------------------- 1 | # Training config 2 | training: 3 | epochs: 1000 4 | batch_size: 32 5 | loss_combine_sources: yes 6 | loss_use_multidomain: yes 7 | loss_pit: no 8 | mix_coef: 10.0 9 | val_dur: 6.0 10 | # Optim config 11 | optim: 12 | optimizer: adam 13 | lr: 0.001 14 | patience: 1000 15 | lr_decay_patience: 40 16 | lr_decay_gamma: 0.5 17 | weight_decay: 0.00001 18 | # Data config 19 | data: 20 | train_dir: /dev/shm/mimii 21 | split: 6dB 22 | output: /hdd/hdd1/sss/xumx/0908_informed_1 23 | sample_rate: 16000 24 | num_workers: 4 25 | seed: 42 26 | seq_dur: 6.0 27 | samples_per_track: 2 28 | source_augmentations: 29 | - gain 30 | - delay 31 | machine_type: valve 32 | sources: 33 | - id_00 34 | - id_02 35 | - id_04 36 | - id_06 37 | use_control: True 38 | control_type: rms 39 | task_random: False 40 | source_random: True 41 | num_src_in_mix: 3 42 | # Network config 43 | model: 44 | pretrained: null 45 | bidirectional: yes 46 | window_length: 4096 47 | in_chan: 4096 48 | nhop: 1024 49 | hidden_size: 512 50 | bandwidth: 16000 51 | nb_channels: 2 52 | spec_power: 1 53 | -------------------------------------------------------------------------------- /informed-X-UMX/local/conf_pit.yml: -------------------------------------------------------------------------------- 1 | # Training config 2 | training: 3 | epochs: 1000 4 | batch_size: 32 5 | loss_combine_sources: yes 6 | loss_use_multidomain: yes 7 | loss_pit: yes 8 | mix_coef: 10.0 9 | val_dur: 6.0 10 | # Optim config 11 | optim: 12 | optimizer: adam 13 | lr: 0.001 14 | patience: 1000 15 | lr_decay_patience: 40 16 | lr_decay_gamma: 0.5 17 | weight_decay: 0.00001 18 | # Data config 19 | data: 20 | train_dir: /dev/shm/mimii 21 | split: 6dB 22 | output: /hdd/hdd1/sss/xumx/0908_pit_1 23 | sample_rate: 16000 24 | num_workers: 4 25 | seed: 42 26 | seq_dur: 6.0 27 | samples_per_track: 2 28 | source_augmentations: 29 | - gain 30 | - delay 31 | machine_type: valve 32 | sources: 33 | - id_00 34 | - id_02 35 | - id_04 36 | - id_06 37 | use_control: False 38 | control_type: rms 39 | task_random: True 40 | source_random: False 41 | num_src_in_mix: 2 42 | # Network config 43 | model: 44 | pretrained: null 45 | bidirectional: yes 46 | window_length: 4096 47 | in_chan: 4096 48 | nhop: 1024 49 | hidden_size: 512 50 | bandwidth: 16000 51 | nb_channels: 2 52 | spec_power: 1 53 | -------------------------------------------------------------------------------- /informed-X-UMX/local/dataloader.py: -------------------------------------------------------------------------------- 1 | from asteroid.data import MIMIISliderDataset, MIMIIValveDataset 2 | import torch 3 | from pathlib import Path 4 | import numpy as np 5 | import random 6 | 7 | train_tracks = [f"{n:0>3}" for n in range(10, 350)] 8 | 9 | 10 | 11 | 12 | def load_datasets(parser, args): 13 | """Loads the specified dataset from commandline arguments 14 | 15 | Returns: 16 | train_dataset, validation_dataset 17 | """ 18 | 19 | args = parser.parse_args() 20 | 21 | dataset_kwargs = { 22 | "root": Path(args.train_dir), 23 | } 24 | 25 | source_augmentations = Compose( 26 | [globals()["_augment_" + aug] for aug in args.source_augmentations] 27 | ) 28 | 29 | if args.machine_type == 'valve': 30 | Dataset = MIMIIValveDataset 31 | validation_tracks = validation_tracks = ["00000000", "00000001","00000002", "00000003"] 32 | elif args.machine_type == 'slider': 33 | Dataset = MIMIISliderDataset 34 | validation_tracks = validation_tracks = ["00000000", "00000001","00000002", "00000003"] 35 | else: 36 | raise Exception("unexpected machine type") 37 | 38 | 39 | train_dataset = Dataset( 40 | split=args.split, 41 | sources=args.sources, 42 | targets=args.sources, 43 | source_augmentations=source_augmentations, 44 | random_track_mix=True, 45 | segment=args.seq_dur, 46 | random_segments=True, 47 | sample_rate=args.sample_rate, 48 | samples_per_track=args.samples_per_track, 49 | use_control=args.use_control, 50 | task_random=args.task_random, 51 | source_random=args.source_random, 52 | num_src_in_mix=args.num_src_in_mix, 53 | **dataset_kwargs, 54 | ) 55 | 56 | train_dataset = filtering_out_valid(train_dataset, validation_tracks) 57 | 58 | valid_dataset = Dataset( 59 | split=args.split, 60 | subset=validation_tracks, 61 | sources=args.sources, 62 | targets=args.sources, 63 | source_augmentations=source_augmentations, 64 | segment=args.val_dur, 65 | sample_rate=args.sample_rate, 66 | use_control=args.use_control, 67 | task_random=args.task_random, 68 | source_random=args.source_random, 69 | num_src_in_mix=args.num_src_in_mix, 70 | **dataset_kwargs, 71 | ) 72 | 73 | return train_dataset, valid_dataset 74 | 75 | 76 | def filtering_out_valid(input_dataset, validation_tracks): 77 | """Filtering out validation tracks from input dataset. 78 | 79 | Return: 80 | input_dataset (w/o validation tracks) 81 | """ 82 | input_dataset.tracks = [ 83 | tmp 84 | for tmp in input_dataset.tracks 85 | if not (str(tmp["path"]).split("/")[-1] in validation_tracks) 86 | ] 87 | return input_dataset 88 | 89 | 90 | class Compose(object): 91 | """Composes several augmentation transforms. 92 | Args: 93 | augmentations: list of augmentations to compose. 94 | """ 95 | 96 | def __init__(self, transforms): 97 | self.transforms = transforms 98 | 99 | def __call__(self, audio): 100 | for transform in self.transforms: 101 | audio = transform(audio) 102 | return audio 103 | 104 | 105 | def _augment_delay(audio, max=16000): 106 | """Applies a random gain to each source between `low` and `high`""" 107 | delay = random.randint(0, max) 108 | audio_len = audio.shape[1] 109 | delayed = torch.cat([torch.zeros_like(audio)[:, :delay], audio[:, :audio_len - delay]], dim=1) 110 | return delayed 111 | 112 | 113 | def _augment_gain(audio, low=0.25, high=1.25): 114 | """Applies a random gain to each source between `low` and `high`""" 115 | gain = low + torch.rand(1) * (high - low) 116 | return audio * gain 117 | 118 | 119 | def _augment_channelswap(audio): 120 | """Randomly swap channels of stereo sources""" 121 | if audio.shape[0] == 2 and torch.FloatTensor(1).uniform_() < 0.5: 122 | return torch.flip(audio, [0]) 123 | 124 | return audio 125 | -------------------------------------------------------------------------------- /informed-X-UMX/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn>=0.22 2 | museval>=0.4.0 3 | norbert>=0.2.1 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements for using Asteroid. Using this file is equivalent to using 2 | # requirements/install.txt. Note that we cannot make this file a symlink to 3 | # requirements/install.txt because of how pip resolves relative paths with -r. 4 | 5 | -r requirements/install.txt 6 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | # Requirements for development on Asteroid and running tests 2 | -r ./install.txt 3 | pre-commit 4 | black==20.8b1 5 | pytest 6 | coverage 7 | codecov 8 | 9 | librosa 10 | museval 11 | wandb -------------------------------------------------------------------------------- /requirements/install.txt: -------------------------------------------------------------------------------- 1 | # Requirements for using Asteroid 2 | -r ./torchhub.txt 3 | PyYAML>=5.0 4 | pandas>=0.23.4 5 | pytorch-lightning>=1.0.1,<1.5.0 6 | torchaudio>=0.8.0 7 | pb_bss_eval>=0.0.2 8 | torch_stoi>=0.0.1 9 | torch_optimizer>=0.0.1a12,<0.2.0 10 | julius 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import os 3 | import re 4 | from setuptools import setup, find_packages 5 | 6 | 7 | with open("README.md", encoding="utf-8") as fh: 8 | long_description = fh.read() 9 | 10 | 11 | here = os.path.abspath(os.path.dirname(__file__)) 12 | 13 | 14 | def read(*parts): 15 | with codecs.open(os.path.join(here, *parts), "r") as fp: 16 | return fp.read() 17 | 18 | 19 | def find_version(*file_paths): 20 | version_file = read(*file_paths) 21 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 22 | if version_match: 23 | return version_match.group(1) 24 | raise RuntimeError("Unable to find version string.") 25 | 26 | 27 | setup( 28 | name="asteroid", 29 | version=find_version("asteroid", "__init__.py"), 30 | author="Manuel Pariente", 31 | author_email="manuel.pariente@loria.fr", 32 | url="https://github.com/asteroid-team/asteroid", 33 | description="PyTorch-based audio source separation toolkit", 34 | long_description=long_description, 35 | long_description_content_type="text/markdown", 36 | license="MIT", 37 | python_requires=">=3.6", 38 | install_requires=[ 39 | # From requirements/torchhub.txt 40 | "numpy>=1.16.4", 41 | "scipy>=1.1.0", 42 | "torch>=1.8.0", 43 | "asteroid-filterbanks>=0.4.0", 44 | "SoundFile>=0.10.2", 45 | "huggingface_hub>=0.0.2", 46 | # From requirements/install.txt 47 | "PyYAML>=5.0", 48 | "pandas>=0.23.4", 49 | "pytorch-lightning>=1.0.1,<1.5.0", 50 | "torchaudio>=0.5.0", 51 | "pb_bss_eval>=0.0.2", 52 | "torch_stoi>=0.1.2", 53 | "torch_optimizer>=0.0.1a12,<0.2.0", 54 | "julius", 55 | ], 56 | entry_points={ 57 | "console_scripts": [ 58 | "asteroid-upload=asteroid.scripts.asteroid_cli:upload", 59 | "asteroid-infer=asteroid.scripts.asteroid_cli:infer", 60 | "asteroid-register-sr=asteroid.scripts.asteroid_cli:register_sample_rate", 61 | "asteroid-versions=asteroid.scripts.asteroid_versions:print_versions", 62 | ] 63 | }, 64 | packages=find_packages(), 65 | include_package_data=True, 66 | classifiers=[ 67 | "Development Status :: 4 - Beta", 68 | "Programming Language :: Python :: 3", 69 | "Programming Language :: Python :: 3.6", 70 | "Programming Language :: Python :: 3.7", 71 | "Programming Language :: Python :: 3.8", 72 | "License :: OSI Approved :: MIT License", 73 | "Operating System :: OS Independent", 74 | ], 75 | ) --------------------------------------------------------------------------------