├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── complex_neural_source_localization ├── __init__.py ├── datasets │ ├── __init__.py │ └── dcase_2019_task3_dataset.py ├── feature_extractors.py ├── loss.py ├── model.py ├── trainer.py └── utils │ ├── __init__.py │ ├── base_trainer.py │ ├── complexPyTorch │ ├── complexFunctions.py │ └── complexLayers.py │ ├── conv_block.py │ ├── load_config.py │ ├── model_utilities.py │ └── model_visualization.py ├── config ├── config.yaml ├── dataset │ └── dcase_2019_task3_dataset.yaml ├── model.yaml └── training.yaml ├── notebooks ├── DOA Visualization.ipynb ├── Model Visualization.ipynb └── Test Model.ipynb ├── requirements.txt ├── tdoa ├── correlation.py ├── math_utils.py ├── metrics.py ├── srp.py ├── tdoa.py └── visualization.py ├── tests ├── .gitignore ├── __init__.py ├── baselines │ ├── __init__.py │ └── test_crnns.py ├── fixtures │ ├── 0.0_split1_ir0_ov1_3.wav │ ├── fold1_room1_mix001.csv │ ├── fold1_room1_mix001.wav │ ├── p225_001.wav │ ├── test_real.pickle │ └── weights.ckpt ├── test_correlation.py ├── test_feature_extractors.py ├── test_loss.py ├── test_math_utils.py ├── test_model.py ├── test_tdoa.py └── test_visualize_convolutional_feature_maps.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | logs*/ 2 | runs/ 3 | .vscode 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | 138 | # pytype static type analyzer 139 | .pytype/ 140 | 141 | # Cython debug symbols 142 | cython_debug/ 143 | *.DS_Store 144 | 145 | generated_dataset/ 146 | validation_dataset/ 147 | training_dataset/ 148 | test_dataset/ 149 | tests/temp 150 | outputs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Imperial College London 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install clean lint preprocessing 2 | 3 | test: 4 | @pytest tests/ 5 | 6 | train: 7 | @python train.py 8 | 9 | visualizations: 10 | @python visualizations.py 11 | 12 | install: 13 | @pip install -r requirements.txt 14 | 15 | lint: 16 | @flake8 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Complex neural source localization 2 | This repository contains the code for the paper 3 | "Deep complex-valued convolutional-recurrent networks for single source doa estimation" to be published at the 4 | International Workshop on Acoustic Signal Enhancement (IWAENC) 2022. 5 | 6 | https://hal.science/hal-03779970/document 7 | https://ieeexplore.ieee.org/abstract/document/9914747 8 | 9 | 10 | ## Installation 11 | 12 | To test the code without installing anything, we suggest running it using this [Kaggle notebook](https://www.kaggle.com/code/egrinstein/neural-doa-training-notebook). To install it locally, follow the instructions below. 13 | 14 | 15 | ### Requirements 16 | * Python 3 17 | 18 | run `pip install -r requirements.txt` to install the python libraries needed 19 | 20 | Download the [Kaggle dataset](https://www.kaggle.com/datasets/egrinstein/dcase-2019-single-source) containing the data, and change the file 'config/dcase_2019_task3_dataset.yaml' to point at the correct train, validation and test datasets. 21 | 22 | Then, change the working directory to this project and run `python train.py` or `make train` to start training the model. Every time you start training a model, a folder will be created in the outputs/ 23 | 24 | 25 | ## Unit tests 26 | To execute all unit tests, run either: 27 | 28 | `pytest tests` 29 | or 30 | `make tests` 31 | ` 32 | -------------------------------------------------------------------------------- /complex_neural_source_localization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/complex_neural_source_localization/__init__.py -------------------------------------------------------------------------------- /complex_neural_source_localization/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .dcase_2019_task3_dataset import DCASE2019Task3Dataset 4 | 5 | 6 | def create_dataloaders(config): 7 | dataset_train = DCASE2019Task3Dataset(config["dataset"]) 8 | dataset_val = DCASE2019Task3Dataset(config["dataset"]) 9 | dataset_test = DCASE2019Task3Dataset(config["dataset"]) 10 | 11 | batch_size = config["training"]["batch_size"] 12 | n_workers = config["training"]["n_workers"] 13 | 14 | dataloader_train = _create_torch_dataloader(dataset_train, batch_size, 15 | n_workers, shuffle=True) 16 | dataloader_val = _create_torch_dataloader(dataset_val, batch_size, 17 | n_workers) 18 | dataloader_test = _create_torch_dataloader(dataset_test, batch_size, 19 | n_workers) 20 | 21 | return dataloader_train, dataloader_val, dataloader_test 22 | 23 | 24 | def _create_torch_dataloader(torch_dataset, batch_size, n_workers, shuffle=False): 25 | return torch.utils.data.DataLoader(torch_dataset, 26 | batch_size=batch_size, 27 | shuffle=shuffle, 28 | pin_memory=True, 29 | drop_last=False, 30 | num_workers=n_workers) 31 | -------------------------------------------------------------------------------- /complex_neural_source_localization/datasets/dcase_2019_task3_dataset.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | 6 | from math import isnan 7 | from pathlib import Path 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class DCASE2019Task3Dataset(Dataset): 12 | def __init__(self, dataset_config, mode="train"): 13 | self.config = dataset_config 14 | self.sr = dataset_config["sr"] 15 | self.sample_duration_in_seconds = dataset_config["sample_duration_in_seconds"] 16 | self.num_mics = dataset_config["num_mics"] 17 | self.n_max_sources = dataset_config["n_max_sources"] 18 | 19 | if mode == "train": 20 | annotations_path = dataset_config["train_csv_path"] 21 | self.wav_path = Path(dataset_config["train_wav_path"]) 22 | elif mode == "validation": 23 | annotations_path = dataset_config["validation_csv_path"] 24 | self.wav_path = Path(dataset_config["validation_wav_path"]) 25 | elif mode == "test": 26 | annotations_path = dataset_config["test_csv_path"] 27 | self.wav_path = Path(dataset_config["test_wav_path"]) 28 | else: 29 | raise ValueError(f"Dataset mode {mode} is invalid") 30 | 31 | self.df = pd.read_csv(annotations_path) 32 | 33 | def __getitem__(self, index): 34 | annotation = self.df.iloc[index] 35 | wav_file_path = self.wav_path / (annotation["file_name"] + ".wav") 36 | 37 | signal = load_multichannel_wav(wav_file_path, self.sr, self.sample_duration_in_seconds) 38 | 39 | # TODO: Remove samples from the dataset which are incomplete 40 | 41 | azimuth_in_radians = np.deg2rad(annotation["azi"]) 42 | azimuth_2d_point = _angle_to_point(azimuth_in_radians) 43 | 44 | y = { 45 | "azimuth_2d_point": azimuth_2d_point, 46 | "azimuth_in_radians": torch.Tensor([azimuth_in_radians]) 47 | # "start_time": torch.Tensor([annotation["start_time"]]), 48 | # "end_time": torch.Tensor([annotation["end_time"]]) 49 | } 50 | 51 | # If multi-source 52 | if self.n_max_sources == 2: 53 | # If there is a second source, use it. Else, copy first source 54 | if not isnan(annotation["azi_2"]): 55 | azi = annotation["azi_2"] 56 | else: 57 | azi = annotation["azi"] 58 | azimuth_in_radians_2 = np.deg2rad(azi) 59 | azimuth_2d_point_2 = _angle_to_point(azimuth_in_radians_2) 60 | y["azimuth_2d_point_2"] = azimuth_2d_point_2 61 | 62 | 63 | return (signal, y) 64 | 65 | def __len__(self): 66 | return self.df.shape[0] 67 | 68 | 69 | def load_multichannel_wav(wav_file_path, sr, duration_in_secs): 70 | signal, _ = librosa.load(wav_file_path, sr=sr, 71 | mono=False, dtype=np.float32) 72 | duration_in_samples = int(duration_in_secs*sr) 73 | padded_signal = np.zeros((signal.shape[0], duration_in_samples)) 74 | padded_signal[:, :signal.shape[1]] = signal 75 | return torch.Tensor(padded_signal) 76 | 77 | 78 | 79 | def _angle_to_point(angle): 80 | return torch.Tensor([ 81 | torch.Tensor([np.cos(angle)]), 82 | torch.Tensor([np.sin(angle)]) 83 | ]) 84 | -------------------------------------------------------------------------------- /complex_neural_source_localization/feature_extractors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.nn import Module 4 | from torchaudio.transforms import MelSpectrogram 5 | 6 | 7 | class MfccArray(Module): 8 | def __init__(self, model_config, dataset_config): 9 | 10 | super().__init__() 11 | 12 | self.mel_spectrogram = MelSpectrogram( 13 | sample_rate=dataset_config["sr"], 14 | n_fft=model_config["n_fft"], 15 | n_mels=model_config["n_mels"] 16 | ) 17 | 18 | def forward(self, X): 19 | "Expected input has shape (batch_size, n_arrays, time_steps)" 20 | 21 | result = [] 22 | 23 | n_arrays = X.shape[1] 24 | 25 | for i in range(n_arrays): 26 | x = X[:, i, :] 27 | result.append(self.mel_spectrogram(x)) 28 | 29 | return torch.stack(result, dim=1) 30 | 31 | 32 | class StftArray(Module): 33 | def __init__(self, model_config): 34 | 35 | super().__init__() 36 | 37 | self.n_fft = model_config["n_fft"] 38 | self.onesided = model_config["use_onesided_fft"] 39 | 40 | def forward(self, X): 41 | "Expected input has shape (batch_size, n_arrays, time_steps)" 42 | 43 | result = [] 44 | n_arrays = X.shape[1] 45 | 46 | for i in range(n_arrays): 47 | x = X[:, i, :] 48 | stft_output = torch.stft(x, self.n_fft, onesided=self.onesided, return_complex=True) 49 | result.append( 50 | stft_output[:, 1:, :] 51 | ) # Ignore frequency 0 52 | 53 | result = torch.stack(result, dim=1) 54 | return result 55 | 56 | 57 | class StftMagnitudeArray(StftArray): 58 | def forward(self, X): 59 | stft = super().forward(X) 60 | return stft.abs() 61 | 62 | 63 | class StftPhaseArray(StftArray): 64 | def forward(self, X): 65 | stft = super().forward(X) 66 | return stft.angle() 67 | 68 | 69 | class DecoupledStftArray(StftArray): 70 | "Stft where the real and imaginary channels are modeled as separate channels" 71 | def forward(self, X): 72 | 73 | stft = super().forward(X) 74 | 75 | # stft.real.shape = (batch_size, num_mics, num_channels, time_steps) 76 | result = torch.cat((stft.real, stft.imag), dim=2) 77 | 78 | return result 79 | 80 | 81 | class CrossSpectra(Module): 82 | def __init__(self, model_config): 83 | 84 | super().__init__() 85 | 86 | self.n_fft = model_config["n_fft"] 87 | self.stft_extractor = StftArray(model_config) 88 | 89 | def forward(self, X): 90 | "Expected input has shape (batch_size, n_channels, time_steps)" 91 | batch_size, n_channels, time_steps = X.shape 92 | 93 | stfts = self.stft_extractor(X) 94 | # (batch_size, n_channels, n_freq_bins, n_time_bins) 95 | cross_spectra = [] 96 | 97 | for sample_idx in range(batch_size): 98 | # TODO: maybe there is a way to vectorize the loops below, 99 | # although it would probably repeat many operations 100 | sample_cross_spectra = [] 101 | for channel_1 in range(n_channels): 102 | for channel_2 in range(channel_1, n_channels): 103 | cross_spectrum = stfts[sample_idx][channel_1]*stfts[sample_idx][channel_2].conj() 104 | sample_cross_spectra.append(cross_spectrum) 105 | cross_spectra.append( 106 | torch.stack(sample_cross_spectra, axis=0) 107 | ) 108 | 109 | result = torch.stack(cross_spectra, dim=0) 110 | return result 111 | 112 | 113 | FEATURE_NAME_TO_CLASS_MAP = { 114 | "mfcc": MfccArray, 115 | "stft": StftArray, 116 | "stft_magnitude": StftMagnitudeArray, 117 | "stft_phase": StftPhaseArray, 118 | "cross_spectra": CrossSpectra 119 | } -------------------------------------------------------------------------------- /complex_neural_source_localization/loss.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, CosineSimilarity 2 | 3 | 4 | class AngularLoss(Module): 5 | def __init__(self, model_output_type="scalar"): 6 | # See https://pytorch.org/docs/stable/generated/torch.nn.CosineEmbeddingLoss.html 7 | # for a related implementation used for NLP 8 | super().__init__() 9 | 10 | dim = 1 if model_output_type == "scalar" else 2 11 | self.cosine_similarity = CosineSimilarity(dim=dim) 12 | # dim=0 -> batch | dim=1 -> time steps | dim=2 -> azimuth 13 | 14 | def forward(self, model_output, targets, mean_reduce=True): 15 | values = 1 - self.cosine_similarity(model_output, targets["azimuth_2d_point"]) 16 | if mean_reduce: 17 | values = values.mean() 18 | return values 19 | -------------------------------------------------------------------------------- /complex_neural_source_localization/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from complex_neural_source_localization.feature_extractors import ( 5 | FEATURE_NAME_TO_CLASS_MAP 6 | ) 7 | from complex_neural_source_localization.utils.conv_block import ConvBlock 8 | 9 | from complex_neural_source_localization.utils.complexPyTorch.complexLayers import ( 10 | ComplexGRU, ComplexLinear 11 | ) 12 | from complex_neural_source_localization.utils.model_utilities import init_gru, init_layer 13 | 14 | DEFAULT_CONV_CONFIG = [ 15 | {"type": "complex_single", "n_channels": 64, "dropout_rate":0}, 16 | {"type": "complex_single", "n_channels": 64, "dropout_rate":0}, 17 | {"type": "complex_single", "n_channels": 64, "dropout_rate":0}, 18 | {"type": "complex_single", "n_channels": 64, "dropout_rate":0}, 19 | ] 20 | 21 | DEFAULT_STFT_CONFIG = {"n_fft": 1024, "use_onesided_fft":True} 22 | 23 | 24 | class DOACNet(nn.Module): 25 | def __init__(self, output_type="scalar", n_input_channels=4, n_sources=2, 26 | pool_type="avg", pool_size=(1,2), kernel_size=(2, 2), 27 | feature_type="stft", 28 | conv_layers_config=DEFAULT_CONV_CONFIG, 29 | stft_config=DEFAULT_STFT_CONFIG, 30 | fc_layer_dropout_rate=0.5, 31 | activation="relu", 32 | use_complex_rnn=False, 33 | init_real_layers=True, 34 | **kwargs): 35 | 36 | super().__init__() 37 | 38 | # 1. Store configuration 39 | self.output_type = output_type 40 | self.n_input_channels = n_input_channels 41 | self.n_sources = n_sources 42 | self.pool_type = pool_type 43 | self.pool_size = pool_size 44 | self.kernel_size = kernel_size 45 | self.activation = activation 46 | self.max_filters = conv_layers_config[-1]["n_channels"] 47 | self.is_rnn_complex = use_complex_rnn 48 | 49 | # 2. Create feature extractor 50 | self.feature_extractor = self._create_feature_extractor(feature_type, stft_config) 51 | 52 | # 3. Create convolutional blocks 53 | self.conv_blocks = self._create_conv_blocks(conv_layers_config, init_weights=init_real_layers) 54 | 55 | # 4. Create recurrent block 56 | self.rnn = self._create_rnn_block() 57 | 58 | # 5. Create linear block 59 | self.azimuth_fc = self._create_linear_block(n_sources, fc_layer_dropout_rate) 60 | 61 | # If using a real valued rnn, initialize gru and fc layers 62 | if not use_complex_rnn and init_real_layers: 63 | init_gru(self.rnn) 64 | init_layer(self.azimuth_fc) 65 | 66 | def forward(self, x): 67 | # input: (batch_size, mic_channels, time_steps) 68 | 69 | # 1. Extract STFT of signals 70 | x = self.feature_extractor(x) 71 | # (batch_size, mic_channels, n_freqs, stft_time_steps) 72 | x = x.transpose(2, 3) 73 | # (batch_size, mic_channels, stft_time_steps, n_freqs) 74 | 75 | # 2. Extract features using convolutional layers 76 | for conv_block in self.conv_blocks: 77 | if x.is_complex() and conv_block.is_real: 78 | x = complex_to_real(x) 79 | x = conv_block(x) 80 | # (batch_size, feature_maps, time_steps, n_freqs) 81 | 82 | # 3. Average across all frequencies 83 | x = torch.mean(x, dim=3) 84 | # (batch_size, feature_maps, time_steps) 85 | 86 | # Preprocessing for RNN 87 | if x.is_complex() and not self.is_rnn_complex: 88 | x = complex_to_real(x) 89 | x = x.transpose(1,2) 90 | # (batch_size, time_steps, feature_maps): 91 | 92 | # 4. Use features as input to RNN 93 | (x, _) = self.rnn(x) 94 | # (batch_size, time_steps, feature_maps): 95 | # Average across all time steps 96 | x = torch.mean(x, dim=1) 97 | 98 | # 5. Fully connected layer 99 | x = self.azimuth_fc(x) 100 | # (batch_size, class_num) 101 | 102 | if x.is_complex(): 103 | x = complex_to_real(x) 104 | return x 105 | 106 | def _create_feature_extractor(self, feature_type, stft_config): 107 | if feature_type == "cross_spectra": 108 | self.n_input_channels = sum(range(self.n_input_channels + 1)) 109 | 110 | return FEATURE_NAME_TO_CLASS_MAP[feature_type](stft_config) 111 | 112 | 113 | def _create_conv_blocks(self, conv_layers_config, init_weights): 114 | 115 | conv_blocks = [ 116 | ConvBlock(self.n_input_channels, conv_layers_config[0]["n_channels"], 117 | block_type=conv_layers_config[0]["type"], 118 | dropout_rate=conv_layers_config[0]["dropout_rate"], 119 | pool_size=self.pool_size, 120 | activation=self.activation, 121 | kernel_size=self.kernel_size, 122 | init=init_weights), 123 | ] 124 | 125 | for i, config in enumerate(conv_layers_config[1:]): 126 | last_layer = conv_blocks[-1] 127 | in_channels = last_layer.out_channels 128 | if last_layer.is_real == False and "complex" not in config["type"]: 129 | # complex convolutions are performed using 2 convolutions of half the filters 130 | in_channels *= 2 131 | conv_blocks.append( 132 | ConvBlock(in_channels, config["n_channels"], 133 | block_type=config["type"], init=init_weights, 134 | dropout_rate=config["dropout_rate"], 135 | pool_size=self.pool_size, 136 | activation=self.activation, 137 | kernel_size=self.kernel_size) 138 | ) 139 | 140 | return nn.ModuleList(conv_blocks) 141 | 142 | def _create_rnn_block(self): 143 | if self.is_rnn_complex: 144 | return ComplexGRU(input_size=self.max_filters//2, 145 | hidden_size=self.max_filters//4, 146 | batch_first=True, bidirectional=True) 147 | else: 148 | return nn.GRU(input_size=self.max_filters, 149 | hidden_size=self.max_filters//2, 150 | batch_first=True, bidirectional=True) 151 | 152 | def _create_linear_block(self, n_sources, fc_layer_dropout_rate): 153 | if self.is_rnn_complex: 154 | # TODO: Use dropout on complex linear block 155 | return ComplexLinear(self.max_filters//2, n_sources) 156 | else: 157 | n_last_layer = 2*n_sources # 2 cartesian dimensions for each source 158 | if fc_layer_dropout_rate > 0: 159 | return nn.Sequential( 160 | nn.Linear(self.max_filters, n_last_layer, bias=True), 161 | nn.Dropout(fc_layer_dropout_rate) 162 | ) 163 | else: 164 | return nn.Linear(self.max_filters, n_last_layer, bias=True) 165 | 166 | def track_feature_maps(self): 167 | "Make all the intermediate layers accessible through the 'feature_maps' dictionary" 168 | 169 | self.feature_maps = {} 170 | 171 | hook_fn = self._create_hook_fn("stft") 172 | self.feature_extractor.register_forward_hook(hook_fn) 173 | 174 | for i, conv_layer in enumerate(self.conv_blocks): 175 | hook_fn = self._create_hook_fn(f"conv_{i}") 176 | conv_layer.register_forward_hook(hook_fn) 177 | 178 | hook_fn = self._create_hook_fn("rnn") 179 | self.rnn.register_forward_hook(hook_fn) 180 | 181 | hook_fn = self._create_hook_fn("azimuth_fc") 182 | self.azimuth_fc.register_forward_hook(hook_fn) 183 | 184 | def _create_hook_fn(self, layer_id): 185 | def fn(_, __, output): 186 | if type(output) == tuple: 187 | output = output[0] 188 | self.feature_maps[layer_id] = output.detach().cpu() #.cpu().detach() 189 | return fn 190 | 191 | 192 | def complex_to_real(x, mode="real_imag", axis=1): 193 | if mode == "real_imag": 194 | x = torch.cat([x.real, x.imag], axis=axis) 195 | elif mode == "magnitude": 196 | x = x.abs() 197 | elif mode == "phase": 198 | x = x.angle() 199 | elif mode == "amp_phase": 200 | x = torch.cat([x.abs(), x.angle()], axis=axis) 201 | else: 202 | raise ValueError(f"Invalid complex mode :{mode}") 203 | 204 | return x 205 | -------------------------------------------------------------------------------- /complex_neural_source_localization/trainer.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import torch 3 | 4 | from torch.optim.lr_scheduler import MultiStepLR 5 | 6 | from complex_neural_source_localization.model import DOACNet 7 | from complex_neural_source_localization.loss import AngularLoss 8 | from complex_neural_source_localization.utils.base_trainer import ( 9 | BaseTrainer, BaseLightningModule 10 | ) 11 | 12 | 13 | class DOACNetTrainer(BaseTrainer): 14 | def __init__(self, config): 15 | lightning_module = DOACNetLightniningModule(config) 16 | super().__init__(lightning_module, 17 | config["training"]["n_epochs"]) 18 | 19 | def fit(self, train_dataloaders, val_dataloaders=None): 20 | super().fit(self._lightning_module, train_dataloaders, 21 | val_dataloaders=val_dataloaders) 22 | 23 | def test(self, test_dataloaders): 24 | super().test(self._lightning_module, test_dataloaders, ckpt_path="best") 25 | 26 | 27 | class DOACNetLightniningModule(BaseLightningModule): 28 | """This class abstracts the 29 | training/validation/testing procedures 30 | used for training a DOACNet 31 | """ 32 | 33 | def __init__(self, config): 34 | config = OmegaConf.to_container(config) 35 | self.config = config 36 | 37 | n_sources = self.config["dataset"]["n_max_sources"] 38 | 39 | stft_config = { 40 | "n_fft": config["model"]["n_fft"], 41 | "use_onesided_fft": config["model"]["use_onesided_fft"] 42 | } 43 | 44 | model = DOACNet(n_sources=n_sources, 45 | stft_config=stft_config, 46 | **config["model"]) 47 | 48 | loss = AngularLoss() 49 | 50 | super().__init__(model, loss) 51 | 52 | def configure_optimizers(self): 53 | lr = self.config["training"]["learning_rate"] 54 | decay_step = self.config["training"]["learning_rate_decay_steps"] 55 | decay_value = self.config["training"]["learning_rate_decay_values"] 56 | 57 | optimizer = torch.optim.Adam(self.parameters(), lr=lr) 58 | scheduler = MultiStepLR(optimizer, decay_step, decay_value) 59 | 60 | return [optimizer], [scheduler] 61 | -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/complex_neural_source_localization/utils/__init__.py -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/base_trainer.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import pickle 3 | import pytorch_lightning as pl 4 | import torch 5 | 6 | from pytorch_lightning.callbacks import ( 7 | Callback, ModelCheckpoint, TQDMProgressBar 8 | ) 9 | from pytorch_lightning import loggers as pl_loggers 10 | 11 | from complex_neural_source_localization.utils.model_utilities import merge_list_of_dicts 12 | 13 | SAVE_DIR = "logs/" 14 | LOG_EVERY_N_STEPS = 50 15 | 16 | class BaseTrainer(pl.Trainer): 17 | def __init__(self, lightning_module, n_epochs): 18 | 19 | gpus = 1 if torch.cuda.is_available() else 0 20 | 21 | progress_bar = CustomProgressBar() 22 | 23 | checkpoint_callback = ModelCheckpoint( 24 | monitor="validation_loss", 25 | save_last=True, 26 | filename='weights-{epoch:02d}-{validation_loss:.2f}', 27 | save_weights_only=True 28 | ) 29 | 30 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=SAVE_DIR) 31 | csv_logger = pl_loggers.CSVLogger(save_dir=SAVE_DIR) 32 | 33 | super().__init__( 34 | max_epochs=n_epochs, 35 | callbacks=[ 36 | checkpoint_callback, progress_bar, # feature_map_callback 37 | ], 38 | logger=[tb_logger, csv_logger], 39 | gpus=gpus, 40 | log_every_n_steps=25 41 | ) 42 | 43 | self._lightning_module = lightning_module 44 | 45 | 46 | class BaseLightningModule(pl.LightningModule): 47 | """Class which abstracts interactions with Hydra 48 | and basic training/testing/validation conventions 49 | """ 50 | 51 | def __init__(self, model, loss, 52 | log_step=50): 53 | super().__init__() 54 | 55 | self.is_cuda_available = torch.cuda.is_available() 56 | 57 | self.model = model 58 | self.loss = loss 59 | 60 | self.log_step = log_step 61 | 62 | def _step(self, batch, batch_idx, log_model_output=False, 63 | log_labels=False): 64 | 65 | x, y = batch 66 | 67 | # 1. Compute model output and loss 68 | output = self.model(x) 69 | loss = self.loss(output, y, mean_reduce=False) 70 | 71 | output_dict = { 72 | "loss_vector": loss 73 | } 74 | 75 | # TODO: Add these to a callback 76 | # 2. Log model output 77 | if log_model_output: 78 | output_dict["model_output"] = output 79 | # 3. Log ground truth labels 80 | if log_labels: 81 | output_dict.update(y) 82 | 83 | output_dict["loss"] = output_dict["loss_vector"].mean() 84 | output_dict["loss_vector"] = output_dict["loss_vector"].detach().cpu() 85 | 86 | # 4. Log step metrics 87 | self.log("loss_step", output_dict["loss"], on_step=True, prog_bar=False) 88 | 89 | return output_dict 90 | 91 | def training_step(self, batch, batch_idx): 92 | return self._step(batch, batch_idx) 93 | 94 | def validation_step(self, batch, batch_idx): 95 | return self._step(batch, batch_idx, 96 | log_model_output=True, log_labels=True) 97 | 98 | def test_step(self, batch, batch_idx): 99 | return self._step(batch, batch_idx, 100 | log_model_output=True, log_labels=True) 101 | 102 | def _epoch_end(self, outputs, epoch_type="train", save_pickle=False): 103 | # 1. Compute epoch metrics 104 | outputs = merge_list_of_dicts(outputs) 105 | epoch_stats = { 106 | f"{epoch_type}_loss": outputs["loss"].mean(), 107 | f"{epoch_type}_std": outputs["loss"].std() 108 | } 109 | 110 | # 2. Log epoch metrics 111 | for key, value in epoch_stats.items(): 112 | self.log(key, value, on_epoch=True, prog_bar=True) 113 | 114 | # 3. Save complete epoch data on pickle 115 | if save_pickle: 116 | pickle_filename = f"{epoch_type}.pickle" 117 | with open(pickle_filename, "wb") as f: 118 | pickle.dump(outputs, f) 119 | 120 | return epoch_stats 121 | 122 | def training_epoch_end(self, outputs): 123 | self._epoch_end(outputs) 124 | 125 | def validation_epoch_end(self, outputs): 126 | self._epoch_end(outputs, epoch_type="validation") 127 | 128 | def test_epoch_end(self, outputs): 129 | self._epoch_end(outputs, epoch_type="test", save_pickle=True) 130 | 131 | def forward(self, x): 132 | return self.model(x) 133 | 134 | def fit(self, dataset_train, dataset_val): 135 | super().fit(self.model, dataset_train, val_dataloaders=dataset_val) 136 | 137 | def test(self, dataset_test): 138 | super().test(self.model, dataset_test, ckpt_path="best") 139 | 140 | 141 | class FeatureMapLoggerCallback(Callback): 142 | def on_test_start(self, trainer: BaseTrainer, pl_module: BaseLightningModule): 143 | pl_module.model.track_feature_maps() 144 | 145 | self.output_file = h5py.File("test_feature_maps.hdf5", "a") 146 | 147 | def on_test_batch_end(self, trainer: BaseTrainer, pl_module: BaseLightningModule, 148 | outputs, batch, batch_idx: int, dataloader_idx: int): 149 | feature_maps = pl_module.model.feature_maps 150 | 151 | group = self.output_file.create_group(str(batch_idx)) 152 | for feature_name, feature_map in feature_maps.items(): 153 | group.create_dataset(feature_name, data=feature_map.numpy()) 154 | 155 | self.output_file 156 | 157 | def on_test_end(self, trainer, pl_module): 158 | self.output_file.close() 159 | 160 | 161 | class CustomProgressBar(TQDMProgressBar): 162 | def get_metrics(self, trainer, model): 163 | # don't show the version number 164 | items = super().get_metrics(trainer, model) 165 | items.pop("v_num", None) 166 | return items 167 | -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/complexPyTorch/complexFunctions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | @author: spopoff 6 | """ 7 | 8 | import torch 9 | 10 | from torch.nn.functional import max_pool2d, avg_pool2d, dropout, dropout2d, interpolate 11 | from torch import tanh, relu, sigmoid 12 | 13 | 14 | def complex_matmul(A, B): 15 | ''' 16 | Performs the matrix product between two complex matricess 17 | ''' 18 | 19 | outp_real = torch.matmul(A.real, B.real) - torch.matmul(A.imag, B.imag) 20 | outp_imag = torch.matmul(A.real, B.imag) + torch.matmul(A.imag, B.real) 21 | 22 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) 23 | 24 | 25 | def complex_avg_pool2d(input, *args, **kwargs): 26 | ''' 27 | Perform complex average pooling. 28 | ''' 29 | absolute_value_real = avg_pool2d(input.real, *args, **kwargs) 30 | absolute_value_imag = avg_pool2d(input.imag, *args, **kwargs) 31 | 32 | return absolute_value_real.type(torch.complex64)+1j*absolute_value_imag.type(torch.complex64) 33 | 34 | 35 | def complex_normalize(input): 36 | ''' 37 | Perform complex normalization 38 | ''' 39 | real_value, imag_value = input.real, input.imag 40 | real_norm = (real_value - real_value.mean()) / real_value.std() 41 | imag_norm = (imag_value - imag_value.mean()) / imag_value.std() 42 | 43 | return real_norm.type(torch.complex64) + 1j*imag_norm.type(torch.complex64) 44 | 45 | 46 | def complex_relu(input): 47 | return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64) 48 | 49 | 50 | def complex_amp_tanh(input): 51 | "https://link.springer.com/book/10.1007/978-3-642-27632-3" 52 | return relu(input.abs()) * torch.exp(1.j * input.angle()) 53 | 54 | 55 | def complex_sigmoid(input): 56 | return sigmoid(input.real).type(torch.complex64)+1j*sigmoid(input.imag).type(torch.complex64) 57 | 58 | 59 | def complex_tanh(input): 60 | return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64) 61 | 62 | 63 | def complex_opposite(input): 64 | return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64)) 65 | 66 | 67 | def complex_stack(input, dim): 68 | input_real = [x.real for x in input] 69 | input_imag = [x.imag for x in input] 70 | return torch.stack(input_real, dim).type(torch.complex64)+1j*torch.stack(input_imag, dim).type(torch.complex64) 71 | 72 | 73 | def _retrieve_elements_from_indices(tensor, indices): 74 | flattened_tensor = tensor.flatten(start_dim=-2) 75 | output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices) 76 | return output 77 | 78 | 79 | def complex_upsample(input, size=None, scale_factor=None, mode='nearest', 80 | align_corners=None, recompute_scale_factor=None): 81 | ''' 82 | Performs upsampling by separately interpolating the real and imaginary part and recombining 83 | ''' 84 | outp_real = interpolate(input.real, size=size, scale_factor=scale_factor, mode=mode, 85 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) 86 | outp_imag = interpolate(input.imag, size=size, scale_factor=scale_factor, mode=mode, 87 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) 88 | 89 | return outp_real.type(torch.complex64) + 1j * outp_imag.type(torch.complex64) 90 | 91 | 92 | def complex_upsample2(input, size=None, scale_factor=None, mode='nearest', 93 | align_corners=None, recompute_scale_factor=None): 94 | ''' 95 | Performs upsampling by separately interpolating the amplitude and phase part and recombining 96 | ''' 97 | outp_abs = interpolate(input.abs(), size=size, scale_factor=scale_factor, mode=mode, 98 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) 99 | angle = torch.atan2(input.imag,input.real) 100 | outp_angle = interpolate(angle, size=size, scale_factor=scale_factor, mode=mode, 101 | align_corners=align_corners, recompute_scale_factor=recompute_scale_factor) 102 | 103 | return outp_abs \ 104 | * (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64)) 105 | 106 | 107 | def complex_max_pool2d(input,kernel_size, stride=None, padding=0, 108 | dilation=1, ceil_mode=False, return_indices=False): 109 | ''' 110 | Perform complex max pooling by selecting on the absolute value on the complex values. 111 | ''' 112 | absolute_value, indices = max_pool2d( 113 | input.abs(), 114 | kernel_size = kernel_size, 115 | stride = stride, 116 | padding = padding, 117 | dilation = dilation, 118 | ceil_mode = ceil_mode, 119 | return_indices = True 120 | ) 121 | # performs the selection on the absolute values 122 | absolute_value = absolute_value.type(torch.complex64) 123 | # retrieve the corresonding phase value using the indices 124 | # unfortunately, the derivative for 'angle' is not implemented 125 | angle = torch.atan2(input.imag,input.real) 126 | # get only the phase values selected by max pool 127 | angle = _retrieve_elements_from_indices(angle, indices) 128 | return absolute_value \ 129 | * (torch.cos(angle).type(torch.complex64)+1j*torch.sin(angle).type(torch.complex64)) 130 | 131 | 132 | def complex_dropout(input, p=0.5, training=True): 133 | # need to have the same dropout mask for real and imaginary part, 134 | # this not a clean solution! 135 | #mask = torch.ones_like(input).type(torch.float32) 136 | mask = torch.ones(*input.shape, dtype = torch.float32) 137 | mask = dropout(mask, p, training)*1/(1-p) 138 | mask.type(input.dtype) 139 | return mask*input 140 | 141 | 142 | def complex_dropout2d(input, p=0.5, training=True): 143 | # need to have the same dropout mask for real and imaginary part, 144 | # this not a clean solution! 145 | mask = torch.ones(*input.shape, dtype = torch.float32) 146 | mask = dropout2d(mask, p, training)*1/(1-p) 147 | mask.type(input.dtype) 148 | return mask*input 149 | -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/complexPyTorch/complexLayers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Mar 19 10:30:02 2019 5 | 6 | @author: Sebastien M. Popoff 7 | 8 | 9 | Based on https://openreview.net/forum?id=H1T2hmZAb 10 | """ 11 | 12 | import torch 13 | from torch.nn import ( 14 | Module, Parameter, init, 15 | Conv2d, ConvTranspose2d, Linear, LSTM, GRU, 16 | BatchNorm1d, BatchNorm2d, 17 | PReLU 18 | ) 19 | 20 | from .complexFunctions import ( 21 | complex_amp_tanh, 22 | complex_relu, 23 | complex_tanh, 24 | complex_sigmoid, 25 | complex_max_pool2d, 26 | complex_avg_pool2d, 27 | complex_dropout, 28 | complex_dropout2d, 29 | complex_opposite, 30 | ) 31 | 32 | 33 | def apply_complex(fr, fi, input, dtype=torch.complex64): 34 | return (fr(input.real)-fi(input.imag)).type(dtype) \ 35 | + 1j*(fr(input.imag)+fi(input.real)).type(dtype) 36 | 37 | 38 | class ComplexDropout(Module): 39 | def __init__(self, p=0.5): 40 | super().__init__() 41 | self.p = p 42 | 43 | def forward(self, input): 44 | if self.training: 45 | return complex_dropout(input, self.p) 46 | else: 47 | return input 48 | 49 | 50 | class ComplexDropout2d(Module): 51 | def __init__(self, p=0.5): 52 | super().__init__() 53 | self.p = p 54 | 55 | def forward(self, input): 56 | if self.training: 57 | return complex_dropout2d(input, self.p) 58 | else: 59 | return input 60 | 61 | 62 | class ComplexMaxPool2d(Module): 63 | 64 | def __init__(self, kernel_size, stride=None, padding=0, 65 | dilation=1, return_indices=False, ceil_mode=False): 66 | super().__init__() 67 | self.kernel_size = kernel_size 68 | self.stride = stride 69 | self.padding = padding 70 | self.dilation = dilation 71 | self.ceil_mode = ceil_mode 72 | self.return_indices = return_indices 73 | 74 | def forward(self, input): 75 | return complex_max_pool2d(input, kernel_size=self.kernel_size, 76 | stride=self.stride, padding=self.padding, 77 | dilation=self.dilation, ceil_mode=self.ceil_mode, 78 | return_indices=self.return_indices) 79 | 80 | 81 | class ComplexAvgPool2d(Module): 82 | 83 | def __init__(self, kernel_size, stride=None, padding=0, 84 | ceil_mode=False): 85 | super().__init__() 86 | self.kernel_size = kernel_size 87 | self.stride = stride 88 | self.padding = padding 89 | self.ceil_mode = ceil_mode 90 | 91 | def forward(self, input): 92 | return complex_avg_pool2d(input, kernel_size=self.kernel_size, 93 | stride=self.stride, padding=self.padding, 94 | ceil_mode=self.ceil_mode) 95 | 96 | 97 | class ComplexReLU(Module): 98 | 99 | def forward(self, input): 100 | return complex_relu(input) 101 | 102 | 103 | class ComplexSigmoid(Module): 104 | 105 | def forward(self, input): 106 | return complex_sigmoid(input) 107 | 108 | 109 | class ComplexTanh(Module): 110 | 111 | def forward(self, input): 112 | return complex_tanh(input) 113 | 114 | 115 | class ComplexAmpTanh(Module): 116 | def forward(self, input): 117 | return complex_amp_tanh(input) 118 | 119 | 120 | class ComplexPReLU(Module): 121 | def __init__(self): 122 | super().__init__() 123 | self.r_prelu = PReLU() 124 | self.i_prelu = PReLU() 125 | 126 | 127 | def forward(self, input): 128 | return self.r_prelu(input.real) + 1j*self.i_prelu(input.imag) 129 | 130 | 131 | class ComplexConvTranspose2d(Module): 132 | 133 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 134 | output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'): 135 | 136 | super().__init__() 137 | 138 | self.conv_tran_r = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, 139 | output_padding, groups, bias, dilation, padding_mode) 140 | self.conv_tran_i = ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, 141 | output_padding, groups, bias, dilation, padding_mode) 142 | 143 | def forward(self, input): 144 | return apply_complex(self.conv_tran_r, self.conv_tran_i, input) 145 | 146 | 147 | class ComplexConv2d(Module): 148 | 149 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, 150 | dilation=1, groups=1, bias=True): 151 | super().__init__() 152 | self.conv_r = Conv2d(in_channels, out_channels, 153 | kernel_size, stride, padding, dilation, groups, bias) 154 | self.conv_i = Conv2d(in_channels, out_channels, 155 | kernel_size, stride, padding, dilation, groups, bias) 156 | 157 | def forward(self, input): 158 | return apply_complex(self.conv_r, self.conv_i, input) 159 | 160 | 161 | class ComplexLinear(Module): 162 | 163 | def __init__(self, in_features, out_features): 164 | super().__init__() 165 | self.fc_r = Linear(in_features, out_features) 166 | self.fc_i = Linear(in_features, out_features) 167 | 168 | def forward(self, input): 169 | return apply_complex(self.fc_r, self.fc_i, input) 170 | 171 | 172 | class NaiveComplexBatchNorm1d(Module): 173 | ''' 174 | Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. 175 | ''' 176 | 177 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 178 | track_running_stats=True): 179 | super().__init__() 180 | self.bn_r = BatchNorm1d( 181 | num_features, eps, momentum, affine, track_running_stats) 182 | self.bn_i = BatchNorm1d( 183 | num_features, eps, momentum, affine, track_running_stats) 184 | 185 | def forward(self, input): 186 | return self.bn_r(input.real).type(torch.complex64) + 1j*self.bn_i(input.imag).type(torch.complex64) 187 | 188 | 189 | class NaiveComplexBatchNorm2d(Module): 190 | ''' 191 | Naive approach to complex batch norm, perform batch norm independently on real and imaginary part. 192 | ''' 193 | 194 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 195 | track_running_stats=True): 196 | super(NaiveComplexBatchNorm2d, self).__init__() 197 | self.bn_r = BatchNorm2d( 198 | num_features, eps, momentum, affine, track_running_stats) 199 | self.bn_i = BatchNorm2d( 200 | num_features, eps, momentum, affine, track_running_stats) 201 | 202 | def forward(self, input): 203 | return self.bn_r(input.real).type(torch.complex64) + 1j*self.bn_i(input.imag).type(torch.complex64) 204 | 205 | 206 | class _ComplexBatchNorm(Module): 207 | 208 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 209 | track_running_stats=True): 210 | super().__init__() 211 | self.num_features = num_features 212 | self.eps = eps 213 | self.momentum = momentum 214 | self.affine = affine 215 | self.track_running_stats = track_running_stats 216 | if self.affine: 217 | self.weight = Parameter(torch.Tensor(num_features, 3)) 218 | self.bias = Parameter(torch.Tensor(num_features, 2)) 219 | else: 220 | self.register_parameter('weight', None) 221 | self.register_parameter('bias', None) 222 | if self.track_running_stats: 223 | self.register_buffer('running_mean', torch.zeros( 224 | num_features, dtype=torch.complex64)) 225 | self.register_buffer('running_covar', torch.zeros(num_features, 3)) 226 | self.running_covar[:, 0] = 1.4142135623730951 227 | self.running_covar[:, 1] = 1.4142135623730951 228 | self.register_buffer('num_batches_tracked', 229 | torch.tensor(0, dtype=torch.long)) 230 | else: 231 | self.register_parameter('running_mean', None) 232 | self.register_parameter('running_covar', None) 233 | self.register_parameter('num_batches_tracked', None) 234 | self.reset_parameters() 235 | 236 | def reset_running_stats(self): 237 | if self.track_running_stats: 238 | self.running_mean.zero_() 239 | self.running_covar.zero_() 240 | self.running_covar[:, 0] = 1.4142135623730951 241 | self.running_covar[:, 1] = 1.4142135623730951 242 | self.num_batches_tracked.zero_() 243 | 244 | def reset_parameters(self): 245 | self.reset_running_stats() 246 | if self.affine: 247 | init.constant_(self.weight[:, :2], 1.4142135623730951) 248 | init.zeros_(self.weight[:, 2]) 249 | init.zeros_(self.bias) 250 | 251 | 252 | class ComplexBatchNorm2d(_ComplexBatchNorm): 253 | 254 | def forward(self, input): 255 | exponential_average_factor = 0.0 256 | 257 | if self.training and self.track_running_stats: 258 | if self.num_batches_tracked is not None: 259 | self.num_batches_tracked += 1 260 | if self.momentum is None: # use cumulative moving average 261 | exponential_average_factor = 1.0 / \ 262 | float(self.num_batches_tracked) 263 | else: # use exponential moving average 264 | exponential_average_factor = self.momentum 265 | 266 | if self.training or (not self.training and not self.track_running_stats): 267 | # calculate mean of real and imaginary part 268 | # mean does not support automatic differentiation for outputs with complex dtype. 269 | mean_r = input.real.mean([0, 2, 3]).type(torch.complex64) 270 | mean_i = input.imag.mean([0, 2, 3]).type(torch.complex64) 271 | mean = mean_r + 1j*mean_i 272 | else: 273 | mean = self.running_mean 274 | 275 | if self.training and self.track_running_stats: 276 | # update running mean 277 | with torch.no_grad(): 278 | self.running_mean = exponential_average_factor * mean\ 279 | + (1 - exponential_average_factor) * self.running_mean 280 | 281 | input = input - mean[None, :, None, None] 282 | 283 | if self.training or (not self.training and not self.track_running_stats): 284 | # Elements of the covariance matrix (biased for train) 285 | n = input.numel() / input.size(1) 286 | Crr = 1./n*input.real.pow(2).sum(dim=[0, 2, 3])+self.eps 287 | Cii = 1./n*input.imag.pow(2).sum(dim=[0, 2, 3])+self.eps 288 | Cri = (input.real.mul(input.imag)).mean(dim=[0, 2, 3]) 289 | else: 290 | Crr = self.running_covar[:, 0]+self.eps 291 | Cii = self.running_covar[:, 1]+self.eps 292 | Cri = self.running_covar[:, 2] # +self.eps 293 | 294 | if self.training and self.track_running_stats: 295 | with torch.no_grad(): 296 | self.running_covar[:, 0] = exponential_average_factor * Crr * n / (n - 1)\ 297 | + (1 - exponential_average_factor) * \ 298 | self.running_covar[:, 0] 299 | 300 | self.running_covar[:, 1] = exponential_average_factor * Cii * n / (n - 1)\ 301 | + (1 - exponential_average_factor) * \ 302 | self.running_covar[:, 1] 303 | 304 | self.running_covar[:, 2] = exponential_average_factor * Cri * n / (n - 1)\ 305 | + (1 - exponential_average_factor) * \ 306 | self.running_covar[:, 2] 307 | 308 | # calculate the inverse square root the covariance matrix 309 | det = Crr*Cii-Cri.pow(2) 310 | s = torch.sqrt(det) 311 | t = torch.sqrt(Cii+Crr + 2 * s) 312 | inverse_st = 1.0 / (s * t) 313 | Rrr = (Cii + s) * inverse_st 314 | Rii = (Crr + s) * inverse_st 315 | Rri = -Cri * inverse_st 316 | 317 | input = (Rrr[None, :, None, None]*input.real+Rri[None, :, None, None]*input.imag).type(torch.complex64) \ 318 | + 1j*(Rii[None, :, None, None]*input.imag+Rri[None, 319 | :, None, None]*input.real).type(torch.complex64) 320 | 321 | if self.affine: 322 | input = (self.weight[None, :, 0, None, None]*input.real+self.weight[None, :, 2, None, None]*input.imag + 323 | self.bias[None, :, 0, None, None]).type(torch.complex64) \ 324 | + 1j*(self.weight[None, :, 2, None, None]*input.real+self.weight[None, :, 1, None, None]*input.imag + 325 | self.bias[None, :, 1, None, None]).type(torch.complex64) 326 | 327 | return input 328 | 329 | 330 | class ComplexBatchNorm1d(_ComplexBatchNorm): 331 | 332 | def forward(self, input): 333 | 334 | exponential_average_factor = 0.0 335 | 336 | if self.training and self.track_running_stats: 337 | if self.num_batches_tracked is not None: 338 | self.num_batches_tracked += 1 339 | if self.momentum is None: # use cumulative moving average 340 | exponential_average_factor = 1.0 / \ 341 | float(self.num_batches_tracked) 342 | else: # use exponential moving average 343 | exponential_average_factor = self.momentum 344 | 345 | if self.training or (not self.training and not self.track_running_stats): 346 | # calculate mean of real and imaginary part 347 | mean_r = input.real.mean(dim=0).type(torch.complex64) 348 | mean_i = input.imag.mean(dim=0).type(torch.complex64) 349 | mean = mean_r + 1j*mean_i 350 | else: 351 | mean = self.running_mean 352 | 353 | if self.training and self.track_running_stats: 354 | # update running mean 355 | with torch.no_grad(): 356 | self.running_mean = exponential_average_factor * mean\ 357 | + (1 - exponential_average_factor) * self.running_mean 358 | 359 | input = input - mean[None, ...] 360 | 361 | if self.training or (not self.training and not self.track_running_stats): 362 | # Elements of the covariance matrix (biased for train) 363 | n = input.numel() / input.size(1) 364 | Crr = input.real.var(dim=0, unbiased=False)+self.eps 365 | Cii = input.imag.var(dim=0, unbiased=False)+self.eps 366 | Cri = (input.real.mul(input.imag)).mean(dim=0) 367 | else: 368 | Crr = self.running_covar[:, 0]+self.eps 369 | Cii = self.running_covar[:, 1]+self.eps 370 | Cri = self.running_covar[:, 2] 371 | 372 | if self.training and self.track_running_stats: 373 | self.running_covar[:, 0] = exponential_average_factor * Crr * n / (n - 1)\ 374 | + (1 - exponential_average_factor) * self.running_covar[:, 0] 375 | 376 | self.running_covar[:, 1] = exponential_average_factor * Cii * n / (n - 1)\ 377 | + (1 - exponential_average_factor) * self.running_covar[:, 1] 378 | 379 | self.running_covar[:, 2] = exponential_average_factor * Cri * n / (n - 1)\ 380 | + (1 - exponential_average_factor) * self.running_covar[:, 2] 381 | 382 | # calculate the inverse square root the covariance matrix 383 | det = Crr*Cii-Cri.pow(2) 384 | s = torch.sqrt(det) 385 | t = torch.sqrt(Cii+Crr + 2 * s) 386 | inverse_st = 1.0 / (s * t) 387 | Rrr = (Cii + s) * inverse_st 388 | Rii = (Crr + s) * inverse_st 389 | Rri = -Cri * inverse_st 390 | 391 | input = (Rrr[None, :]*input.real+Rri[None, :]*input.imag).type(torch.complex64) \ 392 | + 1j*(Rii[None, :]*input.imag+Rri[None, :] 393 | * input.real).type(torch.complex64) 394 | 395 | if self.affine: 396 | input = (self.weight[None, :, 0]*input.real+self.weight[None, :, 2]*input.imag + 397 | self.bias[None, :, 0]).type(torch.complex64) \ 398 | + 1j*(self.weight[None, :, 2]*input.real+self.weight[None, :, 1]*input.imag + 399 | self.bias[None, :, 1]).type(torch.complex64) 400 | 401 | del Crr, Cri, Cii, Rrr, Rii, Rri, det, s, t 402 | return input 403 | 404 | 405 | class ComplexGRUCell(Module): 406 | """ 407 | A GRU cell for complex-valued inputs 408 | """ 409 | 410 | def __init__(self, input_length, hidden_length): 411 | super().__init__() 412 | self.input_length = input_length 413 | self.hidden_length = hidden_length 414 | 415 | # reset gate components 416 | self.linear_reset_w1 = ComplexLinear( 417 | self.input_length, self.hidden_length) 418 | self.linear_reset_r1 = ComplexLinear( 419 | self.hidden_length, self.hidden_length) 420 | 421 | self.linear_reset_w2 = ComplexLinear( 422 | self.input_length, self.hidden_length) 423 | self.linear_reset_r2 = ComplexLinear( 424 | self.hidden_length, self.hidden_length) 425 | 426 | # update gate components 427 | self.linear_gate_w3 = ComplexLinear( 428 | self.input_length, self.hidden_length) 429 | self.linear_gate_r3 = ComplexLinear( 430 | self.hidden_length, self.hidden_length) 431 | 432 | self.activation_gate = ComplexSigmoid() 433 | self.activation_candidate = ComplexTanh() 434 | 435 | def reset_gate(self, x, h): 436 | x_1 = self.linear_reset_w1(x) 437 | h_1 = self.linear_reset_r1(h) 438 | # gate update 439 | reset = self.activation_gate(x_1 + h_1) 440 | return reset 441 | 442 | def update_gate(self, x, h): 443 | x_2 = self.linear_reset_w2(x) 444 | h_2 = self.linear_reset_r2(h) 445 | z = self.activation_gate(h_2 + x_2) 446 | return z 447 | 448 | def update_component(self, x, h, r): 449 | x_3 = self.linear_gate_w3(x) 450 | h_3 = r * self.linear_gate_r3(h) # element-wise multiplication 451 | gate_update = self.activation_candidate(x_3 + h_3) 452 | return gate_update 453 | 454 | def forward(self, x, h): 455 | # Equation 1. reset gate vector 456 | r = self.reset_gate(x, h) 457 | 458 | # Equation 2: the update gate - the shared update gate vector z 459 | z = self.update_gate(x, h) 460 | 461 | # Equation 3: The almost output component 462 | n = self.update_component(x, h, r) 463 | 464 | # Equation 4: the new hidden state 465 | h_new = (1 + complex_opposite(z)) * n + \ 466 | z * h # element-wise multiplication 467 | 468 | return h_new 469 | 470 | 471 | class ComplexBNGRUCell(Module): 472 | """ 473 | A BN-GRU cell for complex-valued inputs 474 | """ 475 | 476 | def __init__(self, input_length=10, hidden_length=20): 477 | super().__init__() 478 | self.input_length = input_length 479 | self.hidden_length = hidden_length 480 | 481 | # reset gate components 482 | self.linear_reset_w1 = ComplexLinear( 483 | self.input_length, self.hidden_length) 484 | self.linear_reset_r1 = ComplexLinear( 485 | self.hidden_length, self.hidden_length) 486 | 487 | self.linear_reset_w2 = ComplexLinear( 488 | self.input_length, self.hidden_length) 489 | self.linear_reset_r2 = ComplexLinear( 490 | self.hidden_length, self.hidden_length) 491 | 492 | # update gate components 493 | self.linear_gate_w3 = ComplexLinear( 494 | self.input_length, self.hidden_length) 495 | self.linear_gate_r3 = ComplexLinear( 496 | self.hidden_length, self.hidden_length) 497 | 498 | self.activation_gate = ComplexSigmoid() 499 | self.activation_candidate = ComplexTanh() 500 | 501 | self.bn = ComplexBatchNorm2d(1) 502 | 503 | def reset_gate(self, x, h): 504 | x_1 = self.linear_reset_w1(x) 505 | h_1 = self.linear_reset_r1(h) 506 | # gate update 507 | reset = self.activation_gate(self.bn(x_1) + self.bn(h_1)) 508 | return reset 509 | 510 | def update_gate(self, x, h): 511 | x_2 = self.linear_reset_w2(x) 512 | h_2 = self.linear_reset_r2(h) 513 | z = self.activation_gate(self.bn(h_2) + self.bn(x_2)) 514 | return z 515 | 516 | def update_component(self, x, h, r): 517 | x_3 = self.linear_gate_w3(x) 518 | # element-wise multiplication 519 | h_3 = r * self.bn(self.linear_gate_r3(h)) 520 | gate_update = self.activation_candidate(self.bn(self.bn(x_3) + h_3)) 521 | return gate_update 522 | 523 | def forward(self, x, h): 524 | # Equation 1. reset gate vector 525 | r = self.reset_gate(x, h) 526 | 527 | # Equation 2: the update gate - the shared update gate vector z 528 | z = self.update_gate(x, h) 529 | 530 | # Equation 3: The almost output component 531 | n = self.update_component(x, h, r) 532 | 533 | # Equation 4: the new hidden state 534 | h_new = (1 + complex_opposite(z)) * n + \ 535 | z * h # element-wise multiplication 536 | 537 | return h_new 538 | 539 | 540 | class ComplexGRU(Module): 541 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, 542 | batch_first=False, dropout=0, bidirectional=False): 543 | super().__init__() 544 | 545 | self.gru_re = GRU(input_size=input_size, hidden_size=hidden_size, 546 | num_layers=num_layers, bias=bias, 547 | batch_first=batch_first, dropout=dropout, 548 | bidirectional=bidirectional) 549 | self.gru_im = GRU(input_size=input_size, hidden_size=hidden_size, 550 | num_layers=num_layers, bias=bias, 551 | batch_first=batch_first, dropout=dropout, 552 | bidirectional=bidirectional) 553 | 554 | def forward(self, x): 555 | real, state_real = self._forward_real(x) 556 | imaginary, state_imag = self._forward_imaginary(x) 557 | 558 | output = torch.complex(real, imaginary) 559 | state = torch.complex(state_real, state_imag) 560 | 561 | return output, state 562 | 563 | def forward(self, x): 564 | r2r_out = self.gru_re(x.real)[0] 565 | r2i_out = self.gru_im(x.real)[0] 566 | i2r_out = self.gru_re(x.imag)[0] 567 | i2i_out = self.gru_im(x.imag)[0] 568 | real_out = r2r_out - i2i_out 569 | imag_out = i2r_out + r2i_out 570 | 571 | return torch.complex(real_out, imag_out), None 572 | 573 | def _forward_real(self, x): 574 | real_real, h_real = self.gru_re(x.real) 575 | imag_imag, h_imag = self.gru_im(x.imag) 576 | real = real_real - imag_imag 577 | 578 | return real, torch.complex(h_real, h_imag) 579 | 580 | def _forward_imaginary(self, x): 581 | imag_real, h_real = self.gru_re(x.imag) 582 | real_imag, h_imag = self.gru_im(x.real) 583 | imaginary = imag_real + real_imag 584 | 585 | return imaginary, torch.complex(h_real, h_imag) 586 | 587 | 588 | class ComplexLSTM(Module): 589 | def __init__(self, input_size, hidden_size, num_layers=1, bias=True, 590 | batch_first=False, dropout=0, bidirectional=False): 591 | super().__init__() 592 | self.num_layer = num_layers 593 | self.hidden_size = hidden_size 594 | self.batch_dim = 0 if batch_first else 1 595 | self.bidirectional = bidirectional 596 | 597 | self.lstm_re = LSTM(input_size=input_size, hidden_size=hidden_size, 598 | num_layers=num_layers, bias=bias, 599 | batch_first=batch_first, dropout=dropout, 600 | bidirectional=bidirectional) 601 | self.lstm_im = LSTM(input_size=input_size, hidden_size=hidden_size, 602 | num_layers=num_layers, bias=bias, 603 | batch_first=batch_first, dropout=dropout, 604 | bidirectional=bidirectional) 605 | def forward(self, x): 606 | real, state_real = self._forward_real(x) 607 | imaginary, state_imag = self._forward_imaginary(x) 608 | 609 | output = torch.complex(real, imaginary) 610 | 611 | return output, (state_real, state_imag) 612 | 613 | def _forward_real(self, x): 614 | h_real, h_imag, c_real, c_imag = self._init_state(self._get_batch_size(x), x.is_cuda) 615 | real_real, (h_real, c_real) = self.lstm_re(x.real, (h_real, c_real)) 616 | imag_imag, (h_imag, c_imag) = self.lstm_im(x.imag, (h_imag, c_imag)) 617 | real = real_real - imag_imag 618 | return real, ((h_real, c_real), (h_imag, c_imag)) 619 | 620 | def _forward_imaginary(self, x): 621 | h_real, h_imag, c_real, c_imag = self._init_state(self._get_batch_size(x), x.is_cuda) 622 | imag_real, (h_real, c_real) = self.lstm_re(x.imag, (h_real, c_real)) 623 | real_imag, (h_imag, c_imag) = self.lstm_im(x.real, (h_imag, c_imag)) 624 | imaginary = imag_real + real_imag 625 | 626 | return imaginary, ((h_real, c_real), (h_imag, c_imag)) 627 | 628 | def _init_state(self, batch_size, to_gpu=False): 629 | dim_0 = 2 if self.bidirectional else 1 630 | dims = (dim_0, batch_size, self.hidden_size) 631 | 632 | h_real, h_imag, c_real, c_imag = [ 633 | torch.zeros(dims) for i in range(4)] 634 | 635 | if to_gpu: 636 | h_real, h_imag, c_real, c_imag = [ 637 | t.cuda() for t in [h_real, h_imag, c_real, c_imag]] 638 | 639 | 640 | return h_real, h_imag, c_real, c_imag 641 | 642 | def _get_batch_size(self, x): 643 | return x.size(self.batch_dim) 644 | -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/conv_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from complex_neural_source_localization.utils.complexPyTorch.complexLayers import ( 4 | ComplexAmpTanh, ComplexConv2d, ComplexBatchNorm2d, ComplexDropout, 5 | ComplexReLU, ComplexTanh, ComplexPReLU, ComplexAvgPool2d 6 | ) 7 | from complex_neural_source_localization.utils.model_utilities import init_layer 8 | 9 | 10 | class ConvBlock(nn.Module): 11 | def __init__(self, in_channels, out_channels, 12 | kernel_size=(3,3), stride=(1,1), 13 | padding=(1,1), pool_size=(2, 2), 14 | block_type="real_double", 15 | init=False, 16 | dropout_rate=0.1, 17 | activation="relu"): 18 | 19 | super().__init__() 20 | self.block_type = block_type 21 | self.pool_size=pool_size 22 | self.dropout_rate = dropout_rate 23 | 24 | if "complex" in block_type: 25 | conv_block = ComplexConv2d 26 | bn_block = ComplexBatchNorm2d 27 | dropout_block = ComplexDropout 28 | if activation == "relu": 29 | self.activation = ComplexReLU() 30 | elif activation == "amp_tanh": 31 | self.activation = ComplexAmpTanh() 32 | elif activation == "tanh": 33 | self.activation = ComplexTanh() 34 | elif activation == "prelu": 35 | self.activation = ComplexPReLU() 36 | 37 | self.pooling = ComplexAvgPool2d(pool_size) 38 | self.is_real = False 39 | out_channels = out_channels//2 40 | else: 41 | conv_block = nn.Conv2d 42 | bn_block = nn.BatchNorm2d 43 | dropout_block = nn.Dropout 44 | self.activation = nn.ReLU() 45 | self.pooling = nn.AvgPool2d(pool_size) 46 | self.is_real = True 47 | 48 | self.conv1 = conv_block(in_channels=in_channels, 49 | out_channels=out_channels, 50 | kernel_size=kernel_size, stride=stride, 51 | padding=padding, bias=False) 52 | self.bn1 = bn_block(out_channels) 53 | self.dropout = dropout_block(dropout_rate) 54 | 55 | if "double" in block_type: 56 | self.conv2 = conv_block(in_channels=out_channels, 57 | out_channels=out_channels, 58 | kernel_size=kernel_size, stride=stride, 59 | padding=padding, bias=False) 60 | self.bn2 = bn_block(out_channels) 61 | 62 | self.in_channels = in_channels 63 | self.out_channels = out_channels 64 | 65 | if init and "real" in block_type: 66 | # TODO: Implement complex-valued initialization 67 | self._init_weights() 68 | 69 | def forward(self, x): 70 | x = self.activation(self.bn1(self.conv1(x))) 71 | if "double" in self.block_type: 72 | x = self.activation(self.bn2(self.conv2(x))) 73 | x = self.pooling(x) 74 | 75 | if self.dropout_rate > 0: 76 | x = self.dropout(x) 77 | return x 78 | 79 | def _init_weights(self): 80 | init_layer(self.conv1) 81 | init_layer(self.bn1) 82 | if self.block_type == "real_double": 83 | init_layer(self.conv2) 84 | init_layer(self.bn2) -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/load_config.py: -------------------------------------------------------------------------------- 1 | from hydra import compose, initialize 2 | from hydra.core.global_hydra import GlobalHydra 3 | 4 | 5 | def load_config(key=None, reset_hydra=True): 6 | if reset_hydra: 7 | GlobalHydra.instance().clear() 8 | 9 | initialize(config_path="../../config", job_name="_") 10 | cfg = compose(config_name="config") 11 | 12 | if key is not None: 13 | return cfg[key] 14 | 15 | return cfg 16 | -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/model_utilities.py: -------------------------------------------------------------------------------- 1 | # Credits to Yin Cao et al: 2 | # https://github.com/yinkalario/Two-Stage-Polyphonic-Sound-Event-Detection-and-Localization/blob/master/models/model_utilities.py 3 | 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def init_layer(layer, nonlinearity='leaky_relu'): 13 | """Initialize a convolutional or linear layer""" 14 | classname = layer.__class__.__name__ 15 | if (classname.find('Conv') != -1) or (classname.find('Linear') != -1): 16 | nn.init.kaiming_uniform_(layer.weight, nonlinearity=nonlinearity) 17 | if hasattr(layer, 'bias'): 18 | if layer.bias is not None: 19 | nn.init.constant_(layer.bias, 0.0) 20 | elif classname.find('BatchNorm') != -1: 21 | nn.init.normal_(layer.weight, 1.0, 0.02) 22 | nn.init.constant_(layer.bias, 0.0) 23 | 24 | 25 | def init_gru(rnn): 26 | """Initialize a GRU layer. """ 27 | 28 | def _concat_init(tensor, init_funcs): 29 | (length, fan_out) = tensor.shape 30 | fan_in = length // len(init_funcs) 31 | 32 | for (i, init_func) in enumerate(init_funcs): 33 | init_func(tensor[i * fan_in : (i + 1) * fan_in, :]) 34 | 35 | def _inner_uniform(tensor): 36 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') 37 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 38 | 39 | for i in range(rnn.num_layers): 40 | _concat_init( 41 | getattr(rnn, 'weight_ih_l{}'.format(i)), 42 | [_inner_uniform, _inner_uniform, _inner_uniform] 43 | ) 44 | torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0) 45 | 46 | _concat_init( 47 | getattr(rnn, 'weight_hh_l{}'.format(i)), 48 | [_inner_uniform, _inner_uniform, nn.init.orthogonal_] 49 | ) 50 | torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0) 51 | 52 | 53 | def merge_list_of_dicts(list_of_dicts): 54 | result = {} 55 | 56 | def _add_to_dict(key, value): 57 | if len(value.shape) == 0: # 0-dimensional tensor 58 | value = value.unsqueeze(0) 59 | 60 | if key not in result: 61 | result[key] = value 62 | else: 63 | result[key] = torch.cat([ 64 | result[key], value 65 | ]) 66 | 67 | for d in list_of_dicts: 68 | for key, value in d.items(): 69 | _add_to_dict(key, value) 70 | 71 | return result 72 | 73 | 74 | def get_all_layers(model: nn.Module, layer_types=None, name_prefix=""): 75 | 76 | layers = {} 77 | 78 | for name, layer in model.named_children(): 79 | if name_prefix: 80 | name = f"{name_prefix}.{name}" 81 | if isinstance(layer, nn.Sequential) or isinstance(layer, nn.ModuleList): 82 | layers.update(get_all_layers(layer, layer_types, name)) 83 | else: 84 | layers[name] = layer 85 | 86 | if layer_types is not None: 87 | layers = { 88 | layer_id: layer 89 | for layer_id, layer in layers.items() 90 | if any([ 91 | isinstance(layer, layer_type) 92 | for layer_type in layer_types 93 | ]) 94 | } 95 | 96 | return layers 97 | -------------------------------------------------------------------------------- /complex_neural_source_localization/utils/model_visualization.py: -------------------------------------------------------------------------------- 1 | from inspect import unwrap 2 | import librosa 3 | import librosa.display 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import os 7 | import torch 8 | 9 | from pathlib import Path 10 | from torchvision.utils import make_grid 11 | from tqdm import tqdm 12 | 13 | from complex_neural_source_localization.utils.model_utilities import ( 14 | get_all_layers 15 | ) 16 | from complex_neural_source_localization.utils.conv_block import ConvBlock 17 | 18 | 19 | def plot_multichannel_spectrogram(multichannel_spectrogram, unwrap=True, unwrap_mode="freq", mode="column", 20 | axs=None, figsize=(10, 5), output_path=None, close=True, db=True, 21 | colorbar=True): 22 | 23 | num_channels, num_freq_bins, num_time_steps = multichannel_spectrogram.shape 24 | 25 | if axs is None: 26 | if mode == "column": 27 | n_rows, n_cols = (2, num_channels) 28 | share_y = "col" 29 | elif mode == "row": 30 | n_rows, n_cols = (num_channels, 2) 31 | share_y = "row" 32 | 33 | fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=figsize, sharey=share_y) 34 | 35 | # If torch.Tensor, move to cpu and convert to numpy 36 | if isinstance(multichannel_spectrogram, torch.Tensor): 37 | multichannel_spectrogram = multichannel_spectrogram.cpu().detach().numpy() 38 | 39 | # Plot spectrograms for all channels 40 | for n_channel in range(num_channels): 41 | if mode == "row": 42 | channel_axs = [axs[n_channel, 0], axs[n_channel, 1]] 43 | elif mode == "column": 44 | channel_axs = [axs[0, n_channel], axs[1, n_channel]] 45 | 46 | else: 47 | raise ValueError("Allowed modes are 'row' and 'column'") 48 | 49 | (mag_mesh, phase_mesh), _ = plot_spectrogram(multichannel_spectrogram[n_channel], 50 | axs=channel_axs, unwrap=unwrap, unwrap_mode=unwrap_mode, 51 | db=db, colorbar=False) 52 | 53 | if colorbar: 54 | if mode == "column": 55 | location = "right" 56 | mag_axs, phase_axs = axs[0, :], axs[1, :] 57 | elif mode == "row": 58 | location = "top" 59 | mag_axs, phase_axs = axs[:, 0], axs[:, 1] 60 | 61 | plt.colorbar(mag_mesh, ax=mag_axs, format="%+2.f", 62 | location=location) 63 | plt.colorbar(phase_mesh, ax=phase_axs, 64 | location=location) 65 | 66 | if output_path is not None: 67 | plt.savefig(output_path) 68 | if close: 69 | plt.close() 70 | 71 | return axs 72 | 73 | 74 | def plot_spectrogram(spectrogram, unwrap=True, unwrap_mode="freq", db=True, 75 | mode="column", figsize=(10, 5), axs=None, output_path=None, close=True, 76 | colorbar=True): 77 | if axs is None: 78 | if mode == "column": 79 | n_rows, n_cols = (2, 1) 80 | elif mode == "row": 81 | n_rows, n_cols = (1, 2) 82 | 83 | fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=figsize) 84 | 85 | # If torch.Tensor, move to cpu and convert to numpy 86 | if isinstance(spectrogram, torch.Tensor): 87 | spectrogram = spectrogram.cpu().detach().numpy() 88 | 89 | # Extract magnitude and phase, then unwrap phase 90 | spectrogram_mag = np.abs(spectrogram) 91 | if db: 92 | spectrogram_mag = librosa.amplitude_to_db(spectrogram_mag, ref=np.max) 93 | spectrogram_phase = np.angle(spectrogram) 94 | if unwrap: 95 | axis = 0 if unwrap_mode == "freq" else 1 96 | spectrogram_phase = np.unwrap(spectrogram_phase, axis=axis) 97 | 98 | # img_mag = librosa.display.specshow(spectrogram_mag, ax=axs[0]) 99 | # img_phase = librosa.display.specshow(spectrogram_phase, ax=axs[1]) 100 | 101 | # https://matplotlib.org/stable/tutorials/colors/colormaps.html for beautiful colormaps 102 | mag_mesh = axs[0].pcolormesh(spectrogram_mag, cmap="RdBu_r") 103 | phase_mesh = axs[1].pcolormesh(spectrogram_phase, cmap="RdBu_r") 104 | 105 | axs[0].xaxis.set_ticklabels([]) 106 | axs[0].yaxis.set_ticklabels([]) 107 | axs[1].xaxis.set_ticklabels([]) 108 | axs[1].yaxis.set_ticklabels([]) 109 | 110 | if colorbar: 111 | plt.colorbar(mag_mesh, ax=axs[0], format="%+2.f dB") 112 | plt.colorbar(phase_mesh, ax=axs[1], format="%+4.f rad") 113 | 114 | if output_path is not None: 115 | plt.savefig(output_path) 116 | if close: 117 | plt.close() 118 | 119 | return (mag_mesh, phase_mesh), axs 120 | 121 | 122 | def plot_model_output(feature_maps, metadata=None, unwrap=True, 123 | batch_start_idx=0, output_dir_path=None, close_after_saving=True): 124 | 125 | output_dir_path = Path(output_dir_path) 126 | os.makedirs(output_dir_path, exist_ok=True) 127 | 128 | batch_size = feature_maps["stft"].shape[0] 129 | 130 | 131 | stft_output = feature_maps["stft"] 132 | conv_output = { 133 | feature_name: feature_map.transpose(2, 3) 134 | for feature_name, feature_map in feature_maps.items() 135 | if "conv" in feature_name 136 | } 137 | rnn_output = feature_maps["rnn"].transpose(1, 2) 138 | 139 | for i in tqdm(range(batch_size)): 140 | sample_idx = batch_size + batch_start_idx 141 | multichannel_spectrogram_filename = output_dir_path / f"{sample_idx}_multichannel_stft.png" 142 | 143 | plot_multichannel_spectrogram( 144 | stft_output[i], unwrap=unwrap, 145 | output_path=multichannel_spectrogram_filename, close=close_after_saving) 146 | 147 | for conv_id, conv_map in conv_output.items(): 148 | if conv_map.is_complex(): 149 | conv_filename = output_dir_path / f"{sample_idx}_{conv_id}.png" 150 | plot_multichannel_spectrogram( 151 | conv_map[i], unwrap=unwrap, output_path=conv_filename, mode="row", 152 | figsize=(5, 10), close=close_after_saving) 153 | else: 154 | print("real") 155 | 156 | rnn_filename = output_dir_path / f"{sample_idx}_rnn_output.png" 157 | plot_spectrogram(rnn_output[i], unwrap=unwrap, output_path=rnn_filename, close=close_after_saving) 158 | 159 | 160 | def plot_real_feature_maps(feature_maps, mode="column", axs=None, figsize=(10, 5), 161 | output_path=None, close=True): 162 | 163 | num_channels, num_freq_bins, num_time_steps = feature_maps.shape 164 | 165 | if axs is None: 166 | if mode == "column": 167 | n_rows, n_cols = (1, num_channels) 168 | elif mode == "row": 169 | n_rows, n_cols = (num_channels, 1) 170 | 171 | fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=figsize) 172 | 173 | # If torch.Tensor, move to cpu and convert to numpy 174 | if isinstance(feature_maps, torch.Tensor): 175 | multichannel_spectrogram = feature_maps.cpu().detach().numpy() 176 | 177 | # Plot spectrograms for all channels 178 | for n_channel in range(num_channels): 179 | channel_ax = axs[n_channel] 180 | 181 | plot_real_feature_map(multichannel_spectrogram[n_channel], ax=channel_ax) 182 | 183 | if output_path is not None: 184 | plt.savefig(output_path) 185 | if close: 186 | plt.close() 187 | 188 | return axs 189 | 190 | 191 | def plot_real_feature_map(feature_map, mode="column", figsize=(10, 5), ax=None, output_path=None, close=True): 192 | if ax is None: 193 | fig, ax = plt.subplots(figsize=figsize) 194 | 195 | # If torch.Tensor, move to cpu and convert to numpy 196 | if isinstance(feature_map, torch.Tensor): 197 | feature_map = feature_map.cpu().detach().numpy() 198 | 199 | 200 | librosa.display.specshow(feature_map, ax=ax) 201 | 202 | if output_path is not None: 203 | plt.savefig(output_path) 204 | if close: 205 | plt.close() 206 | 207 | return ax 208 | 209 | 210 | class ConvolutionalFeatureMapLogger: 211 | def __init__(self, model, trainer): 212 | # 1. Find convolutional layers in the model 213 | self.conv_layers = get_all_layers(model, [ConvBlock]) 214 | # 2. Variable to store the output feature maps produced in a forward pass 215 | self.feature_maps = {} 216 | # 3. Create a forward hook to fill the variable above at every pass 217 | 218 | for layer_id, layer in self.conv_layers.items(): 219 | fn = self._create_hook(layer_id) 220 | layer.register_forward_hook(fn) 221 | 222 | self.trainer = trainer 223 | 224 | def log(self): 225 | n_epoch = self.trainer.current_epoch 226 | for layer, feature_maps in self.feature_maps.items(): 227 | batch_sample_idx = 0 # Always select first example on batch 228 | feature_maps = feature_maps[batch_sample_idx] 229 | 230 | # Transform grayscale to RGB image 231 | # Make R and G channels 0 so we get a nice blue picture 232 | # Transpose time and frequency channels to get the format 233 | # B x C x H x W required by torchvision's "make_grid" function 234 | feature_maps = feature_maps.unsqueeze(1).repeat([1, 3, 1, 1]) 235 | feature_maps[:, 0:2, :, :] = 0 236 | 237 | if feature_maps.dtype == torch.complex64: 238 | feature_maps_mag = feature_maps.abs() 239 | feature_maps_phase = feature_maps.angle() 240 | # TODO: Phase unwrapping 241 | 242 | feature_maps_mag = make_grid(feature_maps_mag, normalize=True, padding=5) 243 | self.trainer.logger.experiment.add_image(f"{layer}.mag.epoch{n_epoch}", feature_maps_mag) 244 | feature_maps_phase = make_grid(feature_maps_phase, normalize=True, padding=5) 245 | self.trainer.logger.experiment.add_image(f"{layer}.phase.epoch{n_epoch}", feature_maps_phase) 246 | else: 247 | feature_maps = make_grid(feature_maps, normalize=True, padding=5) 248 | 249 | self.trainer.logger.experiment.add_image(f"{layer}.epoch{n_epoch}", feature_maps) 250 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model 3 | - training 4 | - dataset: dcase_2019_task3_dataset -------------------------------------------------------------------------------- /config/dataset/dcase_2019_task3_dataset.yaml: -------------------------------------------------------------------------------- 1 | sr: 24000 2 | train_wav_path: /Users/ezajlerg/datasets/dcase_2019_task_3/mic_dev_splitted 3 | validation_wav_path: /Users/ezajlerg/datasets/dcase_2019_task_3/mic_dev_splitted 4 | test_wav_path: /Users/ezajlerg/datasets/dcase_2019_task_3/mic_eval_splitted 5 | train_csv_path: /Users/ezajlerg/datasets/dcase_2019_task_3/train_annotations.csv 6 | validation_csv_path: /Users/ezajlerg/datasets/dcase_2019_task_3/validation_annotations.csv 7 | test_csv_path: /Users/ezajlerg/datasets/dcase_2019_task_3/test_annotations.csv 8 | test_wavs_path: /Users/ezajlerg/datasets/dcase_2019_task_3/mic_eval_splitted 9 | num_mics: 4 10 | sample_duration_in_seconds: 1 11 | n_max_sources: 1 12 | name: "dcase_2019_task3" -------------------------------------------------------------------------------- /config/model.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | # Input params 3 | n_input_channels: 4 4 | n_fft: 1024 5 | window: hann 6 | feature_type: stft # stft | cross_spectra 7 | use_onesided_fft: True 8 | 9 | # Output and optimization settings 10 | output_type: scalar # scalar | frame 11 | target: source_coordinates # model target: azimuth_2d_point | normalized_tdoa | source_coordinates 12 | loss: angular # angular | magnitude | l1 13 | 14 | # Convolutional layer settings 15 | pool_size: [2, 2] # [1, 1] = No pooling | [1, 2] = Frequency pooling | [2, 2] = Time and frequency pooling 16 | kernel_size: [2, 2] # [1, 1] = "Beamforming" convolution | [2, 1] = Time averaging | [1, 2] = Frequency averaging 17 | conv_layers_config: 18 | - layer_1: 19 | type: complex_single # complex_single | real_single | complex_double | real_double 20 | n_channels: 64 21 | dropout_rate: 0.0 22 | - layer_2: 23 | type: complex_single 24 | n_channels: 128 25 | dropout_rate: 0.0 26 | - layer_3: 27 | type: complex_single 28 | n_channels: 256 29 | dropout_rate: 0.0 30 | - layer_4: 31 | type: complex_single 32 | n_channels: 512 33 | dropout_rate: 0.0 34 | 35 | # Miscellaneous model settings 36 | fc_layer_dropout_rate: 0.0 37 | activation: prelu # tanh | relu | prelu | amp_tanh 38 | pool_type: avg # max | avg 39 | use_complex_rnn: True 40 | init_real_layers: True 41 | -------------------------------------------------------------------------------- /config/training.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | batch_size: 32 3 | n_epochs: 20 4 | learning_rate: 0.00005 5 | learning_rate_decay_steps: [5, 10] 6 | learning_rate_decay_values: 0.5 7 | n_workers: 4 8 | delete_datasets_after_training: False 9 | -------------------------------------------------------------------------------- /notebooks/Test Model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 22, 6 | "id": "aa124043", 7 | "metadata": { 8 | "scrolled": true 9 | }, 10 | "outputs": [ 11 | { 12 | "name": "stderr", 13 | "output_type": "stream", 14 | "text": [ 15 | "\r", 16 | " 0%| | 0/500 [00:00" 88 | ] 89 | }, 90 | "metadata": { 91 | "needs_background": "light" 92 | }, 93 | "output_type": "display_data" 94 | } 95 | ], 96 | "source": [ 97 | "import numpy as np\n", 98 | "import matplotlib.pyplot as plt\n", 99 | "\n", 100 | "from tqdm import tqdm, trange\n", 101 | "\n", 102 | "eps = np.finfo(np.float32).eps\n", 103 | "\n", 104 | "\n", 105 | "def error(y_gt, y_pred):\n", 106 | " return np.sqrt((y_gt - y_pred)**2)\n", 107 | "\n", 108 | " \n", 109 | "def eval_dataset(model, dataset, sort=True, plot_results=True):\n", 110 | "\n", 111 | " output_dicts = []\n", 112 | " \n", 113 | " for sample in tqdm(dataset):\n", 114 | " x, y = sample\n", 115 | " prediction = float(eval_model(model, x.unsqueeze(0)).numpy())\n", 116 | " gt = y[\"tdoa\"]\n", 117 | " error_val = error(gt, prediction)\n", 118 | " #print(f\"ground_truth: {gt} prediction: {prediction}, error: {error_val}\")\n", 119 | " \n", 120 | " y[\"prediction\"] = prediction\n", 121 | " y[\"error\"] = error_val\n", 122 | " y[\"signals\"] = x\n", 123 | " y[\"energy\"] = 10*torch.log10(torch.sum(x[1]**2)/torch.sum(x[0]**2))\n", 124 | " output_dicts.append(y)\n", 125 | " \n", 126 | " if plot_results:\n", 127 | " plot_sample_results(y)\n", 128 | " \n", 129 | " if sort: \n", 130 | " output_dicts.sort(key=lambda x: x[\"error\"], reverse=True)\n", 131 | " \n", 132 | " print_dataset_stats(output_dicts) \n", 133 | " \n", 134 | " return output_dicts\n", 135 | "\n", 136 | "def plot_sample_results(d):\n", 137 | " delay_in_ms = d['mic_delays'][1]*1000\n", 138 | " print(f\"Prediction {d['prediction']}, Error: {d['error']}\")\n", 139 | " plot_mics_and_sources(d[\"room_dims\"], d[\"mic_coordinates\"], [d[\"source_coordinates\"]])\n", 140 | " \n", 141 | " plt.show()\n", 142 | "\n", 143 | " #plot_microphone_signals(d[\"signals\"], d, share_axis=False)\n", 144 | " \n", 145 | " plt.show()\n", 146 | " \n", 147 | " \n", 148 | "def print_dataset_stats(output_dicts):\n", 149 | " errors = [d[\"error\"] for d in output_dicts]\n", 150 | " avg_error = np.mean(errors)\n", 151 | " std_error = np.std(errors)\n", 152 | " print(f\"Error mean: {avg_error} Error std: {std_error}\")\n", 153 | " \n", 154 | "\n", 155 | "def filter_points(x, y, low, high):\n", 156 | " zipped = zip(x, y)\n", 157 | " \n", 158 | " filtered = filter(lambda x: x[0] >= low and x[0] < high, zipped)\n", 159 | " \n", 160 | " x, y = zip(*filtered)\n", 161 | "\n", 162 | " return x, y\n", 163 | "\n", 164 | "\n", 165 | "def denormalize(normalized, min_v, max_v):\n", 166 | " return (normalized * (max_v-min_v) + min_v)\n", 167 | "\n", 168 | "\n", 169 | "def scatter_prediction_vs_gt(output_dicts, label, plot_line=True, ax=None):\n", 170 | " if ax is None:\n", 171 | " fig, ax = plt.subplots()\n", 172 | " \n", 173 | " predictions = [d[\"prediction\"] for d in output_dicts]\n", 174 | " targets = [d[\"tdoa\"] for d in output_dicts]\n", 175 | " \n", 176 | " predictions = np.array([1000*p for p in predictions]) # convert to milliseconds\n", 177 | " targets = np.array([1000*t for t in targets])\n", 178 | " \n", 179 | " coef = np.polyfit(np.float32(targets),predictions,1, rcond = len(targets)*eps)\n", 180 | " poly1d_fn = np.poly1d(coef)\n", 181 | " ax.scatter(targets, predictions, label=label)\n", 182 | " #plt.plot(targets, poly1d_fn(targets), '--g', label=\"regression\")\n", 183 | " \n", 184 | " plt.xlabel(\"True TDOA (ms)\")\n", 185 | " plt.ylabel(\"Estimated TDOA (ms)\")\n", 186 | " \n", 187 | " if plot_line:\n", 188 | " max_tdoa = 1/343\n", 189 | " min_tdoa = - max_tdoa\n", 190 | " expected_line = np.linspace(1000*min_tdoa, 1000*max_tdoa, len(predictions))\n", 191 | " ax.plot(expected_line, expected_line, label=\"Ground truth\", color=\"lime\")\n", 192 | " ax.legend()\n", 193 | " #plt.title(label)\n", 194 | " return ax\n", 195 | "\n", 196 | " \n", 197 | "\n", 198 | "def scatter_energy_vs_prediction(output_dicts, ax=None):\n", 199 | " if ax is None:\n", 200 | " fig, ax = plt.subplots()\n", 201 | " \n", 202 | " predictions = [d[\"prediction\"] for d in output_dicts]\n", 203 | " log_energy_ratios = [d[\"energy\"] for d in output_dicts] #multiply by 10 to be dB\n", 204 | " \n", 205 | " predictions = np.array([1000*p for p in predictions]) # convert to milliseconds\n", 206 | " \n", 207 | " log_energy_ratios, predictions = filter_points(log_energy_ratios, predictions, -2, 2)\n", 208 | " \n", 209 | " ax.scatter(log_energy_ratios, predictions, color=\"orange\")\n", 210 | " \n", 211 | " ax.set_xlabel(\"Log energy ratio (dB)\")\n", 212 | " ax.set_ylabel(\"Estimated TDOA (ms)\")\n", 213 | "\n", 214 | " ax.legend()\n", 215 | " \n", 216 | " return ax\n", 217 | "\n", 218 | "#output_dicts2 = eval_dataset(model, dataset, sort=True)\n", 219 | "\n", 220 | "fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(5, 4))\n", 221 | "\n", 222 | "scatter_prediction_vs_gt(output_dicts, \"No gain\", ax=axes[1])\n", 223 | "scatter_prediction_vs_gt(output_dicts2, \"Gain=0.5 on first mic.\", False, ax=axes[1])\n", 224 | "\n", 225 | "scatter_energy_vs_prediction(output_dicts, axes[0])\n", 226 | "# scatter_energy_vs_prediction(output_dicts, axes[1])\n", 227 | "\n", 228 | "plt.tight_layout()\n", 229 | "plt.show()" 230 | ] 231 | } 232 | ], 233 | "metadata": { 234 | "kernelspec": { 235 | "display_name": "Python 3 (ipykernel)", 236 | "language": "python", 237 | "name": "python3" 238 | }, 239 | "language_info": { 240 | "codemirror_mode": { 241 | "name": "ipython", 242 | "version": 3 243 | }, 244 | "file_extension": ".py", 245 | "mimetype": "text/x-python", 246 | "name": "python", 247 | "nbconvert_exporter": "python", 248 | "pygments_lexer": "ipython3", 249 | "version": "3.9.7" 250 | } 251 | }, 252 | "nbformat": 4, 253 | "nbformat_minor": 5 254 | } 255 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.1.0 2 | librosa==0.8.1 3 | matplotlib==3.4.2 4 | pandas==1.2.4 5 | pyroomasync==0.0.23 6 | pytorch-lightning==1.4.2 7 | scikit-learn==0.24.2 8 | torch==1.9.0 9 | torchaudio==0.9.0 10 | tqdm==4.61.0 -------------------------------------------------------------------------------- /tdoa/correlation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.fft as fft 3 | from scipy.signal.signaltools import correlate, correlation_lags 4 | 5 | 6 | def compute_correlations(simulation_results, fs, mode="gcc-phat"): 7 | n_microphones = simulation_results.shape[0] 8 | 9 | correlations = {} 10 | for n_microphone_1 in range(n_microphones): 11 | for n_microphone_2 in range(n_microphone_1 + 1, n_microphones): 12 | key = (n_microphone_1, n_microphone_2) 13 | cc, lag_indexes = _compute_correlation( 14 | simulation_results[n_microphone_1], 15 | simulation_results[n_microphone_2], 16 | fs, 17 | mode=mode 18 | ) 19 | tdoa = lag_indexes[np.argmax(np.abs(cc))] 20 | correlations[key] = tdoa, cc 21 | 22 | return correlations 23 | 24 | 25 | 26 | def _compute_correlation(x1, x2, fs, mode="gcc-phat"): 27 | if mode == "gcc-phat": 28 | cc, lag_indexes = gcc_phat(x1, x2, fs) 29 | else: 30 | cc, lag_indexes = temporal_cross_correlation(x1, x2, fs) 31 | 32 | return cc, lag_indexes 33 | 34 | 35 | def temporal_cross_correlation(x1, x2, fs): 36 | 37 | # Normalize signals for a normalized correlation 38 | # https://github.com/numpy/numpy/issues/2310 39 | x1 = (x1 - np.mean(x1)) / (np.std(x1) * len(x1)) 40 | x2 = (x2 - np.mean(x2)) / (np.std(x2) * len(x2)) 41 | 42 | cc = correlate(x1, x2, mode="same") 43 | lag_indexes = correlation_lags(x1.shape[0], x2.shape[0], mode="same") 44 | 45 | cc = np.abs(cc) 46 | 47 | return cc, lag_indexes/fs 48 | 49 | 50 | def gcc_phat(x1, x2, fs): 51 | ''' 52 | This function computes the offset between the signal sig and the reference signal x2 53 | using the Generalized Cross Correlation - Phase Transform (GCC-PHAT) method. 54 | Implementation based on http://www.xavieranguera.com/phdthesis/node92.html 55 | ''' 56 | 57 | n = x1.shape[0] # + x2.shape[0] 58 | 59 | X1 = np.fft.rfft(x1, n=n) 60 | X2 = np.fft.rfft(x2, n=n) 61 | R = X1 * np.conj(X2) 62 | Gphat = R / np.abs(R) 63 | cc = np.fft.irfft(Gphat, n=n) 64 | 65 | max_shift = n // 2 66 | 67 | cc = np.concatenate((cc[-max_shift:], cc[:max_shift+1])) 68 | 69 | indxs = np.zeros_like(cc) 70 | indxs[0:max_shift] = - np.arange(max_shift, 0, -1) 71 | indxs[max_shift:] = np.arange(0, max_shift + 1) 72 | indxs = indxs/fs 73 | 74 | return cc, indxs -------------------------------------------------------------------------------- /tdoa/math_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.fft 4 | 5 | SPEED_OF_SOUND = 343.0 6 | 7 | 8 | def compute_distance(p1, p2, mode="numpy"): 9 | "Compute the euclidean distance between two points" 10 | 11 | if mode == "numpy": 12 | p1 = np.array(p1) 13 | p2 = np.array(p2) 14 | return np.linalg.norm(p1 - p2) 15 | elif mode == "torch": 16 | return torch.linalg.norm(p1 - p2) 17 | 18 | 19 | def compute_tdoa(source, microphones): 20 | dist_0 = compute_distance(source, microphones[0]) 21 | dist_1 = compute_distance(source, microphones[1]) 22 | 23 | return (dist_0 - dist_1)/SPEED_OF_SOUND 24 | 25 | 26 | def normalize_tdoa(tdoa, mic_distance): 27 | max_tdoa = mic_distance/SPEED_OF_SOUND 28 | min_tdoa = -max_tdoa 29 | 30 | return normalize(tdoa, min_tdoa, max_tdoa) 31 | 32 | 33 | def normalize(x, min_x, max_x): 34 | return (x - min_x)/(max_x - min_x) 35 | 36 | 37 | def denormalize(x, min_x, max_x): 38 | return x*(max_x - min_x) + min_x 39 | 40 | 41 | def gcc_phat(x1, x2, fs): 42 | """ 43 | This function computes the offset between the signal sig and the reference signal x2 44 | using the Generalized Cross Correlation - Phase Transform (GCC-PHAT) method. 45 | Implementation based on http://www.xavieranguera.com/phdthesis/node92.html 46 | """ 47 | 48 | n = x1.shape[0] # + x2.shape[0] 49 | 50 | X1 = torch.fft.rfft(x1, n=n) 51 | X2 = torch.fft.rfft(x2, n=n) 52 | R = X1 * torch.conj(X2) 53 | Gphat = R / torch.abs(R) 54 | cc = torch.fft.irfft(Gphat, n=n) 55 | 56 | max_shift = n // 2 57 | 58 | cc = torch.cat((cc[-max_shift:], cc[:max_shift+1])) 59 | 60 | indxs = torch.zeros_like(cc) 61 | indxs[0:max_shift] = - torch.arange(max_shift, 0, -1) 62 | indxs[max_shift:] = torch.arange(0, max_shift + 1) 63 | indxs = indxs/fs 64 | 65 | return cc, indxs 66 | 67 | 68 | def compute_doa(m1, m2, s, radians=True): 69 | """Get the direction of arrival between two microphones and a source. 70 | The referential used is the direction of the two sources, that is, 71 | the vector m1 - m2. 72 | 73 | For more details, see: 74 | https://math.stackexchange.com/questions/878785/how-to-find-an-angle-in-range0-360-between-2-vectors/879474 75 | 76 | Args: 77 | m1 (np.array): 2d coordinates of microphone 1 78 | m2 (np.array): 2d coordinates of microphone 2 79 | s (np.array): 2d coordinates of the source 80 | radians (bool): If True, result is between [-pi, pi). Else, result is between [0, 360) 81 | """ 82 | 83 | reference_direction = m1 - m2 84 | mic_centre = (m1 + m2)/2 85 | source_direction = s - mic_centre 86 | 87 | doa = compute_angle_between_vectors(reference_direction, source_direction, 88 | radians=radians) 89 | 90 | return doa 91 | 92 | 93 | def compute_angle_between_vectors(v1, v2, radians=True): 94 | dot = np.dot(v1, v2) 95 | det = np.linalg.det([v1, v2]) 96 | 97 | doa = np.arctan2(det, dot) 98 | 99 | if not radians: 100 | doa = np.rad2deg(doa) 101 | 102 | return doa 103 | -------------------------------------------------------------------------------- /tdoa/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tdoa.math_utils import compute_distance 4 | 5 | 6 | def compute_error(candidate, mic_1, mic_2, tdoa, norm="l1"): 7 | "Get a measure of how far a candidate point (x, y) is from a computed doa" 8 | 9 | dist_1 = compute_distance(candidate, mic_1) 10 | dist_2 = compute_distance(candidate, mic_2) 11 | 12 | error = tdoa - np.abs(dist_1 - dist_2) 13 | 14 | if norm == "l2": 15 | error = error**2 16 | elif norm == "l1": 17 | error = np.abs(error) 18 | 19 | return error 20 | -------------------------------------------------------------------------------- /tdoa/srp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pyroomacoustics as pra 3 | 4 | from pyroomacoustics.doa.srp import SRP 5 | 6 | from pyroomasync.settings import SPEED_OF_SOUND 7 | 8 | NFFT = 1024 9 | 10 | 11 | def srp_phat(simulation_results, room, nfft=NFFT): 12 | 13 | microphone_positions = np.array(room.microphones.get_positions()).T 14 | 15 | sources = room.sources 16 | fs = room.fs 17 | 18 | srp = SRP( 19 | microphone_positions, 20 | fs, 21 | nfft, 22 | c=SPEED_OF_SOUND, 23 | num_src=len(sources), 24 | mode='near' 25 | ) 26 | 27 | # Compute the STFT frames needed 28 | simulation_results_stft = np.array( 29 | [ 30 | pra.transform.stft.analysis(signal, nfft, nfft // 2).T 31 | for signal in simulation_results 32 | ] 33 | ) 34 | 35 | srp.locate_sources(simulation_results_stft) 36 | 37 | return srp 38 | -------------------------------------------------------------------------------- /tdoa/tdoa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tdoa.metrics import compute_error 4 | from tdoa.correlation import compute_correlations 5 | 6 | 7 | def compute_tdoas(simulation_results, simulation_fs, mode="gcc-phat"): 8 | correlations = compute_correlations( 9 | simulation_results, simulation_fs, mode=mode) 10 | tdoas = { 11 | key: value[0] 12 | for key, value in correlations.items() 13 | } 14 | 15 | return tdoas 16 | 17 | 18 | def get_top_candidate_points( 19 | mic_1, mic_2, room_dims, tdoa, norm="l2", step=0.1, n_candidates=200): 20 | candidates = [] 21 | 22 | for x in np.arange(0, room_dims[0], step): 23 | for y in np.arange(0, room_dims[1], step): 24 | error = compute_error((x, y), mic_1, mic_2, tdoa, norm=norm) 25 | candidates.append((x, y, error)) 26 | 27 | candidates = np.array(candidates) 28 | sorted_candidates = candidates[candidates[:, 2].argsort()] 29 | top_candidates = sorted_candidates[0:n_candidates] 30 | 31 | return top_candidates 32 | 33 | 34 | def tdoa_sum_error_grid(room, microphone_distances, norm="l2", grid_step=0.1): 35 | room_dims = room.dims[0:2] 36 | mics = [ 37 | mic.loc[0:2] for mic in 38 | room.microphones.mic_array 39 | ] 40 | 41 | grids = [] 42 | for mic_ixs, distance in microphone_distances.items(): 43 | mic_1 = mics[mic_ixs[0]] 44 | mic_2 = mics[mic_ixs[1]] 45 | 46 | grid = _tdoa_error_grid(mic_1, mic_2, room_dims, distance, 47 | norm=norm, grid_step=grid_step) 48 | grids.append(grid) 49 | 50 | return np.sum(grids, 0) 51 | 52 | 53 | def _tdoa_error_grid(mic_1, mic_2, room_dims, target_tdoa, 54 | norm="l2", grid_step=0.1): 55 | 56 | num_x_points = int(room_dims[0]/grid_step) 57 | num_y_points = int(room_dims[1]/grid_step) 58 | 59 | x_points = np.arange(0, room_dims[0], grid_step) 60 | y_points = np.arange(0, room_dims[1], grid_step) 61 | 62 | grid = np.zeros((num_x_points, num_y_points)) 63 | 64 | for ix, x in enumerate(x_points): 65 | for iy, y in enumerate(y_points): 66 | error = compute_error((x, y), mic_1, mic_2, target_tdoa, norm=norm) 67 | grid[ix, iy] = error 68 | 69 | return grid 70 | 71 | -------------------------------------------------------------------------------- /tdoa/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import seaborn as sns 4 | 5 | from tdoa.tdoa import get_top_candidate_points, tdoa_sum_error_grid 6 | 7 | from tdoa.math_utils import compute_distance 8 | 9 | def plot_correlations(correlations, output_path=None): 10 | for key, value in correlations.items(): 11 | plt.plot(value[0], value[1], label=key) 12 | 13 | plt.legend() 14 | if output_path is not None: 15 | plt.savefig(output_path) 16 | else: 17 | plt.show() 18 | 19 | 20 | def plot_location_candidates(room, microphone_distances, output_path=None, grid_step=0.1): 21 | room_dims = room.dims 22 | mics = room.microphones.mic_array 23 | sources = room.sources.source_array 24 | 25 | sns.set_theme() 26 | 27 | error_grid = tdoa_sum_error_grid(room, microphone_distances, grid_step=grid_step) 28 | 29 | x_points = np.arange(0, room_dims[0], grid_step) 30 | y_points = np.arange(0, room_dims[1], grid_step) 31 | ax = sns.heatmap(error_grid) #, xticklabels=x_points, yticklabels=y_points) 32 | 33 | 34 | ax = draw_mics_and_sources(ax, 35 | mics, 36 | sources, 37 | x_max=len(x_points), y_max=len(y_points)) 38 | 39 | if output_path is not None: 40 | plt.savefig(output_path) 41 | else: 42 | plt.show() 43 | 44 | 45 | def plot_top_candidate_points(room, tdoas, output_path=None): 46 | ax = get_2d_room_plot_axis(room, plot_mics_and_sources=False) 47 | _plot_top_candidates(room, tdoas, ax) 48 | 49 | room_dims = room.dims 50 | mics = room.microphones.mic_array 51 | sources = room.sources.source_array 52 | ax = draw_mics_and_sources(ax, mics, sources) 53 | 54 | if output_path is not None: 55 | plt.savefig(output_path) 56 | else: 57 | plt.show() 58 | 59 | 60 | def get_2d_room_plot_axis(room, plot_mics_and_sources=True, 61 | plot_distances=True): 62 | plt.figure() 63 | ax = plt.gca() 64 | plt.xlim(0, room.dims[0]) 65 | plt.ylim(0, room.dims[1]) 66 | 67 | if plot_mics_and_sources: 68 | room_dims = room.dims 69 | mics = room.microphones.mic_array 70 | sources = room.sources.source_array 71 | ax = draw_mics_and_sources(ax, mics, sources) 72 | 73 | if plot_distances: 74 | _plot_source_to_microphone_distances(room, plt.gca()) 75 | 76 | return ax 77 | 78 | 79 | def plot_mics_and_sources(room_dims, mics, sources): 80 | plt.figure() 81 | ax = plt.gca() 82 | plt.xlim(0, room_dims[0]) 83 | plt.ylim(0, room_dims[1]) 84 | 85 | draw_mics_and_sources(ax, mics, sources) 86 | _plot_source_to_microphone_distances(ax, mics, sources) 87 | 88 | plt.xlabel("Width (m)") 89 | plt.ylabel("Length (m)") 90 | return ax 91 | 92 | 93 | def draw_mics_and_sources(ax, mics, sources, x_max=None, y_max=None): 94 | "Draw microphones and sources in an existing room" 95 | 96 | mics_x = [mic[0] for mic in mics] 97 | mics_y = [mic[1] for mic in mics] 98 | sources_x = [source[0] for source in sources] 99 | sources_y = [source[1] for source in sources] 100 | 101 | ax.scatter(mics_x, mics_y, marker="^", label="microphones") 102 | ax.scatter(sources_x, sources_y, marker="o", label="sources") 103 | ax.legend() 104 | ax.grid() 105 | 106 | return ax 107 | 108 | 109 | def _plot_source_to_microphone_distances(ax, mics, sources): 110 | 111 | distance = compute_distance(mics[0], mics[1]) 112 | ax.plot( 113 | [mics[0][0], mics[1][0]], 114 | [mics[0][1], mics[1][1]], 115 | "--", color="blue", 116 | label="distance={:.2f}m".format(distance) 117 | ) 118 | for source in sources: 119 | for mic in mics: 120 | distance = compute_distance(source, mic) 121 | ax.plot( 122 | [source[0], mic[0]], 123 | [source[1], mic[1]], 124 | "--", color="grey", 125 | label="distance={:.2f}m".format(distance) 126 | ) 127 | 128 | ax.legend() 129 | 130 | return ax 131 | 132 | 133 | def _plot_top_candidates(room, microphone_distances, ax): 134 | # Filter only first and second dimensions 135 | room_dims = room.dims[0:2] 136 | mics = [ 137 | mic.loc[0:2] for mic in 138 | room.microphones.mic_array 139 | ] 140 | 141 | for mic_ixs, distance in microphone_distances.items(): 142 | mic_1 = mics[mic_ixs[0]] 143 | mic_2 = mics[mic_ixs[1]] 144 | 145 | candidates = get_top_candidate_points( 146 | mic_1, mic_2, room_dims, distance 147 | ) 148 | candidates_x = [candidate[0] for candidate in candidates] 149 | candidates_y = [candidate[1] for candidate in candidates] 150 | 151 | ax.scatter(candidates_x, candidates_y, label="candidates") 152 | 153 | return ax 154 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | fixtures/csv_frames 2 | fixtures/splitted_sample -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/tests/__init__.py -------------------------------------------------------------------------------- /tests/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/tests/baselines/__init__.py -------------------------------------------------------------------------------- /tests/baselines/test_crnns.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | 5 | from complex_neural_source_localization.model import DOACNet 6 | 7 | 8 | def test_crnn10(): 9 | input_file_name = "tests/fixtures/0.0_split1_ir0_ov1_3.wav" 10 | 11 | signal = librosa.load(input_file_name, sr=24000, 12 | mono=False, dtype=np.float32)[0][np.newaxis] 13 | 14 | signal = torch.Tensor(signal) 15 | model = DOACNet() 16 | 17 | result = model(signal) 18 | -------------------------------------------------------------------------------- /tests/fixtures/0.0_split1_ir0_ov1_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/tests/fixtures/0.0_split1_ir0_ov1_3.wav -------------------------------------------------------------------------------- /tests/fixtures/fold1_room1_mix001.csv: -------------------------------------------------------------------------------- 1 | 12,3,0,-80,40 2 | 13,3,0,-80,40 3 | 14,3,0,-80,40 4 | 15,3,0,-80,40 5 | 16,3,0,-80,40 6 | 17,3,0,-80,40 7 | 18,3,0,-80,40 8 | 19,3,0,-80,40 9 | 20,3,0,-80,40 10 | 21,3,0,-80,40 11 | 22,3,0,-80,40 12 | 23,3,0,-80,40 13 | 24,3,0,-80,40 14 | 25,3,0,-80,40 15 | 26,3,0,-80,40 16 | 27,3,0,-80,40 17 | 28,3,0,-80,40 18 | 29,3,0,-80,40 19 | 44,4,1,-179,20 20 | 45,4,1,-179,20 21 | 46,4,1,-179,20 22 | 47,4,1,-179,20 23 | 48,4,1,-179,20 24 | 49,4,1,-179,20 25 | 50,4,1,-179,20 26 | 51,4,1,-179,20 27 | 51,5,2,-54,-10 28 | 52,5,2,-54,-10 29 | 53,5,2,-54,-10 30 | 54,5,2,-54,-10 31 | 55,5,2,-54,-10 32 | 56,5,2,-54,-10 33 | 57,5,2,-54,-10 34 | 58,5,2,-54,-10 35 | 59,5,2,-54,-10 36 | 60,5,2,-54,-10 37 | 61,5,2,-54,-10 38 | 62,5,2,-54,-10 39 | 63,5,2,-54,-10 40 | 64,5,2,-54,-10 41 | 65,5,2,-54,-10 42 | 65,8,3,4,-30 43 | 66,5,2,-54,-10 44 | 66,8,3,4,-30 45 | 66,1,4,95,-10 46 | 67,5,2,-54,-10 47 | 67,8,3,4,-30 48 | 67,1,4,97,-10 49 | 68,5,2,-54,-10 50 | 68,8,3,4,-30 51 | 68,1,4,99,-10 52 | 69,5,2,-54,-10 53 | 69,8,3,4,-30 54 | 69,1,4,101,-10 55 | 70,5,2,-54,-10 56 | 70,8,3,4,-30 57 | 70,1,4,103,-10 58 | 71,5,2,-54,-10 59 | 71,8,3,4,-30 60 | 71,1,4,105,-10 61 | 72,8,3,4,-30 62 | 72,1,4,107,-10 63 | 73,8,3,4,-30 64 | 73,1,4,109,-10 65 | 74,8,3,4,-30 66 | 75,8,3,4,-30 67 | 76,8,3,4,-30 68 | 76,1,4,115,-10 69 | 77,8,3,4,-30 70 | 77,1,4,117,-10 71 | 78,8,3,4,-30 72 | 78,1,4,119,-10 73 | 79,8,3,4,-30 74 | 79,1,4,121,-10 75 | 80,8,3,4,-30 76 | 80,1,4,123,-10 77 | 81,8,3,4,-30 78 | 81,1,4,125,-10 79 | 82,8,3,4,-30 80 | 82,1,4,127,-10 81 | 83,1,4,129,-10 82 | 84,1,4,131,-10 83 | 85,1,4,133,-10 84 | 88,1,4,139,-10 85 | 89,1,4,141,-10 86 | 90,1,4,143,-10 87 | 91,1,4,145,-10 88 | 92,1,4,147,-10 89 | 93,1,4,149,-10 90 | 94,1,4,151,-10 91 | 98,1,4,159,-10 92 | 99,1,4,161,-10 93 | 100,1,4,163,-10 94 | 101,1,4,165,-10 95 | 102,1,4,167,-10 96 | 103,1,4,169,-10 97 | 104,1,4,171,-10 98 | 105,1,4,173,-10 99 | 106,1,4,175,-10 100 | 107,1,4,177,-10 101 | 108,1,4,179,-10 102 | 108,0,5,-36,0 103 | 109,1,4,180,-20 104 | 109,0,5,-36,0 105 | 110,1,4,178,-20 106 | 110,0,5,-36,0 107 | 111,1,4,176,-20 108 | 111,0,5,-36,0 109 | 112,1,4,174,-20 110 | 112,0,5,-36,0 111 | 113,1,4,172,-20 112 | 113,0,5,-36,0 113 | 114,1,4,170,-20 114 | 114,0,5,-36,0 115 | 115,1,4,168,-20 116 | 115,0,5,-36,0 117 | 116,1,4,166,-20 118 | 116,0,5,-36,0 119 | 117,1,4,164,-20 120 | 117,0,5,-36,0 121 | 118,1,4,162,-20 122 | 118,0,5,-36,0 123 | 119,1,4,160,-20 124 | 119,0,5,-36,0 125 | 120,1,4,158,-20 126 | 120,0,5,-36,0 127 | 121,1,4,156,-20 128 | 121,0,5,-36,0 129 | 122,0,5,-36,0 130 | 123,0,5,-36,0 131 | 124,1,4,150,-20 132 | 124,0,5,-36,0 133 | 125,1,4,148,-20 134 | 125,0,5,-36,0 135 | 126,1,4,146,-20 136 | 126,0,5,-36,0 137 | 127,1,4,144,-20 138 | 127,0,5,-36,0 139 | 128,1,4,142,-20 140 | 128,0,5,-36,0 141 | 129,1,4,140,-20 142 | 129,0,5,-36,0 143 | 130,1,4,138,-20 144 | 130,0,5,-36,0 145 | 131,1,4,136,-20 146 | 131,0,5,-36,0 147 | 132,1,4,134,-20 148 | 132,0,5,-36,0 149 | 133,1,4,132,-20 150 | 133,0,5,-36,0 151 | 134,1,4,130,-20 152 | 134,0,5,-36,0 153 | 135,1,4,128,-20 154 | 135,0,5,-36,0 155 | 136,1,4,126,-20 156 | 136,0,5,-36,0 157 | 137,1,4,124,-20 158 | 137,0,5,-36,0 159 | 138,1,4,122,-20 160 | 138,0,5,-36,0 161 | 139,1,4,120,-20 162 | 139,0,5,-36,0 163 | 140,1,4,118,-20 164 | 140,0,5,-36,0 165 | 140,0,6,79,10 166 | 141,1,4,116,-20 167 | 141,0,5,-36,0 168 | 141,0,6,80,10 169 | 142,0,5,-36,0 170 | 142,0,6,81,10 171 | 143,0,5,-36,0 172 | 143,0,6,82,10 173 | 144,1,4,110,-20 174 | 144,0,5,-36,0 175 | 144,0,6,83,10 176 | 145,1,4,108,-20 177 | 145,0,5,-36,0 178 | 145,0,6,84,10 179 | 146,1,4,106,-20 180 | 146,0,5,-36,0 181 | 146,0,6,85,10 182 | 147,1,4,104,-20 183 | 147,0,5,-36,0 184 | 147,0,6,86,10 185 | 148,1,4,102,-20 186 | 148,0,5,-36,0 187 | 148,0,6,87,10 188 | 149,1,4,100,-20 189 | 149,0,5,-36,0 190 | 150,1,4,98,-20 191 | 150,0,5,-36,0 192 | 151,1,4,96,-20 193 | 151,0,5,-36,0 194 | 152,1,4,94,-20 195 | 152,0,5,-36,0 196 | 152,0,6,91,10 197 | 153,1,4,92,-20 198 | 153,0,5,-36,0 199 | 153,0,6,92,10 200 | 154,1,4,90,-20 201 | 154,0,5,-36,0 202 | 154,0,6,93,10 203 | 155,1,4,88,-20 204 | 155,0,5,-36,0 205 | 155,0,6,94,10 206 | 156,1,4,86,-20 207 | 156,0,5,-36,0 208 | 156,0,6,95,10 209 | 157,1,4,84,-20 210 | 157,0,5,-36,0 211 | 157,0,6,96,10 212 | 158,1,4,82,-20 213 | 158,0,5,-36,0 214 | 158,0,6,97,10 215 | 159,1,4,80,-20 216 | 159,0,5,-36,0 217 | 159,0,6,98,10 218 | 160,1,4,78,-20 219 | 160,0,5,-36,0 220 | 160,0,6,99,10 221 | 161,0,5,-36,0 222 | 161,0,6,100,10 223 | 162,1,4,74,-20 224 | 162,0,5,-36,0 225 | 163,1,4,72,-20 226 | 163,0,5,-36,0 227 | 164,1,4,70,-20 228 | 164,0,5,-36,0 229 | 164,0,6,103,10 230 | 165,1,4,68,-20 231 | 165,0,5,-36,0 232 | 165,0,6,104,10 233 | 166,1,4,66,-20 234 | 166,0,5,-36,0 235 | 166,0,6,105,10 236 | 167,1,4,64,-20 237 | 167,0,5,-36,0 238 | 167,0,6,106,10 239 | 168,1,4,62,-20 240 | 168,0,5,-36,0 241 | 168,0,6,107,10 242 | 169,1,4,60,-20 243 | 169,0,5,-36,0 244 | 169,0,6,108,10 245 | 170,1,4,58,-20 246 | 170,0,5,-36,0 247 | 170,0,6,109,10 248 | 171,1,4,56,-20 249 | 171,0,5,-36,0 250 | 171,0,6,110,10 251 | 172,1,4,54,-20 252 | 172,0,5,-36,0 253 | 172,0,6,111,10 254 | 173,1,4,52,-20 255 | 173,0,5,-36,0 256 | 174,1,4,50,-20 257 | 174,0,5,-36,0 258 | 175,1,4,48,-20 259 | 175,0,5,-36,0 260 | 175,0,6,114,10 261 | 176,1,4,46,-20 262 | 176,0,5,-36,0 263 | 176,0,6,115,10 264 | 177,1,4,44,-20 265 | 177,0,5,-36,0 266 | 177,0,6,116,10 267 | 178,1,4,42,-20 268 | 178,0,5,-36,0 269 | 178,0,6,117,10 270 | 179,1,4,40,-20 271 | 179,0,5,-36,0 272 | 179,0,6,118,10 273 | 180,1,4,38,-20 274 | 180,0,5,-36,0 275 | 180,0,6,119,10 276 | 181,0,5,-36,0 277 | 181,0,6,120,10 278 | 182,0,5,-36,0 279 | 182,0,6,121,10 280 | 183,0,5,-36,0 281 | 183,0,6,122,10 282 | 184,1,4,30,-20 283 | 184,0,5,-36,0 284 | 184,0,6,123,10 285 | 185,1,4,28,-20 286 | 185,0,5,-36,0 287 | 186,1,4,26,-20 288 | 186,0,5,-36,0 289 | 187,1,4,24,-20 290 | 187,0,5,-36,0 291 | 188,1,4,22,-20 292 | 188,0,5,-36,0 293 | 188,0,6,127,10 294 | 189,1,4,20,-20 295 | 189,0,5,-36,0 296 | 189,0,6,128,10 297 | 190,1,4,18,-20 298 | 190,0,5,-36,0 299 | 190,0,6,129,10 300 | 191,1,4,16,-20 301 | 191,0,5,-36,0 302 | 191,0,6,130,10 303 | 192,0,5,-36,0 304 | 192,0,6,131,10 305 | 193,0,5,-36,0 306 | 193,0,6,132,10 307 | 194,0,5,-36,0 308 | 194,0,6,133,10 309 | 195,0,5,-36,0 310 | 195,0,6,134,10 311 | 196,0,5,-36,0 312 | 196,0,6,135,10 313 | 197,0,5,-36,0 314 | 198,0,5,-36,0 315 | 199,0,5,-36,0 316 | 199,0,6,138,10 317 | 200,0,5,-36,0 318 | 200,0,6,139,10 319 | 201,0,5,-36,0 320 | 201,0,6,140,10 321 | 202,0,5,-36,0 322 | 202,0,6,141,10 323 | 203,0,5,-36,0 324 | 203,0,6,142,10 325 | 204,0,5,-36,0 326 | 204,0,6,143,10 327 | 205,0,5,-36,0 328 | 205,0,6,144,10 329 | 206,0,5,-36,0 330 | 206,0,6,145,10 331 | 207,0,5,-36,0 332 | 207,0,6,146,10 333 | 208,0,5,-36,0 334 | 208,0,6,147,10 335 | 209,0,5,-36,0 336 | 210,0,5,-36,0 337 | 211,0,5,-36,0 338 | 211,0,6,150,10 339 | 212,0,5,-36,0 340 | 212,0,6,151,10 341 | 213,0,5,-36,0 342 | 213,0,6,152,10 343 | 214,0,5,-36,0 344 | 214,0,6,153,10 345 | 215,0,5,-36,0 346 | 215,0,6,154,10 347 | 216,0,5,-36,0 348 | 216,0,6,155,10 349 | 217,0,5,-36,0 350 | 217,0,6,156,10 351 | 218,0,5,-36,0 352 | 218,0,6,157,10 353 | 219,0,5,-36,0 354 | 219,0,6,158,10 355 | 220,0,5,-36,0 356 | 220,0,6,159,10 357 | 221,0,5,-36,0 358 | 222,0,5,-36,0 359 | 223,0,6,162,10 360 | 224,0,6,163,10 361 | 225,0,6,164,10 362 | 226,0,6,165,10 363 | 227,0,6,166,10 364 | 228,0,6,167,10 365 | 229,0,6,168,10 366 | 230,0,6,169,10 367 | 231,0,6,170,10 368 | 233,8,7,-33,10 369 | 234,8,7,-33,10 370 | 235,0,6,174,10 371 | 235,8,7,-33,10 372 | 236,0,6,175,10 373 | 236,8,7,-33,10 374 | 237,0,6,176,10 375 | 237,8,7,-33,10 376 | 238,0,6,177,10 377 | 238,8,7,-33,10 378 | 239,0,6,178,10 379 | 239,8,7,-33,10 380 | 239,2,8,111,10 381 | 240,0,6,179,10 382 | 240,8,7,-33,10 383 | 240,2,8,110,10 384 | 241,0,6,180,10 385 | 241,8,7,-33,10 386 | 241,2,8,109,10 387 | 242,0,6,180,0 388 | 242,8,7,-33,10 389 | 242,2,8,108,10 390 | 243,0,6,179,0 391 | 243,8,7,-33,10 392 | 243,2,8,107,10 393 | 244,8,7,-33,10 394 | 244,2,8,106,10 395 | 245,8,7,-33,10 396 | 245,2,8,105,10 397 | 246,0,6,176,0 398 | 246,2,8,104,10 399 | 247,0,6,175,0 400 | 247,2,8,103,10 401 | 248,0,6,174,0 402 | 248,2,8,102,10 403 | 249,0,6,173,0 404 | 249,2,8,101,10 405 | 250,0,6,172,0 406 | 250,2,8,100,10 407 | 251,0,6,171,0 408 | 251,2,8,99,10 409 | 252,0,6,170,0 410 | 252,2,8,98,10 411 | 253,0,6,169,0 412 | 253,2,8,97,10 413 | 254,0,6,168,0 414 | 254,2,8,96,10 415 | 255,0,6,167,0 416 | 255,2,8,95,10 417 | 256,2,8,94,10 418 | 257,2,8,93,10 419 | 258,0,6,164,0 420 | 258,2,8,92,10 421 | 259,0,6,163,0 422 | 259,2,8,91,10 423 | 260,0,6,162,0 424 | 260,2,8,90,10 425 | 261,0,6,161,0 426 | 261,2,8,89,10 427 | 262,0,6,160,0 428 | 262,2,8,88,10 429 | 263,0,6,159,0 430 | 263,2,8,87,10 431 | 264,0,6,158,0 432 | 264,2,8,86,10 433 | 265,0,6,157,0 434 | 265,2,8,85,10 435 | 266,0,6,156,0 436 | 266,2,8,84,10 437 | 267,0,6,155,0 438 | 267,2,8,83,10 439 | 268,2,8,82,10 440 | 287,2,9,56,0 441 | 288,2,9,56,0 442 | 289,2,9,56,0 443 | 290,2,9,56,0 444 | 291,2,9,56,0 445 | 292,2,9,56,0 446 | 293,2,9,56,0 447 | 294,2,9,56,0 448 | 295,2,9,56,0 449 | 296,2,9,56,0 450 | 297,2,9,56,0 451 | 298,2,9,56,0 452 | 299,2,9,56,0 453 | 300,2,9,56,0 454 | 301,2,9,56,0 455 | 302,2,9,56,0 456 | 303,2,9,56,0 457 | 304,2,9,56,0 458 | 305,2,9,56,0 459 | 306,2,9,56,0 460 | 316,4,10,180,-20 461 | 317,4,10,177,-10 462 | 318,4,10,173,-10 463 | 319,4,10,169,-10 464 | 320,4,10,165,-10 465 | 321,4,10,161,-10 466 | 322,4,10,157,-10 467 | 323,4,10,153,-10 468 | 324,4,10,149,-10 469 | 325,4,10,145,-10 470 | 326,4,10,141,-10 471 | 327,4,10,137,-10 472 | 328,4,10,133,-10 473 | 329,4,10,129,-10 474 | 330,4,10,125,-10 475 | 331,4,10,121,-10 476 | 332,4,10,117,-10 477 | 333,4,10,113,-10 478 | 334,4,10,109,-10 479 | 337,4,10,97,-10 480 | 338,4,10,93,-10 481 | 339,4,10,89,-10 482 | 340,4,10,85,-10 483 | 341,4,10,81,-10 484 | 342,4,10,77,-10 485 | 343,4,10,73,-10 486 | 344,4,10,69,-10 487 | 345,4,10,65,-10 488 | 346,4,10,61,-10 489 | 347,4,10,57,-10 490 | 348,4,10,53,-10 491 | 349,4,10,49,-10 492 | 350,4,10,45,-10 493 | 351,4,10,41,-10 494 | 352,4,10,37,-10 495 | 353,4,10,33,-10 496 | 354,4,10,29,-10 497 | 358,1,11,-51,20 498 | 359,1,11,-51,20 499 | 360,1,11,-51,20 500 | 361,1,11,-51,20 501 | 362,1,11,-51,20 502 | 363,1,11,-51,20 503 | 364,1,11,-51,20 504 | 365,1,11,-51,20 505 | 366,1,11,-51,20 506 | 367,1,11,-51,20 507 | 368,1,11,-51,20 508 | 369,1,11,-51,20 509 | 370,1,11,-51,20 510 | 371,4,10,-39,-10 511 | 371,1,11,-51,20 512 | 372,4,10,-43,-10 513 | 372,1,11,-51,20 514 | 373,4,10,-47,-10 515 | 373,1,11,-51,20 516 | 374,4,10,-51,-10 517 | 374,1,11,-51,20 518 | 375,4,10,-55,-10 519 | 375,1,11,-51,20 520 | 376,4,10,-59,-10 521 | 376,1,11,-51,20 522 | 377,4,10,-63,-10 523 | 377,1,11,-51,20 524 | 378,4,10,-67,-10 525 | 378,1,11,-51,20 526 | 379,4,10,-71,-10 527 | 379,1,11,-51,20 528 | 380,4,10,-75,-10 529 | 380,1,11,-51,20 530 | 381,4,10,-79,-10 531 | 381,1,11,-51,20 532 | 382,4,10,-83,-10 533 | 382,1,11,-51,20 534 | 383,1,11,-51,20 535 | 384,1,11,-51,20 536 | 384,2,12,-75,-20 537 | 385,1,11,-51,20 538 | 385,2,12,-75,-20 539 | 386,1,11,-51,20 540 | 386,2,12,-75,-20 541 | 387,1,11,-51,20 542 | 387,2,12,-75,-20 543 | 388,1,11,-51,20 544 | 388,2,12,-75,-20 545 | 389,1,11,-51,20 546 | 389,2,12,-75,-20 547 | 390,1,11,-51,20 548 | 390,2,12,-75,-20 549 | 391,1,11,-51,20 550 | 391,2,12,-75,-20 551 | 392,1,11,-51,20 552 | 392,2,12,-75,-20 553 | 393,1,11,-51,20 554 | 393,2,12,-75,-20 555 | 394,1,11,-51,20 556 | 394,2,12,-75,-20 557 | 395,1,11,-51,20 558 | 395,2,12,-75,-20 559 | 396,1,11,-51,20 560 | 396,2,12,-75,-20 561 | 397,1,11,-51,20 562 | 397,2,12,-75,-20 563 | 398,1,11,-51,20 564 | 398,2,12,-75,-20 565 | 399,1,11,-51,20 566 | 399,2,12,-75,-20 567 | 400,1,11,-51,20 568 | 400,2,12,-75,-20 569 | 401,1,11,-51,20 570 | 402,1,11,-51,20 571 | 403,1,11,-51,20 572 | 404,1,11,-51,20 573 | 405,1,11,-51,20 574 | 406,1,11,-51,20 575 | 407,1,11,-51,20 576 | 408,1,11,-51,20 577 | 409,1,11,-51,20 578 | 410,1,11,-51,20 579 | 411,1,11,-51,20 580 | 412,1,11,-51,20 581 | 413,1,11,-51,20 582 | 414,1,11,-51,20 583 | 415,1,11,-51,20 584 | 416,1,11,-51,20 585 | 417,1,11,-51,20 586 | 417,8,14,123,-20 587 | 418,1,11,-51,20 588 | 418,8,14,123,-20 589 | 419,1,11,-51,20 590 | 419,8,14,123,-20 591 | 420,1,11,-51,20 592 | 420,0,13,40,-30 593 | 420,8,14,123,-20 594 | 421,1,11,-51,20 595 | 421,0,13,36,-30 596 | 421,8,14,123,-20 597 | 422,1,11,-51,20 598 | 422,0,13,32,-30 599 | 422,8,14,123,-20 600 | 423,1,11,-51,20 601 | 423,0,13,28,-30 602 | 423,8,14,123,-20 603 | 424,1,11,-51,20 604 | 424,0,13,24,-30 605 | 424,8,14,123,-20 606 | 425,1,11,-51,20 607 | 425,0,13,20,-30 608 | 425,8,14,123,-20 609 | 426,1,11,-51,20 610 | 426,0,13,16,-30 611 | 426,8,14,123,-20 612 | 427,1,11,-51,20 613 | 427,0,13,12,-30 614 | 427,8,14,123,-20 615 | 428,1,11,-51,20 616 | 428,0,13,8,-30 617 | 428,8,14,123,-20 618 | 429,1,11,-51,20 619 | 429,0,13,4,-30 620 | 429,8,14,123,-20 621 | 430,1,11,-51,20 622 | 430,0,13,0,-30 623 | 430,8,14,123,-20 624 | 431,1,11,-51,20 625 | 431,0,13,-4,-30 626 | 431,8,14,123,-20 627 | 432,1,11,-51,20 628 | 432,0,13,-8,-30 629 | 432,8,14,123,-20 630 | 433,1,11,-51,20 631 | 433,0,13,-12,-30 632 | 433,8,14,123,-20 633 | 434,1,11,-51,20 634 | 434,0,13,-16,-30 635 | 434,8,14,123,-20 636 | 435,1,11,-51,20 637 | 435,0,13,-20,-30 638 | 435,8,14,123,-20 639 | 436,1,11,-51,20 640 | 436,0,13,-24,-30 641 | 436,8,14,123,-20 642 | 437,1,11,-51,20 643 | 437,0,13,-28,-30 644 | 437,8,14,123,-20 645 | 438,1,11,-51,20 646 | 438,0,13,-32,-30 647 | 438,8,14,123,-20 648 | 439,1,11,-51,20 649 | 439,0,13,-36,-30 650 | 439,8,14,123,-20 651 | 440,1,11,-51,20 652 | 440,0,13,-40,-30 653 | 440,8,14,123,-20 654 | 441,1,11,-51,20 655 | 441,0,13,-44,-30 656 | 442,1,11,-51,20 657 | 442,0,13,-48,-30 658 | 443,1,11,-51,20 659 | 443,0,13,-52,-30 660 | 444,1,11,-51,20 661 | 444,0,13,-56,-30 662 | 445,1,11,-51,20 663 | 445,0,13,-60,-30 664 | 446,1,11,-51,20 665 | 446,0,13,-64,-30 666 | 447,1,11,-51,20 667 | 447,0,13,-68,-30 668 | 448,1,11,-51,20 669 | 448,0,13,-72,-30 670 | 449,1,11,-51,20 671 | 449,0,13,-76,-30 672 | 450,1,11,-51,20 673 | 450,0,13,-80,-30 674 | 451,1,11,-51,20 675 | 451,0,13,-84,-30 676 | 452,1,11,-51,20 677 | 452,0,13,-88,-30 678 | 453,1,11,-51,20 679 | 453,0,13,-92,-30 680 | 454,1,11,-51,20 681 | 454,0,13,-96,-30 682 | 455,1,11,-51,20 683 | 455,0,13,-100,-30 684 | 456,1,11,-51,20 685 | 456,0,13,-104,-30 686 | 457,1,11,-51,20 687 | 457,0,13,-108,-30 688 | 458,1,11,-51,20 689 | 458,0,13,-112,-30 690 | 459,1,11,-51,20 691 | 459,0,13,-116,-30 692 | 460,1,11,-51,20 693 | 460,0,13,-120,-30 694 | 461,1,11,-51,20 695 | 461,0,13,-124,-30 696 | 462,1,11,-51,20 697 | 462,0,13,-128,-30 698 | 463,1,11,-51,20 699 | 463,0,13,-132,-30 700 | 464,1,11,-51,20 701 | 464,0,13,-136,-30 702 | 465,1,11,-51,20 703 | 465,0,13,-140,-30 704 | 466,1,11,-51,20 705 | 466,0,13,-144,-30 706 | 467,1,11,-51,20 707 | 467,0,13,-148,-30 708 | 468,1,11,-51,20 709 | 468,0,13,-152,-30 710 | 469,1,11,-51,20 711 | 469,0,13,-156,-30 712 | 469,11,15,-175,0 713 | 470,1,11,-51,20 714 | 470,0,13,-160,-30 715 | 470,11,15,-175,0 716 | 471,1,11,-51,20 717 | 471,0,13,-164,-30 718 | 471,11,15,-175,0 719 | 472,1,11,-51,20 720 | 472,0,13,-168,-30 721 | 472,11,15,-175,0 722 | 473,1,11,-51,20 723 | 473,0,13,-172,-30 724 | 473,11,15,-175,0 725 | 474,1,11,-51,20 726 | 474,0,13,-176,-30 727 | 474,11,15,-175,0 728 | 475,1,11,-51,20 729 | 475,0,13,-180,-30 730 | 475,11,15,-175,0 731 | 476,1,11,-51,20 732 | 476,0,13,-177,-20 733 | 476,11,15,-175,0 734 | 477,1,11,-51,20 735 | 477,0,13,-173,-20 736 | 477,11,15,-175,0 737 | 478,1,11,-51,20 738 | 478,0,13,-169,-20 739 | 478,11,15,-175,0 740 | 479,1,11,-51,20 741 | 479,0,13,-165,-20 742 | 479,11,15,-175,0 743 | 480,1,11,-51,20 744 | 480,0,13,-161,-20 745 | 481,1,11,-51,20 746 | 481,0,13,-157,-20 747 | 482,1,11,-51,20 748 | 482,0,13,-153,-20 749 | 483,1,11,-51,20 750 | 483,0,13,-149,-20 751 | 484,1,11,-51,20 752 | 484,0,13,-145,-20 753 | 485,1,11,-51,20 754 | 485,0,13,-141,-20 755 | 486,1,11,-51,20 756 | 486,0,13,-137,-20 757 | 487,1,11,-51,20 758 | 487,0,13,-133,-20 759 | 488,1,11,-51,20 760 | 488,0,13,-129,-20 761 | 489,1,11,-51,20 762 | 489,0,13,-125,-20 763 | 490,1,11,-51,20 764 | 490,0,13,-121,-20 765 | 491,1,11,-51,20 766 | 491,0,13,-117,-20 767 | 492,1,11,-51,20 768 | 492,0,13,-113,-20 769 | 493,1,11,-51,20 770 | 493,0,13,-109,-20 771 | 494,1,11,-51,20 772 | 494,0,13,-105,-20 773 | 495,1,11,-51,20 774 | 495,0,13,-101,-20 775 | 496,1,11,-51,20 776 | 496,0,13,-97,-20 777 | 497,1,11,-51,20 778 | 497,0,13,-93,-20 779 | 498,1,11,-51,20 780 | 498,0,13,-89,-20 781 | 499,1,11,-51,20 782 | 499,0,13,-85,-20 783 | 500,1,11,-51,20 784 | 500,0,13,-81,-20 785 | 501,1,11,-51,20 786 | 501,0,13,-77,-20 787 | 502,1,11,-51,20 788 | 502,0,13,-73,-20 789 | 503,1,11,-51,20 790 | 503,0,13,-69,-20 791 | 504,1,11,-51,20 792 | 504,0,13,-65,-20 793 | 505,1,11,-51,20 794 | 505,0,13,-61,-20 795 | 506,1,11,-51,20 796 | 506,0,13,-57,-20 797 | 507,1,11,-51,20 798 | 507,0,13,-53,-20 799 | 508,1,11,-51,20 800 | 508,0,13,-49,-20 801 | 509,1,11,-51,20 802 | 509,0,13,-45,-20 803 | 510,1,11,-51,20 804 | 510,0,13,-41,-20 805 | 511,1,11,-51,20 806 | 511,0,13,-37,-20 807 | 512,1,11,-51,20 808 | 512,0,13,-33,-20 809 | 513,1,11,-51,20 810 | 513,0,13,-29,-20 811 | 514,1,11,-51,20 812 | 514,0,13,-25,-20 813 | 515,1,11,-51,20 814 | 515,0,13,-21,-20 815 | 516,1,11,-51,20 816 | 516,0,13,-17,-20 817 | 517,1,11,-51,20 818 | 517,0,13,-13,-20 819 | 518,1,11,-51,20 820 | 518,0,13,-9,-20 821 | 519,1,11,-51,20 822 | 519,0,13,-5,-20 823 | 520,1,11,-51,20 824 | 520,0,13,-1,-20 825 | 521,1,11,-51,20 826 | 521,0,13,3,-20 827 | 522,1,11,-51,20 828 | 522,0,13,7,-20 829 | 523,1,11,-51,20 830 | 523,0,13,11,-20 831 | 524,1,11,-51,20 832 | 524,0,13,15,-20 833 | 525,1,11,-51,20 834 | 525,0,13,19,-20 835 | 526,1,11,-51,20 836 | 526,0,13,23,-20 837 | 527,1,11,-51,20 838 | 527,0,13,27,-20 839 | 528,1,11,-51,20 840 | 528,0,13,31,-20 841 | 529,1,11,-51,20 842 | 529,0,13,35,-20 843 | 529,6,16,110,-10 844 | 530,1,11,-51,20 845 | 530,0,13,39,-20 846 | 530,6,16,110,-10 847 | 531,1,11,-51,20 848 | 531,0,13,43,-20 849 | 531,6,16,110,-10 850 | 532,1,11,-51,20 851 | 532,0,13,47,-20 852 | 532,6,16,110,-10 853 | 533,1,11,-51,20 854 | 533,0,13,51,-20 855 | 533,6,16,110,-10 856 | 534,1,11,-51,20 857 | 534,0,13,55,-20 858 | 534,6,16,110,-10 859 | 535,1,11,-51,20 860 | 535,0,13,59,-20 861 | 535,6,16,110,-10 862 | 536,1,11,-51,20 863 | 536,0,13,63,-20 864 | 536,6,16,110,-10 865 | 537,1,11,-51,20 866 | 537,0,13,67,-20 867 | 537,6,16,110,-10 868 | 538,1,11,-51,20 869 | 538,0,13,71,-20 870 | 538,6,16,110,-10 871 | 539,1,11,-51,20 872 | 539,0,13,75,-20 873 | 539,6,16,110,-10 874 | 540,1,11,-51,20 875 | 540,0,13,79,-20 876 | 540,6,16,110,-10 877 | 541,1,11,-51,20 878 | 541,0,13,83,-20 879 | 541,6,16,110,-10 880 | 542,1,11,-51,20 881 | 542,0,13,87,-20 882 | 542,6,16,110,-10 883 | 543,1,11,-51,20 884 | 543,0,13,91,-20 885 | 543,6,16,110,-10 886 | 544,1,11,-51,20 887 | 544,0,13,95,-20 888 | 544,6,16,110,-10 889 | 545,1,11,-51,20 890 | 545,0,13,99,-20 891 | 545,6,16,110,-10 892 | 546,1,11,-51,20 893 | 546,0,13,103,-20 894 | 546,6,16,110,-10 895 | 547,1,11,-51,20 896 | 547,0,13,107,-20 897 | 547,6,16,110,-10 898 | 548,1,11,-51,20 899 | 548,0,13,111,-20 900 | 548,6,16,110,-10 901 | 549,1,11,-51,20 902 | 549,0,13,115,-20 903 | 549,6,16,110,-10 904 | 550,1,11,-51,20 905 | 550,0,13,119,-20 906 | 550,6,16,110,-10 907 | 551,1,11,-51,20 908 | 551,0,13,123,-20 909 | 551,6,16,110,-10 910 | 552,1,11,-51,20 911 | 552,0,13,127,-20 912 | 552,6,16,110,-10 913 | 553,1,11,-51,20 914 | 553,0,13,131,-20 915 | 553,6,16,110,-10 916 | 554,1,11,-51,20 917 | 554,0,13,135,-20 918 | 554,6,16,110,-10 919 | 555,1,11,-51,20 920 | 555,0,13,139,-20 921 | 555,6,16,110,-10 922 | 556,1,11,-51,20 923 | 556,0,13,143,-20 924 | 556,6,16,110,-10 925 | 557,1,11,-51,20 926 | 557,0,13,147,-20 927 | 557,6,16,110,-10 928 | 558,1,11,-51,20 929 | 558,0,13,151,-20 930 | 558,6,16,110,-10 931 | 559,1,11,-51,20 932 | 559,0,13,155,-20 933 | 560,1,11,-51,20 934 | 560,0,13,159,-20 935 | 560,6,16,110,-10 936 | 561,1,11,-51,20 937 | 561,0,13,163,-20 938 | 561,6,16,110,-10 939 | 562,1,11,-51,20 940 | 562,0,13,167,-20 941 | 562,6,16,110,-10 942 | 563,1,11,-51,20 943 | 563,0,13,171,-20 944 | 563,6,16,110,-10 945 | 564,1,11,-51,20 946 | 564,0,13,175,-20 947 | 564,6,16,110,-10 948 | 565,1,11,-51,20 949 | 565,0,13,179,-20 950 | 565,6,16,110,-10 951 | 566,1,11,-51,20 952 | 566,0,13,178,-10 953 | 566,6,16,110,-10 954 | 567,1,11,-51,20 955 | 567,0,13,174,-10 956 | 567,6,16,110,-10 957 | 568,1,11,-51,20 958 | 568,0,13,170,-10 959 | 568,6,16,110,-10 960 | 569,1,11,-51,20 961 | 569,0,13,166,-10 962 | 569,6,16,110,-10 963 | 570,1,11,-51,20 964 | 570,0,13,162,-10 965 | 570,6,16,110,-10 966 | 571,1,11,-51,20 967 | 571,0,13,158,-10 968 | 571,6,16,110,-10 969 | 572,1,11,-51,20 970 | 572,0,13,154,-10 971 | 572,6,16,110,-10 972 | 573,1,11,-51,20 973 | 573,0,13,150,-10 974 | 573,6,16,110,-10 975 | 574,1,11,-51,20 976 | 574,0,13,146,-10 977 | 574,6,16,110,-10 978 | 575,1,11,-51,20 979 | 575,0,13,142,-10 980 | 575,6,16,110,-10 981 | 576,1,11,-51,20 982 | 576,0,13,138,-10 983 | 576,6,16,110,-10 984 | 577,1,11,-51,20 985 | 577,0,13,134,-10 986 | 577,6,16,110,-10 987 | 578,1,11,-51,20 988 | 578,0,13,130,-10 989 | 578,6,16,110,-10 990 | 579,1,11,-51,20 991 | 579,0,13,126,-10 992 | 579,6,16,110,-10 993 | 580,1,11,-51,20 994 | 580,0,13,122,-10 995 | 580,6,16,110,-10 996 | 581,1,11,-51,20 997 | 581,0,13,118,-10 998 | 581,6,16,110,-10 999 | 582,1,11,-51,20 1000 | 582,0,13,114,-10 1001 | 582,6,16,110,-10 1002 | 583,1,11,-51,20 1003 | 583,0,13,110,-10 1004 | 583,6,16,110,-10 1005 | 584,1,11,-51,20 1006 | 584,0,13,106,-10 1007 | 584,6,16,110,-10 1008 | 585,1,11,-51,20 1009 | 585,0,13,102,-10 1010 | 585,6,16,110,-10 1011 | 586,1,11,-51,20 1012 | 586,0,13,98,-10 1013 | 586,6,16,110,-10 1014 | 587,1,11,-51,20 1015 | 587,0,13,94,-10 1016 | 587,6,16,110,-10 1017 | 588,1,11,-51,20 1018 | 588,0,13,90,-10 1019 | 588,6,16,110,-10 1020 | 589,1,11,-51,20 1021 | 589,0,13,86,-10 1022 | 589,6,16,110,-10 1023 | 590,1,11,-51,20 1024 | 590,0,13,82,-10 1025 | 590,6,16,110,-10 1026 | 591,1,11,-51,20 1027 | 591,0,13,78,-10 1028 | 591,6,16,110,-10 1029 | 592,1,11,-51,20 1030 | 592,0,13,74,-10 1031 | 592,6,16,110,-10 1032 | 593,1,11,-51,20 1033 | 593,0,13,70,-10 1034 | 593,6,16,110,-10 1035 | 594,1,11,-51,20 1036 | 594,0,13,66,-10 1037 | 594,6,16,110,-10 1038 | 595,1,11,-51,20 1039 | 595,0,13,62,-10 1040 | 595,6,16,110,-10 1041 | 596,1,11,-51,20 1042 | 596,0,13,58,-10 1043 | 596,6,16,110,-10 1044 | 597,1,11,-51,20 1045 | 597,0,13,54,-10 1046 | 597,6,16,110,-10 1047 | 598,1,11,-51,20 1048 | 598,0,13,50,-10 1049 | 598,6,16,110,-10 1050 | 599,1,11,-51,20 1051 | 599,0,13,46,-10 1052 | 599,6,16,110,-10 1053 | -------------------------------------------------------------------------------- /tests/fixtures/fold1_room1_mix001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/tests/fixtures/fold1_room1_mix001.wav -------------------------------------------------------------------------------- /tests/fixtures/p225_001.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/tests/fixtures/p225_001.wav -------------------------------------------------------------------------------- /tests/fixtures/test_real.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/tests/fixtures/test_real.pickle -------------------------------------------------------------------------------- /tests/fixtures/weights.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SOUNDS-RESEARCH/complex_neural_source_localization/d3b4e3ca1ea216f4515abb011f02809183e375c9/tests/fixtures/weights.ckpt -------------------------------------------------------------------------------- /tests/test_correlation.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import librosa 3 | import os 4 | from pyroomasync.room import ConnectedShoeBox 5 | from pyroomasync.simulator import simulate 6 | from tdoa.tdoa import compute_correlations 7 | from tdoa.visualization import plot_correlations 8 | 9 | 10 | def test_plot_correlation(): 11 | _test_plot_correlation("correlation") 12 | 13 | def test_plot_gcc_phat(): 14 | _test_plot_correlation("gcc-phat") 15 | 16 | 17 | def _test_plot_correlation(correlation_mode): 18 | os.makedirs("tests/temp", exist_ok=True) 19 | temp_file_path = f"tests/temp/tdoa_plot_correlation_{correlation_mode}.png" 20 | if os.path.exists(temp_file_path): 21 | os.remove(temp_file_path) 22 | 23 | input_signal, fs = librosa.load("tests/fixtures/p225_001.wav") 24 | 25 | room_dim = [5, 5, 3] 26 | source_location = [2, 3, 1] 27 | mic_locations = [[2, 2, 1], [3, 3, 1], [4, 4, 1]] 28 | 29 | room = ConnectedShoeBox(room_dim, fs=fs) 30 | room.add_source(source_location, input_signal) 31 | room.add_microphone_array(mic_locations) 32 | 33 | simulation_results = simulate(room) 34 | correlations = compute_correlations(simulation_results, fs, correlation_mode) 35 | 36 | # plot_correlations(correlations, output_path=temp_file_path) 37 | 38 | # plt.close() 39 | # assert os.path.exists(temp_file_path) 40 | -------------------------------------------------------------------------------- /tests/test_feature_extractors.py: -------------------------------------------------------------------------------- 1 | from numpy import unwrap 2 | import torchaudio 3 | 4 | from hydra import compose, initialize 5 | from hydra.core.global_hydra import GlobalHydra 6 | 7 | from complex_neural_source_localization.feature_extractors import ( 8 | StftArray, CrossSpectra 9 | ) 10 | from complex_neural_source_localization.utils.model_visualization import ( 11 | plot_multichannel_spectrogram 12 | ) 13 | 14 | 15 | def test_spectrogram_array(): 16 | cfg = _load_config() 17 | 18 | sample_path = "tests/fixtures/fold1_room1_mix001.wav" 19 | signal = torchaudio.load(sample_path)[0].unsqueeze(0) 20 | 21 | spec_array = StftArray(cfg["model"]) 22 | 23 | result = spec_array(signal) 24 | 25 | plot_multichannel_spectrogram(result[0], output_path="tests/temp/multichannel_stft.png") 26 | 27 | assert result.shape == (1, 4, 512, 5626) 28 | """Batch size, n_array, n_fft//2 + 1, time_steps""" 29 | 30 | 31 | def test_cross_spectrum_array(): 32 | cfg = _load_config() 33 | 34 | sample_path = "tests/fixtures/fold1_room1_mix001.wav" 35 | signal, sr = torchaudio.load(sample_path) 36 | cross_spec_array = CrossSpectra(cfg["model"]) 37 | 38 | result = cross_spec_array(signal.unsqueeze(0)) 39 | 40 | plot_multichannel_spectrogram(result[0], output_path="tests/temp/cross_spectra.png", mode="row", unwrap=True, db=True) 41 | 42 | assert result.shape == (1, 10, 512, 5626) 43 | """Batch size, n_array, n_fft//2 + 1, time_steps""" 44 | 45 | 46 | 47 | def _load_config(): 48 | GlobalHydra.instance().clear() 49 | 50 | initialize(config_path="../config", job_name="test_app") 51 | cfg = compose(config_name="config") 52 | 53 | return cfg 54 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | 5 | from complex_neural_source_localization.loss import Loss 6 | from complex_neural_source_localization.model import DOACNet 7 | 8 | 9 | def test_neural_tdoa_loss(): 10 | 11 | loss_fn = Loss() 12 | model = DOACNet(n_sources=1) 13 | 14 | sample_path = "tests/fixtures/0.0_split1_ir0_ov1_3.wav" 15 | 16 | sample = librosa.load(sample_path, sr=24000, mono=False, dtype=np.float32)[0] 17 | sample = torch.from_numpy(sample).unsqueeze(0) 18 | 19 | target = { 20 | "azimuth_2d_point": torch.Tensor([[0.0, 1.0]]), 21 | } 22 | 23 | model_output = model(sample) 24 | 25 | _ = loss_fn(model_output, target) -------------------------------------------------------------------------------- /tests/test_math_utils.py: -------------------------------------------------------------------------------- 1 | from re import M 2 | import numpy as np 3 | from tdoa.math_utils import compute_doa 4 | 5 | 6 | def test_compute_doa_colinear(): 7 | # m1 and m2 are on the x axis 8 | m1 = np.array([1, 0]) 9 | m2 = np.array([0, 0]) 10 | # s is also on the x axis 11 | s = np.array([2, 0]) 12 | 13 | doa_radians = compute_doa(m1, m2, s) 14 | doa_degrees = compute_doa(m1, m2, s, radians=False) 15 | assert doa_radians == 0 16 | assert doa_degrees == 0 17 | 18 | 19 | def test_compute_doa_perpendicular(): 20 | # m1 and m2 are on the x axis 21 | m1 = np.array([1, 0]) 22 | m2 = np.array([-1, 0]) 23 | 24 | # s is between the mics 25 | s = np.array([0, 1]) 26 | 27 | doa_degrees = compute_doa(m1, m2, s, radians=False) 28 | doa_degrees_2 = compute_doa(m1, m2, -s, radians=False) 29 | 30 | assert doa_degrees == 90 31 | assert doa_degrees_2 == -90 32 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | from complex_neural_source_localization.datasets.dcase_2019_task3_dataset import load_multichannel_wav 2 | from complex_neural_source_localization.model import DOACNet 3 | 4 | 5 | def test_tdoa_crnn10_with_stft(): 6 | _test_tdoa_crnn10("stft") 7 | 8 | 9 | def test_tdoa_crnn10_with_mfcc(): 10 | _test_tdoa_crnn10("mfcc") 11 | 12 | 13 | def _test_tdoa_crnn10(feature_type): 14 | 15 | model = DOACNet(n_sources=1) 16 | 17 | sample = load_multichannel_wav("tests/fixtures/0.0_split1_ir0_ov1_3.wav", 16000, 1) 18 | 19 | model_output = model(sample.unsqueeze(0)) 20 | 21 | assert model_output.shape == (1, 2) 22 | -------------------------------------------------------------------------------- /tests/test_tdoa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pyroomasync.room import ConnectedShoeBox 4 | from pyroomasync.simulator import simulate 5 | from tdoa.tdoa import compute_tdoas 6 | 7 | 8 | def test_compute_tdoas(): 9 | def sinusoid(freq_in_hz, duration, sr): 10 | linear_samples = np.arange(duration*sr) 11 | return np.sin(linear_samples*freq_in_hz) 12 | 13 | fs = 48000 14 | input_signal = sinusoid(10, 1, fs) 15 | 16 | room_dim = [4, 6, 3] 17 | source_location = [1, 5, 1] 18 | mic_locations = [[2, 2, 1], [3, 3, 1]] 19 | 20 | room = ConnectedShoeBox(room_dim) 21 | room.add_source(source_location, input_signal) 22 | room.add_microphone_array(mic_locations) 23 | 24 | simulation_results = simulate(room) 25 | tdoas = compute_tdoas(simulation_results, room.base_fs) 26 | # Currently no assertions, just seeing if it goes through 27 | -------------------------------------------------------------------------------- /tests/test_visualize_convolutional_feature_maps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from complex_neural_source_localization.utils.model_utilities import( 4 | get_all_layers 5 | ) 6 | from complex_neural_source_localization.model import DOACNet 7 | 8 | 9 | def test_get_all_layers(): 10 | model = DOACNet() 11 | 12 | nm = [i for i in model.named_modules()] 13 | 14 | layers = get_all_layers(model) 15 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | from omegaconf import DictConfig 4 | 5 | from complex_neural_source_localization.datasets import create_dataloaders 6 | from complex_neural_source_localization.trainer import DOACNetTrainer 7 | 8 | 9 | @hydra.main(config_path="config", config_name="config", version_base=None) 10 | def train(config: DictConfig): 11 | """Runs the training procedure using Pytorch lightning 12 | And tests the model with the best validation score against the test dataset. 13 | 14 | Args: 15 | config (DictConfig): Configuration automatically loaded by Hydra. 16 | See the config/ directory for the configuration 17 | """ 18 | 19 | dataset_train, dataset_val, dataset_test = create_dataloaders(config) 20 | 21 | trainer = DOACNetTrainer(config) 22 | 23 | trainer.fit(dataset_train, val_dataloaders=dataset_val) 24 | trainer.test(dataset_test) 25 | 26 | 27 | if __name__ == "__main__": 28 | train() 29 | --------------------------------------------------------------------------------