├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── configs │ └── dcase2020_seld_data_config.yml ├── database.py ├── dataloader.py ├── datamodule.py ├── feature_extraction.py ├── label_mappings.py └── meta │ ├── eval.csv │ └── original │ ├── test.csv │ ├── train.csv │ ├── trainval.csv │ └── val.csv ├── experiments ├── configs │ └── sed.yml └── train.py ├── figures ├── crnn_block.png ├── experimental_results.png ├── model_descriptions.png └── seld_framework.png ├── metrics ├── SELD_evaluation_metrics.py ├── evaluation_metrics.py └── pl_metrics.py ├── models ├── model_utils.py ├── sed_decoders.py ├── sed_encoders.py └── sed_models.py ├── paper └── general_seld.pdf ├── pretrained_models └── README.md ├── py37_environment.yml └── utilities ├── builder_utils.py ├── experiments_utils.py ├── learning_utils.py ├── plot_utils.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # pytype static type analyzer 132 | .pytype/ 133 | 134 | # Cython debug symbols 135 | cython_debug/ 136 | 137 | # My git ignore 138 | .idea 139 | 140 | # generated data files 141 | outputs/ 142 | # pretrained_models/ 143 | .vscode/ 144 | lightning_logs/ 145 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tho Nguyen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # General-network-architecture-for-sound-event-localization-and-detection 2 | This repository contains code for our paper: **A General Network Architecture for Sound Event Localization and Detection Using Transfer Learning and Recurrent Neural Network**. 3 | Link to paper can be found [here](https://arxiv.org/abs/2011.07859) or in the *paper* folder. 4 | 5 | We are sorry for uploading the code late. We are factoring our code using pytorch lightning to improve readability and reproducibility. 6 | Stay tuned. 7 | 8 | ## Task description 9 | Sound event localization and detection (SELD) is an audio-related task that aims to detect onsets, offsets of sound events, estimate the directions-of-arrival of the sound sources, and recognize the sound classes. 10 | SELD consists of two subtasks, which are sound event detection (SED) and direction-of-arrival (DOA) estimation. 11 | While SED and DOA estimation tasks are quite matured in audio field, 12 | SELD task is relatively new and getting more popular thanks to its appearance in the last two [DCASE2019](http://dcase.community/challenge2019/task-sound-event-localization-and-detection) and [DCASE202](http://dcase.community/challenge2020/task-sound-event-localization-and-detection) SELD challenge. 13 | 14 | SELD is challenging because not only we need to solve two subtask SED and DOA estimation well, 15 | but also we need to match the correct directions with the correct sound classes. There are several approaches for SELD task. 16 | We can jointly estimate SED and DOA estimation in a single stage or we can solve SELD in several stages. 17 | Conventional challenges of SED, DOA estimation, and SELD are noise, reveberation and overlapping sources. 18 | 19 | Another big challenge of SELD is the lack of joint datasets. 20 | We have much larger datasets for single-channel sound classification such as [AudioSet](https://research.google.com/audioset/) or [FSD50K](https://annotator.freesound.org/fsd/release/FSD50K/). 21 | However, since the multi-channel datasets are restricted to some particular microphone array geometry, 22 | we do not have a big general datasets for DOA estimation. 23 | The current largest publicly available joint dataset for SELD task is simulated and limited to 13.3 hours. 24 | 25 | ## Proposed method 26 | We proposed a general and flexible network architecture for SELD as shown in the figures below. 27 | ![alt text](figures/seld_framework.png) 28 | We decoupled the SED and DOA estimation task. 29 | An alignment module is used to align output predictions of the SED and DOA estimation module. 30 | The SED and DOA estimation modules are pretrained on their respectively task and can be fine-tuned again during the training of the alignment module. 31 | The advantages of the proposed network architecture are: 32 | 33 | 1. It is easier to optimize SED and DOA estimation module separately as SED and DOA estimation relies on different form of audio input features. 34 | 35 | 2. The model reduces unwanted association between sound classes and DOAs in the training set since SED and DOA estimation are pretrained separately. 36 | 37 | 3. The network architecture is highly practical: 38 | * It is more flexible to select different SED and DOA estimation algorithms that are suitable for specific application. 39 | For example, lightweight algorithms are preferred for edge devices. 40 | * The SED module can use single-channel dataset instead of multi-channel dataset to leverage the much larger available datasets. 41 | * We can use both traditional signal processing-based or deep learning-based algorithms for the DOA estimation. 42 | The signal processing-based methods do not required training data. 43 | * The required joint dataset to train the alignment module can be much smaller than the required joint dataset to train the whole SELD network end-to-end. 44 | 45 | 46 | An example of network architectures of each module are shown below. 47 | ![alt text](figures/crnn_block.png) 48 | 49 | ## Dataset 50 | We use the [TAU-NIGENS Spatial Sound Events 2020](https://zenodo.org/record/3870859) for our experiment. 51 | For more information, please refer to the [DCASE2020 SELD challenge](http://dcase.community/challenge2020/task-sound-event-localization-and-detection). 52 | 53 | ## Experimental Results 54 | We trained different SED and DOA estimation models as shown below: 55 | ![alt text](figures/model_descriptions.png) 56 | One of the SED model is initialized using pretrained weights of a model trained on AudioSet (single-channel) to demonstrate 57 | that the proposed method can also work when SED models are trained using single-channel datasets. 58 | We mix and match different SED and DOA models and trained an alignment network to match SED and DOA estimation output. 59 | The experimental results are shown below: 60 | ![alt text](figures/experimental_results.png) 61 | The experimental results are obtained without fine-tuned the SED and DOA estimation module. 62 | 63 | ## How to use the provided code 64 | The python code was implemented using Python 3.7 65 | ### Requirements 66 | The virtual environment can be installed using conda 67 | ```commandline 68 | conda env create -f py37_environment.yml 69 | ``` 70 | ### Download data and pretrained model 71 | Dataset can be downloaded from this [page](https://zenodo.org/record/3870859). Your data folder might look like this 72 | ```text 73 | |__SELD2020/ 74 | |__foa_dev/ 75 | |__foa_eval/ 76 | |__mic_dev/ 77 | |__mic_eval/ 78 | |__metadata_dev/ 79 | |__metadata_eval/ 80 | |__metadata_eval_info.csv 81 | ``` 82 | 83 | Pretrained model `Cnn14_mAP=0.431.pth` for SED can be downloaded from this [page](https://zenodo.org/record/3987831) 84 | 85 | ### Running Scripts 86 | To be updated 87 | ## Citation 88 | If you use this code in your research, please cite our paper 89 | ```text 90 | @article{nguyen2020general, 91 | title={A General Network Architecture for Sound Event Localization and Detection Using Transfer Learning and Recurrent Neural Network}, 92 | author={Nguyen, Thi Ngoc Tho and Nguyen, Ngoc Khanh and Phan, Huy and Pham, Lam and Ooi, Kenneth and Jones, Douglas L and Gan, Woon-Seng}, 93 | journal={arXiv preprint arXiv:2011.07859}, 94 | year={2020} 95 | } 96 | ``` 97 | 98 | ## External links 99 | 1. http://dcase.community/challenge2020/task-sound-event-localization-and-detection 100 | 2. https://zenodo.org/record/3870859 101 | 3. https://github.com/qiuqiangkong/audioset_tagging_cnn 102 | 4. https://zenodo.org/record/3987831 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomeou/General-network-architecture-for-sound-event-localization-and-detection/03b3aaccf3c87dd8fb857960e765ae768ad36625/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/configs/dcase2020_seld_data_config.yml: -------------------------------------------------------------------------------- 1 | data_dir: '/media/tho_nguyen/disk1/audio_datasets/dcase2020/task3' 2 | feature_dir: '/media/tho_nguyen/disk2/new_seld/dcase2020/features' 3 | data: 4 | format: 'foa' 5 | fs: 24000 6 | n_fft: 1024 7 | hop_len: 300 # 240 for 10ms, 300 for 12.5ms, 480 for 20 ms 8 | fmin: 50 9 | fmax: 12000 10 | n_mels: 128 11 | is_std_norm: false 12 | is_bg_norm: false # not using at the moment 13 | -------------------------------------------------------------------------------- /dataset/database.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module consists database class to handle different different data split, loading data into memory, dividing long 3 | audio file into segments and tokenize these segments 4 | Note: 32bit is sufficient for DL models. 5 | Note on terminology on audio length: frames -> segments -> chunk/clip -> file 6 | """ 7 | import logging 8 | import os 9 | from typing import List 10 | 11 | import h5py 12 | import numpy as np 13 | import pandas as pd 14 | 15 | 16 | class SedDoaDatabase: 17 | """ 18 | Class to handle different extracted features for SED or DOA separately. 19 | """ 20 | def __init__(self, 21 | feature_root_dir: str = '/media/tho_nguyen/disk2/new_seld/dcase2020/features/' 22 | 'logmel_norm/24000fs_1024nfft_300nhop_128nmels_Falsestd', 23 | gt_meta_root_dir: str = '/media/tho_nguyen/disk1/audio_datasets/dcase2020/task3', 24 | audio_format: str = 'foa', n_classes: int = 14, fs: int = 24000, n_fft: int = 1024, hop_len: int = 300, 25 | label_rate: float = 10, train_chunk_len_s: float = 4.0, train_chunk_hop_len_s: float = 0.5, 26 | test_chunk_len_s: float = 4.0, test_chunk_hop_len_s: float = 2.0, scaler_type: str = 'vector'): 27 | """ 28 | :param feature_root_dir: Feature directory. can be SED or DOA feature. 29 | The data are organized in the following format: 30 | |__feature_root_dir/ 31 | |__foa_dev/ 32 | |__foa_eval/ 33 | |__mic_dev/ 34 | |__mic_eval/ 35 | |__foa_feature_scaler.h5 36 | |__mic_feature_scaler.h5 37 | :param gt_meta_root_dir: Directory that contains groundtruth meta data. 38 | The data are orgamized in the following format: 39 | |__gt_meta_dir/ 40 | |__/metadata_dev/ 41 | |__/metadata_eval/ 42 | |__metadata_eval_info.csv 43 | """ 44 | self.feature_root_dir = feature_root_dir 45 | self.gt_meta_root_dir = gt_meta_root_dir 46 | self.audio_format = audio_format 47 | self.n_classes = n_classes 48 | self.fs = fs 49 | self.n_fft = n_fft 50 | self.hop_len = hop_len 51 | self.label_rate = label_rate 52 | self.train_chunk_len = self.second2frame(train_chunk_len_s) 53 | self.train_chunk_hop_len = self.second2frame(train_chunk_hop_len_s) 54 | self.test_chunk_len = self.second2frame(test_chunk_len_s) 55 | self.test_chunk_hop_len = self.second2frame(test_chunk_hop_len_s) 56 | self.scaler_type = scaler_type 57 | 58 | assert audio_format in ['foa', 'mic'], 'Incorrect value for audio format {}'.format(audio_format) 59 | assert os.path.isdir(os.path.join(self.feature_root_dir, self.audio_format + '_dev')), \ 60 | '"dev" folder is not found' 61 | 62 | self.chunk_len = None 63 | self.chunk_hop_len = None 64 | self.n_frames = int(np.floor((self.fs * 60 - (self.n_fft - self.hop_len)) / self.hop_len)) + 2 #+ 2 because of padding 65 | self.feature_rate = self.fs/self.hop_len # Frame rate per second 66 | self.label_upsample_ratio = int(self.feature_rate / self.label_rate) 67 | self.mean, self.std = self.load_scaler() 68 | 69 | logger = logging.getLogger('lightning') 70 | logger.info('Load feature database from {}.'.format(self.feature_root_dir)) 71 | logger.info('train_chunk_len = {}, train_chunk_hop_len = {}'.format( 72 | self.train_chunk_len, self.train_chunk_hop_len)) 73 | logger.info('test_chunk_len = {}, test_chunk_hop_len = {}'.format( 74 | self.test_chunk_len, self.test_chunk_hop_len)) 75 | 76 | def get_split(self, split: str, split_meta_dir: str = '/meta/original', doa_format: str = 'xyz'): 77 | """ 78 | Function to load all data of a split into memory, divide long audio clip/file into smaller chunks, and assign 79 | labels for clips and chunks. List of SED labels: 80 | 81 | :param split: Split of data, choices: 82 | 'train', 'val', 'test', 'eval': load chunk of data 83 | :param split_meta_dir: Directory where meta of split is stored. 84 | :param doa_format: Choices are 'xyz' or 'polar'. 85 | :return: 86 | """ 87 | assert doa_format in ['xyz', 'polar'], 'Incorrect value for doa format {}'.format(doa_format) 88 | # Get feature dir, filename list, and gt_meta_dir 89 | if split == 'eval': 90 | split_feature_dir = os.path.join(self.feature_root_dir, self.audio_format + '_eval') 91 | csv_filename = os.path.join(os.path.split(split_meta_dir)[0], 'eval.csv') 92 | gt_meta_dir = os.path.join(self.gt_meta_root_dir, 'metadata_eval') 93 | else: 94 | split_feature_dir = os.path.join(self.feature_root_dir, self.audio_format + '_dev') 95 | csv_filename = os.path.join(split_meta_dir, split + '.csv') 96 | gt_meta_dir = os.path.join(self.gt_meta_root_dir, 'metadata_dev') 97 | meta_df = pd.read_csv(csv_filename) 98 | split_filenames = meta_df['filename'].tolist() 99 | # Get chunk len and chunk hop len 100 | if split in ['train', 'trainval', 'val']: 101 | self.chunk_len = self.train_chunk_len 102 | self.chunk_hop_len = self.train_chunk_hop_len 103 | elif split in ['test', 'eval']: 104 | self.chunk_len = self.test_chunk_len 105 | self.chunk_hop_len = self.test_chunk_hop_len 106 | else: 107 | raise NotImplementedError('chunk len is not assigned for split {}'.format(split)) 108 | 109 | # Load and crop data 110 | features, sed_targets, doa_targets, chunk_idxes, filename_list, test_batch_size = self.load_chunk_data( 111 | split_filenames=split_filenames, split_feature_dir=split_feature_dir, gt_meta_dir=gt_meta_dir, 112 | doa_format=doa_format, split=split) 113 | # pack data 114 | db_data = { 115 | 'features': features, 116 | 'sed_targets': sed_targets, 117 | 'doa_targets': doa_targets, 118 | 'chunk_idxes': chunk_idxes, 119 | 'filename_list': filename_list, 120 | 'test_batch_size': test_batch_size 121 | } 122 | 123 | return db_data 124 | 125 | def second2frame(self, second): 126 | """ 127 | Convert seconds to frame unit. 128 | """ 129 | sample = int(second * self.fs) 130 | frame = int(round(sample/self.hop_len)) 131 | return frame 132 | 133 | def load_scaler(self): 134 | scaler_fn = os.path.join(self.feature_root_dir, self.audio_format + '_feature_scaler.h5') 135 | if self.scaler_type == 'vector': 136 | with h5py.File(scaler_fn, 'r') as hf: 137 | mean = hf['mean'][:] 138 | std = hf['std'][:] 139 | elif self.scaler_type == 'scalar': 140 | with h5py.File(scaler_fn, 'r') as hf: 141 | mean = hf['scalar_mean'][:] 142 | std = hf['scalar_std'][:] 143 | else: 144 | mean = 0 145 | std = 1 146 | return mean, std 147 | 148 | def load_chunk_data(self, split_filenames: List, split_feature_dir: str, gt_meta_dir: str, 149 | doa_format: str = 'xyz', split: str = 'train'): 150 | """ 151 | Load feature, crop data and assign labels. 152 | :param split_filenames: List of filename in the split. 153 | :param split_feature_dir: Feature directory of the split 154 | :param gt_meta_dir: Ground truth meta directory of the split. 155 | :param doa_format: Choices are 'xyz' or 'polar'. 156 | :param split: Name of split, can be 'train', 'trainval', 'val', 'test', 'eval'. 157 | :return: features, targets, chunk_idxes, filename_list 158 | """ 159 | pointer = 0 160 | features_list = [] 161 | filename_list = [] 162 | sed_targets_list = [] 163 | doa_targets_list = [] 164 | idxes_list = [] 165 | for filename in split_filenames: 166 | feature_fn = os.path.join(split_feature_dir, filename + '.h5') 167 | # Load feature 168 | with h5py.File(feature_fn, 'r') as hf: 169 | feature = hf['feature'][:] 170 | # Normalize feature 171 | feature = (feature - self.mean) / self.std 172 | n_frames = feature.shape[1] 173 | # Load gt info from metadata 174 | gt_meta_fn = os.path.join(gt_meta_dir, filename + '.csv') 175 | df = pd.read_csv(gt_meta_fn, header=None, 176 | names=['frame_number', 'sound_class_idx', 'track_number', 'azimuth', 'elevation']) 177 | frame_number = df['frame_number'].values 178 | sound_class_idx = df['sound_class_idx'].values 179 | track_number = df['track_number'].values 180 | azimuth = df['azimuth'].values 181 | elevation = df['elevation'].values 182 | # Generate target data 183 | sed_target = np.zeros((n_frames, self.n_classes), dtype=np.float32) 184 | azi_target = np.zeros((n_frames, self.n_classes), dtype=np.float32) 185 | ele_target = np.zeros((n_frames, self.n_classes), dtype=np.float32) 186 | nsources_target = np.zeros((n_frames, 3), dtype=np.float32) 187 | count_sources_target = np.zeros((n_frames,), dtype=np.float32) 188 | for itrack in np.arange(5): 189 | track_idx = track_number == itrack 190 | frame_number_1 = frame_number[track_idx] 191 | sound_class_idx_1 = sound_class_idx[track_idx] 192 | azimuth_1 = azimuth[track_idx] 193 | elevation_1 = elevation[track_idx] 194 | for idx, iframe in enumerate(frame_number_1): 195 | start_idx = int(iframe * self.label_upsample_ratio - self.label_upsample_ratio//2) 196 | start_idx = np.max((0, start_idx)) 197 | end_idx = int(start_idx + self.label_upsample_ratio) 198 | end_idx = np.min((end_idx, n_frames)) 199 | class_idx = int(sound_class_idx_1[idx]) 200 | sed_target[start_idx:end_idx, class_idx] = 1.0 201 | azi_target[start_idx:end_idx, class_idx] = azimuth_1[idx] * np.pi / 180.0 # Radian unit 202 | ele_target[start_idx:end_idx, class_idx] = elevation_1[idx] * np.pi / 180.0 # Radian unit 203 | count_sources_target[start_idx:end_idx] += 1 204 | # Convert nsources to one-hot encoding 205 | for i in np.arange(3): 206 | idx = count_sources_target == i 207 | nsources_target[idx, i] = 1.0 208 | # Doa target 209 | if doa_format == 'polar': 210 | doa_target = np.concatenate((azi_target, ele_target), axis=-1) 211 | elif doa_format == 'xyz': 212 | x = np.cos(azi_target) * np.cos(ele_target) 213 | y = np.sin(azi_target) * np.cos(ele_target) 214 | z = np.sin(ele_target) 215 | doa_target = np.concatenate((x, y, z), axis=-1) 216 | # Get segment indices 217 | n_crop_frames = n_frames 218 | assert self.chunk_len <= n_crop_frames, 'Number of cropped frame is less than chunk len' 219 | idxes = np.arange(pointer, pointer + n_crop_frames - self.chunk_len + 1, self.chunk_hop_len).tolist() 220 | # Include the leftover of the cropped data 221 | if (n_crop_frames - self.chunk_len) % self.chunk_hop_len != 0: 222 | idxes.append(pointer + n_crop_frames - self.chunk_len) 223 | pointer += n_crop_frames 224 | # Append data 225 | features_list.append(feature) 226 | filename_list.extend([filename] * len(idxes)) 227 | sed_targets_list.append(sed_target) 228 | doa_targets_list.append(doa_target) 229 | idxes_list.append(idxes) 230 | 231 | if len(features_list) > 0: 232 | features = np.concatenate(features_list, axis=1) 233 | sed_targets = np.concatenate(sed_targets_list, axis=0) 234 | doa_targets = np.concatenate(doa_targets_list, axis=0) 235 | chunk_idxes = np.concatenate(idxes_list, axis=0) 236 | test_batch_size = len(idxes) # to load all chunks of the same file 237 | return features, sed_targets, doa_targets, chunk_idxes, filename_list, test_batch_size 238 | else: 239 | return None, None, None, None, None 240 | 241 | class SeldDatabase(): 242 | """ 243 | Database class to handle two input streams, one for SED, one for DOA. Use this database to train alignment module 244 | """ 245 | pass 246 | 247 | 248 | if __name__ == '__main__': 249 | tp_db = SedDoaDatabase() 250 | db_data = tp_db.get_split(split='val', split_meta_dir='meta/original') 251 | print(db_data['features'].shape) 252 | print(db_data['sed_targets'].shape) 253 | print(len(db_data['chunk_idxes'])) 254 | print(len(db_data['filename_list'])) 255 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module for dataloader 3 | """ 4 | from typing import List, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset, Sampler 9 | 10 | from dataset.database import SedDoaDatabase 11 | 12 | 13 | class SedDoaChunkDataset(Dataset): 14 | """ 15 | Chunk dataset for SED or DOA task. For training and chunk evaluation. 16 | """ 17 | def __init__(self, db_data, chunk_len, transform=None, is_mixup: bool = False): 18 | self.features = db_data['features'] 19 | self.sed_targets = db_data['sed_targets'] 20 | self.doa_targets = db_data['doa_targets'] 21 | self.chunk_idxes = db_data['chunk_idxes'] 22 | self.filename_list = db_data['filename_list'] 23 | self.chunk_len = chunk_len 24 | self.transform = transform 25 | self.is_mixup = is_mixup 26 | self.n_samples = len(self.chunk_idxes) 27 | 28 | def __len__(self): 29 | """ 30 | Total of training samples. 31 | """ 32 | return len(self.chunk_idxes) 33 | 34 | def __getitem__(self, index): 35 | """ 36 | Generate one sample of data 37 | """ 38 | # Select sample 39 | chunk_idx = self.chunk_idxes[index] 40 | 41 | # get filename 42 | filename = self.filename_list[index] 43 | 44 | # Load data and get label 45 | X = self.features[:, chunk_idx: chunk_idx + self.chunk_len, :] # (n_channels, n_timesteps, n_mels) 46 | sed_labels = self.sed_targets[chunk_idx: chunk_idx + self.chunk_len] # (n_timesteps, n_classes) 47 | doa_labels = self.doa_targets[chunk_idx: chunk_idx + self.chunk_len] # (n_timesteps, x*n_classes) 48 | 49 | # Mixup mainly for SED 50 | if self.is_mixup: 51 | a1 = np.random.beta(0.5, 0.5) 52 | if np.random.rand() < 0.8 and np.abs(a1 - 0.5) > 0.2: 53 | random_index = np.random.randint(0, self.n_samples, 1)[0] 54 | random_chunk_idx = self.chunk_idxes[random_index] 55 | X_1 = self.features[:, random_chunk_idx: random_chunk_idx + self.chunk_len, :] 56 | sed_labels_1 = self.sed_targets[random_chunk_idx: random_chunk_idx + self.chunk_len] 57 | doa_labels_1 = self.doa_targets[random_chunk_idx: random_chunk_idx + self.chunk_len] 58 | X = a1 * X + (1 - a1) * X_1 59 | sed_labels = a1 * sed_labels + (1 - a1) * sed_labels_1 60 | doa_labels = a1 * doa_labels + (1 - a1) * doa_labels_1 61 | 62 | if self.transform is not None: 63 | X = self.transform(X) 64 | 65 | return X, sed_labels, doa_labels, filename 66 | 67 | 68 | class SeldChunkDataset(Dataset): 69 | """ 70 | Chunk dataset for SELD task 71 | """ 72 | pass 73 | 74 | 75 | if __name__ == '__main__': 76 | # test dataloader 77 | db = SedDoaDatabase() 78 | data_db = db.get_split(split='val') 79 | 80 | # create train dataset 81 | dataset = SedDoaChunkDataset(db_data=data_db, chunk_len=db.chunk_len) 82 | print('Number of training samples: {}'.format(len(dataset))) 83 | 84 | # load one sample 85 | index = np.random.randint(len(dataset)) 86 | sample = dataset[index] 87 | for item in sample[:-1]: 88 | print(item.shape) 89 | print(sample[-1]) 90 | 91 | # test data generator 92 | batch_size = 8 93 | dataloader = torch.utils.data.DataLoader(dataset=dataset, 94 | batch_size=batch_size, 95 | shuffle=False, 96 | num_workers=4) 97 | print('Number of batches: {}'.format(len(dataloader))) # steps_per_epoch 98 | for train_iter, (X, sed_labels, doa_labels, filenames) in enumerate(dataloader): 99 | if train_iter == 0: 100 | print(X.dtype) 101 | print(X.shape) 102 | print(sed_labels.dtype) 103 | print(sed_labels.shape) 104 | print(doa_labels.dtype) 105 | print(doa_labels.shape) 106 | print(type(filenames)) 107 | print(filenames) 108 | break -------------------------------------------------------------------------------- /dataset/datamodule.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | 6 | from dataset.dataloader import SedDoaChunkDataset 7 | from utilities.transforms import CompositeCutout 8 | 9 | 10 | class SedDoaDataModule(pl.LightningDataModule): 11 | """ 12 | DataModule that group train and validation data for SED or DOA task loader under on hood. 13 | """ 14 | def __init__(self, feature_db, split_meta_dir: str = '/dataset/meta/original/', train_batch_size: int = 32, 15 | val_batch_size: int = 32, mode: str = 'crossval'): 16 | super().__init__() 17 | self.feature_db = feature_db 18 | self.split_meta_dir = split_meta_dir 19 | self.train_batch_size = train_batch_size 20 | self.val_batch_size = val_batch_size 21 | self.train_dataset = None 22 | self.val_dataset = None 23 | self.test_dataset = None 24 | self.test_batch_size = None 25 | self.lit_logger = logging.getLogger('lightning') 26 | self.lit_logger.info('Create DataModule using tran val split at {}.'.format(split_meta_dir)) 27 | if mode == 'crossval': 28 | self.train_split = 'train' 29 | self.val_split = 'val' 30 | self.test_split = 'test' 31 | elif mode == 'eval': 32 | self.train_split = 'trainval' 33 | self.val_split = 'test' 34 | self.test_split = 'eval' 35 | else: 36 | raise NotImplementedError('Mode {} is not implemented!'.format(mode)) 37 | 38 | # Data augmentation 39 | self.train_transform = CompositeCutout(image_aspect_ratio=self.feature_db.train_chunk_len/128) # 128 n_mels 40 | 41 | def setup(self, stage: str = None): 42 | """ 43 | :param stage: can be 'fit', 'test. 44 | """ 45 | # Get train and val data during training 46 | if stage == 'fit': # to use clip for validation 47 | train_db = self.feature_db.get_split(split=self.train_split, split_meta_dir=self.split_meta_dir) 48 | self.train_dataset = SedDoaChunkDataset(db_data=train_db, chunk_len=self.feature_db.train_chunk_len, 49 | transform=self.train_transform, is_mixup=True) 50 | val_db = self.feature_db.get_split(split=self.val_split, split_meta_dir=self.split_meta_dir) 51 | self.val_dataset = SedDoaChunkDataset(db_data=val_db, chunk_len=self.feature_db.train_chunk_len) 52 | elif stage == 'test': 53 | test_db = self.feature_db.get_split(split=self.test_split, split_meta_dir=self.split_meta_dir) 54 | self.test_dataset = SedDoaChunkDataset(db_data=test_db, chunk_len=self.feature_db.test_chunk_len) 55 | self.test_batch_size = test_db['test_batch_size'] 56 | self.lit_logger.info('In datamodule: test batch size = {}'.format(self.test_batch_size)) 57 | else: 58 | raise NotImplementedError('stage {} is not implemented for datamodule'.format(stage)) 59 | 60 | def train_dataloader(self): 61 | return DataLoader(dataset=self.train_dataset, 62 | batch_size=self.train_batch_size, 63 | shuffle=True, 64 | pin_memory=True, 65 | num_workers=4) 66 | 67 | def val_dataloader(self): 68 | return DataLoader(dataset=self.val_dataset, 69 | batch_size=self.val_batch_size, 70 | shuffle=False, 71 | pin_memory=True, 72 | num_workers=4) 73 | 74 | def test_dataloader(self): 75 | return DataLoader(dataset=self.test_dataset, 76 | batch_size=self.test_batch_size, 77 | shuffle=False, 78 | pin_memory=True, 79 | num_workers=4) 80 | 81 | 82 | class SeldDataModule(pl.LightningDataModule): 83 | """ 84 | DataModule that group train and validation data for SELD task loader under on hood. 85 | """ -------------------------------------------------------------------------------- /dataset/feature_extraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module includes classes and functions to extract audio features and compute global mean and standard deviation of 3 | the extracted features. 4 | Reference: ff551c5 https://github.com/yinkalario/Two-Stage-Polyphonic-Sound-Event-Detection-and-Localization/blob/ 5 | master/utils/feature_extractor.py 6 | """ 7 | 8 | import os 9 | import shutil 10 | 11 | import fire 12 | import h5py 13 | import librosa 14 | import numpy as np 15 | import yaml 16 | from sklearn import preprocessing 17 | from timeit import default_timer as timer 18 | from tqdm import tqdm 19 | 20 | from utilities import noise 21 | 22 | 23 | class FeatureExtractor: 24 | """ 25 | Base class for feature extraction. 26 | """ 27 | def __init__(self, fs: int, n_fft: int, hop_length: int, n_mels: int, fmin: int = 50, fmax: int = None, 28 | window: str = 'hann'): 29 | """ 30 | :param fs: Sampling rate. 31 | :param n_fft: Number of FFT points. 32 | :param hop_length: Number of sample for hopping. 33 | :param n_mels: Number of mel bands. 34 | :param fmin: Min frequency to extract feature (Hz). 35 | :param fmax: Max frequency to extract feature (Hz). 36 | :param window: Type of window. 37 | """ 38 | self.n_fft = n_fft 39 | self.hop_length = hop_length 40 | self.window = window 41 | self.melW = librosa.filters.mel(sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 42 | 43 | def extract(self, audio_input: np.ndarray) -> np.ndarray: 44 | """ 45 | :param audio_input: . 46 | """ 47 | raise NotImplementedError 48 | 49 | 50 | class LogMelGccExtractor(FeatureExtractor): 51 | """ 52 | Extract logmel and GCC-PHAT features. 53 | """ 54 | def __init__(self, fs: int, n_fft: int, hop_length: int, n_mels: int, fmin: int = 50, fmax: int = None, 55 | window: str = 'hann'): 56 | """ 57 | :param fs: Sampling rate. 58 | :param n_fft: Number of FFT points. 59 | :param hop_length: Number of sample for hopping. 60 | :param n_mels: Number of mel bands. 61 | :param fmin: Min frequency to extract feature (Hz). 62 | :param fmax: Max frequency to extract feature (Hz). 63 | :param window: Type of window. 64 | """ 65 | super().__init__(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax, window=window) 66 | self.n_mels = n_mels 67 | 68 | def gcc_phat(self, sig, refsig) -> np.ndarray: 69 | """ 70 | Compute GCC-PHAT between sig and refsig. 71 | :param sig: 74 | """ 75 | ncorr = 2 * self.n_fft - 1 76 | n_fft = int(2 ** np.ceil(np.log2(np.abs(ncorr)))) 77 | assert n_fft == self.n_fft, 'Please choose nfft in the form of 2**x' 78 | Px = librosa.stft(y=np.asfortranarray(sig), 79 | n_fft=self.n_fft, 80 | hop_length=self.hop_length, 81 | center=True, 82 | window=self.window, 83 | pad_mode='reflect') 84 | Px_ref = librosa.stft(y=np.asfortranarray(refsig), 85 | n_fft=self.n_fft, 86 | hop_length=self.hop_length, 87 | center=True, 88 | window=self.window, 89 | pad_mode='reflect') 90 | R = Px * np.conj(Px_ref) 91 | n_frames = R.shape[1] 92 | gcc_phat = [] 93 | for i in range(n_frames): 94 | spec = R[:, i].flatten() 95 | cc = np.fft.irfft(np.exp(1.j * np.angle(spec))) 96 | cc = np.concatenate((cc[self.n_mels // 2:], cc[:self.n_mels // 2])) 97 | gcc_phat.append(cc) 98 | gcc_phat = np.array(gcc_phat) 99 | gcc_phat = gcc_phat[None, :, :] 100 | 101 | return gcc_phat 102 | 103 | def logmel(self, sig) -> np.ndarray: 104 | """ 105 | Compute logmel of single channel signal 106 | :param sig: . 108 | """ 109 | spec = np.abs(librosa.stft(y=np.asfortranarray(audio_input[i_channel]), 110 | n_fft=self.n_fft, 111 | hop_length=self.hop_length, 112 | center=True, 113 | window=self.window, 114 | pad_mode='reflect')) 115 | 116 | mel_spec = np.dot(self.melW, spec ** 2).T 117 | logmel_spec = librosa.power_to_db(mel_spec, ref=1.0, amin=1e-10, top_db=None) 118 | logmel_spec = np.expand_dims(logmel_spec, axis=0) 119 | 120 | return logmel_spec 121 | 122 | def extract(self, audio_input: np.ndarray) -> np.ndarray: 123 | """ 124 | :param audio_input: . 125 | :return: logmel_features . 126 | """ 127 | n_channels = audio_input.shape[0] 128 | features = [] 129 | gcc_features = [] 130 | for n in range(n_channels): 131 | features.append(self.logmel(audio_input[n])) 132 | for m in range(n + 1, n_channels): 133 | gcc_features.append(self.gcc_phat(sig=audio_input[m], refsig=audio_input[n])) 134 | 135 | features.extend(gcc_features) 136 | features = np.concatenate(features, axis=0) 137 | 138 | return features 139 | 140 | 141 | class GccExtractor(FeatureExtractor): 142 | """ 143 | Extract GCC-PHAT features. 144 | """ 145 | 146 | def __init__(self, fs: int, n_fft: int, hop_length: int, n_mels: int, fmin: int = 50, fmax: int = None, 147 | window: str = 'hann'): 148 | """ 149 | :param fs: Sampling rate. 150 | :param n_fft: Number of FFT points. 151 | :param hop_length: Number of sample for hopping. 152 | :param n_mels: Number of mel bands. 153 | :param fmin: Min frequency to extract feature (Hz). 154 | :param fmax: Max frequency to extract feature (Hz). 155 | :param window: Type of window. 156 | """ 157 | super().__init__(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax, window=window) 158 | self.n_mels = n_mels 159 | 160 | def gcc_phat(self, sig, refsig) -> np.ndarray: 161 | """ 162 | Compute GCC-PHAT between sig and refsig. 163 | :param sig: 166 | """ 167 | ncorr = 2 * self.n_fft - 1 168 | n_fft = int(2 ** np.ceil(np.log2(np.abs(ncorr)))) 169 | assert n_fft == self.n_fft, 'Please choose nfft in the form of 2**x' 170 | Px = librosa.stft(y=np.asfortranarray(sig), 171 | n_fft=self.n_fft, 172 | hop_length=self.hop_length, 173 | center=True, 174 | window=self.window, 175 | pad_mode='reflect') 176 | Px_ref = librosa.stft(y=np.asfortranarray(refsig), 177 | n_fft=self.n_fft, 178 | hop_length=self.hop_length, 179 | center=True, 180 | window=self.window, 181 | pad_mode='reflect') 182 | R = Px * np.conj(Px_ref) 183 | n_frames = R.shape[1] 184 | gcc_phat = [] 185 | for i in range(n_frames): 186 | spec = R[:, i].flatten() 187 | cc = np.fft.irfft(np.exp(1.j * np.angle(spec))) 188 | cc = np.concatenate((cc[self.n_mels // 2:], cc[:self.n_mels // 2])) 189 | gcc_phat.append(cc) 190 | gcc_phat = np.array(gcc_phat) 191 | gcc_phat = gcc_phat[None, :, :] 192 | 193 | return gcc_phat 194 | 195 | def extract(self, audio_input: np.ndarray) -> np.ndarray: 196 | """ 197 | :param audio_input: . 198 | :return: logmel_features . 199 | """ 200 | n_channels = audio_input.shape[0] 201 | features = [] 202 | for n in range(n_channels): 203 | for m in range(n + 1, n_channels): 204 | features.append(self.gcc_phat(sig=audio_input[m], refsig=audio_input[n])) 205 | features = np.concatenate(features, axis=0) 206 | 207 | return features 208 | 209 | 210 | class LogMelIvExtractor(FeatureExtractor): 211 | """ 212 | Extract Logmel and Intensity vector from FOA format. 213 | """ 214 | def __init__(self, fs: int, n_fft: int, hop_length: int, n_mels: int, fmin: int = 50, fmax: int = None, 215 | window: str = 'hann'): 216 | """ 217 | :param fs: Sampling rate. 218 | :param n_fft: Number of FFT points. 219 | :param hop_length: Number of sample for hopping. 220 | :param n_mels: Number of mel bands. 221 | :param fmin: Min frequency to extract feature (Hz). 222 | :param fmax: Max frequency to extract feature (Hz). 223 | :param window: Type of window. 224 | """ 225 | super().__init__(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax, window=window) 226 | self.eps = 1e-8 227 | 228 | def extract(self, audio_input: np.ndarray) -> np.ndarray: 229 | """ 230 | :param audio_input: . 231 | :return: feature_logmel . 232 | """ 233 | n_channels = audio_input.shape[0] 234 | features = [] 235 | X = [] 236 | 237 | for i_channel in range(n_channels): 238 | spec = librosa.stft(y=np.asfortranarray(audio_input[i_channel]), 239 | n_fft=self.n_fft, 240 | hop_length=self.hop_length, 241 | center=True, 242 | window=self.window, 243 | pad_mode='reflect') 244 | X.append(np.expand_dims(spec, axis=0)) # 1 x n_bins x n_frames 245 | 246 | # compute logmel 247 | mel_spec = np.dot(self.melW, np.abs(spec) ** 2).T 248 | logmel_spec = librosa.power_to_db(mel_spec, ref=1.0, amin=1e-10, top_db=None) 249 | logmel_spec = np.expand_dims(logmel_spec, axis=0) 250 | features.append(logmel_spec) 251 | 252 | # compute intensity vector: for ambisonic signal, n_channels = 4 253 | X = np.concatenate(X, axis=0) # 4 x n_bins x n_frames 254 | IVx = np.real(np.conj(X[0, :, :]) * X[1, :, :]) 255 | IVy = np.real(np.conj(X[0, :, :]) * X[2, :, :]) 256 | IVz = np.real(np.conj(X[0, :, :]) * X[3, :, :]) 257 | 258 | normal = np.sqrt(IVx ** 2 + IVy ** 2 + IVz ** 2) + self.eps 259 | IVx = np.dot(self.melW, IVx / normal).T # n_frames x n_mels 260 | IVy = np.dot(self.melW, IVy / normal).T 261 | IVz = np.dot(self.melW, IVz / normal).T 262 | 263 | # add intensity vector to logmel 264 | features.append(np.expand_dims(IVx, axis=0)) 265 | features.append(np.expand_dims(IVy, axis=0)) 266 | features.append(np.expand_dims(IVz, axis=0)) 267 | feature = np.concatenate(features, axis=0) 268 | 269 | return feature 270 | 271 | 272 | class IvExtractor(FeatureExtractor): 273 | """ 274 | Extract Intensity vector from FOA format. 275 | """ 276 | def __init__(self, fs: int, n_fft: int, hop_length: int, n_mels: int, fmin: int = 50, fmax: int = None, 277 | window: str = 'hann'): 278 | """ 279 | :param fs: Sampling rate. 280 | :param n_fft: Number of FFT points. 281 | :param hop_length: Number of sample for hopping. 282 | :param n_mels: Number of mel bands. 283 | :param fmin: Min frequency to extract feature (Hz). 284 | :param fmax: Max frequency to extract feature (Hz). 285 | :param window: Type of window. 286 | """ 287 | super().__init__(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax, window=window) 288 | self.eps = 1e-8 289 | 290 | def extract(self, audio_input: np.ndarray) -> np.ndarray: 291 | """ 292 | :param audio_input: . 293 | :return: feature_logmel . 294 | """ 295 | n_channels = audio_input.shape[0] 296 | features = [] 297 | X = [] 298 | 299 | for i_channel in range(n_channels): 300 | spec = librosa.stft(y=np.asfortranarray(audio_input[i_channel]), 301 | n_fft=self.n_fft, 302 | hop_length=self.hop_length, 303 | center=True, 304 | window=self.window, 305 | pad_mode='reflect') 306 | X.append(np.expand_dims(spec, axis=0)) # 1 x n_bins x n_frames 307 | 308 | # compute intensity vector: for ambisonic signal, n_channels = 4 309 | X = np.concatenate(X, axis=0) # 4 x n_bins x n_frames 310 | IVx = np.real(np.conj(X[0, :, :]) * X[1, :, :]) 311 | IVy = np.real(np.conj(X[0, :, :]) * X[2, :, :]) 312 | IVz = np.real(np.conj(X[0, :, :]) * X[3, :, :]) 313 | 314 | normal = np.sqrt(IVx ** 2 + IVy ** 2 + IVz ** 2) + self.eps 315 | IVx = np.dot(self.melW, IVx / normal).T # n_frames x n_mels 316 | IVy = np.dot(self.melW, IVy / normal).T 317 | IVz = np.dot(self.melW, IVz / normal).T 318 | 319 | # add intensity vector to features 320 | features.append(np.expand_dims(IVx, axis=0)) 321 | features.append(np.expand_dims(IVy, axis=0)) 322 | features.append(np.expand_dims(IVz, axis=0)) 323 | features = np.concatenate(features, axis=0) 324 | 325 | return features 326 | 327 | 328 | class LogMelExtractor(FeatureExtractor): 329 | """ 330 | Extract single-channel or multi-channel logmel spectrograms. 331 | """ 332 | def __init__(self, fs: int, n_fft: int, hop_length: int, n_mels: int, fmin: int = 50, fmax: int = None, 333 | window: str = 'hann'): 334 | """ 335 | :param fs: Sampling rate. 336 | :param n_fft: Number of FFT points. 337 | :param hop_length: Number of sample for hopping. 338 | :param n_mels: Number of mel bands. 339 | :param fmin: Min frequency to extract feature (Hz). 340 | :param fmax: Max frequency to extract feature (Hz). 341 | :param window: Type of window. 342 | """ 343 | super().__init__(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax, window=window) 344 | 345 | def extract(self, audio_input: np.ndarray) -> np.ndarray: 346 | """ 347 | :param audio_input: . 348 | :return: logmel_features . 349 | """ 350 | n_channels = audio_input.shape[0] 351 | logmel_features = [] 352 | 353 | for i_channel in range(n_channels): 354 | spec = np.abs(librosa.stft(y=np.asfortranarray(audio_input[i_channel]), 355 | n_fft=self.n_fft, 356 | hop_length=self.hop_length, 357 | center=True, 358 | window=self.window, 359 | pad_mode='reflect')) 360 | 361 | mel_spec = np.dot(self.melW, spec**2).T 362 | logmel_spec = librosa.power_to_db(mel_spec, ref=1.0, amin=1e-10, top_db=None) 363 | logmel_spec = np.expand_dims(logmel_spec, axis=0) 364 | logmel_features.append(logmel_spec) 365 | 366 | logmel_features = np.concatenate(logmel_features, axis=0) 367 | 368 | return logmel_features 369 | 370 | 371 | def select_extractor(feature_type: str, fs: int, n_fft: int, hop_length: int, n_mels: int, fmin: int, fmax: int = None)\ 372 | -> None: 373 | """ 374 | Select feature extractor based on feature_type. 375 | :param feature_type: Choices are: 376 | 'logmel': logmel. 377 | 'gcc': gcc-phat. 378 | 'logmelgcc': logmel + gcc. 379 | 'iv': intensity vector. 380 | 'logmeliv': logmel + iv 381 | :param fs: Sampling rate. 382 | :param n_fft: Number of FFT points. 383 | :param hop_length: Number of sample for hopping. 384 | :param n_mels: Number of mel bands. 385 | :param fmin: Min frequency to extract feature (Hz). 386 | :param fmax: Max frequency to extract feature (Hz). 387 | """ 388 | if feature_type == 'logmel': 389 | extractor = LogMelExtractor(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax) 390 | elif feature_type == 'iv': 391 | extractor = IvExtractor(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax) 392 | elif feature_type == 'logmeliv': 393 | extractor = LogMelIvExtractor(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax) 394 | elif feature_type == 'gcc': 395 | extractor = GccExtractor(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax) 396 | elif feature_type == 'logmelgcc': 397 | extractor = LogMelGccExtractor(fs=fs, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax) 398 | else: 399 | raise NotImplementedError('Feature type {} is not implemented!'.format(feature_type)) 400 | 401 | return extractor 402 | 403 | 404 | def compute_scaler(feature_dir: str, audio_format: str) -> None: 405 | """ 406 | Compute feature mean and std vectors for normalization. 407 | :param feature_dir: Feature directory that contains train and test folder. 408 | :param audio_format: Audio format, either 'foa' or 'mic' 409 | """ 410 | print('============> Start calculating scaler') 411 | start_time = timer() 412 | 413 | # Get list of feature filenames 414 | train_feature_dir = os.path.join(feature_dir, audio_format + '_dev') 415 | feature_fn_list = os.listdir(train_feature_dir) 416 | 417 | # Get the dimensions of feature by reading one feature files 418 | full_feature_fn = os.path.join(train_feature_dir, feature_fn_list[0]) 419 | with h5py.File(full_feature_fn, 'r') as hf: 420 | afeature = hf['feature'][:] # (n_chanels, n_timesteps, n_features) 421 | n_channels = afeature.shape[0] 422 | scaler_dict = {} 423 | scalar_scaler_dict = {} 424 | for i_chan in range(n_channels): 425 | scaler_dict[i_chan] = preprocessing.StandardScaler() 426 | scalar_scaler_dict[i_chan] = preprocessing.StandardScaler() 427 | 428 | # Iterate through data 429 | for count, feature_fn in enumerate(tqdm(feature_fn_list)): 430 | full_feature_fn = os.path.join(train_feature_dir, feature_fn) 431 | with h5py.File(full_feature_fn, 'r') as hf: 432 | afeature = hf['feature'][:] # (n_chanels, n_timesteps, n_features) 433 | for i_chan in range(n_channels): 434 | scaler_dict[i_chan].partial_fit(afeature[i_chan, :, :]) # (n_timesteps, n_features) 435 | scalar_scaler_dict[i_chan].partial_fit(np.reshape(afeature[i_chan, :, :], (-1, 1))) # (n_timesteps * n_features, 1) 436 | 437 | # Extract mean and std 438 | feature_mean = [] 439 | feature_std = [] 440 | feature_mean_scalar = [] 441 | feature_std_scalar = [] 442 | for i_chan in range(n_channels): 443 | feature_mean.append(scaler_dict[i_chan].mean_) 444 | feature_std.append(np.sqrt(scaler_dict[i_chan].var_)) 445 | feature_mean_scalar.append(scalar_scaler_dict[i_chan].mean_) 446 | feature_std_scalar.append(np.sqrt(scalar_scaler_dict[i_chan].var_)) 447 | 448 | feature_mean = np.array(feature_mean) 449 | feature_std = np.array(feature_std) 450 | feature_mean_scalar = np.array(feature_mean_scalar) 451 | feature_std_scalar = np.array(feature_std_scalar) 452 | 453 | # Expand dims for timesteps: (n_chanels, n_timesteps, n_features) 454 | feature_mean = np.expand_dims(feature_mean, axis=1) 455 | feature_std = np.expand_dims(feature_std, axis=1) 456 | feature_mean_scalar = np.expand_dims(feature_mean_scalar, axis=1) 457 | feature_std_scalar = np.expand_dims(feature_std_scalar, axis=1) 458 | 459 | scaler_path = os.path.join(feature_dir, audio_format + '_feature_scaler.h5') 460 | with h5py.File(scaler_path, 'w') as hf: 461 | hf.create_dataset('mean', data=feature_mean, dtype=np.float32) 462 | hf.create_dataset('std', data=feature_std, dtype=np.float32) 463 | hf.create_dataset('scalar_mean', data=feature_mean_scalar, dtype=np.float32) 464 | hf.create_dataset('scalar_std', data=feature_std_scalar, dtype=np.float32) 465 | 466 | print('Features shape: {}'.format(afeature.shape)) 467 | print('mean {}: {}'.format(feature_mean.shape, feature_mean)) 468 | print('std {}: {}'.format(feature_std.shape, feature_std)) 469 | print('scalar mean {}: {}'.format(feature_mean_scalar.shape, feature_mean_scalar)) 470 | print('scalar std {}: {}'.format(feature_std_scalar.shape, feature_std_scalar)) 471 | print('Scaler path: {}'.format(scaler_path)) 472 | print('Elapsed time: {:.3f} s'.format(timer() - start_time)) 473 | 474 | 475 | def extract_features(data_config: str = 'configs/dcase2020_seld_data_config.yml', feature_type: str = 'logmel_norm', 476 | task: str = 'feature_scaler') -> None: 477 | """ 478 | Extract features 479 | :param data_config: Path to data config file. 480 | :param feature_type: Choices are: 481 | 'logmel': single channel logmel. 482 | 'logmel_norm': logmel with bg normalization. 483 | 'logmel_bg': logmel & logmel with bg normalization. 484 | 'gcc': gcc-phat. 485 | 'logmelgcc': logmel + gcc. 486 | 'iv': intensity vector. 487 | 'logmeliv': logmel + iv 488 | :param task: 'feature_scaler': extract feature and scaler, 'feature': only extract feature, 'scaler': only extract 489 | scaler. 490 | """ 491 | # Load data config files 492 | with open(data_config, 'r') as stream: 493 | try: 494 | cfg = yaml.safe_load(stream) 495 | except yaml.YAMLError as exc: 496 | print(exc) 497 | 498 | # Parse config file 499 | audio_format = cfg['data']['format'] 500 | fs = cfg['data']['fs'] 501 | n_fft = cfg['data']['n_fft'] 502 | hop_length = cfg['data']['hop_len'] 503 | fmin = cfg['data']['fmin'] 504 | fmax = cfg['data']['fmax'] 505 | n_mels = cfg['data']['n_mels'] 506 | fmax = np.min((fmax, fs//2)) 507 | 508 | # Get feature descriptions 509 | feature_description = '{}fs_{}nfft_{}nhop_{}nmels_{}std'.format( 510 | fs, n_fft, hop_length, n_mels, cfg['data']['is_std_norm']) 511 | 512 | # Get feature extractor 513 | feature_extractor = select_extractor(feature_type=feature_type, fs=fs, n_fft=n_fft, hop_length=hop_length, 514 | n_mels=n_mels, fmin=fmin, fmax=fmax) 515 | 516 | if audio_format == 'foa': 517 | splits = ['foa_dev', 'foa_eval'] 518 | elif audio_format == 'mic': 519 | splits = ['mic_dev', 'mic_eval'] 520 | else: 521 | raise ValueError('Unknown audio format {}'.format(audio_format)) 522 | 523 | # Extract features 524 | if task in ['feature_scaler', 'feature']: 525 | for split in splits: 526 | print('============> Start extracting features for {} split'.format(split)) 527 | start_time = timer() 528 | # Required directories 529 | audio_dir = os.path.join(cfg['data_dir'], split) 530 | feature_dir = os.path.join(cfg['feature_dir'], feature_type, feature_description, split) 531 | # Empty feature folder 532 | shutil.rmtree(feature_dir, ignore_errors=True) 533 | os.makedirs(feature_dir, exist_ok=True) 534 | 535 | # Get audio list 536 | audio_fn_list = sorted(os.listdir(audio_dir)) 537 | 538 | # Extract features 539 | for count, audio_fn in enumerate(tqdm(audio_fn_list)): 540 | full_audio_fn = os.path.join(audio_dir, audio_fn) 541 | audio_input, _ = librosa.load(full_audio_fn, sr=fs, mono=False, dtype=np.float32) 542 | if cfg['data']['is_std_norm']: 543 | sig_std = np.std(audio_input) 544 | audio_input = audio_input / sig_std * 0.1 545 | audio_feature = feature_extractor.extract(audio_input) # (n_channels, n_timesteps, n_mels) 546 | 547 | # Write features to file 548 | feature_fn = os.path.join(feature_dir, audio_fn.replace('wav', 'h5')) 549 | with h5py.File(feature_fn, 'w') as hf: 550 | hf.create_dataset('feature', data=audio_feature, dtype=np.float32) 551 | tqdm.write('{}, {}, {}'.format(count, audio_fn, audio_feature.shape)) 552 | 553 | print("Extracting feature finished! Elapsed time: {:.3f} s".format(timer() - start_time)) 554 | 555 | # Compute feature mean and std for train set. For simplification, we use same mean and std for validation and 556 | # evaluation 557 | if task in ['feature_scaler', 'scaler']: 558 | feature_dir = os.path.join(cfg['feature_dir'], feature_type, feature_description) 559 | compute_scaler(feature_dir=feature_dir, audio_format=audio_format) 560 | 561 | 562 | if __name__ == '__main__': 563 | fire.Fire(extract_features) 564 | -------------------------------------------------------------------------------- /dataset/label_mappings.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module includes labels mapping for different dataset and different setting 3 | """ 4 | 5 | 6 | class SELD2020Config: 7 | """ 8 | Data configs for SELD 2020 dataset. 9 | """ 10 | def __init__(self): 11 | self.event_labels = ['alarm', 'baby', 'crash', 'dog', 'engine', 12 | 'female_scream', 'female_speech', 'fire', 'footsteps', 13 | 'knock', 'male_scream', 'male_speech', 'phone', 'piano'] 14 | self.n_classes = len(self.event_labels) 15 | self.lb_to_ix = {lb: i for i, lb in enumerate(self.event_labels)} 16 | self.ix_to_lb = {i: lb for i, lb in enumerate(self.event_labels)} -------------------------------------------------------------------------------- /dataset/meta/eval.csv: -------------------------------------------------------------------------------- 1 | filename 2 | mix001 3 | mix002 4 | mix003 5 | mix004 6 | mix005 7 | mix006 8 | mix007 9 | mix008 10 | mix009 11 | mix010 12 | mix011 13 | mix012 14 | mix013 15 | mix014 16 | mix015 17 | mix016 18 | mix017 19 | mix018 20 | mix019 21 | mix020 22 | mix021 23 | mix022 24 | mix023 25 | mix024 26 | mix025 27 | mix026 28 | mix027 29 | mix028 30 | mix029 31 | mix030 32 | mix031 33 | mix032 34 | mix033 35 | mix034 36 | mix035 37 | mix036 38 | mix037 39 | mix038 40 | mix039 41 | mix040 42 | mix041 43 | mix042 44 | mix043 45 | mix044 46 | mix045 47 | mix046 48 | mix047 49 | mix048 50 | mix049 51 | mix050 52 | mix051 53 | mix052 54 | mix053 55 | mix054 56 | mix055 57 | mix056 58 | mix057 59 | mix058 60 | mix059 61 | mix060 62 | mix061 63 | mix062 64 | mix063 65 | mix064 66 | mix065 67 | mix066 68 | mix067 69 | mix068 70 | mix069 71 | mix070 72 | mix071 73 | mix072 74 | mix073 75 | mix074 76 | mix075 77 | mix076 78 | mix077 79 | mix078 80 | mix079 81 | mix080 82 | mix081 83 | mix082 84 | mix083 85 | mix084 86 | mix085 87 | mix086 88 | mix087 89 | mix088 90 | mix089 91 | mix090 92 | mix091 93 | mix092 94 | mix093 95 | mix094 96 | mix095 97 | mix096 98 | mix097 99 | mix098 100 | mix099 101 | mix100 102 | mix101 103 | mix102 104 | mix103 105 | mix104 106 | mix105 107 | mix106 108 | mix107 109 | mix108 110 | mix109 111 | mix110 112 | mix111 113 | mix112 114 | mix113 115 | mix114 116 | mix115 117 | mix116 118 | mix117 119 | mix118 120 | mix119 121 | mix120 122 | mix121 123 | mix122 124 | mix123 125 | mix124 126 | mix125 127 | mix126 128 | mix127 129 | mix128 130 | mix129 131 | mix130 132 | mix131 133 | mix132 134 | mix133 135 | mix134 136 | mix135 137 | mix136 138 | mix137 139 | mix138 140 | mix139 141 | mix140 142 | mix141 143 | mix142 144 | mix143 145 | mix144 146 | mix145 147 | mix146 148 | mix147 149 | mix148 150 | mix149 151 | mix150 152 | mix151 153 | mix152 154 | mix153 155 | mix154 156 | mix155 157 | mix156 158 | mix157 159 | mix158 160 | mix159 161 | mix160 162 | mix161 163 | mix162 164 | mix163 165 | mix164 166 | mix165 167 | mix166 168 | mix167 169 | mix168 170 | mix169 171 | mix170 172 | mix171 173 | mix172 174 | mix173 175 | mix174 176 | mix175 177 | mix176 178 | mix177 179 | mix178 180 | mix179 181 | mix180 182 | mix181 183 | mix182 184 | mix183 185 | mix184 186 | mix185 187 | mix186 188 | mix187 189 | mix188 190 | mix189 191 | mix190 192 | mix191 193 | mix192 194 | mix193 195 | mix194 196 | mix195 197 | mix196 198 | mix197 199 | mix198 200 | mix199 201 | mix200 202 | -------------------------------------------------------------------------------- /dataset/meta/original/test.csv: -------------------------------------------------------------------------------- 1 | filename 2 | fold1_room1_mix001_ov1 3 | fold1_room1_mix002_ov1 4 | fold1_room1_mix003_ov1 5 | fold1_room1_mix004_ov1 6 | fold1_room1_mix005_ov1 7 | fold1_room1_mix006_ov1 8 | fold1_room1_mix007_ov1 9 | fold1_room1_mix008_ov1 10 | fold1_room1_mix009_ov1 11 | fold1_room1_mix010_ov1 12 | fold1_room1_mix011_ov1 13 | fold1_room1_mix012_ov1 14 | fold1_room1_mix013_ov1 15 | fold1_room1_mix014_ov1 16 | fold1_room1_mix015_ov1 17 | fold1_room1_mix016_ov1 18 | fold1_room1_mix017_ov1 19 | fold1_room1_mix018_ov1 20 | fold1_room1_mix019_ov1 21 | fold1_room1_mix020_ov1 22 | fold1_room1_mix021_ov1 23 | fold1_room1_mix022_ov1 24 | fold1_room1_mix023_ov1 25 | fold1_room1_mix024_ov1 26 | fold1_room1_mix025_ov1 27 | fold1_room1_mix026_ov2 28 | fold1_room1_mix027_ov2 29 | fold1_room1_mix028_ov2 30 | fold1_room1_mix029_ov2 31 | fold1_room1_mix030_ov2 32 | fold1_room1_mix031_ov2 33 | fold1_room1_mix032_ov2 34 | fold1_room1_mix033_ov2 35 | fold1_room1_mix034_ov2 36 | fold1_room1_mix035_ov2 37 | fold1_room1_mix036_ov2 38 | fold1_room1_mix037_ov2 39 | fold1_room1_mix038_ov2 40 | fold1_room1_mix039_ov2 41 | fold1_room1_mix040_ov2 42 | fold1_room1_mix041_ov2 43 | fold1_room1_mix042_ov2 44 | fold1_room1_mix043_ov2 45 | fold1_room1_mix044_ov2 46 | fold1_room1_mix045_ov2 47 | fold1_room1_mix046_ov2 48 | fold1_room1_mix047_ov2 49 | fold1_room1_mix048_ov2 50 | fold1_room1_mix049_ov2 51 | fold1_room1_mix050_ov2 52 | fold1_room2_mix001_ov1 53 | fold1_room2_mix002_ov1 54 | fold1_room2_mix003_ov1 55 | fold1_room2_mix004_ov1 56 | fold1_room2_mix005_ov1 57 | fold1_room2_mix006_ov1 58 | fold1_room2_mix007_ov1 59 | fold1_room2_mix008_ov1 60 | fold1_room2_mix009_ov1 61 | fold1_room2_mix010_ov1 62 | fold1_room2_mix011_ov1 63 | fold1_room2_mix012_ov1 64 | fold1_room2_mix013_ov1 65 | fold1_room2_mix014_ov1 66 | fold1_room2_mix015_ov1 67 | fold1_room2_mix016_ov1 68 | fold1_room2_mix017_ov1 69 | fold1_room2_mix018_ov1 70 | fold1_room2_mix019_ov1 71 | fold1_room2_mix020_ov1 72 | fold1_room2_mix021_ov1 73 | fold1_room2_mix022_ov1 74 | fold1_room2_mix023_ov1 75 | fold1_room2_mix024_ov1 76 | fold1_room2_mix025_ov1 77 | fold1_room2_mix026_ov2 78 | fold1_room2_mix027_ov2 79 | fold1_room2_mix028_ov2 80 | fold1_room2_mix029_ov2 81 | fold1_room2_mix030_ov2 82 | fold1_room2_mix031_ov2 83 | fold1_room2_mix032_ov2 84 | fold1_room2_mix033_ov2 85 | fold1_room2_mix034_ov2 86 | fold1_room2_mix035_ov2 87 | fold1_room2_mix036_ov2 88 | fold1_room2_mix037_ov2 89 | fold1_room2_mix038_ov2 90 | fold1_room2_mix039_ov2 91 | fold1_room2_mix040_ov2 92 | fold1_room2_mix041_ov2 93 | fold1_room2_mix042_ov2 94 | fold1_room2_mix043_ov2 95 | fold1_room2_mix044_ov2 96 | fold1_room2_mix045_ov2 97 | fold1_room2_mix046_ov2 98 | fold1_room2_mix047_ov2 99 | fold1_room2_mix048_ov2 100 | fold1_room2_mix049_ov2 101 | fold1_room2_mix050_ov2 102 | -------------------------------------------------------------------------------- /dataset/meta/original/train.csv: -------------------------------------------------------------------------------- 1 | filename 2 | fold3_room1_mix001_ov1 3 | fold3_room1_mix002_ov1 4 | fold3_room1_mix003_ov1 5 | fold3_room1_mix004_ov1 6 | fold3_room1_mix005_ov1 7 | fold3_room1_mix006_ov1 8 | fold3_room1_mix007_ov1 9 | fold3_room1_mix008_ov1 10 | fold3_room1_mix009_ov1 11 | fold3_room1_mix010_ov1 12 | fold3_room1_mix011_ov1 13 | fold3_room1_mix012_ov1 14 | fold3_room1_mix013_ov1 15 | fold3_room1_mix014_ov1 16 | fold3_room1_mix015_ov1 17 | fold3_room1_mix016_ov1 18 | fold3_room1_mix017_ov1 19 | fold3_room1_mix018_ov1 20 | fold3_room1_mix019_ov1 21 | fold3_room1_mix020_ov1 22 | fold3_room1_mix021_ov1 23 | fold3_room1_mix022_ov1 24 | fold3_room1_mix023_ov1 25 | fold3_room1_mix024_ov1 26 | fold3_room1_mix025_ov1 27 | fold3_room1_mix026_ov2 28 | fold3_room1_mix027_ov2 29 | fold3_room1_mix028_ov2 30 | fold3_room1_mix029_ov2 31 | fold3_room1_mix030_ov2 32 | fold3_room1_mix031_ov2 33 | fold3_room1_mix032_ov2 34 | fold3_room1_mix033_ov2 35 | fold3_room1_mix034_ov2 36 | fold3_room1_mix035_ov2 37 | fold3_room1_mix036_ov2 38 | fold3_room1_mix037_ov2 39 | fold3_room1_mix038_ov2 40 | fold3_room1_mix039_ov2 41 | fold3_room1_mix040_ov2 42 | fold3_room1_mix041_ov2 43 | fold3_room1_mix042_ov2 44 | fold3_room1_mix043_ov2 45 | fold3_room1_mix044_ov2 46 | fold3_room1_mix045_ov2 47 | fold3_room1_mix046_ov2 48 | fold3_room1_mix047_ov2 49 | fold3_room1_mix048_ov2 50 | fold3_room1_mix049_ov2 51 | fold3_room1_mix050_ov2 52 | fold3_room2_mix001_ov1 53 | fold3_room2_mix002_ov1 54 | fold3_room2_mix003_ov1 55 | fold3_room2_mix004_ov1 56 | fold3_room2_mix005_ov1 57 | fold3_room2_mix006_ov1 58 | fold3_room2_mix007_ov1 59 | fold3_room2_mix008_ov1 60 | fold3_room2_mix009_ov1 61 | fold3_room2_mix010_ov1 62 | fold3_room2_mix011_ov1 63 | fold3_room2_mix012_ov1 64 | fold3_room2_mix013_ov1 65 | fold3_room2_mix014_ov1 66 | fold3_room2_mix015_ov1 67 | fold3_room2_mix016_ov1 68 | fold3_room2_mix017_ov1 69 | fold3_room2_mix018_ov1 70 | fold3_room2_mix019_ov1 71 | fold3_room2_mix020_ov1 72 | fold3_room2_mix021_ov1 73 | fold3_room2_mix022_ov1 74 | fold3_room2_mix023_ov1 75 | fold3_room2_mix024_ov1 76 | fold3_room2_mix025_ov1 77 | fold3_room2_mix026_ov2 78 | fold3_room2_mix027_ov2 79 | fold3_room2_mix028_ov2 80 | fold3_room2_mix029_ov2 81 | fold3_room2_mix030_ov2 82 | fold3_room2_mix031_ov2 83 | fold3_room2_mix032_ov2 84 | fold3_room2_mix033_ov2 85 | fold3_room2_mix034_ov2 86 | fold3_room2_mix035_ov2 87 | fold3_room2_mix036_ov2 88 | fold3_room2_mix037_ov2 89 | fold3_room2_mix038_ov2 90 | fold3_room2_mix039_ov2 91 | fold3_room2_mix040_ov2 92 | fold3_room2_mix041_ov2 93 | fold3_room2_mix042_ov2 94 | fold3_room2_mix043_ov2 95 | fold3_room2_mix044_ov2 96 | fold3_room2_mix045_ov2 97 | fold3_room2_mix046_ov2 98 | fold3_room2_mix047_ov2 99 | fold3_room2_mix048_ov2 100 | fold3_room2_mix049_ov2 101 | fold3_room2_mix050_ov2 102 | fold4_room1_mix001_ov1 103 | fold4_room1_mix002_ov1 104 | fold4_room1_mix003_ov1 105 | fold4_room1_mix004_ov1 106 | fold4_room1_mix005_ov1 107 | fold4_room1_mix006_ov1 108 | fold4_room1_mix007_ov1 109 | fold4_room1_mix008_ov1 110 | fold4_room1_mix009_ov1 111 | fold4_room1_mix010_ov1 112 | fold4_room1_mix011_ov1 113 | fold4_room1_mix012_ov1 114 | fold4_room1_mix013_ov1 115 | fold4_room1_mix014_ov1 116 | fold4_room1_mix015_ov1 117 | fold4_room1_mix016_ov1 118 | fold4_room1_mix017_ov1 119 | fold4_room1_mix018_ov1 120 | fold4_room1_mix019_ov1 121 | fold4_room1_mix020_ov1 122 | fold4_room1_mix021_ov1 123 | fold4_room1_mix022_ov1 124 | fold4_room1_mix023_ov1 125 | fold4_room1_mix024_ov1 126 | fold4_room1_mix025_ov1 127 | fold4_room1_mix026_ov2 128 | fold4_room1_mix027_ov2 129 | fold4_room1_mix028_ov2 130 | fold4_room1_mix029_ov2 131 | fold4_room1_mix030_ov2 132 | fold4_room1_mix031_ov2 133 | fold4_room1_mix032_ov2 134 | fold4_room1_mix033_ov2 135 | fold4_room1_mix034_ov2 136 | fold4_room1_mix035_ov2 137 | fold4_room1_mix036_ov2 138 | fold4_room1_mix037_ov2 139 | fold4_room1_mix038_ov2 140 | fold4_room1_mix039_ov2 141 | fold4_room1_mix040_ov2 142 | fold4_room1_mix041_ov2 143 | fold4_room1_mix042_ov2 144 | fold4_room1_mix043_ov2 145 | fold4_room1_mix044_ov2 146 | fold4_room1_mix045_ov2 147 | fold4_room1_mix046_ov2 148 | fold4_room1_mix047_ov2 149 | fold4_room1_mix048_ov2 150 | fold4_room1_mix049_ov2 151 | fold4_room1_mix050_ov2 152 | fold4_room2_mix001_ov1 153 | fold4_room2_mix002_ov1 154 | fold4_room2_mix003_ov1 155 | fold4_room2_mix004_ov1 156 | fold4_room2_mix005_ov1 157 | fold4_room2_mix006_ov1 158 | fold4_room2_mix007_ov1 159 | fold4_room2_mix008_ov1 160 | fold4_room2_mix009_ov1 161 | fold4_room2_mix010_ov1 162 | fold4_room2_mix011_ov1 163 | fold4_room2_mix012_ov1 164 | fold4_room2_mix013_ov1 165 | fold4_room2_mix014_ov1 166 | fold4_room2_mix015_ov1 167 | fold4_room2_mix016_ov1 168 | fold4_room2_mix017_ov1 169 | fold4_room2_mix018_ov1 170 | fold4_room2_mix019_ov1 171 | fold4_room2_mix020_ov1 172 | fold4_room2_mix021_ov1 173 | fold4_room2_mix022_ov1 174 | fold4_room2_mix023_ov1 175 | fold4_room2_mix024_ov1 176 | fold4_room2_mix025_ov1 177 | fold4_room2_mix026_ov2 178 | fold4_room2_mix027_ov2 179 | fold4_room2_mix028_ov2 180 | fold4_room2_mix029_ov2 181 | fold4_room2_mix030_ov2 182 | fold4_room2_mix031_ov2 183 | fold4_room2_mix032_ov2 184 | fold4_room2_mix033_ov2 185 | fold4_room2_mix034_ov2 186 | fold4_room2_mix035_ov2 187 | fold4_room2_mix036_ov2 188 | fold4_room2_mix037_ov2 189 | fold4_room2_mix038_ov2 190 | fold4_room2_mix039_ov2 191 | fold4_room2_mix040_ov2 192 | fold4_room2_mix041_ov2 193 | fold4_room2_mix042_ov2 194 | fold4_room2_mix043_ov2 195 | fold4_room2_mix044_ov2 196 | fold4_room2_mix045_ov2 197 | fold4_room2_mix046_ov2 198 | fold4_room2_mix047_ov2 199 | fold4_room2_mix048_ov2 200 | fold4_room2_mix049_ov2 201 | fold4_room2_mix050_ov2 202 | fold5_room1_mix001_ov1 203 | fold5_room1_mix002_ov1 204 | fold5_room1_mix003_ov1 205 | fold5_room1_mix004_ov1 206 | fold5_room1_mix005_ov1 207 | fold5_room1_mix006_ov1 208 | fold5_room1_mix007_ov1 209 | fold5_room1_mix008_ov1 210 | fold5_room1_mix009_ov1 211 | fold5_room1_mix010_ov1 212 | fold5_room1_mix011_ov1 213 | fold5_room1_mix012_ov1 214 | fold5_room1_mix013_ov1 215 | fold5_room1_mix014_ov1 216 | fold5_room1_mix015_ov1 217 | fold5_room1_mix016_ov1 218 | fold5_room1_mix017_ov1 219 | fold5_room1_mix018_ov1 220 | fold5_room1_mix019_ov1 221 | fold5_room1_mix020_ov1 222 | fold5_room1_mix021_ov1 223 | fold5_room1_mix022_ov1 224 | fold5_room1_mix023_ov1 225 | fold5_room1_mix024_ov1 226 | fold5_room1_mix025_ov1 227 | fold5_room1_mix026_ov2 228 | fold5_room1_mix027_ov2 229 | fold5_room1_mix028_ov2 230 | fold5_room1_mix029_ov2 231 | fold5_room1_mix030_ov2 232 | fold5_room1_mix031_ov2 233 | fold5_room1_mix032_ov2 234 | fold5_room1_mix033_ov2 235 | fold5_room1_mix034_ov2 236 | fold5_room1_mix035_ov2 237 | fold5_room1_mix036_ov2 238 | fold5_room1_mix037_ov2 239 | fold5_room1_mix038_ov2 240 | fold5_room1_mix039_ov2 241 | fold5_room1_mix040_ov2 242 | fold5_room1_mix041_ov2 243 | fold5_room1_mix042_ov2 244 | fold5_room1_mix043_ov2 245 | fold5_room1_mix044_ov2 246 | fold5_room1_mix045_ov2 247 | fold5_room1_mix046_ov2 248 | fold5_room1_mix047_ov2 249 | fold5_room1_mix048_ov2 250 | fold5_room1_mix049_ov2 251 | fold5_room1_mix050_ov2 252 | fold5_room2_mix001_ov1 253 | fold5_room2_mix002_ov1 254 | fold5_room2_mix003_ov1 255 | fold5_room2_mix004_ov1 256 | fold5_room2_mix005_ov1 257 | fold5_room2_mix006_ov1 258 | fold5_room2_mix007_ov1 259 | fold5_room2_mix008_ov1 260 | fold5_room2_mix009_ov1 261 | fold5_room2_mix010_ov1 262 | fold5_room2_mix011_ov1 263 | fold5_room2_mix012_ov1 264 | fold5_room2_mix013_ov1 265 | fold5_room2_mix014_ov1 266 | fold5_room2_mix015_ov1 267 | fold5_room2_mix016_ov1 268 | fold5_room2_mix017_ov1 269 | fold5_room2_mix018_ov1 270 | fold5_room2_mix019_ov1 271 | fold5_room2_mix020_ov1 272 | fold5_room2_mix021_ov1 273 | fold5_room2_mix022_ov1 274 | fold5_room2_mix023_ov1 275 | fold5_room2_mix024_ov1 276 | fold5_room2_mix025_ov1 277 | fold5_room2_mix026_ov2 278 | fold5_room2_mix027_ov2 279 | fold5_room2_mix028_ov2 280 | fold5_room2_mix029_ov2 281 | fold5_room2_mix030_ov2 282 | fold5_room2_mix031_ov2 283 | fold5_room2_mix032_ov2 284 | fold5_room2_mix033_ov2 285 | fold5_room2_mix034_ov2 286 | fold5_room2_mix035_ov2 287 | fold5_room2_mix036_ov2 288 | fold5_room2_mix037_ov2 289 | fold5_room2_mix038_ov2 290 | fold5_room2_mix039_ov2 291 | fold5_room2_mix040_ov2 292 | fold5_room2_mix041_ov2 293 | fold5_room2_mix042_ov2 294 | fold5_room2_mix043_ov2 295 | fold5_room2_mix044_ov2 296 | fold5_room2_mix045_ov2 297 | fold5_room2_mix046_ov2 298 | fold5_room2_mix047_ov2 299 | fold5_room2_mix048_ov2 300 | fold5_room2_mix049_ov2 301 | fold5_room2_mix050_ov2 302 | fold6_room1_mix001_ov1 303 | fold6_room1_mix002_ov1 304 | fold6_room1_mix003_ov1 305 | fold6_room1_mix004_ov1 306 | fold6_room1_mix005_ov1 307 | fold6_room1_mix006_ov1 308 | fold6_room1_mix007_ov1 309 | fold6_room1_mix008_ov1 310 | fold6_room1_mix009_ov1 311 | fold6_room1_mix010_ov1 312 | fold6_room1_mix011_ov1 313 | fold6_room1_mix012_ov1 314 | fold6_room1_mix013_ov1 315 | fold6_room1_mix014_ov1 316 | fold6_room1_mix015_ov1 317 | fold6_room1_mix016_ov1 318 | fold6_room1_mix017_ov1 319 | fold6_room1_mix018_ov1 320 | fold6_room1_mix019_ov1 321 | fold6_room1_mix020_ov1 322 | fold6_room1_mix021_ov1 323 | fold6_room1_mix022_ov1 324 | fold6_room1_mix023_ov1 325 | fold6_room1_mix024_ov1 326 | fold6_room1_mix025_ov1 327 | fold6_room1_mix026_ov1 328 | fold6_room1_mix027_ov1 329 | fold6_room1_mix028_ov1 330 | fold6_room1_mix029_ov1 331 | fold6_room1_mix030_ov1 332 | fold6_room1_mix031_ov1 333 | fold6_room1_mix032_ov1 334 | fold6_room1_mix033_ov1 335 | fold6_room1_mix034_ov1 336 | fold6_room1_mix035_ov1 337 | fold6_room1_mix036_ov1 338 | fold6_room1_mix037_ov1 339 | fold6_room1_mix038_ov1 340 | fold6_room1_mix039_ov1 341 | fold6_room1_mix040_ov1 342 | fold6_room1_mix041_ov1 343 | fold6_room1_mix042_ov1 344 | fold6_room1_mix043_ov1 345 | fold6_room1_mix044_ov1 346 | fold6_room1_mix045_ov1 347 | fold6_room1_mix046_ov1 348 | fold6_room1_mix047_ov1 349 | fold6_room1_mix048_ov1 350 | fold6_room1_mix049_ov1 351 | fold6_room1_mix050_ov1 352 | fold6_room1_mix051_ov2 353 | fold6_room1_mix052_ov2 354 | fold6_room1_mix053_ov2 355 | fold6_room1_mix054_ov2 356 | fold6_room1_mix055_ov2 357 | fold6_room1_mix056_ov2 358 | fold6_room1_mix057_ov2 359 | fold6_room1_mix058_ov2 360 | fold6_room1_mix059_ov2 361 | fold6_room1_mix060_ov2 362 | fold6_room1_mix061_ov2 363 | fold6_room1_mix062_ov2 364 | fold6_room1_mix063_ov2 365 | fold6_room1_mix064_ov2 366 | fold6_room1_mix065_ov2 367 | fold6_room1_mix066_ov2 368 | fold6_room1_mix067_ov2 369 | fold6_room1_mix068_ov2 370 | fold6_room1_mix069_ov2 371 | fold6_room1_mix070_ov2 372 | fold6_room1_mix071_ov2 373 | fold6_room1_mix072_ov2 374 | fold6_room1_mix073_ov2 375 | fold6_room1_mix074_ov2 376 | fold6_room1_mix075_ov2 377 | fold6_room1_mix076_ov2 378 | fold6_room1_mix077_ov2 379 | fold6_room1_mix078_ov2 380 | fold6_room1_mix079_ov2 381 | fold6_room1_mix080_ov2 382 | fold6_room1_mix081_ov2 383 | fold6_room1_mix082_ov2 384 | fold6_room1_mix083_ov2 385 | fold6_room1_mix084_ov2 386 | fold6_room1_mix085_ov2 387 | fold6_room1_mix086_ov2 388 | fold6_room1_mix087_ov2 389 | fold6_room1_mix088_ov2 390 | fold6_room1_mix089_ov2 391 | fold6_room1_mix090_ov2 392 | fold6_room1_mix091_ov2 393 | fold6_room1_mix092_ov2 394 | fold6_room1_mix093_ov2 395 | fold6_room1_mix094_ov2 396 | fold6_room1_mix095_ov2 397 | fold6_room1_mix096_ov2 398 | fold6_room1_mix097_ov2 399 | fold6_room1_mix098_ov2 400 | fold6_room1_mix099_ov2 401 | fold6_room1_mix100_ov2 402 | -------------------------------------------------------------------------------- /dataset/meta/original/trainval.csv: -------------------------------------------------------------------------------- 1 | filename 2 | fold2_room1_mix001_ov1 3 | fold2_room1_mix002_ov1 4 | fold2_room1_mix003_ov1 5 | fold2_room1_mix004_ov1 6 | fold2_room1_mix005_ov1 7 | fold2_room1_mix006_ov1 8 | fold2_room1_mix007_ov1 9 | fold2_room1_mix008_ov1 10 | fold2_room1_mix009_ov1 11 | fold2_room1_mix010_ov1 12 | fold2_room1_mix011_ov1 13 | fold2_room1_mix012_ov1 14 | fold2_room1_mix013_ov1 15 | fold2_room1_mix014_ov1 16 | fold2_room1_mix015_ov1 17 | fold2_room1_mix016_ov1 18 | fold2_room1_mix017_ov1 19 | fold2_room1_mix018_ov1 20 | fold2_room1_mix019_ov1 21 | fold2_room1_mix020_ov1 22 | fold2_room1_mix021_ov1 23 | fold2_room1_mix022_ov1 24 | fold2_room1_mix023_ov1 25 | fold2_room1_mix024_ov1 26 | fold2_room1_mix025_ov1 27 | fold2_room1_mix026_ov2 28 | fold2_room1_mix027_ov2 29 | fold2_room1_mix028_ov2 30 | fold2_room1_mix029_ov2 31 | fold2_room1_mix030_ov2 32 | fold2_room1_mix031_ov2 33 | fold2_room1_mix032_ov2 34 | fold2_room1_mix033_ov2 35 | fold2_room1_mix034_ov2 36 | fold2_room1_mix035_ov2 37 | fold2_room1_mix036_ov2 38 | fold2_room1_mix037_ov2 39 | fold2_room1_mix038_ov2 40 | fold2_room1_mix039_ov2 41 | fold2_room1_mix040_ov2 42 | fold2_room1_mix041_ov2 43 | fold2_room1_mix042_ov2 44 | fold2_room1_mix043_ov2 45 | fold2_room1_mix044_ov2 46 | fold2_room1_mix045_ov2 47 | fold2_room1_mix046_ov2 48 | fold2_room1_mix047_ov2 49 | fold2_room1_mix048_ov2 50 | fold2_room1_mix049_ov2 51 | fold2_room1_mix050_ov2 52 | fold2_room2_mix001_ov1 53 | fold2_room2_mix002_ov1 54 | fold2_room2_mix003_ov1 55 | fold2_room2_mix004_ov1 56 | fold2_room2_mix005_ov1 57 | fold2_room2_mix006_ov1 58 | fold2_room2_mix007_ov1 59 | fold2_room2_mix008_ov1 60 | fold2_room2_mix009_ov1 61 | fold2_room2_mix010_ov1 62 | fold2_room2_mix011_ov1 63 | fold2_room2_mix012_ov1 64 | fold2_room2_mix013_ov1 65 | fold2_room2_mix014_ov1 66 | fold2_room2_mix015_ov1 67 | fold2_room2_mix016_ov1 68 | fold2_room2_mix017_ov1 69 | fold2_room2_mix018_ov1 70 | fold2_room2_mix019_ov1 71 | fold2_room2_mix020_ov1 72 | fold2_room2_mix021_ov1 73 | fold2_room2_mix022_ov1 74 | fold2_room2_mix023_ov1 75 | fold2_room2_mix024_ov1 76 | fold2_room2_mix025_ov1 77 | fold2_room2_mix026_ov2 78 | fold2_room2_mix027_ov2 79 | fold2_room2_mix028_ov2 80 | fold2_room2_mix029_ov2 81 | fold2_room2_mix030_ov2 82 | fold2_room2_mix031_ov2 83 | fold2_room2_mix032_ov2 84 | fold2_room2_mix033_ov2 85 | fold2_room2_mix034_ov2 86 | fold2_room2_mix035_ov2 87 | fold2_room2_mix036_ov2 88 | fold2_room2_mix037_ov2 89 | fold2_room2_mix038_ov2 90 | fold2_room2_mix039_ov2 91 | fold2_room2_mix040_ov2 92 | fold2_room2_mix041_ov2 93 | fold2_room2_mix042_ov2 94 | fold2_room2_mix043_ov2 95 | fold2_room2_mix044_ov2 96 | fold2_room2_mix045_ov2 97 | fold2_room2_mix046_ov2 98 | fold2_room2_mix047_ov2 99 | fold2_room2_mix048_ov2 100 | fold2_room2_mix049_ov2 101 | fold2_room2_mix050_ov2 102 | fold3_room1_mix001_ov1 103 | fold3_room1_mix002_ov1 104 | fold3_room1_mix003_ov1 105 | fold3_room1_mix004_ov1 106 | fold3_room1_mix005_ov1 107 | fold3_room1_mix006_ov1 108 | fold3_room1_mix007_ov1 109 | fold3_room1_mix008_ov1 110 | fold3_room1_mix009_ov1 111 | fold3_room1_mix010_ov1 112 | fold3_room1_mix011_ov1 113 | fold3_room1_mix012_ov1 114 | fold3_room1_mix013_ov1 115 | fold3_room1_mix014_ov1 116 | fold3_room1_mix015_ov1 117 | fold3_room1_mix016_ov1 118 | fold3_room1_mix017_ov1 119 | fold3_room1_mix018_ov1 120 | fold3_room1_mix019_ov1 121 | fold3_room1_mix020_ov1 122 | fold3_room1_mix021_ov1 123 | fold3_room1_mix022_ov1 124 | fold3_room1_mix023_ov1 125 | fold3_room1_mix024_ov1 126 | fold3_room1_mix025_ov1 127 | fold3_room1_mix026_ov2 128 | fold3_room1_mix027_ov2 129 | fold3_room1_mix028_ov2 130 | fold3_room1_mix029_ov2 131 | fold3_room1_mix030_ov2 132 | fold3_room1_mix031_ov2 133 | fold3_room1_mix032_ov2 134 | fold3_room1_mix033_ov2 135 | fold3_room1_mix034_ov2 136 | fold3_room1_mix035_ov2 137 | fold3_room1_mix036_ov2 138 | fold3_room1_mix037_ov2 139 | fold3_room1_mix038_ov2 140 | fold3_room1_mix039_ov2 141 | fold3_room1_mix040_ov2 142 | fold3_room1_mix041_ov2 143 | fold3_room1_mix042_ov2 144 | fold3_room1_mix043_ov2 145 | fold3_room1_mix044_ov2 146 | fold3_room1_mix045_ov2 147 | fold3_room1_mix046_ov2 148 | fold3_room1_mix047_ov2 149 | fold3_room1_mix048_ov2 150 | fold3_room1_mix049_ov2 151 | fold3_room1_mix050_ov2 152 | fold3_room2_mix001_ov1 153 | fold3_room2_mix002_ov1 154 | fold3_room2_mix003_ov1 155 | fold3_room2_mix004_ov1 156 | fold3_room2_mix005_ov1 157 | fold3_room2_mix006_ov1 158 | fold3_room2_mix007_ov1 159 | fold3_room2_mix008_ov1 160 | fold3_room2_mix009_ov1 161 | fold3_room2_mix010_ov1 162 | fold3_room2_mix011_ov1 163 | fold3_room2_mix012_ov1 164 | fold3_room2_mix013_ov1 165 | fold3_room2_mix014_ov1 166 | fold3_room2_mix015_ov1 167 | fold3_room2_mix016_ov1 168 | fold3_room2_mix017_ov1 169 | fold3_room2_mix018_ov1 170 | fold3_room2_mix019_ov1 171 | fold3_room2_mix020_ov1 172 | fold3_room2_mix021_ov1 173 | fold3_room2_mix022_ov1 174 | fold3_room2_mix023_ov1 175 | fold3_room2_mix024_ov1 176 | fold3_room2_mix025_ov1 177 | fold3_room2_mix026_ov2 178 | fold3_room2_mix027_ov2 179 | fold3_room2_mix028_ov2 180 | fold3_room2_mix029_ov2 181 | fold3_room2_mix030_ov2 182 | fold3_room2_mix031_ov2 183 | fold3_room2_mix032_ov2 184 | fold3_room2_mix033_ov2 185 | fold3_room2_mix034_ov2 186 | fold3_room2_mix035_ov2 187 | fold3_room2_mix036_ov2 188 | fold3_room2_mix037_ov2 189 | fold3_room2_mix038_ov2 190 | fold3_room2_mix039_ov2 191 | fold3_room2_mix040_ov2 192 | fold3_room2_mix041_ov2 193 | fold3_room2_mix042_ov2 194 | fold3_room2_mix043_ov2 195 | fold3_room2_mix044_ov2 196 | fold3_room2_mix045_ov2 197 | fold3_room2_mix046_ov2 198 | fold3_room2_mix047_ov2 199 | fold3_room2_mix048_ov2 200 | fold3_room2_mix049_ov2 201 | fold3_room2_mix050_ov2 202 | fold4_room1_mix001_ov1 203 | fold4_room1_mix002_ov1 204 | fold4_room1_mix003_ov1 205 | fold4_room1_mix004_ov1 206 | fold4_room1_mix005_ov1 207 | fold4_room1_mix006_ov1 208 | fold4_room1_mix007_ov1 209 | fold4_room1_mix008_ov1 210 | fold4_room1_mix009_ov1 211 | fold4_room1_mix010_ov1 212 | fold4_room1_mix011_ov1 213 | fold4_room1_mix012_ov1 214 | fold4_room1_mix013_ov1 215 | fold4_room1_mix014_ov1 216 | fold4_room1_mix015_ov1 217 | fold4_room1_mix016_ov1 218 | fold4_room1_mix017_ov1 219 | fold4_room1_mix018_ov1 220 | fold4_room1_mix019_ov1 221 | fold4_room1_mix020_ov1 222 | fold4_room1_mix021_ov1 223 | fold4_room1_mix022_ov1 224 | fold4_room1_mix023_ov1 225 | fold4_room1_mix024_ov1 226 | fold4_room1_mix025_ov1 227 | fold4_room1_mix026_ov2 228 | fold4_room1_mix027_ov2 229 | fold4_room1_mix028_ov2 230 | fold4_room1_mix029_ov2 231 | fold4_room1_mix030_ov2 232 | fold4_room1_mix031_ov2 233 | fold4_room1_mix032_ov2 234 | fold4_room1_mix033_ov2 235 | fold4_room1_mix034_ov2 236 | fold4_room1_mix035_ov2 237 | fold4_room1_mix036_ov2 238 | fold4_room1_mix037_ov2 239 | fold4_room1_mix038_ov2 240 | fold4_room1_mix039_ov2 241 | fold4_room1_mix040_ov2 242 | fold4_room1_mix041_ov2 243 | fold4_room1_mix042_ov2 244 | fold4_room1_mix043_ov2 245 | fold4_room1_mix044_ov2 246 | fold4_room1_mix045_ov2 247 | fold4_room1_mix046_ov2 248 | fold4_room1_mix047_ov2 249 | fold4_room1_mix048_ov2 250 | fold4_room1_mix049_ov2 251 | fold4_room1_mix050_ov2 252 | fold4_room2_mix001_ov1 253 | fold4_room2_mix002_ov1 254 | fold4_room2_mix003_ov1 255 | fold4_room2_mix004_ov1 256 | fold4_room2_mix005_ov1 257 | fold4_room2_mix006_ov1 258 | fold4_room2_mix007_ov1 259 | fold4_room2_mix008_ov1 260 | fold4_room2_mix009_ov1 261 | fold4_room2_mix010_ov1 262 | fold4_room2_mix011_ov1 263 | fold4_room2_mix012_ov1 264 | fold4_room2_mix013_ov1 265 | fold4_room2_mix014_ov1 266 | fold4_room2_mix015_ov1 267 | fold4_room2_mix016_ov1 268 | fold4_room2_mix017_ov1 269 | fold4_room2_mix018_ov1 270 | fold4_room2_mix019_ov1 271 | fold4_room2_mix020_ov1 272 | fold4_room2_mix021_ov1 273 | fold4_room2_mix022_ov1 274 | fold4_room2_mix023_ov1 275 | fold4_room2_mix024_ov1 276 | fold4_room2_mix025_ov1 277 | fold4_room2_mix026_ov2 278 | fold4_room2_mix027_ov2 279 | fold4_room2_mix028_ov2 280 | fold4_room2_mix029_ov2 281 | fold4_room2_mix030_ov2 282 | fold4_room2_mix031_ov2 283 | fold4_room2_mix032_ov2 284 | fold4_room2_mix033_ov2 285 | fold4_room2_mix034_ov2 286 | fold4_room2_mix035_ov2 287 | fold4_room2_mix036_ov2 288 | fold4_room2_mix037_ov2 289 | fold4_room2_mix038_ov2 290 | fold4_room2_mix039_ov2 291 | fold4_room2_mix040_ov2 292 | fold4_room2_mix041_ov2 293 | fold4_room2_mix042_ov2 294 | fold4_room2_mix043_ov2 295 | fold4_room2_mix044_ov2 296 | fold4_room2_mix045_ov2 297 | fold4_room2_mix046_ov2 298 | fold4_room2_mix047_ov2 299 | fold4_room2_mix048_ov2 300 | fold4_room2_mix049_ov2 301 | fold4_room2_mix050_ov2 302 | fold5_room1_mix001_ov1 303 | fold5_room1_mix002_ov1 304 | fold5_room1_mix003_ov1 305 | fold5_room1_mix004_ov1 306 | fold5_room1_mix005_ov1 307 | fold5_room1_mix006_ov1 308 | fold5_room1_mix007_ov1 309 | fold5_room1_mix008_ov1 310 | fold5_room1_mix009_ov1 311 | fold5_room1_mix010_ov1 312 | fold5_room1_mix011_ov1 313 | fold5_room1_mix012_ov1 314 | fold5_room1_mix013_ov1 315 | fold5_room1_mix014_ov1 316 | fold5_room1_mix015_ov1 317 | fold5_room1_mix016_ov1 318 | fold5_room1_mix017_ov1 319 | fold5_room1_mix018_ov1 320 | fold5_room1_mix019_ov1 321 | fold5_room1_mix020_ov1 322 | fold5_room1_mix021_ov1 323 | fold5_room1_mix022_ov1 324 | fold5_room1_mix023_ov1 325 | fold5_room1_mix024_ov1 326 | fold5_room1_mix025_ov1 327 | fold5_room1_mix026_ov2 328 | fold5_room1_mix027_ov2 329 | fold5_room1_mix028_ov2 330 | fold5_room1_mix029_ov2 331 | fold5_room1_mix030_ov2 332 | fold5_room1_mix031_ov2 333 | fold5_room1_mix032_ov2 334 | fold5_room1_mix033_ov2 335 | fold5_room1_mix034_ov2 336 | fold5_room1_mix035_ov2 337 | fold5_room1_mix036_ov2 338 | fold5_room1_mix037_ov2 339 | fold5_room1_mix038_ov2 340 | fold5_room1_mix039_ov2 341 | fold5_room1_mix040_ov2 342 | fold5_room1_mix041_ov2 343 | fold5_room1_mix042_ov2 344 | fold5_room1_mix043_ov2 345 | fold5_room1_mix044_ov2 346 | fold5_room1_mix045_ov2 347 | fold5_room1_mix046_ov2 348 | fold5_room1_mix047_ov2 349 | fold5_room1_mix048_ov2 350 | fold5_room1_mix049_ov2 351 | fold5_room1_mix050_ov2 352 | fold5_room2_mix001_ov1 353 | fold5_room2_mix002_ov1 354 | fold5_room2_mix003_ov1 355 | fold5_room2_mix004_ov1 356 | fold5_room2_mix005_ov1 357 | fold5_room2_mix006_ov1 358 | fold5_room2_mix007_ov1 359 | fold5_room2_mix008_ov1 360 | fold5_room2_mix009_ov1 361 | fold5_room2_mix010_ov1 362 | fold5_room2_mix011_ov1 363 | fold5_room2_mix012_ov1 364 | fold5_room2_mix013_ov1 365 | fold5_room2_mix014_ov1 366 | fold5_room2_mix015_ov1 367 | fold5_room2_mix016_ov1 368 | fold5_room2_mix017_ov1 369 | fold5_room2_mix018_ov1 370 | fold5_room2_mix019_ov1 371 | fold5_room2_mix020_ov1 372 | fold5_room2_mix021_ov1 373 | fold5_room2_mix022_ov1 374 | fold5_room2_mix023_ov1 375 | fold5_room2_mix024_ov1 376 | fold5_room2_mix025_ov1 377 | fold5_room2_mix026_ov2 378 | fold5_room2_mix027_ov2 379 | fold5_room2_mix028_ov2 380 | fold5_room2_mix029_ov2 381 | fold5_room2_mix030_ov2 382 | fold5_room2_mix031_ov2 383 | fold5_room2_mix032_ov2 384 | fold5_room2_mix033_ov2 385 | fold5_room2_mix034_ov2 386 | fold5_room2_mix035_ov2 387 | fold5_room2_mix036_ov2 388 | fold5_room2_mix037_ov2 389 | fold5_room2_mix038_ov2 390 | fold5_room2_mix039_ov2 391 | fold5_room2_mix040_ov2 392 | fold5_room2_mix041_ov2 393 | fold5_room2_mix042_ov2 394 | fold5_room2_mix043_ov2 395 | fold5_room2_mix044_ov2 396 | fold5_room2_mix045_ov2 397 | fold5_room2_mix046_ov2 398 | fold5_room2_mix047_ov2 399 | fold5_room2_mix048_ov2 400 | fold5_room2_mix049_ov2 401 | fold5_room2_mix050_ov2 402 | fold6_room1_mix001_ov1 403 | fold6_room1_mix002_ov1 404 | fold6_room1_mix003_ov1 405 | fold6_room1_mix004_ov1 406 | fold6_room1_mix005_ov1 407 | fold6_room1_mix006_ov1 408 | fold6_room1_mix007_ov1 409 | fold6_room1_mix008_ov1 410 | fold6_room1_mix009_ov1 411 | fold6_room1_mix010_ov1 412 | fold6_room1_mix011_ov1 413 | fold6_room1_mix012_ov1 414 | fold6_room1_mix013_ov1 415 | fold6_room1_mix014_ov1 416 | fold6_room1_mix015_ov1 417 | fold6_room1_mix016_ov1 418 | fold6_room1_mix017_ov1 419 | fold6_room1_mix018_ov1 420 | fold6_room1_mix019_ov1 421 | fold6_room1_mix020_ov1 422 | fold6_room1_mix021_ov1 423 | fold6_room1_mix022_ov1 424 | fold6_room1_mix023_ov1 425 | fold6_room1_mix024_ov1 426 | fold6_room1_mix025_ov1 427 | fold6_room1_mix026_ov1 428 | fold6_room1_mix027_ov1 429 | fold6_room1_mix028_ov1 430 | fold6_room1_mix029_ov1 431 | fold6_room1_mix030_ov1 432 | fold6_room1_mix031_ov1 433 | fold6_room1_mix032_ov1 434 | fold6_room1_mix033_ov1 435 | fold6_room1_mix034_ov1 436 | fold6_room1_mix035_ov1 437 | fold6_room1_mix036_ov1 438 | fold6_room1_mix037_ov1 439 | fold6_room1_mix038_ov1 440 | fold6_room1_mix039_ov1 441 | fold6_room1_mix040_ov1 442 | fold6_room1_mix041_ov1 443 | fold6_room1_mix042_ov1 444 | fold6_room1_mix043_ov1 445 | fold6_room1_mix044_ov1 446 | fold6_room1_mix045_ov1 447 | fold6_room1_mix046_ov1 448 | fold6_room1_mix047_ov1 449 | fold6_room1_mix048_ov1 450 | fold6_room1_mix049_ov1 451 | fold6_room1_mix050_ov1 452 | fold6_room1_mix051_ov2 453 | fold6_room1_mix052_ov2 454 | fold6_room1_mix053_ov2 455 | fold6_room1_mix054_ov2 456 | fold6_room1_mix055_ov2 457 | fold6_room1_mix056_ov2 458 | fold6_room1_mix057_ov2 459 | fold6_room1_mix058_ov2 460 | fold6_room1_mix059_ov2 461 | fold6_room1_mix060_ov2 462 | fold6_room1_mix061_ov2 463 | fold6_room1_mix062_ov2 464 | fold6_room1_mix063_ov2 465 | fold6_room1_mix064_ov2 466 | fold6_room1_mix065_ov2 467 | fold6_room1_mix066_ov2 468 | fold6_room1_mix067_ov2 469 | fold6_room1_mix068_ov2 470 | fold6_room1_mix069_ov2 471 | fold6_room1_mix070_ov2 472 | fold6_room1_mix071_ov2 473 | fold6_room1_mix072_ov2 474 | fold6_room1_mix073_ov2 475 | fold6_room1_mix074_ov2 476 | fold6_room1_mix075_ov2 477 | fold6_room1_mix076_ov2 478 | fold6_room1_mix077_ov2 479 | fold6_room1_mix078_ov2 480 | fold6_room1_mix079_ov2 481 | fold6_room1_mix080_ov2 482 | fold6_room1_mix081_ov2 483 | fold6_room1_mix082_ov2 484 | fold6_room1_mix083_ov2 485 | fold6_room1_mix084_ov2 486 | fold6_room1_mix085_ov2 487 | fold6_room1_mix086_ov2 488 | fold6_room1_mix087_ov2 489 | fold6_room1_mix088_ov2 490 | fold6_room1_mix089_ov2 491 | fold6_room1_mix090_ov2 492 | fold6_room1_mix091_ov2 493 | fold6_room1_mix092_ov2 494 | fold6_room1_mix093_ov2 495 | fold6_room1_mix094_ov2 496 | fold6_room1_mix095_ov2 497 | fold6_room1_mix096_ov2 498 | fold6_room1_mix097_ov2 499 | fold6_room1_mix098_ov2 500 | fold6_room1_mix099_ov2 501 | fold6_room1_mix100_ov2 502 | -------------------------------------------------------------------------------- /dataset/meta/original/val.csv: -------------------------------------------------------------------------------- 1 | filename 2 | fold2_room1_mix001_ov1 3 | fold2_room1_mix002_ov1 4 | fold2_room1_mix003_ov1 5 | fold2_room1_mix004_ov1 6 | fold2_room1_mix005_ov1 7 | fold2_room1_mix006_ov1 8 | fold2_room1_mix007_ov1 9 | fold2_room1_mix008_ov1 10 | fold2_room1_mix009_ov1 11 | fold2_room1_mix010_ov1 12 | fold2_room1_mix011_ov1 13 | fold2_room1_mix012_ov1 14 | fold2_room1_mix013_ov1 15 | fold2_room1_mix014_ov1 16 | fold2_room1_mix015_ov1 17 | fold2_room1_mix016_ov1 18 | fold2_room1_mix017_ov1 19 | fold2_room1_mix018_ov1 20 | fold2_room1_mix019_ov1 21 | fold2_room1_mix020_ov1 22 | fold2_room1_mix021_ov1 23 | fold2_room1_mix022_ov1 24 | fold2_room1_mix023_ov1 25 | fold2_room1_mix024_ov1 26 | fold2_room1_mix025_ov1 27 | fold2_room1_mix026_ov2 28 | fold2_room1_mix027_ov2 29 | fold2_room1_mix028_ov2 30 | fold2_room1_mix029_ov2 31 | fold2_room1_mix030_ov2 32 | fold2_room1_mix031_ov2 33 | fold2_room1_mix032_ov2 34 | fold2_room1_mix033_ov2 35 | fold2_room1_mix034_ov2 36 | fold2_room1_mix035_ov2 37 | fold2_room1_mix036_ov2 38 | fold2_room1_mix037_ov2 39 | fold2_room1_mix038_ov2 40 | fold2_room1_mix039_ov2 41 | fold2_room1_mix040_ov2 42 | fold2_room1_mix041_ov2 43 | fold2_room1_mix042_ov2 44 | fold2_room1_mix043_ov2 45 | fold2_room1_mix044_ov2 46 | fold2_room1_mix045_ov2 47 | fold2_room1_mix046_ov2 48 | fold2_room1_mix047_ov2 49 | fold2_room1_mix048_ov2 50 | fold2_room1_mix049_ov2 51 | fold2_room1_mix050_ov2 52 | fold2_room2_mix001_ov1 53 | fold2_room2_mix002_ov1 54 | fold2_room2_mix003_ov1 55 | fold2_room2_mix004_ov1 56 | fold2_room2_mix005_ov1 57 | fold2_room2_mix006_ov1 58 | fold2_room2_mix007_ov1 59 | fold2_room2_mix008_ov1 60 | fold2_room2_mix009_ov1 61 | fold2_room2_mix010_ov1 62 | fold2_room2_mix011_ov1 63 | fold2_room2_mix012_ov1 64 | fold2_room2_mix013_ov1 65 | fold2_room2_mix014_ov1 66 | fold2_room2_mix015_ov1 67 | fold2_room2_mix016_ov1 68 | fold2_room2_mix017_ov1 69 | fold2_room2_mix018_ov1 70 | fold2_room2_mix019_ov1 71 | fold2_room2_mix020_ov1 72 | fold2_room2_mix021_ov1 73 | fold2_room2_mix022_ov1 74 | fold2_room2_mix023_ov1 75 | fold2_room2_mix024_ov1 76 | fold2_room2_mix025_ov1 77 | fold2_room2_mix026_ov2 78 | fold2_room2_mix027_ov2 79 | fold2_room2_mix028_ov2 80 | fold2_room2_mix029_ov2 81 | fold2_room2_mix030_ov2 82 | fold2_room2_mix031_ov2 83 | fold2_room2_mix032_ov2 84 | fold2_room2_mix033_ov2 85 | fold2_room2_mix034_ov2 86 | fold2_room2_mix035_ov2 87 | fold2_room2_mix036_ov2 88 | fold2_room2_mix037_ov2 89 | fold2_room2_mix038_ov2 90 | fold2_room2_mix039_ov2 91 | fold2_room2_mix040_ov2 92 | fold2_room2_mix041_ov2 93 | fold2_room2_mix042_ov2 94 | fold2_room2_mix043_ov2 95 | fold2_room2_mix044_ov2 96 | fold2_room2_mix045_ov2 97 | fold2_room2_mix046_ov2 98 | fold2_room2_mix047_ov2 99 | fold2_room2_mix048_ov2 100 | fold2_room2_mix049_ov2 101 | fold2_room2_mix050_ov2 102 | -------------------------------------------------------------------------------- /experiments/configs/sed.yml: -------------------------------------------------------------------------------- 1 | # SED config 2 | name: sed 3 | feature_root_dir: '/media/tho_nguyen/disk2/new_seld/dcase2020/features/logmel/24000fs_1024nfft_300nhop_128nmels_Falsestd' 4 | gt_meta_root_dir: '/media/tho_nguyen/disk1/audio_datasets/dcase2020/task3' 5 | split_meta_dir: '/home/tho_nguyen/Documents/work/seld/dataset/meta/original' 6 | seed: 2018 7 | mode: 'crossval' # 'crossval' | 'eval' 8 | task: 'sed' # 'sed'| 'doa' | 'seld' 9 | data: 10 | fs: 24000 11 | n_fft: 1024 12 | hop_len: 300 13 | n_mels: 128 14 | audio_format: 'foa' # 'foa' | 'mic' 15 | label_rate: 10 # Label rate per second 16 | train_chunk_len_s: 4 17 | train_chunk_hop_len_s: 0.5 18 | test_chunk_len_s: 4 19 | test_chunk_hop_len_s: 4 20 | scaler_type: 'vector' 21 | n_classes: 14 22 | train_fraction: 1 23 | val_fraction: 1 24 | model: 25 | encoder: 26 | name: 'Cnn8' 27 | p_dropout: 0.0 28 | pretrained: false 29 | unfreeze_epoch: 0 30 | decoder: 31 | name: 'SedDecoder' 32 | freq_pool: 'avg' # 'avg' | 'max' | 'avg_max' 33 | decoder_type: 'gru' 34 | training: 35 | train_batch_size: 32 36 | val_batch_size: 16 37 | optimizer: 'adam' 38 | lr: 1.e-3 39 | max_epochs: 50 # epoch counting from [0 to n-1] 40 | val_interval: 1 41 | sed_threshold: 0.3 42 | 43 | 44 | -------------------------------------------------------------------------------- /experiments/train.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import logging 3 | import os 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pytorch_lightning.loggers import TensorBoardLogger 9 | 10 | # from models.rfcx import Rfcx 11 | from utilities.builder_utils import build_database, build_datamodule, build_model, build_task 12 | from utilities.experiments_utils import manage_experiments 13 | from utilities.learning_utils import LearningRateScheduler, MyLoggingCallback 14 | 15 | 16 | def train(exp_config: str = './configs/sed.yml', 17 | exp_group_dir: str = '/home/tho_nguyen/Documents/work/seld/outputs/', 18 | exp_suffix: str = '', 19 | resume: bool = False, 20 | empty: bool = False): 21 | """ 22 | Training script 23 | :param exp_config: Config file for experiments 24 | :param exp_group_dir: Parent directory to store all experiment results. 25 | :param exp_suffix: Experiment suffix. 26 | :param resume: If true, resume training from the last epoch. 27 | :param empty: If true, delete all previous data in experiment folder. 28 | """ 29 | # Load config, create folders, logging 30 | cfg = manage_experiments(exp_config=exp_config, exp_group_dir=exp_group_dir, exp_suffix=exp_suffix, empty=empty) 31 | logger = logging.getLogger('lightning') 32 | 33 | # Set random seed for reproducible 34 | pl.seed_everything(cfg.seed) 35 | 36 | # Resume training 37 | if resume: 38 | ckpt_list = [f for f in os.listdir(cfg.dir.model.checkpoint) if f.startswith('epoch') and f.endswith('ckpt')] 39 | if len(ckpt_list) > 0: 40 | resume_from_checkpoint = os.path.join(cfg.dir.model.checkpoint, sorted(ckpt_list)[-1]) 41 | logger.info('Found checkpoint to be resume training at {}'.format(resume_from_checkpoint)) 42 | else: 43 | resume_from_checkpoint = None 44 | else: 45 | resume_from_checkpoint = None 46 | 47 | # Load feature database - will use a builder function build_feature_db to select feature db. 48 | feature_db = build_database(cfg=cfg) 49 | 50 | # Load data module 51 | datamodule = build_datamodule(cfg=cfg, feature_db=feature_db) 52 | 53 | # Model checkpoint 54 | model_checkpoint = ModelCheckpoint(dirpath=cfg.dir.model.checkpoint, filename='{epoch:03d}') 55 | 56 | # Console logger 57 | console_logger = MyLoggingCallback() 58 | 59 | # Tensorboard logger 60 | tb_logger = TensorBoardLogger(save_dir=cfg.dir.tb_dir, name='my_model') 61 | 62 | # Build encoder and decoder 63 | encoder_params = cfg.model.encoder.__dict__ 64 | encoder_kwargs = {'n_input_channels': cfg.data.n_input_channels, **encoder_params} 65 | encoder = build_model(**encoder_kwargs) 66 | decoder_params = cfg.model.decoder.__dict__ 67 | decoder_params = {'n_classes': cfg.data.n_classes, 'encoder_output_channels': encoder.n_output_channels, 68 | **decoder_params} 69 | decoder = build_model(**decoder_params) 70 | 71 | # Build Lightning model 72 | model = build_task(encoder=encoder, decoder=decoder, cfg=cfg) 73 | 74 | # Train 75 | callback_list = [console_logger, model_checkpoint] 76 | trainer = pl.Trainer(gpus=torch.cuda.device_count(), resume_from_checkpoint=resume_from_checkpoint, 77 | max_epochs=cfg.training.max_epochs, logger=tb_logger, progress_bar_refresh_rate=2, 78 | check_val_every_n_epoch=cfg.training.val_interval, 79 | log_every_n_steps=100, flush_logs_every_n_steps=200, 80 | limit_train_batches=cfg.data.train_fraction, limit_val_batches=cfg.data.val_fraction, 81 | callbacks=callback_list) 82 | trainer.fit(model, datamodule) 83 | 84 | 85 | if __name__ == '__main__': 86 | fire.Fire(train) 87 | -------------------------------------------------------------------------------- /figures/crnn_block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomeou/General-network-architecture-for-sound-event-localization-and-detection/03b3aaccf3c87dd8fb857960e765ae768ad36625/figures/crnn_block.png -------------------------------------------------------------------------------- /figures/experimental_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomeou/General-network-architecture-for-sound-event-localization-and-detection/03b3aaccf3c87dd8fb857960e765ae768ad36625/figures/experimental_results.png -------------------------------------------------------------------------------- /figures/model_descriptions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomeou/General-network-architecture-for-sound-event-localization-and-detection/03b3aaccf3c87dd8fb857960e765ae768ad36625/figures/model_descriptions.png -------------------------------------------------------------------------------- /figures/seld_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomeou/General-network-architecture-for-sound-event-localization-and-detection/03b3aaccf3c87dd8fb857960e765ae768ad36625/figures/seld_framework.png -------------------------------------------------------------------------------- /metrics/SELD_evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/sharathadavanne/seld-dcase2020/blob/master/metrics/evaluation_metrics.py 3 | """ 4 | # 5 | # Implements the localization and detection metrics proposed in the paper 6 | # 7 | # Joint Measurement of Localization and Detection of Sound Events 8 | # Annamaria Mesaros, Sharath Adavanne, Archontis Politis, Toni Heittola, Tuomas Virtanen 9 | # WASPAA 2019 10 | # 11 | # 12 | # This script has MIT license 13 | # 14 | 15 | import numpy as np 16 | eps = np.finfo(np.float).eps 17 | from scipy.optimize import linear_sum_assignment 18 | 19 | 20 | class SELDMetrics(object): 21 | def __init__(self, doa_threshold=20, nb_classes=11): 22 | ''' 23 | This class implements both the class-sensitive localization and location-sensitive detection metrics. 24 | Additionally, based on the user input, the corresponding averaging is performed within the segment. 25 | 26 | :param nb_classes: Number of sound classes. In the paper, nb_classes = 11 27 | :param doa_thresh: DOA threshold for location sensitive detection. 28 | ''' 29 | 30 | self._TP = 0 31 | self._FP = 0 32 | self._TN = 0 33 | self._FN = 0 34 | 35 | self._S = 0 36 | self._D = 0 37 | self._I = 0 38 | 39 | self._Nref = 0 40 | self._Nsys = 0 41 | 42 | self._total_DE = 0 43 | self._DE_TP = 0 44 | 45 | self._spatial_T = doa_threshold 46 | self._nb_classes = nb_classes 47 | 48 | def compute_seld_scores(self): 49 | ''' 50 | Collect the final SELD scores 51 | 52 | :return: returns both location-sensitive detection scores and class-sensitive localization scores 53 | ''' 54 | 55 | # Location-senstive detection performance 56 | ER = (self._S + self._D + self._I) / float(self._Nref + eps) 57 | 58 | prec = float(self._TP) / float(self._Nsys + eps) 59 | recall = float(self._TP) / float(self._Nref + eps) 60 | F = 2 * prec * recall / (prec + recall + eps) 61 | 62 | # Class-sensitive localization performance 63 | if self._DE_TP: 64 | DE = self._total_DE / float(self._DE_TP + eps) 65 | else: 66 | # When the total number of prediction is zero 67 | DE = 180 68 | 69 | DE_prec = float(self._DE_TP) / float(self._Nsys + eps) 70 | DE_recall = float(self._DE_TP) / float(self._Nref + eps) 71 | DE_F = 2 * DE_prec * DE_recall / (DE_prec + DE_recall + eps) 72 | 73 | aux_metrics = [self._S/float(self._Nref + eps), 74 | self._D / float(self._Nref + eps), 75 | self._I / float(self._Nref + eps), 76 | prec, 77 | recall, 78 | DE_prec, 79 | DE_recall, 80 | self._Nsys/self._Nref] 81 | 82 | return ER, F, DE, DE_F, aux_metrics 83 | 84 | def update_seld_scores_xyz(self, pred, gt): 85 | ''' 86 | Implements the spatial error averaging according to equation [5] in the paper, using Cartesian distance 87 | 88 | :param pred: dictionary containing class-wise prediction results for each N-seconds segment block 89 | :param gt: dictionary containing class-wise groundtruth for each N-seconds segment block 90 | ''' 91 | for block_cnt in range(len(gt.keys())): 92 | # print('\nblock_cnt', block_cnt, end='') 93 | loc_FN, loc_FP = 0, 0 94 | for class_cnt in range(self._nb_classes): 95 | # print('\tclass:', class_cnt, end='') 96 | # Counting the number of ref and sys outputs should include the number of tracks for each class in the segment 97 | if class_cnt in gt[block_cnt]: 98 | self._Nref += 1 99 | if class_cnt in pred[block_cnt]: 100 | self._Nsys += 1 101 | 102 | if class_cnt in gt[block_cnt] and class_cnt in pred[block_cnt]: 103 | # True positives or False negative case 104 | 105 | # NOTE: For multiple tracks per class, identify multiple tracks using hungarian algorithm and then 106 | # calculate the spatial distance using the following code. In the current code, if there are multiple 107 | # tracks of the same class in a frame we are calculating the least cost between the groundtruth and predicted and using it. 108 | 109 | total_spatial_dist = 0 110 | total_framewise_matching_doa = 0 111 | gt_ind_list = gt[block_cnt][class_cnt][0][0] 112 | pred_ind_list = pred[block_cnt][class_cnt][0][0] 113 | for gt_ind, gt_val in enumerate(gt_ind_list): 114 | if gt_val in pred_ind_list: 115 | total_framewise_matching_doa += 1 116 | pred_ind = pred_ind_list.index(gt_val) 117 | 118 | gt_arr = np.array(gt[block_cnt][class_cnt][0][1][gt_ind]) 119 | pred_arr = np.array(pred[block_cnt][class_cnt][0][1][pred_ind]) 120 | 121 | if gt_arr.shape[0]==1 and pred_arr.shape[0]==1: 122 | total_spatial_dist += distance_between_cartesian_coordinates(gt_arr[0][0], gt_arr[0][1], gt_arr[0][2], pred_arr[0][0], pred_arr[0][1], pred_arr[0][2]) 123 | else: 124 | total_spatial_dist += least_distance_between_gt_pred(gt_arr, pred_arr) 125 | 126 | if total_spatial_dist == 0 and total_framewise_matching_doa == 0: 127 | loc_FN += 1 128 | self._FN += 1 129 | else: 130 | avg_spatial_dist = (total_spatial_dist / total_framewise_matching_doa) 131 | 132 | self._total_DE += avg_spatial_dist 133 | self._DE_TP += 1 134 | 135 | if avg_spatial_dist <= self._spatial_T: 136 | self._TP += 1 137 | else: 138 | loc_FN += 1 139 | self._FN += 1 140 | elif class_cnt in gt[block_cnt] and class_cnt not in pred[block_cnt]: 141 | # False negative 142 | loc_FN += 1 143 | self._FN += 1 144 | elif class_cnt not in gt[block_cnt] and class_cnt in pred[block_cnt]: 145 | # False positive 146 | loc_FP += 1 147 | self._FP += 1 148 | elif class_cnt not in gt[block_cnt] and class_cnt not in pred[block_cnt]: 149 | # True negative 150 | self._TN += 1 151 | 152 | self._S += np.minimum(loc_FP, loc_FN) 153 | self._D += np.maximum(0, loc_FN - loc_FP) 154 | self._I += np.maximum(0, loc_FP - loc_FN) 155 | return 156 | 157 | def update_seld_scores(self, pred_deg, gt_deg): 158 | ''' 159 | Implements the spatial error averaging according to equation [5] in the paper, using Polar distance 160 | Expects the angles in degrees 161 | 162 | :param pred_deg: dictionary containing class-wise prediction results for each N-seconds segment block 163 | :param gt_deg: dictionary containing class-wise groundtruth for each N-seconds segment block 164 | ''' 165 | for block_cnt in range(len(gt_deg.keys())): 166 | # print('\nblock_cnt', block_cnt, end='') 167 | loc_FN, loc_FP = 0, 0 168 | for class_cnt in range(self._nb_classes): 169 | # print('\tclass:', class_cnt, end='') 170 | # Counting the number of ref and sys outputs should include the number of tracks for each class in the segment 171 | if class_cnt in gt_deg[block_cnt]: 172 | self._Nref += 1 173 | if class_cnt in pred_deg[block_cnt]: 174 | self._Nsys += 1 175 | 176 | if class_cnt in gt_deg[block_cnt] and class_cnt in pred_deg[block_cnt]: 177 | # True positives or False negative case 178 | 179 | # NOTE: For multiple tracks per class, identify multiple tracks using hungarian algorithm and then 180 | # calculate the spatial distance using the following code. In the current code, if there are multiple 181 | # tracks of the same class in a frame we are calculating the least cost between the groundtruth and predicted and using it. 182 | total_spatial_dist = 0 183 | total_framewise_matching_doa = 0 184 | gt_ind_list = gt_deg[block_cnt][class_cnt][0][0] 185 | pred_ind_list = pred_deg[block_cnt][class_cnt][0][0] 186 | for gt_ind, gt_val in enumerate(gt_ind_list): 187 | if gt_val in pred_ind_list: 188 | total_framewise_matching_doa += 1 189 | pred_ind = pred_ind_list.index(gt_val) 190 | 191 | gt_arr = np.array(gt_deg[block_cnt][class_cnt][0][1][gt_ind]) * np.pi / 180 192 | pred_arr = np.array(pred_deg[block_cnt][class_cnt][0][1][pred_ind]) * np.pi / 180 193 | if gt_arr.shape[0]==1 and pred_arr.shape[0]==1: 194 | total_spatial_dist += distance_between_spherical_coordinates_rad(gt_arr[0][0], gt_arr[0][1], pred_arr[0][0], pred_arr[0][1]) 195 | else: 196 | total_spatial_dist += least_distance_between_gt_pred(gt_arr, pred_arr) 197 | 198 | if total_spatial_dist == 0 and total_framewise_matching_doa == 0: 199 | loc_FN += 1 200 | self._FN += 1 201 | else: 202 | avg_spatial_dist = (total_spatial_dist / total_framewise_matching_doa) 203 | 204 | self._total_DE += avg_spatial_dist 205 | self._DE_TP += 1 206 | 207 | if avg_spatial_dist <= self._spatial_T: 208 | self._TP += 1 209 | else: 210 | loc_FN += 1 211 | self._FN += 1 212 | elif class_cnt in gt_deg[block_cnt] and class_cnt not in pred_deg[block_cnt]: 213 | # False negative 214 | loc_FN += 1 215 | self._FN += 1 216 | elif class_cnt not in gt_deg[block_cnt] and class_cnt in pred_deg[block_cnt]: 217 | # False positive 218 | loc_FP += 1 219 | self._FP += 1 220 | elif class_cnt not in gt_deg[block_cnt] and class_cnt not in pred_deg[block_cnt]: 221 | # True negative 222 | self._TN += 1 223 | 224 | self._S += np.minimum(loc_FP, loc_FN) 225 | self._D += np.maximum(0, loc_FN - loc_FP) 226 | self._I += np.maximum(0, loc_FP - loc_FN) 227 | return 228 | 229 | 230 | def distance_between_spherical_coordinates_rad(az1, ele1, az2, ele2): 231 | """ 232 | Angular distance between two spherical coordinates 233 | MORE: https://en.wikipedia.org/wiki/Great-circle_distance 234 | 235 | :return: angular distance in degrees 236 | """ 237 | dist = np.sin(ele1) * np.sin(ele2) + np.cos(ele1) * np.cos(ele2) * np.cos(np.abs(az1 - az2)) 238 | # Making sure the dist values are in -1 to 1 range, else np.arccos kills the job 239 | dist = np.clip(dist, -1, 1) 240 | dist = np.arccos(dist) * 180 / np.pi 241 | return dist 242 | 243 | 244 | def distance_between_cartesian_coordinates(x1, y1, z1, x2, y2, z2): 245 | """ 246 | Angular distance between two cartesian coordinates 247 | MORE: https://en.wikipedia.org/wiki/Great-circle_distance 248 | Check 'From chord length' section 249 | 250 | :return: angular distance in degrees 251 | """ 252 | # Normalize the Cartesian vectors 253 | N1 = np.sqrt(x1**2 + y1**2 + z1**2 + 1e-10) 254 | N2 = np.sqrt(x2**2 + y2**2 + z2**2 + 1e-10) 255 | x1, y1, z1, x2, y2, z2 = x1/N1, y1/N1, z1/N1, x2/N2, y2/N2, z2/N2 256 | 257 | #Compute the distance 258 | dist = x1*x2 + y1*y2 + z1*z2 259 | dist = np.clip(dist, -1, 1) 260 | dist = np.arccos(dist) * 180 / np.pi 261 | return dist 262 | 263 | 264 | def least_distance_between_gt_pred(gt_list, pred_list): 265 | """ 266 |         Shortest distance between two sets of DOA coordinates. Given a set of groundtruth coordinates, 267 |         and its respective predicted coordinates, we calculate the distance between each of the 268 |         coordinate pairs resulting in a matrix of distances, where one axis represents the number of groundtruth 269 |         coordinates and the other the predicted coordinates. The number of estimated peaks need not be the same as in 270 |         groundtruth, thus the distance matrix is not always a square matrix. We use the hungarian algorithm to find the 271 |         least cost in this distance matrix. 272 |         :param gt_list_xyz: list of ground-truth Cartesian or Polar coordinates in Radians 273 |         :param pred_list_xyz: list of predicted Carteisan or Polar coordinates in Radians 274 |         :return: cost -  distance 275 |         :return: less - number of DOA's missed 276 |         :return: extra - number of DOA's over-estimated 277 |     """ 278 | gt_len, pred_len = gt_list.shape[0], pred_list.shape[0] 279 | ind_pairs = np.array([[x, y] for y in range(pred_len) for x in range(gt_len)]) 280 | cost_mat = np.zeros((gt_len, pred_len)) 281 | 282 | if gt_len and pred_len: 283 | if len(gt_list[0]) == 3: #Cartesian 284 | x1, y1, z1, x2, y2, z2 = gt_list[ind_pairs[:, 0], 0], gt_list[ind_pairs[:, 0], 1], gt_list[ind_pairs[:, 0], 2], pred_list[ind_pairs[:, 1], 0], pred_list[ind_pairs[:, 1], 1], pred_list[ind_pairs[:, 1], 2] 285 | cost_mat[ind_pairs[:, 0], ind_pairs[:, 1]] = distance_between_cartesian_coordinates(x1, y1, z1, x2, y2, z2) 286 | else: 287 | az1, ele1, az2, ele2 = gt_list[ind_pairs[:, 0], 0], gt_list[ind_pairs[:, 0], 1], pred_list[ind_pairs[:, 1], 0], pred_list[ind_pairs[:, 1], 1] 288 | cost_mat[ind_pairs[:, 0], ind_pairs[:, 1]] = distance_between_spherical_coordinates_rad(az1, ele1, az2, ele2) 289 | 290 | row_ind, col_ind = linear_sum_assignment(cost_mat) 291 | cost = cost_mat[row_ind, col_ind].sum() 292 | return cost 293 | 294 | 295 | def early_stopping_metric(sed_error, doa_error): 296 | """ 297 | Compute early stopping metric from sed and doa errors. 298 | 299 | :param sed_error: [error rate (0 to 1 range), f score (0 to 1 range)] 300 | :param doa_error: [doa error (in degrees), frame recall (0 to 1 range)] 301 | :return: early stopping metric result 302 | """ 303 | seld_metric = np.mean([ 304 | sed_error[0], 305 | 1 - sed_error[1], 306 | doa_error[0]/180, 307 | 1 - doa_error[1]] 308 | ) 309 | return seld_metric 310 | -------------------------------------------------------------------------------- /metrics/pl_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module wrap the metrics inside pytorch lightning metrics so that we can compute running metrics inside 3 | pytorch lightning module 4 | """ 5 | import math 6 | 7 | import numpy as np 8 | import torch 9 | from pytorch_lightning.metrics import Metric 10 | 11 | 12 | class SedMetrics(Metric): 13 | """ 14 | Rewrite class SELDMetrics in evaluation_metrics.py for Sed evaluation. 15 | """ 16 | def __init__(self, nb_frames_1s: int, dist_sync_on_step=False): 17 | """ 18 | :param nb_frames_1s: Number of frames per second. Often this is the label frame rate 19 | """ 20 | super().__init__(dist_sync_on_step=dist_sync_on_step) 21 | 22 | self.block_size = nb_frames_1s 23 | self.eps = 1e-8 24 | self.add_state('S', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx="sum") # dist_reduce_fx is used for distributed training 25 | self.add_state('D', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx="sum") 26 | self.add_state('I', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx="sum") 27 | self.add_state('TP', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx="sum") 28 | self.add_state('Nref', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx="sum") 29 | self.add_state('Nsys', default=torch.tensor(0.0, dtype=torch.float32), dist_reduce_fx="sum") 30 | 31 | def update(self, preds: torch.Tensor, target: torch.Tensor): 32 | """ 33 | :param preds: (batch_size, n_timesteps, n_classes) or (n_timesteps, n_classes) 34 | :param target: (batch_size, n_timesteps, n_classes) or (n_timesteps, n_classes) 35 | """ 36 | assert preds.shape == target.shape 37 | if preds.ndim == 3: 38 | if preds.shape[0] == 1: 39 | preds = torch.squeeze(preds, dim=0) 40 | target = torch.squeeze(target, dim=0) 41 | else: 42 | preds = torch.reshape(preds, (preds.shape[0] * preds.shape[1], -1)) 43 | target = torch.reshape(target, (target.shape[0] * target.shape[1], -1)) 44 | S, D, I = er_overall_1sec(preds, target, self.block_size) 45 | TP, Nref, Nsys = f1_overall_1sec(preds, target, self.block_size) 46 | self.S += S 47 | self.D += D 48 | self.I += I 49 | self.TP += TP 50 | self.Nref += Nref 51 | self.Nsys += Nsys 52 | 53 | def compute(self): 54 | if self.Nref == 0: 55 | self.Nref = self.Nref + 1.0 56 | ER = (self.S + self.D + self.I) / self.Nref 57 | prec = self.TP / (self.Nsys + self.eps) 58 | recall = self.TP / (self.Nref + self.eps) 59 | F = 2 * prec * recall / (prec + recall + self.eps) 60 | return ER, F 61 | 62 | 63 | def f1_overall_1sec(O, T, block_size): 64 | """ 65 | Legacy code, copied from SELD github repo. To compute F1 for SED metrics. 66 | :param O: predictions 67 | :param T: target 68 | :param block_size: number of frames per 1 s. 69 | :return: 70 | """ 71 | new_size = int(math.ceil(float(O.shape[0]) / block_size)) 72 | O_block = torch.zeros((new_size, O.shape[1]), dtype=torch.float32) 73 | T_block = torch.zeros((new_size, O.shape[1]), dtype=torch.float32) 74 | for i in range(0, new_size): 75 | O_block[i, :], _ = torch.max(O[int(i * block_size):int(i * block_size + block_size - 1), :], dim=0) 76 | T_block[i, :], _ = torch.max(T[int(i * block_size):int(i * block_size + block_size - 1), :], dim=0) 77 | TP = ((2 * T_block - O_block) == 1).sum() 78 | Nref, Nsys = T_block.sum(), O_block.sum() 79 | return TP, Nref, Nsys 80 | 81 | 82 | def er_overall_1sec(O, T, block_size): 83 | """ 84 | # TODO combine er_overall_1sec with f1_overall_1sec 85 | Legacy code, copied from SELD github repo. To compute error rate for SED metrics. 86 | :param O: predictions 87 | :param T: target 88 | :param block_size: number of frames per 1 s. 89 | """ 90 | new_size = int(math.ceil(float(O.shape[0]) / block_size)) 91 | O_block = torch.zeros((new_size, O.shape[1])) 92 | T_block = torch.zeros((new_size, O.shape[1])) 93 | for i in range(0, new_size): 94 | O_block[i, :], _ = torch.max(O[int(i * block_size):int(i * block_size + block_size - 1), :], dim=0) 95 | T_block[i, :], _ = torch.max(T[int(i * block_size):int(i * block_size + block_size - 1), :], dim=0) 96 | FP = torch.logical_and(T_block == 0, O_block == 1).sum(1) 97 | FN = torch.logical_and(T_block == 1, O_block == 0).sum(1) 98 | S = torch.minimum(FP, FN).sum() 99 | D = torch.maximum(torch.tensor(0), FN - FP).sum() 100 | I = torch.maximum(torch.tensor(0), FP - FN).sum() 101 | return S, D, I 102 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to construct model. 3 | Reference: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py 4 | The MIT License 5 | """ 6 | import math 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def freeze_model(model): 15 | for name, param in model.named_parameters(): 16 | param.requires_grad = False 17 | 18 | 19 | def unfreeze_model(model): 20 | for name, param in model.named_parameters(): 21 | param.requires_grad = True 22 | 23 | 24 | def interpolate_tensor(tensor, ratio: float = 1.0): 25 | """ 26 | Upsample or Downsample tensor in time dimension 27 | :param tensor: (batch_size, n_timesteps, n_classes) or (n_timesteps, n_classes). Torch tensor. 28 | :param: ratio. If ratio > 1: upsample, ratio < 1: downsample # ratio = output rate/input rate 29 | :return: new_tensor (batch_size, n_timestepss*ratio, n_classes) or (n_timestepss*ratio, n_classes) 30 | """ 31 | ratio = float(ratio) 32 | n_dims = tensor.ndim 33 | 34 | if n_dims == 2: 35 | n_input_frames, n_classes = tensor.shape[0], tensor.shape[1] 36 | 37 | n_output_frames = int(round(n_input_frames * ratio)) 38 | output_idx = torch.arange(n_output_frames) 39 | input_idx = torch.floor(output_idx / ratio).long() 40 | 41 | new_tensor = tensor[input_idx, :] 42 | 43 | elif n_dims == 3: 44 | batch_size, n_input_frames, n_classes = tensor.shape[0], tensor.shape[1], tensor.shape[2] 45 | 46 | n_output_frames = int(round(n_input_frames * ratio)) 47 | output_idx = torch.arange(n_output_frames) 48 | input_idx = torch.floor(output_idx / ratio).long() 49 | 50 | new_tensor = tensor[:, input_idx, :] 51 | 52 | else: 53 | raise NotImplementedError('Interpolate function does not work for tensor dimension above 3') 54 | 55 | return new_tensor 56 | 57 | 58 | def interpolate(x, ratio): 59 | """ 60 | To upsample tensor along time dimension. This is used to compensate the 61 | resolution reduction in downsampling of a CNN. 62 | 63 | Args: 64 | x: (batch_size, time_steps, classes_num) 65 | ratio: int, ratio to interpolate 66 | 67 | Returns: 68 | upsampled: (batch_size, time_steps * ratio, classes_num) 69 | """ 70 | (batch_size, time_steps, classes_num) = x.shape 71 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 72 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 73 | return upsampled 74 | 75 | 76 | def pad_framewise_output(framewise_output, frames_num): 77 | """Pad framewise_output to the same length as input frames. The pad value 78 | is the same as the value of the last frame. 79 | 80 | Args: 81 | framewise_output: (batch_size, frames_num, classes_num) 82 | frames_num: int, number of frames to pad 83 | 84 | Outputs: 85 | output: (batch_size, frames_num, classes_num) 86 | """ 87 | pad = framewise_output[:, -1:, :].repeat(1, frames_num - framewise_output.shape[1], 1) 88 | """tensor for padding""" 89 | 90 | output = torch.cat((framewise_output, pad), dim=1) 91 | """(batch_size, frames_num, classes_num)""" 92 | 93 | return output 94 | 95 | 96 | def init_weights(model): 97 | classname = model.__class__.__name__ 98 | if classname.find("Conv2d") != -1: 99 | nn.init.xavier_uniform_(model.weight, gain=np.sqrt(2)) 100 | model.bias.data.fill_(0) 101 | elif classname.find("BatchNorm") != -1: 102 | model.weight.data.normal_(1.0, 0.02) 103 | model.bias.data.fill_(0) 104 | elif classname.find("GRU") != -1: 105 | for weight in model.parameters(): 106 | if len(weight.size()) > 1: 107 | nn.init.orghogonal_(weight.data) 108 | elif classname.find("Linear") != -1: 109 | model.weight.data.normal_(0, 0.01) 110 | model.bias.data.zero_() 111 | 112 | 113 | def init_layer(layer, method='xavier_uniform'): 114 | """Initialize a Linear or Convolutional layer. """ 115 | if method == 'xavier_uniform': ## default 116 | nn.init.xavier_uniform_(layer.weight) 117 | elif method == 'xavier_normal': 118 | nn.init.xavier_normal_(layer.weight) 119 | elif method == 'kaiming_uniform': ## to try 120 | nn.init.kaiming_uniform_(layer.weight) 121 | elif method == 'kaiming_normal': 122 | nn.init.kaiming_normal_(layer.weight) 123 | elif method == 'orthogonal': 124 | nn.init.orthogonal_(layer.weight) 125 | else: 126 | raise NotImplementedError('init method {} is not implemented'.format(method)) 127 | 128 | if hasattr(layer, 'bias'): 129 | if layer.bias is not None: 130 | layer.bias.data.fill_(0.) 131 | 132 | 133 | def init_bn(bn): 134 | """Initialize a Batchnorm layer. """ 135 | bn.bias.data.fill_(0.) 136 | bn.weight.data.fill_(1.) 137 | 138 | 139 | def init_gru(rnn): 140 | """Initialize a GRU layer. """ 141 | 142 | def _concat_init(tensor, init_funcs): 143 | (length, fan_out) = tensor.shape 144 | fan_in = length // len(init_funcs) 145 | 146 | for (i, init_func) in enumerate(init_funcs): 147 | init_func(tensor[i * fan_in: (i + 1) * fan_in, :]) 148 | 149 | def _inner_uniform(tensor): 150 | fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') 151 | nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) 152 | 153 | for i in range(rnn.num_layers): 154 | _concat_init( 155 | getattr(rnn, 'weight_ih_l{}'.format(i)), 156 | [_inner_uniform, _inner_uniform, _inner_uniform] 157 | ) 158 | torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0) 159 | 160 | _concat_init( 161 | getattr(rnn, 'weight_hh_l{}'.format(i)), 162 | [_inner_uniform, _inner_uniform, nn.init.orthogonal_] 163 | ) 164 | torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0) 165 | 166 | 167 | class ConvBlock(nn.Module): 168 | def __init__(self, in_channels, out_channels): 169 | 170 | super(ConvBlock, self).__init__() 171 | 172 | self.conv1 = nn.Conv2d(in_channels=in_channels, 173 | out_channels=out_channels, 174 | kernel_size=(3, 3), stride=(1, 1), 175 | padding=(1, 1), bias=False) 176 | 177 | self.conv2 = nn.Conv2d(in_channels=out_channels, 178 | out_channels=out_channels, 179 | kernel_size=(3, 3), stride=(1, 1), 180 | padding=(1, 1), bias=False) 181 | 182 | self.bn1 = nn.BatchNorm2d(out_channels) 183 | self.bn2 = nn.BatchNorm2d(out_channels) 184 | 185 | self.init_weight() 186 | 187 | def init_weight(self): 188 | init_layer(self.conv1) 189 | init_layer(self.conv2) 190 | init_bn(self.bn1) 191 | init_bn(self.bn2) 192 | 193 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 194 | x = input 195 | x = F.relu_(self.bn1(self.conv1(x))) 196 | x = F.relu_(self.bn2(self.conv2(x))) 197 | if pool_type == 'max': 198 | x = F.max_pool2d(x, kernel_size=pool_size) 199 | elif pool_type == 'avg': 200 | x = F.avg_pool2d(x, kernel_size=pool_size) 201 | elif pool_type == 'avg+max': 202 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 203 | x2 = F.max_pool2d(x, kernel_size=pool_size) 204 | x = x1 + x2 205 | else: 206 | raise Exception('Incorrect argument!') 207 | 208 | return x 209 | 210 | 211 | class ConvBlock5x5(nn.Module): 212 | def __init__(self, in_channels, out_channels): 213 | 214 | super(ConvBlock5x5, self).__init__() 215 | 216 | self.conv1 = nn.Conv2d(in_channels=in_channels, 217 | out_channels=out_channels, 218 | kernel_size=(5, 5), stride=(1, 1), 219 | padding=(2, 2), bias=False) 220 | 221 | self.bn1 = nn.BatchNorm2d(out_channels) 222 | 223 | self.init_weight() 224 | 225 | def init_weight(self): 226 | init_layer(self.conv1) 227 | init_bn(self.bn1) 228 | 229 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 230 | x = input 231 | x = F.relu_(self.bn1(self.conv1(x))) 232 | if pool_type == 'max': 233 | x = F.max_pool2d(x, kernel_size=pool_size) 234 | elif pool_type == 'avg': 235 | x = F.avg_pool2d(x, kernel_size=pool_size) 236 | elif pool_type == 'avg+max': 237 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 238 | x2 = F.max_pool2d(x, kernel_size=pool_size) 239 | x = x1 + x2 240 | else: 241 | raise Exception('Incorrect argument!') 242 | 243 | return x 244 | 245 | 246 | class AttBlock(nn.Module): 247 | def __init__(self, n_in, n_out, activation='linear'): 248 | super(AttBlock, self).__init__() 249 | 250 | self.activation = activation 251 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 252 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 253 | 254 | self.bn_att = nn.BatchNorm1d(n_out) 255 | self.init_weights() 256 | 257 | def init_weights(self): 258 | init_layer(self.att) 259 | init_layer(self.cla) 260 | init_bn(self.bn_att) 261 | 262 | def forward(self, x): 263 | # x: (n_samples, n_channels/n_features, n_time_steps) 264 | # norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 265 | norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1) # (batch_size, n_classes, n_time_steps) softmax on n_time_steps dimension 266 | cla = self.nonlinear_transform(self.cla(x)) # (batch_size, n_classes, n_time_steps) 267 | x = torch.sum(norm_att * cla, dim=2) # sum over time dimension (batch_size, n_classes) 268 | return x, norm_att, cla 269 | 270 | def nonlinear_transform(self, x): 271 | if self.activation == 'linear': 272 | return x 273 | elif self.activation == 'sigmoid': 274 | return torch.sigmoid(x) 275 | 276 | 277 | class PositionalEncoding(nn.Module): 278 | def __init__(self, pos_len, d_model=512, pe_type='t', dropout=0.1): 279 | """ Positional encoding using sin and cos 280 | Args: 281 | pos_len: positional length 282 | d_model: number of feature maps 283 | pe_type: 't' | 'f' , time domain, frequency domain 284 | dropout: dropout probability 285 | """ 286 | super().__init__() 287 | 288 | self.pe_type = pe_type 289 | pe = torch.zeros(pos_len, d_model) 290 | pos = torch.arange(0, pos_len).float().unsqueeze(1) 291 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) 292 | pe[:, 0::2] = 0.1 * torch.sin(pos * div_term) 293 | pe[:, 1::2] = 0.1 * torch.cos(pos * div_term) 294 | pe = pe.unsqueeze(0).transpose(1, 2) # (N, C, T) 295 | self.register_buffer('pe', pe) 296 | self.dropout = nn.Dropout(p=dropout) 297 | 298 | def forward(self, x): 299 | # x is (N, C, T, F) or (N, C, T) or (N, C, F) 300 | if x.ndim == 4: 301 | if self.pe_type == 't': 302 | pe = self.pe.unsqueeze(3) 303 | x += pe[:, :, :x.shape[2]] 304 | elif self.pe_type == 'f': 305 | pe = self.pe.unsqueeze(2) 306 | x += pe[:, :, :, :x.shape[3]] 307 | elif x.ndim == 3: 308 | x += self.pe[:, :, :x.shape[2]] 309 | return self.dropout(x) 310 | 311 | 312 | def _resnet_conv3x3(in_planes, out_planes): 313 | # 3x3 convolution with padding 314 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 315 | padding=1, groups=1, bias=False, dilation=1) 316 | 317 | 318 | def _resnet_conv1x1(in_planes, out_planes): 319 | # 1x1 convolution 320 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) 321 | 322 | 323 | class _ResnetBasicBlock(nn.Module): 324 | expansion = 1 325 | 326 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 327 | base_width=64, dilation=1, norm_layer=None): 328 | super(_ResnetBasicBlock, self).__init__() 329 | if norm_layer is None: 330 | norm_layer = nn.BatchNorm2d 331 | if groups != 1 or base_width != 64: 332 | raise ValueError('_ResnetBasicBlock only supports groups=1 and base_width=64') 333 | if dilation > 1: 334 | raise NotImplementedError("Dilation > 1 not supported in _ResnetBasicBlock") 335 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 336 | 337 | self.stride = stride 338 | 339 | self.conv1 = _resnet_conv3x3(inplanes, planes) 340 | self.bn1 = norm_layer(planes) 341 | self.relu = nn.ReLU(inplace=True) 342 | self.conv2 = _resnet_conv3x3(planes, planes) 343 | self.bn2 = norm_layer(planes) 344 | self.downsample = downsample 345 | self.stride = stride 346 | 347 | self.init_weights() 348 | 349 | def init_weights(self): 350 | init_layer(self.conv1) 351 | init_bn(self.bn1) 352 | init_layer(self.conv2) 353 | init_bn(self.bn2) 354 | nn.init.constant_(self.bn2.weight, 0) 355 | 356 | def forward(self, x): 357 | identity = x 358 | 359 | if self.stride == 2: 360 | out = F.avg_pool2d(x, kernel_size=(2, 2)) 361 | else: 362 | out = x 363 | 364 | out = self.conv1(out) 365 | out = self.bn1(out) 366 | out = self.relu(out) 367 | out = F.dropout(out, p=0.1, training=self.training) 368 | 369 | out = self.conv2(out) 370 | out = self.bn2(out) 371 | 372 | if self.downsample is not None: 373 | identity = self.downsample(identity) 374 | 375 | out += identity 376 | out = self.relu(out) 377 | 378 | return out 379 | 380 | 381 | class _ResnetBottleneck(nn.Module): 382 | expansion = 4 383 | 384 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 385 | base_width=64, dilation=1, norm_layer=None): 386 | super(_ResnetBottleneck, self).__init__() 387 | if norm_layer is None: 388 | norm_layer = nn.BatchNorm2d 389 | width = int(planes * (base_width / 64.)) * groups 390 | self.stride = stride 391 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 392 | self.conv1 = _resnet_conv1x1(inplanes, width) 393 | self.bn1 = norm_layer(width) 394 | self.conv2 = _resnet_conv3x3(width, width) 395 | self.bn2 = norm_layer(width) 396 | self.conv3 = _resnet_conv1x1(width, planes * self.expansion) 397 | self.bn3 = norm_layer(planes * self.expansion) 398 | self.relu = nn.ReLU(inplace=True) 399 | self.downsample = downsample 400 | self.stride = stride 401 | 402 | self.init_weights() 403 | 404 | def init_weights(self): 405 | init_layer(self.conv1) 406 | init_bn(self.bn1) 407 | init_layer(self.conv2) 408 | init_bn(self.bn2) 409 | init_layer(self.conv3) 410 | init_bn(self.bn3) 411 | nn.init.constant_(self.bn3.weight, 0) 412 | 413 | def forward(self, x): 414 | identity = x 415 | 416 | if self.stride == 2: 417 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 418 | 419 | out = self.conv1(x) 420 | out = self.bn1(out) 421 | out = self.relu(out) 422 | 423 | out = self.conv2(out) 424 | out = self.bn2(out) 425 | out = self.relu(out) 426 | out = F.dropout(out, p=0.1, training=self.training) 427 | 428 | out = self.conv3(out) 429 | out = self.bn3(out) 430 | 431 | if self.downsample is not None: 432 | identity = self.downsample(identity) 433 | 434 | out += identity 435 | out = self.relu(out) 436 | 437 | return out 438 | 439 | 440 | class _ResNet(nn.Module): 441 | def __init__(self, block, layers, zero_init_residual=False, 442 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 443 | norm_layer=None): 444 | super(_ResNet, self).__init__() 445 | 446 | if norm_layer is None: 447 | norm_layer = nn.BatchNorm2d 448 | self._norm_layer = norm_layer 449 | 450 | self.inplanes = 64 451 | self.dilation = 1 452 | if replace_stride_with_dilation is None: 453 | # each element in the tuple indicates if we should replace 454 | # the 2x2 stride with a dilated convolution instead 455 | replace_stride_with_dilation = [False, False, False] 456 | if len(replace_stride_with_dilation) != 3: 457 | raise ValueError("replace_stride_with_dilation should be None " 458 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 459 | self.groups = groups 460 | self.base_width = width_per_group 461 | 462 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 463 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 464 | dilate=replace_stride_with_dilation[0]) 465 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 466 | dilate=replace_stride_with_dilation[1]) 467 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 468 | dilate=replace_stride_with_dilation[2]) 469 | 470 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 471 | norm_layer = self._norm_layer 472 | downsample = None 473 | previous_dilation = self.dilation 474 | if dilate: 475 | self.dilation *= stride 476 | stride = 1 477 | if stride != 1 or self.inplanes != planes * block.expansion: 478 | if stride == 1: 479 | downsample = nn.Sequential( 480 | _resnet_conv1x1(self.inplanes, planes * block.expansion), 481 | norm_layer(planes * block.expansion), 482 | ) 483 | init_layer(downsample[0]) 484 | init_bn(downsample[1]) 485 | elif stride == 2: 486 | downsample = nn.Sequential( 487 | nn.AvgPool2d(kernel_size=2), 488 | _resnet_conv1x1(self.inplanes, planes * block.expansion), 489 | norm_layer(planes * block.expansion), 490 | ) 491 | init_layer(downsample[1]) 492 | init_bn(downsample[2]) 493 | 494 | layers = [] 495 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 496 | self.base_width, previous_dilation, norm_layer)) 497 | self.inplanes = planes * block.expansion 498 | for _ in range(1, blocks): 499 | layers.append(block(self.inplanes, planes, groups=self.groups, 500 | base_width=self.base_width, dilation=self.dilation, 501 | norm_layer=norm_layer)) 502 | 503 | return nn.Sequential(*layers) 504 | 505 | def forward(self, x): 506 | x = self.layer1(x) 507 | x = self.layer2(x) 508 | x = self.layer3(x) 509 | x = self.layer4(x) 510 | 511 | return x 512 | 513 | 514 | class _ResNet3(nn.Module): 515 | def __init__(self, block, layers, zero_init_residual=False, 516 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 517 | norm_layer=None): 518 | super().__init__() 519 | 520 | if norm_layer is None: 521 | norm_layer = nn.BatchNorm2d 522 | self._norm_layer = norm_layer 523 | 524 | self.inplanes = 64 525 | self.dilation = 1 526 | if replace_stride_with_dilation is None: 527 | # each element in the tuple indicates if we should replace 528 | # the 2x2 stride with a dilated convolution instead 529 | replace_stride_with_dilation = [False, False, False] 530 | if len(replace_stride_with_dilation) != 3: 531 | raise ValueError("replace_stride_with_dilation should be None " 532 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 533 | self.groups = groups 534 | self.base_width = width_per_group 535 | 536 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 537 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 538 | dilate=replace_stride_with_dilation[0]) 539 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 540 | dilate=replace_stride_with_dilation[1]) 541 | 542 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 543 | norm_layer = self._norm_layer 544 | downsample = None 545 | previous_dilation = self.dilation 546 | if dilate: 547 | self.dilation *= stride 548 | stride = 1 549 | if stride != 1 or self.inplanes != planes * block.expansion: 550 | if stride == 1: 551 | downsample = nn.Sequential( 552 | _resnet_conv1x1(self.inplanes, planes * block.expansion), 553 | norm_layer(planes * block.expansion), 554 | ) 555 | init_layer(downsample[0]) 556 | init_bn(downsample[1]) 557 | elif stride == 2: 558 | downsample = nn.Sequential( 559 | nn.AvgPool2d(kernel_size=2), 560 | _resnet_conv1x1(self.inplanes, planes * block.expansion), 561 | norm_layer(planes * block.expansion), 562 | ) 563 | init_layer(downsample[1]) 564 | init_bn(downsample[2]) 565 | 566 | layers = [] 567 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 568 | self.base_width, previous_dilation, norm_layer)) 569 | self.inplanes = planes * block.expansion 570 | for _ in range(1, blocks): 571 | layers.append(block(self.inplanes, planes, groups=self.groups, 572 | base_width=self.base_width, dilation=self.dilation, 573 | norm_layer=norm_layer)) 574 | 575 | return nn.Sequential(*layers) 576 | 577 | def forward(self, x): 578 | x = self.layer1(x) 579 | x = self.layer2(x) 580 | x = self.layer3(x) 581 | 582 | return x 583 | 584 | 585 | if __name__ == '__main__': 586 | # test function interpolate_tensor 587 | input_tensor = torch.arange(24) 588 | input_tensor = torch.reshape(input_tensor, (2, -1, 3)) 589 | print(input_tensor.shape) 590 | print(input_tensor) 591 | output_tensor = interpolate_tensor(input_tensor, ratio=0.5) 592 | print(output_tensor.shape) 593 | print(output_tensor) -------------------------------------------------------------------------------- /models/sed_decoders.py: -------------------------------------------------------------------------------- 1 | """ 2 | This modules include decoders for sound classification. 3 | """ 4 | import logging 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from models.model_utils import AttBlock, ConvBlock, init_layer, init_gru, PositionalEncoding 11 | 12 | 13 | class SedDecoder(nn.Module): 14 | """ 15 | Decoder for SED 16 | """ 17 | def __init__(self, encoder_output_channels: int, n_classes: int = 14, freq_pool: str = 'avg', 18 | decoder_type: str = 'gru', **kwargs): 19 | """ 20 | :param encoder_output_channels: Number of output channels/filter of encoder. 21 | :param n_classes: Number of classes. 22 | :param freq_pool: Type of frequency pooling. Choices are: 23 | 'avg': average pooling 24 | 'max': max pooling 25 | 'avg_max': add average and max pooling 26 | :param decoder_type: Choices are: 27 | 'gru': 28 | """ 29 | super().__init__() 30 | self.decoder_input_size = encoder_output_channels 31 | self.n_classes = n_classes 32 | self.freq_pool = freq_pool 33 | self.decoder_type = decoder_type 34 | 35 | if self.decoder_type == 'gru': 36 | self.gru_hidden_size = self.decoder_input_size//2 37 | self.fc_size = self.gru_hidden_size * 2 38 | 39 | self.gru = nn.GRU(input_size=self.decoder_input_size, hidden_size=self.gru_hidden_size, num_layers=1, 40 | batch_first=True, bidirectional=True) 41 | else: 42 | raise NotImplementedError('decoder type {} is not implemented'.format(self.decoder_type)) 43 | 44 | self.event_fc = nn.Linear(self.fc_size, self.n_classes, bias=True) 45 | 46 | self.init_weights() 47 | 48 | def init_weights(self): 49 | if self.decoder_type == 'gru': 50 | init_gru(self.gru) 51 | init_layer(self.event_fc) 52 | 53 | def forward(self, x): 54 | """ 55 | Input x: (batch_size, n_channels, n_timesteps/n_frames (downsampled), n_features/n_freqs (downsampled) 56 | """ 57 | if self.freq_pool == 'avg': 58 | x = torch.mean(x, dim=3) 59 | elif self.freq_pool == 'max': 60 | (x, _) = torch.max(x, dim=3) 61 | elif self.freq_pool == 'avg_max': 62 | x1 = torch.mean(x, dim=3) 63 | (x2, _) = torch.max(x, dim=3) 64 | x = x1 + x 65 | else: 66 | raise ValueError('freq pooling {} is not implemented'.format(self.freq_pool)) 67 | '''(batch_size, feature_maps, time_steps)''' 68 | 69 | if self.decoder_type == 'gru': 70 | x = x.transpose(1, 2) 71 | ''' (batch_size, time_steps, feature_maps):''' 72 | (x, _) = self.gru(x) 73 | else: 74 | raise NotImplementedError('decoder type {} is not implemented'.format(self.decoder_type)) 75 | 76 | x = F.dropout(x, p=0.2, training=self.training) 77 | event_frame_logit = self.event_fc(x) 78 | '''(batch_size, time_steps, class_num)''' 79 | 80 | output = { 81 | 'event_frame_logit': event_frame_logit 82 | } 83 | 84 | return output 85 | -------------------------------------------------------------------------------- /models/sed_encoders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | 8 | from models.model_utils import ConvBlock, init_layer, _ResNet3, _ResNet, _ResnetBasicBlock 9 | 10 | 11 | class PannCnn14L6(nn.Module): 12 | """ 13 | Derived from PANN CNN14 network. PannCnn14L6 has 6 CNN layers (3 convblock) 14 | """ 15 | def __init__(self, n_input_channels: int = 1, p_dropout: float = 0.2, pretrained: bool = False, **kwargs): 16 | """ 17 | :param n_input_channels: Number of input channels. 18 | :param p_dropout: Dropout probability. 19 | :param pretrained: If True, load pretrained model. 20 | """ 21 | super().__init__() 22 | self.n_input_channels = n_input_channels 23 | self.p_dropout = p_dropout 24 | self.n_output_channels = 256 25 | self.time_downsample_ratio = 8 26 | self.freq_downsample_ratio = 8 27 | 28 | self.conv_block1 = ConvBlock(in_channels=n_input_channels, out_channels=64) 29 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 30 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 31 | 32 | # Load pretrained model 33 | self.load_pretrained_weight(pretrained=pretrained) 34 | 35 | def load_pretrained_weight(self, pretrained: bool = False): 36 | logger = logging.getLogger('lightning') 37 | pretrained_path = '../pretrained_models/Cnn14_DecisionLevelAtt_mAP=0.425.pth' 38 | if pretrained: 39 | checkpoint = torch.load(pretrained_path, map_location=lambda storage, loc: storage) 40 | try: 41 | self.load_state_dict(checkpoint['model'], strict=False) 42 | logger.info('Load pretrained weights from checkpoint {}.'.format(pretrained_path)) 43 | except: 44 | logger.info('WARNING: Coud not load pretrained weights from checkpoint {}.'.format(pretrained_path)) 45 | 46 | def forward(self, x): 47 | """ 48 | Input x: (batch_size, n_channels, n_timesteps/n_frames, n_features/n_freqs) 49 | """ 50 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 51 | x = F.dropout(x, p=self.p_dropout, training=self.training) 52 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 53 | x = F.dropout(x, p=self.p_dropout, training=self.training) 54 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 55 | x = F.dropout(x, p=self.p_dropout, training=self.training) 56 | 57 | return x 58 | 59 | @property 60 | def count_number_of_params(self): 61 | n_params = sum([param.numel() for param in self.parameters()]) 62 | n_trainable_params = sum(param.numel() for param in self.parameters() if param.requires_grad) 63 | return n_params, n_trainable_params 64 | 65 | 66 | class PannCnn14L6F64(nn.Module): 67 | """ 68 | Derived from PANN CNN14 network. PannCnn14L6 has 6 CNN layers (3 convblock) 69 | """ 70 | def __init__(self, n_input_channels: int = 1, p_dropout: float = 0.2, pretrained: bool = False, **kwargs): 71 | """ 72 | :param n_input_channels: Number of input channels. 73 | :param p_dropout: Dropout probability. 74 | :param pretrained: If True, load pretrained model. 75 | """ 76 | super().__init__() 77 | self.n_input_channels = n_input_channels 78 | self.p_dropout = p_dropout 79 | self.n_output_channels = 256 80 | self.time_downsample_ratio = 8 81 | self.freq_downsample_ratio = 64 82 | 83 | self.conv_block1 = ConvBlock(in_channels=n_input_channels, out_channels=64) 84 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 85 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 86 | 87 | # Load pretrained model 88 | self.load_pretrained_weight(pretrained=pretrained) 89 | 90 | def load_pretrained_weight(self, pretrained: bool = False): 91 | logger = logging.getLogger('lightning') 92 | pretrained_path = '../pretrained_models/Cnn14_DecisionLevelAtt_mAP=0.425.pth' 93 | if pretrained: 94 | checkpoint = torch.load(pretrained_path, map_location=lambda storage, loc: storage) 95 | try: 96 | self.load_state_dict(checkpoint['model'], strict=False) 97 | logger.info('Load pretrained weights from checkpoint {}.'.format(pretrained_path)) 98 | except: 99 | logger.info('WARNING: Coud not load pretrained weights from checkpoint {}.'.format(pretrained_path)) 100 | 101 | def forward(self, x): 102 | """ 103 | Input x: (batch_size, n_channels, n_timesteps/n_frames, n_features/n_freqs) 104 | """ 105 | x = self.conv_block1(x, pool_size=(2, 4), pool_type='avg') 106 | x = F.dropout(x, p=self.p_dropout, training=self.training) 107 | x = self.conv_block2(x, pool_size=(2, 4), pool_type='avg') 108 | x = F.dropout(x, p=self.p_dropout, training=self.training) 109 | x = self.conv_block3(x, pool_size=(2, 4), pool_type='avg') 110 | x = F.dropout(x, p=self.p_dropout, training=self.training) 111 | 112 | return x 113 | 114 | @property 115 | def count_number_of_params(self): 116 | n_params = sum([param.numel() for param in self.parameters()]) 117 | n_trainable_params = sum(param.numel() for param in self.parameters() if param.requires_grad) 118 | return n_params, n_trainable_params 119 | 120 | 121 | class PannCnn14L8(nn.Module): 122 | """ 123 | Derived from PANN CNN14 network. PannCnn14L8 has 8 CNN layers (4 convblock) 124 | """ 125 | def __init__(self, n_input_channels: int = 1, p_dropout: float = 0.2, pretrained: bool = False, **kwargs): 126 | """ 127 | :param n_input_channels: Number of input channels. 128 | :param p_dropout: Dropout probability. 129 | :param pretrained: If True, load pretrained model. 130 | """ 131 | super().__init__() 132 | self.n_input_channels = n_input_channels 133 | self.p_dropout = p_dropout 134 | self.n_output_channels = 512 135 | self.time_downsample_ratio = 16 136 | self.freq_downsample_ratio = 16 137 | 138 | self.conv_block1 = ConvBlock(in_channels=n_input_channels, out_channels=64) 139 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 140 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 141 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 142 | 143 | # Load pretrained model 144 | self.load_pretrained_weight(pretrained=pretrained) 145 | 146 | def load_pretrained_weight(self, pretrained: bool = False): 147 | logger = logging.getLogger('lightning') 148 | pretrained_path = '../pretrained_models/Cnn14_DecisionLevelAtt_mAP=0.425.pth' 149 | if pretrained: 150 | checkpoint = torch.load(pretrained_path, map_location=lambda storage, loc: storage) 151 | try: 152 | self.load_state_dict(checkpoint['model'], strict=False) 153 | logger.info('Load pretrained weights from checkpoint {}.'.format(pretrained_path)) 154 | except: 155 | logger.info('WARNING: Coud not load pretrained weights from checkpoint {}.'.format(pretrained_path)) 156 | 157 | def forward(self, x): 158 | """ 159 | Input x: (batch_size, n_channels, n_timesteps/n_frames, n_features/n_freqs) 160 | """ 161 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 162 | x = F.dropout(x, p=self.p_dropout, training=self.training) 163 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 164 | x = F.dropout(x, p=self.p_dropout, training=self.training) 165 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 166 | x = F.dropout(x, p=self.p_dropout, training=self.training) 167 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 168 | x = F.dropout(x, p=self.p_dropout, training=self.training) 169 | 170 | return x 171 | 172 | @property 173 | def count_number_of_params(self): 174 | n_params = sum([param.numel() for param in self.parameters()]) 175 | n_trainable_params = sum(param.numel() for param in self.parameters() if param.requires_grad) 176 | return n_params, n_trainable_params 177 | 178 | 179 | if __name__ == '__main__': 180 | encoder = PannCnn14L8() 181 | print(encoder.count_number_of_params) 182 | print(encoder) 183 | -------------------------------------------------------------------------------- /models/sed_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module include code to perform SED task 3 | """ 4 | import logging 5 | import os 6 | from typing import Tuple 7 | 8 | import h5py 9 | import numpy as np 10 | import pandas as pd 11 | import pytorch_lightning as pl 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from metrics.pl_metrics import SedMetrics 17 | from models.model_utils import interpolate_tensor, freeze_model, unfreeze_model 18 | 19 | 20 | class SedModel(pl.LightningModule): 21 | def __init__(self, encoder: nn.Module, decoder: nn.Module, sed_threshold: float = 0.3, label_rate: int = 10, 22 | encoder_unfreeze_epoch: int = 0, optimizer_name: str = 'Adam', lr: float = 1e-3, 23 | output_pred_dir: str = None, submission_filename: str = None, **kwargs): 24 | super().__init__() 25 | self.save_hyperparameters() 26 | self.encoder = encoder 27 | self.decoder = decoder 28 | self.sed_threshold = sed_threshold 29 | self.label_rate = label_rate 30 | self.encoder_unfreeze_epoch = encoder_unfreeze_epoch 31 | self.optimizer_name = optimizer_name 32 | self.lr = lr 33 | self.submission_fn = submission_filename 34 | self.output_pred_dir = output_pred_dir 35 | self.time_downsample_ratio = self.encoder.time_downsample_ratio 36 | self.n_classes = self.decoder.n_classes 37 | self.lit_logger = logging.getLogger('lightning') 38 | 39 | # Metrics 40 | self.train_sed_metrics = SedMetrics(nb_frames_1s=self.label_rate) 41 | self.valid_sed_metrics = SedMetrics(nb_frames_1s=self.label_rate) 42 | self.test_sed_metrics = SedMetrics(nb_frames_1s=self.label_rate) 43 | 44 | # Freeze encoder layer 45 | self.freeze_encoder() 46 | 47 | # Write submission files 48 | self.columns = ['frame_idx', 'event', 'azimuth', 'elevation'] 49 | self.submission = pd.DataFrame(columns=self.columns) 50 | 51 | # Write output prediction for test step 52 | if self.output_pred_dir is not None: 53 | os.makedirs(self.output_pred_dir, exist_ok=True) 54 | 55 | def freeze_encoder(self): 56 | if self.encoder_unfreeze_epoch == -1 or self.encoder_unfreeze_epoch > 0: 57 | freeze_model(self.encoder) 58 | 59 | def forward(self, x): 60 | """ 61 | x: (batch_size, n_channels, n_timesteps (n_frames), n_features). 62 | """ 63 | x = self.encoder(x) # (batch_size, n_channels, n_timesteps, n_features) 64 | output_dict = self.decoder(x) # (batch_size, n_timesteps, n_classes) 65 | # output_dict = { 66 | # "event_frame_logit": event_frame_logit, 67 | # } 68 | return output_dict 69 | 70 | def common_step(self, batch_data): 71 | x, y_sed, _, _ = batch_data 72 | y_sed = interpolate_tensor(y_sed, ratio=1.0 / self.time_downsample_ratio) # to match output dimension 73 | target_dict = { 74 | 'event_frame_gt': y_sed, 75 | } 76 | pred_dict = self.forward(x) 77 | event_frame_output = (torch.sigmoid(pred_dict['event_frame_logit']) > self.sed_threshold).type(torch.float32) 78 | return target_dict, pred_dict, event_frame_output 79 | 80 | def training_step(self, train_batch, batch_idx): 81 | target_dict, pred_dict, event_frame_output = self.common_step(train_batch) 82 | loss = self.compute_loss(target_dict=target_dict, pred_dict=pred_dict) 83 | self.train_sed_metrics(event_frame_output, target_dict['event_frame_gt']) 84 | # logging 85 | self.log('trl', loss, prog_bar=True, logger=True) 86 | training_step_outputs = {'loss': loss, 'event_frame_logit': pred_dict['event_frame_logit'], 87 | 'event_frame_gt': target_dict['event_frame_gt']} 88 | return training_step_outputs 89 | 90 | def training_epoch_end(self, training_step_outputs): 91 | # Unfreeze encoder 92 | if 0 < self.encoder_unfreeze_epoch == self.current_epoch: 93 | unfreeze_model(self.encoder) 94 | self.lit_logger.info('Unfreezing encoder at epoch: {}'.format(self.current_epoch)) 95 | # compute running metric 96 | ER, F1 = self.train_sed_metrics.compute() 97 | sed_error = (ER + 1 - F1)/2 98 | self.log('trER', ER) 99 | self.log('trF1', F1) 100 | self.log('trSedE', sed_error) 101 | self.lit_logger.info('Epoch {} - Training - ER: {:.4f} - F1: {:.4f} - SED error: {:.4f}'.format( 102 | self.current_epoch, ER, F1, sed_error)) 103 | 104 | def validation_step(self, val_batch, batch_idx): 105 | target_dict, pred_dict, event_frame_output = self.common_step(val_batch) 106 | loss = self.compute_loss(target_dict=target_dict, pred_dict=pred_dict) 107 | self.valid_sed_metrics(event_frame_output, target_dict['event_frame_gt']) 108 | # logging 109 | self.log('vall', loss, prog_bar=True, logger=True) 110 | return None 111 | 112 | def validation_epoch_end(self, validation_step_outputs): 113 | # compute running metric for SED 114 | ER, F1 = self.valid_sed_metrics.compute() 115 | sed_error = (ER + 1 - F1) / 2 116 | self.log('valER', ER) 117 | self.log('valF1', F1) 118 | self.log('valSedE', sed_error) 119 | self.lit_logger.info('Epoch {} - Validation - ER: {:.4f} - F1: {:.4f} - SED error: {:.4f}'.format( 120 | self.current_epoch, ER, F1, sed_error)) 121 | 122 | def test_step(self, test_batch, batch_idx): 123 | target_dict, pred_dict, event_frame_output = self.common_step(test_batch) 124 | self.test_sed_metrics(event_frame_output, target_dict['event_frame_gt']) 125 | filenames = test_batch[-1] 126 | # TODO submission file 127 | # # add output to submission dataframe 128 | # self.append_output_prediction(y_pred=y_pred_file, filenames=filenames) 129 | # Write output prediction 130 | if self.output_pred_dir: 131 | h5_filename = os.path.join(self.output_pred_dir, filenames[0] + '.h5') 132 | event_frame_pred = torch.sigmoid(pred_dict['event_frame_logit']).detach().cpu().numpy() 133 | event_frame_gt = target_dict['event_frame_gt'].detach().cpu.numpy() 134 | with h5py.File(h5_filename, 'w') as hf: 135 | hf.create_dataset('event_frame_pred', data=event_frame_pred, dtype=np.float32) 136 | hf.create_dataset('event_frame_gt', data=event_frame_gt, dtype=np.float32) 137 | hf.create_dataset('time_downsample_ratio', data=self.time_downsample_ratio, dtype=np.float32) 138 | return None 139 | 140 | def test_epoch_end(self, test_step_outputs): 141 | ER, F1 = self.test_sed_metrics.compute() 142 | sed_error = (ER + 1 - F1) / 2 143 | self.log('testER', ER) 144 | self.log('testF1', F1) 145 | self.log('testSedE', sed_error) 146 | self.lit_logger.info('Epoch {} - Test - ER: {:.4f} - F1: {:.4f} - SED error: {:.4f}'.format( 147 | self.current_epoch, ER, F1, sed_error)) 148 | # # TODO 149 | # # write to output file 150 | # self.write_output_submission() 151 | 152 | @staticmethod 153 | def compute_loss(target_dict, pred_dict): 154 | # Event frame loss 155 | sed_loss = F.binary_cross_entropy_with_logits(input=pred_dict['event_frame_logit'], 156 | target=target_dict['event_frame']) 157 | return sed_loss 158 | 159 | def configure_optimizers(self): 160 | """ 161 | Pytorch lightning hook 162 | """ 163 | if self.optimizer_name in ['Adam', 'adam']: 164 | optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 165 | elif self.optimizer_name in ['AdamW', 'Adamw', 'adamw']: 166 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) 167 | else: 168 | raise NotImplementedError('Optimizer {} is not implemented!'.format(self.optimizer_name)) 169 | return optimizer 170 | 171 | def append_output_prediction(self, y_pred, filenames): 172 | assert len(set(filenames)) == 1, 'Test batch contains different audio files.' 173 | if self.submission_fn is not None: 174 | filename = filenames[0] 175 | prediction = dict(zip(self.columns[1:], y_pred.cpu().numpy()[0])) 176 | prediction['recording_id'] = filename 177 | self.submission = self.submission.append(prediction, ignore_index=True) 178 | 179 | def write_output_submission(self): 180 | if self.submission_fn is not None: 181 | self.submission.to_csv(self.submission_fn, index=False) 182 | 183 | -------------------------------------------------------------------------------- /paper/general_seld.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomeou/General-network-architecture-for-sound-event-localization-and-detection/03b3aaccf3c87dd8fb857960e765ae768ad36625/paper/general_seld.pdf -------------------------------------------------------------------------------- /pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | Please download a pretrained model name "Cnn14_mAP=0.431.pth" from https://zenodo.org/record/3987831 and put in this folder. -------------------------------------------------------------------------------- /py37_environment.yml: -------------------------------------------------------------------------------- 1 | name: py37 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_gnu 10 | - absl-py=0.11.0=py37h89c1867_0 11 | - aiohttp=3.7.3=py37h4abf009_0 12 | - appdirs=1.4.4=pyh9f0ad1d_0 13 | - argon2-cffi=20.1.0=py37h7b6447c_1 14 | - async-timeout=3.0.1=py_1000 15 | - async_generator=1.10=py37h28b3542_0 16 | - attrs=20.3.0=pyhd3eb1b0_0 17 | - audioread=2.1.8=py37hc8dfbb8_3 18 | - backcall=0.2.0=py_0 19 | - beautifulsoup4=4.9.3=pyhb0f4dca_0 20 | - blas=1.0=mkl 21 | - bleach=3.2.1=py_0 22 | - blinker=1.4=py_1 23 | - brotlipy=0.7.0=py37hb5d75c8_1001 24 | - bzip2=1.0.8=h7f98852_4 25 | - c-ares=1.17.1=h36c2ea0_0 26 | - ca-certificates=2020.12.5=ha878542_0 27 | - cachetools=4.1.1=py_0 28 | - captum=0.3.0=0 29 | - certifi=2020.12.5=py37h89c1867_0 30 | - cffi=1.14.4=py37h261ae71_0 31 | - chardet=3.0.4=py37he5f6b98_1008 32 | - click=7.1.2=pyh9f0ad1d_0 33 | - cryptography=3.2.1=py37hc72a4ac_0 34 | - cudatoolkit=11.0.221=h6bb024c_0 35 | - cycler=0.10.0=py37_0 36 | - cython=0.29.21=py37h2531618_0 37 | - dbus=1.13.18=hb2f20db_0 38 | - decorator=4.4.2=py_0 39 | - defusedxml=0.6.0=py_0 40 | - entrypoints=0.3=py37_0 41 | - expat=2.2.10=he6710b0_2 42 | - ffmpeg=4.3.1=h3215721_1 43 | - fire=0.3.1=pyh9f0ad1d_0 44 | - fontconfig=2.13.0=h9420a91_0 45 | - freetype=2.10.4=h5ab3b9f_0 46 | - fsspec=0.8.4=py_0 47 | - future=0.18.2=py37h89c1867_2 48 | - gettext=0.19.8.1=h0b5b191_1005 49 | - glib=2.66.1=h92f7085_0 50 | - gmp=6.2.1=h58526e2_0 51 | - gnutls=3.6.13=h85f3911_1 52 | - google-auth=1.23.0=pyhd8ed1ab_0 53 | - google-auth-oauthlib=0.4.1=py_2 54 | - grpcio=1.34.0=py37hb27c1af_0 55 | - gst-plugins-base=1.14.0=hbbd80ab_1 56 | - gstreamer=1.14.0=hb31296c_0 57 | - h5py=2.10.0=py37hd6299e0_1 58 | - hdf5=1.10.6=hb1b8bf9_0 59 | - icu=58.2=he6710b0_3 60 | - idna=2.10=pyh9f0ad1d_0 61 | - importlib-metadata=2.0.0=py_1 62 | - importlib_metadata=2.0.0=1 63 | - intel-openmp=2020.2=254 64 | - ipykernel=5.3.4=py37h5ca1d4c_0 65 | - ipython=7.19.0=py37hb070fc8_0 66 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 67 | - ipywidgets=7.5.1=py_1 68 | - jedi=0.17.2=py37h06a4308_1 69 | - jinja2=2.11.2=py_0 70 | - joblib=0.17.0=py_0 71 | - jpeg=9b=h024ee3a_2 72 | - jsonschema=3.2.0=py_2 73 | - jupyter=1.0.0=py37_7 74 | - jupyter_client=6.1.7=py_0 75 | - jupyter_console=6.2.0=py_0 76 | - jupyter_contrib_core=0.3.3=py_2 77 | - jupyter_contrib_nbextensions=0.5.1=py37hc8dfbb8_1 78 | - jupyter_core=4.7.0=py37h06a4308_0 79 | - jupyter_highlight_selected_word=0.2.0=py37hc8dfbb8_1002 80 | - jupyter_latex_envs=1.4.6=py37hc8dfbb8_1001 81 | - jupyter_nbextensions_configurator=0.4.1=py37hc8dfbb8_2 82 | - jupyterlab_pygments=0.1.2=py_0 83 | - kiwisolver=1.3.0=py37h2531618_0 84 | - lame=3.100=h14c3975_1001 85 | - lcms2=2.11=h396b838_0 86 | - ld_impl_linux-64=2.33.1=h53a641e_7 87 | - libedit=3.1.20191231=h14c3975_1 88 | - libffi=3.3=he6710b0_2 89 | - libflac=1.3.3=h9c3ff4c_1 90 | - libgcc-ng=9.3.0=h5dbcf3e_17 91 | - libgfortran-ng=7.3.0=hdf63c60_0 92 | - libgomp=9.3.0=h5dbcf3e_17 93 | - libiconv=1.16=h516909a_0 94 | - libllvm10=10.0.1=he513fc3_3 95 | - libogg=1.3.4=h7f98852_0 96 | - libopus=1.3.1=h7f98852_1 97 | - libpng=1.6.37=hbc83047_0 98 | - libprotobuf=3.14.0=h780b84a_0 99 | - librosa=0.8.0=pyh9f0ad1d_0 100 | - libsndfile=1.0.30=h9c3ff4c_1 101 | - libsodium=1.0.18=h7b6447c_0 102 | - libstdcxx-ng=9.3.0=h2ae2ef3_17 103 | - libtiff=4.1.0=h2733197_1 104 | - libuuid=1.0.3=h1bed415_2 105 | - libuv=1.40.0=h7b6447c_0 106 | - libvorbis=1.3.7=he1b5a44_0 107 | - libxcb=1.14=h7b6447c_0 108 | - libxml2=2.9.10=hb55368b_3 109 | - libxslt=1.1.34=hc22bd24_0 110 | - llvmlite=0.35.0=py37h9d7f4d0_0 111 | - lxml=4.6.2=py37h9120a33_0 112 | - lz4-c=1.9.2=heb0550a_3 113 | - markdown=3.3.3=pyh9f0ad1d_0 114 | - markupsafe=1.1.1=py37h14c3975_1 115 | - matplotlib=3.3.2=0 116 | - matplotlib-base=3.3.2=py37h817c723_0 117 | - mistune=0.8.4=py37h14c3975_1001 118 | - mkl=2020.2=256 119 | - mkl-service=2.3.0=py37he8ac12f_0 120 | - mkl_fft=1.2.0=py37h23d657b_0 121 | - mkl_random=1.1.1=py37h0573a6f_0 122 | - multidict=4.7.5=py37h8f50634_2 123 | - nbclient=0.5.1=py_0 124 | - nbconvert=6.0.7=py37_0 125 | - nbformat=5.0.8=py_0 126 | - ncurses=6.2=he6710b0_1 127 | - nest-asyncio=1.4.3=pyhd3eb1b0_0 128 | - nettle=3.6=he412f7d_0 129 | - ninja=1.10.2=py37hff7bd54_0 130 | - notebook=6.1.4=py37_0 131 | - numba=0.52.0=py37hdc94413_0 132 | - numpy=1.19.2=py37h54aff64_0 133 | - numpy-base=1.19.2=py37hfa32c7d_0 134 | - oauthlib=3.0.1=py_0 135 | - olefile=0.46=py37_0 136 | - openh264=2.1.1=h8b12597_0 137 | - openssl=1.1.1h=h516909a_0 138 | - packaging=20.7=pyhd3eb1b0_0 139 | - pandas=1.1.3=py37he6710b0_0 140 | - pandoc=2.11=hb0f4dca_0 141 | - pandocfilters=1.4.3=py37h06a4308_1 142 | - parso=0.7.0=py_0 143 | - pcre=8.44=he6710b0_0 144 | - pexpect=4.8.0=pyhd3eb1b0_3 145 | - pickleshare=0.7.5=pyhd3eb1b0_1003 146 | - pillow=8.0.1=py37he98fc37_0 147 | - pip=20.3=py37h06a4308_0 148 | - pooch=1.3.0=pyhd8ed1ab_0 149 | - prometheus_client=0.9.0=pyhd3eb1b0_0 150 | - prompt-toolkit=3.0.8=py_0 151 | - prompt_toolkit=3.0.8=0 152 | - protobuf=3.14.0=py37hcd2ae1e_0 153 | - ptyprocess=0.6.0=pyhd3eb1b0_2 154 | - pyasn1=0.4.8=py_0 155 | - pyasn1-modules=0.2.7=py_0 156 | - pycparser=2.20=py_2 157 | - pygments=2.7.2=pyhd3eb1b0_0 158 | - pyjwt=1.7.1=py_0 159 | - pyopenssl=20.0.0=pyhd8ed1ab_0 160 | - pyparsing=2.4.7=py_0 161 | - pyqt=5.9.2=py37h05f1152_2 162 | - pyrsistent=0.17.3=py37h7b6447c_0 163 | - pysocks=1.7.1=py37he5f6b98_2 164 | - pysoundfile=0.10.3.post1=pyhd3deb0d_0 165 | - python=3.7.9=h7579374_0 166 | - python-dateutil=2.8.1=py_0 167 | - python_abi=3.7=1_cp37m 168 | - pytorch=1.7.0=py3.7_cuda11.0.221_cudnn8.0.3_0 169 | - pytorch-lightning=1.0.8=pyhd8ed1ab_0 170 | - pytz=2020.4=pyhd3eb1b0_0 171 | - pyyaml=5.3.1=py37h7b6447c_1 172 | - pyzmq=20.0.0=py37h2531618_1 173 | - qt=5.9.7=h5867ecd_1 174 | - qtconsole=4.7.7=py_0 175 | - qtpy=1.9.0=py_0 176 | - readline=8.0=h7b6447c_0 177 | - requests=2.25.0=pyhd3deb0d_0 178 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 179 | - resampy=0.2.2=py_0 180 | - rsa=4.6=pyh9f0ad1d_0 181 | - scikit-learn=0.23.2=py37h0573a6f_0 182 | - scipy=1.5.2=py37h0b6359f_0 183 | - seaborn=0.11.0=py_0 184 | - send2trash=1.5.0=py37_0 185 | - setuptools=50.3.2=py37h06a4308_2 186 | - sip=4.19.8=py37hf484d3e_0 187 | - six=1.15.0=py37h06a4308_0 188 | - soupsieve=2.0.1=py_0 189 | - sqlite=3.33.0=h62c20be_0 190 | - tensorboard=2.4.0=pyhd8ed1ab_0 191 | - tensorboard-plugin-wit=1.7.0=pyh9f0ad1d_0 192 | - tensorboardx=2.1=py_0 193 | - termcolor=1.1.0=py_2 194 | - terminado=0.9.1=py37_0 195 | - testpath=0.4.4=py_0 196 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 197 | - tk=8.6.10=hbc83047_0 198 | - torchaudio=0.7.0=py37 199 | - torchvision=0.8.1=py37_cu110 200 | - tornado=6.1=py37h27cfd23_0 201 | - tqdm=4.54.1=pyhd8ed1ab_0 202 | - traitlets=5.0.5=py_0 203 | - typing-extensions=3.7.4.3=0 204 | - typing_extensions=3.7.4.3=py_0 205 | - urllib3=1.25.11=py_0 206 | - wcwidth=0.2.5=py_0 207 | - webencodings=0.5.1=py37_1 208 | - werkzeug=1.0.1=pyh9f0ad1d_0 209 | - wheel=0.36.0=pyhd3eb1b0_0 210 | - widgetsnbextension=3.5.1=py37_0 211 | - x264=1!152.20180806=h14c3975_0 212 | - xz=5.2.5=h7b6447c_0 213 | - yaml=0.2.5=h516909a_0 214 | - yarl=1.6.3=py37h4abf009_0 215 | - zeromq=4.3.3=he6710b0_3 216 | - zipp=3.4.0=pyhd3eb1b0_0 217 | - zlib=1.2.11=h7b6447c_3 218 | - zstd=1.4.5=h9ceee32_0 219 | - pip: 220 | - albumentations==0.5.2 221 | - imageio==2.9.0 222 | - imgaug==0.4.0 223 | - networkx==2.5 224 | - opencv-python==4.4.0.46 225 | - opencv-python-headless==4.4.0.46 226 | - pyaudio==0.2.11 227 | - pybind11==2.6.1 228 | - pyroomacoustics==0.4.2 229 | - pywavelets==1.1.1 230 | - scikit-image==0.17.2 231 | - shapely==1.7.1 232 | - tifffile==2020.12.4 233 | - torchsummary==1.5.1 234 | - youtube-dl==2020.12.7 235 | prefix: /home/tho_nguyen/anaconda3/envs/py37 236 | 237 | -------------------------------------------------------------------------------- /utilities/builder_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This modules consists code to select different components for 3 | feature_database 4 | models 5 | """ 6 | import logging 7 | 8 | import pytorch_lightning as pl 9 | import torch 10 | import torch.nn as nn 11 | 12 | import models 13 | from dataset.database import SedDoaDatabase 14 | from dataset.datamodule import SedDoaDataModule 15 | from models.sed_models import SedModel 16 | 17 | 18 | def build_database(cfg): 19 | """ 20 | Function to select database according to task 21 | :param cfg: Experiment config 22 | """ 23 | if cfg.task in ['sed', 'SED', 'doa', 'DOA']: 24 | feature_db = SedDoaDatabase(feature_root_dir=cfg.feature_root_dir, gt_meta_root_dir=cfg.gt_meta_root_dir, 25 | audio_format=cfg.data.audio_format, n_classes=cfg.data.n_classes, fs=cfg.data.fs, 26 | n_fft=cfg.data.n_fft, hop_len=cfg.data.hop_len, label_rate=cfg.data.label_rate, 27 | train_chunk_len_s=cfg.data.train_chunk_len_s, 28 | train_chunk_hop_len_s=cfg.data.train_chunk_hop_len_s, 29 | test_chunk_len_s=cfg.data.test_chunk_len_s, 30 | test_chunk_hop_len_s=cfg.data.test_chunk_hop_len_s, 31 | scaler_type=cfg.data.scaler_type) 32 | elif cfg.task in ['seld', 'SELD']: 33 | pass 34 | else: 35 | raise NotImplementedError('task {} is not implemented'.format(cfg.task)) 36 | 37 | return feature_db 38 | 39 | 40 | def build_datamodule(cfg, feature_db): 41 | """ 42 | Function to select pytorch lightning datamodule according to different tasks. 43 | :param cfg: Experiment config. 44 | :param feature_db: Feature database. 45 | """ 46 | if cfg.task in ['sed', 'SED', 'doa', 'DOA']: 47 | datamodule = SedDoaDataModule(feature_db=feature_db, split_meta_dir=cfg.split_meta_dir, mode=cfg.mode, 48 | train_batch_size=cfg.training.train_batch_size, 49 | val_batch_size=cfg.training.val_batch_size) 50 | 51 | elif cfg.task in ['seld', 'SELD']: 52 | pass 53 | else: 54 | raise NotImplementedError('task {} is not implemented'.format(cfg.task)) 55 | 56 | return datamodule 57 | 58 | 59 | def build_model(name: str, **kwargs) -> nn.Module: 60 | """ 61 | Build encoder. 62 | :param name: Name of the encoder. 63 | :return: encoder model 64 | """ 65 | logger = logging.getLogger('lightning') 66 | # Load model: 67 | model = models.__dict__[name](**kwargs) 68 | logger.info('Finish loading model {}.'.format(name)) 69 | 70 | return model 71 | 72 | 73 | def build_task(encoder, decoder, cfg, **kwargs) -> pl.LightningModule: 74 | """ 75 | Build task 76 | :param encoder: 77 | :param decoder: 78 | :param cfg: 79 | :return: Lightning module 80 | """ 81 | if cfg.task in ['sed', 'Sed', 'SED']: 82 | model = SedModel(encoder=encoder, decoder=decoder, encoder_unfreeze_epoch=cfg.model.encoder.unfreeze_epoch, 83 | sed_threshold=cfg.data.sed_threshold, label_rate=cfg.data.label_rate, 84 | optimizer_name=cfg.training.optimizer) 85 | else: 86 | pass 87 | 88 | return model 89 | 90 | -------------------------------------------------------------------------------- /utilities/experiments_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This modules consists utility functions to organize folders, create logger for experiments. 3 | """ 4 | import logging 5 | import os 6 | import shutil 7 | import random 8 | from datetime import datetime 9 | 10 | import numpy as np 11 | import torch 12 | import yaml 13 | from tqdm import tqdm 14 | 15 | 16 | def manage_experiments(exp_config: str = 'configs/exp1.yml', 17 | exp_group_dir: str = '/home/tho_nguyen/Documents/work/rfcx/outputs/rfcx/', 18 | exp_suffix: str = '_first_exp', 19 | empty: bool = False): 20 | """ 21 | Function to load config, create folder and logging. 22 | :param exp_config: Config file for experiments 23 | :param exp_group_dir: Parent directory to store all experiment results. 24 | :param exp_suffix: Experiment suffix. 25 | :param empty: If true, delete all previous data in experiment folder. 26 | :return: config 27 | """ 28 | # Load data config files 29 | with open(exp_config, 'r') as stream: 30 | try: 31 | cfg_dict = yaml.safe_load(stream) 32 | except yaml.YAMLError as exc: 33 | print(exc) 34 | # Convert dictionary to object 35 | cfg = dict2obj(cfg_dict) 36 | 37 | # Parse feature type from config 38 | cfg.feature_type = os.path.split(os.path.split(cfg.feature_dir)[0])[-1] 39 | n_channels_dict = {'logmel': 4, 'logmelgcc': 10, 'logmeliv': 7, 'gcc': 6, 'iv': 3} 40 | cfg.data.n_input_channels = n_channels_dict[cfg.feature_type] 41 | 42 | # Create experiment folder 43 | exp_name = os.path.splitext(os.path.basename(exp_config))[0] + exp_suffix 44 | create_exp_folders(cfg=cfg, exp_group_dir=exp_group_dir, exp_name=exp_name, empty=empty) 45 | 46 | # Create logging 47 | create_logging(log_dir=cfg.dir.logs_dir, filemode='a') 48 | 49 | # Write config file to output folder 50 | yaml_config_fn = os.path.join(cfg.dir.config_dir, 51 | 'exp_config_{}.yml'.format(datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))) 52 | write_yaml_config(output_filename=yaml_config_fn, config_dict=cfg_dict) 53 | logger = logging.getLogger('lightning') 54 | logger.info('Write yaml config file to {}'.format(cfg.dir.config_dir)) 55 | logger.info('Finish parsing config file: {}.'.format(exp_config)) 56 | if empty: 57 | logger.info('Clear all directories.') 58 | 59 | return cfg 60 | 61 | 62 | class DummyClass: 63 | """ 64 | Dummy class for entry of config object. 65 | """ 66 | pass 67 | 68 | 69 | def dict2obj(d): 70 | """ 71 | Convert nested dictionary to object. 72 | Copied from https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/ 73 | :param d: Dictionary. 74 | :return: object. 75 | """ 76 | # checking whether object d is a 77 | # instance of class list 78 | if isinstance(d, list): 79 | d = [dict2obj(x) for x in d] 80 | 81 | # if d is not a instance of dict then 82 | # directly object is returned 83 | if not isinstance(d, dict): 84 | return d 85 | 86 | # declaring a class 87 | 88 | # constructor of the class passed to obj 89 | obj = DummyClass() 90 | 91 | for k in d: 92 | obj.__dict__[k] = dict2obj(d[k]) 93 | 94 | return obj 95 | 96 | 97 | def create_empty_folder(folder_name) -> None: 98 | shutil.rmtree(folder_name, ignore_errors=True) 99 | os.makedirs(folder_name, exist_ok=True) 100 | 101 | 102 | def create_exp_folders(cfg, exp_group_dir: str = '', exp_name: str = '', empty: bool = False) -> None: 103 | """ 104 | Create folders required for experiments. 105 | :param cfg: Experiment config object. 106 | :param exp_group_dir: Experiment directory. 107 | :param exp_name: Experiment name. 108 | :param empty: If true, delete all previous data in experiment folder. 109 | """ 110 | # 1. Experiment directory 111 | cfg.dir = DummyClass() 112 | cfg.dir.exp_dir = os.path.join(exp_group_dir, cfg.mode, cfg.task, cfg.feature_type, exp_name) 113 | if empty: 114 | create_empty_folder(cfg.dir.exp_dir) 115 | else: 116 | os.makedirs(cfg.dir.exp_dir, exist_ok=True) 117 | 118 | # 2. config directory 119 | cfg.dir.config_dir = os.path.join(cfg.dir.exp_dir, 'configs') 120 | os.makedirs(cfg.dir.config_dir, exist_ok=True) 121 | 122 | # 3. log directory 123 | cfg.dir.logs_dir = os.path.join(cfg.dir.exp_dir, 'logs') 124 | os.makedirs(cfg.dir.logs_dir, exist_ok=True) 125 | 126 | # 4. tensorboard directory 127 | cfg.dir.tb_dir = os.path.join(cfg.dir.exp_dir, 'tensorboard') 128 | os.makedirs(cfg.dir.tb_dir, exist_ok=True) 129 | 130 | # 5. model directory 131 | cfg.dir.model = DummyClass() 132 | # 5.1 model checkpoint 133 | cfg.dir.model.checkpoint = os.path.join(cfg.dir.exp_dir, 'models', 'checkpoint') 134 | os.makedirs(cfg.dir.model.checkpoint, exist_ok=True) 135 | # 5.2 best model 136 | cfg.dir.model.best = os.path.join(cfg.dir.exp_dir, 'models', 'best') 137 | os.makedirs(cfg.dir.model.best, exist_ok=True) 138 | # # 5.3 save all epochs 139 | # cfg.dir.model.epoch = os.path.join(cfg.dir.exp_dir, 'models', 'epoch') 140 | # os.makedirs(cfg.dir.model.epoch, exist_ok=True) 141 | 142 | # 6. output directory 143 | cfg.dir.output_dir = DummyClass() 144 | # 6.1 submission directory 145 | cfg.dir.output_dir.submission = os.path.join(cfg.dir.exp_dir, 'outputs', 'submissions') 146 | os.makedirs(cfg.dir.output_dir.submission, exist_ok=True) 147 | # 6.2 prediction directory 148 | cfg.dir.output_dir.prediction = os.path.join(cfg.dir.exp_dir, 'outputs', 'predictions') 149 | os.makedirs(cfg.dir.output_dir.prediction, exist_ok=True) 150 | 151 | # 7. temporatory output directory to save output during training for inspection 152 | 153 | 154 | class TqdmLoggingHandler(logging.Handler): 155 | """Log consistently when using the tqdm progress bar. 156 | From https://stackoverflow.com/questions/38543506/ 157 | change-logging-print-function-to-tqdm-write-so-logging-doesnt-interfere-wit 158 | """ 159 | 160 | def __init__(self, level=logging.NOTSET): 161 | super().__init__(level) 162 | 163 | def emit(self, record): 164 | try: 165 | msg = self.format(record) 166 | tqdm.write(msg) 167 | self.flush() 168 | except (KeyboardInterrupt, SystemExit): 169 | raise 170 | except: 171 | self.handleError(record) 172 | 173 | 174 | def create_logging(log_dir, filemode='a') -> None: 175 | """ 176 | Initialize logger. 177 | """ 178 | # log_filename 179 | log_filename = os.path.join(log_dir, 'log.txt') 180 | 181 | if not logging.getLogger().hasHandlers(): 182 | # basic config for logging 183 | logging.basicConfig( 184 | level=logging.DEBUG, 185 | format='%(filename)s[line:%(lineno)d] %(levelname)s %(message)s', 186 | datefmt='%a, %d %b %Y %H:%M:%S', 187 | filename=log_filename, 188 | filemode=filemode) 189 | 190 | # Get lightning logger. 191 | logger = logging.getLogger("lightning") 192 | logger.setLevel(logging.INFO) 193 | # Purge old handlers. 194 | for old_handler in logger.handlers: 195 | logger.removeHandler(old_handler) 196 | 197 | # create tqdm handler 198 | handler = TqdmLoggingHandler() 199 | handler.setLevel(logging.INFO) 200 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 201 | # handler.setFormatter(formatter) 202 | # add tqdm handler to current logger 203 | logger.addHandler(handler) 204 | 205 | # For normal code without lightning: logger = logging.getLogger('my_logger') 206 | # logger = logging.getLogger('my_logger') 207 | # if not logger.handlers: 208 | # console = logging.StreamHandler() 209 | # console.setLevel(logging.INFO) 210 | # formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 211 | # console.setFormatter(formatter) 212 | # logging.getLogger('').addHandler(TqdmLoggingHandler(logging.INFO)) 213 | 214 | logger = logging.getLogger("lightning") 215 | logger.info('**********************************************************') 216 | logger.info('****** Start new experiment ******************************') 217 | logger.info('**********************************************************\n') 218 | logger.info('Timestamp: {}'.format(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))) 219 | logger.info('Log file is created in {}.'.format(log_dir)) 220 | 221 | 222 | def write_yaml_config(output_filename, config_dict) -> None: 223 | """ 224 | Write configs to yaml file for reference later. 225 | """ 226 | with open(output_filename, 'w') as outfile: 227 | yaml.dump(config_dict, outfile, default_flow_style=False, sort_keys=True) 228 | 229 | 230 | def set_random_seed(random_seed: int = 2020) -> None: 231 | """ 232 | Set random seed for pytorch, numpy, random. Replaced with pytorch lightning function. 233 | :param random_seed: Random seed. 234 | """ 235 | ''' Reproducible seed set''' 236 | torch.manual_seed(random_seed) 237 | if torch.cuda.is_available(): 238 | torch.cuda.manual_seed(random_seed + 1) 239 | torch.backends.cudnn.deterministic = True 240 | torch.backends.cudnn.benchmark = True 241 | np.random.seed(random_seed + 2) 242 | random.seed(random_seed + 3) 243 | 244 | -------------------------------------------------------------------------------- /utilities/learning_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module consists code to facilitate learning such as: 3 | learning rate schedule 4 | model checkpoint 5 | ... 6 | """ 7 | import logging 8 | import time 9 | from typing import Tuple 10 | 11 | import numpy as np 12 | import pytorch_lightning as pl 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | class LearningRateScheduler(pl.Callback): 18 | def __init__(self, steps_per_epoch, max_epochs: int = 50, milestones: Tuple = (0, 0.45, 0.9, 1.0), 19 | lrs: Tuple = (1e-4, 1e-2, 1e-3, 1e-4), moms: Tuple = (0.9, 0.8, 0.9, 0.9)): 20 | self.steps_per_epoch = steps_per_epoch 21 | self.max_epochs = max_epochs 22 | self.milestones = milestones 23 | self.lrs = lrs 24 | self.moms = moms 25 | self.n_steps = int(self.max_epochs * self.steps_per_epoch) 26 | self.step_milestones = [int(i * self.n_steps) for i in self.milestones] 27 | 28 | def on_train_start(self, trainer, pl_module): 29 | """ 30 | Pytorch lightning hook. 31 | Set the initial learning rate and momentums (for Adam and AdamW optimizer) 32 | """ 33 | for opt_idx, optimizer in enumerate(trainer.optimizers): 34 | for param_group in optimizer.param_groups: 35 | param_group['lr'] = self.lrs[0] 36 | param_group['betas'] = (self.moms[0], 0.999) 37 | 38 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 39 | """ 40 | Pytorch lightning hook. 41 | Set new learning rate. 42 | """ 43 | current_global_step = trainer.current_epoch * self.steps_per_epoch + batch_idx 44 | lr = np.interp(current_global_step, self.step_milestones, self.lrs) 45 | mom = np.interp(current_global_step, self.step_milestones, self.moms) 46 | trainer.logger.log_metrics({'lr': lr}, step=trainer.global_step) # trainer.global_step same as current_global_step 47 | trainer.logger.log_metrics({'momentum': mom}, step=trainer.global_step) 48 | for opt_idx, optimizer in enumerate(trainer.optimizers): 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | param_group['betas'] = (mom, 0.999) 52 | 53 | 54 | def count_model_params(model: nn.Module) -> None: 55 | """ 56 | Count and log the number of trainable and total params. 57 | :param model: Pytorch model. 58 | """ 59 | logger = logging.getLogger('lightning') 60 | logger.info('Model architecture: \n{}\n'.format(model)) 61 | logger.info('Model parameters and size:') 62 | for n, (name, param) in enumerate(model.named_parameters()): 63 | logger.info('{}: {}'.format(name, list(param.size()))) 64 | total_params = sum([param.numel() for param in model.parameters()]) 65 | trainable_params = sum(param.numel() for param in model.parameters() if param.requires_grad) 66 | logging.info('Total number of parameters: {}'.format(total_params)) 67 | logging.info('Total number of trainable parameters: {}'.format(trainable_params)) 68 | 69 | 70 | class MyLoggingCallback(pl.Callback): 71 | def __init__(self): 72 | self.lit_logger = logging.getLogger('lightning') 73 | self.train_start_time = None 74 | self.train_end_time = None 75 | self.val_start_time = None 76 | self.val_end_time = None 77 | self.fit_start_time = None 78 | self.fit_end_time = None 79 | self.test_start_time = None 80 | self.test_end_time = None 81 | 82 | def on_init_start(self, trainer): 83 | self.lit_logger.info('Start initiating trainer!') 84 | 85 | def on_init_end(self, trainer): 86 | self.lit_logger.info('Finish initiating trainer.') 87 | 88 | def on_fit_start(self, trainer, pl_module): 89 | self.lit_logger.info('Start training...') 90 | self.fit_start_time = time.time() 91 | 92 | def on_fit_end(self, trainer, pl_module): 93 | self.lit_logger.info('Finish training!') 94 | self.fit_end_time = time.time() 95 | duration = self.fit_end_time - self.fit_start_time 96 | self.lit_logger.info('Total training time: {} s'.format(time.strftime('%H:%M:%S', time.gmtime(duration)))) 97 | 98 | def on_test_start(self, trainer, pl_module): 99 | self.lit_logger.info('Start testing ...') 100 | self.test_start_time = time.time() 101 | 102 | def on_test_end(self, trainer, pl_module): 103 | self.lit_logger.info('Finish testing!') 104 | self.test_end_time = time.time() 105 | duration = self.test_end_time - self.test_start_time 106 | self.lit_logger.info('Total testing time: {} s'.format(time.strftime('%H:%M:%S', time.gmtime(duration)))) 107 | 108 | -------------------------------------------------------------------------------- /utilities/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module includes helper functions for plotting. 3 | """ 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | 8 | def plot_image(image, figsize=None, title=None, xlim=None, ylim=None, extent=None, colorbar=False, cmap=None): 9 | if cmap is None: 10 | cmap = 'nipy_spectral' 11 | if figsize is None: 12 | figsize = (9, 6) 13 | fig = plt.figure(figsize=figsize) 14 | if extent is None: 15 | plt.imshow(image, origin='lower', cmap=cmap) # 'hot' 'nipy_spectral' 16 | else: 17 | ax = fig.add_subplot(1, 1, 1) 18 | ax.imshow(image, origin='lower', cmap=cmap, extent=extent) 19 | if title: 20 | plt.title(title) 21 | if xlim is not None: 22 | plt.xlim(xlim) 23 | if ylim is not None: 24 | plt.ylim(ylim) 25 | if colorbar: 26 | plt.colorbar(orientation="horizontal") 27 | plt.show() 28 | 29 | 30 | def plot_images(images, figsize=None, title_list=None, xlim=None, ylim=None, extend=None, colorbar=False, cmap=None): 31 | if cmap is None: 32 | cmap = 'nipy_spectral' 33 | n_images = len(images) 34 | if figsize is None: 35 | figwidth = np.min((n_images * 6, 24)) 36 | figsize = (figwidth, 6) 37 | fig, axes = plt.subplots(1, n_images, figsize=figsize) 38 | for count, i_image in enumerate(images): 39 | if extend is None: 40 | im = axes[count].imshow(i_image, origin='lower', cmap=cmap) 41 | else: 42 | im = axes[count].imshow(i_image, origin='lower', cmap=cmap, extend=extend) 43 | if title_list is not None: 44 | axes[count].set_title(title_list[count]) 45 | if xlim is not None: 46 | axes[count].set_xlim(xlim) 47 | if ylim is not None: 48 | axes[count].set_ylim(ylim) 49 | if colorbar: 50 | fig.colorbar(im, ax=axes[count]) 51 | plt.show() 52 | 53 | 54 | def plot_graph(y, figsize=None, title=None, xlim=None, ylim=None): 55 | if figsize is None: 56 | figsize = (9, 6) 57 | plt.figure(figsize=figsize) 58 | plt.plot(y) 59 | if title: 60 | plt.title(title) 61 | if xlim is not None: 62 | plt.xlim(xlim) 63 | else: 64 | plt.xlim([0, len(y)]) 65 | if ylim is not None: 66 | plt.ylim(ylim) 67 | plt.show() 68 | 69 | 70 | def plot_graphs(graph_list, title_list=[], xlim=None, ylim=None): 71 | n_graphs = len(graph_list) 72 | fig_depth = 2 * n_graphs 73 | # fig_depth = np.max((4, fig_depth)) 74 | fig_depth = np.min((24, fig_depth)) 75 | plt.figure(figsize=(12, fig_depth)) 76 | for i in range(n_graphs): 77 | plt.subplot(n_graphs, 1, i + 1) 78 | plt.plot(graph_list[i]) 79 | if title_list: 80 | plt.title(title_list[i]) 81 | if xlim is not None: 82 | plt.xlim(xlim) 83 | if ylim is not None: 84 | plt.ylim(ylim) 85 | plt.tight_layout() 86 | plt.show() -------------------------------------------------------------------------------- /utilities/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module includes code to do data augmentation in STFT domain on numpy array: 3 | 1. random volume 4 | 2. random cutout 5 | 3. spec augment 6 | 4. freq shift 7 | 5. TTA: test time augmentation 8 | """ 9 | import random 10 | from typing import Tuple 11 | 12 | import numpy as np 13 | import torch 14 | 15 | 16 | class ComposeTransformNp: 17 | """ 18 | Compose a list of data augmentation on numpy array. 19 | """ 20 | def __init__(self, transforms: list): 21 | self.transforms = transforms 22 | 23 | def __call__(self, x: np.ndarray): 24 | for transform in self.transforms: 25 | x = transform(x) 26 | return x 27 | 28 | 29 | class DataAugSpectrogramNp: 30 | """ 31 | Base class for data augmentation for audio spectrogram of numpy array. This class does not alter label 32 | """ 33 | def __init__(self, always_apply: bool = False, p: float = 0.5): 34 | self.always_apply = always_apply 35 | self.p = p 36 | 37 | def __call__(self, x: np.ndarray): 38 | if self.always_apply: 39 | return self.apply(x) 40 | else: 41 | if np.random.rand() < self.p: 42 | return self.apply(x) 43 | else: 44 | return x 45 | 46 | def apply(self, x: np.ndarray): 47 | raise NotImplementedError 48 | 49 | 50 | class RandomCutoutNp(DataAugSpectrogramNp): 51 | """ 52 | This data augmentation randomly cutout a rectangular area from the input image. Tested. 53 | """ 54 | def __init__(self, always_apply: bool = False, p: float = 0.5, image_aspect_ratio: float = 1, 55 | random_value: float = None): 56 | """ 57 | :param always_apply: If True, always apply transform. 58 | :param p: If always_apply is false, p is the probability to apply transform. 59 | :param image_aspect_ratio: height/width ratio. For spectrogram: n_time_steps/ n_features. 60 | :param random_value: random value to fill in the cutout area. If None, randomly fill the cutout area with value 61 | between min and max of input. 62 | """ 63 | super().__init__(always_apply, p) 64 | self.random_value = random_value 65 | # Params: s: area, r: height/width ratio. 66 | self.s_l = 0.02 67 | self.s_h = 0.3 68 | self.r_1 = 0.3 69 | self.r_2 = 1 / 0.3 70 | if image_aspect_ratio > 1: 71 | self.r_1 = self.r_1 * image_aspect_ratio 72 | elif image_aspect_ratio < 1: 73 | self.r_2 = self.r_2 * image_aspect_ratio 74 | 75 | def apply(self, x: np.ndarray) -> np.ndarray: 76 | """ 77 | :param x: <(n_channels, n_time_steps, n_features) or (n_time_steps, n_features)>: input spectrogram. 78 | :return: random cutout x 79 | """ 80 | # get image size 81 | image_dim = x.ndim 82 | img_h = x.shape[-2] # time frame dimension 83 | img_w = x.shape[-1] # feature dimension 84 | min_value = np.min(x) 85 | max_value = np.max(x) 86 | # Initialize output 87 | output_img = x.copy() 88 | # random erase 89 | s = np.random.uniform(self.s_l, self.s_h) * img_h * img_w 90 | r = np.random.uniform(self.r_1, self.r_2) 91 | w = np.min((int(np.sqrt(s / r)), img_w - 1)) 92 | h = np.min((int(np.sqrt(s * r)), img_h - 1)) 93 | left = np.random.randint(0, img_w - w) 94 | top = np.random.randint(0, img_h - h) 95 | if self.random_value is None: 96 | c = np.random.uniform(min_value, max_value) 97 | else: 98 | c = self.random_value 99 | if image_dim == 2: 100 | output_img[top:top + h, left:left + w] = c 101 | else: 102 | output_img[:, top:top + h, left:left + w] = c 103 | 104 | return output_img 105 | 106 | 107 | class SpecAugmentNp(DataAugSpectrogramNp): 108 | """ 109 | This data augmentation randomly remove horizontal or vertical strips from image. Tested 110 | """ 111 | def __init__(self, always_apply: bool = False, p: float = 0.5, time_max_width: int = None, 112 | freq_max_width: int = None, n_time_stripes: int = 1, n_freq_stripes: int = 1): 113 | """ 114 | :param always_apply: If True, always apply transform. 115 | :param p: If always_apply is false, p is the probability to apply transform. 116 | :param time_max_width: maximum time width to remove. 117 | :param freq_max_width: maximum freq width to remove. 118 | :param n_time_stripes: number of time stripes to remove. 119 | :param n_freq_stripes: number of freq stripes to remove. 120 | """ 121 | super().__init__(always_apply, p) 122 | self.time_max_width = time_max_width 123 | self.freq_max_width = freq_max_width 124 | self.n_time_stripes = n_time_stripes 125 | self.n_freq_stripes = n_freq_stripes 126 | 127 | def apply(self, x: np.ndarray) -> np.ndarray: 128 | """ 129 | :param x: <(n_channels, n_time_steps, n_features)>: input spectrogram. 130 | :return: augmented spectrogram. 131 | """ 132 | assert x.ndim == 3, 'Error: dimension of input spectrogram is not 3!' 133 | n_frames = x.shape[1] 134 | n_freqs = x.shape[2] 135 | min_value = np.min(x) 136 | max_value = np.max(x) 137 | if self.time_max_width is None: 138 | time_max_width = int(0.15 * n_frames) 139 | else: 140 | time_max_width = self.time_max_width 141 | time_max_width = np.max((1, time_max_width)) 142 | if self.freq_max_width is None: 143 | freq_max_width = int(0.2 * n_freqs) 144 | else: 145 | freq_max_width = self.freq_max_width 146 | freq_max_width = np.max((1, freq_max_width)) 147 | 148 | new_spec = x.copy() 149 | 150 | for i in np.arange(self.n_time_stripes): 151 | dur = np.random.randint(1, time_max_width, 1)[0] 152 | start_idx = np.random.randint(0, n_frames - dur, 1)[0] 153 | random_value = np.random.uniform(min_value, max_value, 1) 154 | new_spec[:, start_idx:start_idx + dur, :] = random_value 155 | 156 | for i in np.arange(self.n_freq_stripes): 157 | dur = np.random.randint(1, freq_max_width, 1)[0] 158 | start_idx = np.random.randint(0, n_freqs - dur, 1)[0] 159 | random_value = np.random.uniform(min_value, max_value, 1) 160 | new_spec[:, :, start_idx:start_idx + dur] = random_value 161 | 162 | return new_spec 163 | 164 | 165 | class RandomCutoutHoleNp(DataAugSpectrogramNp): 166 | """ 167 | This data augmentation randomly cutout a few small holes in the spectrogram. Tested. 168 | """ 169 | def __init__(self, always_apply: bool = False, p: float = 0.5, n_max_holes: int = 8, max_h_size: int = 8, 170 | max_w_size: int = 8, filled_value: float = None): 171 | """ 172 | :param always_apply: If True, always apply transform. 173 | :param p: If always_apply is false, p is the probability to apply transform. 174 | :param n_max_holes: Maximum number of holes to cutout. 175 | :param max_h_size: Maximum time frames of the cutout holes. 176 | :param max_w_size: Maximum freq bands of the cutout holes. 177 | :param filled_value: random value to fill in the cutout area. If None, randomly fill the cutout area with value 178 | between min and max of input. 179 | """ 180 | super().__init__(always_apply, p) 181 | self.n_max_holes = n_max_holes 182 | self.max_h_size = np.max((max_h_size, 5)) 183 | self.max_w_size = np.max((max_w_size, 5)) 184 | self.filled_value = filled_value 185 | 186 | def apply(self, x: np.ndarray): 187 | """ 188 | :param x: <(n_channels, n_time_steps, n_features)>: input spectrogram. 189 | :return: augmented spectrogram. 190 | """ 191 | assert x.ndim == 3, 'Error: dimension of input spectrogram is not 3!' 192 | img_h = x.shape[-2] # time frame dimension 193 | img_w = x.shape[-1] # feature dimension 194 | min_value = np.min(x) 195 | max_value = np.max(x) 196 | new_spec = x.copy() 197 | # n_cutout_holes = np.random.randint(1, self.n_max_holes, 1)[0] 198 | n_cutout_holes = self.n_max_holes 199 | for ihole in np.arange(n_cutout_holes): 200 | # w = np.random.randint(4, self.max_w_size, 1)[0] 201 | # h = np.random.randint(4, self.max_h_size, 1)[0] 202 | w = self.max_w_size 203 | h = self.max_h_size 204 | left = np.random.randint(0, img_w - w) 205 | top = np.random.randint(0, img_h - h) 206 | if self.filled_value is None: 207 | new_spec[:, top:top + h, left:left + w] = np.random.uniform(min_value, max_value) 208 | else: 209 | new_spec[:, top:top + h, left:left + w] = self.filled_value 210 | 211 | return new_spec 212 | 213 | 214 | class CompositeCutout(DataAugSpectrogramNp): 215 | """ 216 | This data augmentation combine Random cutout, specaugment, cutout hole. 217 | """ 218 | def __init__(self, always_apply: bool = False, p: float = 0.5, image_aspect_ratio: float = 1): 219 | super().__init__(always_apply, p) 220 | self.random_cutout = RandomCutoutNp(always_apply=True, image_aspect_ratio=image_aspect_ratio) 221 | self.spec_augment = SpecAugmentNp(always_apply=True) 222 | self.random_cutout_hole = RandomCutoutHoleNp(always_apply=True) 223 | 224 | def apply(self, x: np.ndarray): 225 | choice = np.random.randint(0, 3, 1)[0] 226 | if choice == 0: 227 | return self.random_cutout(x) 228 | elif choice == 1: 229 | return self.spec_augment(x) 230 | elif choice == 2: 231 | return self.random_cutout_hole(x) 232 | 233 | 234 | class RandomFlipLeftRightNp(DataAugSpectrogramNp): 235 | """ 236 | This data augmentation randomly flip spectrogram left and right. 237 | """ 238 | def __init__(self, always_apply: bool = False, p: float = 0.5): 239 | super().__init__(always_apply, p) 240 | 241 | def apply(self, x: np.ndarray): 242 | """ 243 | :param x < np.ndarray (n_channels, n_time_steps, n_features) 244 | :return: 245 | """ 246 | new_x = x.copy() 247 | for ichan in np.arange(x.shape[0]): 248 | new_x[ichan, :] = np.flip(x[ichan, :], axis=0) 249 | return new_x 250 | 251 | 252 | class AdditiveGaussianNoiseNp(DataAugSpectrogramNp): 253 | """ 254 | This data augmentation add gaussian noise to spectrogram. Assume spectrograms are mean-var normalzied. 255 | """ 256 | def __init__(self, always_apply: bool = False, p: float = 0.5): 257 | super().__init__(always_apply, p) 258 | 259 | def apply(self, x: np.ndarray): 260 | """ 261 | :param x < np.ndarray (n_channels, n_time_steps, n_features) 262 | """ 263 | n_frames, n_features = x.shape[1], x.shape[2] 264 | jitter_std = np.random.uniform(0.05, 0.2, 1) 265 | jitter = np.random.normal(0, jitter_std, size=(n_frames, n_features)).astype(np.float32) 266 | new_spec = x.copy() 267 | new_spec = new_spec + jitter 268 | return new_spec 269 | 270 | 271 | class MultiplicativeGaussianNoiseNp(DataAugSpectrogramNp): 272 | """ 273 | This data augmentation multiply gaussian noise to spectrogram. Assume spectrograms are mean-var normalzied. 274 | """ 275 | def __init__(self, always_apply: bool = False, p: float = 0.5): 276 | super().__init__(always_apply, p) 277 | 278 | def apply(self, x: np.ndarray): 279 | """ 280 | :param x < np.ndarray (n_channels, n_time_steps, n_features) 281 | """ 282 | n_frames, n_features = x.shape[1], x.shape[2] 283 | jitter_std = np.random.uniform(0.01, 0.1, 1) 284 | jitter = np.random.normal(0, jitter_std, size=(n_frames, n_features)).astype(np.float32) + 1 285 | new_spec = x.copy() 286 | new_spec = new_spec * jitter 287 | return new_spec 288 | 289 | 290 | class CosineGaussianNoiseNp(DataAugSpectrogramNp): 291 | """ 292 | This data augmentation add/multiply Gaussian noise whose power has sinusoidal pattern over time. 293 | """ 294 | def __init__(self, always_apply: bool = False, p: float = 0.5): 295 | super().__init__(always_apply, p) 296 | 297 | def apply(self, x: np.ndarray): 298 | """ 299 | :param x < np.ndarray (n_channels, n_time_steps, n_features) 300 | """ 301 | n_frames, n_features = x.shape[1], x.shape[2] 302 | jitter_std = np.random.uniform(0.01, 0.1, 1) 303 | jitter = np.random.normal(0, jitter_std, size=(n_frames, n_features)).astype(np.float32) 304 | cosine = np.cos(np.arange(n_frames) / n_frames * np.pi * 2).astype(np.float32) 305 | jitter = jitter * cosine[:, None] + 1 306 | new_spec = x.copy() 307 | new_spec = new_spec * jitter 308 | return new_spec 309 | 310 | 311 | class CompositeGaussianNoiseNp(DataAugSpectrogramNp): 312 | """ 313 | This data augmentation randomly select different method to add gaussian noise to the spectrograms 314 | """ 315 | def __init__(self, always_apply: bool = False, p: float = 0.5): 316 | super().__init__(always_apply, p) 317 | self.agauss = AdditiveGaussianNoiseNp(always_apply=True) 318 | self.mgauss = MultiplicativeGaussianNoiseNp(always_apply=True) 319 | self.cgauss = CosineGaussianNoiseNp(always_apply=True) 320 | 321 | def apply(self, x: np.ndarray): 322 | choice = np.random.randint(0, 2, 1)[0] 323 | if choice == 0: 324 | return self.agauss(x) 325 | elif choice == 1: 326 | return self.mgauss(x) 327 | 328 | 329 | class RandomVolumeNp(DataAugSpectrogramNp): 330 | """ 331 | This data augmentation randomly increase or decrease volume 332 | Reference from: https://github.com/koukyo1994/kaggle-birdcall-6th-place/blob/master/src/transforms.py 333 | """ 334 | def __init__(self, always_apply: bool = False, p: float = 0.5, limit=3): 335 | super().__init__(always_apply, p) 336 | self.limit = limit 337 | 338 | def apply(self, x: np.ndarray): 339 | db = np.random.uniform(-self.limit, self.limit) 340 | new_spec = x.copy() 341 | new_spec = new_spec * db2float(db, amplitude=False) 342 | return new_spec 343 | 344 | 345 | class CosineVolumeNp(DataAugSpectrogramNp): 346 | """ 347 | This data augmentation change volume in cosine pattern 348 | Reference from: https://github.com/koukyo1994/kaggle-birdcall-6th-place/blob/master/src/transforms.py 349 | """ 350 | def __init__(self, always_apply=False, p=0.5, limit=3): 351 | super().__init__(always_apply, p) 352 | self.limit = limit 353 | 354 | def apply(self, x: np.ndarray): 355 | """ 356 | :param x < np.ndarray (n_channels, n_time_steps, n_features) 357 | """ 358 | db = np.random.uniform(-self.limit, self.limit) 359 | n_time_steps = x.shape[1] 360 | cosine = np.cos(np.arange(n_time_steps) / n_time_steps * np.pi * 2) 361 | dbs = db2float(cosine * db) 362 | new_spec = x.copy() 363 | return new_spec * dbs[None, :, None] 364 | 365 | 366 | def db2float(db: float, amplitude=True): 367 | """Function to convert dB to float""" 368 | if amplitude: 369 | return 10**(db / 20) 370 | else: 371 | return 10 ** (db / 10) 372 | 373 | 374 | class CompositeVolumeNp(DataAugSpectrogramNp): 375 | """ 376 | This data augmentation randomly select different method to add gaussian noise to the spectrograms 377 | """ 378 | def __init__(self, always_apply: bool = False, p: float = 0.5): 379 | super().__init__(always_apply, p) 380 | self.random_volume = RandomVolumeNp(always_apply=True) 381 | self.cosine_volume = CosineVolumeNp(always_apply=True) 382 | 383 | def apply(self, x: np.ndarray): 384 | choice = np.random.randint(0, 2, 1)[0] 385 | if choice == 0: 386 | return self.random_volume(x) 387 | elif choice == 1: 388 | return self.cosine_volume(x) 389 | 390 | 391 | class RandomRotateNp(DataAugSpectrogramNp): 392 | """ 393 | This data augmentation rotate spectrogram along time dimension. 394 | """ 395 | def __init__(self, always_apply=False, p=0.5, max_length: int = None, direction: str = None): 396 | super().__init__(always_apply, p) 397 | self.max_length = max_length 398 | self.direction = direction 399 | 400 | def apply(self, x: np.ndarray): 401 | n_channels, n_timesteps, n_features = x.shape 402 | if self.max_length is None: 403 | self.max_length = int(n_timesteps * 0.1) 404 | rotate_len = np.random.randint(1, self.max_length, 1)[0] 405 | new_spec = x.copy() 406 | if self.direction is None: 407 | direction = np.random.choice(['left', 'right'], 1)[0] 408 | else: 409 | direction = self.direction 410 | if direction == 'left': # rotate left 411 | new_spec[:, 0:-rotate_len, :] = x[:, rotate_len:, :] 412 | new_spec[:, -rotate_len:, ] = x[:, 0: rotate_len, :] 413 | else: # rotate right 414 | new_spec[:, rotate_len:, :] = x[:, 0: -rotate_len, :] 415 | new_spec[:, 0: rotate_len, :] = x[:, -rotate_len:, :] 416 | return new_spec 417 | 418 | 419 | class RandomShiftUpDownNp(DataAugSpectrogramNp): 420 | """ 421 | This data augmentation random shift the spectrogram up or down. 422 | """ 423 | def __init__(self, always_apply=False, p=0.5, freq_shift_range: int = None, direction: str = None, mode='reflect'): 424 | super().__init__(always_apply, p) 425 | self.freq_shift_range = freq_shift_range 426 | self.direction = direction 427 | self.mode = mode 428 | 429 | def apply(self, x: np.ndarray): 430 | n_channels, n_timesteps, n_features = x.shape 431 | if self.freq_shift_range is None: 432 | self.freq_shift_range = int(n_features * 0.08) 433 | shift_len = np.random.randint(1, self.freq_shift_range, 1)[0] 434 | if self.direction is None: 435 | direction = np.random.choice(['up', 'down'], 1)[0] 436 | else: 437 | direction = self.direction 438 | new_spec = x.copy() 439 | if direction == 'up': 440 | new_spec = np.pad(new_spec, ((0, 0), (0, 0), (shift_len, 0)), mode=self.mode)[:, :, 0: n_features] 441 | else: 442 | new_spec = np.pad(new_spec, ((0, 0), (0, 0), (0, shift_len)), mode=self.mode)[:, :, shift_len:] 443 | return new_spec 444 | 445 | --------------------------------------------------------------------------------