├── gen_dataset
├── run.sh
├── align_brir.py
├── wav2npy.py
└── gen_dataset_wav.py
├── .gitignore
├── images
├── learning_curve_all.png
└── end2end-model-framework.png
├── utils
├── file_reader_v2.py
├── plot_learning_curve.py
└── file_reader.py
├── evaluate_mct.py
├── train_mct.py
├── README.md
└── WaveLoc.py
/gen_dataset/run.sh:
--------------------------------------------------------------------------------
1 | python align_brir.py
2 | python gen_dataset_wav.py
3 | python wav2npy.py
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | models/
2 | models_valid/
3 | __pycache__/
4 | .trash/
5 | pre_result
6 | models
7 | models_pre/
8 |
--------------------------------------------------------------------------------
/images/learning_curve_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bingo-todd/WaveLoc/HEAD/images/learning_curve_all.png
--------------------------------------------------------------------------------
/images/end2end-model-framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bingo-todd/WaveLoc/HEAD/images/end2end-model-framework.png
--------------------------------------------------------------------------------
/utils/file_reader_v2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from BasicTools import get_fpath
3 |
4 |
5 | def file_reader(reverb_set_dir, batch_size=128, is_shuffle=True,
6 | frame_len=320, shift_len=160, n_azi=37):
7 | """ read wav files in given directies, one file per time
8 | Args:
9 | record_set_dir: directory or list of directories where recordings exist
10 | batch_size:
11 | is_shuffle:
12 | Returns:
13 | samples generator, [samples, label_all]
14 | """
15 | if isinstance(reverb_set_dir, list):
16 | dir_all = reverb_set_dir
17 | else:
18 | dir_all = [reverb_set_dir]
19 | #
20 | fpath_reverb_all = []
21 | for dir_fpath in dir_all:
22 | fpath_all_tmp = get_fpath(dir_fpath, '.npy', is_absolute=True)
23 | fpath_reverb_all.extend(fpath_all_tmp)
24 |
25 | if is_shuffle:
26 | np.random.shuffle(fpath_reverb_all)
27 |
28 | for fpath_reverb in fpath_reverb_all:
29 | x_d_batch, x_r_batch, y_loc_batch, is_anechoic = np.load(fpath_reverb,allow_pickle=True)
30 | # if x_d.shape[0] == batch_size and x_r.shape[0] == batch_size and y_loc.shape[0] == batch_size:
31 | yield x_r_batch, y_loc_batch
32 |
--------------------------------------------------------------------------------
/utils/plot_learning_curve.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | from matplotlib.ticker import MaxNLocator
4 | import numpy as np
5 | from BasicTools import plot_tools
6 | plt.rcParams.update({"font.size": "12"})
7 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D']
8 | reverb_room_all = ['Room_A', 'Room_B', 'Room_C', 'Room_D']
9 |
10 |
11 | def plot_train_process(model_dir):
12 |
13 | fig, ax = plt.subplots(1, 1)
14 | plot_settings = {'linewidth': 4}
15 | n_epoch_max = 0
16 | for room_i, room in enumerate(reverb_room_all):
17 | record_fpath = os.path.join(model_dir, room, 'train_record.npz')
18 | record_info = np.load(record_fpath)
19 | cost_record_valid = record_info['cost_record_valid']
20 | # rmse_record_valid = record_info['azi_rmse_record_valid']
21 | n_epoch = np.nonzero(cost_record_valid)[0][-1] + 1
22 | if n_epoch > n_epoch_max:
23 | n_epoch_max = n_epoch
24 | ax.plot(cost_record_valid[:n_epoch], **plot_settings, label=room[-1])
25 |
26 | ax.xaxis.set_major_locator(MaxNLocator(integer=True))
27 | ax.legend()
28 | ax.set_ylabel('Cross entrophy')
29 | ax.set_xlabel('Epoch(n)')
30 |
31 | plot_tools.savefig(fig, 'learning_curve_all.png', '../images')
32 |
33 |
34 | if __name__ == '__main__':
35 |
36 | model_dir = '../models/mct'
37 | plot_train_process(model_dir)
38 |
--------------------------------------------------------------------------------
/evaluate_mct.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
4 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
5 |
6 | import sys
7 | from WaveLoc import WaveLoc
8 | from BasicTools import plot_tools
9 | from utils import file_reader
10 |
11 |
12 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D']
13 | reverb_room_all = ['Room_A', 'Room_B', 'Room_C', 'Room_D']
14 | n_reverb_room = 4
15 | chunk_size = 25
16 | n_test = 4
17 |
18 |
19 | def evaluate_mct(model_dir_base):
20 | rmse_all = np.zeros((n_test, n_reverb_room))
21 | for room_i, room in enumerate(reverb_room_all):
22 | model_dir = os.path.join(model_dir_base, room)
23 | model_config_fpath = os.path.join(model_dir, 'config.cfg')
24 | model = WaveLoc(file_reader.file_reader,
25 | model_config_fpath, gpu_index=0)
26 | model.load_model(model_dir)
27 |
28 | for test_i in range(n_test):
29 | dataset_dir_test = os.path.join(
30 | '/home/st/Work_Space/Localize/WaveLoc/Data',
31 | f'v{test_i+1}/test/reverb/{room[-1]}')
32 | rmse_all[test_i, room_i] = model.evaluate_chunk_rmse(
33 | dataset_dir_test,
34 | chunk_size=chunk_size)
35 | return rmse_all
36 |
37 |
38 | if __name__ == '__main__':
39 | model_dir = sys.argv[1] #'models/mct'
40 | rmse_all = evaluate_mct(model_dir)
41 |
42 | with open(os.path.join(model_dir, 'result.txt'), 'w') as result_file:
43 | result_file.write(f'{rmse_all}')
44 | result_file.write('mean: {}\n'.format(np.mean(rmse_all, axis=0)))
45 | result_file.write('std: {}\n'.format(np.std(rmse_all, axis=0)))
46 |
47 | print(rmse_all)
48 | print('mean:', np.mean(rmse_all, axis=0))
49 | print('std:', np.std(rmse_all, axis=0))
50 |
--------------------------------------------------------------------------------
/utils/file_reader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import copy
4 | import os
5 | from BasicTools import get_fpath
6 | from BasicTools import wav_tools
7 |
8 | plt.rcParams.update({"font.size": "12"})
9 | reverb_room_all = ['A', 'B', 'C', 'D']
10 |
11 |
12 | def file_reader(record_set_dir, batch_size=-1, is_shuffle=True):
13 | """ read wav files in given directies, one file per time
14 | Args:
15 | record_set_dir: directory or list of directories where recordings exist
16 | Returns:
17 | samples generator, [samples, label_all]
18 | """
19 | if isinstance(record_set_dir, list):
20 | dirs = record_set_dir
21 | else:
22 | dirs = [record_set_dir]
23 | #
24 | fpath_all = []
25 | for sub_set_dir in dirs:
26 | fpath_all_sub = get_fpath(sub_set_dir, '.wav', is_absolute=True)
27 | fpath_all.extend(fpath_all_sub)
28 |
29 | if is_shuffle:
30 | np.random.shuffle(fpath_all)
31 |
32 | # print('#file',len(fpath_all))
33 | # raise Exception()
34 |
35 | if len(fpath_all) < 1:
36 | raise Exception('empty folder:{}'.format(record_set_dir))
37 |
38 | frame_len = 320
39 | shift_len = 160
40 | n_azi = 37
41 |
42 | if batch_size > 1:
43 | x_all = np.zeros((0, frame_len, 2, 1))
44 | y_all = np.zeros((0, n_azi))
45 |
46 | for fpath in fpath_all:
47 | record, fs = wav_tools.read_wav(fpath)
48 | x_file_all = wav_tools.frame_data(record, frame_len, shift_len)
49 | x_file_all = np.expand_dims(x_file_all, axis=-1)
50 |
51 | # onehot azi label
52 | n_sample_file = x_file_all.shape[0]
53 | fname = os.path.basename(fpath)
54 | azi = np.int16(fname.split('_')[0])
55 | y_file_all = np.zeros((n_sample_file, n_azi))
56 | y_file_all[:, azi] = 1
57 |
58 | if batch_size > 0:
59 | x_all = np.concatenate((x_all, x_file_all), axis=0)
60 | y_all = np.concatenate((y_all, y_file_all), axis=0)
61 |
62 | while x_all.shape[0] > batch_size:
63 | x_batch = copy.deepcopy(x_all[:batch_size])
64 | y_batch = copy.deepcopy(y_all[:batch_size])
65 |
66 | x_all = x_all[batch_size:]
67 | y_all = y_all[batch_size:]
68 |
69 | yield [x_batch, y_batch]
70 | else:
71 | yield [x_file_all, y_file_all]
72 |
--------------------------------------------------------------------------------
/train_mct.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import configparser
4 | from multiprocessing import Process
5 | from WaveLoc import WaveLoc
6 | from utils import file_reader_v2
7 |
8 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
9 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
10 |
11 |
12 | data_dir = 'Data'
13 | train_set_dir_base = os.path.join(data_dir, 'v1/npy/train')
14 | valid_set_dir_base = os.path.join(data_dir, 'v1/npy/valid')
15 |
16 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D']
17 | reverb_room_all = ['Room_A', 'Room_B', 'Room_C', 'Room_D']
18 |
19 | model_basic_settings = {'fs': 16000,
20 | 'n_band': 32,
21 | 'cf_low': 70,
22 | 'cf_high': 7000,
23 | 'frame_len': 320,
24 | 'shift_len': 160,
25 | 'filter_len': 320,
26 | 'azi_num': 37,
27 | 'is_use_gtf': False,
28 | 'is_padd': False}
29 | gpu_index = 1
30 |
31 |
32 | def train_mct(room_tar, model_dir):
33 | if not os.path.exists(model_dir):
34 | os.makedirs(model_dir)
35 |
36 | # filter out room_tar from room_all
37 | print('tar_room', room_tar)
38 | mct_room_all = [room for room in room_all if room != room_tar]
39 | config = configparser.ConfigParser()
40 | config['model'] = {**model_basic_settings}
41 | config['train'] = {'batch_size': 128,
42 | 'max_epoch': 50,
43 | 'is_print_log': False,
44 | 'train_set_dir': ';'.join(
45 | [os.path.join(train_set_dir_base, room)
46 | for room in mct_room_all]),
47 | 'valid_set_dir': ';'.join(
48 | [os.path.join(valid_set_dir_base, room)
49 | for room in mct_room_all])}
50 |
51 | config_fpath = os.path.join(model_dir, 'config.cfg')
52 | with open(config_fpath, 'w') as config_file:
53 | if config_file is None:
54 | raise Exception('fail to create file')
55 | config.write(config_file)
56 |
57 | model = WaveLoc(file_reader_v2.file_reader, config_fpath=config_fpath,
58 | gpu_index=gpu_index)
59 | model.train_model(model_dir)
60 |
61 |
62 | if __name__ == '__main__':
63 | thread_all = []
64 | for room_tar in reverb_room_all:
65 | model_dir = 'models/mct/{}'.format(room_tar)
66 | thread = Process(target=train_mct,args=(room_tar,model_dir))
67 | thread.start()
68 | thread_all.append(thread)
69 |
70 | [thread.join() for thread in thread_all]
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # End2End sound localization model
2 |
3 | Reference:
4 |
5 | P. Vecchiotti, N. Ma, S. Squartini, and G. J. Brown, “END-TO-END BINAURAL SOUND LOCALISATION FROM THE RAW WAVEFORM,” in 2019 IEEE INTERNATIONAL CONFERENCE ON ACOUSTICS, SPEECH AND SIGNAL PROCESSING (ICASSP), 345 E 47TH ST, NEW YORK, NY 10017 USA, 2019, pp. 451–455.
6 |
7 | **Only WaveLoc-GTF is implemented**
8 |
9 | ## Model
10 |
11 |
12 | ## Training
13 | ### Dataset
14 | - BRIR
15 |
16 | Surrey binaural room impulse response (BRIR) database, including anechoic room and 4 reverberation room.
17 |
18 |
19 |
20 | | Room | A | B | C | D |
21 |
22 |
23 | | RT_60(s) | 0.32 | 0.47 | 0.68 | 0.89 |
24 |
25 |
26 | | DDR(dB) | 6.09 | 5.31 | 8.82 | 6.12 |
27 |
28 |
29 |
30 | - Sound source
31 |
32 | TIMIT database
33 |
34 | Sentences per azimuth
35 |
36 |
37 |
38 |
39 |
40 | | Train | Validate | Evaluate |
41 |
42 |
43 | | 24 | 6 | 15 |
44 |
45 |
46 |
47 |
48 | ## Multi-conditional training(MCT)
49 |
50 | For
51 | For each reverberant room, the rest 3 reverberant rooms and anechoic room are used for training
52 |
53 | Training curves
54 |
55 |

56 |
57 |
58 |
59 | ## Evaluation
60 | Root mean square error(RMSE) is used as the metrics of performance. For each reverberant room, the evaluation was performed 3 times to get more stable results and the test dataset was regenerated each time.
61 |
62 | Since binaural sound is directly fed to models without extra preprocess and there may be short pulses in speech, the localization result was reported based on chunks rather than frames. Each chunk consisted of 25 consecutive frames.
63 |
64 | ### My result vs. paper
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 | | Reverberant room | A | B | C | D |
73 |
74 |
75 | | My result | 1.5 | 2.0 | 1.4 | 2.7 |
76 |
77 |
78 | | Result in paper | 1.5 | 3.0 | 1.7 | 3.5 |
79 |
80 |
81 |
82 | ## Main dependencies
83 | - tensorflow-1.14
84 | - pysofa (can be installed by pip)
85 | - BasicTools (in my other [repository](https://github.com/bingo-todd/BasicTools))
86 |
--------------------------------------------------------------------------------
/gen_dataset/align_brir.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pysofa
4 | from BasicTools import DspTools
5 | import matplotlib.pyplot as plt
6 |
7 | """Align brirs of reverberant rooms to that from the anechoic room
8 | """
9 |
10 | brirs_dir = os.path.expanduser('~/Work_Space/Data/RealRoomBRIRs')
11 | brirs_aligned_dir = 'brirs_aligned'
12 | os.mkdir(brirs_aligned_dir, exist_ok=True)
13 |
14 | n_azi = 37
15 | n_channel = 2
16 | rever_room_all = ('Room_A', 'Room_B', 'Room_C', 'Room_D')
17 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D']
18 |
19 |
20 |
21 | def plot_brirs():
22 | for room in room_all:
23 | brirs_fpath = f'brirs_aligned/{room}.npy'
24 | brirs = np.load(brirs_fpath)
25 | print(brirs.shape)
26 | fig, ax = plt.subplots(1, 1)
27 | for azi_i in range(0, 37, 8):
28 | ax.plot(brirs[azi_i, :, 0])
29 | ax.set_xlim([0, 2000])
30 | fig.savefig(f'brirs_aligned/{room}.png')
31 | plt.close(fig)
32 |
33 |
34 | def load_brirs(room):
35 | brirs_fpath = os.path.join(brirs_dir, f'UniS_{room}_BRIR_16k.sofa')
36 | brirs = pysofa.SOFA(brirs_fpath).FIR.IR.transpose(0, 2, 1)
37 | return brirs
38 |
39 |
40 | def align_brirs():
41 | """
42 | For each reverberant room, calculate BRIRs delays of each sound position
43 | and align BRIRs according to the averaged delay
44 | """
45 |
46 | brirs_anechoic = load_brirs('Anechoic')
47 | np.save(os.path.join(brirs_aligned_dir, 'Anechoic.npy'), brirs_anechoic)
48 |
49 | delay_all = np.zeros((n_azi, n_channel))
50 | for reverb_room in rever_room_all:
51 | print(reverb_room)
52 | brirs = load_brirs(reverb_room)
53 | for azi_i in range(n_azi):
54 | for channel_i in range(n_channel):
55 | delay_all[azi_i, channel_i] = DspTools.cal_delay(
56 | brirs_anechoic[azi_i, :, channel_i],
57 | brirs[azi_i, :, channel_i])
58 | delay_mean = np.int16(np.round(np.mean(delay_all)))
59 | if delay_mean > 0:
60 | brirs_aligned = np.concatenate((np.zeros((n_azi, delay_mean,
61 | n_channel)),
62 | brirs),
63 | axis=1)
64 | else:
65 | brirs_aligned = brirs[:, -delay_mean:, :]
66 |
67 | np.save(os.path.join(brirs_aligned_dir, f'{reverb_room}.npy'),
68 | brirs_aligned)
69 |
70 | if False:
71 | fig, ax = plt.subplots(1, 1)
72 | ax.plot(brirs_anechoic[0, :, 0], label='anechoic')
73 | ax.plot(brirs_aligned[0, :, 0], label='reverb_aligned')
74 | ax.plot(brirs[0, :, 0], label='reverb')
75 | ax.set_xlim((0, 500))
76 | ax.legend()
77 | ax.set_title(reverb_room)
78 | fig.savefig(os.path.join(brirs_aligned_dir,
79 | f'{reverb_room}_aligned.png'))
80 | plt.close(fig)
81 |
82 |
83 | if __name__ == '__main__':
84 | align_brirs()
85 | # plot_brirs()
86 |
--------------------------------------------------------------------------------
/gen_dataset/wav2npy.py:
--------------------------------------------------------------------------------
1 | """
2 | frame raw wavefrom and save in batches
3 | """
4 |
5 |
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import copy
9 | import os
10 | from BasicTools import get_fpath, wav_tools, ProcessBar
11 |
12 | plt.rcParams.update({"font.size": "12"})
13 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D']
14 |
15 |
16 | def wav2npy(reverb_set_dir, npy_dir, is_anechoic):
17 | """ read wav files in given directies, one file per time
18 | Args:
19 | record_set_dir: directory or list of directories where recordings exist
20 | batch_size:
21 | is_shuffle:
22 | Returns:
23 | samples generator, [samples, label_all]
24 | """
25 |
26 | frame_len = 320
27 | shift_len = 160
28 | n_azi = 37
29 | batch_size = 128
30 |
31 | os.makedirs(npy_dir, exist_ok=True)
32 |
33 | #
34 | fpath_reverb_all = get_fpath(reverb_set_dir, '.wav', is_absolute=True)
35 | if len(fpath_reverb_all) < 1:
36 | raise Exception('empty folder:{}'.format(reverb_set_dir))
37 |
38 | pb = ProcessBar(len(fpath_reverb_all))
39 |
40 | batch_count = 0
41 | x_r = np.zeros((0, frame_len, 2, 1))
42 | x_d = np.zeros((0, frame_len, 2, 1))
43 | y_loc = np.zeros((0, n_azi))
44 |
45 | for fpath_reverb in fpath_reverb_all:
46 | pb.update()
47 | # reverb signal
48 | record, fs = wav_tools.read_wav(fpath_reverb)
49 | x_r_file = np.expand_dims(
50 | wav_tools.frame_data(record, frame_len, shift_len),
51 | axis=-1)
52 | # direct signal
53 | fpath_direct = fpath_reverb.replace('reverb', 'direct')
54 | direct, fs = wav_tools.read_wav(fpath_direct)
55 | x_d_file = np.expand_dims(
56 | wav_tools.frame_data(direct, frame_len, shift_len),
57 | axis=-1)
58 |
59 | # onehot azi label
60 | n_sample_file = x_d_file.shape[0]
61 | if x_r_file.shape[0] != n_sample_file:
62 | raise Exception('sample number do not consist')
63 |
64 | fname = os.path.basename(fpath_reverb)
65 | azi = np.int16(fname.split('_')[0])
66 | y_loc_file = np.zeros((n_sample_file, n_azi))
67 | y_loc_file[:, azi] = 1
68 |
69 | x_r = np.concatenate((x_r, x_r_file), axis=0)
70 | x_d = np.concatenate((x_d, x_d_file), axis=0)
71 | y_loc = np.concatenate((y_loc, y_loc_file), axis=0)
72 |
73 | while x_d.shape[0] > batch_size:
74 | x_r_batch = x_r[:batch_size]
75 | x_d_batch = x_d[:batch_size]
76 | y_loc_batch = y_loc[:batch_size]
77 |
78 | npy_fpath = os.path.join(npy_dir, '{}.npy'.format(batch_count))
79 | np.save(npy_fpath,[x_d_batch, x_r_batch, y_loc_batch, is_anechoic])
80 | batch_count = batch_count + 1
81 |
82 | x_r = x_r[batch_size:]
83 | x_d = x_d[batch_size:]
84 | y_loc = y_loc[batch_size:]
85 |
86 |
87 |
88 |
89 | if __name__ == '__main__':
90 |
91 | for set_type in ['train', 'valid']:
92 | for room in room_all:
93 | print(room)
94 | wav_set_dir = '../Data/v1/{}/reverb/{}'.format(set_type, room)
95 | npy_set_dir = '../Data/v1/npy/{}/{}'.format(set_type, room)
96 | wav2npy(wav_set_dir, npy_set_dir, False)
97 |
98 |
99 | for test_i in range(1,5):
100 | for room in room_all:
101 | print(room)
102 | wav_set_dir = f'../Data/v{test_i}/test/reverb/{room}'
103 | npy_set_dir = f'../Data/v{test_i}/npy/test/{room}'
104 | wav2npy(wav_set_dir, npy_set_dir, False)
--------------------------------------------------------------------------------
/gen_dataset/gen_dataset_wav.py:
--------------------------------------------------------------------------------
1 | """
2 | synthesize spatial recordings
3 | """
4 |
5 | import numpy as np
6 | import os
7 | from multiprocessing import Process
8 | from BasicTools import ProcessBarMulti
9 | from BasicTools import wav_tools
10 | from BasicTools.Filter_GPU import Filter_GPU
11 | from BasicTools.get_fpath import get_fpath
12 |
13 |
14 | TIMIT_dir = os.path.expanduser('~/Work_Space/Data/TIMIT_wav')
15 | brirs_dir = 'brirs_aligned'
16 | data_dir = '../Data'
17 |
18 | room_all = ['Anechoic', 'Room_A', 'Room_B', 'Room_C', 'Room_D']
19 | n_room = 5
20 | n_azi = 37 # -90~90 in step of 5
21 | n_wav_per_azi_all = {'train': 24,
22 | 'valid': 6,
23 | 'test': 15}
24 |
25 | fs = 16e3
26 | frame_len = int(20e-3*fs)
27 | shift_len = int(10e-3*fs)
28 |
29 |
30 | def load_brirs(room):
31 | """load brirs of given room
32 | Args:
33 | room: room name from ['Anechoic','A','B','C','D']
34 | """
35 | brirs_fpath = os.path.join(brirs_dir, f'{room}.npy')
36 | brirs = np.load(brirs_fpath)
37 | return brirs
38 |
39 |
40 | def truncate_silence(x):
41 | """ trip off slient frames in begining and end(only one frame)
42 | """
43 | vad_flag = wav_tools.vad(x, frame_len, shift_len)
44 | if not vad_flag[0]: # silence in the first frame
45 | x = x[frame_len:]
46 | if not vad_flag[-1]: # silence in the last frame
47 | x = x[:-frame_len]
48 | return x
49 |
50 |
51 | def syn_record(src_fpath_all, set_dir, n_wav_per_azi, task_i, pb):
52 | """synthesize spatial recordings as well corresponding direct sound for
53 | each set
54 | """
55 | filter_gpu = Filter_GPU(gpu_index=1)
56 |
57 | brirs_direct = load_brirs('Anechoic')
58 | wav_count = 0
59 | for room in room_all:
60 | direct_dir = os.path.join(set_dir, 'direct', room)
61 | os.makedirs(direct_dir, exist_ok=True)
62 | rever_dir = os.path.join(set_dir, 'reverb', room)
63 | os.makedirs(rever_dir, exist_ok=True)
64 |
65 | brirs_room = load_brirs(room)
66 | for azi_i in range(n_azi):
67 | for i in range(n_wav_per_azi):
68 | pb.update(task_i)
69 | src_fpath = src_fpath_all[wav_count]
70 | wav_count = wav_count+1
71 |
72 | src, fs = wav_tools.read_wav(src_fpath)
73 | src = truncate_silence(src)
74 |
75 | direct = filter_gpu.brir_filter(src, brirs_direct[azi_i])
76 | # direct = wav_tools.brir_filter(src, brirs_direct[azi_i])
77 | direct_fpath = os.path.join(direct_dir, f'{azi_i}_{i}.wav')
78 | wav_tools.write_wav(direct, fs, direct_fpath)
79 |
80 | reverb = filter_gpu.brir_filter(src, brirs_room[azi_i])
81 | # reverb = wav_tools.brir_filter(src, brirs_room[azi_i])
82 | reverb_fpath = os.path.join(rever_dir, f'{azi_i}_{i}.wav')
83 | wav_tools.write_wav(reverb, fs, reverb_fpath)
84 |
85 |
86 | def gen_dataset(dir_path, set_type_all):
87 |
88 | n_wav_train = n_azi * n_room * n_wav_per_azi_all['train']
89 | n_wav_valid = n_azi * n_room * n_wav_per_azi_all['valid']
90 | n_wav_test = n_azi * n_room * n_wav_per_azi_all['test']
91 | # prepare sound source
92 | # train and validate
93 | if not os.path.exists('fpath_TIMIT_train_all.npy'):
94 | TIMIT_train_dir = os.path.join(TIMIT_dir, 'TIMIT/TRAIN')
95 | src_fpath_all = get_fpath(TIMIT_train_dir, '.wav',
96 | is_absolute=True)
97 | np.save('fpath_TIMIT_train_all.npy', src_fpath_all)
98 | src_fpath_all = np.load('fpath_TIMIT_train_all.npy')
99 | print('train', len(src_fpath_all))
100 | print('train+valid', n_wav_train+n_wav_valid)
101 | np.random.shuffle(src_fpath_all)
102 | src_fpath_train_all = src_fpath_all[:n_wav_train]
103 | src_fpath_valid_all = src_fpath_all[n_wav_train:]
104 |
105 | # test
106 | if not os.path.exists('fpath_TIMIT_test_all.npy'):
107 | TIMIT_test_dir = os.path.join(TIMIT_dir, 'TIMIT/TEST')
108 | src_fpath_test_all = get_fpath(TIMIT_test_dir, '.wav',
109 | is_absolute=True)
110 | np.save('fpath_TIMIT_test_all.npy', src_fpath_test_all)
111 | src_fpath_test_all = np.load('fpath_TIMIT_test_all.npy')
112 | print('test', len(src_fpath_test_all))
113 | print('test', n_wav_test)
114 | np.random.shuffle(src_fpath_test_all)
115 |
116 | # np.save(os.path.join(dir_path, 'src_fpath_all.npy'),
117 | # [src_fpath_train_all, src_fpath_valid_all, src_fpath_test_all])
118 |
119 | src_fpath_all = (src_fpath_train_all,
120 | src_fpath_valid_all,
121 | src_fpath_test_all)
122 |
123 | n_wav_all = [len(src_fpath_all[i]) for i in range(len(set_type_all))]
124 | pb = ProcessBarMulti(n_wav_all, desc_all=set_type_all)
125 | proc_all = []
126 | for i, set_type in enumerate(set_type_all):
127 | print(set_type)
128 | set_dir = os.path.join(dir_path, set_type)
129 | proc = Process(target=syn_record,
130 | args=(src_fpath_all[i], set_dir,
131 | n_wav_per_azi_all[set_type],
132 | str(i), pb))
133 | proc.start()
134 | proc_all.append(proc)
135 | [proc.join() for proc in proc_all]
136 |
137 |
138 | if __name__ == '__main__':
139 |
140 | # train dataset and validation dataset
141 | dataset_dir = os.path.join(data_dir, 'v1')
142 | os.makedirs(dataset_dir, exist_ok=True)
143 | gen_dataset(dir_path=dataset_dir,
144 | set_type_all=['train', 'valid'])
145 |
146 | # test dataset
147 | for i in range(1, 5):
148 | dataset_dir = os.path.join(data_dir, f'v{i}')
149 | os.makedirs(dataset_dir, exist_ok=True)
150 | gen_dataset(dir_path=dataset_dir, set_type_all=['test'])
151 |
--------------------------------------------------------------------------------
/WaveLoc.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import os
5 | import configparser
6 | import time
7 | import gammatone.filters as gt_filters
8 | import tensorflow as tf
9 |
10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
11 |
12 |
13 | class WaveLoc(object):
14 | """
15 | """
16 | def __init__(self, file_reader, config_fpath=None, gpu_index=0):
17 | """
18 | """
19 |
20 | # constant settings
21 | self.epsilon = 1e-20
22 | self._file_reader = file_reader
23 |
24 | self._graph = tf.Graph()
25 | config = tf.compat.v1.ConfigProto()
26 | config.gpu_options.allow_growth = True
27 | # config.gpu_options.visible_device_list = '{}'.format(gpu_index)
28 | self._sess = tf.compat.v1.Session(graph=self._graph, config=config)
29 |
30 | self._load_cfg(config_fpath)
31 | self._build_model()
32 |
33 | def _add_log(self, log_info):
34 | self._log_file.write(log_info)
35 | self._log_file.write('\n')
36 | self._log_file.flush()
37 | if self.is_print_log:
38 | print(log_info)
39 |
40 | def _load_cfg(self, config_fpath):
41 | if config_fpath is not None and os.path.exists(config_fpath):
42 | config = configparser.ConfigParser()
43 | config.read(config_fpath)
44 |
45 | # settings for model
46 | self.fs = np.int16(config['model']['fs'])
47 | self.n_band = np.int16(config['model']['n_band'])
48 | self.cf_low = np.int16(config['model']['cf_low'])
49 | self.cf_high = np.int16(config['model']['cf_high'])
50 | self.frame_len = np.int16(config['model']['frame_len'])
51 | self.shift_len = np.int16(config['model']['shift_len'])
52 | self.filter_len = np.int16(config['model']['filter_len'])
53 | self.is_padd = config['model']['is_padd'] == 'True'
54 | self.n_azi = np.int16(config['model']['azi_num'])
55 |
56 | # settings for training
57 | self.batch_size = np.int16(config['train']['batch_size'])
58 | self.max_epoch = np.int16(config['train']['max_epoch'])
59 | self.is_print_log = config['train']['is_print_log'] == 'True'
60 | self.train_set_dir = config['train']['train_set_dir'].split(';')
61 | self.valid_set_dir = config['train']['valid_set_dir'].split(';')
62 | if self.valid_set_dir[0] == '':
63 | self.valid_set_dir = None
64 |
65 | print('Train set:')
66 | [print('\t{}'.format(item)) for item in self.train_set_dir]
67 |
68 | print('Valid set:')
69 | [print('\t{}'.format(item)) for item in self.valid_set_dir]
70 |
71 | else:
72 | print(config_fpath)
73 | raise OSError
74 |
75 | def get_gtf_kernel(self):
76 | """
77 | """
78 | cfs = gt_filters.erb_space(self.cf_low, self.cf_high, self.n_band)
79 | self.cfs = cfs
80 |
81 | sample_times = np.arange(0, self.filter_len, 1)/self.fs
82 | irs = np.zeros((self.filter_len, self.n_band), dtype=np.float32)
83 |
84 | EarQ = 9.26449
85 | minBW = 24.7
86 | order = 1
87 | N = 4
88 | for band_i in range(self.n_band):
89 | ERB = ((cfs[band_i]/EarQ)**order+minBW**order)**(1/order)
90 | b = 1.019*ERB
91 | numerator = np.multiply(sample_times**(N-1),
92 | np.cos(2*np.pi*cfs[band_i]*sample_times))
93 | denominator = np.exp(2*np.pi*b*sample_times)
94 | irs[:, band_i] = np.divide(numerator, denominator)
95 |
96 | gain = np.max(np.abs(np.fft.fft(irs, axis=0)), axis=0)
97 | irs_gain_norm = np.divide(np.flipud(irs), gain)
98 | if self.is_padd:
99 | kernel = np.concatenate((irs_gain_norm,
100 | np.zeros((self.filter_len, self.n_band))),
101 | axis=0)
102 | else:
103 | kernel = irs_gain_norm
104 | return kernel
105 |
106 | def _fcn_layers(self, input, *layers_setting):
107 | for setting in layers_setting:
108 | fcn_size = setting['fcn_size']
109 | activation = setting['activation']
110 | rate = setting['rate']
111 |
112 | layer_fcn = tf.keras.layers.Dense(units=fcn_size,
113 | activation=activation)
114 | if rate > 0:
115 | layer_drop = tf.keras.layers.Dropout(rate=rate)
116 | output = layer_fcn(layer_drop(input))
117 | elif rate == 0:
118 | output = layer_fcn(input)
119 | else:
120 | raise Exception('illegal dropout rate')
121 | input = output
122 | return output
123 |
124 | def _build_model_subband(self, input):
125 | """
126 | """
127 | layer1_conv = tf.keras.layers.Conv2D(filters=6,
128 | kernel_size=[18, 2],
129 | strides=[1, 1],
130 | activation=tf.nn.relu)
131 | layer1_pool = tf.keras.layers.MaxPool2D([4, 1], [4, 1])
132 | layer1_out = layer1_pool(layer1_conv(input))
133 |
134 | layer2_conv = tf.keras.layers.Conv2D(filters=12,
135 | kernel_size=[6, 1],
136 | strides=[1, 1],
137 | activation=tf.nn.relu)
138 | layer2_pool = tf.keras.layers.MaxPool2D([4, 1], [4, 1])
139 | layer2_out = layer2_pool(layer2_conv(layer1_out))
140 |
141 | flatten_len = np.prod(layer2_out.get_shape().as_list()[1:])
142 | out = tf.reshape(layer2_out, [-1, flatten_len]) # flatten
143 | return out
144 |
145 | def _build_model(self):
146 | """Build graph
147 | """
148 | # gammatone layer kernel initalizer
149 | with self._graph.as_default():
150 | kernel_initializer = tf.constant_initializer(
151 | self.get_gtf_kernel())
152 |
153 | if self.is_padd:
154 | gtf_kernel_len = 2*self.filter_len
155 | else:
156 | gtf_kernel_len = self.filter_len
157 |
158 | x = tf.compat.v1.placeholder(shape=[None, self.frame_len, 2, 1],
159 | dtype=tf.float32,
160 | name='x') #
161 |
162 | gt_layer = tf.keras.layers.Conv2D(
163 | filters=self.n_band,
164 | kernel_size=[gtf_kernel_len, 1],
165 | strides=[1, 1],
166 | padding='same',
167 | kernel_initializer=kernel_initializer,
168 | trainable=False, use_bias=False)
169 |
170 | # add to model for test
171 | # self.layer1_conv = gt_layer
172 |
173 | # amplitude normalization across frequency channs
174 | # problem: silence ?
175 | x_band_all = gt_layer(x)
176 | amp_max = tf.reduce_max(
177 | tf.reduce_max(
178 | tf.reduce_max(
179 | tf.abs(x_band_all),
180 | axis=1, keepdims=True),
181 | axis=2, keepdims=True),
182 | axis=3, keepdims=True)
183 | x_band_norm_all = tf.divide(x_band_all, amp_max)
184 |
185 | # layer1_pool
186 | gt_layer_pool = tf.keras.layers.MaxPool2D([2, 1], [2, 1])
187 | gt_layer_output = gt_layer_pool(x_band_norm_all)
188 |
189 | band_out_list = []
190 | for band_i in range(self.n_band):
191 | band_output = self._build_model_subband(
192 | tf.expand_dims(
193 | gt_layer_output[:, :, :, band_i],
194 | axis=-1))
195 | band_out_list.append(band_output)
196 | band_out = tf.concat(band_out_list, axis=1)
197 |
198 | layer4 = {'fcn_size': 1024,
199 | 'activation': tf.nn.relu,
200 | 'rate': 0.5}
201 | layer5 = {'fcn_size': 1024,
202 | 'activation': tf.nn.relu,
203 | 'rate': 0.5}
204 | output_layer = {'fcn_size': self.n_azi,
205 | 'activation': tf.nn.softmax,
206 | 'rate': 0}
207 |
208 | y_est = self._fcn_layers(band_out, layer4, layer5, output_layer)
209 |
210 | # groundtruth of two tasks
211 | y = tf.compat.v1.placeholder(shape=[None, self.n_azi],
212 | dtype=tf.float32)
213 | # cost function
214 | cost = self._cal_cross_entropy(y_est, y)
215 | # additional measurement of localization
216 | azi_rmse = self._cal_azi_rmse(y_est, y)
217 |
218 | #
219 | lr = tf.compat.v1.placeholder(tf.float32, shape=[])
220 | opt_step = tf.compat.v1.train.AdamOptimizer(
221 | learning_rate=lr).minimize(cost)
222 |
223 | # initialize of model
224 | init = tf.compat.v1.global_variables_initializer()
225 | self._sess.run(init)
226 |
227 | # input and output
228 | self._x = x
229 | self._y_est = y_est
230 | # groundtruth
231 | self._y = y
232 | # cost function and optimizer
233 | self._cost = cost
234 | self._azi_rmse = azi_rmse
235 | self._lr = lr
236 | self._opt_step = opt_step
237 |
238 | def _cal_cross_entropy(self, y_est, y):
239 | cross_entropy = -tf.reduce_mean(
240 | tf.reduce_sum(
241 | tf.multiply(
242 | y, tf.math.log(y_est+self.epsilon)),
243 | axis=1))
244 | return cross_entropy
245 |
246 | def _cal_mse(self, y_est, y):
247 | rmse = tf.reduce_mean(tf.reduce_sum((y-y_est)**2, axis=1))
248 | return rmse
249 |
250 | def _cal_azi_rmse(self, y_est, y):
251 | azi_est = tf.argmax(y_est, axis=1)
252 | azi = tf.argmax(y, axis=1)
253 | diff = tf.cast(azi_est - azi, tf.float32)
254 | return tf.sqrt(tf.reduce_mean(tf.pow(diff, 2)))
255 |
256 | def _cal_cp(self, y_est, y):
257 | equality = tf.equal(tf.argmax(y_est, axis=1), tf.argmax(y, axis=1))
258 | cp = tf.reduce_mean(tf.cast(equality, tf.float32))
259 | return cp
260 |
261 | def load_model(self, model_dir):
262 | """load model"""
263 | if not os.path.exists(model_dir):
264 | raise Exception('no model exists in {}'.format(model_dir))
265 |
266 | with self._graph.as_default():
267 | # restore model
268 | saver = tf.compat.v1.train.Saver()
269 | ckpt = tf.train.get_checkpoint_state(model_dir)
270 | if ckpt and ckpt.model_checkpoint_path:
271 | saver.restore(self._sess, ckpt.model_checkpoint_path)
272 |
273 | print(f'load model from {model_dir}')
274 |
275 | def _train_record_init(self, model_dir, is_load_model):
276 | """
277 | """
278 | if is_load_model:
279 | record_info = np.load(os.path.join(model_dir,
280 | 'train_record.npz'))
281 | cost_record_valid = record_info['cost_record_valid']
282 | azi_rmse_record_valid = record_info['azi_rmse_record_valid']
283 | lr_value = record_info['lr']
284 | best_epoch = record_info['best_epoch']
285 | min_valid_cost = record_info['min_valid_cost']
286 | last_epoch = np.nonzero(cost_record_valid)[0][-1]
287 | else:
288 | cost_record_valid = np.zeros(self.max_epoch)
289 | azi_rmse_record_valid = np.zeros(self.max_epoch)
290 | lr_value = 1e-3
291 | min_valid_cost = np.infty
292 | best_epoch = 0
293 | last_epoch = -1
294 | return [cost_record_valid, azi_rmse_record_valid,
295 | lr_value, min_valid_cost, best_epoch, last_epoch]
296 |
297 | def train_model(self, model_dir, is_load_model=False):
298 | """Train model either from initial state(self._build_model()) or
299 | already existed model
300 | """
301 | if is_load_model:
302 | self.load_model(model_dir)
303 |
304 | if not os.path.exists(model_dir):
305 | os.makedirs(self.model_dir)
306 |
307 | # open text file for logging
308 | self._log_file = open(os.path.join(model_dir, 'log.txt'), 'a')
309 |
310 | with self._graph.as_default():
311 |
312 | [cost_record_valid, azi_rmse_record_valid,
313 | lr_value, min_valid_cost,
314 | best_epoch, last_epoch] = self._train_record_init(model_dir,
315 | is_load_model)
316 |
317 | saver = tf.compat.v1.train.Saver()
318 | print('start training')
319 | for epoch in range(last_epoch+1, self.max_epoch):
320 | t_start = time.time()
321 | print(f'epoch {epoch}')
322 | batch_generator = self._file_reader(self.train_set_dir)
323 | for x, y in batch_generator:
324 | self._sess.run(self._opt_step,
325 | feed_dict={self._x: x,
326 | self._y: y,
327 | self._lr: lr_value})
328 | # model test
329 | [cost_record_valid[epoch],
330 | azi_rmse_record_valid[epoch]] = self.evaluate(
331 | self.valid_set_dir)
332 |
333 | # write to log
334 | iter_time = time.time()-t_start
335 | self._add_log(' '.join((f'epoch:{epoch}',
336 | f'lr:{lr_value}',
337 | f'time:{iter_time:.2f}\n')))
338 |
339 | log_template = '\t cost:{:.2f} azi_rmse:{:.2f}\n'
340 | self._add_log('\t valid ')
341 | self._add_log(log_template.format(
342 | cost_record_valid[epoch],
343 | azi_rmse_record_valid[epoch]))
344 |
345 | #
346 | if min_valid_cost > cost_record_valid[epoch]:
347 | self._add_log('find new optimal\n')
348 | best_epoch = epoch
349 | min_valid_cost = cost_record_valid[epoch]
350 | saver.save(self._sess, os.path.join(model_dir,
351 | 'model'),
352 | global_step=epoch)
353 |
354 | # save record info
355 | np.savez(os.path.join(model_dir, 'train_record'),
356 | cost_record_valid=cost_record_valid,
357 | azi_rmse_record_valid=azi_rmse_record_valid,
358 | lr=lr_value,
359 | best_epoch=best_epoch,
360 | min_valid_cost=min_valid_cost)
361 |
362 | # early stop
363 | if epoch-best_epoch > 5:
364 | print('early stop\n', min_valid_cost)
365 | self._add_log('early stop{}\n'.format(min_valid_cost))
366 | break
367 |
368 | # learning rate decay
369 | if epoch > 2: # no better performance in 2 epoches
370 | min_valid_cost_local = np.min(
371 | cost_record_valid[epoch-1:epoch+1])
372 | if cost_record_valid[epoch-2] < min_valid_cost_local:
373 | lr_value = lr_value*.2
374 |
375 | self._log_file.close()
376 |
377 | if True:
378 | fig, ax = plt.subplots(2, 1, sharex=True, tight_layout=True)
379 | ax[0].plot(cost_record_valid)
380 | ax[0].set_ylabel('cross entrophy')
381 |
382 | ax[1].plot(azi_rmse_record_valid)
383 | ax[1].set_ylabel('rmse(deg)')
384 | #
385 | fig_path = os.path.join(model_dir, 'train_curve.png')
386 | plt.savefig(fig_path)
387 |
388 | def predict(self, x):
389 | """Model output of x
390 | """
391 | y_est = self._sess.run(self._y_est, feed_dict={self._x: x})
392 | return y_est
393 |
394 | def evaluate(self, set_dir):
395 | cost_all = 0.
396 | rmse_all = 0.
397 | n_sample_all = 0
398 | batch_generator = self._file_reader(set_dir)
399 | for x, y in batch_generator:
400 | n_sample_tmp = x.shape[0]
401 | [cost_tmp, rmse_tmp] = self._sess.run([self._cost, self._azi_rmse],
402 | feed_dict={self._x: x,
403 | self._y: y})
404 | #
405 | n_sample_all = n_sample_all+n_sample_tmp
406 | cost_all = cost_all+n_sample_tmp*cost_tmp
407 | rmse_all = rmse_all+n_sample_tmp*(rmse_tmp**2)
408 |
409 | # average across all set
410 | cost_all = cost_all/n_sample_all
411 | rmse_all = np.sqrt(rmse_all/n_sample_all)
412 | return [cost_all, rmse_all]
413 |
414 | def evaluate_chunk_rmse(self, record_set_dir, chunk_size=25):
415 | """ Evaluate model on given data_set, only for loc
416 | Args:
417 | data_set_dir:
418 | Returns:
419 | [rmse_chunk,cp_chunk,rmse_frame,cp_frame]
420 | """
421 | rmse_chunk = 0.
422 | n_chunk = 0
423 |
424 | for x, y in self._file_reader(record_set_dir, is_shuffle=False):
425 | sample_num = x.shape[0]
426 | azi_true = np.argmax(y[0])
427 |
428 | y_est = self.predict(x)
429 | for sample_i in range(0, sample_num-chunk_size+1):
430 | azi_est_chunk = np.argmax(
431 | np.mean(
432 | y_est[sample_i:sample_i+chunk_size],
433 | axis=0))
434 | rmse_chunk = rmse_chunk+(azi_est_chunk-azi_true)**2
435 | n_chunk = n_chunk+1
436 |
437 | rmse_chunk = np.sqrt(rmse_chunk/n_chunk)*5
438 | return rmse_chunk
439 |
--------------------------------------------------------------------------------