├── README.md └── src ├── config ├── jnf_config.yaml └── ssf_config.yaml ├── data ├── data_gen_fixed_pos.py ├── data_gen_var_pos.py ├── datamodule.py └── dataset.py ├── models ├── exp_enhancement.py ├── exp_jnf.py ├── exp_ssf.py └── models.py ├── scripts ├── train_jnf.py └── train_ssf.py └── utils └── log_images.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep Non-linear Filters for Multi-channel Speech Enhancement and Separation 2 | 3 | This repository contains code for the papers 4 | 5 | [1] Kristina Tesch, Nils-Hendrik Mohrmann, and Timo Gerkmann, "On the Role of Spatial, Spectral, and Temporal Processing for DNN-based Non-linear Multi-channel Speech Enhancement", Proceedings of Interspeech, pp. 2908-2912, 2022, [[arxiv]](https://arxiv.org/abs/2206.11181), [[audio examples]](https://www.inf.uni-hamburg.de/en/inst/ab/sp/publications/interspeech2022-deepmcfilter.html) 6 | 7 | [2] Kristina Tesch and Timo Gerkmann, "Insights into Deep Non-linear Filters for Improved Multi-channel Speech Enhancement", IEEE/ACM Transactions of Audio, Speech and Language Processing, vol 31. pp.563-575, 2023, [[audio examples]](https://www.inf.uni-hamburg.de/en/inst/ab/sp/publications/tasl2022-deepmcfilter.html) 8 | 9 | [3] Kristina Tesch and Timo Gerkmann, "Spatially Selective Deep Non-linear filters for Speaker Extraction", accepted for ICASSP 2023, [[audio examples]](https://www.inf.uni-hamburg.de/en/inst/ab/sp/publications/icassp2023-spatiallyselective) 10 | 11 | [4] Kristina Tesch and Timo Gerkmann, "Multi-channel Speech Separation Using Spatially Selective Deep Non-linear Filters", IEEE/ACM Transactions of Audio, Speech and Language Processing, vol. 32, pp. 542-553, 2024 [[audio examples]](https://www.inf.uni-hamburg.de/en/inst/ab/sp/publications/tasl2023-ssf-vs-ds.html) 12 | 13 | Take a look at a video of our real-time multi-channel enhancement demo: [http://uhh.de/inf-sp-jnf-demo](http://uhh.de/inf-sp-jnf-demo) 14 | 15 | ## Train JNF with a fixed look direction 16 | 17 | 1. Prepare a dataset by running ```data_gen_fixed_pos.py```. 18 | 2. Prepare a config file. Examples can be found in the config folder. 19 | 3. Run the training script in the scripts folder (replace the path to your config file). 20 | 21 | ## Train steerable JNF-SSF 22 | 23 | 1. Prepare a dataset by running ```data_gen_var_pos.py```. 24 | 2. Prepare a config file. Examples can be found in the config folder. 25 | 3. Run the training script in the scripts folder (replace the path to your config file). -------------------------------------------------------------------------------- /src/config/jnf_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | seed: 123 3 | 4 | data: 5 | n_channels: 3 6 | batch_size: 2 7 | prep_files: { 8 | data: ./data/prep/prep_mix_ch3_sp5_small.hdf5, 9 | meta: ./data/prep/prep_mix_meta_ch3_sp5_small.json, 10 | } 11 | dry_target: True 12 | snr_range: [] 13 | meta_frame_length: 48000 14 | stft_length_samples: 512 15 | stft_shift_samples: 256 16 | n_workers: 10 17 | 18 | network: 19 | n_channels: 3 20 | n_lstm_hidden1: 256 21 | n_lstm_hidden2: 128 22 | bidirectional: True 23 | freq_first: True 24 | output_type: 'CRM' 25 | output_activation: 'tanh' 26 | append_freq_idx: False 27 | permute_freqs: False 28 | dropout: 0 29 | 30 | experiment: 31 | learning_rate: 0.001 32 | weight_decay: 0 33 | loss_alpha: 10 34 | cirm_comp_K: 1 35 | cirm_comp_C: 1 36 | reference_channel: 0 37 | 38 | training: 39 | max_epochs: 250 40 | gradient_clip_val: 0 41 | gradient_clip_algorithm: value 42 | strategy: ddp 43 | accelerator: gpu 44 | devices: 1 45 | 46 | 47 | logging: 48 | tb_log_dir: "../logs/tb_logs" 49 | ckpt_dir: "../logs/ckpts" -------------------------------------------------------------------------------- /src/config/ssf_config.yaml: -------------------------------------------------------------------------------- 1 | 2 | seed: 123 3 | 4 | data: 5 | n_channels: 3 6 | batch_size: 6 7 | prep_files: { 8 | data: ./data/prep/conditional/prep_mix_ch3_sp5_var_target.hdf5, 9 | meta: ./data/prep/conditional/prep_mix_meta_ch3_sp5_var_target.json, 10 | } 11 | dry_target: True 12 | snr_range: [] 13 | meta_frame_length: 48000 14 | stft_length_samples: 512 15 | stft_shift_samples: 256 16 | n_workers: 10 17 | 18 | network: 19 | n_channels: 3 20 | n_lstm_hidden1: 256 21 | n_lstm_hidden2: 128 22 | n_cond_emb_dim: 180 # 360 / angle_resolution 23 | bidirectional: True 24 | causal: False 25 | output_type: 'CRM' 26 | output_activation: 'tanh' 27 | condition_nb_only: False 28 | condition_wb_only: True 29 | 30 | experiment: 31 | learning_rate: 0.001 32 | weight_decay: 0 33 | loss_alpha: 10 34 | cirm_comp_K: 1 35 | cirm_comp_C: 1 36 | reference_channel: 0 37 | n_cond_emb_dim: 180 # same as in network settings 38 | condition_enc_type: arange 39 | cond_arange_params: [-180, 180, 2] # has to match the embedding dim 40 | scheduler_type: 'MultiStepLR' 41 | scheduler_params: { 42 | milestones: [50,100,150,200,250,300,350,400], 43 | gamma: 0.75 44 | } 45 | loss_type: 'l1' 46 | 47 | training: 48 | max_epochs: 500 49 | gradient_clip_val: 0 50 | gradient_clip_algorithm: value 51 | strategy: ddp 52 | accelerator: gpu 53 | devices: 1 54 | 55 | 56 | logging: 57 | tb_log_dir: "../logs/tb_logs" 58 | ckpt_dir: "../logs/ckpts" -------------------------------------------------------------------------------- /src/data/data_gen_fixed_pos.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | import soundfile as sf 4 | from scipy.io import wavfile 5 | import matplotlib.pyplot as plt 6 | import pyroomacoustics as pra 7 | import glob 8 | import numpy as np 9 | import h5py 10 | import os 11 | import random 12 | random.seed(12345) 13 | """ 14 | Dataset generation for a speaker in a fixed location relative to the microphone array. This preprocessing script creates a HDF5 file with three datasets: 15 | - train 16 | - val 17 | - test 18 | 19 | Each dataset has the shape [NUM_SAMPLES, 3, CHANNELS, MAX_SAMPLES_PER_FILE]. In the second axis, we store in this order the 20 | - spatialized target signal (includes reverb) 21 | - the spatialized noise signal (sum of all interfering speakers) 22 | - the dry target signal (including the time-shift caused by the direct path) 23 | 24 | 25 | The code in this file was partially written by Nils Mohrmann. 26 | """ 27 | 28 | # WSJ0 dataset path 29 | WSJ0_PATH = "/path/to/wsj0/CSR-1-WSJ-0/WAV/wsj0" 30 | # Path where to save the simulated data 31 | SIM_DATA_PATH = "./prep/" 32 | 33 | 34 | class RoomSimulation: 35 | 36 | def __init__(self, channels): 37 | self.channels = channels 38 | 39 | def set_room_properties(self, rt: float, room_dim: np.ndarray): 40 | """ 41 | Recreate room with a new reverberation time, this deletes all sources and mics. 42 | :param rt: reverberation time 43 | :param room_dim: room dimension ([x,y,z]) 44 | :return: None 45 | """ 46 | self.rt60_tgt = rt 47 | if self.rt60_tgt > 0: 48 | e_absorption, max_order = pra.inverse_sabine( 49 | self.rt60_tgt, room_dim) 50 | 51 | self.room = pra.ShoeBox(room_dim, 52 | fs=16000, 53 | materials=pra.Material(e_absorption), 54 | max_order=max_order) 55 | else: 56 | e_absorption, max_order = pra.inverse_sabine(0.5, room_dim) 57 | self.room = pra.ShoeBox(room_dim, 58 | fs=16000, 59 | materials=pra.Material(e_absorption), 60 | max_order=0) 61 | 62 | def set_microphones(self, x: float, y: float, z: float, phi: float): 63 | """ 64 | Add microphone array at position xyz with rotation phi 65 | Radius: 0.05 m 66 | :param x: x pos 67 | :param y: y pos 68 | :param z: z pos 69 | :param phi: The counterclockwise rotation of the first element in the array (from the x-axis) 70 | :return: 71 | """ 72 | if self.channels == 2: 73 | # special orientation for 2 mics. -> speaker at broadside 74 | phi += np.pi / 2 75 | R = pra.beamforming.circular_2D_array( 76 | [x, y], self.channels, phi0=phi, radius=0.05) 77 | R = np.vstack((R, [[z] * self.channels])) 78 | self.room.add_microphone_array(pra.Beamformer(R, self.room.fs)) 79 | 80 | def add_source(self, position: np.ndarray, signal: np.ndarray, delay: float): 81 | """ 82 | Add signal source in room with a delay 83 | :param position: position [x, y, z] 84 | :param signal: The signal played by the source 85 | :param delay: A time delay until the source signal starts in the simulation 86 | :return: None 87 | """ 88 | self.room.add_source(position, signal, delay) 89 | 90 | def measure_time(self): 91 | """ 92 | Get measured RT60 93 | :return: rt60 in seconds 94 | """ 95 | self.room.compute_rir() 96 | return self.room.measure_rt60() 97 | 98 | 99 | class SPRoomSimulator: 100 | """ 101 | Generate dataset for the training of the nonlinear training 102 | """ 103 | 104 | def __init__(self, channels=3, seed=13, mode="train"): 105 | self.training = True 106 | if mode == "train": 107 | path = "si_tr_s" 108 | elif mode == "val": 109 | path = "si_dt_20" 110 | elif mode == "test": 111 | path = "si_et_05" 112 | self.training = False 113 | path = f'{WSJ0_PATH}/{path}' 114 | 115 | self.speaker = glob.glob(path + "/*") 116 | 117 | self.rng = np.random.default_rng(seed) 118 | self.rng.shuffle(self.speaker) 119 | self.channels = channels 120 | self.exp_room = RoomSimulation(channels=self.channels) 121 | self.dry_room = RoomSimulation(channels=self.channels) 122 | self.fs = 16000 123 | 124 | def create_sample(self, 125 | speaker_list: List[str], 126 | seed2: int, 127 | reverb: bool = True, 128 | target_angle: float = 0, 129 | rt60_min: float = 0.2, 130 | rt60_max: float = 1, 131 | snr_min: int = -10, 132 | snr_max: int = 5, 133 | side_room: int = 20): 134 | """ 135 | Create for a list of speech signals (first one is the target signal) the spatial image using a randomly placed 136 | microphone array and distributing the interfering speakers (len(speaker_list)-1) uniformly around the array. 137 | 138 | :param n_interfering: number of interfering speaker 139 | :param seed2: Seed for the random audio files and positions 140 | :target_angle: place source at given fixed angle (given in degree) 141 | :reverb: turn off reverberation if set to False 142 | :rt60_min: minimum T60 reverb time 143 | :rt60_max: maximum T60 reverb time 144 | :side_room: angle of closest interfering source (default: 45 deg) 145 | :return: the audio signals as numpy array [N_SPEAKERS, N_CHANNELS, N_SAMPLES] and corresponding meta data 146 | """ 147 | # set seed for this sample 148 | self.rng = np.random.default_rng(seed2) 149 | meta = {} 150 | 151 | signal = [] 152 | for file in speaker_list: 153 | audio, fs = sf.read(file) 154 | signal.append(audio / np.max(np.abs(audio)) * 0.3) 155 | 156 | # ensure noise signal is long enough and does not start with zeros always 157 | offset_indices = np.random.randint( 158 | low=-8000, high=8000, size=len(speaker_list)-1) 159 | target_signal_len = len(signal[0]) 160 | for i in range(len(speaker_list)-1): 161 | new_signal = np.roll( 162 | np.resize(signal[1+i], target_signal_len), shift=offset_indices[i]) 163 | signal[1+i] = new_signal 164 | 165 | # room properties 166 | RT = self.rng.uniform(rt60_min, rt60_max) if reverb else 0 167 | meta["rt"] = RT 168 | room_dim = np.squeeze(np.array([np.random.uniform( 169 | 2.5, 5, 1), np.random.uniform(3, 9, 1), np.random.uniform(2.2, 3.5, 1)])) 170 | meta["room_dim"] = [room_dim[0], room_dim[1], room_dim[2]] 171 | self.exp_room.set_room_properties(RT, np.array(room_dim)) 172 | self.dry_room.set_room_properties(0, np.array(room_dim)) 173 | 174 | 175 | # random mic position in room (min 1 m to wall) 176 | mic_pos = self.rng.random(3) * (room_dim - 2.02) + 1.01 177 | mic_pos[2] = 1.5 178 | phi = self.rng.random() * 2 * np.pi # microphone rotation 179 | self.exp_room.set_microphones(mic_pos[0], mic_pos[1], mic_pos[2], phi) 180 | self.dry_room.set_microphones(mic_pos[0], mic_pos[1], mic_pos[2], phi) 181 | 182 | meta["mic_pos"] = mic_pos.tolist() 183 | meta["mic_phi"] = phi 184 | 185 | # target speaker 186 | target_phi = phi + target_angle / 360 * 2 * np.pi 187 | main_source = mic_pos + \ 188 | normal_vec(target_phi) * (self.rng.random() * 0.7 + 0.3) 189 | main_source[2] = self.rng.normal(1.60, 0.08) # height of speaker 190 | 191 | self.exp_room.add_source(main_source, signal[0], 0) 192 | self.dry_room.add_source(main_source, signal[0], 0) 193 | 194 | meta["target_file"] = speaker_list[0].split( 195 | "wsj0")[-1].replace("\\", "/") 196 | meta["n_samples"] = len(signal[0]) 197 | meta["target_pos"] = main_source.tolist() 198 | meta["target_angle"] = target_angle 199 | 200 | # interering speakers 201 | n_interfering = len(speaker_list) - 1 202 | for interf_idx, interf_path in enumerate(speaker_list[1:]): 203 | for moveback in np.arange(0, 8, 0.25): 204 | # if pos outside from room, move back to the microphone 205 | # distance max 7 m, min 1 m 206 | side_room_rad = 2*np.pi/360*side_room 207 | speaker_range = (2*np.pi-2*side_room_rad)/n_interfering 208 | 209 | interf_source = mic_pos + normal_vec( 210 | target_phi + side_room_rad + speaker_range * self.rng.random() + interf_idx * speaker_range) \ 211 | * max(1, self.rng.random() * 7 - moveback) 212 | 213 | # height of speaker is round about the height of standing people 214 | interf_source[2] = self.rng.normal(1.60, 0.08) 215 | if self.exp_room.room.is_inside(interf_source) and np.all(interf_source >= 0): 216 | # if inside room, no need to move further to the mic 217 | break 218 | 219 | self.exp_room.add_source(interf_source, signal[interf_idx + 1], 0) 220 | meta[f"interf{interf_idx}_file"] = interf_path.split( 221 | "wsj0")[-1].replace("\\", "/") 222 | meta[f"interf{interf_idx}_pos"] = interf_source.tolist() 223 | 224 | # return_premix allows separation of speaker signals 225 | mic_signals = self.exp_room.room.simulate(return_premix=True) 226 | dry_target_signal = self.dry_room.room.simulate(return_premix=True) 227 | 228 | reverb_target_signal = mic_signals[0, ...] 229 | noise_signal = np.sum(mic_signals[1:, ...], axis=0) 230 | dry_target_signal = dry_target_signal[0, ...] 231 | 232 | # scale to SNR 233 | if not snr_min is np.nan: 234 | target_snr = self.rng.uniform(snr_min, snr_max) 235 | noise_factor = snr_scale_factor( 236 | reverb_target_signal, noise_signal, target_snr) 237 | noise_signal = noise_signal * noise_factor 238 | 239 | meta["snr"] = target_snr 240 | 241 | return reverb_target_signal, noise_signal, dry_target_signal, meta 242 | 243 | def get_room(self): 244 | return self.exp_room.room 245 | 246 | def plot(self): 247 | self.exp_room.plot() 248 | 249 | 250 | def normal_vec(phi): 251 | return np.array([np.cos(phi), np.sin(phi), 0]) 252 | 253 | 254 | def snr_scale_factor(speech: np.ndarray, noise: np.ndarray, snr: int): 255 | """ 256 | Compute the scale factor that has to be applied to a noise signal in order for the noisy (sum of noise and clean) 257 | to have the specified SNR. 258 | 259 | :param speech: the clean speech signal [..., SAMPLES] 260 | :param noise: the noise signal [..., SAMPLES] 261 | :param snr: the SNR of the mixture 262 | :return: the scaling factor 263 | """ 264 | 265 | noise_var = np.mean(np.var(noise, axis=-1)) 266 | speech_var = np.mean(np.var(speech, axis=-1)) 267 | 268 | factor = np.sqrt( 269 | speech_var / np.maximum((noise_var * 10. ** (snr / 10.)), 10**(-6))) 270 | 271 | return factor 272 | 273 | 274 | def prep_speaker_mix_data(store_dir: str, 275 | post_fix: str = None, 276 | wsj0_path: str = 'whatever', 277 | n_channels: int = 3, 278 | n_interfering_speakers: int = 3, 279 | target_fs: int = 16000, 280 | num_files: dict = {'train': -1, 281 | 'val': -1, 282 | 'test': -1}, 283 | reverb: bool = True, 284 | target_angle: float = 0, 285 | side_room: int = 20, 286 | rt60_min=0.2, 287 | rt60_max=0.8, 288 | snr_min=-10, 289 | snr_max=5 290 | ): 291 | """ 292 | Preparation of speaker mix dataset. The target speaker is placed in a fixed position relative to the microphone array. The interfering speakers are placed randomly with one speaker per angle segment. 293 | 294 | :param store_dir: path to directory in which to store the dataset 295 | :param post_fix: postfix to specify the characteristics of the dataset 296 | :param wsj0_path: path the the raw WSJ0 data 297 | :param n_channels: number of channels in the microphone array 298 | :param n_interfering_speakers: the number of interfering speakers 299 | :param target_fs: the target sampling rate for the dataset 300 | :param num_files: a dictionary specifying the number of examples per stage 301 | :param reverb: turn off reverberation if set to False 302 | :param rt60_min: min RT60 time (uniformly sampled if reverb) 303 | :param rt60_max: max RT60 time (uniformly sampled if reverb) 304 | :param snr_min: min SNR (uniformly sampled) 305 | :param snr_max: max SNR (uniformely sampled) 306 | :param side_room: angle of closest interfering source (default: 20 deg) 307 | :return: 308 | """ 309 | prep_store_name = f"prep_mix{'_' + post_fix if post_fix else ''}.hdf5" 310 | 311 | train_samples = list( 312 | sorted(glob.glob(os.path.join(wsj0_path, 'si_tr_s/*/*.wav')))) 313 | val_samples = list( 314 | sorted(glob.glob(os.path.join(wsj0_path, 'si_dt_20/*/*.wav')))) 315 | test_samples = list( 316 | sorted(glob.glob(os.path.join(wsj0_path, 'si_et_05/*/*.wav')))) 317 | 318 | meta = {} 319 | with h5py.File(os.path.join(store_dir, prep_store_name), 'w') as prep_storage: 320 | for data_set, samples in (('train', train_samples), 321 | ('val', val_samples), 322 | ('test', test_samples)): 323 | if num_files[data_set] == 0: 324 | continue 325 | 326 | n_dataset_samples = num_files[data_set] if num_files[data_set] > 0 else len(samples) 327 | random.shuffle(samples) # pick random speakers 328 | 329 | MAX_SAMPLES_PER_FILE = 12 * target_fs 330 | audio_dataset = prep_storage.create_dataset(data_set, 331 | shape=( 332 | n_dataset_samples, 3, n_channels, MAX_SAMPLES_PER_FILE), 333 | chunks=( 334 | 1, 3, n_channels, MAX_SAMPLES_PER_FILE), 335 | dtype=np.float32, 336 | compression="gzip", 337 | shuffle=True) 338 | 339 | set_meta = {} 340 | 341 | sproom = SPRoomSimulator(channels=n_channels, mode=data_set) 342 | 343 | for target_idx, target_path in enumerate(samples[:n_dataset_samples]): 344 | 345 | # select interfering speakers 346 | interfering_speakers = random.choices( 347 | samples[:n_dataset_samples], k=n_interfering_speakers) 348 | 349 | 350 | reverb_target_signal, noise_signal, dry_target_signal, sample_meta = sproom.create_sample( 351 | speaker_list=[target_path] + interfering_speakers, 352 | seed2=target_idx, 353 | reverb=reverb, 354 | target_angle=target_angle, 355 | side_room=side_room, 356 | rt60_min=rt60_min, 357 | rt60_max=rt60_max, 358 | snr_min=snr_min, 359 | snr_max=snr_max) 360 | n_audio_samples = min( 361 | sample_meta['n_samples'], MAX_SAMPLES_PER_FILE) 362 | sample_meta['n_samples'] = n_audio_samples 363 | 364 | # store reverb clean 365 | audio_dataset[target_idx, 0, :, 366 | :n_audio_samples] = reverb_target_signal[:, :n_audio_samples] 367 | audio_dataset[target_idx, 0, 368 | :, n_audio_samples:MAX_SAMPLES_PER_FILE] = 0 369 | 370 | # store noise 371 | audio_dataset[target_idx, 1, :, 372 | :n_audio_samples] = noise_signal[:, :n_audio_samples] 373 | audio_dataset[target_idx, 1, 374 | :, n_audio_samples:MAX_SAMPLES_PER_FILE] = 0 375 | 376 | set_meta[target_idx] = sample_meta 377 | 378 | # store dry clean 379 | audio_dataset[target_idx, 2, :, 380 | :n_audio_samples] = dry_target_signal[:, :n_audio_samples] 381 | audio_dataset[target_idx, 2, 382 | :, n_audio_samples:MAX_SAMPLES_PER_FILE] = 0 383 | 384 | if target_idx % 10 == 0: 385 | print( 386 | f'{data_set}: {target_idx} of {n_dataset_samples}') 387 | 388 | meta[data_set] = set_meta 389 | 390 | with open(os.path.join(store_dir, f"prep_mix_meta{'_' + post_fix if post_fix else ''}.json"), 391 | 'w') as prep_meta_storage: 392 | json.dump(meta, prep_meta_storage, indent=4) 393 | 394 | 395 | if __name__ == '__main__': 396 | prep_speaker_mix_data(SIM_DATA_PATH, 397 | 'ch3_sp5_small', 398 | WSJ0_PATH, 399 | n_interfering_speakers=5, 400 | n_channels=3, 401 | num_files={'train': 6000, 'val': 1000, 'test': 600}, 402 | reverb=True, 403 | target_angle=0, 404 | rt60_min=0.2, 405 | rt60_max=0.5, 406 | snr_min=np.nan, 407 | snr_max=5, 408 | side_room=15) -------------------------------------------------------------------------------- /src/data/data_gen_var_pos.py: -------------------------------------------------------------------------------- 1 | import soundfile as sf 2 | from scipy.io import wavfile 3 | import matplotlib.pyplot as plt 4 | import pyroomacoustics as pra 5 | import glob 6 | import numpy as np 7 | import h5py 8 | import os 9 | import random 10 | random.seed(12345) 11 | from typing import List 12 | import json 13 | """ 14 | Dataset generation for a target speaker in a variable location. The target speaker angle is relative to the microphone array orientation. This preprocessing script creates a HDF5 file with three datasets: 15 | - train 16 | - val 17 | - test 18 | 19 | Each dataset has the shape [NUM_SAMPLES, 3, CHANNELS, MAX_SAMPLES_PER_FILE]. In the second axis, we store in this order the 20 | - spatialized target signal (includes reverb) 21 | - the spatialized noise signal (sum of all interfering speakers) 22 | - the dry target signal (including the time-shift caused by the direct path) 23 | 24 | 25 | The code in this file was partially written by Nils Mohrmann. 26 | """ 27 | 28 | # WSJ0 dataset path 29 | WSJ0_PATH = "/informatik2/sp/intern/databases_DEPRECATED/Good/WSJ/CSR-1-WSJ-0/WAV/wsj0" 30 | # Path where to save the simulated data 31 | SIM_DATA_PATH = "prep/conditional/" 32 | 33 | 34 | class RoomSimulation: 35 | 36 | def __init__(self, channels): 37 | """ 38 | Room simulation of sp acoustic lab with 3 dimensions 39 | 40 | :param channels: number of microphones in a uniform circular array 41 | """ 42 | self.channels = channels 43 | self.room_dim = np.array([9.3, 5.04, 2.84]) 44 | self.rt60_tgt = 0.3 # only a init value 45 | e_absorption, max_order = pra.inverse_sabine(self.rt60_tgt, self.room_dim) 46 | 47 | self.room = pra.ShoeBox(self.room_dim, 48 | fs=16000, 49 | materials=pra.Material(e_absorption), 50 | max_order=max_order) 51 | 52 | def set_room_properties(self, rt: float, room_dim: np.ndarray): 53 | """ 54 | Recreate room with a new reverberation time, this deletes all sources and mics. 55 | :param rt: reverberation time 56 | :param room_dim: room dimension ([x,y,z]) 57 | :return: None 58 | """ 59 | self.rt60_tgt = rt 60 | if self.rt60_tgt > 0: 61 | e_absorption, max_order = pra.inverse_sabine(self.rt60_tgt, room_dim) 62 | 63 | self.room = pra.ShoeBox(room_dim, 64 | fs=16000, 65 | materials=pra.Material(e_absorption), 66 | max_order=max_order) 67 | else: 68 | e_absorption, max_order = pra.inverse_sabine(0.5, room_dim) 69 | self.room = pra.ShoeBox(room_dim, 70 | fs=16000, 71 | materials=pra.Material(e_absorption), 72 | max_order=0) 73 | 74 | def reset(self): 75 | """ 76 | Reset the room, delete all sources and microphones 77 | Keeps RT60 78 | :return: None 79 | """ 80 | self.set_room_properties(self.rt60_tgt, self.room_dim) 81 | 82 | def set_microphones(self, x: float, y: float, z: float, phi: float, mic_offset: np.ndarray): 83 | """ 84 | Add microphone array at position xyz with rotation phi 85 | Radius: 0.05 m 86 | :param x: x pos 87 | :param y: y pos 88 | :param z: z pos 89 | :param phi: The counterclockwise rotation of the first element in the array (from the x-axis) 90 | :return: 91 | """ 92 | if self.channels == 2: 93 | # special orientation for 2 mics. -> speaker at broadside 94 | phi += np.pi / 2 95 | R = pra.beamforming.circular_2D_array([x, y], self.channels, phi0=phi, radius=0.05) 96 | R = np.vstack((R, [[z] * self.channels])) 97 | R += mic_offset 98 | self.room.add_microphone_array(pra.Beamformer(R, self.room.fs)) 99 | 100 | def add_source(self, position: np.ndarray, signal: np.ndarray, delay: float): 101 | """ 102 | Add signal source in room with a delay 103 | :param position: position [x, y, z] 104 | :param signal: The signal played by the source 105 | :param delay: A time delay until the source signal starts in the simulation 106 | :return: None 107 | """ 108 | self.room.add_source(position, signal, delay) 109 | 110 | def plot(self): 111 | """ 112 | Plot the room (see pyroomacoustic examples) 113 | :return: None 114 | """ 115 | fig, ax = self.room.plot() 116 | ax.set_xlim([0, 10]) 117 | ax.set_ylim([0, 6]) 118 | ax.set_zlim([0, 3]) 119 | ax.view_init(elev=90, azim=0) 120 | 121 | ## Now compute the delay and sum weights for the beamformer 122 | # room.mic_array.rake_delay_and_sum_weights(room.sources[0][:1]) 123 | ## plot the room and resulting beamformer 124 | # room.plot(freq=[1000, 2000, 4000, 8000], img_order=0) 125 | plt.show() 126 | 127 | # self.room.compute_rir() 128 | # room.plot_rir() 129 | # plt.plot(self.room.rir[1][0]) 130 | # plt.show() 131 | 132 | def measuretime(self): 133 | """ 134 | Get measured RT60 135 | :return: rt60 in seconds 136 | """ 137 | self.room.compute_rir() 138 | return self.room.measure_rt60() 139 | 140 | class SPRoomSimulator: 141 | """ 142 | Generate dataset for the training of the steerable non-linear filter. 143 | """ 144 | 145 | def __init__(self, channels=3, seed=13, mode="train"): 146 | self.training = True 147 | if mode == "train": 148 | path = "si_tr_s" 149 | elif mode == "val": 150 | path = "si_dt_20" 151 | elif mode == "test": 152 | path = "si_et_05" 153 | self.training = False 154 | path = f'{WSJ0_PATH}/{path}' 155 | 156 | self.speaker = glob.glob(path + "/*") 157 | 158 | self.rng = np.random.default_rng(seed) 159 | self.rng.shuffle(self.speaker) 160 | self.channels = channels 161 | self.exp_room = RoomSimulation(channels=self.channels) 162 | self.dry_room = RoomSimulation(channels=self.channels) 163 | self.fs = 16000 164 | 165 | def create_sample(self, 166 | speaker_list: List[str], 167 | seed2: int, 168 | target_angle: float = 0, 169 | reverb: bool = True, 170 | rt60_min:float=0.2, 171 | rt60_max:float=1, 172 | snr_min: int =-10, 173 | snr_max: int = 5, 174 | min_dist=0.8, 175 | max_dist=1.2, 176 | mic_pert_std: float = 0, 177 | min_angle_dist: int = 10, ): 178 | """ 179 | Create for a list of speech signals (first one is the target signal) the spatial image using a randomly placed 180 | microphone array and distributing the interfering speakers (len(speaker_list)-1) uniformly around the array. 181 | 182 | :param speaker_list: List of paths to speaker utterances 183 | :param seed2: Seed for the random audio files and positions 184 | :param target_angle: The DOA of the target speaker in degree. 185 | :param reverb: Create reverberant signals 186 | :param rt60_min, rt60_mx: The RT60 is sampled uniformly from the range (rt60_min, rt60_max) 187 | :param snr: The SNR is sampled uniformly from the range (snr_min, snr_max). The noise signal is rescaled to match the chosen SNR. If snr_min is None, no rescaling is performed. 188 | :param min_dist, max_dist: The range (min_dist, max_dist) from which the sources (also interfering sources) are sampled uniformly. Unit is meters. 189 | :param mic_pert_std: Add noise to the microphone positions sampled from a Gaussian with zero mean and specified standard deviation. Unit is cm. 190 | :param min_angle_dist: Minimum angle distance between two sources (target-interfering and interfering-interfering) 191 | 192 | :return: the audio signals as numpy array [N_SPEAKERS, N_CHANNELS, N_SAMPLES] and corresponding meta data 193 | """ 194 | # set seed for this sample 195 | self.rng = np.random.default_rng(seed2) 196 | meta = {} 197 | 198 | signal = [] 199 | for file in speaker_list: 200 | audio, fs = sf.read(file) 201 | signal.append(audio / np.max(np.abs(audio)) * 0.3) 202 | 203 | # ensure noise signal is long enough and does not start with zeros always 204 | offset_indices = np.random.randint(low=-8000, high=8000, size=len(speaker_list)-1) 205 | target_signal_len = len(signal[0]) 206 | for i in range(len(speaker_list)-1): 207 | new_signal = np.roll(np.resize(signal[1+i], target_signal_len), shift=offset_indices[i]) 208 | signal[1+i] = new_signal 209 | 210 | # room configuration 211 | RT = self.rng.uniform(rt60_min, rt60_max) if reverb else 0 212 | meta["rt"] = RT 213 | 214 | room_dim = np.squeeze(np.array([self.rng.uniform(2.5,5,1), self.rng.uniform(3,9,1), self.rng.uniform(2.2, 3.5, 1)])) 215 | meta["room_dim"] = [room_dim[0], room_dim[1], room_dim[2]] 216 | 217 | self.exp_room.set_room_properties(RT, np.array(room_dim)) 218 | self.dry_room.set_room_properties(0, np.array(room_dim)) 219 | 220 | # mic array at random position in room (min 1.2 m to wall) 221 | mic_pos = self.rng.random(3) * (room_dim - 2.42) + 1.21 222 | mic_pos[2] = 1.5 223 | 224 | if mic_pert_std > 0: 225 | mic_offset = self.rng.normal(loc=0, scale=mic_pert_std, size=(3, self.channels)) 226 | else: 227 | mic_offset = np.zeros((3, self.channels)) 228 | 229 | phi = self.rng.random() * 2 * np.pi # microphone rotation 230 | self.exp_room.set_microphones(mic_pos[0], mic_pos[1], mic_pos[2], phi, mic_offset) 231 | self.dry_room.set_microphones(mic_pos[0], mic_pos[1], mic_pos[2], phi, mic_offset) 232 | 233 | meta["mic_pos"] = mic_pos.tolist() 234 | meta["mic_phi"] = phi 235 | 236 | # target speaker 237 | target_phi = phi + target_angle/ 360 * 2 * np.pi 238 | speaker_phis = [target_phi] 239 | main_source = mic_pos + normal_vec(target_phi) * ((self.rng.random() * (max_dist-min_dist) + min_dist)) 240 | main_source[2] = self.rng.normal(1.60, 0.08) # height of speaker 241 | 242 | self.exp_room.add_source(main_source, signal[0], 0) 243 | self.dry_room.add_source(main_source, signal[0], 0) 244 | 245 | meta["target_file"] = speaker_list[0].split("wsj0")[-1].replace("\\", "/") 246 | meta["n_samples"] = len(signal[0]) 247 | meta["target_pos"] = main_source.tolist() 248 | meta["target_angle"] = target_angle 249 | n_interfering = len(speaker_list) - 1 250 | for interf_idx, interf_path in enumerate(speaker_list[1:]): 251 | 252 | # distance max 1.2 m, min 0.8 m 253 | min_angle_dist_rad = 2*np.pi/360*min_angle_dist 254 | speaker_range = (2*np.pi-2*min_angle_dist_rad)/n_interfering 255 | 256 | too_close = True 257 | while too_close: # make sure the selected angle is not too close to other source 258 | speaker_phi = target_phi + min_angle_dist_rad + speaker_range * self.rng.random() + interf_idx * speaker_range 259 | interf_source = mic_pos + normal_vec(speaker_phi) * (self.rng.random() * (max_dist-min_dist) + min_dist) 260 | 261 | # height of speaker is round about the height of standing people 262 | interf_source[2] = self.rng.normal(1.60, 0.08) 263 | 264 | if len(speaker_phis) == 0: 265 | too_close = False 266 | speaker_phis.append(speaker_phi) 267 | else: 268 | if speaker_phi - speaker_phis[-1] < np.deg2rad(min_angle_dist) or (speaker_phis[0] + 2*np.pi)- speaker_phi < np.deg2rad(min_angle_dist): 269 | # previous speaker or first speaker too close 270 | too_close = True 271 | else: 272 | too_close = False 273 | speaker_phis.append(speaker_phi) 274 | 275 | self.exp_room.add_source(interf_source, signal[interf_idx + 1], 0) 276 | meta[f"interf{interf_idx}_file"] = interf_path.split("wsj0")[-1].replace("\\", "/") 277 | meta[f"interf{interf_idx}_pos"] = interf_source.tolist() 278 | 279 | # return_premix allows separation of speaker signals 280 | self.exp_room.room.compute_rir() 281 | mic_signals = self.exp_room.room.simulate(return_premix=True) 282 | 283 | # direct path target 284 | self.dry_room.room.compute_rir() 285 | target_signal = self.dry_room.room.simulate(return_premix=True) 286 | 287 | #scale to SNR 288 | reverb_target_signal = mic_signals[0, ...] 289 | noise_signal = np.sum(mic_signals[1:, ...], axis=0) 290 | target_signal = target_signal[0, ...] 291 | 292 | if not snr_min is np.nan: 293 | target_snr = self.rng.uniform(snr_min, snr_max) 294 | noise_factor = snr_scale_factor(reverb_target_signal, noise_signal, target_snr) 295 | noise_signal = noise_signal * noise_factor 296 | 297 | meta["snr"] = target_snr 298 | 299 | return reverb_target_signal, noise_signal, target_signal, meta 300 | 301 | def get_room(self): 302 | return self.exp_room.room 303 | 304 | def plot(self): 305 | self.exp_room.plot() 306 | 307 | def normal_vec(phi): 308 | return np.array([np.cos(phi), np.sin(phi), 0]) 309 | 310 | def snr_scale_factor(speech: np.ndarray, noise: np.ndarray, snr: int): 311 | """ 312 | Compute the scale factor that has to be applied to a noise signal in order for the noisy (sum of noise and clean) 313 | to have the specified SNR. 314 | 315 | :param speech: the clean speech signal [..., SAMPLES] 316 | :param noise: the noise signal [..., SAMPLES] 317 | :param snr: the SNR of the mixture 318 | :return: the scaling factor 319 | """ 320 | 321 | noise_var = np.mean(np.var(noise, axis=-1)) 322 | speech_var = np.mean(np.var(speech, axis=-1)) 323 | 324 | factor = np.sqrt(speech_var / np.maximum((noise_var * 10. ** (snr / 10.)), 10**(-6))) 325 | 326 | return factor 327 | 328 | def prep_speaker_mix_data(store_dir: str, 329 | post_fix: str = None, 330 | wsj0_path: str = 'whatever', 331 | n_channels: int = 3, 332 | n_interfering_speakers: int = 3, 333 | target_fs: int = 16000, 334 | num_files: dict = {'train': -1, 335 | 'val': -1, 336 | 'test': -1}, 337 | angle_settings: dict = None, 338 | reverb: bool = True, 339 | side_room: int = 10, 340 | rt60_min=0.2, 341 | rt60_max=0.8, 342 | snr_min=-10, 343 | snr_max=5, 344 | mic_pert=0, 345 | min_dist=0.8, 346 | max_dist=1.2, 347 | ): 348 | """ 349 | Preparation of speaker mix dataset. The target speaker is placed in a fixed position relative to the microphone 350 | array. The interfering speakers are placed randomly with one speaker per angle segment. 351 | 352 | If angle_settings are provided, the function can also create a dataset with a moving speaker placed 353 | on a range of angles. 354 | 355 | :param store_dir: path to directory in which to store the dataset 356 | :param post_fix: postfix to specify the characteristics of the dataset 357 | :param wsj0_path: path the the raw WSJ0 data 358 | :param n_channels: number of channels in the microphone array 359 | :param n_interfering_speakers: the number of interfering speakers 360 | :param target_fs: the target sampling rate for the dataset 361 | :param num_files: a dictionary specifying the number of examples per stage 362 | :param angle_settings: a dict {'start': -45, 'stop': 45, 'step': 1, 'n_samples_per_angle': 100} 363 | :param reverb: turn off reverberation if set to False 364 | :param rt60_min: min RT60 time (uniformly sampled if reverb) 365 | :param rt60_max: max RT60 time (uniformly sampled if reverb) 366 | :param snr_min: min SNR (uniformly sampled) 367 | :param snr_max: max SNR (uniformely sampled) 368 | :param side_room: minimum angle difference between two sources (default: 10 deg) 369 | 370 | """ 371 | prep_store_name = f"prep_mix{'_' + post_fix if post_fix else ''}.hdf5" 372 | 373 | train_samples = list(sorted(glob.glob(os.path.join(wsj0_path, 'si_tr_s/*/*.wav')))) 374 | val_samples = list(sorted(glob.glob(os.path.join(wsj0_path, 'si_dt_20/*/*.wav')))) 375 | test_samples = list(sorted(glob.glob(os.path.join(wsj0_path, 'si_et_05/*/*.wav')))) 376 | 377 | n_angles = len(range(angle_settings['start'], 378 | angle_settings['stop'], angle_settings['step'])) 379 | 380 | meta = {} 381 | with h5py.File(os.path.join(store_dir, prep_store_name), 'w') as prep_storage: 382 | for data_set, samples in (('train', train_samples), 383 | ('val', val_samples), 384 | ('test', test_samples)): 385 | if num_files[data_set] == 0: 386 | continue 387 | n_dataset_samples = num_files[data_set] if num_files[data_set] > 0 else len(samples) 388 | 389 | # Variable target speaker position distributed over some range 390 | n_dataset_samples_full = n_dataset_samples*n_angles 391 | angle_start = angle_settings['start'] 392 | angle_stop = angle_settings['stop'] 393 | angle_step = angle_settings['step'] 394 | 395 | MAX_SAMPLES_PER_FILE = 12 * target_fs 396 | audio_dataset = prep_storage.create_dataset(data_set, 397 | shape=(n_dataset_samples_full, 3, n_channels, MAX_SAMPLES_PER_FILE), 398 | chunks=(1, 3, n_channels, MAX_SAMPLES_PER_FILE), 399 | dtype=np.float32, 400 | compression="gzip", 401 | shuffle=True) 402 | 403 | set_meta = {} 404 | 405 | sproom = SPRoomSimulator(channels=n_channels, mode=data_set) 406 | 407 | for i, fixed_angle in enumerate(range(angle_start, angle_stop, angle_step)): 408 | random.shuffle(samples) # pick random speakers 409 | for target_idx, target_path in enumerate(samples[:n_dataset_samples]): 410 | 411 | interfering_speakers = random.choices(samples, k=n_interfering_speakers) 412 | reverb_target_signal, noise_signal, dry_target_signal, sample_meta = sproom.create_sample( 413 | speaker_list=[target_path] + interfering_speakers, 414 | seed2=i*n_dataset_samples+target_idx, 415 | reverb = reverb, 416 | target_angle= fixed_angle, 417 | min_angle_dist = side_room, 418 | rt60_min=rt60_min, 419 | rt60_max=rt60_max, 420 | snr_min=snr_min, 421 | snr_max=snr_max, 422 | mic_pert_std = mic_pert, 423 | min_dist=min_dist, 424 | max_dist=max_dist,) 425 | n_audio_samples = min(sample_meta['n_samples'], MAX_SAMPLES_PER_FILE) 426 | sample_meta['n_samples'] = n_audio_samples 427 | sample_meta['target_dir'] = fixed_angle 428 | 429 | # store reverb clean 430 | audio_dataset[i*n_dataset_samples+target_idx, 0, :, :n_audio_samples] = reverb_target_signal[:, :n_audio_samples] 431 | audio_dataset[i*n_dataset_samples+target_idx, 0, :, n_audio_samples:MAX_SAMPLES_PER_FILE] = 0 432 | 433 | # store noise 434 | audio_dataset[i*n_dataset_samples+target_idx, 1, :, :n_audio_samples] = noise_signal[:, :n_audio_samples] 435 | audio_dataset[i*n_dataset_samples+target_idx, 1, :, n_audio_samples:MAX_SAMPLES_PER_FILE] = 0 436 | 437 | set_meta[i*n_dataset_samples+target_idx] = sample_meta 438 | 439 | # store dry clean 440 | audio_dataset[i*n_dataset_samples+target_idx, 2, :, :n_audio_samples] = dry_target_signal[:, :n_audio_samples] 441 | audio_dataset[i*n_dataset_samples+target_idx, 2, :, n_audio_samples:MAX_SAMPLES_PER_FILE] = 0 442 | 443 | if target_idx % 10 == 0: 444 | print(f'{data_set}/{fixed_angle}: {target_idx} of {n_dataset_samples}') 445 | 446 | meta[data_set] = set_meta 447 | with open(os.path.join(store_dir, f"prep_mix_meta{'_' + post_fix if post_fix else ''}.json"), 448 | 'w') as prep_meta_storage: 449 | json.dump(meta, prep_meta_storage, indent=4) 450 | 451 | 452 | if __name__ == '__main__': 453 | channels = 3 454 | prefix = f'ch{channels}_sp5_var_target' 455 | 456 | store_path = os.path.join(SIM_DATA_PATH) 457 | if not os.path.exists(store_path): 458 | os.makedirs(store_path) 459 | prep_speaker_mix_data(store_path, 460 | prefix, 461 | WSJ0_PATH, 462 | n_interfering_speakers=5, 463 | n_channels=channels, 464 | num_files={'train': 300, 'val': 15, 'test': 10}, # files per angle! 465 | angle_settings={'start': -180, 'stop': 180, 'step': 2}, 466 | reverb=True, 467 | rt60_min=0.2, 468 | rt60_max=0.5, 469 | snr_min=np.nan, 470 | snr_max=np.nan, 471 | mic_pert = 0, 472 | min_dist=0.8, 473 | max_dist=1.2, 474 | side_room=10, 475 | ) 476 | -------------------------------------------------------------------------------- /src/data/datamodule.py: -------------------------------------------------------------------------------- 1 | from data.dataset import MixDataset 2 | from typing import List 3 | from torch.utils.data import DataLoader 4 | import pytorch_lightning as pl 5 | 6 | 7 | class HDF5DataModule(pl.LightningDataModule): 8 | 9 | def __init__(self, 10 | n_channels: int, 11 | batch_size: int, 12 | prep_files: dict, 13 | stft_length_samples: int, 14 | stft_shift_samples: int, 15 | snr_range: List[int], 16 | meta_frame_length: int, 17 | n_workers: int, 18 | dry_target: bool = True, 19 | target_dir = 0, 20 | noise_snr: List[int] = None, 21 | fs: int = 16000 22 | ): 23 | """ 24 | Init the data module for the simulated mixture dataset. 25 | 26 | :param n_channels: number of channels in the microphone array (must be smaller or equal to prep_files) 27 | :param batch_size: the batch size 28 | :param prep_files: a dictionary specifying the HDF5 data file and meta data file generated by the preprocessing.py files. Keys are 'data', 'meta', 'train_data', 'train_meta', 'val_data', 'val_meta', 'test_data' and 'test_meta'. The 'data' and 'meta' keys serve as default if other keys are not specified. 29 | :param meta_frame_length: the metaframe length (e.g. randomly cut 1 second of data -> 16000 samples per metaframe) 30 | :param dry_target: use dry clean signal if True else use reverberant signal 31 | :param fs: sampling rate (default 16 kHz) 32 | :param snr_range: the list of possible SNRs (sampled from for training example generation) 33 | :param noise_snr: a list of SNRs for additive white noise (no noise if None) 34 | :param n_workers: number of workers for data loading 35 | """ 36 | super().__init__() 37 | 38 | self.batch_size = batch_size 39 | self.fs = fs 40 | self.meta_frame_len = meta_frame_length 41 | self.snr_range = snr_range 42 | 43 | self.n_channels = n_channels 44 | self.stft_len = stft_length_samples 45 | self.stft_shift = stft_shift_samples 46 | 47 | self.target_dir = target_dir 48 | self.noise_snr = noise_snr 49 | 50 | self.n_workers = n_workers 51 | 52 | self.train_dataset = MixDataset(stage='train', 53 | prep_files=prep_files, 54 | n_channels=self.n_channels, 55 | meta_frame_length=self.meta_frame_len, 56 | disable_random=False, 57 | snr_range = self.snr_range, 58 | dry_target=dry_target, 59 | target_dir= self.target_dir, 60 | noise_snr=self.noise_snr) 61 | self.val_dataset = MixDataset(stage='val', 62 | prep_files=prep_files, 63 | n_channels=self.n_channels, 64 | meta_frame_length=self.meta_frame_len, 65 | disable_random=True, 66 | snr_range = self.snr_range, 67 | dry_target=dry_target, 68 | target_dir= self.target_dir, 69 | noise_snr=self.noise_snr) 70 | 71 | 72 | def train_dataloader(self): 73 | return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.n_workers, shuffle=True) 74 | 75 | def val_dataloader(self): 76 | return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.n_workers, shuffle=False) -------------------------------------------------------------------------------- /src/data/dataset.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import json 3 | import random 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from typing import Literal, List 7 | 8 | 9 | class MixDataset(Dataset): 10 | """ 11 | Provides access to the utterances of a simulated dataset (noisy, clean, and noise images for n_channels). 12 | 13 | Data: 14 | The underlying HDF5 file provides access to 15 | 1. The clean speech signal (low reverberation, but propagation delay to the microphones included) 16 | 2. The noise signal (including reverberation) 17 | 3. The clean speech image (including reverberation) 18 | The dataset provides 19 | - noisy (clean image + noise) 20 | - noise 21 | - target (clean or clean image depending on dry_target variable) 22 | The SNR of the dataset itself if used if the snr_range is None and otherwise the signals are rescaled to match one of the specified SNR values. 23 | """ 24 | def __init__(self, 25 | stage: Literal['train', 'val', 'test'], 26 | prep_files: dict, 27 | n_channels: int, 28 | meta_frame_length: int, 29 | dry_target: bool, 30 | disable_random: bool, 31 | snr_range: List[int] = None, 32 | target_dir = 0, 33 | noise_snr: List[int] = None 34 | ): 35 | """ 36 | Initialize the dataset. 37 | 38 | :param stage: the dataset stage ('train', 'val', 'test) 39 | :param prep_files: a dictionary specifying the HDF5 data file and meta data file generated by the preprocessing.py files. Keys are 'data', 'meta', 'train_data', 'train_meta', 'val_data', 'val_meta', 'test_data' and 'test_meta'. The 'data' and 'meta' keys serve as default if other keys are not specified. 40 | :param n_channels: the number of microphone channels 41 | :param meta_frame_length: the metaframe length (e.g. randomly cut 1 second of data -> 16000 samples per metaframe 42 | :param dry_target: use the dry signal as target (as opposed to the reverberant) 43 | :param has_dry: if set to False the given clean 44 | :param disable_random: turn off random metaframe selection in every epoch (only makes sense for validation or training debugging) 45 | :param target_dir: angle of the target direction 46 | :param noise_snr: a list of SNRs for additive white noise (no noise if None) 47 | """ 48 | self.stage = stage 49 | 50 | self.data_prep_path = prep_files.get('data', None) 51 | self.meta_prep_path = prep_files.get('meta', None) 52 | 53 | stage_data_path = prep_files.get(f'{stage}_data', None) 54 | stage_meta_path = prep_files.get(f'{stage}_meta', None) 55 | 56 | if stage_data_path: 57 | self.data_prep_path = stage_data_path 58 | if stage_meta_path: 59 | self.meta_prep_path = stage_meta_path 60 | if self.data_prep_path is None: 61 | raise ValueError(f'Specified prep paths are not valid: {prep_files}') 62 | if self.meta_prep_path is None: 63 | raise ValueError(f'Specified prep paths are not valid: {prep_files}') 64 | 65 | self.n_channels = n_channels 66 | 67 | self.meta_frame_length = meta_frame_length 68 | 69 | self.use_dry_target = dry_target 70 | 71 | self.snr_range = snr_range 72 | self.noise_snr = noise_snr 73 | 74 | self.target_dir = target_dir 75 | 76 | with h5py.File(self.data_prep_path, 'r') as prep_file: 77 | self.n_samples = prep_file[stage].shape[0] 78 | with open(self.meta_prep_path, 'r') as meta_file: 79 | self.meta_data = json.load(meta_file)[stage] 80 | 81 | self.disable_random = disable_random 82 | self.start_idxs = None 83 | if disable_random: 84 | self._init_seg_start_idxs() 85 | 86 | def _open_hdf5(self): 87 | self.prep_file = h5py.File(self.data_prep_path, 'r') 88 | # self.audio_data = self.prep_file[self.stage] 89 | 90 | def _init_seg_start_idxs(self): 91 | """ 92 | Initializes the random cut for every utterance. Can be used to disable random segment selection during 93 | validation which can help to visualize improvements. 94 | """ 95 | start_idxs = {} 96 | for i in range(self.n_samples): 97 | data = self.__getitem__(i) 98 | start_idxs[i] = data['start_idx'] 99 | self.start_idxs = start_idxs 100 | 101 | def __len__(self): 102 | return self.n_samples 103 | 104 | def __getitem__(self, idx, start_idx: int = -1): 105 | """ 106 | Get sample for given index starting and random position not specified otherwise. 107 | 108 | :param idx: the sample index 109 | :param start_idx: the start index for some random cut (random if -1 is given) 110 | :return: clean, noise, and noisy arrays [CHANNEL, SAMPLES], snr and start index 111 | """ 112 | 113 | if self.disable_random and not self.start_idxs is None: 114 | start_idx = self.start_idxs[idx] 115 | reverb_clean_audio, dry_clean_audio, noise_audio, start_idx = self._read_audio_segment(idx, start_idx) 116 | 117 | if not self.snr_range is None and not len(self.snr_range) == 0: 118 | snr = random.choice(self.snr_range) 119 | noise_scale = snr_scale_factor(reverb_clean_audio, noise_audio, snr) 120 | else: 121 | noise_scale = 1 122 | 123 | mix_td = reverb_clean_audio + noise_scale * noise_audio 124 | 125 | if not self.noise_snr is None: 126 | snr = random.choice(self.noise_snr) 127 | noise = np.random.randn(*reverb_clean_audio.shape) 128 | noise_scale = snr_scale_factor(mix_td, noise, snr) 129 | mix_td += noise_scale * noise 130 | 131 | target_dir = self.meta_data[str(idx)].get('target_dir', self.target_dir) 132 | 133 | return {'noisy_td': mix_td, 134 | 'clean_td': dry_clean_audio if self.use_dry_target else reverb_clean_audio, 135 | 'reverb_clean_td': reverb_clean_audio, 136 | 'noise_td': noise_audio, 137 | 'start_idx': start_idx, 138 | 'sample_idx': idx, 139 | 'target_dir': target_dir} 140 | 141 | def get_utterance(self, idx): 142 | """ 143 | Get the full stored utterance from the datast. 144 | 145 | :param idx: the sample idex 146 | :return: noisy, clean and noise utterance and sample name 147 | """ 148 | sample_name = self.meta_data[str(idx)].get('name', f'sample_{idx}') 149 | n_samples = self.meta_data[str(idx)]['n_samples'] 150 | 151 | reverb_clean_audio, dry_clean_audio, noise_audio = self._read_audio(idx) 152 | 153 | target_dir = self.meta_data[str(idx)].get('target_dir', self.target_dir) 154 | 155 | return (reverb_clean_audio + noise_audio)[:, :n_samples], \ 156 | (dry_clean_audio if self.use_dry_target else reverb_clean_audio)[:, :n_samples], \ 157 | noise_audio[:, :n_samples], \ 158 | sample_name, \ 159 | target_dir 160 | 161 | def _read_audio_segment(self, idx, start_idx): 162 | """ 163 | Get clean and noise signal segments for given index starting at random position if not specified otherwise. 164 | 165 | :param idx: the sample index 166 | :param start_idx: the start index for some random cut (random if -1 is given) 167 | :return: clean and noise arrays [CHANNEL, SAMPLES] and start index 168 | """ 169 | if not hasattr(self, 'prep_file'): 170 | self._open_hdf5() 171 | 172 | audio = self.prep_file[self.stage][idx] 173 | n_samples = min(self.meta_data[str(idx)]['n_samples'], audio.shape[-1]) 174 | 175 | if self.meta_frame_length < 0: 176 | return audio[..., :n_samples], 0 # return full audio if meta_frame_length is -1 177 | 178 | possible_start = n_samples - self.meta_frame_length 179 | 180 | if possible_start < 0: # example shorter than selected meta_frame length 181 | return np.concatenate((audio[0, :self.n_channels, :n_samples], 182 | np.zeros((self.n_channels, self.meta_frame_length - n_samples), dtype=np.float32)), 183 | axis=-1), \ 184 | np.concatenate((audio[2, :self.n_channels, :n_samples], 185 | np.zeros((self.n_channels, self.meta_frame_length - n_samples), dtype=np.float32)), 186 | axis=-1), \ 187 | np.concatenate((audio[1, :self.n_channels, :n_samples], 188 | np.zeros((self.n_channels, self.meta_frame_length - n_samples), dtype=np.float32)), 189 | axis=-1), \ 190 | 0 191 | elif start_idx >= 0: # use given start index 192 | return audio[0, :self.n_channels, start_idx:start_idx + self.meta_frame_length], \ 193 | audio[2, :self.n_channels, start_idx:start_idx + self.meta_frame_length], \ 194 | audio[1, :self.n_channels, start_idx:start_idx + self.meta_frame_length], \ 195 | start_idx 196 | else: # cut random metaframe from utterance 197 | start = random.randint(0, possible_start) 198 | end = start + self.meta_frame_length 199 | return audio[0, :self.n_channels, start: end], \ 200 | audio[2, :self.n_channels, start: end], \ 201 | audio[1, :self.n_channels, start: end], \ 202 | start 203 | 204 | def _read_audio(self, idx): 205 | """ 206 | Get full clean and noise utterance for the given index. 207 | 208 | :param idx: the sample index 209 | :return: clean and noise arrays [CHANNEL, SAMPLES] 210 | """ 211 | if not hasattr(self, 'prep_file'): 212 | self._open_hdf5() 213 | 214 | audio = self.prep_file[self.stage][idx] 215 | 216 | return audio[0], audio[2], audio[1] 217 | 218 | 219 | def snr_scale_factor(speech: np.ndarray, noise: np.ndarray, snr: int): 220 | """ 221 | Compute the scale factor that has to be applied to a noise signal in order for the noisy (sum of noise and clean) 222 | to have the specified SNR. 223 | 224 | :param speech: the clean speech signal [..., SAMPLES] 225 | :param noise: the noise signal [..., SAMPLES] 226 | :param snr: the SNR of the mixture 227 | :return: the scaling factor 228 | """ 229 | 230 | noise_var = np.mean(np.var(noise, axis=-1)) 231 | speech_var = np.mean(np.var(speech, axis=-1)) 232 | 233 | factor = np.sqrt(speech_var / np.maximum((noise_var * 10. ** (snr / 10.)), 10**(-6))) 234 | 235 | return factor 236 | 237 | def target_level_scale_factor(audio: np.ndarray, target_level: int, eps: float = 1e-6): 238 | """ 239 | Compute the scale factor that has to be applied to the signal to normalize to the specified target level. 240 | 241 | :param audio: the time domain signal [..., SAMPLES] 242 | :param target_level: the target level in db 243 | :return: the scaling factor 244 | """ 245 | rms = np.mean(np.square(np.abs(audio))) 246 | scale = np.sqrt((10 ** (target_level / 10)) / rms) / np.abs(audio).max() 247 | return scale 248 | -------------------------------------------------------------------------------- /src/models/exp_enhancement.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from typing import List, Union, Literal 6 | from utils.log_images import make_image_grid 7 | from torch.optim import Adam 8 | 9 | 10 | class EnhancementExp(pl.LightningModule): 11 | 12 | def __init__(self, 13 | model: nn.Module, 14 | cirm_comp_K: float, 15 | cirm_comp_C: float, 16 | scheduler_type: str = None, 17 | scheduler_params: dict = None 18 | ): 19 | super(EnhancementExp, self).__init__() 20 | 21 | self.model = model 22 | 23 | self.cirm_K = cirm_comp_K 24 | self.cirm_C = cirm_comp_C 25 | 26 | self.scheduler_type = scheduler_type 27 | self.scheduler_params = scheduler_params 28 | 29 | def forward(self, input): 30 | pass 31 | 32 | def training_step(self, batch, batch_idx): 33 | return self.shared_step(batch, batch_idx, stage='train') 34 | 35 | def validation_step(self, batch, batch_idx): 36 | return self.shared_step(batch, batch_idx, stage='val') 37 | 38 | 39 | def shared_step(self, batch, batch_idx, stage: Literal['train', 'val']): 40 | pass 41 | 42 | def loss(self, clean_td, est_clean_td, noise_td, est_noise_td, 43 | clean_stft, est_clean_stft, noise_stft, est_noise_stft): 44 | """ 45 | Compute the loss based on L1-norms of time domain speech and noise signals and frequency magnitudes. 46 | 47 | :param clean_td: target clean signal in time domain 48 | :param est_clean_td: estimated clean signal in time domain 49 | :param noise_td: target noise signal in time domain 50 | :param est_noise_td: estimated noise signal in time domain 51 | :param clean_stft: target clean signal in STFT domain 52 | :param est_clean_stft: estimated clean signal in STFT domain 53 | :param noise_stft: target noise signal in STFT domain 54 | :param est_noise_stft: estimated noise signal in STFT domain 55 | :return: four loss terms based on L1-loss 56 | """ 57 | clean_td_loss = torch.mean(torch.abs(clean_td - est_clean_td), dim=1) 58 | noise_td_loss = torch.mean(torch.abs(noise_td - est_noise_td), dim=1) 59 | clean_mag_loss = torch.mean(torch.abs(torch.abs(clean_stft) - torch.abs(est_clean_stft))) 60 | noise_mag_loss = torch.mean(torch.abs(torch.abs(noise_stft) - torch.abs(est_noise_stft))) 61 | 62 | return clean_td_loss, noise_td_loss, clean_mag_loss, noise_mag_loss 63 | 64 | def compute_global_si_sdr(self, est_clean_td, clean_td): 65 | """ 66 | Compute the SI-SDR for a whole utterance. 67 | 68 | :param enhanced_td: estimated clean signal in the time domain 69 | :param clean_td: clean signal in the time domain 70 | """ 71 | 72 | def si_sdr(s, s_hat): 73 | alpha = torch.einsum('cs,cs->c', s_hat, s) / torch.einsum('cs,cs->c', s, s) 74 | scaled_ref = torch.unsqueeze(alpha, dim=1) * s 75 | sdr = 10 * torch.log10(torch.einsum('cs,cs->c', scaled_ref, scaled_ref) / ( 76 | torch.einsum('cs,cs->c', scaled_ref - s_hat, scaled_ref - s_hat) + 1e-14)) 77 | return sdr 78 | 79 | enhanced_si_sdr = si_sdr(clean_td, est_clean_td) 80 | 81 | return enhanced_si_sdr 82 | 83 | def configure_optimizers(self): 84 | 85 | opt = Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 86 | 87 | if self.scheduler_type == "ReduceLROnPLateau": 88 | lr_scheduler={ 89 | 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt, **self.scheduler_params), 90 | 'name': 'lr_schedule', 91 | 'monitor': 'monitor_loss' 92 | } 93 | return opt, lr_scheduler 94 | if self.scheduler_type == "MultiStepLR": 95 | lr_scheduler = { 96 | 'scheduler': torch.optim.lr_scheduler.MultiStepLR(optimizer=opt, **self.scheduler_params), 97 | 'name': 'lr_schedule' 98 | } 99 | return {"optimizer": opt, "lr_scheduler": lr_scheduler} 100 | 101 | return opt 102 | 103 | def get_complex_masks_from_stacked(self, real_mask): 104 | """ 105 | Construct the complex clean speech and noise mask from the estimated stacked clean speech mask. Inverts the 106 | compression by tanh output activation to the range [-inf, inf] for real and imaginary components. 107 | 108 | :param real_mask: estimated mask with stacked real and imaginary components [BATCH, 2, F, T] 109 | :return: the complex masks [B, F, T] 110 | """ 111 | compressed_complex_speech_mask = real_mask[:, 0, ...] + (1j) * real_mask[:, 1, ...] 112 | 113 | complex_speech_mask = (-1 / self.cirm_C) * torch.log( 114 | (self.cirm_K - self.cirm_K * compressed_complex_speech_mask) / ( 115 | self.cirm_K + self.cirm_K * compressed_complex_speech_mask)) 116 | complex_noise_mask = (1 - torch.real(complex_speech_mask)) - (1j) * torch.imag(complex_speech_mask) 117 | 118 | return complex_speech_mask, complex_noise_mask 119 | 120 | def get_stft_rep(self, *td_signals, return_complex=True): 121 | """ 122 | Compute the STFT for the given time-domain signals. 123 | 124 | :param *td_signals: the time-domain signals (list of arrays [CHANNEL, SAMPLES] or [SAMPLES]) 125 | :return: list of stfts [BATCH, CHANNEL, FREQ, TIME] 126 | """ 127 | result = [] 128 | window = torch.sqrt(torch.hann_window(self.stft_length)).to(device=self.device) 129 | #window = torch.hann_window(self.stft_length).to(device=self.device) 130 | for td_signal in td_signals: 131 | if len(td_signal.shape) == 1: # single-channel 132 | stft = torch.stft(td_signal, self.stft_length, self.stft_shift, window=window, center=True, 133 | onesided=True, return_complex=return_complex) 134 | result.append(stft) 135 | else: # multi-channel and/or multiple speakers 136 | signal_shape = td_signal.shape 137 | reshaped_signal = td_signal.reshape((signal_shape[:-1].numel(), signal_shape[-1])) 138 | stfts = torch.stft(reshaped_signal, self.stft_length, self.stft_shift, window=window, center=True, onesided=True, return_complex=return_complex) 139 | if return_complex: 140 | combined_dim, freq_dim, time_dim = stfts.shape 141 | stfts = stfts.reshape(signal_shape[:-1]+(freq_dim, time_dim)) 142 | else: 143 | comined_dim, freq_dim, time_dim, complex_dim = stfts.shape 144 | stfts = stfts.reshape(signal_shape[:-1]+(freq_dim, time_dim, complex_dim)) 145 | result.append(stfts) 146 | 147 | return result 148 | 149 | def get_td_rep(self, *stfts): 150 | """ 151 | Compute the time domain represetnation for the given STFTs. 152 | :param stfts: list of STFTs [BATCH, FREQ, TIME] 153 | :return: list of time domain signals [BATCH, SAMPLES] 154 | """ 155 | result = [] 156 | window = torch.sqrt(torch.hann_window(self.stft_length)).to(device=self.device) 157 | for stft in stfts: 158 | has_complex_dim = stft.shape[-1] == 2 159 | if (not has_complex_dim and len(stft.shape) <= 3) or (has_complex_dim and len(stft.shape) <= 4): # single-channel 160 | td_signal = torch.istft(stft, self.stft_length, self.stft_shift, window=window, center=True, 161 | onesided=True, 162 | return_complex=False) 163 | result.append(td_signal) 164 | else: # multi-channel 165 | signal_shape = stft.shape 166 | if not has_complex_dim: 167 | reshaped_signal = stft.reshape((signal_shape[:-2].numel(), signal_shape[-2], signal_shape[-1])) 168 | td_signals = torch.istft(reshaped_signal, self.stft_length, self.stft_shift, window=window, center=True, onesided=True, return_complex=False) 169 | combined_dim, n_samples = td_signals.shape 170 | td_signals = td_signals.reshape(signal_shape[:-2]+(n_samples,)) 171 | else: 172 | reshaped_signal = stft.reshape((signal_shape[:-3].numel(), signal_shape[-3], signal_shape[-2], signal_shape[-1])) 173 | td_signals = torch.istft(reshaped_signal, self.stft_length, self.stft_shift, window=window, center=True, onesided=True, return_complex=False) 174 | combined_dim, n_samples = td_signals.shape 175 | td_signals = td_signals.reshape(signal_shape[:-3]+(n_samples,)) 176 | result.append(td_signals) 177 | return result 178 | 179 | def log_batch_detailed_spectrograms(self, 180 | stfts: List[torch.Tensor], 181 | batch_idx: Union[int, None], 182 | tag: str = 'train', 183 | n_samples: int = -1): 184 | """ 185 | Write spectrograms for a batch. 186 | 187 | The spectrograms are reordered so that the ith sample of all STFTs are displayed in the same row 188 | (e.g. noisy, clean, noise and enhanced side-by-side). 189 | 190 | :param stfts: a list of the STFTs [BATCH, FREQ, TIME] 191 | :param batch_idx: the batch index 192 | :param tag: the logging tag (e.g. val vs. train) 193 | :param n_samples: the number of samples to log (default is full batch) 194 | """ 195 | 196 | tensorboard = self.logger.experiment 197 | 198 | log_name = f"{tag}/spectrogram{'_' + str(batch_idx) if not batch_idx is None else ''}" 199 | 200 | combined_stfts = torch.flatten(torch.stack(stfts, dim=1), start_dim=0, end_dim=1) 201 | if n_samples > 0: 202 | combined_stfts = combined_stfts[:n_samples * len(stfts)] 203 | spectrograms_db = 10 * torch.log10(torch.maximum(torch.square(torch.abs(combined_stfts)), 204 | (10 ** (-15)) * torch.ones_like(combined_stfts, 205 | dtype=torch.float32))) 206 | spectrograms_db = torch.flip(torch.unsqueeze(spectrograms_db, dim=1), dims=[-2]) 207 | tensorboard.add_image(log_name, make_image_grid(spectrograms_db, vmin=-80, vmax=20, n_img_per_row=len(stfts)), 208 | global_step=self.current_epoch) 209 | 210 | 211 | def log_batch_detailed_audio(self, noisy_td, enhanced_td, batch_idx: Union[int, None], tag: str, 212 | n_samples: int = 10): 213 | """ 214 | Write audio logs for a batch. 215 | 216 | :param noisy_stft: the noisy stft [BATCH, FREQ, TIME] 217 | :param enhanced_stft: the enhanced stft 218 | :param batch_idx: the batch index 219 | :param tag: the logging tag (e.g. val vs. train) 220 | """ 221 | tensorboard = self.logger.experiment 222 | 223 | cur_samples = len(noisy_td) 224 | for i in range(min(self.trainer.datamodule.batch_size, n_samples, cur_samples)): 225 | log_noisy_name = f"{tag}/{str(batch_idx) if not batch_idx is None else ''}_{i}_noisy" 226 | tensorboard.add_audio(log_noisy_name, noisy_td[i], global_step=self.current_epoch, 227 | sample_rate=self.trainer.datamodule.fs) 228 | 229 | log_enhanced_name = f"{tag}/{str(batch_idx) if not batch_idx is None else ''}_{i}_enhanced" 230 | tensorboard.add_audio(log_enhanced_name, enhanced_td[i], global_step=self.current_epoch, 231 | sample_rate=self.trainer.datamodule.fs) 232 | -------------------------------------------------------------------------------- /src/models/exp_jnf.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | import torch 3 | from torch import nn 4 | from models.exp_enhancement import EnhancementExp 5 | from models.models import FTJNF 6 | 7 | class JNFExp(EnhancementExp): 8 | 9 | def __init__(self, 10 | model: nn.Module, 11 | learning_rate: float, 12 | weight_decay: float, 13 | loss_alpha: float, 14 | stft_length: int, 15 | stft_shift: int, 16 | cirm_comp_K: float, 17 | cirm_comp_C: float, 18 | reference_channel: int = 0): 19 | super(JNFExp, self).__init__(model=model, cirm_comp_K=cirm_comp_K, cirm_comp_C=cirm_comp_C) 20 | 21 | self.model = model 22 | 23 | self.stft_length = stft_length 24 | self.stft_shift = stft_shift 25 | 26 | self.cirm_K = cirm_comp_K 27 | self.cirm_C = cirm_comp_C 28 | 29 | self.learning_rate = learning_rate 30 | self.weight_decay = weight_decay 31 | self.loss_alpha = loss_alpha 32 | 33 | self.reference_channel = reference_channel 34 | 35 | #self.example_input_array = torch.from_numpy(np.ones((2, 6, 513, 75), dtype=np.float32)) 36 | 37 | def forward(self, input): 38 | speech_mask = self.model(input) 39 | return speech_mask 40 | 41 | def shared_step(self, batch, batch_idx, stage: Literal['train', 'val']): 42 | 43 | noisy_td, clean_td, noise_td = batch['noisy_td'], batch['clean_td'], batch['noise_td'] 44 | noisy_stft, clean_stft, noise_stft = self.get_stft_rep(noisy_td, clean_td, noise_td) 45 | 46 | # compute mask estimate 47 | stacked_noisy_stft = torch.concat((torch.real(noisy_stft), torch.imag(noisy_stft)), dim=1) 48 | 49 | if self.model.output_type == 'IRM': 50 | irm_speech_mask = self.model(stacked_noisy_stft) 51 | speech_mask, noise_mask = irm_speech_mask, 1-irm_speech_mask 52 | elif self.model.output_type == 'CRM': 53 | stacked_speech_mask = self.model(stacked_noisy_stft) 54 | speech_mask, noise_mask = self.get_complex_masks_from_stacked(stacked_speech_mask) 55 | else: 56 | raise ValueError(f'The output type {self.model.output_type} is not supported.') 57 | 58 | # compute estimates 59 | est_clean_stft = noisy_stft[:, self.reference_channel, ...] * speech_mask 60 | est_noise_stft = noisy_stft[:, self.reference_channel, ...] * noise_mask 61 | clean_td, noise_td, est_clean_td, est_noise_td = self.get_td_rep(clean_stft[:, self.reference_channel, ...], noise_stft[:, self.reference_channel, ...], 62 | est_clean_stft, est_noise_stft) 63 | 64 | # compute loss 65 | clean_td_loss, noise_td_loss, clean_mag_loss, noise_mag_loss = self.loss(clean_td, est_clean_td, noise_td, 66 | est_noise_td, 67 | clean_stft[:, self.reference_channel, ...], est_clean_stft, 68 | noise_stft[:, self.reference_channel, ...], 69 | est_noise_stft) 70 | 71 | loss = torch.mean(self.loss_alpha * (clean_td_loss + noise_td_loss) + (clean_mag_loss + noise_mag_loss)) 72 | 73 | # logging 74 | on_step = False 75 | self.log(f'{stage}/loss', loss, on_step=on_step, on_epoch=True, logger=True, sync_dist=True) 76 | self.log(f'{stage}/clean_td_loss', clean_td_loss.mean(), on_step=on_step, on_epoch=True, logger=True, sync_dist=True) 77 | self.log(f'{stage}/noise_td_loss', noise_td_loss.mean(), on_step=on_step, on_epoch=True, logger=True, sync_dist=True) 78 | self.log(f'{stage}/clean_mag_loss', clean_mag_loss.mean(), on_step=on_step, on_epoch=True, logger=True, sync_dist=True) 79 | self.log(f'{stage}/noise_mag_loss', noise_mag_loss.mean(), on_step=on_step, on_epoch=True, logger=True, sync_dist=True) 80 | if batch_idx < 1: 81 | self.log_batch_detailed_audio(noisy_td[:, self.reference_channel, ...], est_clean_td, batch_idx, stage) 82 | self.log_batch_detailed_spectrograms( 83 | [noisy_stft[:, self.reference_channel, ...], clean_stft[:, self.reference_channel, ...], noise_stft[:, self.reference_channel, ...], est_clean_stft, est_noise_stft], 84 | batch_idx, 85 | stage, n_samples=10) 86 | # self.log_batch_detailed_maks([complex_speech_mask.abs(), complex_noise_mask.abs()], batch_idx, stage, n_samples=10) 87 | if stage == 'val': 88 | self.log(f'monitor_loss', loss, on_step=on_step, on_epoch=True, logger=True) 89 | global_si_sdr = self.compute_global_si_sdr(est_clean_td, clean_td) 90 | self.log('val/si_sdr', global_si_sdr.mean(), on_epoch=True, logger=True, sync_dist=True) 91 | 92 | return loss -------------------------------------------------------------------------------- /src/models/exp_ssf.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from models.exp_enhancement import EnhancementExp 6 | from torchmetrics.aggregation import RunningMean 7 | from typing import Literal 8 | 9 | 10 | class SSFExp(EnhancementExp): 11 | 12 | def __init__( 13 | self, 14 | model: nn.Module, 15 | learning_rate: float, 16 | weight_decay: float, 17 | loss_alpha: float, 18 | stft_length: int, 19 | stft_shift: int, 20 | cirm_comp_K: float, 21 | cirm_comp_C: float, 22 | n_cond_emb_dim: int, 23 | condition_enc_type: Literal["index", "arange"], 24 | cond_arange_params: tuple = None, 25 | loss_type: str = "l1", 26 | scheduler_type: str = None, 27 | scheduler_params: dict = None, 28 | reference_channel: int = 0, 29 | ): 30 | super(SSFExp, self).__init__( 31 | model=model, 32 | cirm_comp_K=cirm_comp_K, 33 | cirm_comp_C=cirm_comp_C, 34 | scheduler_type=scheduler_type, 35 | scheduler_params=scheduler_params, 36 | ) 37 | 38 | self.model = model 39 | 40 | self.stft_length = stft_length 41 | self.stft_shift = stft_shift 42 | 43 | self.cirm_K = cirm_comp_K 44 | self.cirm_C = cirm_comp_C 45 | 46 | self.reference_channel = reference_channel 47 | 48 | self.learning_rate = learning_rate 49 | self.weight_decay = weight_decay 50 | self.loss_alpha = loss_alpha 51 | 52 | self.n_cond_emb_dim = n_cond_emb_dim 53 | self.condition_enc_type = condition_enc_type 54 | self.cond_arange_params = cond_arange_params 55 | self.loss_type = loss_type 56 | 57 | self.running_loss = RunningMean(window=20) 58 | 59 | if self.condition_enc_type == "arange": 60 | assert cond_arange_params is not None, "Angle range parameters are missing" 61 | start, stop, step = cond_arange_params 62 | angles = range(start, stop, step) 63 | n_angles = len(angles) 64 | indices = range(n_angles) 65 | assert ( 66 | n_angles == n_cond_emb_dim 67 | ), "The embedding dim does not match the angle range params" 68 | self.angle_index_map = dict(zip(angles, indices)) 69 | 70 | # self.example_input_array = torch.from_numpy(np.ones((2, 6, 513, 75), dtype=np.float32)) 71 | 72 | def forward(self, input, target_dir): 73 | target_dir_enc = self.encode_condition(target_dir) 74 | speech_mask = self.model(input, target_dir_enc, device=self.device) 75 | return speech_mask 76 | 77 | def training_step(self, batch, batch_idx): 78 | return self.shared_step(batch, batch_idx, stage="train") 79 | 80 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 81 | return self.shared_step(batch, 82 | batch_idx, 83 | stage="val", 84 | dataloader_idx=dataloader_idx) 85 | 86 | def shared_step(self, 87 | batch, 88 | batch_idx, 89 | stage: Literal["train", "val"], 90 | dataloader_idx=0): 91 | noisy_td, clean_td, noise_td = ( 92 | batch["noisy_td"], 93 | batch["clean_td"], 94 | batch["noise_td"], 95 | ) 96 | noisy_stft, clean_stft, noise_stft = self.get_stft_rep( 97 | noisy_td, clean_td, noise_td) 98 | 99 | # compute mask estimate 100 | stacked_noisy_stft = torch.concat( 101 | (torch.real(noisy_stft), torch.imag(noisy_stft)), dim=1) 102 | 103 | target_dirs = batch["target_dir"] 104 | target_dirs_enc = self.encode_condition(target_dirs) 105 | 106 | if self.model.output_type == "IRM": 107 | irm_speech_mask = self.model(stacked_noisy_stft, 108 | target_dirs_enc, 109 | device=self.device) 110 | speech_mask, noise_mask = irm_speech_mask, 1 - irm_speech_mask 111 | elif self.model.output_type == "CRM": 112 | stacked_speech_mask = self.model(stacked_noisy_stft, 113 | target_dirs_enc, 114 | device=self.device) 115 | speech_mask, noise_mask = self.get_complex_masks_from_stacked( 116 | stacked_speech_mask) 117 | else: 118 | raise ValueError( 119 | f"The output type {self.model.output_type} is not supported.") 120 | 121 | # compute estimates 122 | est_clean_stft = noisy_stft[:, self.reference_channel, 123 | ...] * speech_mask 124 | est_noise_stft = noisy_stft[:, self.reference_channel, 125 | ...] * noise_mask 126 | clean_td, noise_td, est_clean_td, est_noise_td = self.get_td_rep( 127 | clean_stft[:, self.reference_channel, ...], 128 | noise_stft[:, self.reference_channel, ...], 129 | est_clean_stft, 130 | est_noise_stft, 131 | ) 132 | 133 | # compute loss 134 | if self.loss_type == "l1": 135 | clean_td_loss, noise_td_loss, clean_mag_loss, noise_mag_loss = self.loss( 136 | clean_td, 137 | est_clean_td, 138 | noise_td, 139 | est_noise_td, 140 | clean_stft[:, 0, ...], 141 | est_clean_stft, 142 | noise_stft[:, 0, ...], 143 | est_noise_stft, 144 | ) 145 | 146 | loss = torch.mean(self.loss_alpha * clean_td_loss + clean_mag_loss) 147 | elif self.loss_type == "sisdr": 148 | loss = -torch.mean( 149 | self.compute_global_si_sdr(est_clean_td, clean_td)) 150 | 151 | # logging 152 | if stage == "train" or dataloader_idx == 0 or dataloader_idx is None: 153 | add_dataloader_idx = False 154 | else: 155 | add_dataloader_idx = True 156 | 157 | self.running_loss(loss) 158 | on_step = True if stage == 'train' else False 159 | self.log( 160 | f"{stage}/loss", 161 | self.running_loss.compute(), 162 | on_step=on_step, 163 | on_epoch=True, 164 | logger=True, 165 | add_dataloader_idx=add_dataloader_idx, 166 | sync_dist=True, 167 | prog_bar=True, 168 | ) 169 | if self.loss_type == "l1": 170 | self.log( 171 | f"{stage}/clean_td_loss", 172 | clean_td_loss.mean(), 173 | on_step=on_step, 174 | on_epoch=True, 175 | logger=True, 176 | add_dataloader_idx=add_dataloader_idx, 177 | sync_dist=True, 178 | ) 179 | 180 | self.log( 181 | f"{stage}/clean_mag_loss", 182 | clean_mag_loss.mean(), 183 | on_step=on_step, 184 | on_epoch=True, 185 | logger=True, 186 | add_dataloader_idx=add_dataloader_idx, 187 | sync_dist=True, 188 | ) 189 | 190 | if batch_idx < 1: 191 | self.log_batch_detailed_audio(noisy_td[:, 0, ...], est_clean_td, 192 | batch_idx, stage) 193 | self.log_batch_detailed_spectrograms( 194 | [ 195 | noisy_stft[:, self.reference_channel, ...], 196 | clean_stft[:, self.reference_channel, ...], 197 | noise_stft[:, self.reference_channel, ...], 198 | est_clean_stft, 199 | est_noise_stft, 200 | ], 201 | batch_idx, 202 | stage, 203 | n_samples=10, 204 | ) 205 | # self.log_batch_detailed_maks([complex_speech_mask.abs(), complex_noise_mask.abs()], batch_idx, stage, n_samples=10) 206 | if stage == "val": 207 | if dataloader_idx == 0: 208 | self.log( 209 | "monitor_loss", 210 | loss, 211 | on_step=False, 212 | on_epoch=True, 213 | logger=True, 214 | add_dataloader_idx=add_dataloader_idx, 215 | sync_dist=True, 216 | ) 217 | global_si_sdr = torch.mean( 218 | self.compute_global_si_sdr(est_clean_td, clean_td)) 219 | self.log( 220 | "val/si_sdr", 221 | global_si_sdr, 222 | on_epoch=True, 223 | logger=True, 224 | add_dataloader_idx=add_dataloader_idx, 225 | sync_dist=True, 226 | ) 227 | 228 | return loss 229 | 230 | def encode_condition(self, target_dirs): 231 | """ 232 | Provide an encoding of the target direction of length self.n_cond_emb_dim using the specified encoding strategy. 233 | 234 | Encoding strategys: 235 | - index: The target dir is already an index of the direction and only need a one_hot encoding. 236 | 237 | - arange: The target dirs have been generated using the arange function and will first be mapped to indices and then be one_hot encoded 238 | 239 | """ 240 | 241 | if self.condition_enc_type == "index": 242 | return torch.nn.functional.one_hot(target_dirs, 243 | self.n_cond_emb_dim).float() 244 | 245 | elif self.condition_enc_type == "arange": 246 | index_mapped = (target_dirs.cpu().apply_( 247 | self.angle_index_map.get).to(self.device)) 248 | return torch.nn.functional.one_hot(index_mapped, 249 | self.n_cond_emb_dim).float() 250 | -------------------------------------------------------------------------------- /src/models/models.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from torch import nn 3 | import torch 4 | 5 | """ 6 | This file provides the torch modules for the paper 7 | 8 | Kristina Tesch and Timo Gerkmann, "Insights into Deep Non-linear Filters for Improved Multi-channel Speech Enhancement", 9 | IEEE/ACM Transactions of Audio, Speech and Language Processing, vol 31, pp. 563-575, 2023. 10 | 11 | and also 12 | 13 | Kristina Tesch and Timo Gerkmann, "Multi-channel Speech Separation Using Spatially Selective Deeo Non-linear Filters", submitted to IEEE/ACM Transactions of Audio, Speech and Language Processing. 14 | 15 | Included networks are: 16 | JNF (implements T-JNF, F-JNF, T-NSF and F-NSF) 17 | FTJNF (implements FT-JNF and FT-NSF) 18 | JNF_SSF (implements FT-JNF with a conditioning on the DoA angle) 19 | 20 | The network architecture T-JNF corresponds to the network proposed in 21 | 22 | X. Li und R. Horaud, „Multichannel Speech Enhancement Based On Time-Frequency Masking Using Subband Long Short-Term Memory“, 23 | in 2019 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA), Okt. 2019, p. 298–302. 24 | """ 25 | 26 | 27 | class JNF(nn.Module): 28 | 29 | def __init__(self, 30 | n_time_steps: int, 31 | n_freqs: int, 32 | n_channels: int, 33 | n_lstm_hidden1: int = 256, 34 | n_lstm_hidden2: int = 128, 35 | bidirectional: bool = True, 36 | output_type: Literal['IRM', 'CRM'] = 'CRM', 37 | output_activation: Literal['sigmoid', 'tanh', 'linear'] = 'tanh', 38 | dropout: float = 0, 39 | append_freq_idx: bool = False, 40 | permute_freqs: bool = False, 41 | narrow_band: bool = False): 42 | """ 43 | Initialize model. 44 | 45 | :param n_time_steps: number of STFT time frames in the input signal 46 | :param n_freqs: number of STFT frequency bins in the input signal 47 | :param n_channels: number of channel in the input signal 48 | :param n_lstm_hidden1: number of LSTM units in the first LSTM layer 49 | :param n_lstm_hidden2: number of LSTM units in the second LSTM layer 50 | :param bidirectional: set to True for a bidirectional LSTM 51 | :param output_type: set to 'IRM' for real-valued ideal ratio mask (IRM) and to 'CRM' for complex IRM 52 | :param output_activation: the activation function applied to the network output (options: 'sigmoid', 'tanh', 'linear') 53 | :param dropout: dropout percentage (default: no dropout) 54 | :param append_freq_idx: add the frequency-bin index to the input of the LSTM when using permuted sequences 55 | :param permute_freqs: permute the LSTM input sequence 56 | :param narrow_band: use narrow-band input if narrow_band else use wide-band input 57 | """ 58 | super(JNF, self).__init__() 59 | 60 | self.n_time_steps = n_time_steps 61 | self.n_freqs = n_freqs 62 | self.n_channels = n_channels 63 | self.n_lstm_hidden1 = n_lstm_hidden1 64 | self.n_lstm_hidden2 = n_lstm_hidden2 65 | self.dropout = dropout 66 | self.bidirectional = bidirectional 67 | self.output_type = output_type 68 | self.output_activation = output_activation 69 | self.append_freq_idx = append_freq_idx 70 | self.permute = permute_freqs 71 | self.narrow_band = narrow_band 72 | 73 | lstm_input = 2*n_channels 74 | if self.append_freq_idx and self.permute: 75 | lstm_input += 1 76 | 77 | self.lstm1 = nn.LSTM(input_size=lstm_input, hidden_size=self.n_lstm_hidden1, bidirectional=bidirectional, batch_first=False) 78 | self.lstm2 = nn.LSTM(input_size=2*self.n_lstm_hidden1, hidden_size=self.n_lstm_hidden2, bidirectional=bidirectional, batch_first=False) 79 | 80 | self.dropout = nn.Dropout(p=self.dropout) 81 | 82 | if self.output_type == 'IRM': 83 | self.linear_out_features = 1 84 | elif self.output_type == 'CRM': 85 | self.linear_out_features = 2 86 | else: 87 | raise ValueError(f'The output type {output_type} is not supported.') 88 | self.ff = nn.Linear(2*self.n_lstm_hidden2, out_features=self.linear_out_features) 89 | 90 | if self.output_activation == 'sigmoid': 91 | self.mask_activation = nn.Sigmoid() 92 | elif self.output_activation == 'tanh': 93 | self.mask_activation = nn.Tanh() 94 | elif self.output_activation == 'linear': 95 | self.mask_activation = nn.Identity() 96 | 97 | 98 | def forward(self, x: torch.Tensor): 99 | """ 100 | Implements the forward pass of the model. 101 | 102 | :param x: input with shape [BATCH, CHANNEL, FREQ, TIME] 103 | :return: the output mask [BATCH, 1 (IRM) or 2 (CRM) , FREQ, TIME] 104 | """ 105 | n_batch, n_channel, n_freq, n_times = x.shape 106 | 107 | if self.narrow_band: 108 | seq_len = n_times 109 | tmp_batch = n_batch*n_freq 110 | x = x.permute(3,0,2,1).reshape(n_times, n_batch*n_freq, n_channel) 111 | else: # wide_band 112 | seq_len = n_freq 113 | tmp_batch = n_batch * n_times 114 | x = x.permute(2,0,3,1).reshape(n_freq, n_batch*n_times, n_channel) 115 | 116 | if self.permute: 117 | perm = torch.randperm(seq_len) 118 | inv_perm = torch.zeros(seq_len, dtype=int) 119 | for i, val in enumerate(perm): 120 | inv_perm[val] = i 121 | x = x[perm] 122 | 123 | if self.append_freq_idx: 124 | if self.narrow_band: 125 | freq_bins = torch.arange(n_freq).repeat(n_batch*n_times).reshape(seq_len, tmp_batch, 1).to(x.device) 126 | x = torch.concat((x, freq_bins), dim=-1) 127 | else: 128 | freq_bins = torch.arange(n_freq).repeat(int(seq_len/n_freq))[perm] 129 | freq_bins = freq_bins.unsqueeze(1).unsqueeze(1).broadcast_to((seq_len, tmp_batch, 1)).to(x.device) 130 | x = torch.concat((x, freq_bins), dim=-1) 131 | 132 | x, _ = self.lstm1(x) 133 | x = self.dropout(x) 134 | x, _ = self.lstm2(x) 135 | x = self.dropout(x) 136 | x = self.ff(x) 137 | 138 | if self.permute: 139 | x = x[inv_perm] 140 | 141 | if self.narrow_band: 142 | x = x.reshape(n_times, n_batch, n_freq, self.linear_out_features).permute(1,3,2,0) 143 | else: # wide_band 144 | x = x.reshape(n_freq, n_batch, n_times, self.linear_out_features).permute(1,3,0,2) 145 | x = self.mask_activation(x) 146 | return x 147 | 148 | 149 | class FTJNF(nn.Module): 150 | """ 151 | Mask estimation network composed of two LSTM layers. One LSTM layer uses the frequency-dimension as sequence input 152 | and the other LSTM uses the time-dimension as input. 153 | """ 154 | def __init__(self, 155 | n_channels: int, 156 | n_lstm_hidden1: int = 512, 157 | n_lstm_hidden2: int = 128, 158 | bidirectional: bool = True, 159 | freq_first: bool = True, 160 | output_type: Literal['IRM', 'CRM'] = 'CRM', 161 | output_activation: Literal['sigmoid', 'tanh', 'linear'] = 'tanh', 162 | dropout: float = 0, 163 | append_freq_idx: bool = False, 164 | permute_freqs: bool = False): 165 | """ 166 | Initialize model. 167 | 168 | :param n_channels: number of channel in the input signal 169 | :param n_lstm_hidden1: number of LSTM units in the first LSTM layer 170 | :param n_lstm_hidden2: number of LSTM units in the second LSTM layer 171 | :param bidirectional: set to True for a bidirectional LSTM 172 | :param freq_first: process frequency dimension first if freq_first else process time dimension first 173 | :param output_type: output_type: set to 'IRM' for real-valued ideal ratio mask (IRM) and to 'CRM' for complex IRM 174 | :param output_activation: the activation function applied to the network output (options: 'sigmoid', 'tanh', 'linear') 175 | :param dropout: dropout percentage (default: no dropout) 176 | """ 177 | super(FTJNF, self).__init__() 178 | 179 | self.n_channels = n_channels 180 | self.n_lstm_hidden1 = n_lstm_hidden1 181 | self.n_lstm_hidden2 = n_lstm_hidden2 182 | self.dropout = dropout 183 | self.bidirectional = bidirectional 184 | self.output_type = output_type 185 | self.output_activation = output_activation 186 | self.freq_first = freq_first 187 | self.append_freq_idx = append_freq_idx 188 | self.permute = permute_freqs 189 | 190 | lstm_input = 2*n_channels 191 | if self.append_freq_idx: 192 | lstm_input += 1 193 | 194 | self.lstm1 = nn.LSTM(input_size=lstm_input, hidden_size=self.n_lstm_hidden1, bidirectional=bidirectional, batch_first=False) 195 | 196 | self.lstm1_out = 2*self.n_lstm_hidden1 if self.bidirectional else self.n_lstm_hidden1 197 | lstm2_input = self.lstm1_out 198 | if self.append_freq_idx: 199 | lstm2_input+= 1 200 | 201 | self.lstm2 = nn.LSTM(input_size=lstm2_input, hidden_size=self.n_lstm_hidden2, bidirectional=bidirectional, batch_first=False) 202 | self.lstm2_out = 2*self.n_lstm_hidden2 if self.bidirectional else self.n_lstm_hidden2 203 | 204 | self.dropout = nn.Dropout(p=self.dropout) 205 | 206 | if self.output_type == 'IRM': 207 | self.linear_out_features = 1 208 | elif self.output_type == 'CRM': 209 | self.linear_out_features = 2 210 | else: 211 | raise ValueError(f'The output type {output_type} is not supported.') 212 | self.ff = nn.Linear(self.lstm2_out, out_features=self.linear_out_features) 213 | 214 | if self.output_activation == 'sigmoid': 215 | self.mask_activation = nn.Sigmoid() 216 | elif self.output_activation == 'tanh': 217 | self.mask_activation = nn.Tanh() 218 | elif self.output_activation == 'linear': 219 | self.mask_activation = nn.Identity() 220 | 221 | 222 | def forward(self, x: torch.Tensor): 223 | """ 224 | Implements the forward pass of the model. 225 | 226 | :param x: input with shape [BATCH, CHANNEL, FREQ, TIME] 227 | :return: the output mask [BATCH, 1 (IRM) or 2 (CRM) , FREQ, TIME] 228 | """ 229 | n_batch, n_channel, n_freq, n_times = x.shape 230 | 231 | if not self.freq_first: # narrow_band 232 | seq_len = n_times 233 | tmp_batch = n_batch*n_freq 234 | x = x.permute(3,0,2,1).reshape(n_times, n_batch*n_freq, n_channel) 235 | else: # wide_band 236 | seq_len = n_freq 237 | tmp_batch = n_batch*n_times 238 | x = x.permute(2,0,3,1).reshape(n_freq, n_batch*n_times, n_channel) 239 | 240 | if self.permute: 241 | perm = torch.randperm(seq_len) 242 | inv_perm = torch.zeros(seq_len, dtype=int) 243 | for i, val in enumerate(perm): 244 | inv_perm[val] = i 245 | x = x[perm] 246 | else: 247 | perm = torch.arange(seq_len) 248 | 249 | if self.append_freq_idx: 250 | if not self.freq_first: # narrow_band: 251 | freq_bins = torch.arange(n_freq).repeat(n_batch*n_times).reshape(seq_len, tmp_batch, 1).to(x.device) 252 | x = torch.concat((x, freq_bins), dim=-1) 253 | else: # wide_band 254 | freq_bins = torch.arange(n_freq).repeat(int(seq_len/n_freq))[perm] 255 | freq_bins = freq_bins.unsqueeze(1).unsqueeze(1).broadcast_to((seq_len, tmp_batch, 1)).to(x.device) 256 | x = torch.concat((x, freq_bins), dim=-1) 257 | 258 | x, _ = self.lstm1(x) 259 | x = self.dropout(x) 260 | 261 | if self.permute: 262 | x = x[inv_perm] 263 | 264 | if not self.freq_first: # narrow_band -> wide_band 265 | seq_len = n_freq 266 | tmp_batch = n_batch*n_times 267 | x = x.reshape(n_times, n_batch, n_freq, self.lstm1_out).permute(2,1,0,3).reshape(n_freq, n_batch*n_times, self.lstm1_out) 268 | else: # wide_band -> narrow_band 269 | seq_len = n_times 270 | tmp_batch = n_batch*n_freq 271 | x = x.reshape(n_freq, n_batch, n_times, self.lstm1_out).permute(2,1,0,3).reshape(n_times, n_batch*n_freq, self.lstm1_out) 272 | 273 | if self.permute: 274 | perm = torch.randperm(seq_len) 275 | inv_perm = torch.zeros(seq_len, dtype=int) 276 | for i, val in enumerate(perm): 277 | inv_perm[val] = i 278 | x = x[perm] 279 | else: 280 | perm = torch.arange(seq_len) 281 | 282 | if self.append_freq_idx: 283 | if self.freq_first: # wide_band 284 | freq_bins = torch.arange(n_freq).repeat(n_batch*n_times).reshape(seq_len, tmp_batch, 1).to(x.device) 285 | x = torch.concat((x, freq_bins), dim=-1) 286 | else: # narrow_band 287 | freq_bins = torch.arange(n_freq).repeat(int(seq_len/n_freq))[perm] 288 | freq_bins = freq_bins.unsqueeze(1).unsqueeze(1).broadcast_to((seq_len, tmp_batch, 1)).to(x.device) 289 | x = torch.concat((x, freq_bins), dim=-1) 290 | 291 | x, _ = self.lstm2(x) 292 | x = self.dropout(x) 293 | 294 | if self.permute: 295 | x = x[inv_perm] 296 | 297 | x = self.ff(x) 298 | 299 | if not self.freq_first: # wide_band -> input shape 300 | x = x.reshape(n_freq, n_batch, n_times, self.linear_out_features).permute(1,3,0,2) 301 | else: # narrow_band -> input shape 302 | x = x.reshape(n_times, n_batch, n_freq, self.linear_out_features).permute(1,3,2,0) 303 | 304 | x = self.mask_activation(x) 305 | return x 306 | 307 | 308 | class JNF_SSF(nn.Module): 309 | """ 310 | Mask estimation network composed of two LSTM layers. One LSTM layer uses the frequency-dimension as sequence input 311 | and the other LSTM uses the time-dimension as input. 312 | In addition to the noisy input, the network also gets a one-hot encoded DoA angle vector to indicate in which direction the target speaker is located. 313 | """ 314 | def __init__(self, 315 | n_channels: int, 316 | n_lstm_hidden1: int, 317 | n_lstm_hidden2: int, 318 | n_cond_emb_dim: int, 319 | bidirectional: bool, 320 | output_type: Literal['IRM', 'CRM'], 321 | output_activation: Literal['sigmoid', 'tanh', 'linear'], 322 | dropout: float = 0, 323 | causal: bool = False, 324 | condition_nb_only: bool = False, 325 | condition_wb_only: bool = True): 326 | """ 327 | Initialize model. 328 | 329 | :param n_channels: number of channel in the input signal 330 | :param n_lstm_hidden1: number of LSTM units in the first LSTM layer 331 | :param n_lstm_hidden2: number of LSTM units in the second LSTM layer 332 | :param bidirectional: set to True for a bidirectional LSTM 333 | :param output_type: output_type: set to 'IRM' for real-valued ideal ratio mask (IRM) and to 'CRM' for complex IRM 334 | :param output_activation: the activation function applied to the network output (options: 'sigmoid', 'tanh', 'linear') 335 | :param dropout: dropout percentage (default: no dropout) 336 | :param condition_wb_only: flag indicating if both LSTM layers or only the wide-band (first) should be conditioned on the target DoA 337 | :param condition_nb_only: flag indicating if both LSTM layers or only the narrowband (second) should be conditioned on the target DoA 338 | """ 339 | super(JNF_SSF, self).__init__() 340 | 341 | self.n_channels = n_channels 342 | self.n_lstm_hidden1 = n_lstm_hidden1 343 | self.n_lstm_hidden2 = n_lstm_hidden2 344 | self.n_cond_emb_dim = n_cond_emb_dim 345 | self.dropout = dropout 346 | self.bidirectional = bidirectional 347 | self.output_type = output_type 348 | self.output_activation = output_activation 349 | self.condition_nb_only = condition_nb_only 350 | self.condition_wb_only = condition_wb_only 351 | 352 | assert not (condition_nb_only and condition_wb_only), "Config does not make sense." 353 | 354 | lstm_input = 2*n_channels 355 | 356 | if not self.condition_nb_only: 357 | self.cond_emb1 = nn.Linear(n_cond_emb_dim, self.n_lstm_hidden1) 358 | if not self.condition_wb_only: 359 | self.cond_emb2 = nn.Linear(n_cond_emb_dim, self.n_lstm_hidden2) 360 | 361 | self.lstm1 = nn.LSTM(input_size=lstm_input, hidden_size=self.n_lstm_hidden1, 362 | bidirectional=self.bidirectional, batch_first=False) 363 | 364 | self.lstm1_out = 2*self.n_lstm_hidden1 if self.bidirectional else self.n_lstm_hidden1 365 | lstm2_input = self.lstm1_out 366 | 367 | self.bidirectional_second = (bidirectional and not causal) 368 | self.lstm2 = nn.LSTM(input_size=lstm2_input, hidden_size=self.n_lstm_hidden2, 369 | bidirectional=self.bidirectional_second, batch_first=False) 370 | self.lstm2_out = 2*self.n_lstm_hidden2 if self.bidirectional_second else self.n_lstm_hidden2 371 | 372 | self.dropout = nn.Dropout(p=self.dropout) 373 | 374 | if self.output_type == 'IRM': 375 | self.linear_out_features = 1 376 | elif self.output_type == 'CRM': 377 | self.linear_out_features = 2 378 | else: 379 | raise ValueError( 380 | f'The output type {output_type} is not supported.') 381 | self.ff = nn.Linear( 382 | self.lstm2_out, out_features=self.linear_out_features) 383 | 384 | if self.output_activation == 'sigmoid': 385 | self.mask_activation = nn.Sigmoid() 386 | elif self.output_activation == 'tanh': 387 | self.mask_activation = nn.Tanh() 388 | elif self.output_activation == 'linear': 389 | self.mask_activation = nn.Identity() 390 | 391 | def forward(self, x: torch.Tensor, target_dirs: torch.Tensor, device: str): 392 | """ 393 | Implements the forward pass of the model. 394 | 395 | :param x: input with shape [BATCH, CHANNEL, FREQ, TIME] 396 | :param target_dirs: the conditional input [BATCH, 1 (IDX)] 397 | :return: the output mask [BATCH, 1 (IRM) or 2 (CRM) , FREQ, TIME] 398 | """ 399 | n_batch, n_channel, n_freq, n_times = x.shape 400 | 401 | 402 | # wide_band 403 | tmp_batch = n_batch*n_times 404 | bidirectional_dim = 2 if self.bidirectional else 1 405 | x = x.permute(2, 0, 3, 1).reshape( 406 | n_freq, n_batch*n_times, n_channel) 407 | 408 | if self.condition_nb_only: 409 | x, _ = self.lstm1(x) 410 | else: 411 | x_cond_emb1 = self.cond_emb1(target_dirs) 412 | x_cond_emb1_reshaped1 = x_cond_emb1.unsqueeze(0).unsqueeze(2).repeat( 413 | bidirectional_dim, 1, n_times, 1).reshape(bidirectional_dim, tmp_batch, self.n_lstm_hidden1) 414 | x, _ = self.lstm1(x, 415 | (torch.zeros(bidirectional_dim, tmp_batch, self.n_lstm_hidden1, device=device), x_cond_emb1_reshaped1)) 416 | x = self.dropout(x) 417 | 418 | # narrow_band 419 | tmp_batch = n_batch*n_freq 420 | x = x.reshape(n_freq, n_batch, n_times, self.lstm1_out).permute( 421 | 2, 1, 0, 3).reshape(n_times, n_batch*n_freq, self.lstm1_out) 422 | if self.condition_wb_only: 423 | x, _ = self.lstm2(x) 424 | else: 425 | x_cond_emb2 = self.cond_emb2(target_dirs) 426 | x_cond_emb2_reshaped2 = x_cond_emb2.unsqueeze(0).unsqueeze(2).repeat( 427 | bidirectional_dim, 1, n_freq, 1).reshape(bidirectional_dim, tmp_batch, self.n_lstm_hidden2) 428 | 429 | x, _ = self.lstm2(x, 430 | (torch.zeros(bidirectional_dim, tmp_batch, self.n_lstm_hidden2, device=device), x_cond_emb2_reshaped2)) 431 | x = self.dropout(x) 432 | 433 | x = self.ff(x) 434 | 435 | # time_slice -> input shape 436 | x = x.reshape(n_times, n_batch, n_freq, 437 | self.linear_out_features).permute(1, 3, 2, 0) 438 | 439 | x = self.mask_activation(x) 440 | return x -------------------------------------------------------------------------------- /src/scripts/train_jnf.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning import loggers as pl_loggers 3 | from pytorch_lightning.callbacks import ModelSummary 4 | from models.exp_jnf import JNFExp 5 | from models.models import FTJNF 6 | from data.datamodule import HDF5DataModule 7 | from typing import Optional 8 | import yaml 9 | 10 | EXP_NAME='JNF' 11 | 12 | def setup_logging(tb_log_dir: str, version_id: Optional[int]= None): 13 | """ 14 | Set-up a Tensorboard logger. 15 | 16 | :param tb_log_dir: path to the log dir 17 | :param version_id: the version id (integer). Consecutive numbering is used if no number is given. 18 | """ 19 | 20 | if version_id is None: 21 | tb_logger = pl_loggers.TensorBoardLogger(tb_log_dir, name=EXP_NAME, log_graph=False) 22 | 23 | # get current version id 24 | version_id = int((tb_logger.log_dir).split('_')[-1]) 25 | else: 26 | tb_logger = pl_loggers.TensorBoardLogger(tb_log_dir, name=EXP_NAME, log_graph=False, version=version_id) 27 | 28 | return tb_logger, version_id 29 | 30 | def load_model(ckpt_file: str, 31 | _config): 32 | init_params = JNFExp.get_init_params(_config) 33 | model = JNFExp.load_from_checkpoint(ckpt_file, **init_params) 34 | model.to('cuda') 35 | return model 36 | 37 | def get_trainer(devices, logger, max_epochs, gradient_clip_val, gradient_clip_algorithm, strategy, accelerator): 38 | return pl.Trainer(enable_model_summary=True, 39 | logger=logger, 40 | devices=devices, 41 | log_every_n_steps=1, 42 | max_epochs=max_epochs, 43 | gradient_clip_val = gradient_clip_val, 44 | gradient_clip_algorithm = gradient_clip_algorithm, 45 | strategy = strategy, 46 | accelerator = accelerator, 47 | callbacks=[ 48 | #setup_checkpointing(), 49 | ModelSummary(max_depth=2) 50 | ], 51 | 52 | ) 53 | 54 | if __name__=="__main__": 55 | 56 | with open('config/jnf_config.yaml') as config_file: 57 | config = yaml.safe_load(config_file) 58 | 59 | ## REPRODUCIBILITY 60 | pl.seed_everything(config.get('seed', 0), workers=True) 61 | 62 | ## LOGGING 63 | tb_logger, version = setup_logging(config['logging']['tb_log_dir']) 64 | 65 | ## DATA 66 | data_config = config['data'] 67 | stft_length = data_config.get('stft_length_samples', 512) 68 | stft_shift = data_config.get('stft_shift_samples', 256) 69 | dm = HDF5DataModule(**data_config) 70 | 71 | ## CONFIGURE EXPERIMENT 72 | ckpt_file = config['training'].get('resume_ckpt', None) 73 | if not ckpt_file is None: 74 | exp = load_model(ckpt_file, config) 75 | else: 76 | model = FTJNF(**config['network']) 77 | exp = JNFExp(model=model, 78 | stft_length=stft_length, 79 | stft_shift=stft_shift, 80 | **config['experiment']) 81 | 82 | ## TRAIN 83 | trainer = get_trainer(logger=tb_logger, **config['training']) 84 | trainer.fit(exp, dm) 85 | 86 | -------------------------------------------------------------------------------- /src/scripts/train_ssf.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning import loggers as pl_loggers 3 | from pytorch_lightning.callbacks import ModelSummary 4 | from models.exp_ssf import SSFExp 5 | from models.models import JNF_SSF 6 | from data.datamodule import HDF5DataModule 7 | from typing import Optional 8 | import yaml 9 | 10 | EXP_NAME='JNF-SSF' 11 | 12 | def setup_logging(tb_log_dir: str, version_id: Optional[int]= None): 13 | """ 14 | Set-up a Tensorboard logger. 15 | 16 | :param tb_log_dir: path to the log dir 17 | :param version_id: the version id (integer). Consecutive numbering is used if no number is given. 18 | """ 19 | 20 | if version_id is None: 21 | tb_logger = pl_loggers.TensorBoardLogger(tb_log_dir, name=EXP_NAME, log_graph=False) 22 | 23 | # get current version id 24 | version_id = int((tb_logger.log_dir).split('_')[-1]) 25 | else: 26 | tb_logger = pl_loggers.TensorBoardLogger(tb_log_dir, name=EXP_NAME, log_graph=False, version=version_id) 27 | 28 | return tb_logger, version_id 29 | 30 | def load_model(ckpt_file: str, 31 | _config): 32 | init_params = SSFExp.get_init_params(_config) 33 | model = SSFExp.load_from_checkpoint(ckpt_file, **init_params) 34 | model.to('cuda') 35 | return model 36 | 37 | def get_trainer(devices, logger, max_epochs, gradient_clip_val, gradient_clip_algorithm, strategy, accelerator): 38 | return pl.Trainer(enable_model_summary=True, 39 | logger=logger, 40 | devices=devices, 41 | log_every_n_steps=100, 42 | max_epochs=max_epochs, 43 | gradient_clip_val = gradient_clip_val, 44 | gradient_clip_algorithm = gradient_clip_algorithm, 45 | strategy = strategy, 46 | accelerator = accelerator, 47 | callbacks=[ 48 | #setup_checkpointing(), 49 | ModelSummary(max_depth=2) 50 | ], 51 | 52 | ) 53 | 54 | if __name__=="__main__": 55 | 56 | with open('config/ssf_config.yaml') as config_file: 57 | config = yaml.safe_load(config_file) 58 | 59 | ## REPRODUCIBILITY 60 | pl.seed_everything(config.get('seed', 0), workers=True) 61 | 62 | ## LOGGING 63 | tb_logger, version = setup_logging(config['logging']['tb_log_dir']) 64 | 65 | ## DATA 66 | data_config = config['data'] 67 | stft_length = data_config.get('stft_length_samples', 512) 68 | stft_shift = data_config.get('stft_shift_samples', 256) 69 | dm = HDF5DataModule(**data_config) 70 | 71 | ## CONFIGURE EXPERIMENT 72 | ckpt_file = config['training'].get('resume_ckpt', None) 73 | if ckpt_file is not None: 74 | exp = load_model(ckpt_file, config) 75 | else: 76 | model = JNF_SSF(**config['network']) 77 | exp = SSFExp(model=model, 78 | stft_length=stft_length, 79 | stft_shift=stft_shift, 80 | **config['experiment']) 81 | 82 | ## TRAIN 83 | trainer = get_trainer(logger=tb_logger, **config['training']) 84 | trainer.fit(exp, dm) 85 | 86 | -------------------------------------------------------------------------------- /src/utils/log_images.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.utils import make_grid 3 | import matplotlib 4 | #matplotlib.use('Qt5Agg') 5 | import matplotlib.pyplot as plt 6 | import matplotlib.cm as cm 7 | import numpy as np 8 | from typing import Union 9 | 10 | def make_image_grid(image_batch: torch.Tensor, vmin: Union[float, None], vmax: Union[float, None], num_images: int = 11 | -1, n_img_per_row: int = 8): 12 | """ 13 | Take a batch of 2D data and create a grid of visualizations. 14 | 15 | :param image_batch: the 2D data [BATCH, XAXIS, YAXIS] 16 | :param num_images: the number of images to be displayed (default: full batch) 17 | :param vmin, vmax: the min and max value in the data (e.g. vmin=0, vmax=1 for a Wiener-like mask). If vmin and 18 | vmax are None, they will be inferred from the data. Be aware that the color range might change between epochs then. 19 | :param n_img_per_row: number of images per row 20 | """ 21 | 22 | image_batch = image_batch[:num_images if num_images > 0 else len(image_batch)].cpu().detach().numpy() 23 | 24 | def rgba_to_rgb(rgba): 25 | """ 26 | Converts a numpy RGBA array with shape [rows, columns, channels=rgba] to a numpy RGB array [rows, columns, 27 | channels=rgb] assuming a black background. 28 | 29 | :param rgba: the rgba data in a numpy array 30 | :return: a numpy array with the same shape as the input array but with the number of channels reduced to 3 31 | """ 32 | a = rgba[..., 3] 33 | rgba[..., 0] = a * rgba[..., 0] + (1 - a) * 255 34 | rgba[..., 1] = a * rgba[..., 1] + (1 - a) * 255 35 | rgba[..., 2] = a * rgba[..., 2] + (1 - a) * 255 36 | 37 | return rgba[..., :3] 38 | 39 | def plot(index): 40 | fig = plt.figure() 41 | ax = fig.add_subplot(111) 42 | # 10 * np.log10(np.maximum(np.square(np.abs(signal_stft)), 10 ** (-15))) 43 | im = ax.imshow(image_batch[index, 0], cmap='viridis', 44 | origin='lower', 45 | aspect='auto') 46 | fig.colorbar(im, orientation="vertical", pad=0.2) 47 | plt.show() 48 | 49 | cmap = cm.ScalarMappable(matplotlib.colors.Normalize(vmin=vmin, vmax=vmax, clip=True), 'viridis') 50 | 51 | rgba_data = cmap.to_rgba(image_batch, norm=True) 52 | rgb_data = np.moveaxis(np.squeeze(rgba_to_rgb(rgba_data)), 3, 1) 53 | image_data = torch.from_numpy(rgb_data) 54 | norm_image_data = ((image_data + 1) * 127.5).type(torch.ByteTensor) 55 | grid = make_grid(norm_image_data, nrow=n_img_per_row, padding=2, normalize=False) 56 | 57 | return grid 58 | --------------------------------------------------------------------------------