├── .gitignore
├── README.md
├── environment.yml
├── notebooks
├── AudioSamples-LibriMix-MixIT.ipynb
├── AudioSamples-LibriMix-PIT-DM.ipynb
├── AudioSamples-LibriMix.ipynb
├── AudioSamples-REAL-M.ipynb
└── utils.py
└── src
├── experiment.py
├── irm.py
├── lib
├── __init__.py
├── data
│ ├── __init__.py
│ ├── collate_utils.py
│ ├── dataloader_utils.py
│ ├── realm.py
│ └── sc09.py
├── losses.py
├── models.py
├── trainers.py
├── transforms.py
└── utils.py
├── sc09mix.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 | .ipynb_checkpoints
5 | .idea/
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MixCycle: Unsupervised Speech Separation via Cyclic Mixture Permutation Invariant Training
2 | This repository contains the audio samples and the source code that accompany [the paper](https://arxiv.org/abs/2202.03875).
3 |
4 | ## Audio samples
5 | We provide audio samples to demonstrate the results of the MixCycle method on two different datasets: [LibriMix](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-LibriMix.ipynb) and [REAL-M](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-REAL-M.ipynb).
6 |
7 | Also note that the provided [REAL-M](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-REAL-M.ipynb) samples were used in the informal listening test.
8 |
9 | We also provide audio samples from the baseline methods on LibriMix: [PIT-DM](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-LibriMix-PIT-DM.ipynb) and [MixIT](https://nbviewer.org/github/ertug/MixCycle/blob/main/notebooks/AudioSamples-LibriMix-MixIT.ipynb).
10 |
11 | ## Source code
12 | We provide the source code under the `src` directory for reproducibility.
13 |
14 | ## Running the experiments
15 |
16 | ### Prepare the datasets
17 | - LibriMix: [GitHub](https://github.com/JorisCos/LibriMix)
18 | - REAL-M: [Download](https://sourceseparationresearch.com/static/REAL-M-v0.1.0.tar.gz)
19 |
20 | ### Create the environment
21 |
22 | Install [Anaconda](https://www.anaconda.com/products/individual) and run the following command:
23 | ```
24 | $ conda env create -f environment.yml
25 | ```
26 | See [more info](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html) on how to manage conda environments.
27 |
28 | ### Activate the environment
29 | ```
30 | $ conda activate mixcycle
31 | ```
32 |
33 | ### Run the experiments
34 | ```
35 | $ cd src
36 | $ python experiment.py --librimix-root ~/datasets/librimix --exp-root ~/experiments --run librimix_irm
37 | $ python experiment.py --librimix-root ~/datasets/librimix --exp-root ~/experiments --run librimix_5p
38 | $ python experiment.py --librimix-root ~/datasets/librimix --exp-root ~/experiments --run librimix_100p
39 | $ python experiment.py --librimix-root ~/datasets/librimix --realm-root ~/datasets/REAL-M-v0.1.0 --exp-root ~/experiments --run realm
40 | ```
41 |
42 | Optionally, you can monitor the training process with TensorBoard by running:
43 | ```
44 | $ tensorboard --logdir experiments
45 | ```
46 |
47 | ## Citation (BibTeX)
48 | If you find this repository useful, please cite our work:
49 |
50 | ```BibTeX
51 | @article{karamatli2022unsupervised,
52 | title={MixCycle: Unsupervised Speech Separation via Cyclic Mixture Permutation Invariant Training},
53 | author={Karamatl{\i}, Ertu{\u{g}} and K{\i}rb{\i}z, Serap},
54 | journal={IEEE Signal Processing Letters},
55 | volume={29},
56 | number={},
57 | pages={2637-2641},
58 | year={2022},
59 | doi={10.1109/LSP.2022.3232276}
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: mixcycle
2 | channels:
3 | - defaults
4 | dependencies:
5 | - python=3.9.7
6 | - pytorch::pytorch=1.12.1
7 | - pytorch::torchaudio=0.12.1
8 | - cudatoolkit=10.2
9 | - tqdm=4.62.3
10 | - conda-forge::tensorboard=2.7.0
11 | - conda-forge::jupyterlab=3.2.5
12 | - conda-forge::ipywidgets=7.6.5
13 |
--------------------------------------------------------------------------------
/notebooks/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from functools import partial
4 | from types import SimpleNamespace
5 |
6 | import torch
7 | from torchaudio.datasets.librimix import LibriMix
8 | import ipywidgets as widgets
9 | import IPython.display as ipd
10 |
11 | from train import BEST_CHECKPOINT_FILENAME, Training
12 | from lib.data.realm import RealM
13 | from lib.models import Model
14 | from lib.losses import negative_sisnri, invariant_loss
15 |
16 |
17 | SAMPLE_RATE = 8000
18 |
19 |
20 | def collate_samples(dataset_name, dataset_root, count):
21 | if dataset_name == 'librimix':
22 | dataset = LibriMix(
23 | root=dataset_root,
24 | subset='test',
25 | num_speakers=2,
26 | sample_rate=SAMPLE_RATE,
27 | task='sep_clean',
28 | )
29 | elif dataset_name == 'realm':
30 | dataset = RealM(
31 | root=dataset_root,
32 | partition='testing',
33 | )
34 | else:
35 | raise Exception(f'unknown dataset name: {dataset_name}')
36 |
37 | selected_idxs = torch.randperm(len(dataset))[:count]
38 | data_samples = []
39 | for idx in selected_idxs:
40 | _, x_true_wave, s_true_wave = dataset[idx]
41 | data_sample = SimpleNamespace(
42 | x_true_wave=x_true_wave,
43 | s_true_wave=torch.cat(s_true_wave)
44 | )
45 | data_samples.append(data_sample)
46 | return data_samples
47 |
48 |
49 | def evaluate_samples(train_results_root, dataset_name, model_name, data_samples, **kwargs):
50 | training = Training(
51 | train_results_root=train_results_root,
52 | librimix_root=(dataset_name == 'librimix'),
53 | realm_root=(dataset_name == 'realm'),
54 | model_name=model_name,
55 | **kwargs
56 | )
57 |
58 | with torch.inference_mode():
59 | model = Model.load(
60 | path=os.path.join(training.results_dir, BEST_CHECKPOINT_FILENAME),
61 | device='cpu'
62 | ).eval()
63 |
64 | separation_samples = []
65 | for data_sample in data_samples:
66 | s_pred_wave = model(data_sample.x_true_wave.unsqueeze(0))
67 | mixing_matrices = model.generate_mixing_matrices(
68 | num_targets=model.config.num_sources,
69 | max_sources=model.num_sources,
70 | num_mix=1,
71 | allow_empty=True
72 | )
73 | negative_sisnri_value, best_perm_idx = invariant_loss(
74 | true=data_sample.s_true_wave.unsqueeze(0),
75 | pred=s_pred_wave,
76 | mixing_matrices=mixing_matrices,
77 | loss_func=partial(negative_sisnri, x_true_wave=data_sample.x_true_wave.unsqueeze(0)),
78 | return_best_perm_idx=True
79 | )
80 | s_pred_wave_permuted = mixing_matrices[best_perm_idx].matmul(s_pred_wave)
81 |
82 | separation_samples.append(SimpleNamespace(
83 | x_true_wave=data_sample.x_true_wave,
84 | s_true_wave=data_sample.s_true_wave,
85 | s_pred_wave=s_pred_wave_permuted.squeeze(0),
86 | sisnri=-negative_sisnri_value.item()
87 | ))
88 | return separation_samples
89 |
90 |
91 | def _audio_widget(data):
92 | out = widgets.Output()
93 | with out:
94 | ipd.display(ipd.Audio(data, rate=SAMPLE_RATE))
95 | return out
96 |
97 |
98 | def _mixture_widget(html, sample):
99 | return widgets.VBox([
100 | widgets.HTML(html),
101 | _audio_widget(sample.x_true_wave)
102 | ])
103 |
104 |
105 | def show_samples_librimix(separation_samples):
106 | def _source_widget(text, s_wave):
107 | return widgets.VBox([
108 | widgets.HTML('
{} 1
'.format(text)),
109 | _audio_widget(s_wave[0]),
110 | widgets.HTML('{} 2
'.format(text)),
111 | _audio_widget(s_wave[1]),
112 | ])
113 |
114 | def sisnri(separation_sample):
115 | html = 'Mean SI-SNRi:
' \
116 | '{:.1f}
' \
117 | .format(separation_sample.sisnri)
118 | return widgets.HTML(html)
119 |
120 | for idx, separation_sample in enumerate(separation_samples):
121 | ipd.display(widgets.HBox([
122 | _mixture_widget(html='Mixture #{}
'.format(idx+1), sample=separation_sample),
123 | _source_widget(text='Reference Source', s_wave=separation_sample.s_true_wave),
124 | _source_widget(text='Estimated Source', s_wave=separation_sample.s_pred_wave),
125 | sisnri(separation_sample),
126 | ], layout=widgets.Layout(align_items='center', padding='3px', margin='5px', border='2px solid gray')))
127 |
128 |
129 | def show_samples_realm(separation_samples_librimix, separation_samples_realm):
130 | def _source_widget(model_id, sample):
131 | return widgets.VBox([
132 | widgets.HTML('Model {} / Est. Source 1
'.format(model_id)),
133 | _audio_widget(sample.s_pred_wave[0]),
134 | widgets.HTML('Model {} / Est. Source 2
'.format(model_id)),
135 | _audio_widget(sample.s_pred_wave[1]),
136 | ])
137 |
138 | for idx, (sample_librimix, sample_realm) in enumerate(zip(separation_samples_librimix, separation_samples_realm)):
139 | ipd.display(widgets.HBox([
140 | _mixture_widget(html='Mixture #{}
'.format(idx+1), sample=sample_librimix),
141 | _source_widget(model_id='A', sample=sample_librimix),
142 | _source_widget(model_id='B', sample=sample_realm)
143 | ], layout=widgets.Layout(align_items='center', padding='3px', margin='5px', border='2px solid gray')))
144 |
145 |
146 | def load_file(path):
147 | try:
148 | data = torch.load(path)
149 | return data
150 | except FileNotFoundError as e:
151 | print(e, file=sys.stderr)
152 | return None
153 |
--------------------------------------------------------------------------------
/src/experiment.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from types import SimpleNamespace
4 |
5 | import torch
6 |
7 | from lib.utils import CONFIG_FILENAME, configure_console_logger, configure_file_logger, get_logger
8 | from train import Training
9 | from test import Test
10 | from irm import IdealRatioMask
11 |
12 |
13 | class Experiment:
14 | def __init__(self, experiments_root, experiment_name, librimix_root=None, realm_root=None):
15 | self.config = SimpleNamespace()
16 | self.config.experiments_root = experiments_root
17 | self.config.datasets_root = os.path.join(experiments_root, 'datasets')
18 | self.config.checkpoints_root = os.path.join(experiments_root, experiment_name)
19 | self.config.librimix_root = librimix_root
20 | self.config.realm_root = realm_root
21 |
22 | if os.path.exists(self.config.checkpoints_root):
23 | raise Exception(f'checkpoints root "{self.config.checkpoints_root}" exists!')
24 | os.makedirs(self.config.checkpoints_root, exist_ok=True)
25 |
26 | configure_file_logger(self.config.checkpoints_root)
27 | self.logger = get_logger('exp', self.config.checkpoints_root)
28 | self.logger.info('config: %s', self.config)
29 | torch.save(self.config, os.path.join(self.config.checkpoints_root, CONFIG_FILENAME))
30 |
31 | def librimix_irm(self):
32 | IdealRatioMask(
33 | irm_results_root=self.config.checkpoints_root,
34 | librimix_root=self.config.librimix_root
35 | ).start()
36 |
37 | def model_comparison_librimix(self,
38 | model_names=None,
39 | mixcycle_init_epochs=50,
40 | train_subsample_ratio=1.,
41 | eval_epochs=1,
42 | run_id=None):
43 | all_model_names = ['mixcycle', 'pit', 'pit-dm', 'mixit', 'mixpit']
44 | for model_name in (all_model_names if model_names is None else model_names):
45 | self.logger.info('model_name=%s', model_name)
46 |
47 | training = Training(
48 | train_results_root=self.config.checkpoints_root,
49 | librimix_root=self.config.librimix_root,
50 | model_name=model_name,
51 | mixcycle_init_epochs=mixcycle_init_epochs,
52 | train_subsample_ratio=train_subsample_ratio,
53 | eval_epochs=eval_epochs,
54 | run_id=run_id
55 | )
56 | training.start()
57 |
58 | Test(
59 | train_results_dir=training.results_dir,
60 | librimix_root=self.config.librimix_root
61 | ).start()
62 |
63 | def librimix_5p(self):
64 | self.model_comparison_librimix(
65 | #model_names=['mixcycle'],
66 | mixcycle_init_epochs=250,
67 | train_subsample_ratio=0.05,
68 | eval_epochs=10
69 | )
70 |
71 | def librimix_100p(self):
72 | self.model_comparison_librimix(
73 | mixcycle_init_epochs=50,
74 | train_subsample_ratio=1.,
75 | eval_epochs=1
76 | )
77 |
78 | def realm(self):
79 | initialized_model = Training(
80 | train_results_root=os.path.join(self.config.experiments_root, 'librimix_100p'),
81 | librimix_root=True,
82 | model_name='mixcycle',
83 | mixcycle_init_epochs=50,
84 | train_subsample_ratio=1.0,
85 | eval_epochs=1
86 | )
87 |
88 | training = Training(
89 | train_results_root=self.config.checkpoints_root,
90 | realm_root=self.config.realm_root,
91 | model_name='mixcycle',
92 | model_load_path=initialized_model.results_dir,
93 | mixcycle_init_epochs=0,
94 | eval_method='blind',
95 | eval_blind_num_repeat=20,
96 | eval_epochs=40
97 | )
98 | training.start()
99 |
100 | Test(
101 | train_results_dir=training.results_dir,
102 | realm_root=self.config.realm_root,
103 | eval_blind_num_repeat=100
104 | ).start()
105 |
106 |
107 | if __name__ == '__main__':
108 | configure_console_logger()
109 |
110 | parser = argparse.ArgumentParser()
111 | parser.add_argument('--exp-root', type=str, required=True)
112 | parser.add_argument('--run', type=str, required=True)
113 | parser.add_argument('--librimix-root', type=str)
114 | parser.add_argument('--realm-root', type=str)
115 |
116 | args = parser.parse_args()
117 | experiment = Experiment(
118 | experiments_root=args.exp_root,
119 | experiment_name=args.run,
120 | librimix_root=args.librimix_root,
121 | realm_root=args.realm_root
122 | )
123 |
124 | try:
125 | func = getattr(experiment, args.run)
126 | except AttributeError:
127 | raise Exception('unknown experiment')
128 |
129 | func()
130 |
--------------------------------------------------------------------------------
/src/irm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from types import SimpleNamespace
4 |
5 | import torch
6 |
7 | from lib.utils import BEST_METRICS_FILENAME, configure_console_logger, get_logger, default, build_run_name, \
8 | ensure_clean_results_dir, soft_mask, MetricAccumulator, metrics_to_str
9 | from lib.data.dataloader_utils import get_dataset_specs, create_dataloader
10 | from lib.transforms import Transform
11 | from lib.losses import sisnri
12 |
13 |
14 | class IdealRatioMask:
15 | def __init__(self, irm_results_root, librimix_root=None, realm_root=None,
16 | stft_frame_size=None,
17 | stft_hop_size=None,
18 | device_name=None):
19 | args = locals()
20 | del args['self']
21 |
22 | self.dataset_name, self.dataset_root = get_dataset_specs(librimix_root, realm_root)
23 | run_name = build_run_name(
24 | args=args,
25 | prepend_items={'irm': self.dataset_name},
26 | exclude_keys=['irm_results_root', 'sc09mix_root', 'librimix_root']
27 | )
28 | self.results_dir = os.path.join(irm_results_root, run_name)
29 |
30 | self.config = SimpleNamespace()
31 | self.config.results_dir = self.results_dir
32 | self.config.librimix_root = librimix_root
33 | self.config.realm_root = realm_root
34 | self.config.stft_frame_size = default(stft_frame_size, 512)
35 | self.config.stft_hop_size = default(stft_hop_size, 128)
36 | self.config.device_name = default(device_name, 'cuda')
37 |
38 | self.logger = None
39 | self.config.device = torch.device(self.config.device_name)
40 |
41 | def start(self):
42 | ensure_clean_results_dir(self.config.results_dir)
43 | self.logger = get_logger('irm')
44 | self.logger.info('config: %s', self.config)
45 |
46 | dataloader = create_dataloader(
47 | dataset_name=self.dataset_name,
48 | dataset_root=self.dataset_root,
49 | partition='testing',
50 | batch_size=1,
51 | )
52 | self.logger.info('using %d samples for evaluation', len(dataloader.dataset))
53 |
54 | transform = Transform(
55 | stft_frame_size=self.config.stft_frame_size,
56 | stft_hop_size=self.config.stft_hop_size,
57 | device=self.config.device,
58 | )
59 |
60 | with torch.inference_mode():
61 | sisnri_accumulator = MetricAccumulator()
62 |
63 | for x_true_wave, s_true_wave in dataloader:
64 | x_true_wave = x_true_wave.to(self.config.device)
65 | s_true_wave = s_true_wave.to(self.config.device)
66 |
67 | x_true_mag, x_true_phase = transform.stft(x_true_wave)
68 | s_true_mag, _ = transform.stft(s_true_wave)
69 | s_pred_mag = soft_mask(s_true_mag, x_true_mag)
70 | s_pred_wave = transform.istft(
71 | mag=s_pred_mag,
72 | phase=x_true_phase,
73 | length=x_true_wave.size(-1)
74 | )
75 |
76 | batch_sisnri = sisnri(
77 | true_wave=s_true_wave,
78 | pred_wave=s_pred_wave,
79 | x_true_wave=x_true_wave,
80 | )
81 |
82 | sisnri_accumulator.store(batch_sisnri)
83 |
84 | std, mean = sisnri_accumulator.std_mean()
85 | metrics = {
86 | 'sisnri': mean.item(),
87 | 'sisnri_std': std.item()
88 | }
89 |
90 | metrics_str = metrics_to_str(metrics)
91 | self.logger.info('[IRM] %s', metrics_str)
92 |
93 | torch.save(metrics, os.path.join(self.config.results_dir, BEST_METRICS_FILENAME))
94 |
95 | self.logger.info('completed')
96 |
97 |
98 | if __name__ == '__main__':
99 | configure_console_logger()
100 |
101 | arg_parser = argparse.ArgumentParser()
102 | arg_parser.add_argument('--irm-results-root', type=str, required=True)
103 | arg_parser.add_argument('--librimix-root', type=str)
104 | arg_parser.add_argument('--realm-root', type=str)
105 | arg_parser.add_argument('--stft-frame-size', type=int)
106 | arg_parser.add_argument('--stft-hop-size', type=int)
107 | arg_parser.add_argument('--device-name', type=str)
108 |
109 | cmd_args = arg_parser.parse_args()
110 | IdealRatioMask(**vars(cmd_args)).start()
111 |
--------------------------------------------------------------------------------
/src/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ertug/MixCycle/88d9c2d302c4bc62752781b9bbba06e0472d67f1/src/lib/__init__.py
--------------------------------------------------------------------------------
/src/lib/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ertug/MixCycle/88d9c2d302c4bc62752781b9bbba06e0472d67f1/src/lib/data/__init__.py
--------------------------------------------------------------------------------
/src/lib/data/collate_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 | from collections import namedtuple
3 |
4 | import torch
5 |
6 | #########################################################################################################
7 | # based on https://github.com/pytorch/audio/blob/main/examples/source_separation/utils/dataset/utils.py #
8 | #########################################################################################################
9 |
10 | SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
11 | Batch = namedtuple("Batch", ["mix", "src"])
12 |
13 |
14 | def _fix_num_frames(sample: SampleType, target_num_frames: int, sample_rate: int, random_start=False):
15 | """Ensure waveform has exact number of frames by slicing or padding"""
16 | mix = sample[1] # [1, time]
17 | src = torch.cat(sample[2], 0) # [num_sources, time]
18 |
19 | num_channels, num_frames = src.shape
20 | num_seconds = torch.div(num_frames, sample_rate, rounding_mode="floor")
21 | target_seconds = torch.div(target_num_frames, sample_rate, rounding_mode="floor")
22 | if num_frames >= target_num_frames:
23 | if random_start and num_frames > target_num_frames:
24 | start_frame = torch.randint(num_seconds - target_seconds + 1, [1]) * sample_rate
25 | mix = mix[:, start_frame:]
26 | src = src[:, start_frame:]
27 | mix = mix[:, :target_num_frames]
28 | src = src[:, :target_num_frames]
29 | mask = torch.ones_like(mix)
30 | else:
31 | num_padding = target_num_frames - num_frames
32 | pad = torch.zeros([1, num_padding], dtype=mix.dtype, device=mix.device)
33 | mix = torch.cat([mix, pad], 1)
34 | src = torch.cat([src, pad.expand(num_channels, -1)], 1)
35 | mask = torch.ones_like(mix)
36 | mask[..., num_frames:] = 0
37 | return mix, src, mask
38 |
39 |
40 | def collate_fn_wsj0mix_train(samples, sample_rate, duration):
41 | target_num_frames = int(duration * sample_rate)
42 |
43 | mixes, srcs = [], []
44 | for sample in samples:
45 | mix, src, _ = _fix_num_frames(sample, target_num_frames, sample_rate, random_start=True)
46 |
47 | mixes.append(mix)
48 | srcs.append(src)
49 |
50 | return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0))
51 |
52 |
53 | def collate_fn_wsj0mix_test(samples):
54 | assert len(samples) == 1
55 |
56 | mixes, srcs = [], []
57 | for sample in samples:
58 | mix = sample[1] # [1, time]
59 | src = torch.cat(sample[2], 0) # [num_sources, time]
60 |
61 | mixes.append(mix)
62 | srcs.append(src)
63 |
64 | return Batch(torch.stack(mixes, 0), torch.stack(srcs, 0))
65 |
--------------------------------------------------------------------------------
/src/lib/data/dataloader_utils.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | from torch.utils.data import DataLoader, Subset
5 | from torchaudio.datasets.librimix import LibriMix
6 |
7 | from lib.data.collate_utils import collate_fn_wsj0mix_train, collate_fn_wsj0mix_test
8 | from lib.data.realm import RealM
9 |
10 |
11 | SAMPLE_RATE = 8000
12 |
13 |
14 | def get_dataset_specs(librimix_root=None, realm_root=None):
15 | if librimix_root and realm_root:
16 | raise Exception('only one dataset root should be given')
17 | elif librimix_root:
18 | dataset_name = 'librimix'
19 | dataset_root = librimix_root
20 | elif realm_root:
21 | dataset_name = 'realm'
22 | dataset_root = realm_root
23 | else:
24 | raise Exception('at least one dataset root should be given')
25 |
26 | return dataset_name, dataset_root
27 |
28 |
29 | def create_dataloader(dataset_name, dataset_root, partition, batch_size, subsample_ratio=1.0, shuffle=None):
30 | assert partition in ('training', 'validation', 'testing')
31 |
32 | if partition in ('training', 'validation'):
33 | collate_fn = partial(collate_fn_wsj0mix_train, sample_rate=SAMPLE_RATE, duration=3)
34 | else:
35 | collate_fn = partial(collate_fn_wsj0mix_test)
36 |
37 | if dataset_name == 'librimix':
38 | subset_mapping = {
39 | 'training': 'train-360',
40 | 'validation': 'dev',
41 | 'testing': 'test',
42 | }
43 | dataset = LibriMix(
44 | root=dataset_root,
45 | subset=subset_mapping[partition],
46 | num_speakers=2,
47 | sample_rate=SAMPLE_RATE,
48 | task='sep_clean',
49 | )
50 | elif dataset_name == 'realm':
51 | dataset = RealM(
52 | root=dataset_root,
53 | partition=partition,
54 | )
55 | else:
56 | raise Exception(f'unknown dataset name: {dataset_name}')
57 |
58 | if subsample_ratio < 1.0:
59 | subsampled_dataset_size = int(len(dataset) // (1/subsample_ratio))
60 | indices = torch.randperm(len(dataset)).int()[:subsampled_dataset_size]
61 | dataset = Subset(dataset, indices)
62 |
63 | return DataLoader(
64 | dataset=dataset,
65 | batch_size=batch_size,
66 | shuffle=(partition == 'training' or shuffle),
67 | collate_fn=collate_fn,
68 | num_workers=4,
69 | drop_last=(partition == 'training'),
70 | pin_memory=True,
71 | )
72 |
73 |
74 | def double_mixture_generator(dataloader):
75 | iterator = iter(dataloader)
76 | while True:
77 | try:
78 | x_true_wave_1, _ = next(iterator)
79 | x_true_wave_2, _ = next(iterator)
80 | yield x_true_wave_1, x_true_wave_2
81 | except StopIteration:
82 | return
83 |
--------------------------------------------------------------------------------
/src/lib/data/realm.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from pathlib import Path
3 |
4 | from torch.utils.data import Dataset
5 | import torchaudio
6 |
7 |
8 | class RealM(Dataset):
9 | VALIDATION_PERCENTAGE = 20
10 | MAX_NUM_FILES = 2 ** 27 - 1 # ~134M
11 | SOURCE_SUFFIXES = ['source1hat.wav', 'source2hat.wav']
12 |
13 | def __init__(self, root, partition):
14 | self.mix_root = Path(root) / 'audio_files_converted_8000Hz'
15 | self.src_root = Path(root) / 'separations'
16 | self.mix_paths = []
17 |
18 | if partition == 'testing':
19 | partition = 'validation'
20 |
21 | for file_path in self.mix_root.glob('*/*.wav'):
22 | if self._assign_partition(file_path.name) == partition:
23 | self.mix_paths.append(file_path)
24 |
25 | self.mix_paths.sort()
26 |
27 | def __len__(self):
28 | return len(self.mix_paths)
29 |
30 | def __getitem__(self, idx):
31 | mix_path = self.mix_paths[idx]
32 | mix = self._load_audio(mix_path)
33 | srcs = [self._load_audio(self.src_root / (mix_path.name + s)) for s in self.SOURCE_SUFFIXES]
34 | return None, mix, srcs
35 |
36 | def _load_audio(self, file_path):
37 | wave, sample_rate = torchaudio.load(file_path)
38 | assert sample_rate == 8000
39 | return wave
40 |
41 | def _assign_partition(self, filename):
42 | filename_hashed = hashlib.sha1(filename.encode('ascii')).hexdigest()
43 | percentage_hash = ((int(filename_hashed, 16) %
44 | (self.MAX_NUM_FILES + 1)) *
45 | (100.0 / self.MAX_NUM_FILES))
46 | if filename.startswith('early'):
47 | return 'discarded'
48 | elif percentage_hash < self.VALIDATION_PERCENTAGE:
49 | return 'validation'
50 | else:
51 | return 'training'
52 |
--------------------------------------------------------------------------------
/src/lib/data/sc09.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import hashlib
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 | import librosa
8 |
9 | from lib.utils import CONFIG_FILENAME
10 |
11 |
12 | SAMPLE_FILENAME_FORMAT = '{:07d}.pth'
13 |
14 |
15 | class SC09(Dataset):
16 | PARTITIONS = ('training', 'validation', 'testing')
17 | CLASSES = ('zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine')
18 | ORIGINAL_SAMPLE_RATE = 16000
19 | VALIDATION_PERCENTAGE = 10
20 | TESTING_PERCENTAGE = 10
21 | TRAINING_PERCENTAGE = 100 - VALIDATION_PERCENTAGE - TESTING_PERCENTAGE
22 | MAX_NUM_WAVS_PER_CLASS = 2 ** 27 - 1 # ~134M
23 |
24 | def __init__(self, root, partition, classes=None, sample_rate=ORIGINAL_SAMPLE_RATE):
25 | if partition not in self.PARTITIONS:
26 | raise Exception(f'invalid partition: {partition}')
27 |
28 | self.sample_rate = sample_rate
29 | self.metadata = []
30 |
31 | for cls in (self.CLASSES if classes is None else classes):
32 | base = os.path.join(root, cls)
33 | filenames = sorted(os.listdir(base))
34 | for filename in filenames:
35 | if self._assign_partition(filename) == partition:
36 | path = os.path.join(base, filename)
37 | self.metadata.append((path, cls))
38 |
39 | def __len__(self):
40 | return len(self.metadata)
41 |
42 | def __getitem__(self, idx):
43 | path, label = self.metadata[idx]
44 | wave = torch.zeros(self.sample_rate)
45 | x, _ = librosa.core.load(path, sr=self.sample_rate)
46 |
47 | # center if shorter than one second
48 | centered_start = (self.sample_rate-x.shape[0])//2
49 | centered_end = centered_start + x.shape[0]
50 | wave[centered_start:centered_end] = torch.from_numpy(x)
51 |
52 | return wave, label
53 |
54 | def _assign_partition(self, filename):
55 | """ copied from the dataset README """
56 | base_name = os.path.basename(filename)
57 | hash_name = re.sub(r'_nohash_.*$', '', base_name)
58 | hash_name_hashed = hashlib.sha1(hash_name.encode('ascii')).hexdigest()
59 | percentage_hash = ((int(hash_name_hashed, 16) %
60 | (self.MAX_NUM_WAVS_PER_CLASS + 1)) *
61 | (100.0 / self.MAX_NUM_WAVS_PER_CLASS))
62 | if percentage_hash < self.VALIDATION_PERCENTAGE:
63 | result = 'validation'
64 | elif percentage_hash < (self.TESTING_PERCENTAGE + self.VALIDATION_PERCENTAGE):
65 | result = 'testing'
66 | else:
67 | result = 'training'
68 | return result
69 |
70 |
71 | class SC09Mix(Dataset):
72 | PARTITIONS = SC09.PARTITIONS
73 |
74 | def __init__(self, root, partition):
75 | assert partition in SC09Mix.PARTITIONS
76 |
77 | self.mix_dir = os.path.join(root, 'mix', partition)
78 | self.count = len(os.listdir(self.mix_dir))
79 | self.config = torch.load(os.path.join(root, CONFIG_FILENAME))
80 |
81 | def _load_sample(self, idx):
82 | return torch.load(os.path.join(self.mix_dir, SAMPLE_FILENAME_FORMAT.format(idx)))
83 |
84 | def _get_single_item(self, idx):
85 | sample = self._load_sample(idx)
86 |
87 | return sample['mixture'].unsqueeze(0), sample['sources']
88 |
89 | def __len__(self):
90 | return self.count
91 |
92 | def __getitem__(self, idx):
93 | return self._get_single_item(idx)
94 |
--------------------------------------------------------------------------------
/src/lib/losses.py:
--------------------------------------------------------------------------------
1 | from math import prod
2 |
3 | import torch
4 |
5 |
6 | def snr(true_wave, pred_wave, snr_max=None):
7 | true_wave_square_sum = true_wave.square().sum(-1)
8 | if snr_max is None:
9 | soft_threshold = 0
10 | else:
11 | threshold = 10 ** (-snr_max / 10)
12 | soft_threshold = threshold * true_wave_square_sum
13 |
14 | return 10 * torch.log10(true_wave_square_sum / ((true_wave - pred_wave).square().sum(-1) + soft_threshold))
15 |
16 |
17 | def sisnr(true_wave, pred_wave, eps=0.):
18 | true_wave = ((true_wave * pred_wave).sum(-1) / true_wave.square().sum(-1)).unsqueeze(-1) * true_wave
19 | return 10 * torch.log10(true_wave.square().sum(-1) / ((true_wave - pred_wave).square().sum(-1) + eps))
20 |
21 |
22 | def sisnri(true_wave, pred_wave, x_true_wave, eps=0.):
23 | return sisnr(true_wave=true_wave, pred_wave=pred_wave, eps=eps) - \
24 | sisnr(true_wave=true_wave, pred_wave=x_true_wave, eps=eps)
25 |
26 |
27 | def negative_snr(true_wave, pred_wave, snr_max=None):
28 | return -snr(true_wave, pred_wave, snr_max)
29 |
30 |
31 | def negative_sisnr(true_wave, pred_wave):
32 | return -sisnr(true_wave, pred_wave)
33 |
34 |
35 | def negative_sisnri(true_wave, pred_wave, x_true_wave, eps=0.):
36 | return -sisnri(true_wave, pred_wave, x_true_wave, eps=eps)
37 |
38 |
39 | def invariant_loss(true, pred, mixing_matrices, loss_func, return_best_perm_idx=False):
40 | pred_flat = pred.view(*pred.size()[:2], prod(pred.size()[2:]))
41 |
42 | batch_size = true.size(0)
43 | perm_size = mixing_matrices.size(0)
44 | loss_perms = torch.empty([batch_size, perm_size], device=true.device)
45 |
46 | for perm_idx in range(perm_size):
47 | pred_flat_mix = mixing_matrices[perm_idx].matmul(pred_flat)
48 | pred_mix = pred_flat_mix.view(*pred_flat_mix.size()[:2], *pred.size()[2:])
49 | loss_perms[:, perm_idx] = loss_func(true, pred_mix).mean(dim=1)
50 | _, best_perm_idx = loss_perms.min(dim=1)
51 |
52 | batch_loss = loss_perms[torch.arange(batch_size), best_perm_idx]
53 |
54 | if return_best_perm_idx:
55 | return batch_loss, best_perm_idx
56 | else:
57 | return batch_loss
58 |
--------------------------------------------------------------------------------
/src/lib/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import itertools
3 |
4 | import torch
5 | from torch import nn
6 | from torchaudio.models.conv_tasnet import MaskGenerator
7 |
8 | from lib.utils import get_logger, soft_mask
9 | from lib.transforms import Transform
10 |
11 |
12 | class Model(nn.Module):
13 | MIXING_MATRICES_CACHE = {}
14 |
15 | def __init__(self, config, num_sources=None):
16 | super().__init__()
17 |
18 | self.args = locals()
19 | self.logger = get_logger('model')
20 | self.config = config
21 |
22 | self.num_sources = self.config.num_sources if num_sources is None else num_sources
23 |
24 | self.transform = Transform(
25 | stft_frame_size=self.config.stft_frame_size,
26 | stft_hop_size=self.config.stft_hop_size,
27 | device=self.config.device,
28 | )
29 |
30 | self.mask_generator = MaskGenerator(
31 | input_dim=self.config.num_frequency_bins,
32 | num_sources=self.num_sources,
33 | kernel_size=3,
34 | num_feats=128,
35 | num_hidden=512,
36 | num_layers=4,
37 | num_stacks=3,
38 | msk_activate='sigmoid',
39 | )
40 |
41 | @classmethod
42 | def load(cls, path, device=None):
43 | checkpoint = torch.load(path)
44 | if 'device' in checkpoint.keys():
45 | del checkpoint['device'] ## FIXME: remove
46 | if device:
47 | checkpoint['config'].device = device
48 | state_dict = checkpoint.pop('state_dict')
49 |
50 | instance = cls(**checkpoint)
51 | instance.load_state_dict(state_dict)
52 | instance.to(instance.config.device)
53 |
54 | return instance
55 |
56 | def save(self, path):
57 | checkpoint = self.args.copy()
58 | del checkpoint['self']
59 | del checkpoint['__class__']
60 | checkpoint['config'] = self.config
61 | checkpoint['state_dict'] = self.state_dict()
62 |
63 | os.makedirs(os.path.dirname(path), exist_ok=True)
64 | torch.save(checkpoint, path)
65 |
66 | def forward(self, x_true_wave):
67 | x_true_mag, x_true_phase = self.transform.stft(x_true_wave)
68 | m_pred_mag = self.mask_generator(x_true_mag.squeeze(1))
69 | s_pred_mag = soft_mask(m_pred_mag, x_true_mag)
70 | s_pred_wave = self.transform.istft(
71 | mag=s_pred_mag,
72 | phase=x_true_phase,
73 | length=x_true_wave.size(-1)
74 | )
75 | return s_pred_wave
76 |
77 | def generate_mixing_matrices(self, num_targets, max_sources, num_mix=None, allow_empty=False):
78 | parameters = locals()
79 | del parameters['self']
80 |
81 | def do_generate():
82 | output_perms = itertools.product([0, 1], repeat=max_sources)
83 | if num_mix is not None:
84 | output_perms = [perm for perm in output_perms if sum(perm) == num_mix]
85 | target_perms = list(itertools.product(output_perms, repeat=num_targets))
86 | perm_list = []
87 | for target_perm in target_perms:
88 | perm_sum = torch.tensor(target_perm).sum(dim=0)
89 | if (perm_sum <= 1).all() if allow_empty else (perm_sum == 1).all():
90 | perm_list.append(target_perm)
91 | self.logger.info('mixing matrices are generated with %d permutations for parameters %s',
92 | len(perm_list), parameters)
93 | return torch.tensor(perm_list).float().to(self.config.device)
94 |
95 | cache_key = '_'.join(str(v) for k, v in parameters.items())
96 | try:
97 | r = Model.MIXING_MATRICES_CACHE[cache_key]
98 | except KeyError:
99 | r = Model.MIXING_MATRICES_CACHE[cache_key] = do_generate()
100 | return r
101 |
--------------------------------------------------------------------------------
/src/lib/trainers.py:
--------------------------------------------------------------------------------
1 | import random
2 | from math import prod
3 | from functools import partial
4 |
5 | import torch
6 |
7 | from lib.losses import negative_snr, negative_sisnri, invariant_loss
8 | from lib.models import Model
9 | from lib.utils import EPS, shuffle_sources, MetricAccumulator
10 | from lib.data.dataloader_utils import double_mixture_generator
11 |
12 |
13 | class Trainer:
14 | def __init__(self, config=None, model=None):
15 | self.loss_accumulator = MetricAccumulator()
16 | self.sisnri_accumulator = MetricAccumulator()
17 | if model:
18 | self.model = model
19 | if config:
20 | self.model.config = config
21 | else:
22 | self.model = Model(config=config)
23 | self.model.to(self.model.config.device)
24 |
25 | self.optimizer = torch.optim.Adam([
26 | {'params': self.model.parameters(), 'lr': self.model.config.lr},
27 | ])
28 |
29 | def train(self, dataloader):
30 | self.model.train()
31 | self.loss_accumulator.reset()
32 |
33 | def step(self, batch_loss):
34 | self.optimizer.zero_grad(set_to_none=True)
35 | batch_loss.sum().backward()
36 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.model.config.grad_clip)
37 | self.optimizer.step()
38 | self.loss_accumulator.store(batch_loss)
39 |
40 | def validate(self, dataloader):
41 | self.model.eval()
42 | self.sisnri_accumulator.reset()
43 |
44 | for x_true_wave, s_true_wave in dataloader:
45 | x_true_wave = x_true_wave.to(self.model.config.device)
46 | s_true_wave = s_true_wave.to(self.model.config.device)
47 |
48 | s_pred_wave = self.model(x_true_wave)
49 |
50 | mixing_matrices = self.model.generate_mixing_matrices(
51 | num_targets=self.model.config.num_sources,
52 | max_sources=self.model.num_sources,
53 | num_mix=1,
54 | allow_empty=True
55 | )
56 | batch_sisnri = -invariant_loss(
57 | true=s_true_wave,
58 | pred=s_pred_wave,
59 | mixing_matrices=mixing_matrices,
60 | loss_func=partial(negative_sisnri, x_true_wave=x_true_wave)
61 | )
62 | self.sisnri_accumulator.store(batch_sisnri)
63 |
64 | std, mean = self.sisnri_accumulator.std_mean()
65 | return {
66 | 'sisnri': mean.item(),
67 | 'sisnri_std': std.item()
68 | }
69 |
70 | def get_loss(self):
71 | _, mean = self.loss_accumulator.std_mean()
72 | return mean.item()
73 |
74 | def get_model(self):
75 | return self.model
76 |
77 |
78 | class PermutationInvariantTrainer(Trainer):
79 | def train(self, dataloader):
80 | super().train(dataloader)
81 |
82 | for x_true_wave, s_true_wave in dataloader:
83 | x_true_wave = x_true_wave.to(self.model.config.device)
84 | s_true_wave = s_true_wave.to(self.model.config.device)
85 |
86 | s_pred_wave = self.model(x_true_wave)
87 |
88 | mixing_matrices = self.model.generate_mixing_matrices(
89 | num_targets=self.model.config.num_sources,
90 | max_sources=self.model.config.num_sources,
91 | num_mix=1
92 | )
93 | batch_loss = invariant_loss(
94 | true=s_true_wave,
95 | pred=s_pred_wave,
96 | mixing_matrices=mixing_matrices,
97 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max),
98 | )
99 | self.step(batch_loss)
100 | yield 1
101 |
102 |
103 | class PermutationInvariantDynamicMixingTrainer(Trainer):
104 | def train(self, dataloader):
105 | super().train(dataloader)
106 |
107 | for _, s_true_wave in dataloader:
108 | s_true_wave = s_true_wave.to(self.model.config.device)
109 |
110 | shuffle_idx = torch.randperm(prod(s_true_wave.size()[:2]))
111 | s_true_wave_shuffled = shuffle_sources(s_true_wave, self.model.config.num_sources, shuffle_idx)
112 | x_true_wave_shuffled = s_true_wave_shuffled.sum(dim=1, keepdims=True)
113 |
114 | s_pred_wave_shuffled = self.model(x_true_wave_shuffled)
115 |
116 | mixing_matrices = self.model.generate_mixing_matrices(
117 | num_targets=self.model.config.num_sources,
118 | max_sources=self.model.config.num_sources,
119 | num_mix=1
120 | )
121 | batch_loss = invariant_loss(
122 | true=s_true_wave_shuffled,
123 | pred=s_pred_wave_shuffled,
124 | mixing_matrices=mixing_matrices,
125 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max),
126 | )
127 | self.step(batch_loss)
128 | yield 1
129 |
130 |
131 | class MixtureInvariantTrainer(Trainer):
132 | def __init__(self, config=None, model=None):
133 | if not model:
134 | model = Model(config=config, num_sources=config.num_sources * 2)
135 |
136 | super().__init__(config, model)
137 |
138 | self.sisnri_mixit_oracle_accumulator = MetricAccumulator()
139 |
140 | def train(self, dataloader):
141 | super().train(dataloader)
142 |
143 | for x_true_wave_1, x_true_wave_2 in double_mixture_generator(dataloader):
144 | x_true_wave_1 = x_true_wave_1.to(self.model.config.device)
145 | x_true_wave_2 = x_true_wave_2.to(self.model.config.device)
146 |
147 | x_true_wave_double = torch.cat([x_true_wave_1, x_true_wave_2], dim=1)
148 | x_true_wave_mom = x_true_wave_double.sum(dim=1, keepdim=True)
149 |
150 | s_pred_wave = self.model(x_true_wave_mom)
151 |
152 | mixing_matrices = self.model.generate_mixing_matrices(
153 | num_targets=2,
154 | max_sources=self.model.num_sources
155 | )
156 | batch_loss = invariant_loss(
157 | true=x_true_wave_double,
158 | pred=s_pred_wave,
159 | mixing_matrices=mixing_matrices,
160 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max),
161 | )
162 | self.step(batch_loss)
163 | yield 2
164 |
165 | def validate(self, dataloader):
166 | metrics = super().validate(dataloader)
167 |
168 | self.sisnri_mixit_oracle_accumulator.reset()
169 |
170 | for x_true_wave, s_true_wave in dataloader:
171 | x_true_wave = x_true_wave.to(self.model.config.device)
172 | s_true_wave = s_true_wave.to(self.model.config.device)
173 |
174 | s_pred_wave = self.model(x_true_wave)
175 |
176 | mixing_matrices = self.model.generate_mixing_matrices(
177 | num_targets=self.model.config.num_sources,
178 | max_sources=self.model.num_sources
179 | )
180 | batch_sisnri = -invariant_loss(
181 | true=s_true_wave,
182 | pred=s_pred_wave,
183 | mixing_matrices=mixing_matrices,
184 | loss_func=partial(negative_sisnri, x_true_wave=x_true_wave, eps=EPS),
185 | )
186 |
187 | self.sisnri_mixit_oracle_accumulator.store(batch_sisnri)
188 |
189 | std, mean = self.sisnri_mixit_oracle_accumulator.std_mean()
190 | metrics['sisnri_mixit_oracle'] = mean.item()
191 | metrics['sisnri_mixit_oracle_std'] = std.item()
192 | return metrics
193 |
194 |
195 | class MixturePermutationInvariantTrainer(Trainer):
196 | def train(self, dataloader):
197 | super().train(dataloader)
198 |
199 | for x_true_wave_1, x_true_wave_2 in double_mixture_generator(dataloader):
200 | x_true_wave_1 = x_true_wave_1.to(self.model.config.device)
201 | x_true_wave_2 = x_true_wave_2.to(self.model.config.device)
202 |
203 | x_true_wave_double = torch.cat([x_true_wave_1, x_true_wave_2], dim=1)
204 | x_true_wave_mom = x_true_wave_double.sum(dim=1, keepdim=True)
205 |
206 | s_pred_wave = self.model(x_true_wave_mom)
207 |
208 | mixing_matrices = self.model.generate_mixing_matrices(
209 | num_targets=self.model.config.num_sources,
210 | max_sources=self.model.config.num_sources,
211 | num_mix=1
212 | )
213 | batch_loss = invariant_loss(
214 | true=x_true_wave_double,
215 | pred=s_pred_wave,
216 | mixing_matrices=mixing_matrices,
217 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max),
218 | )
219 | self.step(batch_loss)
220 | yield 2
221 |
222 |
223 | class MixCycleTrainer(Trainer):
224 | def __init__(self, config=None, model=None):
225 | super().__init__(config, model)
226 |
227 | self.sisnri_blind_accumulator = MetricAccumulator()
228 | self.mixpit_trainer = MixturePermutationInvariantTrainer(config, self.model)
229 | self.epochs = 0
230 | self.mixcycle_steps = 0
231 | self.model_copy = None
232 |
233 | def train(self, dataloader):
234 | super().train(dataloader)
235 |
236 | if self.epochs < self.model.config.mixcycle_init_epochs:
237 | generator = self.mixpit_trainer.train(dataloader)
238 | else:
239 | self.mixpit_trainer = None
240 | generator = self._mixcycle_train(dataloader)
241 |
242 | self.epochs += 1
243 | return generator
244 |
245 | def validate(self, dataloader):
246 | if self.mixpit_trainer:
247 | metrics = self.mixpit_trainer.validate(dataloader)
248 | else:
249 | metrics = super().validate(dataloader)
250 |
251 | if self.model.config.eval_method == 'blind':
252 | metrics.update(self.validate_blind(dataloader))
253 | return metrics
254 |
255 | def validate_blind(self, dataloader):
256 | self.sisnri_blind_accumulator.reset()
257 |
258 | for _ in range(self.model.config.eval_blind_num_repeat):
259 | for x_true_wave, _ in dataloader:
260 | x_true_wave = x_true_wave.to(self.model.config.device)
261 |
262 | x_pred_wave_shuffled, s_pred_wave_shuffled = self._teacher(x_true_wave)
263 |
264 | s_pred_wave = self.model(x_pred_wave_shuffled)
265 |
266 | mixing_matrices = self.model.generate_mixing_matrices(
267 | num_targets=self.model.config.num_sources,
268 | max_sources=self.model.config.num_sources,
269 | num_mix=1
270 | )
271 | batch_sisnri = -invariant_loss(
272 | true=s_pred_wave_shuffled,
273 | pred=s_pred_wave,
274 | mixing_matrices=mixing_matrices,
275 | loss_func=partial(negative_sisnri, x_true_wave=x_pred_wave_shuffled, eps=EPS),
276 | )
277 |
278 | self.sisnri_blind_accumulator.store(batch_sisnri)
279 |
280 | std, mean = self.sisnri_blind_accumulator.std_mean()
281 | return {
282 | 'sisnri_blind': mean.item(),
283 | 'sisnri_blind_std': std.item()
284 | }
285 |
286 | def get_loss(self):
287 | if self.mixpit_trainer:
288 | return self.mixpit_trainer.get_loss()
289 | else:
290 | return super().get_loss()
291 |
292 | def get_model(self):
293 | if self.mixpit_trainer:
294 | return self.mixpit_trainer.get_model()
295 | else:
296 | return super().get_model()
297 |
298 | def _mixcycle_train(self, dataloader):
299 | for x_true_wave, _ in dataloader:
300 | with torch.no_grad():
301 | x_true_wave = x_true_wave.to(self.model.config.device)
302 |
303 | x_pred_wave_shuffled, s_pred_wave_shuffled = self._teacher(x_true_wave)
304 |
305 | s_pred_wave = self.model(x_pred_wave_shuffled)
306 |
307 | mixing_matrices = self.model.generate_mixing_matrices(
308 | num_targets=self.model.config.num_sources,
309 | max_sources=self.model.config.num_sources,
310 | num_mix=1
311 | )
312 | batch_loss = invariant_loss(
313 | true=s_pred_wave_shuffled,
314 | pred=s_pred_wave,
315 | mixing_matrices=mixing_matrices,
316 | loss_func=partial(negative_snr, snr_max=self.model.config.snr_max),
317 | )
318 | self.step(batch_loss)
319 | self.mixcycle_steps += 1
320 | yield 1
321 |
322 | def _teacher(self, x_true_wave):
323 | s_pred_wave = self.model(x_true_wave)
324 |
325 | shuffle_idx = list(range(prod(s_pred_wave.size()[:2])))
326 | for i in range(0, len(shuffle_idx), 2):
327 | # randomly swap source estimates of mixture
328 | if random.randrange(2) == 0:
329 | shuffle_idx[i], shuffle_idx[i + 1] = shuffle_idx[i + 1], shuffle_idx[i]
330 |
331 | for i in range(0, len(shuffle_idx), 4):
332 | # swap source estimates across two mixtures
333 | shuffle_idx[i], shuffle_idx[i + 2] = shuffle_idx[i + 2], shuffle_idx[i]
334 |
335 | s_pred_wave_shuffled = shuffle_sources(s_pred_wave, self.model.config.num_sources, shuffle_idx)
336 | x_pred_wave_shuffled = s_pred_wave_shuffled.sum(dim=1, keepdims=True)
337 |
338 | return x_pred_wave_shuffled, s_pred_wave_shuffled
339 |
340 |
341 | TRAINER_MAPPING = {
342 | 'pit': PermutationInvariantTrainer,
343 | 'pit-dm': PermutationInvariantDynamicMixingTrainer,
344 | 'mixit': MixtureInvariantTrainer,
345 | 'mixpit': MixturePermutationInvariantTrainer,
346 | 'mixcycle': MixCycleTrainer,
347 | }
348 |
349 |
350 | def get_trainer(model_name):
351 | try:
352 | return TRAINER_MAPPING[model_name]
353 | except KeyError:
354 | raise Exception(f'unknown model name: {model_name}')
355 |
--------------------------------------------------------------------------------
/src/lib/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from lib.utils import flatten_sources, unflatten_sources
4 |
5 |
6 | class Transform:
7 | def __init__(self, stft_frame_size, stft_hop_size, device):
8 | self.stft_frame_size = stft_frame_size
9 | self.stft_hop_size = stft_hop_size
10 |
11 | self.hann_window = torch.hann_window(
12 | self.stft_frame_size,
13 | periodic=True,
14 | device=device
15 | )
16 |
17 | def stft(self, wave):
18 | wave_flat = flatten_sources(wave)
19 | complex_flat = torch.stft(
20 | wave_flat,
21 | n_fft=self.stft_frame_size,
22 | hop_length=self.stft_hop_size,
23 | window=self.hann_window,
24 | return_complex=True
25 | )
26 | complex = unflatten_sources(complex_flat, num_sources=wave.size(1))
27 | mag, phase = complex.abs(), complex.angle()
28 | return mag, phase
29 |
30 | def istft(self, mag, phase, length):
31 | complex = torch.complex(
32 | real=mag * phase.cos(),
33 | imag=mag * phase.sin()
34 | )
35 | complex_flat = flatten_sources(complex)
36 | wave_flat = torch.istft(
37 | complex_flat,
38 | n_fft=self.stft_frame_size,
39 | hop_length=self.stft_hop_size,
40 | window=self.hann_window,
41 | length=length
42 | )
43 | return unflatten_sources(wave_flat, num_sources=mag.size(1))
44 |
--------------------------------------------------------------------------------
/src/lib/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import logging
4 | import random
5 |
6 | import torch
7 |
8 |
9 | CONFIG_FILENAME = 'config.pth'
10 | BEST_CHECKPOINT_FILENAME = 'best_checkpoint.pth'
11 | BEST_METRICS_FILENAME = 'best_metrics.pth'
12 | METRICS_HISTORY_FILENAME = 'metrics_history.pth'
13 | EPS = 1e-8
14 |
15 |
16 | class ResultsDirExistsError(Exception):
17 | pass
18 |
19 |
20 | def configure_console_logger():
21 | logger = logging.getLogger()
22 | console = logging.StreamHandler(sys.stdout)
23 | console.setFormatter(logging.Formatter('%(asctime)s %(name)-6s %(levelname)-6s %(message)s', '%y-%m-%d %H:%M:%S'))
24 | logger.addHandler(console)
25 |
26 |
27 | def configure_file_logger(work_dir):
28 | logger = logging.getLogger()
29 | file = logging.FileHandler(os.path.join(work_dir, 'root.log'))
30 | file.setFormatter(logging.Formatter('%(asctime)s %(name)-6s %(levelname)-6s %(message)s'))
31 | logger.addHandler(file)
32 |
33 |
34 | def get_logger(name, work_dir=None):
35 | logger = logging.getLogger(name)
36 | logger.setLevel(logging.DEBUG)
37 | if work_dir:
38 | file = logging.FileHandler(os.path.join(work_dir, name + '.log'))
39 | file.setLevel(logging.DEBUG)
40 | file.setFormatter(logging.Formatter('%(asctime)s %(name)-6s %(levelname)-6s %(message)s'))
41 | logger.addHandler(file)
42 | return logger
43 |
44 |
45 | def default(input_value, default_value):
46 | return default_value if input_value is None else input_value
47 |
48 |
49 | def build_run_name(args, prepend_items=None, exclude_keys=None):
50 | if prepend_items is None:
51 | prepend_items = {}
52 | if exclude_keys is None:
53 | exclude_keys = []
54 |
55 | run_name_dict = {}
56 | if prepend_items is not None:
57 | run_name_dict.update(prepend_items)
58 | for k, v in args.items():
59 | if k not in prepend_items.keys() and k not in exclude_keys and v is not None:
60 | if isinstance(v, list):
61 | v = '_'.join(str(item) for item in v)
62 | run_name_dict[k] = v
63 | return '.'.join(f'{k}_{v}' for k, v in run_name_dict.items())
64 |
65 |
66 | def ensure_clean_results_dir(results_dir):
67 | if os.path.exists(results_dir):
68 | raise ResultsDirExistsError(f'results dir "{results_dir}" exists, remove it and run the script again.')
69 | os.makedirs(results_dir, exist_ok=True)
70 |
71 |
72 | def setup_determinism(seed):
73 | if seed is None:
74 | torch.use_deterministic_algorithms(False)
75 | torch.backends.cudnn.benchmark = True
76 | else:
77 | random.seed(seed)
78 | torch.manual_seed(seed)
79 | torch.use_deterministic_algorithms(True)
80 | torch.backends.cudnn.benchmark = False
81 |
82 |
83 | def flatten_sources(batch):
84 | batch_size = batch.size(0)
85 | num_sources = batch.size(1)
86 | return batch.reshape(batch_size * num_sources, *batch.size()[2:])
87 |
88 |
89 | def unflatten_sources(batch, num_sources):
90 | batch_size = batch.size(0) // num_sources
91 | return batch.view(batch_size, num_sources, *batch.size()[1:])
92 |
93 |
94 | def shuffle_sources(batch, num_sources, shuffle_idx):
95 | batch_flat = flatten_sources(batch)
96 | return unflatten_sources(batch_flat[shuffle_idx], num_sources)
97 |
98 |
99 | def total_num_params(params):
100 | return sum(param.numel() for param in params)
101 |
102 |
103 | def soft_mask(m_pred_mag, x_true_mag):
104 | return (m_pred_mag / (m_pred_mag.sum(dim=1, keepdim=True) + EPS)) * x_true_mag
105 |
106 |
107 | def metrics_to_str(metrics):
108 | return ' '.join('{}={:.3f}'.format(k, v) for k, v in metrics.items())
109 |
110 |
111 | class MetricAccumulator:
112 | def __init__(self):
113 | self.accumulator = None
114 | self.reset()
115 |
116 | def store(self, batch):
117 | self.accumulator.append(batch.detach())
118 |
119 | def reset(self):
120 | self.accumulator = []
121 |
122 | def std_mean(self):
123 | with torch.inference_mode():
124 | epoch_values = torch.cat(self.accumulator)
125 | return torch.std_mean(epoch_values, unbiased=True)
126 |
--------------------------------------------------------------------------------
/src/sc09mix.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from types import SimpleNamespace
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from tqdm import trange
8 |
9 | from lib.utils import CONFIG_FILENAME, configure_console_logger, get_logger, default, build_run_name, \
10 | ensure_clean_results_dir, setup_determinism
11 | from lib.data.sc09 import SAMPLE_FILENAME_FORMAT, SC09
12 |
13 |
14 | class Mixer:
15 | def __init__(self, sc_dir, mix_results_root,
16 | num_train_samples=None,
17 | num_test_samples=None,
18 | num_components=None,
19 | classes=None,
20 | sample_rate=None,
21 | normalization=None,
22 | seed=None):
23 | args = locals()
24 | del args['self']
25 | run_name = build_run_name(
26 | args=args,
27 | prepend_items={'dataset': 'sc09mix'},
28 | exclude_keys=['sc_dir', 'mix_results_root']
29 | )
30 | self.results_dir = os.path.join(mix_results_root, run_name)
31 |
32 | self.config = SimpleNamespace()
33 | self.config.results_dir = self.results_dir
34 | self.config.sc_dir = sc_dir
35 | self.config.mix_results_root = mix_results_root
36 | self.config.num_train_samples = default(num_train_samples, 15000)
37 | self.config.num_test_samples = default(num_test_samples, 5000)
38 | self.config.num_components = default(num_components, 2)
39 | self.config.classes = default(classes, list(range(len(SC09.CLASSES))))
40 | self.config.sample_rate = default(sample_rate, 8000)
41 | self.config.normalization = default(normalization, 'standardize')
42 | self.config.seed = default(seed, None)
43 | self.config.sample_length = self.config.sample_rate # 1 second recordings
44 |
45 | def start(self):
46 | ensure_clean_results_dir(self.config.results_dir)
47 | setup_determinism(self.config.seed)
48 |
49 | self.logger = get_logger('sc09mix', self.config.results_dir)
50 | self.logger.info('config: %s', self.config)
51 | torch.save(self.config, os.path.join(self.config.results_dir, CONFIG_FILENAME))
52 |
53 | for partition in SC09.PARTITIONS:
54 | num_samples = self.config.num_train_samples if partition == 'training' else self.config.num_test_samples
55 | self.logger.info('generating the %s partition...', partition)
56 |
57 | mix_dir = os.path.join(self.config.results_dir, 'mix', partition)
58 | os.makedirs(mix_dir, exist_ok=True)
59 |
60 | infinite_dataloader = self._create_infinite_dataloader(partition)
61 |
62 | for sample_idx in trange(num_samples):
63 | mixture = self._generate_mixture(infinite_dataloader)
64 | torch.save(mixture, os.path.join(mix_dir, SAMPLE_FILENAME_FORMAT.format(sample_idx)))
65 |
66 | self.logger.info('completed')
67 |
68 | def _generate_mixture(self, dataloader):
69 | mixture = torch.zeros(self.config.sample_length)
70 | sources = torch.zeros(self.config.num_components, self.config.sample_length)
71 | for i in range(self.config.num_components):
72 | source, _ = next(dataloader)
73 | source = source.view(-1)
74 |
75 | if self.config.normalization == 'none':
76 | pass
77 | elif self.config.normalization == 'rms':
78 | source = source / source.square().mean().sqrt()
79 | elif self.config.normalization == 'standardize':
80 | source = (source - source.mean()) / source.std()
81 | else:
82 | raise Exception(f'unknown normalization: {self.config.normalization}')
83 |
84 | mixture += source
85 | sources[i] = source
86 |
87 | return {
88 | 'mixture': mixture,
89 | 'sources': sources,
90 | }
91 |
92 | def _create_infinite_dataloader(self, partition):
93 | dataset = SC09(
94 | root=self.config.sc_dir,
95 | partition=partition,
96 | classes=[SC09.CLASSES[class_idx] for class_idx in self.config.classes],
97 | sample_rate=self.config.sample_rate
98 | )
99 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
100 | while True:
101 | for sample in dataloader:
102 | yield sample
103 |
104 |
105 | if __name__ == '__main__':
106 | configure_console_logger()
107 |
108 | arg_parser = argparse.ArgumentParser()
109 | arg_parser.add_argument('--sc-dir', type=str, required=True)
110 | arg_parser.add_argument('--mix-results-root', type=str, required=True)
111 | arg_parser.add_argument('--partition', choices=['training', 'validation', 'testing'])
112 | arg_parser.add_argument('--num-samples', type=int)
113 | arg_parser.add_argument('--num-components', type=int)
114 | arg_parser.add_argument('--classes', type=str, help='comma separated')
115 | arg_parser.add_argument('--sample-rate', type=int)
116 | arg_parser.add_argument('--normalization', choices=['none', 'rms', 'standardize'])
117 | arg_parser.add_argument('--seed', type=int)
118 |
119 | cmd_args = arg_parser.parse_args()
120 | if cmd_args.classes:
121 | cmd_args.classes = [int(cls) for cls in cmd_args.classes.split(',')]
122 |
123 | Mixer(**vars(cmd_args)).start()
124 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from types import SimpleNamespace
4 |
5 | import torch
6 |
7 | from lib.utils import CONFIG_FILENAME, BEST_CHECKPOINT_FILENAME, BEST_METRICS_FILENAME, configure_console_logger, \
8 | get_logger, default, build_run_name, ensure_clean_results_dir, setup_determinism, metrics_to_str
9 | from lib.data.dataloader_utils import get_dataset_specs, create_dataloader
10 | from lib.models import Model
11 | from lib.trainers import get_trainer
12 |
13 |
14 | class Test:
15 | def __init__(self, train_results_dir, librimix_root=None, realm_root=None,
16 | eval_method=None,
17 | eval_blind_num_repeat=None,
18 | seed=None,
19 | device_name=None):
20 | args = locals()
21 | del args['self']
22 |
23 | self.dataset_name, self.dataset_root = get_dataset_specs(librimix_root, realm_root)
24 | run_name = build_run_name(
25 | args=args,
26 | prepend_items={'test': self.dataset_name},
27 | exclude_keys=['train_results_dir', 'librimix_root', 'realm_root']
28 | )
29 | self.results_dir = os.path.join(train_results_dir, 'test', run_name)
30 |
31 | self.config = SimpleNamespace()
32 | self.config.results_dir = self.results_dir
33 | self.config.train_results_dir = train_results_dir
34 | self.config.librimix_root = librimix_root
35 | self.config.realm_root = realm_root
36 | self.config.eval_method = default(eval_method, None)
37 | self.config.eval_blind_num_repeat = default(eval_blind_num_repeat, None)
38 | self.config.seed = default(seed, None)
39 | self.config.device_name = default(device_name, 'cuda')
40 |
41 | self.logger = None
42 | self.config.device = torch.device(self.config.device_name)
43 |
44 | def start(self):
45 | ensure_clean_results_dir(self.config.results_dir)
46 | setup_determinism(self.config.seed)
47 | self.logger = get_logger('test', self.config.results_dir)
48 | self.logger.info('config: %s', self.config)
49 | torch.save(self.config, os.path.join(self.config.results_dir, CONFIG_FILENAME))
50 |
51 | model = Model.load(
52 | path=os.path.join(self.config.train_results_dir, BEST_CHECKPOINT_FILENAME),
53 | device=self.config.device
54 | ).eval()
55 |
56 | if self.config.eval_method:
57 | model.config.eval_method = self.config.eval_method
58 |
59 | if self.config.eval_blind_num_repeat:
60 | model.config.eval_blind_num_repeat = self.config.eval_blind_num_repeat
61 |
62 | if model.config.eval_method == 'blind':
63 | model_name = 'mixcycle'
64 | partition = 'validation'
65 | batch_size = 128
66 | shuffle = True
67 | elif model.config.eval_method == 'reference-valid':
68 | model_name = model.config.model_name
69 | partition = 'validation'
70 | batch_size = 128
71 | shuffle = False
72 | else:
73 | model_name = model.config.model_name
74 | partition = 'testing'
75 | batch_size = 1
76 | shuffle = False
77 |
78 | dataloader = create_dataloader(
79 | dataset_name=self.dataset_name,
80 | dataset_root=self.dataset_root,
81 | partition=partition,
82 | batch_size=batch_size,
83 | shuffle=shuffle
84 | )
85 | self.logger.info('using %d samples for evaluation', len(dataloader.dataset))
86 |
87 | trainer = get_trainer(model_name)(model=model)
88 | with torch.inference_mode():
89 | metrics = trainer.validate(dataloader)
90 |
91 | self.logger.info('[TEST] %s', metrics_to_str(metrics))
92 |
93 | torch.save(metrics, os.path.join(self.config.results_dir, BEST_METRICS_FILENAME))
94 |
95 | self.logger.info('completed')
96 |
97 |
98 | if __name__ == '__main__':
99 | configure_console_logger()
100 |
101 | arg_parser = argparse.ArgumentParser()
102 | arg_parser.add_argument('--train-results-dir', type=str, required=True)
103 | arg_parser.add_argument('--librimix-root', type=str)
104 | arg_parser.add_argument('--realm-root', type=str)
105 | arg_parser.add_argument('--eval-method', choices=['reference', 'reference-valid', 'blind'])
106 | arg_parser.add_argument('--eval-blind-num-repeat', type=int)
107 | arg_parser.add_argument('--seed', type=int)
108 | arg_parser.add_argument('--device-name', type=str)
109 |
110 | cmd_args = arg_parser.parse_args()
111 | Test(**vars(cmd_args)).start()
112 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from time import time
4 | from types import SimpleNamespace
5 |
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 | from tqdm import tqdm
9 |
10 | from lib.utils import CONFIG_FILENAME, BEST_CHECKPOINT_FILENAME, METRICS_HISTORY_FILENAME, configure_console_logger, \
11 | default, build_run_name, ensure_clean_results_dir, setup_determinism, get_logger, total_num_params
12 | from lib.data.dataloader_utils import get_dataset_specs, create_dataloader
13 | from lib.trainers import get_trainer
14 | from lib.models import Model
15 |
16 |
17 | class Training:
18 | def __init__(self, train_results_root, librimix_root=None, realm_root=None,
19 | stft_frame_size=None,
20 | stft_hop_size=None,
21 | model_name=None,
22 | model_load_path=None,
23 | mixcycle_init_epochs=None,
24 | snr_max=None,
25 | train_batch_size=None,
26 | valid_batch_size=None,
27 | lr=None,
28 | grad_clip=None,
29 | train_subsample_ratio=None,
30 | valid_subsample_ratio=None,
31 | eval_method=None,
32 | eval_blind_num_repeat=None,
33 | eval_epochs=None,
34 | patience=None,
35 | min_epochs=None,
36 | seed=None,
37 | run_id=None,
38 | device_name=None):
39 | args = locals()
40 | del args['self']
41 |
42 | self.dataset_name, self.dataset_root = get_dataset_specs(librimix_root, realm_root)
43 | run_name = build_run_name(
44 | args=args,
45 | prepend_items={'train': self.dataset_name},
46 | exclude_keys=['train_results_root', 'realm_root', 'librimix_root', 'model_load_path']
47 | )
48 | self.results_dir = os.path.join(train_results_root, run_name)
49 |
50 | self.config = SimpleNamespace()
51 | self.config.results_dir = self.results_dir
52 | self.config.stft_frame_size = default(stft_frame_size, 512)
53 | self.config.stft_hop_size = default(stft_hop_size, 128)
54 | self.config.model_name = default(model_name, 'mixcycle')
55 | self.config.model_load_path = default(model_load_path, None)
56 | self.config.mixcycle_init_epochs = default(mixcycle_init_epochs, 50)
57 | self.config.snr_max = default(snr_max, 30.0)
58 | self.config.train_batch_size = default(train_batch_size, 128)
59 | self.config.valid_batch_size = default(valid_batch_size, 128)
60 | self.config.lr = default(lr, 0.001)
61 | self.config.grad_clip = default(grad_clip, 5.0)
62 | self.config.train_subsample_ratio = default(train_subsample_ratio, 1.0)
63 | self.config.valid_subsample_ratio = default(valid_subsample_ratio, 1.0)
64 | self.config.eval_method = default(eval_method, 'reference')
65 | self.config.eval_blind_num_repeat = default(eval_blind_num_repeat, 1)
66 | self.config.eval_epochs = default(eval_epochs, 1)
67 | self.config.patience = default(patience, 50)
68 | self.config.min_epochs = default(min_epochs, None)
69 | self.config.seed = default(seed, None)
70 | self.config.device_name = default(device_name, 'cuda')
71 |
72 | def start(self):
73 | ensure_clean_results_dir(self.config.results_dir)
74 |
75 | self.logger = get_logger('train', self.config.results_dir)
76 | self.logger.info('config: %s', self.config)
77 | torch.save(self.config, os.path.join(self.config.results_dir, CONFIG_FILENAME))
78 |
79 | setup_determinism(self.config.seed)
80 |
81 | self.config.device = torch.device(self.config.device_name)
82 | self.tensorboard = SummaryWriter(os.path.join(self.config.results_dir, 'tb'))
83 |
84 | self.cur_epoch = 0
85 | self.cur_step = 0
86 | self.cur_patience = self.config.patience
87 | self.last_best = {}
88 | self.metrics_history = []
89 |
90 | self.train_dataloader = create_dataloader(
91 | dataset_name=self.dataset_name,
92 | dataset_root=self.dataset_root,
93 | partition='training',
94 | batch_size=self.config.train_batch_size,
95 | subsample_ratio=self.config.train_subsample_ratio,
96 | )
97 | self.valid_dataloader = create_dataloader(
98 | dataset_name=self.dataset_name,
99 | dataset_root=self.dataset_root,
100 | partition='validation',
101 | batch_size=self.config.valid_batch_size,
102 | subsample_ratio=self.config.valid_subsample_ratio,
103 | )
104 |
105 | _, s_true_wave = next(iter(self.train_dataloader))
106 | self.config.num_sources = s_true_wave.size(1)
107 | self.config.sample_length = s_true_wave.size(2)
108 | self.config.num_batches = len(self.train_dataloader)
109 | self.config.num_frequency_bins = 1 + self.config.stft_frame_size // 2
110 |
111 | if self.config.model_load_path:
112 | model = Model.load(
113 | path=os.path.join(self.config.model_load_path, BEST_CHECKPOINT_FILENAME),
114 | device=self.config.device
115 | ).train()
116 | else:
117 | model = None
118 |
119 | trainer_cls = get_trainer(self.config.model_name)
120 | self.trainer = trainer_cls(config=self.config, model=model)
121 |
122 | self.logger.info('model parameter count: %d', total_num_params(self.trainer.get_model().parameters()))
123 | self._train_loop()
124 | self.logger.info('completed')
125 |
126 | def _train_loop(self):
127 | while self.cur_patience > 0:
128 | for _ in range(self.config.eval_epochs):
129 | train_start = time()
130 | steps = self.trainer.train(self.train_dataloader)
131 | with tqdm(steps, total=self.config.num_batches, leave=False) as progress_bar:
132 | for increment in steps:
133 | self.cur_step += 1
134 | progress_bar.update(increment)
135 | self.cur_epoch += 1
136 |
137 | metrics = {
138 | 'process': {
139 | 'epoch': self.cur_epoch,
140 | 'step': self.cur_step,
141 | 'train_elapsed': time() - train_start,
142 | 'train_loss': self.trainer.get_loss()
143 | },
144 | 'validation': {},
145 | }
146 |
147 | validate_start = time()
148 | with torch.inference_mode():
149 | metrics.update({'validation': self.trainer.validate(self.valid_dataloader)})
150 | metrics['process']['validate_elapsed'] = time() - validate_start
151 |
152 | self._update_best(metrics)
153 | self._update_patience(metrics)
154 | self._report(metrics)
155 |
156 | def _update_best(self, metrics):
157 | metrics['best'] = {}
158 | for key, value in metrics['validation'].items():
159 | if key not in self.last_best or value > self.last_best[key]:
160 | self.last_best[key] = value
161 | metrics['best'][key] = True
162 | else:
163 | metrics['best'][key] = False
164 | return metrics
165 |
166 | def _update_patience(self, metrics):
167 | min_epoch = self.config.min_epochs and self.config.min_epochs > self.cur_epoch
168 |
169 | if self.config.eval_method == 'reference':
170 | main_metric_name = 'sisnri'
171 | elif self.config.eval_method == 'blind':
172 | main_metric_name = 'sisnri_blind'
173 | else:
174 | raise Exception(f'unknown eval_method: {self.config.eval_method}')
175 |
176 | if min_epoch or metrics['best'][main_metric_name] or metrics['best'].get('sisnri_mixit_oracle'):
177 | self.cur_patience = self.config.patience
178 | self.trainer.get_model().save(os.path.join(self.config.results_dir, BEST_CHECKPOINT_FILENAME))
179 | else:
180 | self.cur_patience -= 1
181 | metrics['process']['patience'] = self.cur_patience
182 |
183 | def _report(self, metrics):
184 | log_line = ''
185 | for group in ['process', 'validation']:
186 | for key, value in metrics[group].items():
187 | format_spec = '{}={:.3f}' if isinstance(value, float) else '{}={}'
188 | log_line += format_spec.format(key, value)
189 |
190 | if group == 'validation' and metrics['best'][key]:
191 | log_line += '*'
192 | else:
193 | log_line += ' '
194 |
195 | if group == 'validation':
196 | log_line += ' '
197 |
198 | self.tensorboard.add_scalar(
199 | tag='{}/{}'.format(group, key),
200 | scalar_value=value,
201 | global_step=metrics['process']['step']
202 | )
203 | if group == 'process':
204 | log_line += '| '
205 |
206 | self.logger.info(log_line)
207 |
208 | self.tensorboard.flush()
209 |
210 | metrics['time'] = time()
211 | self.metrics_history.append(metrics)
212 | metrics_history_path = os.path.join(self.config.results_dir, METRICS_HISTORY_FILENAME)
213 | os.makedirs(os.path.dirname(metrics_history_path), exist_ok=True)
214 | torch.save(self.metrics_history, metrics_history_path)
215 |
216 |
217 | if __name__ == '__main__':
218 | configure_console_logger()
219 |
220 | arg_parser = argparse.ArgumentParser()
221 | arg_parser.add_argument('--train-results-root', type=str, required=True)
222 | arg_parser.add_argument('--librimix-root', type=str)
223 | arg_parser.add_argument('--realm-root', type=str)
224 | arg_parser.add_argument('--stft-frame-size', type=int)
225 | arg_parser.add_argument('--stft-hop-size', type=int)
226 | arg_parser.add_argument('--model-name', choices=['pit', 'pit-dm', 'mixit', 'mixpit', 'mixcycle'])
227 | arg_parser.add_argument('--mixcycle-init-epochs', type=int)
228 | arg_parser.add_argument('--snr-max', type=float)
229 | arg_parser.add_argument('--train-batch-size', type=int)
230 | arg_parser.add_argument('--valid-batch-size', type=int)
231 | arg_parser.add_argument('--lr', type=float)
232 | arg_parser.add_argument('--grad-clip', type=float)
233 | arg_parser.add_argument('--train-subsample-ratio', type=float)
234 | arg_parser.add_argument('--valid-subsample-ratio', type=float)
235 | arg_parser.add_argument('--eval-method', choices=['reference', 'blind'])
236 | arg_parser.add_argument('--eval-epochs', type=int)
237 | arg_parser.add_argument('--patience', type=int)
238 | arg_parser.add_argument('--seed', type=int)
239 | arg_parser.add_argument('--run-id', type=str)
240 | arg_parser.add_argument('--device-name', type=str)
241 |
242 | cmd_args = arg_parser.parse_args()
243 | Training(**vars(cmd_args)).start()
244 |
--------------------------------------------------------------------------------