├── .gitignore ├── 01_musdb_pre_processing.py ├── 02_make_timit_phoneme_vocabulary.py ├── 03_train_BL.py ├── 04_train_informed_models.py ├── 05_train_OA.py ├── 06_copy_configs.py ├── 07_eval_alignment.py ├── 08_eval_separation.py ├── LICENSE ├── README.md ├── configs └── .gitignore ├── data ├── __init__.py ├── timit_musdb_test.py ├── timit_musdb_train.py └── timit_musdb_val.py ├── evaluation └── .gitignore ├── models ├── InformedSeparatorWithAttention.py ├── InformedSeparatorWithPerfectAttention.py ├── InformedSeparatorWithSplitAttention.py └── __init__.py ├── requirements.txt ├── sacred_experiment_logs └── .gitignore ├── tensorboard └── .gitignore ├── trained_models └── .gitignore └── utils ├── __init__.py ├── build_models.py ├── data_set_utls.py └── fct.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # Mac OS-specific storage files 79 | .DS_Store 80 | 81 | # vim 82 | *.swp 83 | *.swo 84 | 85 | # Mypy cache 86 | .mypy_cache/ 87 | 88 | # files I do not want to put to GitHub 89 | job_eval.sh 90 | job_eval_align.sh 91 | job_text.sh 92 | job_wst.sh 93 | job.sh 94 | my_pc_job.sh 95 | run_mfa.sh 96 | 97 | data/*.pickle 98 | tensorboard/* 99 | !tensorboard/.gitignore 100 | 101 | cluster_logs/* 102 | 103 | 104 | -------------------------------------------------------------------------------- /01_musdb_pre_processing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import pickle 4 | 5 | import librosa as lb 6 | import numpy as np 7 | import musdb 8 | import yaml 9 | 10 | 11 | # ignore warning about unsafe loaders in pyYAML 5.1 (used in musdb) 12 | # https://github.com/yaml/pyyaml/wiki/PyYAML-yaml.load(input)-Deprecation 13 | yaml.warnings({'YAMLLoadWarning': False}) 14 | 15 | 16 | def musdb_pre_processing(path_to_musdb, path_to_save_data, target_sr, 17 | frame_length): 18 | """ 19 | This function splits all MUSDB tracks in frames of a given length, downsamples them to a given sampling rate, 20 | converts them to mono and saves each frame as .npy-file. It randomly splits the training partition into a training 21 | (80 tracks) and a validation (20 tracks) set. 22 | """ 23 | 24 | path_to_save_train_set = os.path.join(path_to_save_data, 'train') 25 | path_to_save_val_set = os.path.join(path_to_save_data, 'val') 26 | path_to_save_test_set = os.path.join(path_to_save_data, 'test') 27 | 28 | if not os.path.exists(path_to_save_data): 29 | os.makedirs(path_to_save_data) 30 | if not os.path.exists(path_to_save_train_set): 31 | os.makedirs(path_to_save_train_set) 32 | if not os.path.exists(path_to_save_val_set): 33 | os.makedirs(path_to_save_val_set) 34 | if not os.path.exists(path_to_save_test_set): 35 | os.makedirs(path_to_save_test_set) 36 | 37 | # load the musdb train and test partition with the parser musdb (https://github.com/sigsep/sigsep-mus-db) 38 | musdb_corpus = musdb.DB(root_dir=path_to_musdb) 39 | training_tracks = musdb_corpus.load_mus_tracks(subsets=['train']) 40 | test_tracks = musdb_corpus.load_mus_tracks(subsets=['test']) 41 | 42 | # randomly select 20 tracks from the training partition that will be the validation set 43 | all_idx = list(np.arange(0, 100)) 44 | random.seed(1) 45 | val_idx = random.sample(population=all_idx, k=20) # track indices of validation set tracks 46 | train_idx = [idx for idx in all_idx if idx not in val_idx] # track indices of training set tracks 47 | 48 | # process and save training set 49 | train_file_list = [] 50 | for idx in train_idx: 51 | 52 | track = training_tracks[idx] 53 | 54 | track_name = track.name.split('-') 55 | track_name = track_name[0][0:6] + "_" + track_name[1][1:6] 56 | track_name = track_name.replace(" ", "_") 57 | 58 | track_audio = track.targets['accompaniment'].audio 59 | track_audio_mono = lb.to_mono(track_audio.T) 60 | track_audio_mono_resampled = lb.core.resample(track_audio_mono, track.rate, target_sr) 61 | 62 | frames = lb.util.frame(y=track_audio_mono_resampled, frame_length=frame_length, hop_length=frame_length) 63 | number_of_frames = frames.shape[1] 64 | 65 | for n in range(number_of_frames): 66 | file_name = track_name + '_{}.npy'.format(n) 67 | 68 | np.save(os.path.join(path_to_save_train_set, file_name), frames[:, n]) 69 | train_file_list.append(file_name) 70 | 71 | pickle_out = open(os.path.join(path_to_save_train_set, "train_file_list.pickle"), "wb") 72 | pickle.dump(train_file_list, pickle_out) 73 | pickle_out.close() 74 | 75 | # process and save validation set 76 | val_file_list = [] 77 | for idx in val_idx: 78 | 79 | track = training_tracks[idx] 80 | 81 | track_name = track.name.split('-') 82 | track_name = track_name[0][0:6] + "_" + track_name[1][1:6] 83 | track_name = track_name.replace(" ", "_") 84 | 85 | track_audio = track.targets['accompaniment'].audio 86 | track_audio_mono = lb.to_mono(track_audio.T) 87 | track_audio_mono_resampled = lb.core.resample(track_audio_mono, track.rate, target_sr) 88 | 89 | frames = lb.util.frame(y=track_audio_mono_resampled, frame_length=frame_length, hop_length=frame_length) 90 | number_of_frames = frames.shape[1] 91 | 92 | for n in range(number_of_frames): 93 | file_name = track_name + '_{}.npy'.format(n) 94 | 95 | np.save(os.path.join(path_to_save_val_set, file_name), frames[:, n]) 96 | val_file_list.append(file_name) 97 | 98 | pickle_out = open(os.path.join(path_to_save_val_set, "val_file_list.pickle"), "wb") 99 | pickle.dump(val_file_list, pickle_out) 100 | pickle_out.close() 101 | 102 | # process and save test set 103 | test_file_list = [] 104 | for idx in range(50): 105 | 106 | track = test_tracks[idx] 107 | 108 | track_name = track.name.split('-') 109 | track_name = track_name[0][0:6] + "_" + track_name[1][1:6] 110 | track_name = track_name.replace(" ", "_") 111 | 112 | track_audio = track.targets['accompaniment'].audio 113 | track_audio_mono = lb.to_mono(track_audio.T) 114 | track_audio_mono_resampled = lb.core.resample(track_audio_mono, track.rate, target_sr) 115 | 116 | frames = lb.util.frame(y=track_audio_mono_resampled, frame_length=frame_length, hop_length=frame_length) 117 | number_of_frames = frames.shape[1] 118 | 119 | for n in range(number_of_frames): 120 | file_name = track_name + '_{}.npy'.format(n) 121 | 122 | np.save(os.path.join(path_to_save_test_set, file_name), frames[:, n]) 123 | test_file_list.append(file_name) 124 | 125 | pickle_out = open(os.path.join(path_to_save_test_set, "test_file_list.pickle"), "wb") 126 | pickle.dump(test_file_list, pickle_out) 127 | pickle_out.close() 128 | 129 | 130 | if __name__ == '__main__': 131 | 132 | path_to_musdb = '../Datasets/MUSDB18' 133 | path_to_save_data = '../Datasets/MUSDB_accompaniments' 134 | 135 | target_sr = 16000 136 | frame_length = 131584 137 | 138 | musdb_pre_processing(path_to_musdb, path_to_save_data, target_sr=target_sr, frame_length=frame_length) 139 | -------------------------------------------------------------------------------- /02_make_timit_phoneme_vocabulary.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | 4 | import timit_utils as tu 5 | import numpy as np 6 | 7 | 8 | timit_corpus = tu.Corpus('../Datasets/TIMIT/TIMIT/TIMIT') 9 | timit_training_set = timit_corpus.train 10 | 11 | 12 | def get_timit_train_sentence(idx): 13 | person_idx = int(np.floor(idx / 10)) 14 | person = timit_training_set.person_by_index(person_idx) 15 | sentence_idx = idx % 10 16 | sentence = person.sentence_by_index(sentence_idx) 17 | audio = sentence.raw_audio 18 | phonemes = sentence.phones_df.index.values 19 | return audio, phonemes 20 | 21 | 22 | def make_timit_vocabulary(path_to_save_files): 23 | 24 | # 0: (padding token), 1: (silence token), 2: (noise token, indicates noise in clean speech recordings) 25 | vocabulary = ['', '', ''] 26 | 27 | for idx in range(4620): 28 | 29 | audio, phonemes = get_timit_train_sentence(idx) 30 | 31 | for token in phonemes: 32 | if token not in vocabulary: 33 | vocabulary.append(token) 34 | 35 | # dictionary to translate between token and index representation of phonemes 36 | phoneme2idx = {p: int(idx) for (idx, p) in enumerate(vocabulary)} 37 | idx2phoneme = {idx: p for (idx, p) in enumerate(vocabulary)} 38 | 39 | pickle_out = open(os.path.join(path_to_save_files, "timit_vocabulary.pickle"), "wb") 40 | pickle.dump(vocabulary, pickle_out) 41 | pickle_out.close() 42 | 43 | pickle_out = open(os.path.join(path_to_save_files, "phoneme2idx.pickle"), "wb") 44 | pickle.dump(phoneme2idx, pickle_out) 45 | pickle_out.close() 46 | 47 | pickle_out = open(os.path.join(path_to_save_files, "idx2phoneme.pickle"), "wb") 48 | pickle.dump(idx2phoneme, pickle_out) 49 | pickle_out.close() 50 | 51 | print('Vocabulary: ', vocabulary) 52 | print('Vocabulary size: ', len(vocabulary)) 53 | print('phoneme2idx: ', phoneme2idx) 54 | print('idx2phoneme: ', idx2phoneme) 55 | 56 | 57 | if __name__ == '__main__': 58 | 59 | path_to_save_files = 'data/' 60 | 61 | make_timit_vocabulary(path_to_save_files) 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /03_train_BL.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | from sacred import Experiment 11 | from sacred.observers import FileStorageObserver 12 | 13 | import utils.data_set_utls as utls 14 | from utils import fct 15 | from utils import build_models 16 | 17 | ex = Experiment('tisms') 18 | ex.observers.append(FileStorageObserver.create('sacred_experiment_logs')) 19 | 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | print('Device:', device) 22 | 23 | torch.manual_seed(0) 24 | torch.cuda.manual_seed(0) 25 | 26 | @ex.config 27 | def configuration(): 28 | 29 | tag = 'BL' 30 | model = 'InformedSeparatorWithAttention' 31 | data_set = 'timit_musdb' 32 | 33 | side_info_type = 'ones' 34 | 35 | seed = 1 36 | 37 | snr_train = 'random' 38 | snr_val = -5 39 | fft_len = 512 40 | hop_len = 256 41 | window = 'hamming' 42 | 43 | batch_size_train = 32 44 | batch_size_val = 40 45 | epochs = 3000 46 | lr_switch = 200 47 | 48 | text_feature_size = 1 49 | mix_encoder_layers = 2 50 | side_info_encoder_layers = 1 51 | target_decoder_layers = 2 52 | 53 | optimizer_name = 'Adam' 54 | learning_rate = 0.0001 55 | weight_decay = 0 56 | 57 | models_directory = 'trained_models' 58 | comment = 'baseline' 59 | 60 | 61 | @ex.capture 62 | def make_model(model, fft_len, text_feature_size, mix_encoder_layers, 63 | side_info_encoder_layers, target_decoder_layers): 64 | 65 | mix_features_size = int(fft_len / 2 + 1) 66 | 67 | if model == 'InformedSeparatorWithAttention': 68 | separator = build_models.make_informed_separator_with_attention(mix_features_size, text_feature_size, 69 | mix_encoder_layers, side_info_encoder_layers, 70 | target_decoder_layers) 71 | 72 | return separator 73 | 74 | 75 | @ex.capture 76 | def experiment_tag(tag): 77 | return tag 78 | 79 | @ex.capture 80 | def make_experiment_dir(models_directory, tag): 81 | experiment_dir = os.path.join(models_directory, tag) 82 | if not os.path.exists(experiment_dir): 83 | os.mkdir(experiment_dir) 84 | return experiment_dir 85 | 86 | 87 | @ex.capture 88 | def make_data_sets(data_set, snr_train, snr_val, fft_len, hop_len, window, batch_size_train, batch_size_val): 89 | 90 | if data_set == 'timit_musdb': 91 | import data.timit_musdb_train as train_set 92 | import data.timit_musdb_val as val_set 93 | 94 | training_set = train_set.Train(transform=transforms.Compose([utls.MixSNR(snr_train), 95 | utls.StftOnFly(fft_len=fft_len, 96 | hop_len=hop_len, 97 | window=window), 98 | utls.NormalizeWithOwnMax()])) 99 | 100 | validation_set = val_set.Val(transform=transforms.Compose([utls.MixSNR(snr_val), 101 | utls.StftOnFly(fft_len=fft_len, hop_len=hop_len, 102 | window=window), 103 | utls.NormalizeWithOwnMax()])) 104 | 105 | 106 | dataloader_train = DataLoader(training_set, batch_size=batch_size_train, shuffle=True, num_workers=4, 107 | worker_init_fn=utls.worker_init_fn, collate_fn=utls.collate_with_phonemes) 108 | 109 | dataloader_val = DataLoader(validation_set, batch_size=batch_size_val, shuffle=True, num_workers=4, 110 | worker_init_fn=utls.worker_init_fn, collate_fn=utls.collate_with_phonemes) 111 | 112 | return dataloader_train, dataloader_val 113 | 114 | 115 | @ex.capture 116 | def make_optimizer(model_to_train, learning_rate, weight_decay): 117 | optimizer = torch.optim.Adam(model_to_train.parameters(), lr=learning_rate, weight_decay=weight_decay) 118 | return optimizer 119 | 120 | 121 | @ex.capture 122 | def config2main(epochs, lr_switch): 123 | return epochs, lr_switch 124 | 125 | 126 | @ex.automain 127 | def train_model(): 128 | 129 | epochs, lr_switch = config2main() 130 | 131 | tag = experiment_tag() 132 | 133 | writer = SummaryWriter(logdir=os.path.join('tensorboard', tag)) 134 | 135 | experiment_dir = make_experiment_dir() 136 | 137 | dataloader_train, dataloader_val = make_data_sets() 138 | 139 | model_to_train = make_model() 140 | 141 | model_to_train.to(device) 142 | 143 | optimizer = make_optimizer(model_to_train) 144 | 145 | def factor_fn(epoch): 146 | if epoch < lr_switch: 147 | factor = 0.1 148 | else: 149 | factor = 1 150 | return factor 151 | 152 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, factor_fn, last_epoch=-1) 153 | 154 | loss_fn = nn.L1Loss(reduction='sum') 155 | 156 | best_val_cost = 1000000000 157 | counter = 0 158 | 159 | for i in range(epochs): 160 | cost = 0. 161 | val_cost = 0. 162 | np.random.seed(1) 163 | 164 | # training loop 165 | for i_batch, sample_batched in enumerate(dataloader_train): 166 | fake_side_info = fct.make_fake_side_info(sample_batched['target']) 167 | 168 | batch_cost = fct.train_with_attention(model_to_train, loss_fn, optimizer, 169 | sample_batched['mix'].to(device), fake_side_info.to(device), 170 | sample_batched['target'].to(device)) 171 | 172 | cost += batch_cost 173 | 174 | writer.add_scalar('Train_cost', cost, i + 1) 175 | 176 | # validation loop 177 | for i_batchV, sample_batchedV in enumerate(dataloader_val): 178 | fake_side_infoV = fct.make_fake_side_info(sample_batchedV['target']) 179 | 180 | prediction, _ = fct.predict_with_attention(model_to_train.to(device), sample_batchedV['mix'].to(device), 181 | fake_side_infoV.to(device)) 182 | val_loss = loss_fn(prediction, sample_batchedV['target'].to(device)) 183 | val_cost += val_loss.item() 184 | 185 | print("Epoch: {}, Training cost: {} Validation cost: {}".format(i + 1, cost, val_cost)) 186 | writer.add_scalar('Validation_cost', val_cost, i + 1) 187 | 188 | scheduler.step() 189 | 190 | if val_cost < best_val_cost: 191 | best_val_cost = val_cost 192 | counter = 0 193 | 194 | print('Epoch: ', i + 1, 'val cost: ', best_val_cost) 195 | 196 | torch.save({ 197 | 'experiment_tag': tag, 198 | 'epoch': i, 199 | 'model_state_dict': model_to_train.state_dict(), 200 | 'optimizer_state_dict': optimizer.state_dict(), 201 | 'train_cost': cost, 202 | }, os.path.join(experiment_dir, 'model_best_val_cost.pt')) 203 | 204 | print("Model has been saved!") 205 | 206 | elif val_cost > best_val_cost: 207 | counter += 1 208 | 209 | if counter > 200: 210 | 211 | print("No improvement of validation cost for 200 epochs") 212 | 213 | break 214 | 215 | torch.save({ 216 | 'experiment_tag': tag, 217 | 'epoch': i, 218 | 'model_state_dict': model_to_train.state_dict(), 219 | 'optimizer_state_dict': optimizer.state_dict(), 220 | 'train_cost': cost, 221 | }, os.path.join(experiment_dir, 'model_last_train_epoch.pt')) 222 | 223 | -------------------------------------------------------------------------------- /04_train_informed_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | from sacred import Experiment 11 | from sacred.observers import FileStorageObserver 12 | 13 | import utils.data_set_utls as utls 14 | from utils import fct 15 | from utils import build_models 16 | 17 | 18 | ex = Experiment('tisms') 19 | ex.observers.append(FileStorageObserver.create('sacred_experiment_logs')) 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | print('Device:', device) 23 | 24 | torch.manual_seed(0) 25 | torch.cuda.manual_seed(0) 26 | 27 | 28 | @ex.config 29 | def configuration(): 30 | 31 | tag = 'test' 32 | model = 'InformedSeparatorWithAttention' 33 | data_set = 'timit_musdb' 34 | 35 | seed = 1 36 | 37 | snr_train = 'random' 38 | snr_val = -5 39 | fft_len = 512 40 | hop_len = 256 41 | window = 'hamming' 42 | 43 | loss_function = 'L1' 44 | batch_size_train = 32 45 | batch_size_val = 40 46 | epochs = 3000 47 | lr_switch1 = 0 48 | lr_switch2 = 200 49 | 50 | text_feature_size = 63 51 | vocabulary_size = 63 52 | mix_encoder_layers = 2 53 | side_info_encoder_layers = 1 54 | side_info_encoder_bidirectional = True 55 | target_decoder_layers = 2 56 | 57 | optimizer_name = 'Adam' 58 | learning_rate = 0.0001 59 | weight_decay = 0 60 | 61 | models_directory = 'trained_models' 62 | 63 | comment = 'informed model' 64 | 65 | 66 | @ex.capture 67 | def make_model(model, fft_len, mix_encoder_layers, side_info_encoder_layers, target_decoder_layers, 68 | side_info_encoder_bidirectional): 69 | 70 | mix_features_size = int(fft_len / 2 + 1) 71 | 72 | text_feature_size = 63 73 | 74 | if model == 'InformedSeparatorWithAttention': 75 | separator = build_models.make_informed_separator_with_attention(mix_features_size, 76 | text_feature_size, 77 | mix_encoder_layers, 78 | side_info_encoder_layers, 79 | target_decoder_layers, 80 | side_info_encoder_bidirectional) 81 | 82 | elif model == 'InformedSeparatorWithSplitAttention': 83 | separator = build_models.make_informed_separator_with_split_attention(mix_features_size, 84 | text_feature_size, 85 | mix_encoder_layers, 86 | side_info_encoder_layers, 87 | target_decoder_layers, 88 | side_info_encoder_bidirectional) 89 | 90 | return separator 91 | 92 | 93 | @ex.capture 94 | def experiment_tag(tag): 95 | return tag 96 | 97 | 98 | @ex.capture 99 | def make_experiment_dir(models_directory, tag): 100 | experiment_dir = os.path.join(models_directory, tag) 101 | if not os.path.exists(experiment_dir): 102 | os.mkdir(experiment_dir) 103 | return experiment_dir 104 | 105 | 106 | @ex.capture 107 | def make_data_sets(data_set, snr_train, snr_val, fft_len, hop_len, window, batch_size_train, batch_size_val): 108 | 109 | if data_set == 'timit_musdb': 110 | import data.timit_musdb_train as train_set 111 | import data.timit_musdb_val as val_set 112 | 113 | training_set = train_set.Train(transform=transforms.Compose([utls.MixSNR(snr_train), 114 | utls.StftOnFly(fft_len=fft_len, 115 | hop_len=hop_len, 116 | window=window), 117 | utls.NormalizeWithOwnMax()])) 118 | 119 | timit_musdb_val = val_set.Val(transform=transforms.Compose([utls.MixSNR(snr_val), 120 | utls.StftOnFly(fft_len=fft_len, hop_len=hop_len, 121 | window=window), 122 | utls.NormalizeWithOwnMax()])) 123 | 124 | dataloader_train = DataLoader(training_set, batch_size=batch_size_train, shuffle=True, num_workers=4, 125 | worker_init_fn=utls.worker_init_fn, collate_fn=utls.collate_with_phonemes) 126 | 127 | dataloader_val = DataLoader(timit_musdb_val, batch_size=batch_size_val, shuffle=True, num_workers=4, 128 | worker_init_fn=utls.worker_init_fn, collate_fn=utls.collate_with_phonemes) 129 | 130 | return dataloader_train, dataloader_val 131 | 132 | 133 | @ex.capture 134 | def idx2onehot(phonemes_batched): 135 | 136 | """ 137 | 138 | Parameters 139 | ---------- 140 | phonemes_batched: sequence of phoneme indices, shape: (batch size, sequence length) 141 | vocabulary_size: int 142 | 143 | Returns 144 | ------- 145 | 146 | """ 147 | vocabulary_size = 63 148 | 149 | batch_of_one_hot_sentences = [] 150 | for sentence_idx in range(phonemes_batched.size()[0]): 151 | sentence = phonemes_batched[sentence_idx, :] 152 | sentence_one_hot_encoded = utls.idx2one_hot(sentence, vocabulary_size) 153 | batch_of_one_hot_sentences.append(sentence_one_hot_encoded) 154 | batch_of_one_hot_sentences = torch.from_numpy(np.asarray(batch_of_one_hot_sentences)).type(torch.float32) 155 | return batch_of_one_hot_sentences 156 | 157 | 158 | @ex.capture 159 | def make_optimizer(model_to_train, learning_rate, weight_decay): 160 | optimizer = torch.optim.Adam(model_to_train.parameters(), lr=learning_rate, weight_decay=weight_decay) 161 | return optimizer 162 | 163 | 164 | @ex.capture 165 | def config2main(epochs, lr_switch1, lr_switch2, loss_function): 166 | return epochs, lr_switch1, lr_switch2, loss_function 167 | 168 | 169 | @ex.automain 170 | def train_model(): 171 | 172 | epochs, lr_switch1, lr_switch2, loss_function = config2main() 173 | 174 | tag = experiment_tag() 175 | 176 | writer = SummaryWriter(logdir=os.path.join('tensorboard', tag)) 177 | 178 | experiment_dir = make_experiment_dir() 179 | 180 | dataloader_train, dataloader_val = make_data_sets() 181 | 182 | model_to_train = make_model() 183 | 184 | model_to_train.to(device) 185 | 186 | optimizer = make_optimizer(model_to_train) 187 | 188 | def factor_fn(epoch): 189 | if epoch < lr_switch1: 190 | factor = 0.01 191 | elif epoch < lr_switch2: 192 | factor = 0.1 193 | else: 194 | factor = 1 195 | return factor 196 | 197 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, factor_fn, last_epoch=-1) 198 | 199 | if loss_function == 'L1': 200 | loss_fn = nn.L1Loss(reduction='sum') 201 | 202 | best_val_cost = 1000000000 203 | counter = 0 204 | 205 | for i in range(epochs): 206 | cost = 0. 207 | val_cost = 0. 208 | np.random.seed(1) 209 | 210 | # training loop 211 | for i_batch, sample_batched in enumerate(dataloader_train): 212 | 213 | phonemes_idx = sample_batched['phonemes'] 214 | 215 | # output has shape (batch_size, sequence_len, vocabulary_size) 216 | one_hot_phonemes = idx2onehot(phonemes_idx) 217 | 218 | batch_cost = fct.train_with_attention(model_to_train, loss_fn, optimizer, 219 | sample_batched['mix'].to(device), one_hot_phonemes.to(device), 220 | sample_batched['target'].to(device)) 221 | 222 | cost += batch_cost 223 | 224 | writer.add_scalar('Train_cost', cost, i + 1) 225 | 226 | # validation loop 227 | for i_batchV, sample_batchedV in enumerate(dataloader_val): 228 | 229 | phonemes_idxV = sample_batchedV['phonemes'] 230 | 231 | # output has shape (batch_size, sequence_len, vocabulary_size) 232 | one_hot_phonemesV = idx2onehot(phonemes_idxV) 233 | 234 | with torch.no_grad(): 235 | 236 | prediction, alphas = fct.predict_with_attention(model_to_train.to(device), 237 | sample_batchedV['mix'].to(device), 238 | one_hot_phonemesV.to(device)) 239 | 240 | val_loss = loss_fn(prediction, sample_batchedV['target'].to(device)) 241 | 242 | val_cost += val_loss.item() 243 | 244 | print("Epoch: {}, Training cost: {} Validation cost: {}".format(i + 1, cost, val_cost)) 245 | writer.add_scalar('Validation_cost', val_cost, i + 1) 246 | 247 | scheduler.step() 248 | 249 | if i+1 % 100 == 0: 250 | torch.save({ 251 | 'experiment_tag': tag, 252 | 'epoch': i, 253 | 'model_state_dict': model_to_train.state_dict(), 254 | 'optimizer_state_dict': optimizer.state_dict(), 255 | 'train_cost': cost, 256 | }, os.path.join(experiment_dir, 'model_epoch_{}.pt'.format(i+1))) 257 | 258 | if val_cost < best_val_cost: 259 | best_val_cost = val_cost 260 | counter = 0 261 | 262 | print('Epoch: ', i + 1, 'val cost: ', best_val_cost) 263 | 264 | torch.save({ 265 | 'experiment_tag': tag, 266 | 'epoch': i, 267 | 'model_state_dict': model_to_train.state_dict(), 268 | 'optimizer_state_dict': optimizer.state_dict(), 269 | 'train_cost': cost, 270 | }, os.path.join(experiment_dir, 'model_best_val_cost.pt')) 271 | 272 | print("Model has been saved!") 273 | 274 | elif val_cost > best_val_cost: 275 | counter += 1 276 | 277 | if counter > 200: 278 | 279 | print("No improvement of validation cost for {} epochs".format(counter)) 280 | 281 | break 282 | 283 | torch.save({ 284 | 'experiment_tag': tag, 285 | 'epoch': i, 286 | 'model_state_dict': model_to_train.state_dict(), 287 | 'optimizer_state_dict': optimizer.state_dict(), 288 | 'train_cost': cost, 289 | }, os.path.join(experiment_dir, 'model_last_train_epoch.pt')) 290 | 291 | -------------------------------------------------------------------------------- /05_train_OA.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.utils.data import DataLoader 7 | from torchvision import transforms 8 | import numpy as np 9 | from tensorboardX import SummaryWriter 10 | from sacred import Experiment 11 | from sacred.observers import FileStorageObserver 12 | 13 | import utils.data_set_utls as utls 14 | from utils import fct 15 | from utils import build_models 16 | 17 | 18 | ex = Experiment('tisms') 19 | ex.observers.append(FileStorageObserver.create('sacred_experiment_logs')) 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | print('Device:', device) 23 | 24 | torch.manual_seed(0) 25 | torch.cuda.manual_seed(0) 26 | 27 | 28 | @ex.config 29 | def configuration(): 30 | 31 | tag = 'OA' 32 | model = 'InformedSeparatorWithPerfectAttention' 33 | data_set = 'timit_musdb' 34 | 35 | seed = 1 36 | 37 | snr_train = 'random' 38 | snr_val = -5 39 | fft_len = 512 40 | hop_len = 256 41 | window = 'hamming' 42 | 43 | loss_function = 'L1' 44 | batch_size_train = 32 45 | batch_size_val = 40 46 | epochs = 3000 47 | lr_switch1 = 0 48 | lr_switch2 = 200 49 | 50 | text_feature_size = 63 51 | vocabulary_size = 63 52 | mix_encoder_layers = 2 53 | side_info_encoder_layers = 1 54 | side_info_encoder_bidirectional = True 55 | target_decoder_layers = 2 56 | 57 | perfect_alphas = True 58 | 59 | optimizer_name = 'Adam' 60 | learning_rate = 0.0001 61 | weight_decay = 0 62 | 63 | models_directory = 'trained_models' 64 | 65 | comment = 'optimal attention weights' 66 | 67 | 68 | @ex.capture 69 | def make_model(model, fft_len, text_feature_size, mix_encoder_layers, 70 | side_info_encoder_layers, target_decoder_layers, side_info_encoder_bidirectional=True): 71 | 72 | mix_features_size = int(fft_len / 2 + 1) 73 | 74 | if model == 'InformedSeparatorWithAttention': 75 | separator = build_models.make_informed_separator_with_attention(mix_features_size, 76 | text_feature_size, 77 | mix_encoder_layers, 78 | side_info_encoder_layers, 79 | target_decoder_layers, 80 | side_info_encoder_bidirectional) 81 | 82 | elif model == 'InformedSeparatorWithPerfectAttention': 83 | separator = build_models.make_informed_separator_with_perfect_attention(mix_features_size, 84 | text_feature_size, 85 | mix_encoder_layers, 86 | side_info_encoder_layers, 87 | target_decoder_layers, 88 | side_info_encoder_bidirectional) 89 | 90 | return separator 91 | 92 | 93 | @ex.capture 94 | def experiment_tag(tag): 95 | return tag 96 | 97 | 98 | @ex.capture 99 | def make_experiment_dir(models_directory, tag): 100 | experiment_dir = os.path.join(models_directory, tag) 101 | if not os.path.exists(experiment_dir): 102 | os.mkdir(experiment_dir) 103 | return experiment_dir 104 | 105 | 106 | @ex.capture 107 | def make_data_sets(data_set, snr_train, snr_val, fft_len, hop_len, window, batch_size_train, batch_size_val, perfect_alphas): 108 | 109 | if data_set == 'timit_musdb': 110 | import data.timit_musdb_train as train_set 111 | import data.timit_musdb_val as val_set 112 | 113 | if perfect_alphas: 114 | timit_musdb_train = train_set.Train(transform=transforms.Compose([utls.MixSNR(snr_train), 115 | utls.StftOnFly(fft_len=fft_len, 116 | hop_len=hop_len, 117 | window=window), 118 | utls.NormalizeWithOwnMax(), 119 | utls.MakePerfectAttentionWeights(fft_len, hop_len)])) 120 | 121 | timit_musdb_val = val_set.Val(transform=transforms.Compose([utls.MixSNR(snr_val), 122 | utls.StftOnFly(fft_len=fft_len, hop_len=hop_len, 123 | window=window), 124 | utls.NormalizeWithOwnMax(), 125 | utls.MakePerfectAttentionWeights(fft_len, hop_len)])) 126 | 127 | 128 | else: 129 | timit_musdb_train = train_set.Train(transform=transforms.Compose([utls.MixSNR(snr_train), 130 | utls.StftOnFly(fft_len=fft_len, 131 | hop_len=hop_len, 132 | window=window), 133 | utls.NormalizeWithOwnMax()])) 134 | 135 | timit_musdb_val = val_set.Val(transform=transforms.Compose([utls.MixSNR(snr_val), 136 | utls.StftOnFly(fft_len=fft_len, hop_len=hop_len, 137 | window=window), 138 | utls.NormalizeWithOwnMax()])) 139 | 140 | 141 | dataloader_train = DataLoader(timit_musdb_train, batch_size=batch_size_train, shuffle=True, num_workers=4, 142 | worker_init_fn=utls.worker_init_fn, collate_fn=utls.collate_with_phonemes) 143 | 144 | dataloader_val = DataLoader(timit_musdb_val, batch_size=batch_size_val, shuffle=True, num_workers=4, 145 | worker_init_fn=utls.worker_init_fn, collate_fn=utls.collate_with_phonemes) 146 | 147 | return dataloader_train, dataloader_val 148 | 149 | 150 | @ex.capture 151 | def idx2onehot(phonemes_batched, vocabulary_size): 152 | 153 | """ 154 | 155 | Parameters 156 | ---------- 157 | phonemes_batched: sequence of phoneme indices, shape: (batch size, sequence length) 158 | vocabulary_size: int 159 | 160 | Returns 161 | ------- 162 | 163 | """ 164 | batch_of_one_hot_sentences = [] 165 | for sentence_idx in range(phonemes_batched.size()[0]): 166 | sentence = phonemes_batched[sentence_idx, :] 167 | sentence_one_hot_encoded = utls.idx2one_hot(sentence, vocabulary_size) 168 | batch_of_one_hot_sentences.append(sentence_one_hot_encoded) 169 | batch_of_one_hot_sentences = torch.from_numpy(np.asarray(batch_of_one_hot_sentences)).type(torch.float32) 170 | return batch_of_one_hot_sentences 171 | 172 | 173 | @ex.capture 174 | def make_optimizer(model_to_train, learning_rate, weight_decay): 175 | optimizer = torch.optim.Adam(model_to_train.parameters(), lr=learning_rate, weight_decay=weight_decay) 176 | return optimizer 177 | 178 | 179 | @ex.capture 180 | def config2main(epochs, lr_switch1, lr_switch2, loss_function): 181 | return epochs, lr_switch1, lr_switch2, loss_function 182 | 183 | 184 | @ex.automain 185 | def train_model(): 186 | 187 | epochs, lr_switch1, lr_switch2, loss_function = config2main() 188 | 189 | tag = experiment_tag() 190 | 191 | writer = SummaryWriter(logdir=os.path.join('tensorboard', tag)) 192 | 193 | experiment_dir = make_experiment_dir() 194 | 195 | dataloader_train, dataloader_val = make_data_sets() 196 | 197 | model_to_train = make_model() 198 | 199 | optimizer = make_optimizer(model_to_train) 200 | 201 | def factor_fn(epoch): 202 | if epoch < lr_switch1: 203 | factor = 0.01 204 | elif epoch < lr_switch2: 205 | factor = 0.1 206 | else: 207 | factor = 1 208 | return factor 209 | 210 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, factor_fn, last_epoch=-1) 211 | 212 | loss_fn = nn.L1Loss(reduction='sum') 213 | 214 | best_val_cost = 1000000000 215 | 216 | counter = 0 217 | 218 | for i in range(epochs): 219 | cost = 0. 220 | val_cost = 0. 221 | np.random.seed(1) 222 | 223 | # training loop 224 | for i_batch, sample_batched in enumerate(dataloader_train): 225 | 226 | phonemes_idx = sample_batched['phonemes'] 227 | 228 | alphas = sample_batched['perfect_alphas'] 229 | 230 | # output has shape (batch_size, sequence_len, vocabulary_size) 231 | one_hot_phonemes = idx2onehot(phonemes_idx) 232 | 233 | batch_cost = fct.train_with_perfect_attention(model_to_train.to(device), loss_fn, optimizer, 234 | sample_batched['mix'].to(device), one_hot_phonemes.to(device), 235 | sample_batched['target'].to(device), alphas.to(device)) 236 | 237 | cost += batch_cost 238 | 239 | writer.add_scalar('Train_cost', cost, i + 1) 240 | 241 | # validation loop 242 | for i_batchV, sample_batchedV in enumerate(dataloader_val): 243 | 244 | phonemes_idxV = sample_batchedV['phonemes'] 245 | 246 | # output has shape (batch_size, sequence_len, vocabulary_size) 247 | one_hot_phonemesV = idx2onehot(phonemes_idxV) 248 | 249 | with torch.no_grad(): 250 | 251 | prediction, alphas = fct.predict_with_perfect_attention(model_to_train.to(device), 252 | sample_batchedV['mix'].to(device), 253 | one_hot_phonemesV.to(device), 254 | sample_batchedV['perfect_alphas'].to(device)) 255 | 256 | val_loss = loss_fn(prediction, sample_batchedV['target'].to(device)) 257 | 258 | val_cost += val_loss.item() 259 | 260 | print("Epoch: {}, Training cost: {} Validation cost: {}".format(i + 1, cost, val_cost)) 261 | writer.add_scalar('Validation_cost', val_cost, i + 1) 262 | 263 | scheduler.step() 264 | 265 | if i+1 % 100 == 0: 266 | torch.save({ 267 | 'experiment_tag': tag, 268 | 'epoch': i, 269 | 'model_state_dict': model_to_train.state_dict(), 270 | 'optimizer_state_dict': optimizer.state_dict(), 271 | 'train_cost': cost, 272 | }, os.path.join(experiment_dir, 'model_epoch_{}.pt'.format(i+1))) 273 | 274 | if val_cost < best_val_cost: 275 | best_val_cost = val_cost 276 | counter = 0 277 | 278 | print('Epoch: ', i + 1, 'val cost: ', best_val_cost) 279 | 280 | torch.save({ 281 | 'experiment_tag': tag, 282 | 'epoch': i, 283 | 'model_state_dict': model_to_train.state_dict(), 284 | 'optimizer_state_dict': optimizer.state_dict(), 285 | 'train_cost': cost, 286 | }, os.path.join(experiment_dir, 'model_best_val_cost.pt')) 287 | 288 | print("Model has been saved!") 289 | 290 | elif val_cost > best_val_cost: 291 | counter += 1 292 | 293 | if counter > 200: 294 | 295 | print("No improvement of validation cost for {} epochs".format(counter)) 296 | 297 | break 298 | 299 | torch.save({ 300 | 'experiment_tag': tag, 301 | 'epoch': i, 302 | 'model_state_dict': model_to_train.state_dict(), 303 | 'optimizer_state_dict': optimizer.state_dict(), 304 | 'train_cost': cost, 305 | }, os.path.join(experiment_dir, 'model_last_train_epoch.pt')) 306 | 307 | -------------------------------------------------------------------------------- /06_copy_configs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from shutil import copyfile 4 | 5 | def create_tag_named_folder(original_path, new_path, id): 6 | 7 | original_experiment_dir = os.path.join(original_path, '{}'.format(id)) 8 | 9 | config_dict = json.load(open(os.path.join(original_experiment_dir, 'config.json'.format(id)))) 10 | 11 | tag = config_dict['tag'] 12 | 13 | new_experiment_directory = os.path.join(new_path, tag) 14 | 15 | if not os.path.exists(new_experiment_directory): 16 | os.mkdir(new_experiment_directory) 17 | 18 | copyfile(os.path.join(original_experiment_dir, 'config.json'), 19 | os.path.join(new_experiment_directory, 'config.json')) 20 | copyfile(os.path.join(original_experiment_dir, 'cout.txt'), 21 | os.path.join(new_experiment_directory, 'cout.txt')) 22 | copyfile(os.path.join(original_experiment_dir, 'metrics.json'), 23 | os.path.join(new_experiment_directory, 'metrics.json')) 24 | copyfile(os.path.join(original_experiment_dir, 'run.json'), 25 | os.path.join(new_experiment_directory, 'run.json')) 26 | 27 | 28 | if __name__ == '__main__': 29 | 30 | original_path = 'sacred_experiment_logs' 31 | new_path = 'configs' 32 | 33 | exp = sorted(os.listdir(original_path)) 34 | exp.remove('_sources') 35 | exp.remove('.gitignore') 36 | 37 | for exp_id in exp: 38 | 39 | create_tag_named_folder(original_path, new_path, int(exp_id)) 40 | 41 | 42 | -------------------------------------------------------------------------------- /07_eval_alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from torchvision import transforms 6 | import numpy as np 7 | from sacred import Experiment 8 | 9 | import utils.data_set_utls as utls 10 | from utils import build_models 11 | from utils import fct 12 | 13 | ex = Experiment('eval_alignment') 14 | 15 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | print('Device:', device) 17 | 18 | torch.manual_seed(0) 19 | torch.cuda.manual_seed(0) 20 | 21 | @ex.config 22 | def configuration(): 23 | 24 | tag = 'test' 25 | model_state_dict_name = 'model_best_val_cost.pt' 26 | 27 | eval_dir = 'evaluation' 28 | 29 | ex.add_config('configs/{}/config.json'.format(tag)) 30 | test_snr = -5 31 | 32 | 33 | @ex.capture 34 | def make_data_set(data_set, test_snr, fft_len, hop_len, window): 35 | 36 | if data_set == 'timit_musdb': 37 | import data.timit_musdb_test as test_set 38 | 39 | timit_musdb_test = test_set.Test(transform=transforms.Compose([utls.MixSNR(test_snr), 40 | utls.StftOnFly_testset(fft_len=fft_len, 41 | hop_len=hop_len, 42 | window=window), 43 | utls.NormalizeWithOwnMax()])) 44 | 45 | return timit_musdb_test 46 | 47 | 48 | @ex.capture 49 | def make_model(model, fft_len, text_feature_size, mix_encoder_layers, side_info_encoder_layers, target_decoder_layers, 50 | side_info_encoder_bidirectional): 51 | 52 | mix_features_size = int(fft_len / 2 + 1) 53 | 54 | if model == 'InformedSeparatorWithAttention': 55 | separator = build_models.make_informed_separator_with_attention(mix_features_size, 56 | text_feature_size, 57 | mix_encoder_layers, 58 | side_info_encoder_layers, 59 | target_decoder_layers, 60 | side_info_encoder_bidirectional) 61 | 62 | elif model == 'InformedSeparatorWithSplitAttention': 63 | separator = build_models.make_informed_separator_with_split_attention(mix_features_size, 64 | text_feature_size, 65 | mix_encoder_layers, 66 | side_info_encoder_layers, 67 | target_decoder_layers, 68 | side_info_encoder_bidirectional) 69 | 70 | return separator 71 | 72 | 73 | @ex.capture 74 | def load_state_dict(models_directory, tag, model_state_dict_name): 75 | checkpoint = torch.load(os.path.join(models_directory, tag, model_state_dict_name)) 76 | return checkpoint['model_state_dict'], checkpoint['experiment_tag'] 77 | 78 | 79 | @ex.capture 80 | def config2main(tag, test_snr, fft_len, hop_len, window, text_feature_size, vocabulary_size): 81 | 82 | return tag, test_snr, fft_len, hop_len, window, text_feature_size, vocabulary_size 83 | 84 | 85 | @ex.capture 86 | def make_eval_dir(eval_dir, tag, test_snr): 87 | model_eval_dir = os.path.join(eval_dir, tag + "_snr{}".format(test_snr)) 88 | if not os.path.exists(model_eval_dir): 89 | os.mkdir(model_eval_dir) 90 | return model_eval_dir 91 | 92 | 93 | @ex.automain 94 | def eval_model(): 95 | 96 | tag, test_snr, fft_len, hop_len, window, text_feature_size, vocabulary_size = config2main() 97 | 98 | test_set = make_data_set() 99 | 100 | model_to_evaluate = make_model() 101 | 102 | state_dict, training_tag = load_state_dict() 103 | 104 | model_to_evaluate.load_state_dict(state_dict) 105 | 106 | model_to_evaluate.to(device) 107 | 108 | ae_all_snippets = [] # alignment error 109 | 110 | num_predicted_phonemes = 0 111 | num_phonemes_in_10ms_window = 0 112 | num_phonemes_in_20ms_window = 0 113 | num_phonemes_in_25ms_window = 0 114 | num_phonemes_in_50ms_window = 0 115 | num_phonemes_in_75ms_window = 0 116 | num_phonemes_in_100ms_window = 0 117 | num_phonemes_in_200ms_window = 0 118 | 119 | for i in range(len(test_set)): 120 | sample = test_set[i] 121 | 122 | mix_spec = sample['mix'] 123 | phoneme_times = sample['phoneme_times'] # start of phonemes, last number=last phoneme's end (rel. speech_start) 124 | speech_start = sample['speech_start'] # start of the speech recording in the mix (!= first phoneme start) 125 | 126 | phonemes_idx = torch.from_numpy(sample['phonemes']) 127 | 128 | # output has shape (batch_size, sequence_len, vocabulary_size) 129 | side_info = torch.from_numpy(utls.idx2one_hot(phonemes_idx, vocabulary_size), ).type(torch.float32) 130 | side_info = torch.unsqueeze(side_info, dim=0) 131 | 132 | with torch.no_grad(): 133 | predicted_speech_spec, alphas = fct.predict_with_attention(model_to_evaluate, 134 | torch.unsqueeze(torch.from_numpy(mix_spec), dim=0).to(device), 135 | side_info=side_info.to(device)) 136 | 137 | alphas = alphas.detach().cpu().numpy()[0, :, :].T 138 | 139 | phoneme_idx_sequence, phoneme_onsets = fct.viterbi_alignment_from_attention(alphas, hop_len) 140 | phoneme_onsets = np.asarray(phoneme_onsets) + hop_len/32000 141 | phoneme_onsets_truth = np.asarray([(x+speech_start)/16000 for x in phoneme_times][:-1]) # delete end time of last phoneme 142 | 143 | number_of_phonemes = len(phoneme_onsets_truth) 144 | absolute_errors = abs(phoneme_onsets_truth - phoneme_onsets) 145 | absolute_error_snippet = np.mean(abs(phoneme_onsets_truth - phoneme_onsets)) 146 | 147 | # ae_all_snippets.append(absolute_error_snippet) 148 | ae_all_snippets.append(absolute_error_snippet) 149 | 150 | # compute % correct phonemes within a tolerance 151 | correct_phonemes_in_10ms_window = (absolute_errors < 0.01).sum() 152 | num_phonemes_in_10ms_window += correct_phonemes_in_10ms_window 153 | correct_phonemes_in_20ms_window = (absolute_errors < 0.02).sum() 154 | num_phonemes_in_20ms_window += correct_phonemes_in_20ms_window 155 | correct_phonemes_in_25ms_window = (absolute_errors < 0.025).sum() 156 | num_phonemes_in_25ms_window += correct_phonemes_in_25ms_window 157 | correct_phonemes_in_50ms_window = (absolute_errors < 0.05).sum() 158 | num_phonemes_in_50ms_window += correct_phonemes_in_50ms_window 159 | correct_phonemes_in_75ms_window = (absolute_errors < 0.075).sum() 160 | num_phonemes_in_75ms_window += correct_phonemes_in_75ms_window 161 | correct_phonemes_in_100ms_window = (absolute_errors < 0.1).sum() 162 | num_phonemes_in_100ms_window += correct_phonemes_in_100ms_window 163 | correct_phonemes_in_200ms_window = (absolute_errors < 0.2).sum() 164 | num_phonemes_in_200ms_window += correct_phonemes_in_200ms_window 165 | 166 | num_predicted_phonemes += number_of_phonemes 167 | 168 | mean_abs_error_mean = np.mean(np.asarray(ae_all_snippets)) 169 | mean_abs_error_median = np.median(np.asarray(ae_all_snippets)) 170 | 171 | percent_correct_in_10ms_tolerance = num_phonemes_in_10ms_window / num_predicted_phonemes 172 | percent_correct_in_20ms_tolerance = num_phonemes_in_20ms_window / num_predicted_phonemes 173 | percent_correct_in_25ms_tolerance = num_phonemes_in_25ms_window / num_predicted_phonemes 174 | percent_correct_in_50ms_tolerance = num_phonemes_in_50ms_window / num_predicted_phonemes 175 | percent_correct_in_75ms_tolerance = num_phonemes_in_75ms_window / num_predicted_phonemes 176 | percent_correct_in_100ms_tolerance = num_phonemes_in_100ms_window / num_predicted_phonemes 177 | percent_correct_in_200ms_tolerance = num_phonemes_in_200ms_window / num_predicted_phonemes 178 | 179 | print("mean absolute error over test set: ", mean_abs_error_mean) 180 | print("median absolute error over test set: ", mean_abs_error_median) 181 | 182 | model_eval_dir = make_eval_dir() 183 | 184 | np.save(os.path.join(model_eval_dir, 'mean_abs_error_all_test_examples_snr{}.npy'.format(test_snr)), 185 | ae_all_snippets) 186 | 187 | eval_align_summary_dict = {'tag': tag, 'test_snr': test_snr, 'fft_len': fft_len, 188 | 'hop_len': hop_len, 'mean_abs_error_mean': mean_abs_error_mean, 189 | 'mean_abs_error_median': mean_abs_error_median, 190 | 'in_10ms_tol': percent_correct_in_10ms_tolerance, 191 | 'in_20ms_tol': percent_correct_in_20ms_tolerance, 192 | 'in_25ms_tol': percent_correct_in_25ms_tolerance, 193 | 'in_50ms_tol': percent_correct_in_50ms_tolerance, 194 | 'in_75ms_tol': percent_correct_in_75ms_tolerance, 195 | 'in_100ms_tol': percent_correct_in_100ms_tolerance, 196 | 'in_200ms_tol': percent_correct_in_200ms_tolerance 197 | } 198 | 199 | print(eval_align_summary_dict) 200 | 201 | with open(os.path.join(model_eval_dir, 'eval_align_summary_snr{}.json'.format(test_snr)), 'w') as outfile: 202 | json.dump(eval_align_summary_dict, outfile) 203 | -------------------------------------------------------------------------------- /08_eval_separation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import torch 5 | from torchvision import transforms 6 | import numpy as np 7 | import librosa as lb 8 | from sacred import Experiment 9 | 10 | import mir_eval as me 11 | from pystoi.stoi import stoi as eval_stoi 12 | from pypesq import pesq as eval_pesq 13 | 14 | import utils.data_set_utls as utls 15 | from utils import build_models 16 | from utils import fct 17 | 18 | ex = Experiment('eval_separation') 19 | 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | print('Device:', device) 22 | 23 | torch.manual_seed(0) 24 | torch.cuda.manual_seed(0) 25 | 26 | 27 | @ex.config 28 | def configuration(): 29 | 30 | tag = 'tag to evaluate' 31 | seed = 1 32 | model_state_dict_name = 'model_best_val_cost.pt' 33 | 34 | side_info_type = 'phonemes' # 'phonemes' or 'ones' 35 | test_data_set = 'timit_musdb' 36 | 37 | side_info_encoder_bidirectional = True 38 | 39 | perfect_alphas = False 40 | vocabulary_size = None 41 | eval_dir = 'evaluation' 42 | 43 | ex.add_config('configs/{}/config.json'.format(tag)) 44 | 45 | test_snr = -5 46 | 47 | 48 | @ex.capture 49 | def make_data_set(test_data_set, test_snr, fft_len, hop_len, window, perfect_alphas): 50 | 51 | if test_data_set == 'timit_musdb': 52 | import data.timit_musdb_test as test_set 53 | 54 | if perfect_alphas: 55 | timit_musdb_test = test_set.Test(transform=transforms.Compose([utls.MixSNR(test_snr), 56 | utls.StftOnFly_testset(fft_len=fft_len, 57 | hop_len=hop_len, 58 | window=window), 59 | utls.NormalizeWithOwnMax(), 60 | utls.MakePerfectAttentionWeights(fft_len, 61 | hop_len)])) 62 | 63 | else: 64 | timit_musdb_test = test_set.Test(transform=transforms.Compose([utls.MixSNR(test_snr), 65 | utls.StftOnFly_testset(fft_len=fft_len, 66 | hop_len=hop_len, 67 | window=window), 68 | utls.NormalizeWithOwnMax()])) 69 | 70 | return timit_musdb_test 71 | 72 | 73 | 74 | @ex.capture 75 | def make_model(model, fft_len, text_feature_size, mix_encoder_layers, side_info_encoder_layers, 76 | target_decoder_layers, side_info_encoder_bidirectional): 77 | 78 | mix_features_size = int(fft_len / 2 + 1) 79 | 80 | if model == 'InformedSeparatorWithAttention': 81 | separator = build_models.make_informed_separator_with_attention(mix_features_size, 82 | text_feature_size, 83 | mix_encoder_layers, 84 | side_info_encoder_layers, 85 | target_decoder_layers, 86 | side_info_encoder_bidirectional) 87 | 88 | elif model == 'InformedSeparatorWithPerfectAttention': 89 | separator = build_models.make_informed_separator_with_perfect_attention(mix_features_size, 90 | text_feature_size, 91 | mix_encoder_layers, 92 | side_info_encoder_layers, 93 | target_decoder_layers, 94 | side_info_encoder_bidirectional) 95 | 96 | elif model == 'InformedSeparatorWithSplitAttention': 97 | separator = build_models.make_informed_separator_with_split_attention(mix_features_size, 98 | text_feature_size, 99 | mix_encoder_layers, 100 | side_info_encoder_layers, 101 | target_decoder_layers, 102 | side_info_encoder_bidirectional) 103 | 104 | return separator 105 | 106 | 107 | @ex.capture 108 | def load_state_dict(models_directory, tag, model_state_dict_name): 109 | checkpoint = torch.load(os.path.join(models_directory, tag, model_state_dict_name)) 110 | return checkpoint['model_state_dict'], checkpoint['experiment_tag'] 111 | 112 | @ex.capture 113 | def config2main(tag, test_snr, fft_len, hop_len, window, text_feature_size, side_info_type, vocabulary_size, perfect_alphas): 114 | 115 | return tag, test_snr, fft_len, hop_len, window, text_feature_size, side_info_type, vocabulary_size, perfect_alphas 116 | 117 | 118 | @ex.capture 119 | def make_eval_dir(eval_dir, tag, test_snr): 120 | model_eval_dir = os.path.join(eval_dir, tag + "_snr{}".format(test_snr)) 121 | if not os.path.exists(model_eval_dir): 122 | os.mkdir(model_eval_dir) 123 | return model_eval_dir 124 | 125 | 126 | @ex.automain 127 | def eval_model(): 128 | 129 | tag, test_snr, fft_len, hop_len, window, text_feature_size, side_info_type, vocabulary_size, perfect_alphas = config2main() 130 | 131 | test_set = make_data_set() 132 | 133 | model_to_evaluate = make_model() 134 | 135 | state_dict, training_tag = load_state_dict() 136 | 137 | model_to_evaluate.load_state_dict(state_dict) 138 | 139 | model_to_evaluate.to(device) 140 | 141 | print('state dict loaded') 142 | 143 | sdr_speech_all_snippets = [] 144 | sdr_music_all_snippets = [] 145 | sar_speech_all_snippets = [] 146 | sar_music_all_snippets = [] 147 | sir_speech_all_snippets = [] 148 | sir_music_all_snippets = [] 149 | 150 | stoi_speech_all_snippets = [] 151 | pesq_speech_all_snippets = [] 152 | 153 | pes_speech_all_snippets = [] 154 | eps_speech_all_snippets = [] 155 | 156 | for i in range(len(test_set)): 157 | sample = test_set[i] 158 | 159 | mix_spec = sample['mix'] 160 | mix_phase = sample['mix_phase'] 161 | 162 | mix_time_domain = np.expand_dims(sample['mix_time'], axis=0) 163 | true_speech_time_domain = np.expand_dims(sample['speech_time'], axis=0) 164 | true_music_time_domain = np.expand_dims(sample['music_time'], axis=0) 165 | 166 | mix_length = mix_spec.shape[0] 167 | 168 | if side_info_type == 'ones': 169 | side_info = torch.ones((1, mix_length, 1)) 170 | elif side_info_type == 'phonemes': 171 | phonemes_idx = torch.from_numpy(sample['phonemes']) 172 | 173 | # output has shape (batch_size, sequence_len, vocabulary_size) 174 | side_info = torch.from_numpy(utls.idx2one_hot(phonemes_idx, vocabulary_size)).type(torch.float32) 175 | side_info = torch.unsqueeze(side_info, dim=0) 176 | 177 | if perfect_alphas: 178 | with torch.no_grad(): 179 | predicted_speech_spec, alphas = fct.predict_with_perfect_attention(model_to_evaluate, 180 | torch.unsqueeze(torch.from_numpy( 181 | mix_spec), dim=0).to(device), 182 | side_info.to(device), 183 | torch.unsqueeze( 184 | torch.from_numpy( 185 | sample['perfect_alphas']) 186 | .type(torch.float32), dim=0) 187 | .to(device)) 188 | 189 | else: 190 | with torch.no_grad(): 191 | predicted_speech_spec, alphas = fct.predict_with_attention(model_to_evaluate, 192 | torch.unsqueeze(torch.from_numpy(mix_spec), 193 | dim=0).to(device), 194 | side_info=side_info.to(device)) 195 | 196 | complex_predicted_speech = predicted_speech_spec.detach().cpu().numpy() * np.exp(1j * mix_phase) 197 | 198 | predicted_speech_time_domain = lb.core.istft(complex_predicted_speech.T, hop_length=hop_len, win_length=fft_len, 199 | window=window, center=False) * 70 200 | 201 | pes, eps = fct.eval_source_separation_silent_parts( 202 | true_speech_time_domain.flatten(), predicted_speech_time_domain.flatten(), window_size=16000, 203 | hop_size=16000) 204 | 205 | if len(pes) != 0: 206 | pes_speech_all_snippets.append(np.mean(pes)) 207 | if len(eps) != 0: 208 | eps_speech_all_snippets.append(np.mean(eps)) 209 | 210 | if sum(predicted_speech_time_domain) == 0: 211 | print("all-zero prediction:", i) 212 | continue 213 | if sum(true_music_time_domain[0, :]) == 0: 214 | print("all-zero music snippet", i) 215 | continue 216 | 217 | # STOI implementation from https://github.com/mpariente/pystoi 218 | stoi = eval_stoi(true_speech_time_domain.flatten(), predicted_speech_time_domain, fs_sig=16000) 219 | 220 | stoi_speech_all_snippets.append(stoi) 221 | 222 | # PESQ implementation from of https://github.com/vBaiCai/python-pesq 223 | pesq = eval_pesq(true_speech_time_domain.flatten(), predicted_speech_time_domain, 16000) 224 | 225 | pesq_speech_all_snippets.append(pesq) 226 | 227 | predicted_speech_time_domain = np.expand_dims(predicted_speech_time_domain, axis=0) 228 | 229 | predicted_music_time_domain = mix_time_domain - predicted_speech_time_domain 230 | 231 | true_sources = np.concatenate((true_speech_time_domain, true_music_time_domain), axis=0) 232 | predicted_sources = np.concatenate((predicted_speech_time_domain, predicted_music_time_domain), axis=0) 233 | 234 | me.separation.validate(true_sources, predicted_sources) 235 | 236 | sdr, sir, sar, perm = me.separation.bss_eval_sources_framewise(true_sources, predicted_sources, 237 | window=1 * 16000, hop=1 * 16000) 238 | 239 | # evaluation metrics for the current test snippet 240 | sdr_speech = sdr[0] 241 | sdr_music = sdr[1] 242 | sir_speech = sir[0] 243 | sir_music = sir[1] 244 | sar_speech = sar[0] 245 | sar_music = sar[1] 246 | 247 | # compute median over evaluation frames for current test snippet, ignore nan values 248 | sdr_speech_median_snippet = np.median(sdr_speech[~np.isnan(sdr_speech)]) 249 | sdr_music_median_snippet = np.median(sdr_music[~np.isnan(sdr_music)]) 250 | sar_speech_median_snippet = np.median(sar_speech[~np.isnan(sar_speech)]) 251 | sar_music_median_snippet = np.median(sar_music[~np.isnan(sar_music)]) 252 | sir_speech_median_snippet = np.median(sir_speech[~np.isnan(sir_speech)]) 253 | sir_music_median_snippet = np.median(sir_music[~np.isnan(sir_music)]) 254 | 255 | # append median of current snippet to list 256 | sdr_speech_all_snippets.append(sdr_speech_median_snippet) 257 | sdr_music_all_snippets.append(sdr_music_median_snippet) 258 | sar_speech_all_snippets.append(sar_speech_median_snippet) 259 | sar_music_all_snippets.append(sar_music_median_snippet) 260 | sir_speech_all_snippets.append(sir_speech_median_snippet) 261 | sir_music_all_snippets.append(sir_music_median_snippet) 262 | 263 | model_eval_dir = make_eval_dir() 264 | 265 | np.save(os.path.join(model_eval_dir, 'sdr_speech.npy'), sdr_speech_all_snippets) 266 | np.save(os.path.join(model_eval_dir, 'sdr_music.npy'), sdr_music_all_snippets) 267 | np.save(os.path.join(model_eval_dir, 'sar_speech.npy'), sar_speech_all_snippets) 268 | np.save(os.path.join(model_eval_dir, 'sar_music.npy'), sar_music_all_snippets) 269 | np.save(os.path.join(model_eval_dir, 'sir_speech.npy'), sir_speech_all_snippets) 270 | np.save(os.path.join(model_eval_dir, 'sir_music.npy'), sir_music_all_snippets) 271 | np.save(os.path.join(model_eval_dir, 'stoi_speech.npy'), stoi_speech_all_snippets) 272 | np.save(os.path.join(model_eval_dir, 'pesq_speech.npy'), pesq_speech_all_snippets) 273 | np.save(os.path.join(model_eval_dir, 'pes_speech.npy'), pes_speech_all_snippets) 274 | np.save(os.path.join(model_eval_dir, 'eps_speech.npy'), eps_speech_all_snippets) 275 | 276 | eval_summary_dict = {'tag': tag, 'test_snr': test_snr, 'fft_len': fft_len, 'hop_len': hop_len, 277 | 'SDR speech mean': np.mean(np.asarray(sdr_speech_all_snippets)[~np.isnan(sdr_speech_all_snippets)]), 278 | 'SDR speech median': np.median(np.asarray(sdr_speech_all_snippets)[~np.isnan(sdr_speech_all_snippets)]), 279 | 'SDR music mean': np.mean(np.asarray(sdr_music_all_snippets)[~np.isnan(sdr_music_all_snippets)]), 280 | 'SDR music median': np.median(np.asarray(sdr_music_all_snippets)[~np.isnan(sdr_music_all_snippets)]), 281 | 'SAR speech mean': np.mean(np.asarray(sar_speech_all_snippets)[~np.isnan(sar_speech_all_snippets)]), 282 | 'SAR speech median': np.median(np.asarray(sar_speech_all_snippets)[~np.isnan(sar_speech_all_snippets)]), 283 | 'SAR music mean': np.mean(np.asarray(sar_music_all_snippets)[~np.isnan(sar_music_all_snippets)]), 284 | 'SAR music median': np.median(np.asarray(sar_music_all_snippets)[~np.isnan(sar_music_all_snippets)]), 285 | 'SIR speech mean': np.mean(np.asarray(sir_speech_all_snippets)[~np.isnan(sir_speech_all_snippets)]), 286 | 'SIR speech median': np.median(np.asarray(sir_speech_all_snippets)[~np.isnan(sir_speech_all_snippets)]), 287 | 'SIR music mean': np.mean(np.asarray(sir_music_all_snippets)[~np.isnan(sir_music_all_snippets)]), 288 | 'SIR music median': np.median(np.asarray(sir_music_all_snippets)[~np.isnan(sir_music_all_snippets)]), 289 | 'STOI speech mean': np.mean(np.asarray(stoi_speech_all_snippets)), 290 | 'STOI speech median': np.median(np.asarray(stoi_speech_all_snippets)), 291 | 'PESQ speech mean': np.mean(np.asarray(pesq_speech_all_snippets)), 292 | 'PESQ speech median': np.median(np.asarray(pesq_speech_all_snippets)), 293 | 'EPS speech mean': np.mean(np.asarray(eps_speech_all_snippets)[~np.isnan(eps_speech_all_snippets)]), 294 | 'EPS speech median': np.median(np.asarray(eps_speech_all_snippets)[~np.isnan(eps_speech_all_snippets)]), 295 | 'PES speech mean': np.mean(np.asarray(pes_speech_all_snippets)[~np.isnan(pes_speech_all_snippets)]), 296 | 'PES speech median': np.median(np.asarray(pes_speech_all_snippets)[~np.isnan(pes_speech_all_snippets)]) 297 | } 298 | 299 | with open(os.path.join(model_eval_dir, 'eval_summary.json'), 'w') as outfile: 300 | json.dump(eval_summary_dict, outfile) 301 | 302 | print(eval_summary_dict) 303 | 304 | 305 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kilian Schulze-Forster 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Joint phoneme alignment and text-informed speech separation on highly corrupted speech 2 | 3 | Here you find the code to reproduce the experiments of the paper "Joint phoneme alignment and text-informed speech separation on highly corrupted speech" by Kilian Schulze-Forster, Clement S. J. Doire, Gaël Richard, Roland Badeau. Accepted at *IEEE International Conference on Audio, Speech, and Signal Processing, 2020.* 4 | 5 | The paper and audio examples are available [here](https://schufo.github.io/publications/2020-ICASSP) 6 | 7 | ## Download 8 | Clone the repository to your machine: 9 |
10 | git clone https://github.com/schufo/tisms.git
11 | 
12 | 13 | Make sure that your working directory is *tisms/* for all steps described below. 14 | 15 | ## Virtual Environment 16 | The project was done in a conda environment with python 3.6. You can create one with the following command: 17 |
18 | conda create -n tisms_env python=3.6
19 | 
20 | 21 | Activate the environment: 22 |
23 | source activate tisms_env
24 | 
25 | 26 | Then install pytorch. I was using version 1.1.0 but later versions should work as well (I did not test it though). 27 |
28 | conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
29 | 
30 | 31 | The you can run the following command to install all other required packages with pip: 32 |
33 | pip install -r requirements.txt
34 | 
35 | 36 | 37 | ## Data Preprocessing 38 | 39 | At the bottom of the script *01\_musdb\_pre\_processing.py* enter the correct links to your MUSDB dataset and to the directory where you want to save the preprocessed MUDSB data. Then, run the following two commands: 40 |
41 | python 01_musdb_pre_processing.py
42 | 
43 | python 02_make_timit_phoneme_vocabulary.py
44 | 
45 | 46 | In the folder 'data' you find three python files containing the data set classes for training, validation, and testing. Enter the correct path to your TIMIT dataset and to the preprocessed MUSDB data at the top of all three files. 47 | 48 | 49 | ## Training 50 | 51 | To train the Baseline (BL) model run the following from the command line 52 |
53 | python 03_train_BL.py
54 | 
55 | 56 | With the following commands you can train the three versions of the text-informed models: 57 |
58 | python 04_train_informed_models.py with 'tag="V1"'
59 | 
60 | python 04_train_informed_models.py with 'tag="V2"' 'side_info_encoder_bidirectional=False'
61 | 
62 | python 04_train_informed_models.py with 'tag="V3"' 'model="InformedSeparatorWithSplitAttention"'
63 | 
64 | 65 | To train the model with Optimal Attention (OA) weights run: 66 |
67 | python 05_train_OA.py
68 | 
69 | 70 | For this project I tried the experiment tracking package "sacred". Since scripts we want to run for testing need to access configuration files by their tag name we assigned during training, we now need to run the following script to copy the config files to a folder named by their tag: 71 |
72 | python 06_copy_configs.py
73 | 
74 | 75 | 76 | ## Evaluation 77 | 78 | To evaluate the alignment provided by V1, V2, V3 run the command below. To evaluate on clean speech set the test SNR to 100, for corrupted speech set it to -5. Set the tag parameter to the model you want to evaluate. 79 |
80 | python 07_eval_alignment.py with 'test_snr=100' 'tag="V1"'
81 | 
82 | 83 | 84 | To evaluate the separation quality in terms of SDR, SAR, SIR, STOI, PESQ, as well as [PES and EPS](https://github.com/schufo/wiass) run the following script with the respective tags: 85 |
86 | python 08_eval_separation.py with 'tag="BL"'
87 | 
88 | 89 | The evaluation scripts save json-files with evaluation summaries for a quick preview in the *evaluation* folder. For more advanced analysis of the results, numpy-files with all scores for every test example and metric are also saved in the *evaluation* folder. 90 | 91 | ## Acknowledgment 92 | This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowsa-Curie grant agreement No. 765068. 93 | 94 | ## Copyright notice 95 | Copyright 2020 Kilian Schulze-Forster of Télécom Paris, Institut Polytechnique de Paris. 96 | All rights reserved. -------------------------------------------------------------------------------- /configs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schufo/tisms/51510ee933dbd4a4a3a07537e288c9c63e1961fc/data/__init__.py -------------------------------------------------------------------------------- /data/timit_musdb_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import timit_utils as tu 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | corpus = tu.Corpus('../Datasets/TIMIT/TIMIT/TIMIT') 11 | path_to_processed_musdb = '../Datasets/MUSDB_accompaniments' 12 | 13 | timit_test_set = corpus.test 14 | 15 | torch.manual_seed(0) 16 | torch.cuda.manual_seed(0) 17 | 18 | 19 | def get_timit_test_sentence(idx): 20 | 21 | person_idx = int(np.floor(idx / 8)) 22 | person = timit_test_set.person_by_index(person_idx) 23 | sentence_idx = (idx % 8) + 2 # to ignore sentences 0 and 1 (SA1 and SA2), because they are also in training set 24 | sentence = person.sentence_by_index(sentence_idx) 25 | audio = sentence.raw_audio 26 | phonemes = sentence.phones_df.index.values 27 | words = sentence.words_df.index.values 28 | 29 | # the array 'phoneme_times' contains the start values of the phonemes. 30 | # The last number is the end value of the last phoneme ! 31 | phoneme_times = sentence.phones_df['start'].values 32 | phoneme_times = np.append(phoneme_times, sentence.phones_df['end'].values[-1]) 33 | return audio, phonemes, phoneme_times, words 34 | 35 | 36 | class Test(Dataset): 37 | 38 | def __init__(self, transform=None): 39 | # timit related 40 | pickle_in = open('data/phoneme2idx.pickle', 'rb') 41 | self.phoneme2idx = pickle.load(pickle_in) 42 | 43 | # musdb related 44 | self.musdb_test_path = os.path.join(path_to_processed_musdb, 'test') 45 | 46 | pickle_in = open(os.path.join(path_to_processed_musdb, 'test/test_file_list.pickle'), 'rb') 47 | self.test_file_list = pickle.load(pickle_in) 48 | 49 | # make list of shuffled musdb indices to randomly assign a musdb frame to each timit utterance 50 | musdb_indices_1 = list(np.arange(0, 1487)) 51 | np.random.seed(1) 52 | np.random.shuffle(musdb_indices_1) 53 | self.musdb_shuffled_idx = [] 54 | self.musdb_shuffled_idx.extend(musdb_indices_1) 55 | self.transform = transform 56 | 57 | def __len__(self): 58 | return 1344 # number of utterances in TIMIT test partition without sentences SA1 and SA2 59 | 60 | def __getitem__(self, idx): 61 | 62 | speech, phonemes, phoneme_times, words = get_timit_test_sentence(idx) 63 | musdb_accompaniment = np.load(os.path.join(self.musdb_test_path, 64 | self.test_file_list[self.musdb_shuffled_idx[idx]])) 65 | 66 | # pad the speech signal to same length as music 67 | speech_len = len(speech) 68 | music_len = len(musdb_accompaniment) 69 | padding_at_start = np.random.randint(0, music_len - speech_len) 70 | 71 | padding_at_end = music_len - padding_at_start - speech_len 72 | speech_padded = np.pad(array=speech, pad_width=(padding_at_start, padding_at_end), 73 | mode='constant', constant_values=0) 74 | 75 | phoneme_int = np.array([self.phoneme2idx[p] for p in phonemes]) 76 | 77 | # add a silence token (idx=1) to start and end of phoneme sequence 78 | phoneme_int = np.pad(phoneme_int, (1, 1), mode='constant', constant_values=1) 79 | 80 | sample = {'speech': speech_padded, 'music': musdb_accompaniment, 'speech_start': padding_at_start, 81 | 'speech_len': speech_len, 'phonemes': phoneme_int, 'phoneme_times': phoneme_times} 82 | 83 | if self.transform: 84 | sample = self.transform(sample) 85 | 86 | return sample 87 | -------------------------------------------------------------------------------- /data/timit_musdb_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import timit_utils as tu 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | timit_corpus = tu.Corpus('../Datasets/TIMIT/TIMIT/TIMIT') 10 | path_to_processed_musdb = '../Datasets/MUSDB_accompaniments' 11 | 12 | timit_training_set = timit_corpus.train 13 | 14 | torch.manual_seed(0) 15 | torch.cuda.manual_seed(0) 16 | 17 | 18 | def get_timit_train_sentence(idx): 19 | # the training set for this project comprises the first 4320 sentences of the TIMIT training partition 20 | # the persons are not sorted by dialect regions when accessed with .person_by_index, which ensures that all 21 | # dialect regions are represented in both the training and validation set 22 | person_idx = int(np.floor(idx / 10)) 23 | person = timit_training_set.person_by_index(person_idx) 24 | sentence_idx = idx % 10 25 | sentence = person.sentence_by_index(sentence_idx) 26 | audio = sentence.raw_audio 27 | phonemes = sentence.phones_df.index.values 28 | words = sentence.words_df.index.values 29 | 30 | # the array 'phoneme_times' contains the start values of the phonemes. 31 | # The last number is the end value of the last phoneme ! 32 | phoneme_times = sentence.phones_df['start'].values 33 | phoneme_times = np.append(phoneme_times, sentence.phones_df['end'].values[-1]) 34 | 35 | return audio, phonemes, phoneme_times, words 36 | 37 | 38 | class Train(Dataset): 39 | 40 | def __init__(self, transform=None): 41 | 42 | # timit related 43 | pickle_in = open('data/phoneme2idx.pickle', 'rb') 44 | self.phoneme2idx = pickle.load(pickle_in) 45 | 46 | # musdb related 47 | self.musdb_train_path = os.path.join(path_to_processed_musdb, 'train') 48 | # load pickle file made by pre-processing script 49 | pickle_in = open(os.path.join(path_to_processed_musdb, 'train/train_file_list.pickle'), 'rb') 50 | self.train_file_list = pickle.load(pickle_in) 51 | 52 | # make list of shuffled musdb indices to randomly assign a musdb frame to each timit utterance 53 | musdb_indices_1 = list(np.arange(0, 2259)) 54 | musdb_indices_2 = list(np.arange(0, 2259)) 55 | np.random.seed(1) 56 | np.random.shuffle(musdb_indices_1) 57 | np.random.shuffle(musdb_indices_2) 58 | self.musdb_shuffled_idx = [] 59 | self.musdb_shuffled_idx.extend(musdb_indices_1) 60 | self.musdb_shuffled_idx.extend(musdb_indices_2) 61 | self.transform = transform 62 | 63 | def __len__(self): 64 | return 4320 # number of TIMIT utterances assigned to training set 65 | 66 | def __getitem__(self, idx): 67 | 68 | speech, phonemes, phoneme_times, words = get_timit_train_sentence(idx) 69 | musdb_accompaniment = np.load(os.path.join(self.musdb_train_path, 70 | self.train_file_list[self.musdb_shuffled_idx[idx]])) 71 | 72 | # pad the speech signal to same length as music 73 | speech_len = len(speech) 74 | music_len = len(musdb_accompaniment) 75 | padding_at_start = np.random.randint(0, music_len - speech_len) 76 | padding_at_end = music_len - padding_at_start - speech_len 77 | speech_padded = np.pad(array=speech, pad_width=(padding_at_start, padding_at_end), 78 | mode='constant', constant_values=0) 79 | 80 | phoneme_int = np.array([self.phoneme2idx[p] for p in phonemes]) 81 | 82 | 83 | # add a silence token (idx=1) to start and end of phoneme sequence 84 | phoneme_int = np.pad(phoneme_int, (1, 1), mode='constant', constant_values=1) 85 | 86 | sample = {'speech': speech_padded, 'music': musdb_accompaniment, 'speech_start': padding_at_start, 87 | 'speech_len': speech_len, 'phonemes': phoneme_int, 'phoneme_times': phoneme_times, 88 | 'perfect_alphas': None} 89 | 90 | if self.transform: 91 | sample = self.transform(sample) 92 | 93 | return sample 94 | -------------------------------------------------------------------------------- /data/timit_musdb_val.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import timit_utils as tu 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | timit_corpus = tu.Corpus('../Datasets/TIMIT/TIMIT/TIMIT') 10 | path_to_processed_musdb = '../Datasets/MUSDB_accompaniments' 11 | 12 | timit_training_set = timit_corpus.train 13 | 14 | torch.manual_seed(0) 15 | torch.cuda.manual_seed(0) 16 | 17 | 18 | def get_timit_val_sentence(idx): 19 | # the validation set for this project comprises the last 300 sentences of the TIMIT training partition minus 20 | # the first two sentences per speaker (SA1, SA2) resulting in 240 utterance in total. 21 | # the persons are not sorted by dialect regions when accessed with .person_by_index, which ensures that all 22 | # dialect regions are represented in both the training and validation set 23 | person_idx = int(np.floor(idx / 8)) + 432 24 | person = timit_training_set.person_by_index(person_idx) 25 | sentence_idx = (idx % 8) + 2 # to ignore sentences 0 and 1 (SA1 and SA2), because they are also in training set 26 | sentence = person.sentence_by_index(sentence_idx) 27 | audio = sentence.raw_audio 28 | phonemes = sentence.phones_df.index.values 29 | words = sentence.words_df.index.values 30 | 31 | # the array 'phoneme_times' contains the start values of the phonemes. 32 | # The last number is the end value of the last phoneme ! 33 | phoneme_times = sentence.phones_df['start'].values 34 | phoneme_times = np.append(phoneme_times, sentence.phones_df['end'].values[-1]) 35 | 36 | return audio, phonemes, phoneme_times, words 37 | 38 | 39 | class Val(Dataset): 40 | 41 | def __init__(self, transform=None): 42 | 43 | # timit related 44 | pickle_in = open('data/phoneme2idx.pickle', 'rb') 45 | self.phoneme2idx = pickle.load(pickle_in) 46 | 47 | # musdb related 48 | self.musdb_val_path = os.path.join(path_to_processed_musdb, 'val') 49 | pickle_in = open(os.path.join(path_to_processed_musdb, 'val/val_file_list.pickle'), 'rb') 50 | self.val_file_list = pickle.load(pickle_in) 51 | 52 | # make list of shuffled musdb indices to randomly assign a musdb frame to each timit utterance 53 | musdb_indices_1 = list(np.arange(0, 474)) 54 | np.random.seed(1) 55 | np.random.shuffle(musdb_indices_1) 56 | self.musdb_shuffled_idx = [] 57 | self.musdb_shuffled_idx.extend(musdb_indices_1) 58 | self.transform = transform 59 | 60 | def __len__(self): 61 | return 474 # number of MUSDB snippets in val set 62 | 63 | def __getitem__(self, idx): 64 | 65 | speech, phonemes, phoneme_times, words = get_timit_val_sentence(int(np.floor(idx / 2))) 66 | musdb_accompaniment = np.load(os.path.join(self.musdb_val_path, 67 | self.val_file_list[self.musdb_shuffled_idx[idx]])) 68 | 69 | # pad the speech signal to same length as music 70 | speech_len = len(speech) 71 | music_len = len(musdb_accompaniment) 72 | padding_at_start = min(abs(music_len - speech_len - idx*200), music_len - speech_len - idx*10) 73 | padding_at_end = music_len - padding_at_start - speech_len 74 | speech_padded = np.pad(array=speech, pad_width=(padding_at_start, padding_at_end), 75 | mode='constant', constant_values=0) 76 | 77 | phoneme_int = np.array([self.phoneme2idx[p] for p in phonemes]) 78 | 79 | # add a silence token (idx=1) to start and end of phoneme sequence 80 | phoneme_int = np.pad(phoneme_int, (1, 1), mode='constant', constant_values=1) 81 | 82 | sample = {'speech': speech_padded, 'music': musdb_accompaniment, 'speech_start': padding_at_start, 83 | 'speech_len': speech_len, 'phonemes': phoneme_int, 'phoneme_times': phoneme_times} 84 | 85 | if self.transform: 86 | sample = self.transform(sample) 87 | 88 | return sample 89 | -------------------------------------------------------------------------------- /evaluation/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /models/InformedSeparatorWithAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019 Kilian Schulze-Forster 3 | 4 | 5 | This is a PyTorch implementation of the audio source separation model proposed in the paper "Weakly Informed Audio 6 | Source Separation" by Kilian Schulze-Forster, Clement Doire, Gaël Richard, Roland Badeau. 7 | 8 | The following Python packages are required: 9 | 10 | numpy==1.15.4 11 | torch==1.0.1.post2 12 | 13 | To train the model, you can create an instance of the class InformedSeparatorWithAttention. 14 | 15 | In the experiments, the model was used with the following parameters: 16 | 17 | separator = InformedSeparatorWithAttention(mix_features=513, 18 | mix_encoding_size=513, 19 | mix_encoder_layers=2, 20 | side_info_features=1, 21 | side_info_encoding_size=513, 22 | side_info_encoder_layers=2, 23 | connector_output_size=513, 24 | target_decoding_hidden_size=513, 25 | target_decoding_features=513, 26 | target_decoder_layers=2) 27 | 28 | """ 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import numpy as np 34 | 35 | 36 | class InformedSeparatorWithAttention(nn.Module): 37 | 38 | def __init__(self, mix_features, mix_encoding_size, mix_encoder_layers, side_info_features, 39 | side_info_encoding_size, side_info_encoder_layers, connector_output_size, target_decoding_hidden_size, 40 | target_decoding_features, target_decoder_layers, side_info_encoder_bidirectional=True): 41 | 42 | """ 43 | :param mix_features: number of features of the mixture representation, F 44 | :param mix_encoding_size: number of features of the mixture encoding, E 45 | :param mix_encoder_layers: number of layers of the mixture encoder 46 | :param side_info_features: number of features of the side information representation, D 47 | :param side_info_encoding_size: number of features of the side information encoding, J 48 | :param side_info_encoder_layers: number of layers of the side information encoder 49 | :param connector_output_size: number of features of the first layer in the target source decoder 50 | :param target_decoding_hidden_size: number of features of the target source hidden representation q^{(2)} 51 | :param target_decoding_features: number of features of the target source decoding, F 52 | :param target_decoder_layers: number of LSTM layers in the target source decoder 53 | """ 54 | 55 | super(InformedSeparatorWithAttention, self).__init__() 56 | 57 | self.mix_encoder = MixEncoder(mix_features, mix_encoding_size, mix_encoder_layers) 58 | 59 | self.side_info_encoder = SideInfoEncoder(side_info_features, side_info_encoding_size, side_info_encoder_layers, 60 | side_info_encoder_bidirectional) 61 | 62 | if side_info_encoder_bidirectional: 63 | 64 | self.attention = AttentionMechanism(2 * side_info_encoding_size, 2 * mix_encoding_size) 65 | 66 | self.connection = ConnectionLayer(2 * side_info_encoding_size + 2 * mix_encoding_size, connector_output_size) 67 | 68 | self.target_decoder = TargetDecoder(connector_output_size, target_decoding_hidden_size, target_decoding_features, 69 | target_decoder_layers) 70 | else: 71 | self.attention = AttentionMechanism(side_info_encoding_size, 2 * mix_encoding_size) 72 | 73 | self.connection = ConnectionLayer(side_info_encoding_size + 2 * mix_encoding_size, 74 | connector_output_size) 75 | 76 | self.target_decoder = TargetDecoder(connector_output_size, target_decoding_hidden_size, 77 | target_decoding_features, 78 | target_decoder_layers) 79 | 80 | self.side_info_encoding = None 81 | self.mix_encoding = None 82 | self.combined_hidden_representation = None 83 | 84 | def forward(self, mix_input, side_info): 85 | """ 86 | :param mix_input: mixture representation, shape: (batch_size, N, F) 87 | :param side_info: side information representation, shape: (batch_size, M, D) 88 | :return: target_prediction: prediction of the target source magnitude spectrogram, shape: (batch_size, N, F) 89 | alphas: attention weights, shape: (batch_size, N, M) 90 | """ 91 | 92 | self.mix_encoding = self.mix_encoder(mix_input) 93 | 94 | self.side_info_encoding = self.side_info_encoder(side_info) 95 | 96 | context_vector, alphas = self.attention(self.side_info_encoding, self.mix_encoding) 97 | 98 | self.combined_hidden_representation = self.connection(context_vector, self.mix_encoding) 99 | 100 | target_prediction = self.target_decoder(self.combined_hidden_representation) 101 | 102 | return target_prediction, alphas 103 | 104 | 105 | class MixEncoder(nn.Module): 106 | 107 | def __init__(self, input_size, hidden_size, layers): 108 | """ 109 | :param input_size: number of features of the mixture representation, F 110 | :param hidden_size: number of features of the mixture encoding, E 111 | :param layers: number of LSTM layers 112 | """ 113 | 114 | super(MixEncoder, self).__init__() 115 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 116 | bidirectional=True, batch_first=True) 117 | 118 | def forward(self, mix_input): 119 | """ 120 | :param mix_input: mixture representation, shape: (batch_size, N, F) 121 | :return: mix_encoding: shape: (batch_size, N, 2*E) 122 | """ 123 | 124 | mix_encoding, h_n_c_n = self.LSTM(mix_input) 125 | return mix_encoding 126 | 127 | 128 | class SideInfoEncoder(nn.Module): 129 | 130 | def __init__(self, input_size, hidden_size, layers, bidirectional): 131 | """ 132 | :param input_size: number of features of the side information representation, D 133 | :param hidden_size: desired number of features of the side information encoding, J 134 | :param layers: number of LSTM layers 135 | :param bidirectional: boolean, default True 136 | """ 137 | 138 | super(SideInfoEncoder, self).__init__() 139 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 140 | bidirectional=bidirectional, batch_first=True) 141 | 142 | def forward(self, inputs): 143 | """ 144 | :param inputs: side information representation, shape: (batch_size, M, D) 145 | :return: side info encoding, shape (batch_size, M, 2*J) 146 | """ 147 | 148 | # encoding has shape (batch_size, sequence_len, 2*hidden_size) 149 | 150 | encoding, h_n_c_n = self.LSTM(inputs) 151 | 152 | return encoding 153 | 154 | 155 | class AttentionMechanism(nn.Module): 156 | 157 | def __init__(self, side_info_encoding_size, mix_encoding_size): 158 | """ 159 | :param side_info_encoding_size: number of features of the side information encoding, 2*J 160 | :param mix_encoding_size: number of features of the mixture encoding, 2*E 161 | 162 | """ 163 | super(AttentionMechanism, self).__init__() 164 | self.side_info_encoding_size = side_info_encoding_size 165 | self.mix_encoding_size = mix_encoding_size 166 | 167 | # make weight matrix Ws and initialize it 168 | w_s_init = torch.empty(self.mix_encoding_size, self.side_info_encoding_size) 169 | k = np.sqrt(1 / self.side_info_encoding_size) 170 | nn.init.uniform_(w_s_init, -k, k) 171 | self.w_s = nn.Parameter(w_s_init, requires_grad=True) 172 | 173 | def forward(self, side_info_encoding, mix_encoding): 174 | 175 | """ 176 | :param side_info_encoding: output of side information encoder, shape; (batch_size, M, 2*J) 177 | :param mix_encoding: output of the mixture encoder, shape: (batch_size, N, 2*E) 178 | :return: context: matrix containing context vectors for each time step of the mixture encoding, 179 | shape: (batch_size, N, 2*J) 180 | alphas: matrix of attention weights, shape: (batch_size, N, M) 181 | """ 182 | 183 | batch_size = mix_encoding.size()[0] 184 | 185 | current_device = side_info_encoding.device 186 | 187 | # compute score = g_n * W_s * h_m in two steps (equation 3 in the paper) 188 | side_info_transformed = torch.bmm(self.w_s.expand(batch_size, -1, -1).to(current_device), 189 | torch.transpose(side_info_encoding, 1, 2)) 190 | 191 | scores = torch.bmm(mix_encoding, side_info_transformed) 192 | 193 | # compute the attention weights of all side information steps for all time steps of the target source decoder 194 | alphas = F.softmax(scores, dim=2) # shape: (batch_size, N, M) 195 | 196 | # compute context vector for each time step of target source decoder 197 | context = torch.bmm(torch.transpose(side_info_encoding, 1, 2), torch.transpose(alphas, 1, 2)) 198 | 199 | # make shape: (batch_size, N, 2*J) 200 | context = torch.transpose(context, 1, 2) 201 | 202 | return context, alphas 203 | 204 | 205 | class ConnectionLayer(nn.Module): 206 | """ 207 | This layer is part of the target source decoder in the architecture description in the paper (equation 1) 208 | """ 209 | 210 | def __init__(self, input_size, output_size): 211 | """ 212 | :param input_size: sum of features of the context vector and mixture encoding (2*J + 2*E) 213 | :param output_size: desired number of features of the hidden representation q^{(1)} 214 | """ 215 | 216 | super(ConnectionLayer, self).__init__() 217 | self.fc = nn.Linear(input_size, output_size) 218 | self.Tanh = nn.Tanh() 219 | 220 | def forward(self, context, mix_encoding): 221 | """ 222 | :param context: context vector from attention mechanism, shape: (batch_size, N, 2*J) 223 | :param mix_encoding: output of mixture encoder, shape: (batch_size, N, 2*E) 224 | :return: output (hidden representation q^{(1)}), shape: (batch_size, N, output_size) 225 | """ 226 | 227 | concat = torch.cat((context, mix_encoding), dim=2) 228 | output = self.Tanh(self.fc(concat)) 229 | return output 230 | 231 | 232 | class TargetDecoder(nn.Module): 233 | 234 | def __init__(self, input_size, hidden_size, output_size, layers): 235 | """ 236 | :param input_size: number of features of the connection layer output 237 | :param hidden_size: desired number of features of the hidden representation q^{(2)} 238 | :param output_size: number of features of target source estimation, F 239 | :param layers: number of LSTM layers 240 | """ 241 | 242 | super(TargetDecoder, self).__init__() 243 | 244 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 245 | bidirectional=True, batch_first=True) 246 | self.fc = nn.Linear(2 * hidden_size, output_size) 247 | self.ReLU = nn.ReLU() 248 | 249 | def forward(self, inputs): 250 | """ 251 | :param inputs: hidden representation q^{(1)}, shape: (batch_size, N, number_of_features) 252 | :return: output: prediction of the target source magnitude spectrogram, shape: (batch_size, N, F) 253 | """ 254 | 255 | lstm_out, h_n_c_n = self.LSTM(inputs) 256 | output = self.ReLU(self.fc(lstm_out)) 257 | return output 258 | -------------------------------------------------------------------------------- /models/InformedSeparatorWithPerfectAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019 Kilian Schulze-Forster 3 | 4 | 5 | This is a PyTorch implementation of the audio source separation model proposed in the paper "Weakly Informed Audio 6 | Source Separation" by Kilian Schulze-Forster, Clement Doire, Gaël Richard, Roland Badeau. 7 | 8 | The following Python packages are required: 9 | 10 | numpy==1.15.4 11 | torch==1.0.1.post2 12 | 13 | To train the model, you can create an instance of the class InformedSeparatorWithAttention. 14 | 15 | In the experiments, the model was used with the following parameters: 16 | 17 | separator = InformedSeparatorWithAttention(mix_features=513, 18 | mix_encoding_size=513, 19 | mix_encoder_layers=2, 20 | side_info_features=1, 21 | side_info_encoding_size=513, 22 | side_info_encoder_layers=2, 23 | connector_output_size=513, 24 | target_decoding_hidden_size=513, 25 | target_decoding_features=513, 26 | target_decoder_layers=2) 27 | 28 | """ 29 | 30 | import torch 31 | import torch.nn as nn 32 | import numpy as np 33 | 34 | 35 | class InformedSeparatorWithPerfectAttention(nn.Module): 36 | 37 | def __init__(self, mix_features, mix_encoding_size, mix_encoder_layers, side_info_features, 38 | side_info_encoding_size, side_info_encoder_layers, connector_output_size, target_decoding_hidden_size, 39 | target_decoding_features, target_decoder_layers, side_info_encoder_bidirectional=True): 40 | 41 | """ 42 | :param mix_features: number of features of the mixture representation, F 43 | :param mix_encoding_size: number of features of the mixture encoding, E 44 | :param mix_encoder_layers: number of layers of the mixture encoder 45 | :param side_info_features: number of features of the side information representation, D 46 | :param side_info_encoding_size: number of features of the side information encoding, J 47 | :param side_info_encoder_layers: number of layers of the side information encoder 48 | :param connector_output_size: number of features of the first layer in the target source decoder 49 | :param target_decoding_hidden_size: number of features of the target source hidden representation q^{(2)} 50 | :param target_decoding_features: number of features of the target source decoding, F 51 | :param target_decoder_layers: number of LSTM layers in the target source decoder 52 | """ 53 | 54 | super(InformedSeparatorWithPerfectAttention, self).__init__() 55 | 56 | self.mix_encoder = MixEncoder(mix_features, mix_encoding_size, mix_encoder_layers) 57 | 58 | self.side_info_encoder = SideInfoEncoder(side_info_features, side_info_encoding_size, side_info_encoder_layers, 59 | side_info_encoder_bidirectional) 60 | 61 | if side_info_encoder_bidirectional: 62 | 63 | self.attention = FakeAttentionMechanism(2 * side_info_encoding_size, 2 * mix_encoding_size) 64 | 65 | self.connection = ConnectionLayer(2 * side_info_encoding_size + 2 * mix_encoding_size, connector_output_size) 66 | 67 | self.target_decoder = TargetDecoder(connector_output_size, target_decoding_hidden_size, target_decoding_features, 68 | target_decoder_layers) 69 | else: 70 | self.attention = FakeAttentionMechanism(side_info_encoding_size, 2 * mix_encoding_size) 71 | 72 | self.connection = ConnectionLayer(side_info_encoding_size + 2 * mix_encoding_size, 73 | connector_output_size) 74 | 75 | self.target_decoder = TargetDecoder(connector_output_size, target_decoding_hidden_size, 76 | target_decoding_features, 77 | target_decoder_layers) 78 | 79 | self.side_info_encoding = None 80 | self.mix_encoding = None 81 | self.combined_hidden_representation = None 82 | 83 | def forward(self, mix_input, side_info, alphas): 84 | """ 85 | :param mix_input: mixture representation, shape: (batch_size, N, F) 86 | :param side_info: side information representation, shape: (batch_size, M, D) 87 | :return: target_prediction: prediction of the target source magnitude spectrogram, shape: (batch_size, N, F) 88 | alphas: attention weights, shape: (batch_size, N, M) 89 | """ 90 | 91 | self.mix_encoding = self.mix_encoder(mix_input) 92 | 93 | self.side_info_encoding = self.side_info_encoder(side_info) 94 | 95 | context_vector, alphas = self.attention(self.side_info_encoding, self.mix_encoding, alphas) 96 | 97 | self.combined_hidden_representation = self.connection(context_vector, self.mix_encoding) 98 | 99 | target_prediction = self.target_decoder(self.combined_hidden_representation) 100 | 101 | return target_prediction, alphas 102 | 103 | 104 | class MixEncoder(nn.Module): 105 | 106 | def __init__(self, input_size, hidden_size, layers): 107 | """ 108 | :param input_size: number of features of the mixture representation, F 109 | :param hidden_size: number of features of the mixture encoding, E 110 | :param layers: number of LSTM layers 111 | """ 112 | 113 | super(MixEncoder, self).__init__() 114 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 115 | bidirectional=True, batch_first=True) 116 | 117 | def forward(self, mix_input): 118 | """ 119 | :param mix_input: mixture representation, shape: (batch_size, N, F) 120 | :return: mix_encoding: shape: (batch_size, N, 2*E) 121 | """ 122 | 123 | mix_encoding, h_n_c_n = self.LSTM(mix_input) 124 | return mix_encoding 125 | 126 | 127 | class SideInfoEncoder(nn.Module): 128 | 129 | def __init__(self, input_size, hidden_size, layers, bidirectional): 130 | """ 131 | :param input_size: number of features of the side information representation, D 132 | :param hidden_size: desired number of features of the side information encoding, J 133 | :param layers: number of LSTM layers 134 | :param bidirectional: boolean, default True 135 | """ 136 | 137 | super(SideInfoEncoder, self).__init__() 138 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 139 | bidirectional=bidirectional, batch_first=True) 140 | 141 | def forward(self, inputs): 142 | """ 143 | :param inputs: side information representation, shape: (batch_size, M, D) 144 | :return: side info encoding, shape (batch_size, M, 2*J) 145 | """ 146 | 147 | # encoding has shape (batch_size, sequence_len, 2*hidden_size) 148 | 149 | encoding, h_n_c_n = self.LSTM(inputs) 150 | 151 | return encoding 152 | 153 | 154 | class FakeAttentionMechanism(nn.Module): 155 | 156 | def __init__(self, side_info_encoding_size, mix_encoding_size): 157 | """ 158 | :param side_info_encoding_size: number of features of the side information encoding, 2*J 159 | :param mix_encoding_size: number of features of the mixture encoding, 2*E 160 | 161 | """ 162 | super(FakeAttentionMechanism, self).__init__() 163 | self.side_info_encoding_size = side_info_encoding_size 164 | self.mix_encoding_size = mix_encoding_size 165 | 166 | # make weight matrix Ws and initialize it 167 | w_s_init = torch.empty(self.mix_encoding_size, self.side_info_encoding_size) 168 | k = np.sqrt(1 / self.side_info_encoding_size) 169 | nn.init.uniform_(w_s_init, -k, k) 170 | self.w_s = nn.Parameter(w_s_init, requires_grad=True) 171 | 172 | def forward(self, side_info_encoding, mix_encoding, alphas): 173 | 174 | """ 175 | :param side_info_encoding: output of side information encoder, shape; (batch_size, M, 2*J) 176 | :param mix_encoding: output of the mixture encoder, shape: (batch_size, N, 2*E) 177 | :param alphas: perfect attention weights, shape: (batch_size, N, M) 178 | :return: context: matrix containing context vectors for each time step of the mixture encoding, 179 | shape: (batch_size, N, 2*J) 180 | alphas: matrix of attention weights, shape: (batch_size, N, M) 181 | """ 182 | 183 | batch_size = mix_encoding.size()[0] 184 | 185 | current_device = side_info_encoding.device 186 | 187 | # compute context vector for each time step of target source decoder 188 | context = torch.bmm(torch.transpose(side_info_encoding, 1, 2), torch.transpose(alphas, 1, 2)) 189 | 190 | # make shape: (batch_size, N, 2*J) 191 | context = torch.transpose(context, 1, 2) 192 | 193 | return context, alphas 194 | 195 | 196 | class ConnectionLayer(nn.Module): 197 | """ 198 | This layer is part of the target source decoder in the architecture description in the paper (equation 1) 199 | """ 200 | 201 | def __init__(self, input_size, output_size): 202 | """ 203 | :param input_size: sum of features of the context vector and mixture encoding (2*J + 2*E) 204 | :param output_size: desired number of features of the hidden representation q^{(1)} 205 | """ 206 | 207 | super(ConnectionLayer, self).__init__() 208 | self.fc = nn.Linear(input_size, output_size) 209 | self.Tanh = nn.Tanh() 210 | 211 | def forward(self, context, mix_encoding): 212 | """ 213 | :param context: context vector from attention mechanism, shape: (batch_size, N, 2*J) 214 | :param mix_encoding: output of mixture encoder, shape: (batch_size, N, 2*E) 215 | :return: output (hidden representation q^{(1)}), shape: (batch_size, N, output_size) 216 | """ 217 | 218 | concat = torch.cat((context, mix_encoding), dim=2) 219 | output = self.Tanh(self.fc(concat)) 220 | return output 221 | 222 | 223 | class TargetDecoder(nn.Module): 224 | 225 | def __init__(self, input_size, hidden_size, output_size, layers): 226 | """ 227 | :param input_size: number of features of the connection layer output 228 | :param hidden_size: desired number of features of the hidden representation q^{(2)} 229 | :param output_size: number of features of target source estimation, F 230 | :param layers: number of LSTM layers 231 | """ 232 | 233 | super(TargetDecoder, self).__init__() 234 | 235 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 236 | bidirectional=True, batch_first=True) 237 | self.fc = nn.Linear(2 * hidden_size, output_size) 238 | self.ReLU = nn.ReLU() 239 | 240 | def forward(self, inputs): 241 | """ 242 | :param inputs: hidden representation q^{(1)}, shape: (batch_size, N, number_of_features) 243 | :return: output: prediction of the target source magnitude spectrogram, shape: (batch_size, N, F) 244 | """ 245 | 246 | lstm_out, h_n_c_n = self.LSTM(inputs) 247 | output = self.ReLU(self.fc(lstm_out)) 248 | return output 249 | -------------------------------------------------------------------------------- /models/InformedSeparatorWithSplitAttention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019 Kilian Schulze-Forster 3 | 4 | 5 | This is a PyTorch implementation of the audio source separation model proposed in the paper "Weakly Informed Audio 6 | Source Separation" by Kilian Schulze-Forster, Clement Doire, Gaël Richard, Roland Badeau. 7 | 8 | The following Python packages are required: 9 | 10 | numpy==1.15.4 11 | torch==1.0.1.post2 12 | 13 | To train the model, you can create an instance of the class InformedSeparatorWithAttention. 14 | 15 | In the experiments, the model was used with the following parameters: 16 | 17 | separator = InformedSeparatorWithAttention(mix_features=513, 18 | mix_encoding_size=513, 19 | mix_encoder_layers=2, 20 | side_info_features=1, 21 | side_info_encoding_size=513, 22 | side_info_encoder_layers=2, 23 | connector_output_size=513, 24 | target_decoding_hidden_size=513, 25 | target_decoding_features=513, 26 | target_decoder_layers=2) 27 | 28 | """ 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | import numpy as np 34 | 35 | 36 | class InformedSeparatorWithSplitAttention(nn.Module): 37 | 38 | def __init__(self, mix_features, mix_encoding_size, mix_encoder_layers, side_info_features, 39 | side_info_encoding_size, side_info_encoder_layers, connector_output_size, target_decoding_hidden_size, 40 | target_decoding_features, target_decoder_layers, side_info_encoder_bidirectional=True): 41 | 42 | """ 43 | :param mix_features: number of features of the mixture representation, F 44 | :param mix_encoding_size: number of features of the mixture encoding, E 45 | :param mix_encoder_layers: number of layers of the mixture encoder 46 | :param side_info_features: number of features of the side information representation, D 47 | :param side_info_encoding_size: number of features of the side information encoding, J 48 | :param side_info_encoder_layers: number of layers of the side information encoder 49 | :param connector_output_size: number of features of the first layer in the target source decoder 50 | :param target_decoding_hidden_size: number of features of the target source hidden representation q^{(2)} 51 | :param target_decoding_features: number of features of the target source decoding, F 52 | :param target_decoder_layers: number of LSTM layers in the target source decoder 53 | """ 54 | 55 | super(InformedSeparatorWithSplitAttention, self).__init__() 56 | 57 | self.mix_encoder = MixEncoder(mix_features, mix_encoding_size, mix_encoder_layers) 58 | 59 | self.side_info_encoder = SideInfoEncoder(side_info_features, side_info_encoding_size, side_info_encoder_layers, 60 | side_info_encoder_bidirectional) 61 | 62 | if side_info_encoder_bidirectional: 63 | 64 | self.attention = AttentionMechanism(2 * side_info_encoding_size, 2 * mix_encoding_size) 65 | 66 | self.connection = ConnectionLayer(2 * side_info_encoding_size + 2 * mix_encoding_size, connector_output_size) 67 | 68 | self.target_decoder = TargetDecoder(connector_output_size, target_decoding_hidden_size, target_decoding_features, 69 | target_decoder_layers) 70 | else: 71 | self.attention = AttentionMechanism(side_info_encoding_size, 2 * mix_encoding_size) 72 | 73 | self.connection = ConnectionLayer(side_info_encoding_size + 2 * mix_encoding_size, 74 | connector_output_size) 75 | 76 | self.target_decoder = TargetDecoder(connector_output_size, target_decoding_hidden_size, 77 | target_decoding_features, 78 | target_decoder_layers) 79 | 80 | self.side_info_encoding = None 81 | self.mix_encoding = None 82 | self.combined_hidden_representation = None 83 | 84 | def forward(self, mix_input, side_info): 85 | """ 86 | :param mix_input: mixture representation, shape: (batch_size, N, F) 87 | :param side_info: side information representation, shape: (batch_size, M, D) 88 | :return: target_prediction: prediction of the target source magnitude spectrogram, shape: (batch_size, N, F) 89 | alphas: attention weights, shape: (batch_size, N, M) 90 | """ 91 | 92 | self.mix_encoding = self.mix_encoder(mix_input) 93 | 94 | self.side_info_encoding = self.side_info_encoder(side_info) 95 | 96 | context_vector, alphas = self.attention(self.side_info_encoding, self.mix_encoding) 97 | 98 | self.combined_hidden_representation = self.connection(context_vector, self.mix_encoding) 99 | 100 | target_prediction = self.target_decoder(self.combined_hidden_representation) 101 | 102 | return target_prediction, alphas 103 | 104 | 105 | class MixEncoder(nn.Module): 106 | 107 | def __init__(self, input_size, hidden_size, layers): 108 | """ 109 | :param input_size: number of features of the mixture representation, F 110 | :param hidden_size: number of features of the mixture encoding, E 111 | :param layers: number of LSTM layers 112 | """ 113 | 114 | super(MixEncoder, self).__init__() 115 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 116 | bidirectional=True, batch_first=True) 117 | 118 | def forward(self, mix_input): 119 | """ 120 | :param mix_input: mixture representation, shape: (batch_size, N, F) 121 | :return: mix_encoding: shape: (batch_size, N, 2*E) 122 | """ 123 | 124 | mix_encoding, h_n_c_n = self.LSTM(mix_input) 125 | return mix_encoding 126 | 127 | 128 | class SideInfoEncoder(nn.Module): 129 | 130 | def __init__(self, input_size, hidden_size, layers, bidirectional): 131 | """ 132 | :param input_size: number of features of the side information representation, D 133 | :param hidden_size: desired number of features of the side information encoding, J 134 | :param layers: number of LSTM layers 135 | :param bidirectional: boolean, default True 136 | """ 137 | 138 | super(SideInfoEncoder, self).__init__() 139 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 140 | bidirectional=bidirectional, batch_first=True) 141 | 142 | def forward(self, inputs): 143 | """ 144 | :param inputs: side information representation, shape: (batch_size, M, D) 145 | :return: side info encoding, shape (batch_size, M, 2*J) 146 | """ 147 | 148 | # encoding has shape (batch_size, sequence_len, 2*hidden_size) 149 | 150 | encoding, h_n_c_n = self.LSTM(inputs) 151 | 152 | return encoding 153 | 154 | 155 | class AttentionMechanism(nn.Module): 156 | 157 | def __init__(self, side_info_encoding_size, mix_encoding_size): 158 | """ 159 | :param side_info_encoding_size: number of features of the side information encoding, 2*J 160 | :param mix_encoding_size: number of features of the mixture encoding, 2*E 161 | 162 | """ 163 | super(AttentionMechanism, self).__init__() 164 | self.side_info_encoding_size = side_info_encoding_size 165 | self.mix_encoding_size = mix_encoding_size 166 | 167 | self.fc = nn.Linear(side_info_encoding_size, side_info_encoding_size) 168 | 169 | # make weight matrix Ws and initialize it 170 | w_s_init = torch.empty(self.mix_encoding_size, self.side_info_encoding_size) 171 | k = np.sqrt(1 / self.side_info_encoding_size) 172 | nn.init.uniform_(w_s_init, -k, k) 173 | self.w_s = nn.Parameter(w_s_init, requires_grad=True) 174 | 175 | def forward(self, side_info_encoding, mix_encoding): 176 | 177 | """ 178 | :param side_info_encoding: output of side information encoder, shape; (batch_size, M, 2*J) 179 | :param mix_encoding: output of the mixture encoder, shape: (batch_size, N, 2*E) 180 | :return: context: matrix containing context vectors for each time step of the mixture encoding, 181 | shape: (batch_size, N, 2*J) 182 | alphas: matrix of attention weights, shape: (batch_size, N, M) 183 | """ 184 | 185 | batch_size = mix_encoding.size()[0] 186 | 187 | current_device = side_info_encoding.device 188 | 189 | # compute score = g_n * W_s * h_m in two steps (equation 3 in the paper) 190 | intermediate_score = torch.bmm(self.w_s.expand(batch_size, -1, -1).to(current_device), 191 | torch.transpose(side_info_encoding, 1, 2)) 192 | 193 | scores = torch.bmm(mix_encoding, intermediate_score) 194 | 195 | # compute the attention weights of all side information steps for all time steps of the target source decoder 196 | alphas = F.softmax(scores, dim=2) # shape: (batch_size, N, M) 197 | 198 | side_info_encoding_transformed = self.fc(side_info_encoding) 199 | 200 | # compute context vector for each time step of target source decoder 201 | context = torch.bmm(torch.transpose(side_info_encoding_transformed, 1, 2), torch.transpose(alphas, 1, 2)) 202 | 203 | # make shape: (batch_size, N, 2*J) 204 | context = torch.transpose(context, 1, 2) 205 | 206 | return context, alphas 207 | 208 | 209 | class ConnectionLayer(nn.Module): 210 | """ 211 | This layer is part of the target source decoder in the architecture description in the paper (equation 1) 212 | """ 213 | 214 | def __init__(self, input_size, output_size): 215 | """ 216 | :param input_size: sum of features of the context vector and mixture encoding (2*J + 2*E) 217 | :param output_size: desired number of features of the hidden representation q^{(1)} 218 | """ 219 | 220 | super(ConnectionLayer, self).__init__() 221 | self.fc = nn.Linear(input_size, output_size) 222 | self.Tanh = nn.Tanh() 223 | 224 | def forward(self, context, mix_encoding): 225 | """ 226 | :param context: context vector from attention mechanism, shape: (batch_size, N, 2*J) 227 | :param mix_encoding: output of mixture encoder, shape: (batch_size, N, 2*E) 228 | :return: output (hidden representation q^{(1)}), shape: (batch_size, N, output_size) 229 | """ 230 | 231 | concat = torch.cat((context, mix_encoding), dim=2) 232 | output = self.Tanh(self.fc(concat)) 233 | return output 234 | 235 | 236 | class TargetDecoder(nn.Module): 237 | 238 | def __init__(self, input_size, hidden_size, output_size, layers): 239 | """ 240 | :param input_size: number of features of the connection layer output 241 | :param hidden_size: desired number of features of the hidden representation q^{(2)} 242 | :param output_size: number of features of target source estimation, F 243 | :param layers: number of LSTM layers 244 | """ 245 | 246 | super(TargetDecoder, self).__init__() 247 | 248 | self.LSTM = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=layers, 249 | bidirectional=True, batch_first=True) 250 | self.fc = nn.Linear(2 * hidden_size, output_size) 251 | self.ReLU = nn.ReLU() 252 | 253 | def forward(self, inputs): 254 | """ 255 | :param inputs: hidden representation q^{(1)}, shape: (batch_size, N, number_of_features) 256 | :return: output: prediction of the target source magnitude spectrogram, shape: (batch_size, N, F) 257 | """ 258 | 259 | lstm_out, h_n_c_n = self.LSTM(inputs) 260 | output = self.ReLU(self.fc(lstm_out)) 261 | return output 262 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schufo/tisms/51510ee933dbd4a4a3a07537e288c9c63e1961fc/models/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sacred==0.7.5 2 | pymongo==3.8.0 3 | mir_eval==0.5 4 | pystoi==0.2.2 5 | pypesq==1.2.4 6 | timit_utils==0.9.0 7 | tensorboardX==1.7 8 | librosa==0.6.3 9 | musdb==0.2.3 10 | -------------------------------------------------------------------------------- /sacred_experiment_logs/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /tensorboard/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /trained_models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/schufo/tisms/51510ee933dbd4a4a3a07537e288c9c63e1961fc/utils/__init__.py -------------------------------------------------------------------------------- /utils/build_models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def make_informed_separator_with_attention(mix_features_size, text_feature_size, mix_encoder_layers, 4 | side_info_encoder_layers, target_decoder_layers, side_info_encoder_bidirectional=True): 5 | 6 | from models import InformedSeparatorWithAttention as Model 7 | 8 | network = Model.InformedSeparatorWithAttention(mix_features=mix_features_size, 9 | mix_encoding_size=mix_features_size, 10 | mix_encoder_layers=mix_encoder_layers, 11 | side_info_features=text_feature_size, 12 | side_info_encoding_size=mix_features_size, 13 | side_info_encoder_layers=side_info_encoder_layers, 14 | connector_output_size=mix_features_size, 15 | target_decoding_hidden_size=mix_features_size, 16 | target_decoding_features=mix_features_size, 17 | target_decoder_layers=target_decoder_layers, 18 | side_info_encoder_bidirectional=side_info_encoder_bidirectional) 19 | return network 20 | 21 | 22 | def make_informed_separator_with_perfect_attention(mix_features_size, text_feature_size, mix_encoder_layers, 23 | side_info_encoder_layers, target_decoder_layers, side_info_encoder_bidirectional): 24 | 25 | from models import InformedSeparatorWithPerfectAttention as Model 26 | 27 | network = Model.InformedSeparatorWithPerfectAttention(mix_features=mix_features_size, 28 | mix_encoding_size=mix_features_size, 29 | mix_encoder_layers=mix_encoder_layers, 30 | side_info_features=text_feature_size, 31 | side_info_encoding_size=mix_features_size, 32 | side_info_encoder_layers=side_info_encoder_layers, 33 | connector_output_size=mix_features_size, 34 | target_decoding_hidden_size=mix_features_size, 35 | target_decoding_features=mix_features_size, 36 | target_decoder_layers=target_decoder_layers, 37 | side_info_encoder_bidirectional=side_info_encoder_bidirectional) 38 | return network 39 | 40 | 41 | 42 | def make_informed_separator_with_split_attention(mix_features_size, text_feature_size, mix_encoder_layers, 43 | side_info_encoder_layers, target_decoder_layers, side_info_encoder_bidirectional=True): 44 | 45 | from models import InformedSeparatorWithSplitAttention as Model 46 | 47 | network = Model.InformedSeparatorWithSplitAttention(mix_features=mix_features_size, 48 | mix_encoding_size=mix_features_size, 49 | mix_encoder_layers=mix_encoder_layers, 50 | side_info_features=text_feature_size, 51 | side_info_encoding_size=mix_features_size, 52 | side_info_encoder_layers=side_info_encoder_layers, 53 | connector_output_size=mix_features_size, 54 | target_decoding_hidden_size=mix_features_size, 55 | target_decoding_features=mix_features_size, 56 | target_decoder_layers=target_decoder_layers, 57 | side_info_encoder_bidirectional=side_info_encoder_bidirectional) 58 | return network 59 | -------------------------------------------------------------------------------- /utils/data_set_utls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import librosa as lb 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | 7 | # PyTorch related utilities 8 | def worker_init_fn(worker_id): 9 | np.random.seed(np.random.get_state()[1][0] + worker_id) 10 | 11 | 12 | def collate_with_phonemes(sample_list): 13 | 14 | batch_size = len(sample_list) 15 | 16 | # make list of phonemes, mix, speech, music of the batch 17 | list_of_phoneme_sequences = [torch.from_numpy(sample_list[n]['phonemes']) for n in range(batch_size)] 18 | list_of_mix_specs = [torch.from_numpy(sample_list[n]['mix']) for n in range(batch_size)] 19 | list_of_speech_specs = [torch.from_numpy(sample_list[n]['target']) for n in range(batch_size)] 20 | list_of_music_specs = [torch.from_numpy(sample_list[n]['music']) for n in range(batch_size)] 21 | list_of_perfect_alphas = [torch.from_numpy(sample_list[n]['perfect_alphas'].T) for n in range(batch_size)] 22 | 23 | 24 | # pad phonemes to length of longest phoneme sequence in batch and stack them along dim=0 25 | phonemes_batched = pad_sequence(list_of_phoneme_sequences, batch_first=True, padding_value=0).type(torch.float32) 26 | 27 | alphas_batched = pad_sequence(list_of_perfect_alphas, batch_first=True, padding_value=0).type(torch.float32) 28 | alphas_batched = torch.transpose(alphas_batched, 1, 2) 29 | 30 | # stack other elements in batch that have the same size across individual samples 31 | mix_batched = torch.stack(list_of_mix_specs, dim=0) 32 | speech_batched = torch.stack(list_of_speech_specs, dim=0) 33 | music_batched = torch.stack(list_of_music_specs, dim=0) 34 | 35 | samples_batched = {'target': speech_batched, 'music': music_batched, 'mix': mix_batched, 36 | 'phonemes': phonemes_batched, 'perfect_alphas': alphas_batched} 37 | 38 | return samples_batched 39 | 40 | # transformations that can be applied when creating an instance of a data set 41 | 42 | class MixSNR(object): 43 | """ 44 | The energy ratio of speech and music is measured over samples where the speech is active only 45 | """ 46 | 47 | def __init__(self, desired_snr): 48 | self.snr_desired = desired_snr 49 | 50 | def __call__(self, sample): 51 | music = sample['music'] 52 | speech = sample['speech'] 53 | speech_len = sample['speech_len'] 54 | speech_start = sample['speech_start'] 55 | phonemes = sample['phonemes'] 56 | 57 | speech_energy = sum(speech ** 2) 58 | music_energy_at_speech_overlap = sum(music[speech_start: speech_start + speech_len] ** 2) 59 | 60 | target_snr = self.snr_desired 61 | 62 | if self.snr_desired == 'random': 63 | target_snr = np.random.uniform(-8, 0) 64 | 65 | if music_energy_at_speech_overlap > 0.1: 66 | snr_current = 10 * np.log10(speech_energy / music_energy_at_speech_overlap) 67 | snr_difference = target_snr - snr_current 68 | scaling = (10 ** (snr_difference / 10)) 69 | speech_scaled = speech * np.sqrt(scaling) 70 | mix = speech_scaled + music 71 | mix_max = abs(mix).max() 72 | mix = mix / mix_max 73 | speech_scaled = speech_scaled / mix_max 74 | music = music / mix_max 75 | else: 76 | mix = speech + music 77 | mix_max = abs(mix).max() 78 | mix = mix / mix_max 79 | speech_scaled = speech / mix_max 80 | music = music / mix_max 81 | 82 | sample = {'mix': mix, 'target': speech_scaled, 'music': music, 'phonemes': phonemes, 83 | 'speech_start': speech_start, 'speech_len': speech_len, 'phoneme_times': sample['phoneme_times'], 84 | 'perfect_alphas': np.ones((1, 1))} 85 | return sample 86 | 87 | 88 | class Stft_torch(object): 89 | 90 | def __init__(self, fft_len, hop_len, device): 91 | self.fft_len = fft_len 92 | self.hop_len = hop_len 93 | self.device = device 94 | 95 | def __call__(self, sample): 96 | 97 | mix = torch.from_numpy(sample['mix']).to(self.device) 98 | music = torch.from_numpy(sample['music']).to(self.device) 99 | speech = torch.from_numpy(sample['speech']).to(self.device) 100 | 101 | window = torch.hamming_window(self.fft_len, periodic=False).double().to(self.device) 102 | 103 | mix_stft = torch.stft(mix, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, window=window, center=False) 104 | 105 | music_stft = torch.stft(music, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 106 | window=window, center=False) 107 | speech_stft = torch.stft(speech, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 108 | window=window, center=False) 109 | 110 | 111 | mix_mag_spec = torch.sqrt(mix_stft[:, :, 0]**2 + mix_stft[:, :, 1]**2) 112 | music_mag_spec = torch.sqrt(music_stft[:, :, 0]**2 + music_stft[:, :, 1]**2) 113 | speech_mag_spec = torch.sqrt(speech_stft[:, :, 0]**2 + speech_stft[:, :, 1]**2) 114 | 115 | print(mix_mag_spec.transpose(0, 1).size()) 116 | 117 | sample['target'] = speech_mag_spec.transpose(0, 1) 118 | sample['music'] = music_mag_spec.transpose(0, 1) 119 | sample['mix'] = mix_mag_spec.transpose(0, 1) 120 | return sample 121 | 122 | 123 | class StftOnFly(object): 124 | 125 | def __init__(self, fft_len, hop_len, window): 126 | self.fft_len = fft_len 127 | self.hop_len = hop_len 128 | self.window = window 129 | 130 | def __call__(self, sample): 131 | mix = sample['mix'] 132 | music = sample['music'] 133 | speech = sample['target'] 134 | 135 | mix_stft = lb.core.stft(mix, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 136 | window=self.window, center=False) 137 | music_stft = lb.core.stft(music, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 138 | window=self.window, center=False) 139 | speech_stft = lb.core.stft(speech, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 140 | window=self.window, center=False) 141 | 142 | mix_mag_spec = abs(mix_stft).T 143 | music_mag_spec = abs(music_stft).T 144 | speech_mag_spec = abs(speech_stft).T 145 | 146 | sample['target'] = speech_mag_spec 147 | sample['music'] = music_mag_spec 148 | sample['mix'] = mix_mag_spec 149 | return sample 150 | 151 | class StftOnFly_testset(object): 152 | 153 | def __init__(self, fft_len, hop_len, window): 154 | self.fft_len = fft_len 155 | self.hop_len = hop_len 156 | self.window = window 157 | 158 | def __call__(self, sample): 159 | 160 | # time domain signals 161 | mix = sample['mix'] 162 | music = sample['music'] 163 | speech = sample['target'] 164 | 165 | mix_stft = lb.core.stft(mix, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 166 | window=self.window, center=False) 167 | music_stft = lb.core.stft(music, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 168 | window=self.window, center=False) 169 | speech_stft = lb.core.stft(speech, n_fft=self.fft_len, hop_length=self.hop_len, win_length=self.fft_len, 170 | window=self.window, center=False) 171 | 172 | mix_mag_spec = abs(mix_stft).T 173 | music_mag_spec = abs(music_stft).T 174 | speech_mag_spec = abs(speech_stft).T 175 | 176 | mix_phase = np.angle(mix_stft).T 177 | 178 | sample['target'] = speech_mag_spec 179 | sample['music'] = music_mag_spec 180 | sample['mix'] = mix_mag_spec 181 | sample['mix_phase'] = mix_phase 182 | sample['mix_time'] = mix 183 | sample['music_time'] = music 184 | sample['speech_time'] = speech 185 | return sample 186 | 187 | class NormalizeWithOwnMax(object): 188 | 189 | def __call__(self, sample): 190 | mix_spec = sample['mix'] 191 | music_spec = sample['music'] 192 | speech_spec = sample['target'] 193 | 194 | mix_max = mix_spec.max() 195 | music_max = music_spec.max() 196 | speech_max = speech_spec.max() 197 | 198 | mix_norm_spec = mix_spec / mix_max 199 | if music_max > 0: 200 | music_norm_spec = music_spec / music_max 201 | else: 202 | music_norm_spec = music_spec 203 | speech_norm_spec = speech_spec / speech_max 204 | 205 | sample['target'] = speech_norm_spec 206 | sample['music'] = music_norm_spec 207 | sample['mix'] = mix_norm_spec 208 | return sample 209 | 210 | 211 | class MakePerfectAttentionWeights(object): 212 | 213 | def __init__(self, fft_len, hop_len): 214 | self.fft_len = fft_len 215 | self.hop_len = hop_len 216 | 217 | def __call__(self, sample): 218 | phonemes = sample['phonemes'] # sequence of phoneme indices 219 | phoneme_times = sample['phoneme_times'] # start of phonemes, last number=last phoneme's end (rel. speech_start) 220 | speech_start = sample['speech_start'] # start of the speech recording in the mix (!= first phoneme start) 221 | speech_len = sample['speech_len'] # length of the speech recording 222 | time_frames = sample['mix'].shape[0] # number of time frames in spectrograms 223 | 224 | sequence_of_phoneme_sample_idx = np.zeros((time_frames,), dtype=int) 225 | 226 | # assign phoneme indices to frames where they are active at least over half of the frame length 227 | for n in range(0, len(phoneme_times) - 1): 228 | phoneme_start_frame = int(np.floor((speech_start + phoneme_times[n]) / self.hop_len)) 229 | phoneme_end_frame = int(np.floor((speech_start + phoneme_times[n + 1]) / self.hop_len)) 230 | 231 | if phoneme_start_frame < phoneme_end_frame: 232 | sequence_of_phoneme_sample_idx[phoneme_start_frame: phoneme_end_frame] = n + 1 233 | # elif phoneme_start_frame == phoneme_end_frame: 234 | # pass 235 | 236 | # assign idx of last silence token to silent frames at end of speech frames 237 | sequence_of_phoneme_sample_idx[phoneme_end_frame:] = n + 2 238 | 239 | alphas = idx2one_hot(torch.from_numpy(sequence_of_phoneme_sample_idx), len(phonemes)) 240 | 241 | sample['perfect_alphas'] = alphas 242 | 243 | return sample 244 | 245 | 246 | class AlignPhonemes(object): 247 | 248 | def __init__(self, fft_len, hop_len): 249 | self.fft_len = fft_len 250 | self.hop_len = hop_len 251 | 252 | def __call__(self, sample): 253 | phonemes = sample['phonemes'] # sequence of phoneme indices 254 | phoneme_times = sample['phoneme_times'] # start of phonemes, last number=last phoneme's end (rel. speech_start) 255 | speech_start = sample['speech_start'] # start of the speech recording in the mix (!= first phoneme start) 256 | speech_len = sample['speech_len'] # length of the speech recording 257 | time_frames = sample['mix'].shape[0] # number of time frames in spectrograms 258 | 259 | aligned_phoneme_idx_sequence = np.ones((time_frames,), dtype=int) 260 | 261 | # assign noise token (index 2) to frames where speech recording does not contain voice yet 262 | speech_recording_start_frame = int(np.floor(speech_start / self.hop_len)) 263 | phoneme_start_frame = int(np.floor((speech_start + phoneme_times[0]) / self.hop_len)) 264 | if speech_recording_start_frame < phoneme_start_frame: 265 | aligned_phoneme_idx_sequence[speech_recording_start_frame: phoneme_start_frame] = 2 266 | 267 | # assign phoneme indices to frames where they are active at least over half of the frame length 268 | for n in range(0, len(phoneme_times) - 1): 269 | phoneme_start_frame = int(np.floor((speech_start + phoneme_times[n]) / self.hop_len)) 270 | phoneme_end_frame = int(np.floor((speech_start + phoneme_times[n + 1]) / self.hop_len)) 271 | 272 | if phoneme_start_frame < phoneme_end_frame: 273 | aligned_phoneme_idx_sequence[phoneme_start_frame: phoneme_end_frame] = phonemes[n + 1] 274 | 275 | # assign noise token (index 2) to frames where speech is not active but noise is still in ground truth 276 | last_phoneme_end_frame = int(np.floor((speech_start + phoneme_times[-1]) / self.hop_len)) 277 | speech_recording_end_frame = int(np.floor((speech_start + speech_len) / self.hop_len)) 278 | if last_phoneme_end_frame < speech_recording_end_frame: 279 | aligned_phoneme_idx_sequence[last_phoneme_end_frame: speech_recording_end_frame] = 2 280 | # elif last_phoneme_end_frame == speech_recording_end_frame: 281 | # pass 282 | 283 | sample['phonemes'] = aligned_phoneme_idx_sequence 284 | 285 | return sample 286 | 287 | def idx2one_hot(idx_sentence, vocabulary_size): 288 | sentence_one_hot_encoded = [] 289 | for idx in idx_sentence: 290 | phoneme = [0 for _ in range(vocabulary_size)] 291 | phoneme[idx.type(torch.int)] = 1 292 | sentence_one_hot_encoded.append(np.array(phoneme)) 293 | return np.array(sentence_one_hot_encoded) 294 | 295 | -------------------------------------------------------------------------------- /utils/fct.py: -------------------------------------------------------------------------------- 1 | """ 2 | basic functions needed to train and test deep learning models with PyTorch 3 | """ 4 | import torch 5 | import numpy as np 6 | import sys 7 | 8 | 9 | def make_fake_side_info(voice_spectro_tensor): 10 | voice_energy = torch.sum(voice_spectro_tensor, dim=2, keepdim=True) 11 | 12 | fake_side_info = torch.ones_like(voice_energy) 13 | 14 | return fake_side_info 15 | 16 | 17 | def viterbi_alignment_from_attention(attention_weights, hop_len): 18 | 19 | """ 20 | :param attention_weights: shape (M, N) 21 | :param hop_len: int 22 | :return: 23 | """ 24 | 25 | M = attention_weights.shape[0] 26 | N = attention_weights.shape[1] 27 | 28 | # transition probabilities are zero everywhere except when going back to the same state (m --> m) 29 | # or moving to next state (m --> m+1). First dimension (vertical) is the starting state, 30 | # second dimension (horizontally) is the arriving state 31 | 32 | # initialize transition probabilities to 0.5 for both allowed cases 33 | trans_p = np.zeros((M, M)) 34 | for m in range(M): 35 | trans_p[m, m] = 0.5 36 | if m < M - 1: 37 | trans_p[m, m+1] = 0.5 38 | 39 | # initialization 40 | delta = np.zeros((M, N)) # shape: (states, time_steps), contains delta_n(m) 41 | delta[0, 0] = 1 # delta_0(0) = 1 first state (silence token) must be active at first time step 42 | 43 | psi = np.zeros((M, N)) # state that is most likely predecessor of state m at time step n 44 | 45 | # recurrence 46 | for n in range(1, N): 47 | for m in range(M): 48 | 49 | delta_m_n_candidates = [] 50 | for m_candidate in range(M): 51 | delta_m_n_candidates.append(delta[m_candidate, n-1] * trans_p[m_candidate, m]) 52 | 53 | delta[m, n] = max(delta_m_n_candidates) * (attention_weights[m, n] * 2 + 1) 54 | 55 | psi[m, n] = np.argmax(delta_m_n_candidates) 56 | 57 | np.set_printoptions(threshold=sys.maxsize) 58 | 59 | optimal_state_sequence = np.zeros((1, N)) 60 | optimal_state_sequence[0, N-1] = int(M - 1) # force the last state (silent token) to be active at last time step 61 | 62 | for n in range(N-2, 0, -1): 63 | 64 | optimal_state_sequence[0, n] = (psi[int(optimal_state_sequence[0, n+1]), n+1]) 65 | 66 | # compute index of list elements whose right neighbor is different from itself 67 | last_idx_before_change = [i for i, (x, y) in enumerate(zip(optimal_state_sequence[0, :-1], optimal_state_sequence[0, 1:])) if x != y] 68 | 69 | # compute phoneme onset times from idx of last time frame previous phoneme 70 | phoneme_onsets_prediction = [(n + 1) * hop_len / 16000 for n in last_idx_before_change] 71 | 72 | phoneme_onsets_prediction = phoneme_onsets_prediction[:-1] # remove onset prediction of silence token 73 | 74 | # the optimal_state_sequence is a sequence of phoneme indices with length N 75 | return optimal_state_sequence.astype(int), phoneme_onsets_prediction 76 | 77 | 78 | def train_with_attention(model, loss_function, optimizer, mix_inputs, side_info, targets): 79 | 80 | model.train() 81 | optimizer.zero_grad() 82 | 83 | # Forward 84 | output_of_network, _ = model(mix_inputs, side_info) 85 | loss = loss_function(output_of_network, targets) 86 | 87 | # Backward 88 | loss.backward() 89 | 90 | # Update parameters 91 | optimizer.step() 92 | 93 | # return a number that represents the loss 94 | return loss.item() 95 | 96 | 97 | def train_with_perfect_attention(model, loss_function, optimizer, mix_inputs, side_info, targets, alphas): 98 | 99 | model.train() 100 | optimizer.zero_grad() 101 | 102 | # Forward 103 | output_of_network, _ = model(mix_inputs, side_info, alphas) 104 | loss = loss_function(output_of_network, targets) 105 | 106 | # Backward 107 | loss.backward() 108 | 109 | # Update parameters 110 | optimizer.step() 111 | 112 | # return a number that represents the loss 113 | return loss.item() 114 | 115 | 116 | def predict_with_attention(model, mix_input, side_info): 117 | 118 | model.eval() 119 | prediction, alphas = model(mix_input, side_info) 120 | return prediction, alphas 121 | 122 | 123 | def predict_with_perfect_attention(model, mix_input, side_info, alphas_in): 124 | 125 | model.eval() 126 | prediction, alphas = model(mix_input, side_info, alphas_in) 127 | return prediction, alphas 128 | 129 | 130 | def eval_source_separation_silent_parts(true_source, predicted_source, window_size, hop_size): 131 | 132 | num_eval_windows = int(np.ceil((len(true_source) - abs(hop_size - window_size)) / hop_size)) -1 133 | 134 | list_prediction_energy_at_true_silence = [] 135 | list_true_energy_at_predicted_silence = [] 136 | 137 | for ii in range(num_eval_windows): 138 | 139 | prediction_window = predicted_source[ii * window_size: ii * window_size + window_size] 140 | true_window = true_source[ii * window_size: ii * window_size + window_size] 141 | 142 | # compute predicted energy for silent true source (PESTS) 143 | if sum(abs(true_window)) == 0: 144 | prediction_energy_at_true_silence = 10 * np.log10(sum(prediction_window**2) + 10**(-12)) 145 | list_prediction_energy_at_true_silence.append(prediction_energy_at_true_silence) 146 | else: 147 | # list_prediction_energy_at_true_silence.append(np.nan) 148 | pass 149 | 150 | # compute energy of true source when silence (all zeros) is predicted and true source is not silent// 151 | # True Energy at Wrong Silence Prediction (TEWSP) 152 | if sum(abs(prediction_window)) == 0 and sum(abs(true_window)) != 0: 153 | true_source_energy_at_silent_prediction = 10 * np.log10(sum(true_window**2) + 10**(-12)) 154 | list_true_energy_at_predicted_silence.append(true_source_energy_at_silent_prediction) 155 | else: 156 | # list_true_energy_at_predicted_silence.append(np.nan) 157 | pass 158 | 159 | return np.asarray(list_prediction_energy_at_true_silence), np.asarray(list_true_energy_at_predicted_silence) 160 | --------------------------------------------------------------------------------