├── .gitignore ├── README.md ├── environment.yml ├── notebooks ├── AudioSamples-LibriMix-MixIT.ipynb ├── AudioSamples-LibriMix-PIT-DM.ipynb ├── AudioSamples-LibriMix.ipynb ├── AudioSamples-REAL-M.ipynb └── utils.py └── src ├── experiment.py ├── irm.py ├── lib ├── __init__.py ├── data │ ├── __init__.py │ ├── collate_utils.py │ ├── dataloader_utils.py │ ├── realm.py │ └── sc09.py ├── losses.py ├── models.py ├── trainers.py ├── transforms.py └── utils.py ├── sc09mix.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | .ipynb_checkpoints 5 | .idea/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MixCycle: Unsupervised Speech Separation via Cyclic Mixture Permutation Invariant Training 2 | This repository contains the audio samples and the source code that accompany [the paper](https://arxiv.org/abs/2202.03875). 3 | 4 | ## Audio samples 5 | We provide audio samples to demonstrate the results of the MixCycle method on two different datasets: [LibriMix](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-LibriMix.ipynb) and [REAL-M](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-REAL-M.ipynb). 6 | 7 | Also note that the provided [REAL-M](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-REAL-M.ipynb) samples were used in the informal listening test. 8 | 9 | We also provide audio samples from the baseline methods on LibriMix: [PIT-DM](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-LibriMix-PIT-DM.ipynb) and [MixIT](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-LibriMix-MixIT.ipynb). 10 | 11 | ## Source code 12 | We provide the source code under the `src` directory for reproducibility. 13 | 14 | ## Running the experiments 15 | 16 | ### Prepare the datasets 17 | - LibriMix: [GitHub](https://github.com/JorisCos/LibriMix) 18 | - REAL-M: [Download](https://sourceseparationresearch.com/static/REAL-M-v0.1.0.tar.gz) 19 | 20 | ### Create the environment 21 | 22 | Install [Anaconda](https://www.anaconda.com/products/individual) and run the following command: 23 | ``` 24 | $ conda env create -f environment.yml 25 | ``` 26 | See [more info](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) on how to manage conda environments. 27 | 28 | ### Activate the environment 29 | ``` 30 | $ conda activate mixcycle 31 | ``` 32 | 33 | ### Run the experiments 34 | ``` 35 | $ cd src 36 | $ python experiment.py --librimix-root ~/datasets/librimix --exp-root ~/experiments --run librimix_irm 37 | $ python experiment.py --librimix-root ~/datasets/librimix --exp-root ~/experiments --run librimix_5p 38 | $ python experiment.py --librimix-root ~/datasets/librimix --exp-root ~/experiments --run librimix_100p 39 | $ python experiment.py --librimix-root ~/datasets/librimix --realm-root ~/datasets/REAL-M-v0.1.0 --exp-root ~/experiments --run realm 40 | ``` 41 | 42 | Optionally, you can monitor the training process with TensorBoard by running: 43 | ``` 44 | $ tensorboard --logdir experiments 45 | ``` 46 | 47 | ## Citation (BibTeX) 48 | If you find this repository useful, please cite our work: 49 | 50 | ```BibTeX 51 | @article{karamatli2022unsupervised, 52 | title={MixCycle: Unsupervised Speech Separation via Cyclic Mixture Permutation Invariant Training}, 53 | author={Karamatl{\i}, Ertu{\u{g}} and K{\i}rb{\i}z, Serap}, 54 | journal={IEEE Signal Processing Letters}, 55 | volume={29}, 56 | number={}, 57 | pages={2637-2641}, 58 | year={2022}, 59 | doi={10.1109/LSP.2022.3232276} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mixcycle 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9.7 6 | - pytorch::pytorch=1.12.1 7 | - pytorch::torchaudio=0.12.1 8 | - cudatoolkit=10.2 9 | - tqdm=4.62.3 10 | - conda-forge::tensorboard=2.7.0 11 | - conda-forge::jupyterlab=3.2.5 12 | - conda-forge::ipywidgets=7.6.5 13 | -------------------------------------------------------------------------------- /notebooks/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from functools import partial 4 | from types import SimpleNamespace 5 | 6 | import torch 7 | from torchaudio.datasets.librimix import LibriMix 8 | import ipywidgets as widgets 9 | import IPython.display as ipd 10 | 11 | from train import BEST_CHECKPOINT_FILENAME, Training 12 | from lib.data.realm import RealM 13 | from lib.models import Model 14 | from lib.losses import negative_sisnri, invariant_loss 15 | 16 | 17 | SAMPLE_RATE = 8000 18 | 19 | 20 | def collate_samples(dataset_name, dataset_root, count): 21 | if dataset_name == 'librimix': 22 | dataset = LibriMix( 23 | root=dataset_root, 24 | subset='test', 25 | num_speakers=2, 26 | sample_rate=SAMPLE_RATE, 27 | task='sep_clean', 28 | ) 29 | elif dataset_name == 'realm': 30 | dataset = RealM( 31 | root=dataset_root, 32 | partition='testing', 33 | ) 34 | else: 35 | raise Exception(f'unknown dataset name: {dataset_name}') 36 | 37 | selected_idxs = torch.randperm(len(dataset))[:count] 38 | data_samples = [] 39 | for idx in selected_idxs: 40 | _, x_true_wave, s_true_wave = dataset[idx] 41 | data_sample = SimpleNamespace( 42 | x_true_wave=x_true_wave, 43 | s_true_wave=torch.cat(s_true_wave) 44 | ) 45 | data_samples.append(data_sample) 46 | return data_samples 47 | 48 | 49 | def evaluate_samples(train_results_root, dataset_name, model_name, data_samples, **kwargs): 50 | training = Training( 51 | train_results_root=train_results_root, 52 | librimix_root=(dataset_name == 'librimix'), 53 | realm_root=(dataset_name == 'realm'), 54 | model_name=model_name, 55 | **kwargs 56 | ) 57 | 58 | with torch.inference_mode(): 59 | model = Model.load( 60 | path=os.path.join(training.results_dir, BEST_CHECKPOINT_FILENAME), 61 | device='cpu' 62 | ).eval() 63 | 64 | separation_samples = [] 65 | for data_sample in data_samples: 66 | s_pred_wave = model(data_sample.x_true_wave.unsqueeze(0)) 67 | mixing_matrices = model.generate_mixing_matrices( 68 | num_targets=model.config.num_sources, 69 | max_sources=model.num_sources, 70 | num_mix=1, 71 | allow_empty=True 72 | ) 73 | negative_sisnri_value, best_perm_idx = invariant_loss( 74 | true=data_sample.s_true_wave.unsqueeze(0), 75 | pred=s_pred_wave, 76 | mixing_matrices=mixing_matrices, 77 | loss_func=partial(negative_sisnri, x_true_wave=data_sample.x_true_wave.unsqueeze(0)), 78 | return_best_perm_idx=True 79 | ) 80 | s_pred_wave_permuted = mixing_matrices[best_perm_idx].matmul(s_pred_wave) 81 | 82 | separation_samples.append(SimpleNamespace( 83 | x_true_wave=data_sample.x_true_wave, 84 | s_true_wave=data_sample.s_true_wave, 85 | s_pred_wave=s_pred_wave_permuted.squeeze(0), 86 | sisnri=-negative_sisnri_value.item() 87 | )) 88 | return separation_samples 89 | 90 | 91 | def _audio_widget(data): 92 | out = widgets.Output() 93 | with out: 94 | ipd.display(ipd.Audio(data, rate=SAMPLE_RATE)) 95 | return out 96 | 97 | 98 | def _mixture_widget(html, sample): 99 | return widgets.VBox([ 100 | widgets.HTML(html), 101 | _audio_widget(sample.x_true_wave) 102 | ]) 103 | 104 | 105 | def show_samples_librimix(separation_samples): 106 | def _source_widget(text, s_wave): 107 | return widgets.VBox([ 108 | widgets.HTML('

{} 1

'.format(text)), 109 | _audio_widget(s_wave[0]), 110 | widgets.HTML('

{} 2

'.format(text)), 111 | _audio_widget(s_wave[1]), 112 | ]) 113 | 114 | def sisnri(separation_sample): 115 | html = '
Mean SI-SNRi:
' \ 116 | '
{:.1f}
' \ 117 | .format(separation_sample.sisnri) 118 | return widgets.HTML(html) 119 | 120 | for idx, separation_sample in enumerate(separation_samples): 121 | ipd.display(widgets.HBox([ 122 | _mixture_widget(html='

Mixture #{}

'.format(idx+1), sample=separation_sample), 123 | _source_widget(text='Reference Source', s_wave=separation_sample.s_true_wave), 124 | _source_widget(text='Estimated Source', s_wave=separation_sample.s_pred_wave), 125 | sisnri(separation_sample), 126 | ], layout=widgets.Layout(align_items='center', padding='3px', margin='5px', border='2px solid gray'))) 127 | 128 | 129 | def show_samples_realm(separation_samples_librimix, separation_samples_realm): 130 | def _source_widget(model_id, sample): 131 | return widgets.VBox([ 132 | widgets.HTML('

Model {} / Est. Source 1

'.format(model_id)), 133 | _audio_widget(sample.s_pred_wave[0]), 134 | widgets.HTML('

Model {} / Est. Source 2

'.format(model_id)), 135 | _audio_widget(sample.s_pred_wave[1]), 136 | ]) 137 | 138 | for idx, (sample_librimix, sample_realm) in enumerate(zip(separation_samples_librimix, separation_samples_realm)): 139 | ipd.display(widgets.HBox([ 140 | _mixture_widget(html='

Mixture #{}

'.format(idx+1), sample=sample_librimix), 141 | _source_widget(model_id='A', sample=sample_librimix), 142 | _source_widget(model_id='B', sample=sample_realm) 143 | ], layout=widgets.Layout(align_items='center', padding='3px', margin='5px', border='2px solid gray'))) 144 | 145 | 146 | def load_file(path): 147 | try: 148 | data = torch.load(path) 149 | return data 150 | except FileNotFoundError as e: 151 | print(e, file=sys.stderr) 152 | return None 153 | -------------------------------------------------------------------------------- /src/experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from types import SimpleNamespace 4 | 5 | import torch 6 | 7 | from lib.utils import CONFIG_FILENAME, configure_console_logger, configure_file_logger, get_logger 8 | from train import Training 9 | from test import Test 10 | from irm import IdealRatioMask 11 | 12 | 13 | class Experiment: 14 | def __init__(self, experiments_root, experiment_name, librimix_root=None, realm_root=None): 15 | self.config = SimpleNamespace() 16 | self.config.experiments_root = experiments_root 17 | self.config.datasets_root = os.path.join(experiments_root, 'datasets') 18 | self.config.checkpoints_root = os.path.join(experiments_root, experiment_name) 19 | self.config.librimix_root = librimix_root 20 | self.config.realm_root = realm_root 21 | 22 | if os.path.exists(self.config.checkpoints_root): 23 | raise Exception(f'checkpoints root "{self.config.checkpoints_root}" exists!') 24 | os.makedirs(self.config.checkpoints_root, exist_ok=True) 25 | 26 | configure_file_logger(self.config.checkpoints_root) 27 | self.logger = get_logger('exp', self.config.checkpoints_root) 28 | self.logger.info('config: %s', self.config) 29 | torch.save(self.config, os.path.join(self.config.checkpoints_root, CONFIG_FILENAME)) 30 | 31 | def librimix_irm(self): 32 | IdealRatioMask( 33 | irm_results_root=self.config.checkpoints_root, 34 | librimix_root=self.config.librimix_root 35 | ).start() 36 | 37 | def model_comparison_librimix(self, 38 | model_names=None, 39 | mixcycle_init_epochs=50, 40 | train_subsample_ratio=1., 41 | eval_epochs=1, 42 | run_id=None): 43 | all_model_names = ['mixcycle', 'pit', 'pit-dm', 'mixit', 'mixpit'] 44 | for model_name in (all_model_names if model_names is None else model_names): 45 | self.logger.info('model_name=%s', model_name) 46 | 47 | training = Training( 48 | train_results_root=self.config.checkpoints_root, 49 | librimix_root=self.config.librimix_root, 50 | model_name=model_name, 51 | mixcycle_init_epochs=mixcycle_init_epochs, 52 | train_subsample_ratio=train_subsample_ratio, 53 | eval_epochs=eval_epochs, 54 | run_id=run_id 55 | ) 56 | training.start() 57 | 58 | Test( 59 | train_results_dir=training.results_dir, 60 | librimix_root=self.config.librimix_root 61 | ).start() 62 | 63 | def librimix_5p(self): 64 | self.model_comparison_librimix( 65 | #model_names=['mixcycle'], 66 | mixcycle_init_epochs=250, 67 | train_subsample_ratio=0.05, 68 | eval_epochs=10 69 | ) 70 | 71 | def librimix_100p(self): 72 | self.model_comparison_librimix( 73 | mixcycle_init_epochs=50, 74 | train_subsample_ratio=1., 75 | eval_epochs=1 76 | ) 77 | 78 | def realm(self): 79 | initialized_model = Training( 80 | train_results_root=os.path.join(self.config.experiments_root, 'librimix_100p'), 81 | librimix_root=True, 82 | model_name='mixcycle', 83 | mixcycle_init_epochs=50, 84 | train_subsample_ratio=1.0, 85 | eval_epochs=1 86 | ) 87 | 88 | training = Training( 89 | train_results_root=self.config.checkpoints_root, 90 | realm_root=self.config.realm_root, 91 | model_name='mixcycle', 92 | model_load_path=initialized_model.results_dir, 93 | mixcycle_init_epochs=0, 94 | eval_method='blind', 95 | eval_blind_num_repeat=20, 96 | eval_epochs=40 97 | ) 98 | training.start() 99 | 100 | Test( 101 | train_results_dir=training.results_dir, 102 | realm_root=self.config.realm_root, 103 | eval_blind_num_repeat=100 104 | ).start() 105 | 106 | 107 | if __name__ == '__main__': 108 | configure_console_logger() 109 | 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--exp-root', type=str, required=True) 112 | parser.add_argument('--run', type=str, required=True) 113 | parser.add_argument('--librimix-root', type=str) 114 | parser.add_argument('--realm-root', type=str) 115 | 116 | args = parser.parse_args() 117 | experiment = Experiment( 118 | experiments_root=args.exp_root, 119 | experiment_name=args.run, 120 | librimix_root=args.librimix_root, 121 | realm_root=args.realm_root 122 | ) 123 | 124 | try: 125 | func = getattr(experiment, args.run) 126 | except AttributeError: 127 | raise Exception('unknown experiment') 128 | 129 | func() 130 | -------------------------------------------------------------------------------- /src/irm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from types import SimpleNamespace 4 | 5 | import torch 6 | 7 | from lib.utils import BEST_METRICS_FILENAME, configure_console_logger, get_logger, default, build_run_name, \ 8 | ensure_clean_results_dir, soft_mask, MetricAccumulator, metrics_to_str 9 | from lib.data.dataloader_utils import get_dataset_specs, create_dataloader 10 | from lib.transforms import Transform 11 | from lib.losses import sisnri 12 | 13 | 14 | class IdealRatioMask: 15 | def __init__(self, irm_results_root, librimix_root=None, realm_root=None, 16 | stft_frame_size=None, 17 | stft_hop_size=None, 18 | device_name=None): 19 | args = locals() 20 | del args['self'] 21 | 22 | self.dataset_name, self.dataset_root = get_dataset_specs(librimix_root, realm_root) 23 | run_name = build_run_name( 24 | args=args, 25 | prepend_items={'irm': self.dataset_name}, 26 | exclude_keys=['irm_results_root', 'sc09mix_root', 'librimix_root'] 27 | ) 28 | self.results_dir = os.path.join(irm_results_root, run_name) 29 | 30 | self.config = SimpleNamespace() 31 | self.config.results_dir = self.results_dir 32 | self.config.librimix_root = librimix_root 33 | self.config.realm_root = realm_root 34 | self.config.stft_frame_size = default(stft_frame_size, 512) 35 | self.config.stft_hop_size = default(stft_hop_size, 128) 36 | self.config.device_name = default(device_name, 'cuda') 37 | 38 | self.logger = None 39 | self.config.device = torch.device(self.config.device_name) 40 | 41 | def start(self): 42 | ensure_clean_results_dir(self.config.results_dir) 43 | self.logger = get_logger('irm') 44 | self.logger.info('config: %s', self.config) 45 | 46 | dataloader = create_dataloader( 47 | dataset_name=self.dataset_name, 48 | dataset_root=self.dataset_root, 49 | partition='testing', 50 | batch_size=1, 51 | ) 52 | self.logger.info('using %d samples for evaluation', len(dataloader.dataset)) 53 | 54 | transform = Transform( 55 | stft_frame_size=self.config.stft_frame_size, 56 | stft_hop_size=self.config.stft_hop_size, 57 | device=self.config.device, 58 | ) 59 | 60 | with torch.inference_mode(): 61 | sisnri_accumulator = MetricAccumulator() 62 | 63 | for x_true_wave, s_true_wave in dataloader: 64 | x_true_wave = x_true_wave.to(self.config.device) 65 | s_true_wave = s_true_wave.to(self.config.device) 66 | 67 | x_true_mag, x_true_phase = transform.stft(x_true_wave) 68 | s_true_mag, _ = transform.stft(s_true_wave) 69 | s_pred_mag = soft_mask(s_true_mag, x_true_mag) 70 | s_pred_wave = transform.istft( 71 | mag=s_pred_mag, 72 | phase=x_true_phase, 73 | length=x_true_wave.size(-1) 74 | ) 75 | 76 | batch_sisnri = sisnri( 77 | true_wave=s_true_wave, 78 | pred_wave=s_pred_wave, 79 | x_true_wave=x_true_wave, 80 | ) 81 | 82 | sisnri_accumulator.store(batch_sisnri) 83 | 84 | std, mean = sisnri_accumulator.std_mean() 85 | metrics = { 86 | 'sisnri': mean.item(), 87 | 'sisnri_std': std.item() 88 | } 89 | 90 | metrics_str = metrics_to_str(metrics) 91 | self.logger.info('[IRM] %s', metrics_str) 92 | 93 | torch.save(metrics, os.path.join(self.config.results_dir, BEST_METRICS_FILENAME)) 94 | 95 | self.logger.info('completed') 96 | 97 | 98 | if __name__ == '__main__': 99 | configure_console_logger() 100 | 101 | arg_parser = argparse.ArgumentParser() 102 | arg_parser.add_argument('--irm-results-root', type=str, required=True) 103 | arg_parser.add_argument('--librimix-root', type=str) 104 | arg_parser.add_argument('--realm-root', type=str) 105 | arg_parser.add_argument('--stft-frame-size', type=int) 106 | arg_parser.add_argument('--stft-hop-size', type=int) 107 | arg_parser.add_argument('--device-name', type=str) 108 | 109 | cmd_args = arg_parser.parse_args() 110 | IdealRatioMask(**vars(cmd_args)).start() 111 | -------------------------------------------------------------------------------- /src/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ertug/MixCycle/88d9c2d302c4bc62752781b9bbba06e0472d67f1/src/lib/__init__.py -------------------------------------------------------------------------------- /src/lib/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ertug/MixCycle/88d9c2d302c4bc62752781b9bbba06e0472d67f1/src/lib/data/__init__.py -------------------------------------------------------------------------------- /src/lib/data/collate_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | from collections import namedtuple 3 | 4 | import torch 5 | 6 | ######################################################################################################### 7 | # based on https://github.com/pytorch/audio/blob/main/examples/source_separation/utils/dataset/utils.py # 8 | ######################################################################################################### 9 | 10 | SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]] 11 | Batch = namedtuple("Batch", ["mix", "src"]) 12 | 13 | 14 | def _fix_num_frames(sample: SampleType, target_num_frames: int, sample_rate: int, random_start=False): 15 | """Ensure waveform has exact number of frames by slicing or padding""" 16 | mix = sample[1] # [1, time] 17 | src = torch.cat(sample[2], 0) # [num_sources, time] 18 | 19 | num_channels, num_frames = src.shape 20 | num_seconds = torch.div(num_frames, sample_rate, rounding_mode="floor") 21 | target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode="floor") 22 | if num_frames >= target_num_frames: 23 | if random_start and num_frames > target_num_frames: 24 | start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate 25 | mix = mix[:, start_frame:] 26 | src = src[:, start_frame:] 27 | mix = mix[:, :target_num_frames] 28 | src = src[:, :target_num_frames] 29 | mask = torch.ones_like(mix) 30 | else: 31 | num_padding = target_num_frames - num_frames 32 | pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device) 33 | mix = torch.cat([mix, pad], 1) 34 | src = torch.cat([src, pad.expand(num_channels, -1)], 1) 35 | mask = torch.ones_like(mix) 36 | mask[..., num_frames:] = 0 37 | return mix, src, mask 38 | 39 | 40 | def collate_fn_wsj0mix_train(samples, sample_rate, duration): 41 | target_num_frames = int(duration * sample_rate) 42 | 43 | mixes, srcs = [], [] 44 | for sample in samples: 45 | mix, src, _ = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True) 46 | 47 | mixes.append(mix) 48 | srcs.append(src) 49 | 50 | return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0)) 51 | 52 | 53 | def collate_fn_wsj0mix_test(samples): 54 | assert len(samples) == 1 55 | 56 | mixes, srcs = [], [] 57 | for sample in samples: 58 | mix = sample[1] # [1, time] 59 | src = torch.cat(sample[2], 0) # [num_sources, time] 60 | 61 | mixes.append(mix) 62 | srcs.append(src) 63 | 64 | return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0)) 65 | -------------------------------------------------------------------------------- /src/lib/data/dataloader_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch.utils.data import DataLoader, Subset 5 | from torchaudio.datasets.librimix import LibriMix 6 | 7 | from lib.data.collate_utils import collate_fn_wsj0mix_train, collate_fn_wsj0mix_test 8 | from lib.data.realm import RealM 9 | 10 | 11 | SAMPLE_RATE = 8000 12 | 13 | 14 | def get_dataset_specs(librimix_root=None, realm_root=None): 15 | if librimix_root and realm_root: 16 | raise Exception('only one dataset root should be given') 17 | elif librimix_root: 18 | dataset_name = 'librimix' 19 | dataset_root = librimix_root 20 | elif realm_root: 21 | dataset_name = 'realm' 22 | dataset_root = realm_root 23 | else: 24 | raise Exception('at least one dataset root should be given') 25 | 26 | return dataset_name, dataset_root 27 | 28 | 29 | def create_dataloader(dataset_name, dataset_root, partition, batch_size, subsample_ratio=1.0, shuffle=None): 30 | assert partition in ('training', 'validation', 'testing') 31 | 32 | if partition in ('training', 'validation'): 33 | collate_fn = partial(collate_fn_wsj0mix_train, sample_rate=SAMPLE_RATE, duration=3) 34 | else: 35 | collate_fn = partial(collate_fn_wsj0mix_test) 36 | 37 | if dataset_name == 'librimix': 38 | subset_mapping = { 39 | 'training': 'train-360', 40 | 'validation': 'dev', 41 | 'testing': 'test', 42 | } 43 | dataset = LibriMix( 44 | root=dataset_root, 45 | subset=subset_mapping[partition], 46 | num_speakers=2, 47 | sample_rate=SAMPLE_RATE, 48 | task='sep_clean', 49 | ) 50 | elif dataset_name == 'realm': 51 | dataset = RealM( 52 | root=dataset_root, 53 | partition=partition, 54 | ) 55 | else: 56 | raise Exception(f'unknown dataset name: {dataset_name}') 57 | 58 | if subsample_ratio < 1.0: 59 | subsampled_dataset_size = int(len(dataset) // (1/subsample_ratio)) 60 | indices = torch.randperm(len(dataset)).int()[:subsampled_dataset_size] 61 | dataset = Subset(dataset, indices) 62 | 63 | return DataLoader( 64 | dataset=dataset, 65 | batch_size=batch_size, 66 | shuffle=(partition == 'training' or shuffle), 67 | collate_fn=collate_fn, 68 | num_workers=4, 69 | drop_last=(partition == 'training'), 70 | pin_memory=True, 71 | ) 72 | 73 | 74 | def double_mixture_generator(dataloader): 75 | iterator = iter(dataloader) 76 | while True: 77 | try: 78 | x_true_wave_1, _ = next(iterator) 79 | x_true_wave_2, _ = next(iterator) 80 | yield x_true_wave_1, x_true_wave_2 81 | except StopIteration: 82 | return 83 | -------------------------------------------------------------------------------- /src/lib/data/realm.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from pathlib import Path 3 | 4 | from torch.utils.data import Dataset 5 | import torchaudio 6 | 7 | 8 | class RealM(Dataset): 9 | VALIDATION_PERCENTAGE = 20 10 | MAX_NUM_FILES = 2 ** 27 - 1 # ~134M 11 | SOURCE_SUFFIXES = ['source1hat.wav', 'source2hat.wav'] 12 | 13 | def __init__(self, root, partition): 14 | self.mix_root = Path(root) / 'audio_files_converted_8000Hz' 15 | self.src_root = Path(root) / 'separations' 16 | self.mix_paths = [] 17 | 18 | if partition == 'testing': 19 | partition = 'validation' 20 | 21 | for file_path in self.mix_root.glob('*/*.wav'): 22 | if self._assign_partition(file_path.name) == partition: 23 | self.mix_paths.append(file_path) 24 | 25 | self.mix_paths.sort() 26 | 27 | def __len__(self): 28 | return len(self.mix_paths) 29 | 30 | def __getitem__(self, idx): 31 | mix_path = self.mix_paths[idx] 32 | mix = self._load_audio(mix_path) 33 | srcs = [self._load_audio(self.src_root / (mix_path.name + s)) for s in self.SOURCE_SUFFIXES] 34 | return None, mix, srcs 35 | 36 | def _load_audio(self, file_path): 37 | wave, sample_rate = torchaudio.load(file_path) 38 | assert sample_rate == 8000 39 | return wave 40 | 41 | def _assign_partition(self, filename): 42 | filename_hashed = hashlib.sha1(filename.encode('ascii')).hexdigest() 43 | percentage_hash = ((int(filename_hashed, 16) % 44 | (self.MAX_NUM_FILES + 1)) * 45 | (100.0 / self.MAX_NUM_FILES)) 46 | if filename.startswith('early'): 47 | return 'discarded' 48 | elif percentage_hash < self.VALIDATION_PERCENTAGE: 49 | return 'validation' 50 | else: 51 | return 'training' 52 | -------------------------------------------------------------------------------- /src/lib/data/sc09.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import hashlib 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | import librosa 8 | 9 | from lib.utils import CONFIG_FILENAME 10 | 11 | 12 | SAMPLE_FILENAME_FORMAT = '{:07d}.pth' 13 | 14 | 15 | class SC09(Dataset): 16 | PARTITIONS = ('training', 'validation', 'testing') 17 | CLASSES = ('zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine') 18 | ORIGINAL_SAMPLE_RATE = 16000 19 | VALIDATION_PERCENTAGE = 10 20 | TESTING_PERCENTAGE = 10 21 | TRAINING_PERCENTAGE = 100 - VALIDATION_PERCENTAGE - TESTING_PERCENTAGE 22 | MAX_NUM_WAVS_PER_CLASS = 2 ** 27 - 1 # ~134M 23 | 24 | def __init__(self, root, partition, classes=None, sample_rate=ORIGINAL_SAMPLE_RATE): 25 | if partition not in self.PARTITIONS: 26 | raise Exception(f'invalid partition: {partition}') 27 | 28 | self.sample_rate = sample_rate 29 | self.metadata = [] 30 | 31 | for cls in (self.CLASSES if classes is None else classes): 32 | base = os.path.join(root, cls) 33 | filenames = sorted(os.listdir(base)) 34 | for filename in filenames: 35 | if self._assign_partition(filename) == partition: 36 | path = os.path.join(base, filename) 37 | self.metadata.append((path, cls)) 38 | 39 | def __len__(self): 40 | return len(self.metadata) 41 | 42 | def __getitem__(self, idx): 43 | path, label = self.metadata[idx] 44 | wave = torch.zeros(self.sample_rate) 45 | x, _ = librosa.core.load(path, sr=self.sample_rate) 46 | 47 | # center if shorter than one second 48 | centered_start = (self.sample_rate-x.shape[0])//2 49 | centered_end = centered_start + x.shape[0] 50 | wave[centered_start:centered_end] = torch.from_numpy(x) 51 | 52 | return wave, label 53 | 54 | def _assign_partition(self, filename): 55 | """ copied from the dataset README """ 56 | base_name = os.path.basename(filename) 57 | hash_name = re.sub(r'_nohash_.*$', '', base_name) 58 | hash_name_hashed = hashlib.sha1(hash_name.encode('ascii')).hexdigest() 59 | percentage_hash = ((int(hash_name_hashed, 16) % 60 | (self.MAX_NUM_WAVS_PER_CLASS + 1)) * 61 | (100.0 / self.MAX_NUM_WAVS_PER_CLASS)) 62 | if percentage_hash < self.VALIDATION_PERCENTAGE: 63 | result = 'validation' 64 | elif percentage_hash < (self.TESTING_PERCENTAGE + self.VALIDATION_PERCENTAGE): 65 | result = 'testing' 66 | else: 67 | result = 'training' 68 | return result 69 | 70 | 71 | class SC09Mix(Dataset): 72 | PARTITIONS = SC09.PARTITIONS 73 | 74 | def __init__(self, root, partition): 75 | assert partition in SC09Mix.PARTITIONS 76 | 77 | self.mix_dir = os.path.join(root, 'mix', partition) 78 | self.count = len(os.listdir(self.mix_dir)) 79 | self.config = torch.load(os.path.join(root, CONFIG_FILENAME)) 80 | 81 | def _load_sample(self, idx): 82 | return torch.load(os.path.join(self.mix_dir, SAMPLE_FILENAME_FORMAT.format(idx))) 83 | 84 | def _get_single_item(self, idx): 85 | sample = self._load_sample(idx) 86 | 87 | return sample['mixture'].unsqueeze(0), sample['sources'] 88 | 89 | def __len__(self): 90 | return self.count 91 | 92 | def __getitem__(self, idx): 93 | return self._get_single_item(idx) 94 | -------------------------------------------------------------------------------- /src/lib/losses.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | 3 | import torch 4 | 5 | 6 | def snr(true_wave, pred_wave, snr_max=None): 7 | true_wave_square_sum = true_wave.square().sum(-1) 8 | if snr_max is None: 9 | soft_threshold = 0 10 | else: 11 | threshold = 10 ** (-snr_max / 10) 12 | soft_threshold = threshold * true_wave_square_sum 13 | 14 | return 10 * torch.log10(true_wave_square_sum / ((true_wave - pred_wave).square().sum(-1) + soft_threshold)) 15 | 16 | 17 | def sisnr(true_wave, pred_wave, eps=0.): 18 | true_wave = ((true_wave * pred_wave).sum(-1) / true_wave.square().sum(-1)).unsqueeze(-1) * true_wave 19 | return 10 * torch.log10(true_wave.square().sum(-1) / ((true_wave - pred_wave).square().sum(-1) + eps)) 20 | 21 | 22 | def sisnri(true_wave, pred_wave, x_true_wave, eps=0.): 23 | return sisnr(true_wave=true_wave, pred_wave=pred_wave, eps=eps) - \ 24 | sisnr(true_wave=true_wave, pred_wave=x_true_wave, eps=eps) 25 | 26 | 27 | def negative_snr(true_wave, pred_wave, snr_max=None): 28 | return -snr(true_wave, pred_wave, snr_max) 29 | 30 | 31 | def negative_sisnr(true_wave, pred_wave): 32 | return -sisnr(true_wave, pred_wave) 33 | 34 | 35 | def negative_sisnri(true_wave, pred_wave, x_true_wave, eps=0.): 36 | return -sisnri(true_wave, pred_wave, x_true_wave, eps=eps) 37 | 38 | 39 | def invariant_loss(true, pred, mixing_matrices, loss_func, return_best_perm_idx=False): 40 | pred_flat = pred.view(*pred.size()[:2], prod(pred.size()[2:])) 41 | 42 | batch_size = true.size(0) 43 | perm_size = mixing_matrices.size(0) 44 | loss_perms = torch.empty([batch_size, perm_size], device=true.device) 45 | 46 | for perm_idx in range(perm_size): 47 | pred_flat_mix = mixing_matrices[perm_idx].matmul(pred_flat) 48 | pred_mix = pred_flat_mix.view(*pred_flat_mix.size()[:2], *pred.size()[2:]) 49 | loss_perms[:, perm_idx] = loss_func(true, pred_mix).mean(dim=1) 50 | _, best_perm_idx = loss_perms.min(dim=1) 51 | 52 | batch_loss = loss_perms[torch.arange(batch_size), best_perm_idx] 53 | 54 | if return_best_perm_idx: 55 | return batch_loss, best_perm_idx 56 | else: 57 | return batch_loss 58 | -------------------------------------------------------------------------------- /src/lib/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | 4 | import torch 5 | from torch import nn 6 | from torchaudio.models.conv_tasnet import MaskGenerator 7 | 8 | from lib.utils import get_logger, soft_mask 9 | from lib.transforms import Transform 10 | 11 | 12 | class Model(nn.Module): 13 | MIXING_MATRICES_CACHE = {} 14 | 15 | def __init__(self, config, num_sources=None): 16 | super().__init__() 17 | 18 | self.args = locals() 19 | self.logger = get_logger('model') 20 | self.config = config 21 | 22 | self.num_sources = self.config.num_sources if num_sources is None else num_sources 23 | 24 | self.transform = Transform( 25 | stft_frame_size=self.config.stft_frame_size, 26 | stft_hop_size=self.config.stft_hop_size, 27 | device=self.config.device, 28 | ) 29 | 30 | self.mask_generator = MaskGenerator( 31 | input_dim=self.config.num_frequency_bins, 32 | num_sources=self.num_sources, 33 | kernel_size=3, 34 | num_feats=128, 35 | num_hidden=512, 36 | num_layers=4, 37 | num_stacks=3, 38 | msk_activate='sigmoid', 39 | ) 40 | 41 | @classmethod 42 | def load(cls, path, device=None): 43 | checkpoint = torch.load(path) 44 | if 'device' in checkpoint.keys(): 45 | del checkpoint['device'] ## FIXME: remove 46 | if device: 47 | checkpoint['config'].device = device 48 | state_dict = checkpoint.pop('state_dict') 49 | 50 | instance = cls(**checkpoint) 51 | instance.load_state_dict(state_dict) 52 | instance.to(instance.config.device) 53 | 54 | return instance 55 | 56 | def save(self, path): 57 | checkpoint = self.args.copy() 58 | del checkpoint['self'] 59 | del checkpoint['__class__'] 60 | checkpoint['config'] = self.config 61 | checkpoint['state_dict'] = self.state_dict() 62 | 63 | os.makedirs(os.path.dirname(path), exist_ok=True) 64 | torch.save(checkpoint, path) 65 | 66 | def forward(self, x_true_wave): 67 | x_true_mag, x_true_phase = self.transform.stft(x_true_wave) 68 | m_pred_mag = self.mask_generator(x_true_mag.squeeze(1)) 69 | s_pred_mag = soft_mask(m_pred_mag, x_true_mag) 70 | s_pred_wave = self.transform.istft( 71 | mag=s_pred_mag, 72 | phase=x_true_phase, 73 | length=x_true_wave.size(-1) 74 | ) 75 | return s_pred_wave 76 | 77 | def generate_mixing_matrices(self, num_targets, max_sources, num_mix=None, allow_empty=False): 78 | parameters = locals() 79 | del parameters['self'] 80 | 81 | def do_generate(): 82 | output_perms = itertools.product([0, 1], repeat=max_sources) 83 | if num_mix is not None: 84 | output_perms = [perm for perm in output_perms if sum(perm) == num_mix] 85 | target_perms = list(itertools.product(output_perms, repeat=num_targets)) 86 | perm_list = [] 87 | for target_perm in target_perms: 88 | perm_sum = torch.tensor(target_perm).sum(dim=0) 89 | if (perm_sum <= 1).all() if allow_empty else (perm_sum == 1).all(): 90 | perm_list.append(target_perm) 91 | self.logger.info('mixing matrices are generated with %d permutations for parameters %s', 92 | len(perm_list), parameters) 93 | return torch.tensor(perm_list).float().to(self.config.device) 94 | 95 | cache_key = '_'.join(str(v) for k, v in parameters.items()) 96 | try: 97 | r = Model.MIXING_MATRICES_CACHE[cache_key] 98 | except KeyError: 99 | r = Model.MIXING_MATRICES_CACHE[cache_key] = do_generate() 100 | return r 101 | -------------------------------------------------------------------------------- /src/lib/trainers.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import prod 3 | from functools import partial 4 | 5 | import torch 6 | 7 | from lib.losses import negative_snr, negative_sisnri, invariant_loss 8 | from lib.models import Model 9 | from lib.utils import EPS, shuffle_sources, MetricAccumulator 10 | from lib.data.dataloader_utils import double_mixture_generator 11 | 12 | 13 | class Trainer: 14 | def __init__(self, config=None, model=None): 15 | self.loss_accumulator = MetricAccumulator() 16 | self.sisnri_accumulator = MetricAccumulator() 17 | if model: 18 | self.model = model 19 | if config: 20 | self.model.config = config 21 | else: 22 | self.model = Model(config=config) 23 | self.model.to(self.model.config.device) 24 | 25 | self.optimizer = torch.optim.Adam([ 26 | {'params': self.model.parameters(), 'lr': self.model.config.lr}, 27 | ]) 28 | 29 | def train(self, dataloader): 30 | self.model.train() 31 | self.loss_accumulator.reset() 32 | 33 | def step(self, batch_loss): 34 | self.optimizer.zero_grad(set_to_none=True) 35 | batch_loss.sum().backward() 36 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.model.config.grad_clip) 37 | self.optimizer.step() 38 | self.loss_accumulator.store(batch_loss) 39 | 40 | def validate(self, dataloader): 41 | self.model.eval() 42 | self.sisnri_accumulator.reset() 43 | 44 | for x_true_wave, s_true_wave in dataloader: 45 | x_true_wave = x_true_wave.to(self.model.config.device) 46 | s_true_wave = s_true_wave.to(self.model.config.device) 47 | 48 | s_pred_wave = self.model(x_true_wave) 49 | 50 | mixing_matrices = self.model.generate_mixing_matrices( 51 | num_targets=self.model.config.num_sources, 52 | max_sources=self.model.num_sources, 53 | num_mix=1, 54 | allow_empty=True 55 | ) 56 | batch_sisnri = -invariant_loss( 57 | true=s_true_wave, 58 | pred=s_pred_wave, 59 | mixing_matrices=mixing_matrices, 60 | loss_func=partial(negative_sisnri, x_true_wave=x_true_wave) 61 | ) 62 | self.sisnri_accumulator.store(batch_sisnri) 63 | 64 | std, mean = self.sisnri_accumulator.std_mean() 65 | return { 66 | 'sisnri': mean.item(), 67 | 'sisnri_std': std.item() 68 | } 69 | 70 | def get_loss(self): 71 | _, mean = self.loss_accumulator.std_mean() 72 | return mean.item() 73 | 74 | def get_model(self): 75 | return self.model 76 | 77 | 78 | class PermutationInvariantTrainer(Trainer): 79 | def train(self, dataloader): 80 | super().train(dataloader) 81 | 82 | for x_true_wave, s_true_wave in dataloader: 83 | x_true_wave = x_true_wave.to(self.model.config.device) 84 | s_true_wave = s_true_wave.to(self.model.config.device) 85 | 86 | s_pred_wave = self.model(x_true_wave) 87 | 88 | mixing_matrices = self.model.generate_mixing_matrices( 89 | num_targets=self.model.config.num_sources, 90 | max_sources=self.model.config.num_sources, 91 | num_mix=1 92 | ) 93 | batch_loss = invariant_loss( 94 | true=s_true_wave, 95 | pred=s_pred_wave, 96 | mixing_matrices=mixing_matrices, 97 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max), 98 | ) 99 | self.step(batch_loss) 100 | yield 1 101 | 102 | 103 | class PermutationInvariantDynamicMixingTrainer(Trainer): 104 | def train(self, dataloader): 105 | super().train(dataloader) 106 | 107 | for _, s_true_wave in dataloader: 108 | s_true_wave = s_true_wave.to(self.model.config.device) 109 | 110 | shuffle_idx = torch.randperm(prod(s_true_wave.size()[:2])) 111 | s_true_wave_shuffled = shuffle_sources(s_true_wave, self.model.config.num_sources, shuffle_idx) 112 | x_true_wave_shuffled = s_true_wave_shuffled.sum(dim=1, keepdims=True) 113 | 114 | s_pred_wave_shuffled = self.model(x_true_wave_shuffled) 115 | 116 | mixing_matrices = self.model.generate_mixing_matrices( 117 | num_targets=self.model.config.num_sources, 118 | max_sources=self.model.config.num_sources, 119 | num_mix=1 120 | ) 121 | batch_loss = invariant_loss( 122 | true=s_true_wave_shuffled, 123 | pred=s_pred_wave_shuffled, 124 | mixing_matrices=mixing_matrices, 125 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max), 126 | ) 127 | self.step(batch_loss) 128 | yield 1 129 | 130 | 131 | class MixtureInvariantTrainer(Trainer): 132 | def __init__(self, config=None, model=None): 133 | if not model: 134 | model = Model(config=config, num_sources=config.num_sources * 2) 135 | 136 | super().__init__(config, model) 137 | 138 | self.sisnri_mixit_oracle_accumulator = MetricAccumulator() 139 | 140 | def train(self, dataloader): 141 | super().train(dataloader) 142 | 143 | for x_true_wave_1, x_true_wave_2 in double_mixture_generator(dataloader): 144 | x_true_wave_1 = x_true_wave_1.to(self.model.config.device) 145 | x_true_wave_2 = x_true_wave_2.to(self.model.config.device) 146 | 147 | x_true_wave_double = torch.cat([x_true_wave_1, x_true_wave_2], dim=1) 148 | x_true_wave_mom = x_true_wave_double.sum(dim=1, keepdim=True) 149 | 150 | s_pred_wave = self.model(x_true_wave_mom) 151 | 152 | mixing_matrices = self.model.generate_mixing_matrices( 153 | num_targets=2, 154 | max_sources=self.model.num_sources 155 | ) 156 | batch_loss = invariant_loss( 157 | true=x_true_wave_double, 158 | pred=s_pred_wave, 159 | mixing_matrices=mixing_matrices, 160 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max), 161 | ) 162 | self.step(batch_loss) 163 | yield 2 164 | 165 | def validate(self, dataloader): 166 | metrics = super().validate(dataloader) 167 | 168 | self.sisnri_mixit_oracle_accumulator.reset() 169 | 170 | for x_true_wave, s_true_wave in dataloader: 171 | x_true_wave = x_true_wave.to(self.model.config.device) 172 | s_true_wave = s_true_wave.to(self.model.config.device) 173 | 174 | s_pred_wave = self.model(x_true_wave) 175 | 176 | mixing_matrices = self.model.generate_mixing_matrices( 177 | num_targets=self.model.config.num_sources, 178 | max_sources=self.model.num_sources 179 | ) 180 | batch_sisnri = -invariant_loss( 181 | true=s_true_wave, 182 | pred=s_pred_wave, 183 | mixing_matrices=mixing_matrices, 184 | loss_func=partial(negative_sisnri, x_true_wave=x_true_wave, eps=EPS), 185 | ) 186 | 187 | self.sisnri_mixit_oracle_accumulator.store(batch_sisnri) 188 | 189 | std, mean = self.sisnri_mixit_oracle_accumulator.std_mean() 190 | metrics['sisnri_mixit_oracle'] = mean.item() 191 | metrics['sisnri_mixit_oracle_std'] = std.item() 192 | return metrics 193 | 194 | 195 | class MixturePermutationInvariantTrainer(Trainer): 196 | def train(self, dataloader): 197 | super().train(dataloader) 198 | 199 | for x_true_wave_1, x_true_wave_2 in double_mixture_generator(dataloader): 200 | x_true_wave_1 = x_true_wave_1.to(self.model.config.device) 201 | x_true_wave_2 = x_true_wave_2.to(self.model.config.device) 202 | 203 | x_true_wave_double = torch.cat([x_true_wave_1, x_true_wave_2], dim=1) 204 | x_true_wave_mom = x_true_wave_double.sum(dim=1, keepdim=True) 205 | 206 | s_pred_wave = self.model(x_true_wave_mom) 207 | 208 | mixing_matrices = self.model.generate_mixing_matrices( 209 | num_targets=self.model.config.num_sources, 210 | max_sources=self.model.config.num_sources, 211 | num_mix=1 212 | ) 213 | batch_loss = invariant_loss( 214 | true=x_true_wave_double, 215 | pred=s_pred_wave, 216 | mixing_matrices=mixing_matrices, 217 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max), 218 | ) 219 | self.step(batch_loss) 220 | yield 2 221 | 222 | 223 | class MixCycleTrainer(Trainer): 224 | def __init__(self, config=None, model=None): 225 | super().__init__(config, model) 226 | 227 | self.sisnri_blind_accumulator = MetricAccumulator() 228 | self.mixpit_trainer = MixturePermutationInvariantTrainer(config, self.model) 229 | self.epochs = 0 230 | self.mixcycle_steps = 0 231 | self.model_copy = None 232 | 233 | def train(self, dataloader): 234 | super().train(dataloader) 235 | 236 | if self.epochs < self.model.config.mixcycle_init_epochs: 237 | generator = self.mixpit_trainer.train(dataloader) 238 | else: 239 | self.mixpit_trainer = None 240 | generator = self._mixcycle_train(dataloader) 241 | 242 | self.epochs += 1 243 | return generator 244 | 245 | def validate(self, dataloader): 246 | if self.mixpit_trainer: 247 | metrics = self.mixpit_trainer.validate(dataloader) 248 | else: 249 | metrics = super().validate(dataloader) 250 | 251 | if self.model.config.eval_method == 'blind': 252 | metrics.update(self.validate_blind(dataloader)) 253 | return metrics 254 | 255 | def validate_blind(self, dataloader): 256 | self.sisnri_blind_accumulator.reset() 257 | 258 | for _ in range(self.model.config.eval_blind_num_repeat): 259 | for x_true_wave, _ in dataloader: 260 | x_true_wave = x_true_wave.to(self.model.config.device) 261 | 262 | x_pred_wave_shuffled, s_pred_wave_shuffled = self._teacher(x_true_wave) 263 | 264 | s_pred_wave = self.model(x_pred_wave_shuffled) 265 | 266 | mixing_matrices = self.model.generate_mixing_matrices( 267 | num_targets=self.model.config.num_sources, 268 | max_sources=self.model.config.num_sources, 269 | num_mix=1 270 | ) 271 | batch_sisnri = -invariant_loss( 272 | true=s_pred_wave_shuffled, 273 | pred=s_pred_wave, 274 | mixing_matrices=mixing_matrices, 275 | loss_func=partial(negative_sisnri, x_true_wave=x_pred_wave_shuffled, eps=EPS), 276 | ) 277 | 278 | self.sisnri_blind_accumulator.store(batch_sisnri) 279 | 280 | std, mean = self.sisnri_blind_accumulator.std_mean() 281 | return { 282 | 'sisnri_blind': mean.item(), 283 | 'sisnri_blind_std': std.item() 284 | } 285 | 286 | def get_loss(self): 287 | if self.mixpit_trainer: 288 | return self.mixpit_trainer.get_loss() 289 | else: 290 | return super().get_loss() 291 | 292 | def get_model(self): 293 | if self.mixpit_trainer: 294 | return self.mixpit_trainer.get_model() 295 | else: 296 | return super().get_model() 297 | 298 | def _mixcycle_train(self, dataloader): 299 | for x_true_wave, _ in dataloader: 300 | with torch.no_grad(): 301 | x_true_wave = x_true_wave.to(self.model.config.device) 302 | 303 | x_pred_wave_shuffled, s_pred_wave_shuffled = self._teacher(x_true_wave) 304 | 305 | s_pred_wave = self.model(x_pred_wave_shuffled) 306 | 307 | mixing_matrices = self.model.generate_mixing_matrices( 308 | num_targets=self.model.config.num_sources, 309 | max_sources=self.model.config.num_sources, 310 | num_mix=1 311 | ) 312 | batch_loss = invariant_loss( 313 | true=s_pred_wave_shuffled, 314 | pred=s_pred_wave, 315 | mixing_matrices=mixing_matrices, 316 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max), 317 | ) 318 | self.step(batch_loss) 319 | self.mixcycle_steps += 1 320 | yield 1 321 | 322 | def _teacher(self, x_true_wave): 323 | s_pred_wave = self.model(x_true_wave) 324 | 325 | shuffle_idx = list(range(prod(s_pred_wave.size()[:2]))) 326 | for i in range(0, len(shuffle_idx), 2): 327 | # randomly swap source estimates of mixture 328 | if random.randrange(2) == 0: 329 | shuffle_idx[i], shuffle_idx[i + 1] = shuffle_idx[i + 1], shuffle_idx[i] 330 | 331 | for i in range(0, len(shuffle_idx), 4): 332 | # swap source estimates across two mixtures 333 | shuffle_idx[i], shuffle_idx[i + 2] = shuffle_idx[i + 2], shuffle_idx[i] 334 | 335 | s_pred_wave_shuffled = shuffle_sources(s_pred_wave, self.model.config.num_sources, shuffle_idx) 336 | x_pred_wave_shuffled = s_pred_wave_shuffled.sum(dim=1, keepdims=True) 337 | 338 | return x_pred_wave_shuffled, s_pred_wave_shuffled 339 | 340 | 341 | TRAINER_MAPPING = { 342 | 'pit': PermutationInvariantTrainer, 343 | 'pit-dm': PermutationInvariantDynamicMixingTrainer, 344 | 'mixit': MixtureInvariantTrainer, 345 | 'mixpit': MixturePermutationInvariantTrainer, 346 | 'mixcycle': MixCycleTrainer, 347 | } 348 | 349 | 350 | def get_trainer(model_name): 351 | try: 352 | return TRAINER_MAPPING[model_name] 353 | except KeyError: 354 | raise Exception(f'unknown model name: {model_name}') 355 | -------------------------------------------------------------------------------- /src/lib/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from lib.utils import flatten_sources, unflatten_sources 4 | 5 | 6 | class Transform: 7 | def __init__(self, stft_frame_size, stft_hop_size, device): 8 | self.stft_frame_size = stft_frame_size 9 | self.stft_hop_size = stft_hop_size 10 | 11 | self.hann_window = torch.hann_window( 12 | self.stft_frame_size, 13 | periodic=True, 14 | device=device 15 | ) 16 | 17 | def stft(self, wave): 18 | wave_flat = flatten_sources(wave) 19 | complex_flat = torch.stft( 20 | wave_flat, 21 | n_fft=self.stft_frame_size, 22 | hop_length=self.stft_hop_size, 23 | window=self.hann_window, 24 | return_complex=True 25 | ) 26 | complex = unflatten_sources(complex_flat, num_sources=wave.size(1)) 27 | mag, phase = complex.abs(), complex.angle() 28 | return mag, phase 29 | 30 | def istft(self, mag, phase, length): 31 | complex = torch.complex( 32 | real=mag * phase.cos(), 33 | imag=mag * phase.sin() 34 | ) 35 | complex_flat = flatten_sources(complex) 36 | wave_flat = torch.istft( 37 | complex_flat, 38 | n_fft=self.stft_frame_size, 39 | hop_length=self.stft_hop_size, 40 | window=self.hann_window, 41 | length=length 42 | ) 43 | return unflatten_sources(wave_flat, num_sources=mag.size(1)) 44 | -------------------------------------------------------------------------------- /src/lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import logging 4 | import random 5 | 6 | import torch 7 | 8 | 9 | CONFIG_FILENAME = 'config.pth' 10 | BEST_CHECKPOINT_FILENAME = 'best_checkpoint.pth' 11 | BEST_METRICS_FILENAME = 'best_metrics.pth' 12 | METRICS_HISTORY_FILENAME = 'metrics_history.pth' 13 | EPS = 1e-8 14 | 15 | 16 | class ResultsDirExistsError(Exception): 17 | pass 18 | 19 | 20 | def configure_console_logger(): 21 | logger = logging.getLogger() 22 | console = logging.StreamHandler(sys.stdout) 23 | console.setFormatter(logging.Formatter('%(asctime)s %(name)-6s %(levelname)-6s %(message)s', '%y-%m-%d %H:%M:%S')) 24 | logger.addHandler(console) 25 | 26 | 27 | def configure_file_logger(work_dir): 28 | logger = logging.getLogger() 29 | file = logging.FileHandler(os.path.join(work_dir, 'root.log')) 30 | file.setFormatter(logging.Formatter('%(asctime)s %(name)-6s %(levelname)-6s %(message)s')) 31 | logger.addHandler(file) 32 | 33 | 34 | def get_logger(name, work_dir=None): 35 | logger = logging.getLogger(name) 36 | logger.setLevel(logging.DEBUG) 37 | if work_dir: 38 | file = logging.FileHandler(os.path.join(work_dir, name + '.log')) 39 | file.setLevel(logging.DEBUG) 40 | file.setFormatter(logging.Formatter('%(asctime)s %(name)-6s %(levelname)-6s %(message)s')) 41 | logger.addHandler(file) 42 | return logger 43 | 44 | 45 | def default(input_value, default_value): 46 | return default_value if input_value is None else input_value 47 | 48 | 49 | def build_run_name(args, prepend_items=None, exclude_keys=None): 50 | if prepend_items is None: 51 | prepend_items = {} 52 | if exclude_keys is None: 53 | exclude_keys = [] 54 | 55 | run_name_dict = {} 56 | if prepend_items is not None: 57 | run_name_dict.update(prepend_items) 58 | for k, v in args.items(): 59 | if k not in prepend_items.keys() and k not in exclude_keys and v is not None: 60 | if isinstance(v, list): 61 | v = '_'.join(str(item) for item in v) 62 | run_name_dict[k] = v 63 | return '.'.join(f'{k}_{v}' for k, v in run_name_dict.items()) 64 | 65 | 66 | def ensure_clean_results_dir(results_dir): 67 | if os.path.exists(results_dir): 68 | raise ResultsDirExistsError(f'results dir "{results_dir}" exists, remove it and run the script again.') 69 | os.makedirs(results_dir, exist_ok=True) 70 | 71 | 72 | def setup_determinism(seed): 73 | if seed is None: 74 | torch.use_deterministic_algorithms(False) 75 | torch.backends.cudnn.benchmark = True 76 | else: 77 | random.seed(seed) 78 | torch.manual_seed(seed) 79 | torch.use_deterministic_algorithms(True) 80 | torch.backends.cudnn.benchmark = False 81 | 82 | 83 | def flatten_sources(batch): 84 | batch_size = batch.size(0) 85 | num_sources = batch.size(1) 86 | return batch.reshape(batch_size * num_sources, *batch.size()[2:]) 87 | 88 | 89 | def unflatten_sources(batch, num_sources): 90 | batch_size = batch.size(0) // num_sources 91 | return batch.view(batch_size, num_sources, *batch.size()[1:]) 92 | 93 | 94 | def shuffle_sources(batch, num_sources, shuffle_idx): 95 | batch_flat = flatten_sources(batch) 96 | return unflatten_sources(batch_flat[shuffle_idx], num_sources) 97 | 98 | 99 | def total_num_params(params): 100 | return sum(param.numel() for param in params) 101 | 102 | 103 | def soft_mask(m_pred_mag, x_true_mag): 104 | return (m_pred_mag / (m_pred_mag.sum(dim=1, keepdim=True) + EPS)) * x_true_mag 105 | 106 | 107 | def metrics_to_str(metrics): 108 | return ' '.join('{}={:.3f}'.format(k, v) for k, v in metrics.items()) 109 | 110 | 111 | class MetricAccumulator: 112 | def __init__(self): 113 | self.accumulator = None 114 | self.reset() 115 | 116 | def store(self, batch): 117 | self.accumulator.append(batch.detach()) 118 | 119 | def reset(self): 120 | self.accumulator = [] 121 | 122 | def std_mean(self): 123 | with torch.inference_mode(): 124 | epoch_values = torch.cat(self.accumulator) 125 | return torch.std_mean(epoch_values, unbiased=True) 126 | -------------------------------------------------------------------------------- /src/sc09mix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from types import SimpleNamespace 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import trange 8 | 9 | from lib.utils import CONFIG_FILENAME, configure_console_logger, get_logger, default, build_run_name, \ 10 | ensure_clean_results_dir, setup_determinism 11 | from lib.data.sc09 import SAMPLE_FILENAME_FORMAT, SC09 12 | 13 | 14 | class Mixer: 15 | def __init__(self, sc_dir, mix_results_root, 16 | num_train_samples=None, 17 | num_test_samples=None, 18 | num_components=None, 19 | classes=None, 20 | sample_rate=None, 21 | normalization=None, 22 | seed=None): 23 | args = locals() 24 | del args['self'] 25 | run_name = build_run_name( 26 | args=args, 27 | prepend_items={'dataset': 'sc09mix'}, 28 | exclude_keys=['sc_dir', 'mix_results_root'] 29 | ) 30 | self.results_dir = os.path.join(mix_results_root, run_name) 31 | 32 | self.config = SimpleNamespace() 33 | self.config.results_dir = self.results_dir 34 | self.config.sc_dir = sc_dir 35 | self.config.mix_results_root = mix_results_root 36 | self.config.num_train_samples = default(num_train_samples, 15000) 37 | self.config.num_test_samples = default(num_test_samples, 5000) 38 | self.config.num_components = default(num_components, 2) 39 | self.config.classes = default(classes, list(range(len(SC09.CLASSES)))) 40 | self.config.sample_rate = default(sample_rate, 8000) 41 | self.config.normalization = default(normalization, 'standardize') 42 | self.config.seed = default(seed, None) 43 | self.config.sample_length = self.config.sample_rate # 1 second recordings 44 | 45 | def start(self): 46 | ensure_clean_results_dir(self.config.results_dir) 47 | setup_determinism(self.config.seed) 48 | 49 | self.logger = get_logger('sc09mix', self.config.results_dir) 50 | self.logger.info('config: %s', self.config) 51 | torch.save(self.config, os.path.join(self.config.results_dir, CONFIG_FILENAME)) 52 | 53 | for partition in SC09.PARTITIONS: 54 | num_samples = self.config.num_train_samples if partition == 'training' else self.config.num_test_samples 55 | self.logger.info('generating the %s partition...', partition) 56 | 57 | mix_dir = os.path.join(self.config.results_dir, 'mix', partition) 58 | os.makedirs(mix_dir, exist_ok=True) 59 | 60 | infinite_dataloader = self._create_infinite_dataloader(partition) 61 | 62 | for sample_idx in trange(num_samples): 63 | mixture = self._generate_mixture(infinite_dataloader) 64 | torch.save(mixture, os.path.join(mix_dir, SAMPLE_FILENAME_FORMAT.format(sample_idx))) 65 | 66 | self.logger.info('completed') 67 | 68 | def _generate_mixture(self, dataloader): 69 | mixture = torch.zeros(self.config.sample_length) 70 | sources = torch.zeros(self.config.num_components, self.config.sample_length) 71 | for i in range(self.config.num_components): 72 | source, _ = next(dataloader) 73 | source = source.view(-1) 74 | 75 | if self.config.normalization == 'none': 76 | pass 77 | elif self.config.normalization == 'rms': 78 | source = source / source.square().mean().sqrt() 79 | elif self.config.normalization == 'standardize': 80 | source = (source - source.mean()) / source.std() 81 | else: 82 | raise Exception(f'unknown normalization: {self.config.normalization}') 83 | 84 | mixture += source 85 | sources[i] = source 86 | 87 | return { 88 | 'mixture': mixture, 89 | 'sources': sources, 90 | } 91 | 92 | def _create_infinite_dataloader(self, partition): 93 | dataset = SC09( 94 | root=self.config.sc_dir, 95 | partition=partition, 96 | classes=[SC09.CLASSES[class_idx] for class_idx in self.config.classes], 97 | sample_rate=self.config.sample_rate 98 | ) 99 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True) 100 | while True: 101 | for sample in dataloader: 102 | yield sample 103 | 104 | 105 | if __name__ == '__main__': 106 | configure_console_logger() 107 | 108 | arg_parser = argparse.ArgumentParser() 109 | arg_parser.add_argument('--sc-dir', type=str, required=True) 110 | arg_parser.add_argument('--mix-results-root', type=str, required=True) 111 | arg_parser.add_argument('--partition', choices=['training', 'validation', 'testing']) 112 | arg_parser.add_argument('--num-samples', type=int) 113 | arg_parser.add_argument('--num-components', type=int) 114 | arg_parser.add_argument('--classes', type=str, help='comma separated') 115 | arg_parser.add_argument('--sample-rate', type=int) 116 | arg_parser.add_argument('--normalization', choices=['none', 'rms', 'standardize']) 117 | arg_parser.add_argument('--seed', type=int) 118 | 119 | cmd_args = arg_parser.parse_args() 120 | if cmd_args.classes: 121 | cmd_args.classes = [int(cls) for cls in cmd_args.classes.split(',')] 122 | 123 | Mixer(**vars(cmd_args)).start() 124 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from types import SimpleNamespace 4 | 5 | import torch 6 | 7 | from lib.utils import CONFIG_FILENAME, BEST_CHECKPOINT_FILENAME, BEST_METRICS_FILENAME, configure_console_logger, \ 8 | get_logger, default, build_run_name, ensure_clean_results_dir, setup_determinism, metrics_to_str 9 | from lib.data.dataloader_utils import get_dataset_specs, create_dataloader 10 | from lib.models import Model 11 | from lib.trainers import get_trainer 12 | 13 | 14 | class Test: 15 | def __init__(self, train_results_dir, librimix_root=None, realm_root=None, 16 | eval_method=None, 17 | eval_blind_num_repeat=None, 18 | seed=None, 19 | device_name=None): 20 | args = locals() 21 | del args['self'] 22 | 23 | self.dataset_name, self.dataset_root = get_dataset_specs(librimix_root, realm_root) 24 | run_name = build_run_name( 25 | args=args, 26 | prepend_items={'test': self.dataset_name}, 27 | exclude_keys=['train_results_dir', 'librimix_root', 'realm_root'] 28 | ) 29 | self.results_dir = os.path.join(train_results_dir, 'test', run_name) 30 | 31 | self.config = SimpleNamespace() 32 | self.config.results_dir = self.results_dir 33 | self.config.train_results_dir = train_results_dir 34 | self.config.librimix_root = librimix_root 35 | self.config.realm_root = realm_root 36 | self.config.eval_method = default(eval_method, None) 37 | self.config.eval_blind_num_repeat = default(eval_blind_num_repeat, None) 38 | self.config.seed = default(seed, None) 39 | self.config.device_name = default(device_name, 'cuda') 40 | 41 | self.logger = None 42 | self.config.device = torch.device(self.config.device_name) 43 | 44 | def start(self): 45 | ensure_clean_results_dir(self.config.results_dir) 46 | setup_determinism(self.config.seed) 47 | self.logger = get_logger('test', self.config.results_dir) 48 | self.logger.info('config: %s', self.config) 49 | torch.save(self.config, os.path.join(self.config.results_dir, CONFIG_FILENAME)) 50 | 51 | model = Model.load( 52 | path=os.path.join(self.config.train_results_dir, BEST_CHECKPOINT_FILENAME), 53 | device=self.config.device 54 | ).eval() 55 | 56 | if self.config.eval_method: 57 | model.config.eval_method = self.config.eval_method 58 | 59 | if self.config.eval_blind_num_repeat: 60 | model.config.eval_blind_num_repeat = self.config.eval_blind_num_repeat 61 | 62 | if model.config.eval_method == 'blind': 63 | model_name = 'mixcycle' 64 | partition = 'validation' 65 | batch_size = 128 66 | shuffle = True 67 | elif model.config.eval_method == 'reference-valid': 68 | model_name = model.config.model_name 69 | partition = 'validation' 70 | batch_size = 128 71 | shuffle = False 72 | else: 73 | model_name = model.config.model_name 74 | partition = 'testing' 75 | batch_size = 1 76 | shuffle = False 77 | 78 | dataloader = create_dataloader( 79 | dataset_name=self.dataset_name, 80 | dataset_root=self.dataset_root, 81 | partition=partition, 82 | batch_size=batch_size, 83 | shuffle=shuffle 84 | ) 85 | self.logger.info('using %d samples for evaluation', len(dataloader.dataset)) 86 | 87 | trainer = get_trainer(model_name)(model=model) 88 | with torch.inference_mode(): 89 | metrics = trainer.validate(dataloader) 90 | 91 | self.logger.info('[TEST] %s', metrics_to_str(metrics)) 92 | 93 | torch.save(metrics, os.path.join(self.config.results_dir, BEST_METRICS_FILENAME)) 94 | 95 | self.logger.info('completed') 96 | 97 | 98 | if __name__ == '__main__': 99 | configure_console_logger() 100 | 101 | arg_parser = argparse.ArgumentParser() 102 | arg_parser.add_argument('--train-results-dir', type=str, required=True) 103 | arg_parser.add_argument('--librimix-root', type=str) 104 | arg_parser.add_argument('--realm-root', type=str) 105 | arg_parser.add_argument('--eval-method', choices=['reference', 'reference-valid', 'blind']) 106 | arg_parser.add_argument('--eval-blind-num-repeat', type=int) 107 | arg_parser.add_argument('--seed', type=int) 108 | arg_parser.add_argument('--device-name', type=str) 109 | 110 | cmd_args = arg_parser.parse_args() 111 | Test(**vars(cmd_args)).start() 112 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from time import time 4 | from types import SimpleNamespace 5 | 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | from tqdm import tqdm 9 | 10 | from lib.utils import CONFIG_FILENAME, BEST_CHECKPOINT_FILENAME, METRICS_HISTORY_FILENAME, configure_console_logger, \ 11 | default, build_run_name, ensure_clean_results_dir, setup_determinism, get_logger, total_num_params 12 | from lib.data.dataloader_utils import get_dataset_specs, create_dataloader 13 | from lib.trainers import get_trainer 14 | from lib.models import Model 15 | 16 | 17 | class Training: 18 | def __init__(self, train_results_root, librimix_root=None, realm_root=None, 19 | stft_frame_size=None, 20 | stft_hop_size=None, 21 | model_name=None, 22 | model_load_path=None, 23 | mixcycle_init_epochs=None, 24 | snr_max=None, 25 | train_batch_size=None, 26 | valid_batch_size=None, 27 | lr=None, 28 | grad_clip=None, 29 | train_subsample_ratio=None, 30 | valid_subsample_ratio=None, 31 | eval_method=None, 32 | eval_blind_num_repeat=None, 33 | eval_epochs=None, 34 | patience=None, 35 | min_epochs=None, 36 | seed=None, 37 | run_id=None, 38 | device_name=None): 39 | args = locals() 40 | del args['self'] 41 | 42 | self.dataset_name, self.dataset_root = get_dataset_specs(librimix_root, realm_root) 43 | run_name = build_run_name( 44 | args=args, 45 | prepend_items={'train': self.dataset_name}, 46 | exclude_keys=['train_results_root', 'realm_root', 'librimix_root', 'model_load_path'] 47 | ) 48 | self.results_dir = os.path.join(train_results_root, run_name) 49 | 50 | self.config = SimpleNamespace() 51 | self.config.results_dir = self.results_dir 52 | self.config.stft_frame_size = default(stft_frame_size, 512) 53 | self.config.stft_hop_size = default(stft_hop_size, 128) 54 | self.config.model_name = default(model_name, 'mixcycle') 55 | self.config.model_load_path = default(model_load_path, None) 56 | self.config.mixcycle_init_epochs = default(mixcycle_init_epochs, 50) 57 | self.config.snr_max = default(snr_max, 30.0) 58 | self.config.train_batch_size = default(train_batch_size, 128) 59 | self.config.valid_batch_size = default(valid_batch_size, 128) 60 | self.config.lr = default(lr, 0.001) 61 | self.config.grad_clip = default(grad_clip, 5.0) 62 | self.config.train_subsample_ratio = default(train_subsample_ratio, 1.0) 63 | self.config.valid_subsample_ratio = default(valid_subsample_ratio, 1.0) 64 | self.config.eval_method = default(eval_method, 'reference') 65 | self.config.eval_blind_num_repeat = default(eval_blind_num_repeat, 1) 66 | self.config.eval_epochs = default(eval_epochs, 1) 67 | self.config.patience = default(patience, 50) 68 | self.config.min_epochs = default(min_epochs, None) 69 | self.config.seed = default(seed, None) 70 | self.config.device_name = default(device_name, 'cuda') 71 | 72 | def start(self): 73 | ensure_clean_results_dir(self.config.results_dir) 74 | 75 | self.logger = get_logger('train', self.config.results_dir) 76 | self.logger.info('config: %s', self.config) 77 | torch.save(self.config, os.path.join(self.config.results_dir, CONFIG_FILENAME)) 78 | 79 | setup_determinism(self.config.seed) 80 | 81 | self.config.device = torch.device(self.config.device_name) 82 | self.tensorboard = SummaryWriter(os.path.join(self.config.results_dir, 'tb')) 83 | 84 | self.cur_epoch = 0 85 | self.cur_step = 0 86 | self.cur_patience = self.config.patience 87 | self.last_best = {} 88 | self.metrics_history = [] 89 | 90 | self.train_dataloader = create_dataloader( 91 | dataset_name=self.dataset_name, 92 | dataset_root=self.dataset_root, 93 | partition='training', 94 | batch_size=self.config.train_batch_size, 95 | subsample_ratio=self.config.train_subsample_ratio, 96 | ) 97 | self.valid_dataloader = create_dataloader( 98 | dataset_name=self.dataset_name, 99 | dataset_root=self.dataset_root, 100 | partition='validation', 101 | batch_size=self.config.valid_batch_size, 102 | subsample_ratio=self.config.valid_subsample_ratio, 103 | ) 104 | 105 | _, s_true_wave = next(iter(self.train_dataloader)) 106 | self.config.num_sources = s_true_wave.size(1) 107 | self.config.sample_length = s_true_wave.size(2) 108 | self.config.num_batches = len(self.train_dataloader) 109 | self.config.num_frequency_bins = 1 + self.config.stft_frame_size // 2 110 | 111 | if self.config.model_load_path: 112 | model = Model.load( 113 | path=os.path.join(self.config.model_load_path, BEST_CHECKPOINT_FILENAME), 114 | device=self.config.device 115 | ).train() 116 | else: 117 | model = None 118 | 119 | trainer_cls = get_trainer(self.config.model_name) 120 | self.trainer = trainer_cls(config=self.config, model=model) 121 | 122 | self.logger.info('model parameter count: %d', total_num_params(self.trainer.get_model().parameters())) 123 | self._train_loop() 124 | self.logger.info('completed') 125 | 126 | def _train_loop(self): 127 | while self.cur_patience > 0: 128 | for _ in range(self.config.eval_epochs): 129 | train_start = time() 130 | steps = self.trainer.train(self.train_dataloader) 131 | with tqdm(steps, total=self.config.num_batches, leave=False) as progress_bar: 132 | for increment in steps: 133 | self.cur_step += 1 134 | progress_bar.update(increment) 135 | self.cur_epoch += 1 136 | 137 | metrics = { 138 | 'process': { 139 | 'epoch': self.cur_epoch, 140 | 'step': self.cur_step, 141 | 'train_elapsed': time() - train_start, 142 | 'train_loss': self.trainer.get_loss() 143 | }, 144 | 'validation': {}, 145 | } 146 | 147 | validate_start = time() 148 | with torch.inference_mode(): 149 | metrics.update({'validation': self.trainer.validate(self.valid_dataloader)}) 150 | metrics['process']['validate_elapsed'] = time() - validate_start 151 | 152 | self._update_best(metrics) 153 | self._update_patience(metrics) 154 | self._report(metrics) 155 | 156 | def _update_best(self, metrics): 157 | metrics['best'] = {} 158 | for key, value in metrics['validation'].items(): 159 | if key not in self.last_best or value > self.last_best[key]: 160 | self.last_best[key] = value 161 | metrics['best'][key] = True 162 | else: 163 | metrics['best'][key] = False 164 | return metrics 165 | 166 | def _update_patience(self, metrics): 167 | min_epoch = self.config.min_epochs and self.config.min_epochs > self.cur_epoch 168 | 169 | if self.config.eval_method == 'reference': 170 | main_metric_name = 'sisnri' 171 | elif self.config.eval_method == 'blind': 172 | main_metric_name = 'sisnri_blind' 173 | else: 174 | raise Exception(f'unknown eval_method: {self.config.eval_method}') 175 | 176 | if min_epoch or metrics['best'][main_metric_name] or metrics['best'].get('sisnri_mixit_oracle'): 177 | self.cur_patience = self.config.patience 178 | self.trainer.get_model().save(os.path.join(self.config.results_dir, BEST_CHECKPOINT_FILENAME)) 179 | else: 180 | self.cur_patience -= 1 181 | metrics['process']['patience'] = self.cur_patience 182 | 183 | def _report(self, metrics): 184 | log_line = '' 185 | for group in ['process', 'validation']: 186 | for key, value in metrics[group].items(): 187 | format_spec = '{}={:.3f}' if isinstance(value, float) else '{}={}' 188 | log_line += format_spec.format(key, value) 189 | 190 | if group == 'validation' and metrics['best'][key]: 191 | log_line += '*' 192 | else: 193 | log_line += ' ' 194 | 195 | if group == 'validation': 196 | log_line += ' ' 197 | 198 | self.tensorboard.add_scalar( 199 | tag='{}/{}'.format(group, key), 200 | scalar_value=value, 201 | global_step=metrics['process']['step'] 202 | ) 203 | if group == 'process': 204 | log_line += '| ' 205 | 206 | self.logger.info(log_line) 207 | 208 | self.tensorboard.flush() 209 | 210 | metrics['time'] = time() 211 | self.metrics_history.append(metrics) 212 | metrics_history_path = os.path.join(self.config.results_dir, METRICS_HISTORY_FILENAME) 213 | os.makedirs(os.path.dirname(metrics_history_path), exist_ok=True) 214 | torch.save(self.metrics_history, metrics_history_path) 215 | 216 | 217 | if __name__ == '__main__': 218 | configure_console_logger() 219 | 220 | arg_parser = argparse.ArgumentParser() 221 | arg_parser.add_argument('--train-results-root', type=str, required=True) 222 | arg_parser.add_argument('--librimix-root', type=str) 223 | arg_parser.add_argument('--realm-root', type=str) 224 | arg_parser.add_argument('--stft-frame-size', type=int) 225 | arg_parser.add_argument('--stft-hop-size', type=int) 226 | arg_parser.add_argument('--model-name', choices=['pit', 'pit-dm', 'mixit', 'mixpit', 'mixcycle']) 227 | arg_parser.add_argument('--mixcycle-init-epochs', type=int) 228 | arg_parser.add_argument('--snr-max', type=float) 229 | arg_parser.add_argument('--train-batch-size', type=int) 230 | arg_parser.add_argument('--valid-batch-size', type=int) 231 | arg_parser.add_argument('--lr', type=float) 232 | arg_parser.add_argument('--grad-clip', type=float) 233 | arg_parser.add_argument('--train-subsample-ratio', type=float) 234 | arg_parser.add_argument('--valid-subsample-ratio', type=float) 235 | arg_parser.add_argument('--eval-method', choices=['reference', 'blind']) 236 | arg_parser.add_argument('--eval-epochs', type=int) 237 | arg_parser.add_argument('--patience', type=int) 238 | arg_parser.add_argument('--seed', type=int) 239 | arg_parser.add_argument('--run-id', type=str) 240 | arg_parser.add_argument('--device-name', type=str) 241 | 242 | cmd_args = arg_parser.parse_args() 243 | Training(**vars(cmd_args)).start() 244 | --------------------------------------------------------------------------------