├── gen_dataset ├── run.sh ├── align_brir.py ├── wav2npy.py └── gen_dataset_wav.py ├── .gitignore ├── images ├── learning_curve_all.png └── end2end-model-framework.png ├── utils ├── file_reader_v2.py ├── plot_learning_curve.py └── file_reader.py ├── evaluate_mct.py ├── train_mct.py ├── README.md └── WaveLoc.py /gen_dataset/run.sh: -------------------------------------------------------------------------------- 1 | python align_brir.py 2 | python gen_dataset_wav.py 3 | python wav2npy.py 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | models_valid/ 3 | __pycache__/ 4 | .trash/ 5 | pre_result 6 | models 7 | models_pre/ 8 | -------------------------------------------------------------------------------- /images/learning_curve_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bingo-todd/WaveLoc/HEAD/images/learning_curve_all.png -------------------------------------------------------------------------------- /images/end2end-model-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bingo-todd/WaveLoc/HEAD/images/end2end-model-framework.png -------------------------------------------------------------------------------- /utils/file_reader_v2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from BasicTools import get_fpath 3 | 4 | 5 | def file_reader(reverb_set_dir, batch_size=128, is_shuffle=True, 6 | frame_len=320, shift_len=160, n_azi=37): 7 | """ read wav files in given directies, one file per time 8 | Args: 9 | record_set_dir: directory or list of directories where recordings exist 10 | batch_size: 11 | is_shuffle: 12 | Returns: 13 | samples generator, [samples, label_all] 14 | """ 15 | if isinstance(reverb_set_dir, list): 16 | dir_all = reverb_set_dir 17 | else: 18 | dir_all = [reverb_set_dir] 19 | # 20 | fpath_reverb_all = [] 21 | for dir_fpath in dir_all: 22 | fpath_all_tmp = get_fpath(dir_fpath, '.npy', is_absolute=True) 23 | fpath_reverb_all.extend(fpath_all_tmp) 24 | 25 | if is_shuffle: 26 | np.random.shuffle(fpath_reverb_all) 27 | 28 | for fpath_reverb in fpath_reverb_all: 29 | x_d_batch, x_r_batch, y_loc_batch, is_anechoic = np.load(fpath_reverb,allow_pickle=True) 30 | # if x_d.shape[0] == batch_size and x_r.shape[0] == batch_size and y_loc.shape[0] == batch_size: 31 | yield x_r_batch, y_loc_batch 32 | -------------------------------------------------------------------------------- /utils/plot_learning_curve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | from matplotlib.ticker import MaxNLocator 4 | import numpy as np 5 | from BasicTools import plot_tools 6 | plt.rcParams.update({"font.size": "12"}) 7 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D'] 8 | reverb_room_all = ['Room_A', 'Room_B', 'Room_C', 'Room_D'] 9 | 10 | 11 | def plot_train_process(model_dir): 12 | 13 | fig, ax = plt.subplots(1, 1) 14 | plot_settings = {'linewidth': 4} 15 | n_epoch_max = 0 16 | for room_i, room in enumerate(reverb_room_all): 17 | record_fpath = os.path.join(model_dir, room, 'train_record.npz') 18 | record_info = np.load(record_fpath) 19 | cost_record_valid = record_info['cost_record_valid'] 20 | # rmse_record_valid = record_info['azi_rmse_record_valid'] 21 | n_epoch = np.nonzero(cost_record_valid)[0][-1] + 1 22 | if n_epoch > n_epoch_max: 23 | n_epoch_max = n_epoch 24 | ax.plot(cost_record_valid[:n_epoch], **plot_settings, label=room[-1]) 25 | 26 | ax.xaxis.set_major_locator(MaxNLocator(integer=True)) 27 | ax.legend() 28 | ax.set_ylabel('Cross entrophy') 29 | ax.set_xlabel('Epoch(n)') 30 | 31 | plot_tools.savefig(fig, 'learning_curve_all.png', '../images') 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | model_dir = '../models/mct' 37 | plot_train_process(model_dir) 38 | -------------------------------------------------------------------------------- /evaluate_mct.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 5 | 6 | import sys 7 | from WaveLoc import WaveLoc 8 | from BasicTools import plot_tools 9 | from utils import file_reader 10 | 11 | 12 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D'] 13 | reverb_room_all = ['Room_A', 'Room_B', 'Room_C', 'Room_D'] 14 | n_reverb_room = 4 15 | chunk_size = 25 16 | n_test = 4 17 | 18 | 19 | def evaluate_mct(model_dir_base): 20 | rmse_all = np.zeros((n_test, n_reverb_room)) 21 | for room_i, room in enumerate(reverb_room_all): 22 | model_dir = os.path.join(model_dir_base, room) 23 | model_config_fpath = os.path.join(model_dir, 'config.cfg') 24 | model = WaveLoc(file_reader.file_reader, 25 | model_config_fpath, gpu_index=0) 26 | model.load_model(model_dir) 27 | 28 | for test_i in range(n_test): 29 | dataset_dir_test = os.path.join( 30 | '/home/st/Work_Space/Localize/WaveLoc/Data', 31 | f'v{test_i+1}/test/reverb/{room[-1]}') 32 | rmse_all[test_i, room_i] = model.evaluate_chunk_rmse( 33 | dataset_dir_test, 34 | chunk_size=chunk_size) 35 | return rmse_all 36 | 37 | 38 | if __name__ == '__main__': 39 | model_dir = sys.argv[1] #'models/mct' 40 | rmse_all = evaluate_mct(model_dir) 41 | 42 | with open(os.path.join(model_dir, 'result.txt'), 'w') as result_file: 43 | result_file.write(f'{rmse_all}') 44 | result_file.write('mean: {}\n'.format(np.mean(rmse_all, axis=0))) 45 | result_file.write('std: {}\n'.format(np.std(rmse_all, axis=0))) 46 | 47 | print(rmse_all) 48 | print('mean:', np.mean(rmse_all, axis=0)) 49 | print('std:', np.std(rmse_all, axis=0)) 50 | -------------------------------------------------------------------------------- /utils/file_reader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import copy 4 | import os 5 | from BasicTools import get_fpath 6 | from BasicTools import wav_tools 7 | 8 | plt.rcParams.update({"font.size": "12"}) 9 | reverb_room_all = ['A', 'B', 'C', 'D'] 10 | 11 | 12 | def file_reader(record_set_dir, batch_size=-1, is_shuffle=True): 13 | """ read wav files in given directies, one file per time 14 | Args: 15 | record_set_dir: directory or list of directories where recordings exist 16 | Returns: 17 | samples generator, [samples, label_all] 18 | """ 19 | if isinstance(record_set_dir, list): 20 | dirs = record_set_dir 21 | else: 22 | dirs = [record_set_dir] 23 | # 24 | fpath_all = [] 25 | for sub_set_dir in dirs: 26 | fpath_all_sub = get_fpath(sub_set_dir, '.wav', is_absolute=True) 27 | fpath_all.extend(fpath_all_sub) 28 | 29 | if is_shuffle: 30 | np.random.shuffle(fpath_all) 31 | 32 | # print('#file',len(fpath_all)) 33 | # raise Exception() 34 | 35 | if len(fpath_all) < 1: 36 | raise Exception('empty folder:{}'.format(record_set_dir)) 37 | 38 | frame_len = 320 39 | shift_len = 160 40 | n_azi = 37 41 | 42 | if batch_size > 1: 43 | x_all = np.zeros((0, frame_len, 2, 1)) 44 | y_all = np.zeros((0, n_azi)) 45 | 46 | for fpath in fpath_all: 47 | record, fs = wav_tools.read_wav(fpath) 48 | x_file_all = wav_tools.frame_data(record, frame_len, shift_len) 49 | x_file_all = np.expand_dims(x_file_all, axis=-1) 50 | 51 | # onehot azi label 52 | n_sample_file = x_file_all.shape[0] 53 | fname = os.path.basename(fpath) 54 | azi = np.int16(fname.split('_')[0]) 55 | y_file_all = np.zeros((n_sample_file, n_azi)) 56 | y_file_all[:, azi] = 1 57 | 58 | if batch_size > 0: 59 | x_all = np.concatenate((x_all, x_file_all), axis=0) 60 | y_all = np.concatenate((y_all, y_file_all), axis=0) 61 | 62 | while x_all.shape[0] > batch_size: 63 | x_batch = copy.deepcopy(x_all[:batch_size]) 64 | y_batch = copy.deepcopy(y_all[:batch_size]) 65 | 66 | x_all = x_all[batch_size:] 67 | y_all = y_all[batch_size:] 68 | 69 | yield [x_batch, y_batch] 70 | else: 71 | yield [x_file_all, y_file_all] 72 | -------------------------------------------------------------------------------- /train_mct.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import configparser 4 | from multiprocessing import Process 5 | from WaveLoc import WaveLoc 6 | from utils import file_reader_v2 7 | 8 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 10 | 11 | 12 | data_dir = 'Data' 13 | train_set_dir_base = os.path.join(data_dir, 'v1/npy/train') 14 | valid_set_dir_base = os.path.join(data_dir, 'v1/npy/valid') 15 | 16 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D'] 17 | reverb_room_all = ['Room_A', 'Room_B', 'Room_C', 'Room_D'] 18 | 19 | model_basic_settings = {'fs': 16000, 20 | 'n_band': 32, 21 | 'cf_low': 70, 22 | 'cf_high': 7000, 23 | 'frame_len': 320, 24 | 'shift_len': 160, 25 | 'filter_len': 320, 26 | 'azi_num': 37, 27 | 'is_use_gtf': False, 28 | 'is_padd': False} 29 | gpu_index = 1 30 | 31 | 32 | def train_mct(room_tar, model_dir): 33 | if not os.path.exists(model_dir): 34 | os.makedirs(model_dir) 35 | 36 | # filter out room_tar from room_all 37 | print('tar_room', room_tar) 38 | mct_room_all = [room for room in room_all if room != room_tar] 39 | config = configparser.ConfigParser() 40 | config['model'] = {**model_basic_settings} 41 | config['train'] = {'batch_size': 128, 42 | 'max_epoch': 50, 43 | 'is_print_log': False, 44 | 'train_set_dir': ';'.join( 45 | [os.path.join(train_set_dir_base, room) 46 | for room in mct_room_all]), 47 | 'valid_set_dir': ';'.join( 48 | [os.path.join(valid_set_dir_base, room) 49 | for room in mct_room_all])} 50 | 51 | config_fpath = os.path.join(model_dir, 'config.cfg') 52 | with open(config_fpath, 'w') as config_file: 53 | if config_file is None: 54 | raise Exception('fail to create file') 55 | config.write(config_file) 56 | 57 | model = WaveLoc(file_reader_v2.file_reader, config_fpath=config_fpath, 58 | gpu_index=gpu_index) 59 | model.train_model(model_dir) 60 | 61 | 62 | if __name__ == '__main__': 63 | thread_all = [] 64 | for room_tar in reverb_room_all: 65 | model_dir = 'models/mct/{}'.format(room_tar) 66 | thread = Process(target=train_mct,args=(room_tar,model_dir)) 67 | thread.start() 68 | thread_all.append(thread) 69 | 70 | [thread.join() for thread in thread_all] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End2End sound localization model 2 | 3 | Reference: 4 | 5 | P. Vecchiotti, N. Ma, S. Squartini, and G. J. Brown, “END-TO-END BINAURAL SOUND LOCALISATION FROM THE RAW WAVEFORM,” in 2019 IEEE INTERNATIONAL CONFERENCE ON ACOUSTICS, SPEECH AND SIGNAL PROCESSING (ICASSP), 345 E 47TH ST, NEW YORK, NY 10017 USA, 2019, pp. 451–455. 6 | 7 | **Only WaveLoc-GTF is implemented** 8 | 9 | ## Model 10 | 11 | 12 | ## Training 13 | ### Dataset 14 | - BRIR 15 | 16 | Surrey binaural room impulse response (BRIR) database, including anechoic room and 4 reverberation room. 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 |
Room A B C D
RT_60(s) 0.32 0.47 0.68 0.89
DDR(dB) 6.09 5.31 8.82 6.12
29 | 30 | - Sound source 31 | 32 | TIMIT database 33 | 34 | Sentences per azimuth 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 |
Train Validate Evaluate
24 6 15
46 | 47 | 48 | ## Multi-conditional training(MCT) 49 | 50 | For 51 | For each reverberant room, the rest 3 reverberant rooms and anechoic room are used for training 52 | 53 | Training curves 54 |
55 | 56 |
57 | 58 | 59 | ## Evaluation 60 | Root mean square error(RMSE) is used as the metrics of performance. For each reverberant room, the evaluation was performed 3 times to get more stable results and the test dataset was regenerated each time. 61 | 62 | Since binaural sound is directly fed to models without extra preprocess and there may be short pulses in speech, the localization result was reported based on chunks rather than frames. Each chunk consisted of 25 consecutive frames. 63 | 64 | ### My result vs. paper 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 |
Reverberant room A B C D
My result 1.5 2.0 1.4 2.7
Result in paper 1.5 3.0 1.7 3.5
81 | 82 | ## Main dependencies 83 | - tensorflow-1.14 84 | - pysofa (can be installed by pip) 85 | - BasicTools (in my other [repository](https://github.com/bingo-todd/BasicTools)) 86 | -------------------------------------------------------------------------------- /gen_dataset/align_brir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pysofa 4 | from BasicTools import DspTools 5 | import matplotlib.pyplot as plt 6 | 7 | """Align brirs of reverberant rooms to that from the anechoic room 8 | """ 9 | 10 | brirs_dir = os.path.expanduser('~/Work_Space/Data/RealRoomBRIRs') 11 | brirs_aligned_dir = 'brirs_aligned' 12 | os.mkdir(brirs_aligned_dir, exist_ok=True) 13 | 14 | n_azi = 37 15 | n_channel = 2 16 | rever_room_all = ('Room_A', 'Room_B', 'Room_C', 'Room_D') 17 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D'] 18 | 19 | 20 | 21 | def plot_brirs(): 22 | for room in room_all: 23 | brirs_fpath = f'brirs_aligned/{room}.npy' 24 | brirs = np.load(brirs_fpath) 25 | print(brirs.shape) 26 | fig, ax = plt.subplots(1, 1) 27 | for azi_i in range(0, 37, 8): 28 | ax.plot(brirs[azi_i, :, 0]) 29 | ax.set_xlim([0, 2000]) 30 | fig.savefig(f'brirs_aligned/{room}.png') 31 | plt.close(fig) 32 | 33 | 34 | def load_brirs(room): 35 | brirs_fpath = os.path.join(brirs_dir, f'UniS_{room}_BRIR_16k.sofa') 36 | brirs = pysofa.SOFA(brirs_fpath).FIR.IR.transpose(0, 2, 1) 37 | return brirs 38 | 39 | 40 | def align_brirs(): 41 | """ 42 | For each reverberant room, calculate BRIRs delays of each sound position 43 | and align BRIRs according to the averaged delay 44 | """ 45 | 46 | brirs_anechoic = load_brirs('Anechoic') 47 | np.save(os.path.join(brirs_aligned_dir, 'Anechoic.npy'), brirs_anechoic) 48 | 49 | delay_all = np.zeros((n_azi, n_channel)) 50 | for reverb_room in rever_room_all: 51 | print(reverb_room) 52 | brirs = load_brirs(reverb_room) 53 | for azi_i in range(n_azi): 54 | for channel_i in range(n_channel): 55 | delay_all[azi_i, channel_i] = DspTools.cal_delay( 56 | brirs_anechoic[azi_i, :, channel_i], 57 | brirs[azi_i, :, channel_i]) 58 | delay_mean = np.int16(np.round(np.mean(delay_all))) 59 | if delay_mean > 0: 60 | brirs_aligned = np.concatenate((np.zeros((n_azi, delay_mean, 61 | n_channel)), 62 | brirs), 63 | axis=1) 64 | else: 65 | brirs_aligned = brirs[:, -delay_mean:, :] 66 | 67 | np.save(os.path.join(brirs_aligned_dir, f'{reverb_room}.npy'), 68 | brirs_aligned) 69 | 70 | if False: 71 | fig, ax = plt.subplots(1, 1) 72 | ax.plot(brirs_anechoic[0, :, 0], label='anechoic') 73 | ax.plot(brirs_aligned[0, :, 0], label='reverb_aligned') 74 | ax.plot(brirs[0, :, 0], label='reverb') 75 | ax.set_xlim((0, 500)) 76 | ax.legend() 77 | ax.set_title(reverb_room) 78 | fig.savefig(os.path.join(brirs_aligned_dir, 79 | f'{reverb_room}_aligned.png')) 80 | plt.close(fig) 81 | 82 | 83 | if __name__ == '__main__': 84 | align_brirs() 85 | # plot_brirs() 86 | -------------------------------------------------------------------------------- /gen_dataset/wav2npy.py: -------------------------------------------------------------------------------- 1 | """ 2 | frame raw wavefrom and save in batches 3 | """ 4 | 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import copy 9 | import os 10 | from BasicTools import get_fpath, wav_tools, ProcessBar 11 | 12 | plt.rcParams.update({"font.size": "12"}) 13 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D'] 14 | 15 | 16 | def wav2npy(reverb_set_dir, npy_dir, is_anechoic): 17 | """ read wav files in given directies, one file per time 18 | Args: 19 | record_set_dir: directory or list of directories where recordings exist 20 | batch_size: 21 | is_shuffle: 22 | Returns: 23 | samples generator, [samples, label_all] 24 | """ 25 | 26 | frame_len = 320 27 | shift_len = 160 28 | n_azi = 37 29 | batch_size = 128 30 | 31 | os.makedirs(npy_dir, exist_ok=True) 32 | 33 | # 34 | fpath_reverb_all = get_fpath(reverb_set_dir, '.wav', is_absolute=True) 35 | if len(fpath_reverb_all) < 1: 36 | raise Exception('empty folder:{}'.format(reverb_set_dir)) 37 | 38 | pb = ProcessBar(len(fpath_reverb_all)) 39 | 40 | batch_count = 0 41 | x_r = np.zeros((0, frame_len, 2, 1)) 42 | x_d = np.zeros((0, frame_len, 2, 1)) 43 | y_loc = np.zeros((0, n_azi)) 44 | 45 | for fpath_reverb in fpath_reverb_all: 46 | pb.update() 47 | # reverb signal 48 | record, fs = wav_tools.read_wav(fpath_reverb) 49 | x_r_file = np.expand_dims( 50 | wav_tools.frame_data(record, frame_len, shift_len), 51 | axis=-1) 52 | # direct signal 53 | fpath_direct = fpath_reverb.replace('reverb', 'direct') 54 | direct, fs = wav_tools.read_wav(fpath_direct) 55 | x_d_file = np.expand_dims( 56 | wav_tools.frame_data(direct, frame_len, shift_len), 57 | axis=-1) 58 | 59 | # onehot azi label 60 | n_sample_file = x_d_file.shape[0] 61 | if x_r_file.shape[0] != n_sample_file: 62 | raise Exception('sample number do not consist') 63 | 64 | fname = os.path.basename(fpath_reverb) 65 | azi = np.int16(fname.split('_')[0]) 66 | y_loc_file = np.zeros((n_sample_file, n_azi)) 67 | y_loc_file[:, azi] = 1 68 | 69 | x_r = np.concatenate((x_r, x_r_file), axis=0) 70 | x_d = np.concatenate((x_d, x_d_file), axis=0) 71 | y_loc = np.concatenate((y_loc, y_loc_file), axis=0) 72 | 73 | while x_d.shape[0] > batch_size: 74 | x_r_batch = x_r[:batch_size] 75 | x_d_batch = x_d[:batch_size] 76 | y_loc_batch = y_loc[:batch_size] 77 | 78 | npy_fpath = os.path.join(npy_dir, '{}.npy'.format(batch_count)) 79 | np.save(npy_fpath,[x_d_batch, x_r_batch, y_loc_batch, is_anechoic]) 80 | batch_count = batch_count + 1 81 | 82 | x_r = x_r[batch_size:] 83 | x_d = x_d[batch_size:] 84 | y_loc = y_loc[batch_size:] 85 | 86 | 87 | 88 | 89 | if __name__ == '__main__': 90 | 91 | for set_type in ['train', 'valid']: 92 | for room in room_all: 93 | print(room) 94 | wav_set_dir = '../Data/v1/{}/reverb/{}'.format(set_type, room) 95 | npy_set_dir = '../Data/v1/npy/{}/{}'.format(set_type, room) 96 | wav2npy(wav_set_dir, npy_set_dir, False) 97 | 98 | 99 | for test_i in range(1,5): 100 | for room in room_all: 101 | print(room) 102 | wav_set_dir = f'../Data/v{test_i}/test/reverb/{room}' 103 | npy_set_dir = f'../Data/v{test_i}/npy/test/{room}' 104 | wav2npy(wav_set_dir, npy_set_dir, False) -------------------------------------------------------------------------------- /gen_dataset/gen_dataset_wav.py: -------------------------------------------------------------------------------- 1 | """ 2 | synthesize spatial recordings 3 | """ 4 | 5 | import numpy as np 6 | import os 7 | from multiprocessing import Process 8 | from BasicTools import ProcessBarMulti 9 | from BasicTools import wav_tools 10 | from BasicTools.Filter_GPU import Filter_GPU 11 | from BasicTools.get_fpath import get_fpath 12 | 13 | 14 | TIMIT_dir = os.path.expanduser('~/Work_Space/Data/TIMIT_wav') 15 | brirs_dir = 'brirs_aligned' 16 | data_dir = '../Data' 17 | 18 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D'] 19 | n_room = 5 20 | n_azi = 37 # -90~90 in step of 5 21 | n_wav_per_azi_all = {'train': 24, 22 | 'valid': 6, 23 | 'test': 15} 24 | 25 | fs = 16e3 26 | frame_len = int(20e-3*fs) 27 | shift_len = int(10e-3*fs) 28 | 29 | 30 | def load_brirs(room): 31 | """load brirs of given room 32 | Args: 33 | room: room name from ['Anechoic','A','B','C','D'] 34 | """ 35 | brirs_fpath = os.path.join(brirs_dir, f'{room}.npy') 36 | brirs = np.load(brirs_fpath) 37 | return brirs 38 | 39 | 40 | def truncate_silence(x): 41 | """ trip off slient frames in begining and end(only one frame) 42 | """ 43 | vad_flag = wav_tools.vad(x, frame_len, shift_len) 44 | if not vad_flag[0]: # silence in the first frame 45 | x = x[frame_len:] 46 | if not vad_flag[-1]: # silence in the last frame 47 | x = x[:-frame_len] 48 | return x 49 | 50 | 51 | def syn_record(src_fpath_all, set_dir, n_wav_per_azi, task_i, pb): 52 | """synthesize spatial recordings as well corresponding direct sound for 53 | each set 54 | """ 55 | filter_gpu = Filter_GPU(gpu_index=1) 56 | 57 | brirs_direct = load_brirs('Anechoic') 58 | wav_count = 0 59 | for room in room_all: 60 | direct_dir = os.path.join(set_dir, 'direct', room) 61 | os.makedirs(direct_dir, exist_ok=True) 62 | rever_dir = os.path.join(set_dir, 'reverb', room) 63 | os.makedirs(rever_dir, exist_ok=True) 64 | 65 | brirs_room = load_brirs(room) 66 | for azi_i in range(n_azi): 67 | for i in range(n_wav_per_azi): 68 | pb.update(task_i) 69 | src_fpath = src_fpath_all[wav_count] 70 | wav_count = wav_count+1 71 | 72 | src, fs = wav_tools.read_wav(src_fpath) 73 | src = truncate_silence(src) 74 | 75 | direct = filter_gpu.brir_filter(src, brirs_direct[azi_i]) 76 | # direct = wav_tools.brir_filter(src, brirs_direct[azi_i]) 77 | direct_fpath = os.path.join(direct_dir, f'{azi_i}_{i}.wav') 78 | wav_tools.write_wav(direct, fs, direct_fpath) 79 | 80 | reverb = filter_gpu.brir_filter(src, brirs_room[azi_i]) 81 | # reverb = wav_tools.brir_filter(src, brirs_room[azi_i]) 82 | reverb_fpath = os.path.join(rever_dir, f'{azi_i}_{i}.wav') 83 | wav_tools.write_wav(reverb, fs, reverb_fpath) 84 | 85 | 86 | def gen_dataset(dir_path, set_type_all): 87 | 88 | n_wav_train = n_azi * n_room * n_wav_per_azi_all['train'] 89 | n_wav_valid = n_azi * n_room * n_wav_per_azi_all['valid'] 90 | n_wav_test = n_azi * n_room * n_wav_per_azi_all['test'] 91 | # prepare sound source 92 | # train and validate 93 | if not os.path.exists('fpath_TIMIT_train_all.npy'): 94 | TIMIT_train_dir = os.path.join(TIMIT_dir, 'TIMIT/TRAIN') 95 | src_fpath_all = get_fpath(TIMIT_train_dir, '.wav', 96 | is_absolute=True) 97 | np.save('fpath_TIMIT_train_all.npy', src_fpath_all) 98 | src_fpath_all = np.load('fpath_TIMIT_train_all.npy') 99 | print('train', len(src_fpath_all)) 100 | print('train+valid', n_wav_train+n_wav_valid) 101 | np.random.shuffle(src_fpath_all) 102 | src_fpath_train_all = src_fpath_all[:n_wav_train] 103 | src_fpath_valid_all = src_fpath_all[n_wav_train:] 104 | 105 | # test 106 | if not os.path.exists('fpath_TIMIT_test_all.npy'): 107 | TIMIT_test_dir = os.path.join(TIMIT_dir, 'TIMIT/TEST') 108 | src_fpath_test_all = get_fpath(TIMIT_test_dir, '.wav', 109 | is_absolute=True) 110 | np.save('fpath_TIMIT_test_all.npy', src_fpath_test_all) 111 | src_fpath_test_all = np.load('fpath_TIMIT_test_all.npy') 112 | print('test', len(src_fpath_test_all)) 113 | print('test', n_wav_test) 114 | np.random.shuffle(src_fpath_test_all) 115 | 116 | # np.save(os.path.join(dir_path, 'src_fpath_all.npy'), 117 | # [src_fpath_train_all, src_fpath_valid_all, src_fpath_test_all]) 118 | 119 | src_fpath_all = (src_fpath_train_all, 120 | src_fpath_valid_all, 121 | src_fpath_test_all) 122 | 123 | n_wav_all = [len(src_fpath_all[i]) for i in range(len(set_type_all))] 124 | pb = ProcessBarMulti(n_wav_all, desc_all=set_type_all) 125 | proc_all = [] 126 | for i, set_type in enumerate(set_type_all): 127 | print(set_type) 128 | set_dir = os.path.join(dir_path, set_type) 129 | proc = Process(target=syn_record, 130 | args=(src_fpath_all[i], set_dir, 131 | n_wav_per_azi_all[set_type], 132 | str(i), pb)) 133 | proc.start() 134 | proc_all.append(proc) 135 | [proc.join() for proc in proc_all] 136 | 137 | 138 | if __name__ == '__main__': 139 | 140 | # train dataset and validation dataset 141 | dataset_dir = os.path.join(data_dir, 'v1') 142 | os.makedirs(dataset_dir, exist_ok=True) 143 | gen_dataset(dir_path=dataset_dir, 144 | set_type_all=['train', 'valid']) 145 | 146 | # test dataset 147 | for i in range(1, 5): 148 | dataset_dir = os.path.join(data_dir, f'v{i}') 149 | os.makedirs(dataset_dir, exist_ok=True) 150 | gen_dataset(dir_path=dataset_dir, set_type_all=['test']) 151 | -------------------------------------------------------------------------------- /WaveLoc.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import os 5 | import configparser 6 | import time 7 | import gammatone.filters as gt_filters 8 | import tensorflow as tf 9 | 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 11 | 12 | 13 | class WaveLoc(object): 14 | """ 15 | """ 16 | def __init__(self, file_reader, config_fpath=None, gpu_index=0): 17 | """ 18 | """ 19 | 20 | # constant settings 21 | self.epsilon = 1e-20 22 | self._file_reader = file_reader 23 | 24 | self._graph = tf.Graph() 25 | config = tf.compat.v1.ConfigProto() 26 | config.gpu_options.allow_growth = True 27 | # config.gpu_options.visible_device_list = '{}'.format(gpu_index) 28 | self._sess = tf.compat.v1.Session(graph=self._graph, config=config) 29 | 30 | self._load_cfg(config_fpath) 31 | self._build_model() 32 | 33 | def _add_log(self, log_info): 34 | self._log_file.write(log_info) 35 | self._log_file.write('\n') 36 | self._log_file.flush() 37 | if self.is_print_log: 38 | print(log_info) 39 | 40 | def _load_cfg(self, config_fpath): 41 | if config_fpath is not None and os.path.exists(config_fpath): 42 | config = configparser.ConfigParser() 43 | config.read(config_fpath) 44 | 45 | # settings for model 46 | self.fs = np.int16(config['model']['fs']) 47 | self.n_band = np.int16(config['model']['n_band']) 48 | self.cf_low = np.int16(config['model']['cf_low']) 49 | self.cf_high = np.int16(config['model']['cf_high']) 50 | self.frame_len = np.int16(config['model']['frame_len']) 51 | self.shift_len = np.int16(config['model']['shift_len']) 52 | self.filter_len = np.int16(config['model']['filter_len']) 53 | self.is_padd = config['model']['is_padd'] == 'True' 54 | self.n_azi = np.int16(config['model']['azi_num']) 55 | 56 | # settings for training 57 | self.batch_size = np.int16(config['train']['batch_size']) 58 | self.max_epoch = np.int16(config['train']['max_epoch']) 59 | self.is_print_log = config['train']['is_print_log'] == 'True' 60 | self.train_set_dir = config['train']['train_set_dir'].split(';') 61 | self.valid_set_dir = config['train']['valid_set_dir'].split(';') 62 | if self.valid_set_dir[0] == '': 63 | self.valid_set_dir = None 64 | 65 | print('Train set:') 66 | [print('\t{}'.format(item)) for item in self.train_set_dir] 67 | 68 | print('Valid set:') 69 | [print('\t{}'.format(item)) for item in self.valid_set_dir] 70 | 71 | else: 72 | print(config_fpath) 73 | raise OSError 74 | 75 | def get_gtf_kernel(self): 76 | """ 77 | """ 78 | cfs = gt_filters.erb_space(self.cf_low, self.cf_high, self.n_band) 79 | self.cfs = cfs 80 | 81 | sample_times = np.arange(0, self.filter_len, 1)/self.fs 82 | irs = np.zeros((self.filter_len, self.n_band), dtype=np.float32) 83 | 84 | EarQ = 9.26449 85 | minBW = 24.7 86 | order = 1 87 | N = 4 88 | for band_i in range(self.n_band): 89 | ERB = ((cfs[band_i]/EarQ)**order+minBW**order)**(1/order) 90 | b = 1.019*ERB 91 | numerator = np.multiply(sample_times**(N-1), 92 | np.cos(2*np.pi*cfs[band_i]*sample_times)) 93 | denominator = np.exp(2*np.pi*b*sample_times) 94 | irs[:, band_i] = np.divide(numerator, denominator) 95 | 96 | gain = np.max(np.abs(np.fft.fft(irs, axis=0)), axis=0) 97 | irs_gain_norm = np.divide(np.flipud(irs), gain) 98 | if self.is_padd: 99 | kernel = np.concatenate((irs_gain_norm, 100 | np.zeros((self.filter_len, self.n_band))), 101 | axis=0) 102 | else: 103 | kernel = irs_gain_norm 104 | return kernel 105 | 106 | def _fcn_layers(self, input, *layers_setting): 107 | for setting in layers_setting: 108 | fcn_size = setting['fcn_size'] 109 | activation = setting['activation'] 110 | rate = setting['rate'] 111 | 112 | layer_fcn = tf.keras.layers.Dense(units=fcn_size, 113 | activation=activation) 114 | if rate > 0: 115 | layer_drop = tf.keras.layers.Dropout(rate=rate) 116 | output = layer_fcn(layer_drop(input)) 117 | elif rate == 0: 118 | output = layer_fcn(input) 119 | else: 120 | raise Exception('illegal dropout rate') 121 | input = output 122 | return output 123 | 124 | def _build_model_subband(self, input): 125 | """ 126 | """ 127 | layer1_conv = tf.keras.layers.Conv2D(filters=6, 128 | kernel_size=[18, 2], 129 | strides=[1, 1], 130 | activation=tf.nn.relu) 131 | layer1_pool = tf.keras.layers.MaxPool2D([4, 1], [4, 1]) 132 | layer1_out = layer1_pool(layer1_conv(input)) 133 | 134 | layer2_conv = tf.keras.layers.Conv2D(filters=12, 135 | kernel_size=[6, 1], 136 | strides=[1, 1], 137 | activation=tf.nn.relu) 138 | layer2_pool = tf.keras.layers.MaxPool2D([4, 1], [4, 1]) 139 | layer2_out = layer2_pool(layer2_conv(layer1_out)) 140 | 141 | flatten_len = np.prod(layer2_out.get_shape().as_list()[1:]) 142 | out = tf.reshape(layer2_out, [-1, flatten_len]) # flatten 143 | return out 144 | 145 | def _build_model(self): 146 | """Build graph 147 | """ 148 | # gammatone layer kernel initalizer 149 | with self._graph.as_default(): 150 | kernel_initializer = tf.constant_initializer( 151 | self.get_gtf_kernel()) 152 | 153 | if self.is_padd: 154 | gtf_kernel_len = 2*self.filter_len 155 | else: 156 | gtf_kernel_len = self.filter_len 157 | 158 | x = tf.compat.v1.placeholder(shape=[None, self.frame_len, 2, 1], 159 | dtype=tf.float32, 160 | name='x') # 161 | 162 | gt_layer = tf.keras.layers.Conv2D( 163 | filters=self.n_band, 164 | kernel_size=[gtf_kernel_len, 1], 165 | strides=[1, 1], 166 | padding='same', 167 | kernel_initializer=kernel_initializer, 168 | trainable=False, use_bias=False) 169 | 170 | # add to model for test 171 | # self.layer1_conv = gt_layer 172 | 173 | # amplitude normalization across frequency channs 174 | # problem: silence ? 175 | x_band_all = gt_layer(x) 176 | amp_max = tf.reduce_max( 177 | tf.reduce_max( 178 | tf.reduce_max( 179 | tf.abs(x_band_all), 180 | axis=1, keepdims=True), 181 | axis=2, keepdims=True), 182 | axis=3, keepdims=True) 183 | x_band_norm_all = tf.divide(x_band_all, amp_max) 184 | 185 | # layer1_pool 186 | gt_layer_pool = tf.keras.layers.MaxPool2D([2, 1], [2, 1]) 187 | gt_layer_output = gt_layer_pool(x_band_norm_all) 188 | 189 | band_out_list = [] 190 | for band_i in range(self.n_band): 191 | band_output = self._build_model_subband( 192 | tf.expand_dims( 193 | gt_layer_output[:, :, :, band_i], 194 | axis=-1)) 195 | band_out_list.append(band_output) 196 | band_out = tf.concat(band_out_list, axis=1) 197 | 198 | layer4 = {'fcn_size': 1024, 199 | 'activation': tf.nn.relu, 200 | 'rate': 0.5} 201 | layer5 = {'fcn_size': 1024, 202 | 'activation': tf.nn.relu, 203 | 'rate': 0.5} 204 | output_layer = {'fcn_size': self.n_azi, 205 | 'activation': tf.nn.softmax, 206 | 'rate': 0} 207 | 208 | y_est = self._fcn_layers(band_out, layer4, layer5, output_layer) 209 | 210 | # groundtruth of two tasks 211 | y = tf.compat.v1.placeholder(shape=[None, self.n_azi], 212 | dtype=tf.float32) 213 | # cost function 214 | cost = self._cal_cross_entropy(y_est, y) 215 | # additional measurement of localization 216 | azi_rmse = self._cal_azi_rmse(y_est, y) 217 | 218 | # 219 | lr = tf.compat.v1.placeholder(tf.float32, shape=[]) 220 | opt_step = tf.compat.v1.train.AdamOptimizer( 221 | learning_rate=lr).minimize(cost) 222 | 223 | # initialize of model 224 | init = tf.compat.v1.global_variables_initializer() 225 | self._sess.run(init) 226 | 227 | # input and output 228 | self._x = x 229 | self._y_est = y_est 230 | # groundtruth 231 | self._y = y 232 | # cost function and optimizer 233 | self._cost = cost 234 | self._azi_rmse = azi_rmse 235 | self._lr = lr 236 | self._opt_step = opt_step 237 | 238 | def _cal_cross_entropy(self, y_est, y): 239 | cross_entropy = -tf.reduce_mean( 240 | tf.reduce_sum( 241 | tf.multiply( 242 | y, tf.math.log(y_est+self.epsilon)), 243 | axis=1)) 244 | return cross_entropy 245 | 246 | def _cal_mse(self, y_est, y): 247 | rmse = tf.reduce_mean(tf.reduce_sum((y-y_est)**2, axis=1)) 248 | return rmse 249 | 250 | def _cal_azi_rmse(self, y_est, y): 251 | azi_est = tf.argmax(y_est, axis=1) 252 | azi = tf.argmax(y, axis=1) 253 | diff = tf.cast(azi_est - azi, tf.float32) 254 | return tf.sqrt(tf.reduce_mean(tf.pow(diff, 2))) 255 | 256 | def _cal_cp(self, y_est, y): 257 | equality = tf.equal(tf.argmax(y_est, axis=1), tf.argmax(y, axis=1)) 258 | cp = tf.reduce_mean(tf.cast(equality, tf.float32)) 259 | return cp 260 | 261 | def load_model(self, model_dir): 262 | """load model""" 263 | if not os.path.exists(model_dir): 264 | raise Exception('no model exists in {}'.format(model_dir)) 265 | 266 | with self._graph.as_default(): 267 | # restore model 268 | saver = tf.compat.v1.train.Saver() 269 | ckpt = tf.train.get_checkpoint_state(model_dir) 270 | if ckpt and ckpt.model_checkpoint_path: 271 | saver.restore(self._sess, ckpt.model_checkpoint_path) 272 | 273 | print(f'load model from {model_dir}') 274 | 275 | def _train_record_init(self, model_dir, is_load_model): 276 | """ 277 | """ 278 | if is_load_model: 279 | record_info = np.load(os.path.join(model_dir, 280 | 'train_record.npz')) 281 | cost_record_valid = record_info['cost_record_valid'] 282 | azi_rmse_record_valid = record_info['azi_rmse_record_valid'] 283 | lr_value = record_info['lr'] 284 | best_epoch = record_info['best_epoch'] 285 | min_valid_cost = record_info['min_valid_cost'] 286 | last_epoch = np.nonzero(cost_record_valid)[0][-1] 287 | else: 288 | cost_record_valid = np.zeros(self.max_epoch) 289 | azi_rmse_record_valid = np.zeros(self.max_epoch) 290 | lr_value = 1e-3 291 | min_valid_cost = np.infty 292 | best_epoch = 0 293 | last_epoch = -1 294 | return [cost_record_valid, azi_rmse_record_valid, 295 | lr_value, min_valid_cost, best_epoch, last_epoch] 296 | 297 | def train_model(self, model_dir, is_load_model=False): 298 | """Train model either from initial state(self._build_model()) or 299 | already existed model 300 | """ 301 | if is_load_model: 302 | self.load_model(model_dir) 303 | 304 | if not os.path.exists(model_dir): 305 | os.makedirs(self.model_dir) 306 | 307 | # open text file for logging 308 | self._log_file = open(os.path.join(model_dir, 'log.txt'), 'a') 309 | 310 | with self._graph.as_default(): 311 | 312 | [cost_record_valid, azi_rmse_record_valid, 313 | lr_value, min_valid_cost, 314 | best_epoch, last_epoch] = self._train_record_init(model_dir, 315 | is_load_model) 316 | 317 | saver = tf.compat.v1.train.Saver() 318 | print('start training') 319 | for epoch in range(last_epoch+1, self.max_epoch): 320 | t_start = time.time() 321 | print(f'epoch {epoch}') 322 | batch_generator = self._file_reader(self.train_set_dir) 323 | for x, y in batch_generator: 324 | self._sess.run(self._opt_step, 325 | feed_dict={self._x: x, 326 | self._y: y, 327 | self._lr: lr_value}) 328 | # model test 329 | [cost_record_valid[epoch], 330 | azi_rmse_record_valid[epoch]] = self.evaluate( 331 | self.valid_set_dir) 332 | 333 | # write to log 334 | iter_time = time.time()-t_start 335 | self._add_log(' '.join((f'epoch:{epoch}', 336 | f'lr:{lr_value}', 337 | f'time:{iter_time:.2f}\n'))) 338 | 339 | log_template = '\t cost:{:.2f} azi_rmse:{:.2f}\n' 340 | self._add_log('\t valid ') 341 | self._add_log(log_template.format( 342 | cost_record_valid[epoch], 343 | azi_rmse_record_valid[epoch])) 344 | 345 | # 346 | if min_valid_cost > cost_record_valid[epoch]: 347 | self._add_log('find new optimal\n') 348 | best_epoch = epoch 349 | min_valid_cost = cost_record_valid[epoch] 350 | saver.save(self._sess, os.path.join(model_dir, 351 | 'model'), 352 | global_step=epoch) 353 | 354 | # save record info 355 | np.savez(os.path.join(model_dir, 'train_record'), 356 | cost_record_valid=cost_record_valid, 357 | azi_rmse_record_valid=azi_rmse_record_valid, 358 | lr=lr_value, 359 | best_epoch=best_epoch, 360 | min_valid_cost=min_valid_cost) 361 | 362 | # early stop 363 | if epoch-best_epoch > 5: 364 | print('early stop\n', min_valid_cost) 365 | self._add_log('early stop{}\n'.format(min_valid_cost)) 366 | break 367 | 368 | # learning rate decay 369 | if epoch > 2: # no better performance in 2 epoches 370 | min_valid_cost_local = np.min( 371 | cost_record_valid[epoch-1:epoch+1]) 372 | if cost_record_valid[epoch-2] < min_valid_cost_local: 373 | lr_value = lr_value*.2 374 | 375 | self._log_file.close() 376 | 377 | if True: 378 | fig, ax = plt.subplots(2, 1, sharex=True, tight_layout=True) 379 | ax[0].plot(cost_record_valid) 380 | ax[0].set_ylabel('cross entrophy') 381 | 382 | ax[1].plot(azi_rmse_record_valid) 383 | ax[1].set_ylabel('rmse(deg)') 384 | # 385 | fig_path = os.path.join(model_dir, 'train_curve.png') 386 | plt.savefig(fig_path) 387 | 388 | def predict(self, x): 389 | """Model output of x 390 | """ 391 | y_est = self._sess.run(self._y_est, feed_dict={self._x: x}) 392 | return y_est 393 | 394 | def evaluate(self, set_dir): 395 | cost_all = 0. 396 | rmse_all = 0. 397 | n_sample_all = 0 398 | batch_generator = self._file_reader(set_dir) 399 | for x, y in batch_generator: 400 | n_sample_tmp = x.shape[0] 401 | [cost_tmp, rmse_tmp] = self._sess.run([self._cost, self._azi_rmse], 402 | feed_dict={self._x: x, 403 | self._y: y}) 404 | # 405 | n_sample_all = n_sample_all+n_sample_tmp 406 | cost_all = cost_all+n_sample_tmp*cost_tmp 407 | rmse_all = rmse_all+n_sample_tmp*(rmse_tmp**2) 408 | 409 | # average across all set 410 | cost_all = cost_all/n_sample_all 411 | rmse_all = np.sqrt(rmse_all/n_sample_all) 412 | return [cost_all, rmse_all] 413 | 414 | def evaluate_chunk_rmse(self, record_set_dir, chunk_size=25): 415 | """ Evaluate model on given data_set, only for loc 416 | Args: 417 | data_set_dir: 418 | Returns: 419 | [rmse_chunk,cp_chunk,rmse_frame,cp_frame] 420 | """ 421 | rmse_chunk = 0. 422 | n_chunk = 0 423 | 424 | for x, y in self._file_reader(record_set_dir, is_shuffle=False): 425 | sample_num = x.shape[0] 426 | azi_true = np.argmax(y[0]) 427 | 428 | y_est = self.predict(x) 429 | for sample_i in range(0, sample_num-chunk_size+1): 430 | azi_est_chunk = np.argmax( 431 | np.mean( 432 | y_est[sample_i:sample_i+chunk_size], 433 | axis=0)) 434 | rmse_chunk = rmse_chunk+(azi_est_chunk-azi_true)**2 435 | n_chunk = n_chunk+1 436 | 437 | rmse_chunk = np.sqrt(rmse_chunk/n_chunk)*5 438 | return rmse_chunk 439 | --------------------------------------------------------------------------------