├── .gitignore ├── LICENSE ├── README.md ├── cfg.py ├── data.py ├── dnn_models.py ├── experiments └── ngccphat │ ├── cfg.py │ ├── eval.txt │ ├── eval_anechoic.txt │ ├── evaluations.npz │ ├── evaluations_anechoic.npz │ ├── log.txt │ └── model.pth ├── helpers.py ├── main.py ├── model.py ├── ngcc.png ├── requirements.txt ├── speech.wav └── tdoa_example.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Axel Berg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Generalized Cross Correlations (NGCC) 2 | 3 | ![](ngcc.png) 4 | 5 | This is the official implementation of [*Extending GCC-PHAT using Shift Equivariant Neural Networks*](https://arxiv.org/abs/2208.04654), published in Interspeech 2022. We propose a neural GCC-PHAT (NGCC-PHAT). The signals are passed through a convolutional network that outputs multiple filtered versions of the signals. Several GCC-PHAT correlations are then computed and combined into a single probability distribution. 6 | 7 | For a quick start on how to use NGCCPHAT for TDOA estimation and positioning, try out the notebook [tdoa_example.ipynb](tdoa_example.ipynb) 8 | 9 | ## Dependencies 10 | 11 | To install the required dependencies use `pip install -r requirements.txt` 12 | 13 | ## Room simulation 14 | 15 | Signal propagation is simulated using [pyroomacoustics](https://github.com/LCAV/pyroomacoustics). At each simulation, a room with the specified dimensions is created with the speaker in the specified position. Two microphones are then placed randomly inside the room and the acoustic waveform is propagated from the speaker to the microphones, such that the received signals will be noisy copies of the original signal. The acoustic conditions are controlled by two parameters: the reverberation time (T60) and the signal to noise ratio (SNR). 16 | 17 | ## Dataset 18 | 19 | The LibriSpeech dataset contains recordings of English audiobooks in 16 kHz. For our experiments, we use recordings from the `test-clean` split which can be downloaded from https://www.openslr.org/12. Each recording is pruned from silent parts and recordings shorter than two seconds are removed. At training time, a randomly selected time window of 2048 samples is extracted and transmitted from the speaker position to each microphone. For reproducibility, all 15 time windows are used for each snippet at test time. 20 | 21 | Note that each recording is marked with a speaker id. We use recordings from three randomly selected speakers for training and validation respectively. The remaining 40 speakers are used for training. 22 | 23 | The dataset will be automatically downloaded the first time the training/evaluation script is run. 24 | 25 | ## Training a model from scratch 26 | 27 | To set up an experiment, first have a look at the configuration parameters in [cfg.py](cfg.py). Once everything is set up, you can train your own model by running 28 | 29 | ``` 30 | python main.py --exp_name=name_of_my_experiment 31 | ``` 32 | 33 | The training will run on a single GPU, if available. During training, both SNR and T60 is randomized within the provided intervals. Note that training time increases with the reverberation time T60. In order to train faster, try reducing the T60 range. 34 | 35 | ## Pre-trained models 36 | 37 | A pre-trained model is provided in [experiments/ngccphat](experiments/ngccphat), which was trained with the default parameters in the config file. To evaluate in the test room, run 38 | 39 | ``` 40 | python main.py --evaluate --exp_name=ngccphat 41 | ``` 42 | 43 | The log will output the MAE, RMSE and accuracy for the model and GCC-PHAT. For further analysis, all the ground truth delays and model predictions are stored in an `.npz` file as well. 44 | 45 | ## Citation 46 | 47 | If you use this code repository, please cite the following paper: 48 | 49 | ``` 50 | @inproceedings{berg22_interspeech, 51 | author={Axel Berg and Mark O'Connor and Kalle Åström and Magnus Oskarsson}, 52 | title={{Extending GCC-PHAT using Shift Equivariant Neural Networks}}, 53 | year=2022, 54 | booktitle={Proc. Interspeech 2022}, 55 | pages={1791--1795}, 56 | doi={10.21437/Interspeech.2022-524} 57 | } 58 | ``` 59 | 60 | ## Acknowledgements 61 | 62 | The network backbone is borrowed from SincNet and we thank the authors for sharing their code with the community. The original repository is found [here](https://github.com/mravanelli/SincNet). 63 | 64 | The PGCC-PHAT implementation is based on the description in the paper ["Time Delay Estimation for Speaker Localization Using CNN-Based Parametrized GCC-PHAT Features"](https://www.isca-speech.org/archive/pdfs/interspeech_2021/salvati21_interspeech.pdf) by Salvati et al. 65 | -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Training room simulation parameters 4 | # room dimensions in meters 5 | dx_train = 7.0 6 | dy_train = 5.0 7 | dz_train = 3.0 8 | room_dim_train = [dx_train, dy_train, dz_train] 9 | xyz_min_train = [0.0, 0.0, 0.0] 10 | xyz_max_train = room_dim_train 11 | 12 | # microphone locations 13 | mic_locs_train = np.array([[3.5, 2.25, 1.5], [3.5, 2.75, 1.5]]).T 14 | 15 | # Test room parameters 16 | dx_test = 6.0 17 | dy_test = 4.0 18 | dz_test = 2.5 19 | room_dim_test = [dx_test, dy_test, dz_test] 20 | xyz_min_test = [0.0, 0.0, 0.0] 21 | xyz_max_test = room_dim_test 22 | 23 | mic_locs_test = np.array([[3.0, 1.75, 1.25], [3.0, 2.25, 1.25]]).T 24 | 25 | # Testing environment configuration 26 | # the model will be evaluated for all SNRs and T60s in the lists 27 | snr_range = [0, 6, 12, 18, 24, 30] 28 | t60_range = [0.2, 0.4, 0.6] 29 | 30 | # accuracy threshold in cm 31 | t_cm = 10 32 | 33 | # Training hyperparams 34 | seed = 0 35 | batch_size = 32 36 | epochs = 30 37 | lr = 0.001 # learning rate 38 | wd = 0.0 # weight decay 39 | ls = 0.0 # label smoothing 40 | 41 | # Model parameters 42 | model = 'NGCCPHAT' # choices: NGCCPHAT, PGCCPHAT 43 | max_delay = 23 44 | num_channels = 128 # number of channels in final layer of NGCCPHAT backbone 45 | head = 'classifier' # final layer type. Choices: 'classifier', 'regression' 46 | loss = 'ce' # use 'ce' loss for classifier and 'mse' loss for regression 47 | # Set to true in order to replace Sinc filters with regular convolutional layers 48 | no_sinc = False 49 | 50 | # training environment 51 | snr = [0, 30] # during training, snr will be drawn uniformly from this interval 52 | t60 = [0.2, 1.0] # during training, t60 will be drawn uniformly from this interval 53 | fs = 16000 # sampling rate 54 | sig_len = 2048 # length of snippet used for tdoa estimation 55 | anechoic = False # set to True to use anechoic environment without reverberation 56 | 57 | # threshold in samples 58 | t = t_cm * fs / (343 * 100) 59 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from torchaudio.datasets import LIBRISPEECH 2 | import pyroomacoustics as pra 3 | import numpy as np 4 | from typing import Tuple 5 | from torch import Tensor 6 | import torch 7 | import random 8 | from librosa.effects import split 9 | 10 | 11 | def remove_silence(signal, top_db=20, frame_length=2048, hop_length=512): 12 | ''' 13 | Remove silence from speech signal 14 | ''' 15 | signal = signal.squeeze() 16 | clips = split(signal, top_db=top_db, 17 | frame_length=frame_length, hop_length=hop_length) 18 | output = [] 19 | for ii in clips: 20 | start, end = ii 21 | output.append(signal[start:end]) 22 | 23 | return torch.cat(output) 24 | 25 | 26 | class LibriSpeechLocations(LIBRISPEECH): 27 | ''' 28 | Class of LibriSpeech recordings. Each recording is annotated with a speaker location. 29 | ''' 30 | 31 | def __init__(self, source_locs, split): 32 | super().__init__("./", url=split, download=True) 33 | 34 | self.source_locs = source_locs 35 | 36 | def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int, int, float, int]: 37 | 38 | source_loc = self.source_locs[n] 39 | seed = n 40 | waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_number = super().__getitem__(n) 41 | return (waveform, sample_rate, transcript, speaker_id, utterance_number), source_loc, seed 42 | 43 | 44 | def one_random_delay(room_dim, fs, t60, mic_locs, signal, xyz_min, xyz_max, snr, anechoic=False): 45 | ''' 46 | Simulate signal propagation using pyroomacoustics using random source location. 47 | ''' 48 | 49 | if anechoic: 50 | e_absorption = 1.0 51 | max_order = 0 52 | else: 53 | e_absorption, max_order = pra.inverse_sabine(t60, room_dim) 54 | 55 | room = pra.ShoeBox(room_dim, fs=fs, materials=pra.Material( 56 | e_absorption), max_order=max_order) 57 | 58 | source_loc = np.random.uniform(low=xyz_min, high=xyz_max, size=(3)) 59 | room.add_source(source_loc, signal=signal.squeeze()) 60 | room.add_microphone(mic_locs) 61 | c = room.c 62 | d = np.sqrt(np.sum((mic_locs[:, 0] - source_loc)**2)) - \ 63 | np.sqrt(np.sum((mic_locs[:, 1] - source_loc)**2)) 64 | delay = d * fs / c 65 | room.simulate(reference_mic=0, snr=snr) 66 | x1 = room.mic_array.signals[0, :] 67 | x2 = room.mic_array.signals[1, :] 68 | 69 | return x1, x2, delay, room 70 | 71 | 72 | def one_delay(room_dim, fs, t60, mic_locs, signal, source_loc, snr=1000, anechoic=False): 73 | ''' 74 | Simulate signal propagation using pyroomacoustics for a given source location. 75 | ''' 76 | 77 | if anechoic: 78 | e_absorption = 1.0 79 | max_order = 0 80 | else: 81 | e_absorption, max_order = pra.inverse_sabine(t60, room_dim) 82 | 83 | room = pra.ShoeBox(room_dim, fs=fs, materials=pra.Material( 84 | e_absorption), max_order=max_order) 85 | 86 | room.add_source(source_loc, signal=signal.squeeze()) 87 | room.add_microphone(mic_locs) 88 | c = room.c 89 | d = np.sqrt(np.sum((mic_locs[:, 0] - source_loc)**2)) - \ 90 | np.sqrt(np.sum((mic_locs[:, 1] - source_loc)**2)) 91 | delay = d * fs / c 92 | room.simulate(reference_mic=0, snr=snr) 93 | x1 = room.mic_array.signals[0, :] 94 | x2 = room.mic_array.signals[1, :] 95 | 96 | return x1, x2, delay, room 97 | 98 | 99 | def pad_sequence(batch): 100 | # Make all tensor in a batch the same length by padding with zeros 101 | batch = [item.t() for item in batch] 102 | batch = torch.nn.utils.rnn.pad_sequence( 103 | batch, batch_first=True, padding_value=0.) 104 | return batch 105 | 106 | 107 | class DelaySimulator(object): 108 | ''' 109 | Given a batch of LibrispeechLocation samples, simulate signal 110 | propagation from source to the microphone locations. 111 | ''' 112 | 113 | def __init__(self, room_dim, fs, N, t60, mic_locs, max_tau, anechoic, train=True, snr=1000, lower_bound=16000, upper_bound=48000): 114 | 115 | self.room_dim = room_dim 116 | self.fs = fs 117 | self.N = N 118 | self.mic_locs = mic_locs 119 | self.max_tau = max_tau 120 | self.snr = snr 121 | self.t60 = t60 122 | self.anechoic = anechoic 123 | self.train = train 124 | 125 | self.lower_bound = lower_bound 126 | self.upper_bound = upper_bound 127 | 128 | def __call__(self, batch): 129 | # A data tuple has the form: 130 | # waveform, sample_rate, label, speaker_id, utterance_number 131 | 132 | tensors1, tensors2, targets = [], [], [] 133 | 134 | # Gather in lists, and encode labels as indices 135 | with torch.no_grad(): 136 | for (waveform, sample_rate, _, _, _), source_loc, seed in batch: 137 | 138 | waveform = waveform.squeeze() 139 | signal = remove_silence(waveform, frame_length=self.N) 140 | 141 | # use random seed for training, fixed for val/test 142 | # this controls the randomness in sound propagation when simulating the room 143 | if not self.train: 144 | torch.manual_seed(seed) 145 | random.seed(seed) 146 | np.random.seed(seed) 147 | 148 | # sample random reverberation time and SNR 149 | this_t60 = np.random.uniform(low=self.t60[0], high=self.t60[1]) 150 | this_snr = np.random.uniform(low=self.snr[0], high=self.snr[1]) 151 | 152 | x1, x2, delay, _ = one_delay(room_dim=self.room_dim, fs=self.fs, t60=this_t60, 153 | mic_locs=self.mic_locs, signal=signal, 154 | source_loc=source_loc, snr=this_snr, 155 | anechoic=self.anechoic) 156 | 157 | if self.train: 158 | start_idx = torch.randint( 159 | self.lower_bound, self.upper_bound - self.N - 1, (1,)) 160 | else: 161 | start_idx = self.lower_bound 162 | 163 | end_idx = start_idx + self.N 164 | x1 = x1[start_idx:end_idx] 165 | x2 = x2[start_idx:end_idx] 166 | 167 | tensors1 += [torch.as_tensor(x1, dtype=torch.float)] 168 | tensors2 += [torch.as_tensor(x2, dtype=torch.float)] 169 | targets += [delay+self.max_tau] 170 | 171 | # Group the list of tensors into a batched tensor 172 | tensors1 = pad_sequence(tensors1).unsqueeze(1) 173 | tensors2 = pad_sequence(tensors2).unsqueeze(1) 174 | targets = torch.Tensor(targets) 175 | 176 | return tensors1, tensors2, targets 177 | -------------------------------------------------------------------------------- /dnn_models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file contains the implementation of SincNet, by Mirco Ravanelli and Yoshua Bengio 3 | Circular padding has been added before each convolution. 4 | Source: https://github.com/mravanelli/SincNet 5 | ''' 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | import math 13 | from torch_same_pad import get_pad 14 | 15 | 16 | def flip(x, dim): 17 | xsize = x.size() 18 | dim = x.dim() + dim if dim < 0 else dim 19 | x = x.contiguous() 20 | x = x.view(-1, *xsize[dim:]) 21 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, 22 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] 23 | return x.view(xsize) 24 | 25 | 26 | def sinc(band, t_right): 27 | y_right = torch.sin(2*math.pi*band*t_right)/(2*math.pi*band*t_right) 28 | y_left = flip(y_right, 0) 29 | 30 | y = torch.cat([y_left, Variable(torch.ones(1)).cuda(), y_right]) 31 | 32 | return y 33 | 34 | 35 | class SincConv_fast(nn.Module): 36 | """Sinc-based convolution 37 | Parameters 38 | ---------- 39 | in_channels : `int` 40 | Number of input channels. Must be 1. 41 | out_channels : `int` 42 | Number of filters. 43 | kernel_size : `int` 44 | Filter length. 45 | sample_rate : `int`, optional 46 | Sample rate. Defaults to 16000. 47 | Usage 48 | ----- 49 | See `torch.nn.Conv1d` 50 | Reference 51 | --------- 52 | Mirco Ravanelli, Yoshua Bengio, 53 | "Speaker Recognition from raw waveform with SincNet". 54 | https://arxiv.org/abs/1808.00158 55 | """ 56 | 57 | @staticmethod 58 | def to_mel(hz): 59 | return 2595 * np.log10(1 + hz / 700) 60 | 61 | @staticmethod 62 | def to_hz(mel): 63 | return 700 * (10 ** (mel / 2595) - 1) 64 | 65 | def __init__(self, out_channels, kernel_size, sample_rate=16000, in_channels=1, 66 | stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=50, min_band_hz=50): 67 | 68 | super(SincConv_fast, self).__init__() 69 | 70 | if in_channels != 1: 71 | #msg = (f'SincConv only support one input channel ' 72 | # f'(here, in_channels = {in_channels:d}).') 73 | msg = "SincConv only support one input channel (here, in_channels = {%i})" % ( 74 | in_channels) 75 | raise ValueError(msg) 76 | 77 | self.out_channels = out_channels 78 | self.kernel_size = kernel_size 79 | 80 | # Forcing the filters to be odd (i.e, perfectly symmetrics) 81 | if kernel_size % 2 == 0: 82 | self.kernel_size = self.kernel_size+1 83 | 84 | self.stride = stride 85 | self.padding = padding 86 | self.dilation = dilation 87 | 88 | if bias: 89 | raise ValueError('SincConv does not support bias.') 90 | if groups > 1: 91 | raise ValueError('SincConv does not support groups.') 92 | 93 | self.sample_rate = sample_rate 94 | self.min_low_hz = min_low_hz 95 | self.min_band_hz = min_band_hz 96 | 97 | # initialize filterbanks such that they are equally spaced in Mel scale 98 | low_hz = 30 99 | high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz) 100 | 101 | mel = np.linspace(self.to_mel(low_hz), 102 | self.to_mel(high_hz), 103 | self.out_channels + 1) 104 | hz = self.to_hz(mel) 105 | 106 | # filter lower frequency (out_channels, 1) 107 | self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) 108 | 109 | # filter frequency band (out_channels, 1) 110 | self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) 111 | 112 | # Hamming window 113 | #self.window_ = torch.hamming_window(self.kernel_size) 114 | # computing only half of the window 115 | n_lin = torch.linspace(0, (self.kernel_size/2)-1, 116 | steps=int((self.kernel_size/2))) 117 | self.window_ = 0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size) 118 | 119 | # (1, kernel_size/2) 120 | n = (self.kernel_size - 1) / 2.0 121 | # Due to symmetry, I only need half of the time axes 122 | self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate 123 | 124 | def forward(self, waveforms): 125 | """ 126 | Parameters 127 | ---------- 128 | waveforms : `torch.Tensor` (batch_size, 1, n_samples) 129 | Batch of waveforms. 130 | Returns 131 | ------- 132 | features : `torch.Tensor` (batch_size, out_channels, n_samples_out) 133 | Batch of sinc filters activations. 134 | """ 135 | 136 | self.n_ = self.n_.to(waveforms.device) 137 | 138 | self.window_ = self.window_.to(waveforms.device) 139 | 140 | low = self.min_low_hz + torch.abs(self.low_hz_) 141 | 142 | high = torch.clamp(low + self.min_band_hz 143 | + torch.abs(self.band_hz_), self.min_low_hz, self.sample_rate/2) 144 | band = (high-low)[:, 0] 145 | 146 | f_times_t_low = torch.matmul(low, self.n_) 147 | f_times_t_high = torch.matmul(high, self.n_) 148 | 149 | # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations. 150 | band_pass_left = ((torch.sin(f_times_t_high) 151 | - torch.sin(f_times_t_low))/(self.n_/2))*self.window_ 152 | band_pass_center = 2*band.view(-1, 1) 153 | band_pass_right = torch.flip(band_pass_left, dims=[1]) 154 | 155 | band_pass = torch.cat( 156 | [band_pass_left, band_pass_center, band_pass_right], dim=1) 157 | 158 | band_pass = band_pass / (2*band[:, None]) 159 | 160 | self.filters = (band_pass).view( 161 | self.out_channels, 1, self.kernel_size) 162 | 163 | return F.conv1d(waveforms, self.filters, stride=self.stride, 164 | padding=self.padding, dilation=self.dilation, 165 | bias=None, groups=1) 166 | 167 | 168 | class sinc_conv(nn.Module): 169 | 170 | def __init__(self, N_filt, Filt_dim, fs): 171 | super(sinc_conv, self).__init__() 172 | 173 | # Mel Initialization of the filterbanks 174 | low_freq_mel = 80 175 | high_freq_mel = (2595 * np.log10(1 + (fs / 2) / 700) 176 | ) # Convert Hz to Mel 177 | # Equally spaced in Mel scale 178 | mel_points = np.linspace(low_freq_mel, high_freq_mel, N_filt) 179 | f_cos = (700 * (10**(mel_points / 2595) - 1)) # Convert Mel to Hz 180 | b1 = np.roll(f_cos, 1) 181 | b2 = np.roll(f_cos, -1) 182 | b1[0] = 30 183 | b2[-1] = (fs/2)-100 184 | 185 | self.freq_scale = fs*1.0 186 | self.filt_b1 = nn.Parameter(torch.from_numpy(b1/self.freq_scale)) 187 | self.filt_band = nn.Parameter( 188 | torch.from_numpy((b2-b1)/self.freq_scale)) 189 | 190 | self.N_filt = N_filt 191 | self.Filt_dim = Filt_dim 192 | self.fs = fs 193 | 194 | def forward(self, x): 195 | 196 | filters = Variable(torch.zeros((self.N_filt, self.Filt_dim))).cuda() 197 | N = self.Filt_dim 198 | t_right = Variable(torch.linspace( 199 | 1, (N-1)/2, steps=int((N-1)/2))/self.fs).cuda() 200 | 201 | min_freq = 50.0 202 | min_band = 50.0 203 | 204 | filt_beg_freq = torch.abs(self.filt_b1)+min_freq/self.freq_scale 205 | filt_end_freq = filt_beg_freq + \ 206 | (torch.abs(self.filt_band)+min_band/self.freq_scale) 207 | 208 | n = torch.linspace(0, N, steps=N) 209 | 210 | # Filter window (hamming) 211 | window = 0.54-0.46*torch.cos(2*math.pi*n/N) 212 | window = Variable(window.float().cuda()) 213 | 214 | for i in range(self.N_filt): 215 | 216 | low_pass1 = 2 * \ 217 | filt_beg_freq[i].float()*sinc(filt_beg_freq[i].float() 218 | * self.freq_scale, t_right) 219 | low_pass2 = 2 * \ 220 | filt_end_freq[i].float()*sinc(filt_end_freq[i].float() 221 | * self.freq_scale, t_right) 222 | band_pass = (low_pass2-low_pass1) 223 | 224 | band_pass = band_pass/torch.max(band_pass) 225 | 226 | filters[i, :] = band_pass.cuda()*window 227 | 228 | out = F.conv1d(x, filters.view(self.N_filt, 1, self.Filt_dim)) 229 | 230 | return out 231 | 232 | 233 | def act_fun(act_type): 234 | 235 | if act_type == "relu": 236 | return nn.ReLU() 237 | 238 | if act_type == "tanh": 239 | return nn.Tanh() 240 | 241 | if act_type == "sigmoid": 242 | return nn.Sigmoid() 243 | 244 | if act_type == "leaky_relu": 245 | return nn.LeakyReLU(0.2) 246 | 247 | if act_type == "elu": 248 | return nn.ELU() 249 | 250 | if act_type == "softmax": 251 | return nn.LogSoftmax(dim=1) 252 | 253 | if act_type == "linear": 254 | return nn.LeakyReLU(1) # initializzed like this, but not used in forward! 255 | 256 | 257 | class LayerNorm(nn.Module): 258 | 259 | def __init__(self, features, eps=1e-6): 260 | super(LayerNorm, self).__init__() 261 | self.gamma = nn.Parameter(torch.ones(features)) 262 | self.beta = nn.Parameter(torch.zeros(features)) 263 | self.eps = eps 264 | 265 | def forward(self, x): 266 | mean = x.mean(-1, keepdim=True) 267 | std = x.std(-1, keepdim=True) 268 | return self.gamma * (x - mean) / (std + self.eps) + self.beta 269 | 270 | 271 | class MLP(nn.Module): 272 | def __init__(self, options): 273 | super(MLP, self).__init__() 274 | 275 | self.input_dim = int(options['input_dim']) 276 | self.fc_lay = options['fc_lay'] 277 | self.fc_drop = options['fc_drop'] 278 | self.fc_use_batchnorm = options['fc_use_batchnorm'] 279 | self.fc_use_laynorm = options['fc_use_laynorm'] 280 | self.fc_use_laynorm_inp = options['fc_use_laynorm_inp'] 281 | self.fc_use_batchnorm_inp = options['fc_use_batchnorm_inp'] 282 | self.fc_act = options['fc_act'] 283 | 284 | self.wx = nn.ModuleList([]) 285 | self.bn = nn.ModuleList([]) 286 | self.ln = nn.ModuleList([]) 287 | self.act = nn.ModuleList([]) 288 | self.drop = nn.ModuleList([]) 289 | 290 | # input layer normalization 291 | if self.fc_use_laynorm_inp: 292 | self.ln0 = LayerNorm(self.input_dim) 293 | 294 | # input batch normalization 295 | if self.fc_use_batchnorm_inp: 296 | self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) 297 | 298 | self.N_fc_lay = len(self.fc_lay) 299 | 300 | current_input = self.input_dim 301 | 302 | # Initialization of hidden layers 303 | 304 | for i in range(self.N_fc_lay): 305 | 306 | # dropout 307 | self.drop.append(nn.Dropout(p=self.fc_drop[i])) 308 | 309 | # activation 310 | self.act.append(act_fun(self.fc_act[i])) 311 | 312 | add_bias = True 313 | 314 | # layer norm initialization 315 | self.ln.append(LayerNorm(self.fc_lay[i])) 316 | self.bn.append(nn.BatchNorm1d(self.fc_lay[i], momentum=0.05)) 317 | 318 | if self.fc_use_laynorm[i] or self.fc_use_batchnorm[i]: 319 | add_bias = False 320 | 321 | # Linear operations 322 | self.wx.append( 323 | nn.Linear(current_input, self.fc_lay[i], bias=add_bias)) 324 | 325 | # weight initialization 326 | self.wx[i].weight = torch.nn.Parameter(torch.Tensor(self.fc_lay[i], current_input).uniform_( 327 | -np.sqrt(0.01/(current_input+self.fc_lay[i])), np.sqrt(0.01/(current_input+self.fc_lay[i])))) 328 | self.wx[i].bias = torch.nn.Parameter(torch.zeros(self.fc_lay[i])) 329 | 330 | current_input = self.fc_lay[i] 331 | 332 | def forward(self, x): 333 | 334 | # Applying Layer/Batch Norm 335 | if bool(self.fc_use_laynorm_inp): 336 | x = self.ln0((x)) 337 | 338 | if bool(self.fc_use_batchnorm_inp): 339 | x = self.bn0((x)) 340 | 341 | for i in range(self.N_fc_lay): 342 | 343 | if self.fc_act[i] != 'linear': 344 | 345 | if self.fc_use_laynorm[i]: 346 | x = self.drop[i](self.act[i](self.ln[i](self.wx[i](x)))) 347 | 348 | if self.fc_use_batchnorm[i]: 349 | x = self.drop[i](self.act[i](self.bn[i](self.wx[i](x)))) 350 | 351 | if self.fc_use_batchnorm[i] == False and self.fc_use_laynorm[i] == False: 352 | x = self.drop[i](self.act[i](self.wx[i](x))) 353 | 354 | else: 355 | if self.fc_use_laynorm[i]: 356 | x = self.drop[i](self.ln[i](self.wx[i](x))) 357 | 358 | if self.fc_use_batchnorm[i]: 359 | x = self.drop[i](self.bn[i](self.wx[i](x))) 360 | 361 | if self.fc_use_batchnorm[i] == False and self.fc_use_laynorm[i] == False: 362 | x = self.drop[i](self.wx[i](x)) 363 | 364 | return x 365 | 366 | 367 | class SincNet(nn.Module): 368 | 369 | def __init__(self, options): 370 | super(SincNet, self).__init__() 371 | 372 | self.cnn_N_filt = options['cnn_N_filt'] 373 | self.cnn_len_filt = options['cnn_len_filt'] 374 | self.cnn_max_pool_len = options['cnn_max_pool_len'] 375 | 376 | self.cnn_act = options['cnn_act'] 377 | self.cnn_drop = options['cnn_drop'] 378 | 379 | self.cnn_use_laynorm = options['cnn_use_laynorm'] 380 | self.cnn_use_batchnorm = options['cnn_use_batchnorm'] 381 | self.cnn_use_laynorm_inp = options['cnn_use_laynorm_inp'] 382 | self.cnn_use_batchnorm_inp = options['cnn_use_batchnorm_inp'] 383 | 384 | self.input_dim = int(options['input_dim']) 385 | 386 | self.fs = options['fs'] 387 | 388 | self.N_cnn_lay = len(options['cnn_N_filt']) 389 | self.conv = nn.ModuleList([]) 390 | self.bn = nn.ModuleList([]) 391 | self.ln = nn.ModuleList([]) 392 | self.act = nn.ModuleList([]) 393 | self.drop = nn.ModuleList([]) 394 | self.use_sinc = options['use_sinc'] 395 | 396 | if self.cnn_use_laynorm_inp: 397 | self.ln0 = LayerNorm(self.input_dim) 398 | 399 | if self.cnn_use_batchnorm_inp: 400 | self.bn0 = nn.BatchNorm1d([self.input_dim], momentum=0.05) 401 | 402 | current_input = self.input_dim 403 | 404 | for i in range(self.N_cnn_lay): 405 | 406 | N_filt = int(self.cnn_N_filt[i]) 407 | len_filt = int(self.cnn_len_filt[i]) 408 | 409 | # dropout 410 | self.drop.append(nn.Dropout(p=self.cnn_drop[i])) 411 | 412 | # activation 413 | self.act.append(act_fun(self.cnn_act[i])) 414 | 415 | # layer norm initialization 416 | #self.ln.append(LayerNorm([N_filt,int((current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i])])) 417 | 418 | self.bn.append(nn.BatchNorm1d(N_filt, momentum=0.05)) 419 | 420 | if i == 0: 421 | if self.use_sinc: 422 | self.conv.append(SincConv_fast( 423 | self.cnn_N_filt[0], self.cnn_len_filt[0], self.fs)) 424 | else: 425 | self.conv.append( 426 | nn.Conv1d(1, self.cnn_N_filt[i], self.cnn_len_filt[i])) 427 | 428 | else: 429 | self.conv.append( 430 | nn.Conv1d(self.cnn_N_filt[i-1], self.cnn_N_filt[i], self.cnn_len_filt[i])) 431 | 432 | current_input = int( 433 | (current_input-self.cnn_len_filt[i]+1)/self.cnn_max_pool_len[i]) 434 | 435 | self.out_dim = current_input*N_filt 436 | 437 | def forward(self, x): 438 | batch = x.shape[0] 439 | seq_len = x.shape[-1] 440 | 441 | if bool(self.cnn_use_laynorm_inp): 442 | x = self.ln0((x)) 443 | 444 | if bool(self.cnn_use_batchnorm_inp): 445 | x = self.bn0((x)) 446 | 447 | x = x.view(batch, 1, seq_len) 448 | 449 | for i in range(self.N_cnn_lay): 450 | 451 | s = x.shape[2] 452 | padding = get_pad( 453 | size=s, kernel_size=self.cnn_len_filt[i], stride=1, dilation=1) 454 | x = F.pad(x, pad=padding, mode='circular') 455 | 456 | if self.cnn_use_laynorm[i]: 457 | if i == 0: 458 | x = self.drop[i](self.act[i](self.ln[i](F.max_pool1d( 459 | torch.abs(self.conv[i](x)), self.cnn_max_pool_len[i])))) 460 | else: 461 | x = self.drop[i](self.act[i](self.ln[i]( 462 | F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) 463 | 464 | if self.cnn_use_batchnorm[i]: 465 | x = self.drop[i](self.act[i](self.bn[i]( 466 | F.max_pool1d(self.conv[i](x), self.cnn_max_pool_len[i])))) 467 | 468 | if self.cnn_use_batchnorm[i] == False and self.cnn_use_laynorm[i] == False: 469 | x = self.drop[i](self.act[i](F.max_pool1d( 470 | self.conv[i](x), self.cnn_max_pool_len[i]))) 471 | 472 | #x = x.view(batch,-1) 473 | 474 | return x 475 | -------------------------------------------------------------------------------- /experiments/ngccphat/cfg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Training room simulation parameters 4 | # room dimensions in meters 5 | dx_train = 7.0 6 | dy_train = 5.0 7 | dz_train = 3.0 8 | room_dim_train = [dx_train, dy_train, dz_train] 9 | xyz_min_train = [0.0, 0.0, 0.0] 10 | xyz_max_train = room_dim_train 11 | 12 | # microphone locations 13 | mic_locs_train = np.array([[3.5, 2.25, 1.5], [3.5, 2.75, 1.5]]).T 14 | 15 | # Test room parameters 16 | dx_test = 6.0 17 | dy_test = 4.0 18 | dz_test = 2.5 19 | room_dim_test = [dx_test, dy_test, dz_test] 20 | xyz_min_test = [0.0, 0.0, 0.0] 21 | xyz_max_test = room_dim_test 22 | 23 | mic_locs_test = np.array([[3.0, 1.75, 1.25], [3.0, 2.25, 1.25]]).T 24 | 25 | # Testing environment configuration 26 | # the model will be evaluated for all SNRs and T60s in the lists 27 | snr_range = [0, 6, 12, 18, 24, 30] 28 | t60_range = [0.2, 0.4, 0.6] 29 | 30 | # accuracy threshold in cm 31 | t_cm = 10 32 | 33 | # Training hyperparams 34 | seed = 0 35 | batch_size = 32 36 | epochs = 30 37 | lr = 0.001 # learning rate 38 | wd = 0.0 # weight decay 39 | ls = 0.0 # label smoothing 40 | 41 | # Model parameters 42 | model = 'NGCCPHAT' # choices: NGCCPHAT, PGCCPHAT 43 | max_delay = 23 44 | num_channels = 128 # number of channels in final layer of NGCCPHAT backbone 45 | head = 'classifier' # final layer type. Choices: 'classifier', 'regression' 46 | loss = 'ce' # use 'ce' loss for classifier and 'mse' loss for regression 47 | # Set to true in order to replace Sinc filters with regular convolutional layers 48 | no_sinc = False 49 | 50 | # training environment 51 | snr = [0, 30] # during training, snr will be drawn uniformly from this interval 52 | t60 = [0.2, 1.0] # during training, t60 will be drawn uniformly from this interval 53 | fs = 16000 # sampling rate 54 | sig_len = 2048 # length of snippet used for tdoa estimation 55 | anechoic = False # set to True to use anechoic environment without reverberation 56 | 57 | # threshold in samples 58 | t = t_cm * fs / (343 * 100) 59 | -------------------------------------------------------------------------------- /experiments/ngccphat/eval.txt: -------------------------------------------------------------------------------- 1 | Namespace(evaluate=True, exp_name='ngccphat') 2 | SNR: 0, T60: 0.200000, loss: 2.345504, RMSE: 8.276028, GCC RMSE: 10.105556, MAE: 4.339522, GCC MAE: 5.684419, ACC: 0.722531, GCC ACC: 0.658642 3 | 4 | SNR: 0, T60: 0.400000, loss: 3.078989, RMSE: 12.271958, GCC RMSE: 13.830562, MAE: 7.947325, GCC MAE: 9.344838, ACC: 0.520988, GCC ACC: 0.466667 5 | 6 | SNR: 0, T60: 0.600000, loss: 3.351296, RMSE: 13.627183, GCC RMSE: 15.206236, MAE: 9.400230, GCC MAE: 10.932364, ACC: 0.437654, GCC ACC: 0.383642 7 | 8 | SNR: 6, T60: 0.200000, loss: 1.714704, RMSE: 6.226385, GCC RMSE: 7.733721, MAE: 2.770981, GCC MAE: 3.718708, ACC: 0.825926, GCC ACC: 0.769753 9 | 10 | SNR: 6, T60: 0.400000, loss: 2.633392, RMSE: 10.736298, GCC RMSE: 12.238252, MAE: 6.431622, GCC MAE: 7.740408, ACC: 0.606481, GCC ACC: 0.540123 11 | 12 | SNR: 6, T60: 0.600000, loss: 2.999080, RMSE: 12.347720, GCC RMSE: 13.653988, MAE: 7.989209, GCC MAE: 9.365261, ACC: 0.515741, GCC ACC: 0.445370 13 | 14 | SNR: 12, T60: 0.200000, loss: 1.266595, RMSE: 4.935671, GCC RMSE: 6.300320, MAE: 1.875904, GCC MAE: 2.697714, ACC: 0.888889, GCC ACC: 0.835185 15 | 16 | SNR: 12, T60: 0.400000, loss: 2.277108, RMSE: 9.627654, GCC RMSE: 11.165002, MAE: 5.315495, GCC MAE: 6.743690, ACC: 0.670988, GCC ACC: 0.587654 17 | 18 | SNR: 12, T60: 0.600000, loss: 2.718469, RMSE: 11.530568, GCC RMSE: 12.630047, MAE: 7.103642, GCC MAE: 8.401045, ACC: 0.562963, GCC ACC: 0.484568 19 | 20 | SNR: 18, T60: 0.200000, loss: 0.972733, RMSE: 4.472382, GCC RMSE: 5.698935, MAE: 1.517281, GCC MAE: 2.281355, ACC: 0.916049, GCC ACC: 0.861111 21 | 22 | SNR: 18, T60: 0.400000, loss: 2.008470, RMSE: 8.843968, GCC RMSE: 10.577561, MAE: 4.654292, GCC MAE: 6.219817, ACC: 0.708333, GCC ACC: 0.619444 23 | 24 | SNR: 18, T60: 0.600000, loss: 2.517107, RMSE: 11.143929, GCC RMSE: 12.068566, MAE: 6.645403, GCC MAE: 7.850311, ACC: 0.591358, GCC ACC: 0.512963 25 | 26 | SNR: 24, T60: 0.200000, loss: 0.824509, RMSE: 3.997636, GCC RMSE: 5.553971, MAE: 1.280917, GCC MAE: 2.189512, ACC: 0.930247, GCC ACC: 0.860185 27 | 28 | SNR: 24, T60: 0.400000, loss: 1.862640, RMSE: 8.658301, GCC RMSE: 10.035186, MAE: 4.416655, GCC MAE: 5.917069, ACC: 0.726235, GCC ACC: 0.618210 29 | 30 | SNR: 24, T60: 0.600000, loss: 2.411842, RMSE: 10.701366, GCC RMSE: 11.480780, MAE: 6.321869, GCC MAE: 7.507458, ACC: 0.601852, GCC ACC: 0.513580 31 | 32 | SNR: 30, T60: 0.200000, loss: 0.806157, RMSE: 3.748568, GCC RMSE: 6.117130, MAE: 1.217734, GCC MAE: 2.570721, ACC: 0.931790, GCC ACC: 0.829321 33 | 34 | SNR: 30, T60: 0.400000, loss: 1.853323, RMSE: 8.552352, GCC RMSE: 9.858856, MAE: 4.382690, GCC MAE: 5.980089, ACC: 0.722840, GCC ACC: 0.601543 35 | 36 | SNR: 30, T60: 0.600000, loss: 2.408468, RMSE: 10.564530, GCC RMSE: 11.238611, MAE: 6.247636, GCC MAE: 7.650620, ACC: 0.603395, GCC ACC: 0.489198 37 | 38 | -------------------------------------------------------------------------------- /experiments/ngccphat/eval_anechoic.txt: -------------------------------------------------------------------------------- 1 | Namespace(evaluate=True, exp_name='ngccphat') 2 | SNR: 0, T60: 0.000000, loss: 1.525557, RMSE: 3.846320, GCC RMSE: 5.904639, MAE: 1.404341, GCC MAE: 2.140092, ACC: 0.925926, GCC ACC: 0.891667 3 | 4 | SNR: 6, T60: 0.000000, loss: 1.005062, RMSE: 2.656893, GCC RMSE: 4.455696, MAE: 0.815528, GCC MAE: 1.192619, ACC: 0.966049, GCC ACC: 0.952161 5 | 6 | SNR: 12, T60: 0.000000, loss: 0.768119, RMSE: 2.376456, GCC RMSE: 4.016468, MAE: 0.658064, GCC MAE: 0.914014, ACC: 0.973765, GCC ACC: 0.969753 7 | 8 | SNR: 18, T60: 0.000000, loss: 0.634981, RMSE: 2.160124, GCC RMSE: 3.897248, MAE: 0.568244, GCC MAE: 0.825095, ACC: 0.976852, GCC ACC: 0.972840 9 | 10 | SNR: 24, T60: 0.000000, loss: 0.572766, RMSE: 2.160010, GCC RMSE: 3.989211, MAE: 0.552147, GCC MAE: 0.856189, ACC: 0.977469, GCC ACC: 0.968827 11 | 12 | SNR: 30, T60: 0.000000, loss: 0.558504, RMSE: 2.099877, GCC RMSE: 3.970475, MAE: 0.553722, GCC MAE: 0.866308, ACC: 0.975617, GCC ACC: 0.966049 13 | 14 | -------------------------------------------------------------------------------- /experiments/ngccphat/evaluations.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axeber01/ngcc/2210bf55395669c1c39d49f65f286fc9a1523fc9/experiments/ngccphat/evaluations.npz -------------------------------------------------------------------------------- /experiments/ngccphat/evaluations_anechoic.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axeber01/ngcc/2210bf55395669c1c39d49f65f286fc9a1523fc9/experiments/ngccphat/evaluations_anechoic.npz -------------------------------------------------------------------------------- /experiments/ngccphat/log.txt: -------------------------------------------------------------------------------- 1 | Train epoch 0, loss: 2.940215, MAE: 7.181322, GCC-MAE: 7.486875, ACC: 0.559197, GCC-ACC: 0.544926 2 | 3 | Val epoch 0, loss: 2.884162, MAE: 6.889122, GCC MAE: 9.349649, ACC: 0.558511, GCC ACC: 0.526596 4 | 5 | Train epoch 1, loss: 2.787211, MAE: 6.788748, GCC-MAE: 7.788082, ACC: 0.570296, GCC-ACC: 0.530655 6 | 7 | Val epoch 1, loss: 2.742727, MAE: 6.311578, GCC MAE: 7.847068, ACC: 0.606383, GCC ACC: 0.526596 8 | 9 | Train epoch 2, loss: 2.786114, MAE: 6.864980, GCC-MAE: 7.763416, ACC: 0.576110, GCC-ACC: 0.542812 10 | 11 | Val epoch 2, loss: 2.783464, MAE: 6.974626, GCC MAE: 8.069162, ACC: 0.579787, GCC ACC: 0.547872 12 | 13 | Train epoch 3, loss: 2.728366, MAE: 6.782053, GCC-MAE: 7.530729, ACC: 0.594080, GCC-ACC: 0.545983 14 | 15 | Val epoch 3, loss: 2.793950, MAE: 7.060416, GCC MAE: 7.868888, ACC: 0.595745, GCC ACC: 0.515957 16 | 17 | Train epoch 4, loss: 2.702268, MAE: 6.665494, GCC-MAE: 7.758763, ACC: 0.593552, GCC-ACC: 0.538055 18 | 19 | Val epoch 4, loss: 2.785724, MAE: 6.022076, GCC MAE: 7.304094, ACC: 0.622340, GCC ACC: 0.558511 20 | 21 | Train epoch 5, loss: 2.686963, MAE: 6.533638, GCC-MAE: 7.522538, ACC: 0.600951, GCC-ACC: 0.542812 22 | 23 | Val epoch 5, loss: 2.883787, MAE: 6.642621, GCC MAE: 7.626397, ACC: 0.601064, GCC ACC: 0.585106 24 | 25 | Train epoch 6, loss: 2.717674, MAE: 6.909814, GCC-MAE: 8.070430, ACC: 0.580867, GCC-ACC: 0.519027 26 | 27 | Val epoch 6, loss: 2.704891, MAE: 6.972696, GCC MAE: 8.151320, ACC: 0.585106, GCC ACC: 0.569149 28 | 29 | Train epoch 7, loss: 2.644675, MAE: 7.133255, GCC-MAE: 7.937805, ACC: 0.599894, GCC-ACC: 0.540698 30 | 31 | Val epoch 7, loss: 2.541397, MAE: 6.568586, GCC MAE: 8.323135, ACC: 0.627660, GCC ACC: 0.547872 32 | 33 | Train epoch 8, loss: 2.643423, MAE: 6.726465, GCC-MAE: 7.733125, ACC: 0.608351, GCC-ACC: 0.537526 34 | 35 | Val epoch 8, loss: 2.526285, MAE: 5.720998, GCC MAE: 7.081119, ACC: 0.622340, GCC ACC: 0.601064 36 | 37 | Train epoch 9, loss: 2.674665, MAE: 6.774623, GCC-MAE: 7.598134, ACC: 0.613108, GCC-ACC: 0.544926 38 | 39 | Val epoch 9, loss: 2.556621, MAE: 6.677650, GCC MAE: 7.194186, ACC: 0.595745, GCC ACC: 0.611702 40 | 41 | Train epoch 10, loss: 2.616363, MAE: 6.510004, GCC-MAE: 7.530625, ACC: 0.613108, GCC-ACC: 0.544397 42 | 43 | Val epoch 10, loss: 2.607125, MAE: 6.551370, GCC MAE: 7.197229, ACC: 0.622340, GCC ACC: 0.622340 44 | 45 | Train epoch 11, loss: 2.665833, MAE: 7.170951, GCC-MAE: 7.672150, ACC: 0.592495, GCC-ACC: 0.555497 46 | 47 | Val epoch 11, loss: 2.852714, MAE: 7.226460, GCC MAE: 6.524045, ACC: 0.558511, GCC ACC: 0.590425 48 | 49 | Train epoch 12, loss: 2.574891, MAE: 6.466700, GCC-MAE: 7.532702, ACC: 0.618393, GCC-ACC: 0.545983 50 | 51 | Val epoch 12, loss: 2.669662, MAE: 7.141157, GCC MAE: 7.917722, ACC: 0.574468, GCC ACC: 0.569149 52 | 53 | Train epoch 13, loss: 2.553060, MAE: 6.785158, GCC-MAE: 7.660157, ACC: 0.604123, GCC-ACC: 0.547569 54 | 55 | Val epoch 13, loss: 2.649580, MAE: 8.002169, GCC MAE: 8.145764, ACC: 0.542553, GCC ACC: 0.574468 56 | 57 | Train epoch 14, loss: 2.568930, MAE: 6.318080, GCC-MAE: 7.738595, ACC: 0.624207, GCC-ACC: 0.554440 58 | 59 | Val epoch 14, loss: 2.578147, MAE: 5.910570, GCC MAE: 8.069002, ACC: 0.659574, GCC ACC: 0.574468 60 | 61 | Train epoch 15, loss: 2.529045, MAE: 6.825778, GCC-MAE: 7.732957, ACC: 0.624207, GCC-ACC: 0.547569 62 | 63 | Val epoch 15, loss: 2.596268, MAE: 6.740964, GCC MAE: 8.441404, ACC: 0.585106, GCC ACC: 0.537234 64 | 65 | Train epoch 16, loss: 2.598708, MAE: 6.985491, GCC-MAE: 7.621158, ACC: 0.585095, GCC-ACC: 0.531712 66 | 67 | Val epoch 16, loss: 2.582580, MAE: 6.616679, GCC MAE: 7.342334, ACC: 0.595745, GCC ACC: 0.531915 68 | 69 | Train epoch 17, loss: 2.568331, MAE: 6.851851, GCC-MAE: 7.666514, ACC: 0.602537, GCC-ACC: 0.538055 70 | 71 | Val epoch 17, loss: 2.678940, MAE: 6.694512, GCC MAE: 6.874386, ACC: 0.617021, GCC ACC: 0.590425 72 | 73 | Train epoch 18, loss: 2.519191, MAE: 6.452097, GCC-MAE: 7.753897, ACC: 0.626850, GCC-ACC: 0.542812 74 | 75 | Val epoch 18, loss: 2.515379, MAE: 5.877830, GCC MAE: 7.439689, ACC: 0.627660, GCC ACC: 0.531915 76 | 77 | Train epoch 19, loss: 2.493746, MAE: 6.489126, GCC-MAE: 7.925293, ACC: 0.625264, GCC-ACC: 0.529598 78 | 79 | Val epoch 19, loss: 2.401553, MAE: 5.710040, GCC MAE: 6.816219, ACC: 0.670213, GCC ACC: 0.606383 80 | 81 | Train epoch 20, loss: 2.472274, MAE: 6.792701, GCC-MAE: 7.641991, ACC: 0.619979, GCC-ACC: 0.555497 82 | 83 | Val epoch 20, loss: 2.578272, MAE: 6.044822, GCC MAE: 7.669500, ACC: 0.627660, GCC ACC: 0.494681 84 | 85 | Train epoch 21, loss: 2.523828, MAE: 6.601887, GCC-MAE: 7.756747, ACC: 0.632664, GCC-ACC: 0.533298 86 | 87 | Val epoch 21, loss: 2.653243, MAE: 6.380878, GCC MAE: 7.714487, ACC: 0.595745, GCC ACC: 0.537234 88 | 89 | Train epoch 22, loss: 2.487474, MAE: 6.303914, GCC-MAE: 7.985081, ACC: 0.634249, GCC-ACC: 0.539641 90 | 91 | Val epoch 22, loss: 2.572427, MAE: 6.981922, GCC MAE: 6.552242, ACC: 0.585106, GCC ACC: 0.601064 92 | 93 | Train epoch 23, loss: 2.476933, MAE: 6.151340, GCC-MAE: 7.636378, ACC: 0.636364, GCC-ACC: 0.545983 94 | 95 | Val epoch 23, loss: 2.580400, MAE: 6.973416, GCC MAE: 8.839357, ACC: 0.569149, GCC ACC: 0.526596 96 | 97 | Train epoch 24, loss: 2.503137, MAE: 6.486834, GCC-MAE: 7.923493, ACC: 0.630550, GCC-ACC: 0.542283 98 | 99 | Val epoch 24, loss: 2.371652, MAE: 4.948560, GCC MAE: 6.409157, ACC: 0.680851, GCC ACC: 0.611702 100 | 101 | Train epoch 25, loss: 2.476206, MAE: 6.343758, GCC-MAE: 7.475555, ACC: 0.618922, GCC-ACC: 0.547040 102 | 103 | Val epoch 25, loss: 2.491279, MAE: 6.011263, GCC MAE: 7.586675, ACC: 0.611702, GCC ACC: 0.585106 104 | 105 | Train epoch 26, loss: 2.426955, MAE: 6.387320, GCC-MAE: 7.977196, ACC: 0.629493, GCC-ACC: 0.539112 106 | 107 | Val epoch 26, loss: 2.510316, MAE: 6.783352, GCC MAE: 7.655274, ACC: 0.579787, GCC ACC: 0.531915 108 | 109 | Train epoch 27, loss: 2.424105, MAE: 6.232253, GCC-MAE: 7.980004, ACC: 0.642706, GCC-ACC: 0.544926 110 | 111 | Val epoch 27, loss: 2.541463, MAE: 5.960946, GCC MAE: 7.405776, ACC: 0.606383, GCC ACC: 0.563830 112 | 113 | Train epoch 28, loss: 2.427154, MAE: 6.513635, GCC-MAE: 7.655016, ACC: 0.628964, GCC-ACC: 0.535941 114 | 115 | Val epoch 28, loss: 2.424474, MAE: 6.830526, GCC MAE: 7.450115, ACC: 0.627660, GCC ACC: 0.579787 116 | 117 | Train epoch 29, loss: 2.529005, MAE: 6.522229, GCC-MAE: 7.772382, ACC: 0.623150, GCC-ACC: 0.533827 118 | 119 | Val epoch 29, loss: 2.423645, MAE: 6.273998, GCC MAE: 7.093273, ACC: 0.601064, GCC ACC: 0.590425 120 | 121 | -------------------------------------------------------------------------------- /experiments/ngccphat/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axeber01/ngcc/2210bf55395669c1c39d49f65f286fc9a1523fc9/experiments/ngccphat/model.pth -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LabelSmoothing(nn.Module): 6 | """NLL loss with label smoothing. 7 | """ 8 | 9 | def __init__(self, smoothing=0.0): 10 | """Constructor for the LabelSmoothing module. 11 | :param smoothing: label smoothing factor 12 | """ 13 | super(LabelSmoothing, self).__init__() 14 | self.confidence = 1.0 - smoothing 15 | self.smoothing = smoothing 16 | 17 | def forward(self, x, target): 18 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 19 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 20 | nll_loss = nll_loss.squeeze(1) 21 | smooth_loss = -logprobs.mean(dim=-1) 22 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 23 | return loss.mean() 24 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchaudio 5 | import torch.utils.data as data_utils 6 | from tqdm import tqdm 7 | import numpy as np 8 | import random 9 | from torchinfo import summary 10 | import argparse 11 | import os 12 | 13 | from model import NGCCPHAT, PGCCPHAT, GCC 14 | from data import LibriSpeechLocations, DelaySimulator, one_random_delay, remove_silence 15 | from helpers import LabelSmoothing 16 | import cfg 17 | 18 | # Librispeech dataset constants 19 | DATA_LEN = 2620 20 | VAL_IDS = [260, 672, 908] # use these speaker ids for validation 21 | TEST_IDS = [61, 121, 237] # use these speaker ids for testing 22 | NUM_TEST_WINS = 15 23 | MIN_SIG_LEN = 2 # only use snippets longer than 2 seconds 24 | 25 | parser = argparse.ArgumentParser( 26 | description='Time Difference of Arrival Training') 27 | parser.add_argument('--exp_name', type=str, 28 | default='tdoa_exp', help='Name of the experiment') 29 | parser.add_argument('--evaluate', action='store_true', 30 | help='Set to true in order to evaluate the model across a range of SNRs and T60s') 31 | args = parser.parse_args() 32 | 33 | if not os.path.exists('experiments'): 34 | os.makedirs('experiments') 35 | if not os.path.exists('experiments/'+args.exp_name): 36 | os.makedirs('experiments/'+args.exp_name) 37 | 38 | if not args.evaluate: 39 | LOG_DIR = os.path.join('experiments/'+args.exp_name+'/') 40 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log.txt'), 'w') 41 | os.system('cp cfg.py experiments/' + args.exp_name + '/cfg.py') 42 | 43 | 44 | def log_string(out_str): 45 | LOG_FOUT.write(out_str+'\n') 46 | LOG_FOUT.flush() 47 | print(out_str) 48 | 49 | 50 | # for reproducibility 51 | torch.manual_seed(cfg.seed) 52 | random.seed(cfg.seed) 53 | np.random.seed(cfg.seed) 54 | 55 | # calculate the max_delay for gcc 56 | max_tau_gcc = int(np.floor(np.linalg.norm( 57 | cfg.mic_locs_train[:, 0] - cfg.mic_locs_train[:, 1]) * cfg.fs / 343)) 58 | 59 | # training parameters 60 | max_tau = cfg.max_delay 61 | snr = cfg.snr 62 | t60 = cfg.t60 63 | fs = cfg.fs 64 | sig_len = cfg.sig_len 65 | epochs = cfg.epochs 66 | batch_size = cfg.batch_size 67 | lr = cfg.lr 68 | wd = cfg.wd 69 | label_smooth = cfg.ls 70 | 71 | source_locs_train = np.random.uniform( 72 | low=cfg.xyz_min_train, high=cfg.xyz_max_train, size=(DATA_LEN, 3)) 73 | source_locs_val = np.random.uniform( 74 | low=cfg.xyz_min_train, high=cfg.xyz_max_train, size=(DATA_LEN, 3)) 75 | source_locs_test = np.random.uniform( 76 | low=cfg.xyz_min_test, high=cfg.xyz_max_test, size=(DATA_LEN, 3)) 77 | 78 | # fetch audio snippets within the range of [0, 2] seconds during training 79 | lower_bound = 0 80 | upper_bound = fs * MIN_SIG_LEN 81 | 82 | # create datasets 83 | train_set = LibriSpeechLocations(source_locs_train, split="test-clean") 84 | print('Total data set size: ' + str(len(train_set))) 85 | 86 | # remove silence and keep only waveforms longer than MIN_SIG_LEN seconds 87 | valid_idx = [i if len(remove_silence(waveform, frame_length=sig_len)) 88 | > fs * MIN_SIG_LEN else None for i, ((waveform, sample_rate, 89 | transcript, speaker_id, utterance_number), pos, seed) 90 | in enumerate(train_set)] 91 | inds = [i for i in valid_idx if i is not None] 92 | train_set = torch.utils.data.dataset.Subset(train_set, inds) 93 | print('Total data set size after removing silence: ' + str(len(train_set))) 94 | 95 | # create val and test split based on speaker ids 96 | val_set = LibriSpeechLocations(source_locs_val, split="test-clean") 97 | test_set = LibriSpeechLocations(source_locs_test, split="test-clean") 98 | 99 | indices_test = [i for i, ((waveform, sample_rate, transcript, speaker_id, utterance_number), pos, seed) 100 | in enumerate(train_set) if speaker_id in TEST_IDS] 101 | indices_val = [i for i, ((waveform, sample_rate, transcript, speaker_id, utterance_number), pos, seed) 102 | in enumerate(train_set) if speaker_id in VAL_IDS] 103 | indices_train = [i for i, ((waveform, sample_rate, transcript, speaker_id, utterance_number), pos, seed) 104 | in enumerate(train_set) if speaker_id not in TEST_IDS and speaker_id not in VAL_IDS] 105 | 106 | train_set = data_utils.Subset(train_set, indices_train) 107 | val_set = data_utils.Subset(val_set, indices_val) 108 | test_set = data_utils.Subset(test_set, indices_test) 109 | 110 | train_len = len(train_set) 111 | val_len = len(val_set) 112 | test_len = len(test_set) 113 | 114 | print('Training data size after removing silence: ' + str(train_len)) 115 | print('Validation data size after removing silence: ' + str(val_len)) 116 | print('Test data size after removing silence: ' + str(test_len)) 117 | 118 | (waveform, sample_rate, transcript, speaker_id, 119 | utterance_number), pos, seed = train_set[0] 120 | transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=fs) 121 | transformed = transform(waveform) 122 | 123 | # get delay statistics for normalization when using regression loss 124 | if cfg.loss == "mse": 125 | delays = [] 126 | for i in range(100): 127 | _, x_, delay, _ = one_random_delay(room_dim=cfg.room_dim_train, fs=fs, t60=0., 128 | mic_locs=cfg.mic_locs_train, signal=transformed, 129 | xyz_min=cfg.xyz_min_train, xyz_max=cfg.xyz_max_train, 130 | snr=0, anechoic=True) 131 | delays.append(delay) 132 | 133 | delay_mu = np.mean(delays) 134 | delay_sigma = np.std(delays) 135 | 136 | 137 | # use GPU if available, else CPU 138 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 139 | print("Using device: " + str(device)) 140 | 141 | if device == "cuda": 142 | num_workers = 1 143 | pin_memory = True 144 | else: 145 | num_workers = 0 146 | pin_memory = False 147 | 148 | # load model 149 | if cfg.model == 'NGCCPHAT': 150 | use_sinc = True if not cfg.no_sinc else False 151 | model = NGCCPHAT(max_tau, cfg.head, use_sinc, 152 | sig_len, cfg.num_channels, fs) 153 | elif cfg.model == 'PGCCPHAT': 154 | model = PGCCPHAT(max_tau=max_tau_gcc, head=cfg.head) 155 | else: 156 | raise Exception("Please specify a valid model") 157 | 158 | model = model.to(device) 159 | model.eval() 160 | summary(model, [(1, 1, sig_len), (1, 1, sig_len)]) 161 | 162 | gcc = GCC(max_tau=max_tau_gcc) 163 | 164 | optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) 165 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs) 166 | 167 | if cfg.loss == 'ce': 168 | loss_fn = LabelSmoothing(label_smooth) 169 | elif cfg.loss == 'mse': 170 | loss_fn = nn.MSELoss() 171 | else: 172 | raise Exception("Please specify a valid loss function") 173 | 174 | delay_simulator_train = DelaySimulator(cfg.room_dim_train, fs, sig_len, t60, cfg.mic_locs_train, max_tau, 175 | cfg.anechoic, train=True, snr=snr, lower_bound=lower_bound, upper_bound=upper_bound) 176 | delay_simulator_val = DelaySimulator(cfg.room_dim_train, fs, sig_len, t60, cfg.mic_locs_train, max_tau, 177 | cfg.anechoic, train=True, snr=snr, lower_bound=lower_bound, upper_bound=upper_bound) 178 | delay_simulator_test = DelaySimulator(cfg.room_dim_test, fs, sig_len, t60, cfg.mic_locs_test, max_tau, 179 | cfg.anechoic, train=False, snr=snr, lower_bound=lower_bound, upper_bound=upper_bound) 180 | 181 | print('Using loss function: ' + str(loss_fn)) 182 | 183 | train_loader = torch.utils.data.DataLoader( 184 | train_set, 185 | batch_size=batch_size, 186 | shuffle=True, 187 | collate_fn=delay_simulator_train, 188 | num_workers=num_workers, 189 | pin_memory=pin_memory, 190 | ) 191 | val_loader = torch.utils.data.DataLoader( 192 | val_set, 193 | batch_size=batch_size, 194 | shuffle=False, 195 | drop_last=False, 196 | collate_fn=delay_simulator_val, 197 | num_workers=num_workers, 198 | pin_memory=pin_memory, 199 | ) 200 | test_loader = torch.utils.data.DataLoader( 201 | test_set, 202 | batch_size=batch_size, 203 | shuffle=False, 204 | drop_last=False, 205 | collate_fn=delay_simulator_test, 206 | num_workers=num_workers, 207 | pin_memory=pin_memory, 208 | ) 209 | 210 | for e in range(epochs): 211 | if args.evaluate: 212 | break 213 | mae = 0 214 | gcc_mae = 0 215 | acc = 0 216 | gcc_acc = 0 217 | train_loss = 0 218 | logs = {} 219 | model.train() 220 | pbar_update = batch_size 221 | with tqdm(total=len(train_set)) as pbar: 222 | for batch_idx, (x1, x2, delays) in enumerate(train_loader): 223 | bs = x1.shape[0] 224 | 225 | x1 = x1.to(device) 226 | x2 = x2.to(device) 227 | delays = delays.to(device) 228 | y_hat = model(x1, x2) 229 | 230 | cc = gcc(x1.squeeze(), x2.squeeze()) 231 | shift_gcc = torch.argmax(cc, dim=-1) - max_tau_gcc 232 | 233 | if cfg.loss == 'ce': 234 | delays_loss = torch.round(delays).type(torch.LongTensor) 235 | shift = torch.argmax(y_hat, dim=-1) - max_tau 236 | else: 237 | delays_loss = (delays - delay_mu) / delay_sigma 238 | shift = y_hat * delay_sigma + delay_mu - max_tau 239 | 240 | gt = delays - max_tau 241 | mae += torch.sum(torch.abs(shift-gt)) 242 | gcc_mae += torch.sum(torch.abs(shift_gcc-gt)) 243 | 244 | acc += torch.sum(torch.abs(shift-gt) < cfg.t) 245 | gcc_acc += torch.sum(torch.abs(shift_gcc-gt) < cfg.t) 246 | 247 | loss = loss_fn(y_hat, delays_loss.to(device)) 248 | optimizer.zero_grad() 249 | loss.backward() 250 | optimizer.step() 251 | 252 | train_loss += loss.detach().item() * bs 253 | 254 | pbar.update(pbar_update) 255 | 256 | train_loss = train_loss / train_len 257 | mae = mae / train_len 258 | gcc_mae = gcc_mae / train_len 259 | acc = acc / train_len 260 | gcc_acc = gcc_acc / train_len 261 | 262 | outstr = 'Train epoch %d, loss: %.6f, MAE: %.6f, GCC-MAE: %.6f, ACC: %.6f, GCC-ACC: %.6f' % (e, 263 | train_loss, 264 | mae, 265 | gcc_mae, 266 | acc, 267 | gcc_acc) 268 | 269 | log_string(outstr+'\n') 270 | 271 | scheduler.step() 272 | 273 | torch.cuda.empty_cache() 274 | 275 | # Validation 276 | model.eval() 277 | mae = 0. 278 | gcc_mae = 0. 279 | acc = 0. 280 | gcc_acc = 0. 281 | val_loss = 0. 282 | with tqdm(total=len(val_set)) as pbar: 283 | for batch_idx, (x1, x2, delays) in enumerate(val_loader): 284 | with torch.no_grad(): 285 | bs = x1.shape[0] 286 | x1 = x1.to(device) 287 | x2 = x2.to(device) 288 | delays = delays.to(device) 289 | y_hat = model(x1, x2) 290 | 291 | cc = gcc(x1.squeeze(), x2.squeeze()) 292 | shift_gcc = torch.argmax(cc, dim=-1) - max_tau_gcc 293 | 294 | if cfg.loss == 'ce': 295 | delays_loss = torch.round(delays).type(torch.LongTensor) 296 | shift = torch.argmax(y_hat, dim=-1) - max_tau 297 | else: 298 | delays_loss = (delays - delay_mu) / delay_sigma 299 | shift = y_hat * delay_sigma + delay_mu - max_tau 300 | 301 | gt = delays - max_tau 302 | mae += torch.sum(torch.abs(shift-gt)) 303 | gcc_mae += torch.sum(torch.abs(shift_gcc-gt)) 304 | 305 | acc += torch.sum(torch.abs(shift-gt) < cfg.t) 306 | gcc_acc += torch.sum(torch.abs(shift_gcc-gt) < cfg.t) 307 | 308 | loss = loss_fn(y_hat, delays_loss.to(device)) 309 | val_loss += loss.detach().item() * bs 310 | 311 | pbar.update(pbar_update) 312 | 313 | mae = mae / val_len 314 | gcc_mae = gcc_mae / val_len 315 | acc = acc / val_len 316 | gcc_acc = gcc_acc / val_len 317 | val_loss = val_loss / val_len 318 | 319 | outstr = 'Val epoch %d, loss: %.6f, MAE: %.6f, GCC MAE: %.6f, ACC: %.6f, GCC ACC: %.6f' % (e, 320 | val_loss, 321 | mae, 322 | gcc_mae, 323 | acc, 324 | gcc_acc) 325 | log_string(outstr+'\n') 326 | 327 | torch.cuda.empty_cache() 328 | 329 | 330 | # Save the model 331 | if not args.evaluate: 332 | torch.save(model.state_dict(), 'experiments/' 333 | + args.exp_name+'/'+'model.pth') 334 | LOG_FOUT.close() 335 | 336 | 337 | if args.evaluate: 338 | # load pre-trained model andevaluate on each window in the test set, for 339 | # each SNR and t60 in the list 340 | 341 | model.load_state_dict(torch.load( 342 | "experiments/"+args.exp_name+"/model.pth", map_location=torch.device(device))) 343 | 344 | model.eval() 345 | 346 | LOG_DIR = os.path.join('experiments/'+args.exp_name+'/') 347 | if cfg.anechoic: 348 | name = 'eval_anechoic.txt' 349 | else: 350 | name = 'eval.txt' 351 | LOG_FOUT = open(os.path.join(LOG_DIR, name), 'w') 352 | LOG_FOUT.write(str(args)+'\n') 353 | 354 | if cfg.anechoic: 355 | t60_range = [0.0] 356 | else: 357 | t60_range = cfg.t60_range 358 | 359 | ground_truth = np.empty( 360 | (test_len * NUM_TEST_WINS, len(cfg.snr_range), len(t60_range))) 361 | preds = np.empty( 362 | (test_len * NUM_TEST_WINS, len(cfg.snr_range), len(t60_range))) 363 | preds_gcc = np.empty( 364 | (test_len * NUM_TEST_WINS, len(cfg.snr_range), len(t60_range))) 365 | 366 | for snr_index, this_snr in enumerate(cfg.snr_range): 367 | for t60_index, this_t60 in enumerate(t60_range): 368 | 369 | mse = 0. 370 | gcc_mse = 0. 371 | mae = 0. 372 | gcc_mae = 0. 373 | acc = 0 374 | gcc_acc = 0 375 | test_loss = 0. 376 | 377 | start_index = 0 378 | end_index = 0 379 | 380 | pbar_update = batch_size 381 | with tqdm(total=len(test_set)*NUM_TEST_WINS) as pbar: 382 | 383 | for win in range(NUM_TEST_WINS): 384 | delay_simulator_test = DelaySimulator(cfg.room_dim_test, fs, sig_len, [this_t60, this_t60], 385 | cfg.mic_locs_test, max_tau, cfg.anechoic, False, [this_snr, this_snr], lower_bound=lower_bound+win*sig_len) 386 | 387 | test_loader = torch.utils.data.DataLoader( 388 | test_set, 389 | batch_size=batch_size, 390 | shuffle=False, 391 | drop_last=False, 392 | collate_fn=delay_simulator_test, 393 | num_workers=num_workers, 394 | pin_memory=pin_memory, 395 | ) 396 | 397 | for batch_idx, (x1, x2, delays) in enumerate(test_loader): 398 | with torch.no_grad(): 399 | bs = x1.shape[0] 400 | x1 = x1.to(device) 401 | x2 = x2.to(device) 402 | delays = delays.to(device) 403 | y_hat = model(x1, x2) 404 | 405 | cc = gcc(x1.squeeze(), x2.squeeze()) 406 | shift_gcc = torch.argmax(cc, dim=-1) - max_tau_gcc 407 | 408 | if cfg.loss == 'ce': 409 | delays_loss = torch.round( 410 | delays).type(torch.LongTensor) 411 | shift = torch.argmax(y_hat, dim=-1) - max_tau 412 | else: 413 | delays_loss = (delays - delay_mu) / delay_sigma 414 | shift = y_hat * delay_sigma + delay_mu - max_tau 415 | 416 | gt = delays - max_tau 417 | mse += torch.sum(torch.abs(shift-gt)**2) 418 | gcc_mse += torch.sum(torch.abs(shift_gcc-gt)**2) 419 | mae += torch.sum(torch.abs(shift-gt)) 420 | gcc_mae += torch.sum(torch.abs(shift_gcc-gt)) 421 | acc += torch.sum(torch.abs(shift-gt) < cfg.t) 422 | gcc_acc += torch.sum(torch.abs(shift_gcc-gt) 423 | < cfg.t) 424 | 425 | end_index = end_index + bs 426 | ground_truth[start_index:end_index, 427 | snr_index, t60_index] = gt.cpu().numpy() 428 | preds[start_index:end_index, snr_index, 429 | t60_index] = shift.cpu().numpy() 430 | preds_gcc[start_index:end_index, snr_index, 431 | t60_index] = shift_gcc.cpu().numpy() 432 | start_index = start_index + bs 433 | 434 | loss = loss_fn(y_hat, delays_loss.to(device)) 435 | test_loss += loss.item() * bs 436 | 437 | pbar.update(pbar_update) 438 | 439 | rmse = torch.sqrt(mse / (test_len * NUM_TEST_WINS)) 440 | gcc_rmse = torch.sqrt(gcc_mse / (test_len * NUM_TEST_WINS)) 441 | mae = mae / (test_len * NUM_TEST_WINS) 442 | gcc_mae = gcc_mae / (test_len * NUM_TEST_WINS) 443 | acc = acc / (test_len * NUM_TEST_WINS) 444 | gcc_acc = gcc_acc / (test_len * NUM_TEST_WINS) 445 | test_loss = test_loss / (test_len * NUM_TEST_WINS) 446 | 447 | outstr = 'SNR: % d, T60: % .6f, loss: % .6f, RMSE: % .6f, GCC RMSE: % .6f, MAE: % .6f, GCC MAE: % .6f, ACC: % .6f, GCC ACC: % .6f' % (this_snr, 448 | this_t60, 449 | test_loss, 450 | rmse, 451 | gcc_rmse, 452 | mae, 453 | gcc_mae, 454 | acc, 455 | gcc_acc) 456 | log_string(outstr+'\n') 457 | 458 | torch.cuda.empty_cache() 459 | 460 | # Store all the ground truth delays and predictions 461 | if cfg.anechoic: 462 | np.savez('experiments/'+args.exp_name+'/' 463 | + 'evaluations_anechoic.npz', ground_truth, preds, preds_gcc) 464 | else: 465 | np.savez('experiments/'+args.exp_name+'/'+'evaluations.npz', 466 | ground_truth, preds, preds_gcc) 467 | 468 | LOG_FOUT.close() 469 | LOG_FOUT.close() 470 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch_same_pad import get_pad 6 | from dnn_models import SincNet 7 | import torch.fft 8 | 9 | class GCC(nn.Module): 10 | def __init__(self, max_tau=None, dim=2, filt='phat', epsilon=0.001, beta=None): 11 | super().__init__() 12 | 13 | ''' GCC implementation based on Knapp and Carter, 14 | "The Generalized Correlation Method for Estimation of Time Delay", 15 | IEEE Trans. Acoust., Speech, Signal Processing, August, 1976 ''' 16 | 17 | self.max_tau = max_tau 18 | self.dim = dim 19 | self.filt = filt 20 | self.epsilon = epsilon 21 | self.beta = beta 22 | 23 | def forward(self, x, y): 24 | 25 | n = x.shape[-1] + y.shape[-1] 26 | 27 | # Generalized Cross Correlation Phase Transform 28 | X = torch.fft.rfft(x, n=n) 29 | Y = torch.fft.rfft(y, n=n) 30 | Gxy = X * torch.conj(Y) 31 | 32 | if self.filt == 'phat': 33 | phi = 1 / (torch.abs(Gxy) + self.epsilon) 34 | 35 | elif self.filt == 'roth': 36 | phi = 1 / (X * torch.conj(X) + self.epsilon) 37 | 38 | elif self.filt == 'scot': 39 | Gxx = X * torch.conj(X) 40 | Gyy = Y * torch.conj(Y) 41 | phi = 1 / (torch.sqrt(Gxx * Gyy) + self.epsilon) 42 | 43 | elif self.filt == 'ht': 44 | Gxx = X * torch.conj(X) 45 | Gyy = Y * torch.conj(Y) 46 | gamma = Gxy / torch.sqrt(Gxx * Gxy) 47 | phi = torch.abs(gamma)**2 / (torch.abs(Gxy) 48 | * (1 - gamma)**2 + self.epsilon) 49 | 50 | elif self.filt == 'cc': 51 | phi = 1.0 52 | 53 | else: 54 | raise ValueError('Unsupported filter function') 55 | 56 | if self.beta is not None: 57 | cc = [] 58 | for i in range(self.beta.shape[0]): 59 | cc.append(torch.fft.irfft( 60 | Gxy * torch.pow(phi, self.beta[i]), n)) 61 | 62 | cc = torch.cat(cc, dim=1) 63 | 64 | else: 65 | cc = torch.fft.irfft(Gxy * phi, n) 66 | 67 | max_shift = int(n / 2) 68 | if self.max_tau: 69 | max_shift = np.minimum(self.max_tau, int(max_shift)) 70 | 71 | if self.dim == 2: 72 | cc = torch.cat((cc[:, -max_shift:], cc[:, :max_shift+1]), dim=-1) 73 | elif self.dim == 3: 74 | cc = torch.cat( 75 | (cc[:, :, -max_shift:], cc[:, :, :max_shift+1]), dim=-1) 76 | 77 | return cc 78 | 79 | 80 | class NGCCPHAT(nn.Module): 81 | def __init__(self, max_tau=42, head='classifier', use_sinc=True, 82 | sig_len=2048, num_channels=128, fs=16000): 83 | super().__init__() 84 | 85 | ''' 86 | Neural GCC-PHAT with SincNet backbone 87 | 88 | arguments: 89 | max_tau - the maximum possible delay considered 90 | head - classifier or regression 91 | use_sinc - use sincnet backbone if True, otherwise use regular conv layers 92 | sig_len - length of input signal 93 | n_channel - number of gcc correlation channels to use 94 | fs - sampling frequency 95 | ''' 96 | 97 | self.max_tau = max_tau 98 | self.head = head 99 | 100 | sincnet_params = {'input_dim': sig_len, 101 | 'fs': fs, 102 | 'cnn_N_filt': [128, 128, 128, num_channels], 103 | 'cnn_len_filt': [1023, 11, 9, 7], 104 | 'cnn_max_pool_len': [1, 1, 1, 1], 105 | 'cnn_use_laynorm_inp': False, 106 | 'cnn_use_batchnorm_inp': False, 107 | 'cnn_use_laynorm': [False, False, False, False], 108 | 'cnn_use_batchnorm': [True, True, True, True], 109 | 'cnn_act': ['leaky_relu', 'leaky_relu', 'leaky_relu', 'linear'], 110 | 'cnn_drop': [0.0, 0.0, 0.0, 0.0], 111 | 'use_sinc': use_sinc, 112 | } 113 | 114 | self.backbone = SincNet(sincnet_params) 115 | self.mlp_kernels = [11, 9, 7] 116 | self.channels = [num_channels, 128, 128, 128] 117 | self.final_kernel = [5] 118 | 119 | self.gcc = GCC(max_tau=self.max_tau, dim=3, filt='phat') 120 | 121 | self.mlp = nn.ModuleList([nn.Sequential( 122 | nn.Conv1d(self.channels[i], self.channels[i+1], kernel_size=k), 123 | nn.BatchNorm1d(self.channels[i+1]), 124 | nn.LeakyReLU(0.2), 125 | nn.Dropout(0.5)) for i, k in enumerate(self.mlp_kernels)]) 126 | 127 | self.final_conv = nn.Conv1d(128, 1, kernel_size=self.final_kernel) 128 | 129 | if head == 'regression': 130 | self.reg = nn.Sequential( 131 | nn.BatchNorm1d(2 * self.max_tau + 1), 132 | nn.LeakyReLU(0.2), 133 | nn.Linear(2 * self.max_tau + 1, 1)) 134 | 135 | def forward(self, x1, x2): 136 | 137 | batch_size = x1.shape[0] 138 | 139 | y1 = self.backbone(x1) 140 | y2 = self.backbone(x2) 141 | 142 | cc = self.gcc(y1, y2) 143 | 144 | for k, layer in enumerate(self.mlp): 145 | s = cc.shape[2] 146 | padding = get_pad( 147 | size=s, kernel_size=self.mlp_kernels[k], stride=1, dilation=1) 148 | cc = F.pad(cc, pad=padding, mode='constant') 149 | cc = layer(cc) 150 | 151 | s = cc.shape[2] 152 | padding = get_pad( 153 | size=s, kernel_size=self.final_kernel, stride=1, dilation=1) 154 | cc = F.pad(cc, pad=padding, mode='constant') 155 | cc = self.final_conv(cc).reshape([batch_size, -1]) 156 | if self.head == 'regression': 157 | cc = self.reg(cc).squeeze() 158 | 159 | return cc 160 | 161 | 162 | class PGCCPHAT(nn.Module): 163 | def __init__(self, beta=np.arange(0, 1.1, 0.1), max_tau=42, head='regression'): 164 | super().__init__() 165 | 166 | ''' 167 | Implementation of CNN-Based Parametrized GCC-PHAT by Salvati et al. 168 | https://www.isca-speech.org/archive/pdfs/interspeech_2021/salvati21_interspeech.pdf 169 | ''' 170 | 171 | self.beta = beta 172 | self.gcc = GCC(max_tau=max_tau, dim=3, filt='phat', beta=beta) 173 | self.head = head 174 | self.max_tau = max_tau 175 | 176 | if head == 'regression': 177 | n_out = 1 178 | else: 179 | n_out = 2 * self.max_tau + 1 180 | 181 | self.conv1 = nn.Conv2d(1, 32, kernel_size=(3, 3)) 182 | self.bn1 = nn.BatchNorm2d(32) 183 | self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3)) 184 | self.bn2 = nn.BatchNorm2d(64) 185 | self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3)) 186 | self.bn3 = nn.BatchNorm2d(128) 187 | self.conv4 = nn.Conv2d(128, 256, kernel_size=(3, 3)) 188 | self.bn4 = nn.BatchNorm2d(256) 189 | self.conv5 = nn.Conv2d(256, 512, kernel_size=(3, 3)) 190 | self.bn5 = nn.BatchNorm2d(512) 191 | 192 | self.mlp = nn.Sequential( 193 | nn.Linear(512 * (2 * max_tau + 1 - 10), 512), 194 | nn.BatchNorm1d(512), 195 | nn.ReLU(), 196 | nn.Dropout(0.2), 197 | nn.Linear(512, 512), 198 | nn.BatchNorm1d(512), 199 | nn.ReLU(), 200 | nn.Dropout(0.2), 201 | nn.Linear(512, n_out) 202 | ) 203 | 204 | def forward(self, x1, x2): 205 | 206 | batch_size = x1.shape[0] 207 | 208 | x = self.gcc(x1, x2).unsqueeze(1) 209 | 210 | x = self.conv1(x) 211 | x = F.relu(self.bn1(x)) 212 | 213 | x = self.conv2(x) 214 | x = F.relu(self.bn2(x)) 215 | 216 | x = self.conv3(x) 217 | x = F.relu(self.bn3(x)) 218 | 219 | x = self.conv4(x) 220 | x = F.relu(self.bn4(x)) 221 | 222 | x = self.conv5(x) 223 | x = F.relu(self.bn5(x)) 224 | x = self.mlp(x.reshape([batch_size, -1])).squeeze() 225 | 226 | return x 227 | -------------------------------------------------------------------------------- /ngcc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axeber01/ngcc/2210bf55395669c1c39d49f65f286fc9a1523fc9/ngcc.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchaudio 4 | torchinfo 5 | pyroomacoustics 6 | tqdm 7 | torch_same_pad 8 | librosa 9 | scipy 10 | -------------------------------------------------------------------------------- /speech.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/axeber01/ngcc/2210bf55395669c1c39d49f65f286fc9a1523fc9/speech.wav --------------------------------------------------------------------------------