├── README.md ├── appendixes ├── fold1_plot.png ├── results.png ├── split1_ir0_ov1_1_pred.png └── split1_ir0_ov1_1_ref.png ├── evaluation_tools ├── cls_feature_class.py └── evaluation_metrics.py ├── licenses ├── MIT_LICENSE.md └── TUT_LICENSE.md ├── pytorch ├── evaluate.py ├── losses.py ├── main.py ├── models.py └── pytorch_utils.py ├── runme.sh └── utils ├── config.py ├── data_generator.py ├── features.py ├── plot_results.py └── utilities.py /README.md: -------------------------------------------------------------------------------- 1 | # DCASE 2019 Task 3 Sound Event Localization and Detection 2 | 3 | DCASE 2019 Task3 Sound Event Localization and Detection is a task to jointly localize and recognize individual sound events and their respective temporal onset and offset times. More description of this task can be found in http://dcase.community/challenge2019/task-sound-event-localization-and-detection. 4 | 5 | ## DATASET 6 | The dataset can be downloaded from http://dcase.community/challenge2019/task-sound-event-localization-and-detection. The dataset contains 400 audio recordings, one minute long recordings sampled at 48 kHz. Two formats of audio, First-Order Ambisonic (FOA) and microphone array (MIC) are provided for each audio recording. Both of FOA and MIC are 4 channels. Each one minute recording contains 11 synthetic polyphonic sound events. 7 | 8 | The statistic of the data is shown below: 9 | 10 | | | Attributes | Dev. recordings | Eva. recordings | 11 | |:----:|:-----------------:|:---------------:|:---------------:| 12 | | Data | FOA & MIC, 48 kHz | 400 | - | 13 | 14 | The log mel spectrogram of the scenes are shown below: 15 | 16 | 17 | 18 | ## Run the code 19 | 20 | **0. Prepare data** 21 | 22 | Download and upzip the data, the data looks like: 23 | 24 |
 25 | dataset_root
 26 | ├── metadata_dev (400 files)
 27 | │    ├── split1_ir0_ov1_10.csv
 28 | │    └── ...
 29 | ├── foa_dev (400 files)
 30 | │    ├── split1_ir0_ov1_10.wav
 31 | │    └── ...
 32 | ├── mic_dev (400 files)
 33 | │    ├── split1_ir0_ov1_10.wav
 34 | │    └── ...
 35 | └── ...
 36 | 
37 | 38 | **1. Requirements** 39 | 40 | python 3.6 + pytorch 1.0 41 | 42 | **2. Then simply run:** 43 | 44 | $ Run the bash script ./runme.sh 45 | 46 | Or run the commands in runme.sh line by line. The commands includes: 47 | 48 | (1) Modify the paths of dataset and your workspace 49 | 50 | (2) Extract features 51 | 52 | (3) Train model 53 | 54 | (4) Inference 55 | 56 | ## Model 57 | We apply convolutional neural networks using the log mel spectrogram of 4 channels audio as input. The targets are onset and offset times, elevation and azimuth of sound events. To train a CNN with 9 layers and a mini-batch size of 32, the training takes approximately 200 ms / iteration on a single card GTX Titan Xp GPU. The model is trained for 5000 iterations. The training looks like: 58 | 59 |
 60 | Load data time: 90.292 s
 61 | Training audio num: 300
 62 | Validation audio num: 100
 63 | ------------------------------------
 64 | ...
 65 | ------------------------------------
 66 | iteration: 5000
 67 | train statistics:    total_loss: 0.076, event_loss: 0.007, position_loss: 0.069
 68 |     Total 10 files written to /vol/vssp/msos/qk/workspaces/dcase2019_task3/_temp/submissions/main/Cnn_9layers_foa_dev_logmel_64frames_64melbins
 69 |     sed_error_rate :     0.057
 70 |     sed_f1_score :       0.971
 71 |     doa_error :          8.902
 72 |     doa_frame_recall :   0.966
 73 |     seld_score :         0.042
 74 | validate statistics:  total_loss: 0.449, event_loss: 0.039, position_loss: 0.409
 75 |     Total 10 files written to /vol/vssp/msos/qk/workspaces/dcase2019_task3/_temp/submissions/main/Cnn_9layers_foa_dev_logmel_64frames_64melbins
 76 |     sed_error_rate :     0.206
 77 |     sed_f1_score :       0.875
 78 |     doa_error :          33.374
 79 |     doa_frame_recall :   0.894
 80 |     seld_score :         0.156
 81 | train time: 20.135 s, validate time: 7.023 s
 82 | Model saved to /vol/vssp/msos/qk/workspaces/dcase2019_task3/models/main/Cnn_9layers_foa_dev_logmel_64frames_64melbins/holdout_fold=1/md_5000_iters.pth
 83 | ------------------------------------
 84 | ...
 85 | 
86 | 87 | ## Results 88 | 89 | **Validation result on 400 audio files** 90 | 91 | 92 | 93 | The 9-layer CNN achieves slightly better results than other CNNs. The baseline system result is from [2], which applies phase information as extra input and obtains better DOA result. Our system only use log mel spectrogram magnitue as input, without using phase as input. 94 | 95 | **Plot results over different iterations** 96 | 97 | 98 | 99 | The 5-layer and 9-layer CNN achieve similar results. The 13-layer CNN tends to overfit. 100 | 101 | **Visualization the prediction** 102 | 103 | 104 | 105 | We are able to predict the DOA only using the log mel spectrogram magnitude as input. 106 | 107 | ## Summary 108 | This codebase provides a convolutional neural network (CNN) for DCASE 2019 challenge Task 3 Sound Event Localization and Detection. 109 | 110 | ## Citation 111 | 112 | **If this codebase is helpful, please feel free to cite the following paper:** 113 | 114 | **[1] Qiuqiang Kong, Yin Cao, Turab Iqbal, Yong Xu, Wenwu Wang, Mark D. Plumbley. Cross-task learning for audio tagging, sound event detection and spatial localization: DCASE 2019 baseline systems. arXiv preprint arXiv:1904.03476 (2019).** 115 | 116 | ## FAQ 117 | If you met running out of GPU memory error, then try to reduce batch_size. 118 | 119 | ## License 120 | File evaluation_tools/cls_feature_class.py is under TUT_LICENSE. 121 | 122 | All other files except utils/cls_feature_class.py is under MIT_LICENSE. 123 | 124 | ## External link 125 | 126 | [2] https://github.com/sharathadavanne/seld-dcase2019 127 | 128 | [3] http://dcase.community/challenge2019/task-audio-tagging 129 | -------------------------------------------------------------------------------- /appendixes/fold1_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/dcase2019_task3/aa40091cd9ce49149201634c3a8da2fc01ffd67c/appendixes/fold1_plot.png -------------------------------------------------------------------------------- /appendixes/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/dcase2019_task3/aa40091cd9ce49149201634c3a8da2fc01ffd67c/appendixes/results.png -------------------------------------------------------------------------------- /appendixes/split1_ir0_ov1_1_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/dcase2019_task3/aa40091cd9ce49149201634c3a8da2fc01ffd67c/appendixes/split1_ir0_ov1_1_pred.png -------------------------------------------------------------------------------- /appendixes/split1_ir0_ov1_1_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiangkong/dcase2019_task3/aa40091cd9ce49149201634c3a8da2fc01ffd67c/appendixes/split1_ir0_ov1_1_ref.png -------------------------------------------------------------------------------- /evaluation_tools/cls_feature_class.py: -------------------------------------------------------------------------------- 1 | # Contains routines for labels creation, features extraction and normalization 2 | # 3 | 4 | 5 | import os 6 | import numpy as np 7 | import scipy.io.wavfile as wav 8 | from sklearn import preprocessing 9 | from sklearn.externals import joblib 10 | from IPython import embed 11 | import matplotlib.pyplot as plot 12 | import librosa 13 | # plot.switch_backend('agg') 14 | 15 | 16 | class FeatureClass: 17 | def __init__(self, dataset_dir='', feat_label_dir='', dataset='foa', is_eval=False): 18 | """ 19 | 20 | :param dataset: string, dataset name, supported: foa - ambisonic or mic- microphone format 21 | :param is_eval: if True, does not load dataset labels. 22 | """ 23 | 24 | # Input directories 25 | self._feat_label_dir = feat_label_dir 26 | self._dataset_dir = dataset_dir 27 | self._dataset_combination = '{}_{}'.format(dataset, 'eval' if is_eval else 'dev') 28 | self._aud_dir = os.path.join(self._dataset_dir, self._dataset_combination) 29 | 30 | self._desc_dir = None if is_eval else os.path.join(self._dataset_dir, 'metadata_dev') 31 | 32 | # Output directories 33 | self._label_dir = None 34 | self._feat_dir = None 35 | self._feat_dir_norm = None 36 | 37 | # Local parameters 38 | self._is_eval = is_eval 39 | 40 | self._fs = 48000 41 | self._hop_len_s = 0.02 42 | self._hop_len = int(self._fs * self._hop_len_s) 43 | self._frame_res = self._fs / float(self._hop_len) 44 | self._nb_frames_1s = int(self._frame_res) 45 | 46 | self._win_len = 2 * self._hop_len 47 | self._nfft = self._next_greater_power_of_2(self._win_len) 48 | 49 | self._dataset = dataset 50 | self._eps = np.spacing(np.float(1e-16)) 51 | self._nb_channels = 4 52 | 53 | # Sound event classes dictionary # DCASE 2016 Task 2 sound events 54 | self._unique_classes = dict() 55 | self._unique_classes = \ 56 | { 57 | 'clearthroat': 2, 58 | 'cough': 8, 59 | 'doorslam': 9, 60 | 'drawer': 1, 61 | 'keyboard': 6, 62 | 'keysDrop': 4, 63 | 'knock': 0, 64 | 'laughter': 10, 65 | 'pageturn': 7, 66 | 'phone': 3, 67 | 'speech': 5 68 | } 69 | 70 | self._doa_resolution = 10 71 | self._azi_list = range(-180, 180, self._doa_resolution) 72 | self._length = len(self._azi_list) 73 | self._ele_list = range(-40, 50, self._doa_resolution) 74 | self._height = len(self._ele_list) 75 | 76 | self._audio_max_len_samples = 60 * self._fs # TODO: Fix the audio synthesis code to always generate 60s of 77 | # audio. Currently it generates audio till the last active sound event, which is not always 60s long. This is a 78 | # quick fix to overcome that. We need this because, for processing and training we need the length of features 79 | # to be fixed. 80 | 81 | # For regression task only 82 | self._default_azi = 180 83 | self._default_ele = 50 84 | 85 | if self._default_azi in self._azi_list: 86 | print('ERROR: chosen default_azi value {} should not exist in azi_list'.format(self._default_azi)) 87 | exit() 88 | if self._default_ele in self._ele_list: 89 | print('ERROR: chosen default_ele value {} should not exist in ele_list'.format(self._default_ele)) 90 | exit() 91 | 92 | self._max_frames = int(np.ceil(self._audio_max_len_samples / float(self._hop_len))) 93 | 94 | def _load_audio(self, audio_path): 95 | fs, audio = wav.read(audio_path) 96 | audio = audio[:, :self._nb_channels] / 32768.0 + self._eps 97 | if audio.shape[0] < self._audio_max_len_samples: 98 | zero_pad = np.zeros((self._audio_max_len_samples - audio.shape[0], audio.shape[1])) 99 | audio = np.vstack((audio, zero_pad)) 100 | elif audio.shape[0] > self._audio_max_len_samples: 101 | audio = audio[:self._audio_max_len_samples, :] 102 | return audio, fs 103 | 104 | # INPUT FEATURES 105 | @staticmethod 106 | def _next_greater_power_of_2(x): 107 | return 2 ** (x - 1).bit_length() 108 | 109 | def _spectrogram(self, audio_input): 110 | _nb_ch = audio_input.shape[1] 111 | nb_bins = self._nfft // 2 112 | spectra = np.zeros((self._max_frames, nb_bins, _nb_ch), dtype=complex) 113 | for ch_cnt in range(_nb_ch): 114 | stft_ch = librosa.core.stft(audio_input[:, ch_cnt], n_fft=self._nfft, hop_length=self._hop_len, 115 | win_length=self._win_len, window='hann') 116 | spectra[:, :, ch_cnt] = stft_ch[1:, :self._max_frames].T 117 | return spectra 118 | 119 | def _extract_spectrogram_for_file(self, audio_filename): 120 | audio_in, fs = self._load_audio(os.path.join(self._aud_dir, audio_filename)) 121 | audio_spec = self._spectrogram(audio_in) 122 | # print('\t{}'.format(audio_spec.shape)) 123 | np.save(os.path.join(self._feat_dir, '{}.npy'.format(audio_filename.split('.')[0])), audio_spec.reshape(self._max_frames, -1)) 124 | 125 | # OUTPUT LABELS 126 | def read_desc_file(self, desc_filename, in_sec=False): 127 | desc_file = { 128 | 'class': list(), 'start': list(), 'end': list(), 'ele': list(), 'azi': list() 129 | } 130 | fid = open(desc_filename, 'r') 131 | next(fid) 132 | for line in fid: 133 | split_line = line.strip().split(',') 134 | desc_file['class'].append(split_line[0]) 135 | # desc_file['class'].append(split_line[0].split('.')[0][:-3]) 136 | if in_sec: 137 | # return onset-offset time in seconds 138 | desc_file['start'].append(float(split_line[1])) 139 | desc_file['end'].append(float(split_line[2])) 140 | else: 141 | # return onset-offset time in frames 142 | desc_file['start'].append(int(np.floor(float(split_line[1])*self._frame_res))) 143 | desc_file['end'].append(int(np.ceil(float(split_line[2])*self._frame_res))) 144 | desc_file['ele'].append(int(split_line[3])) 145 | desc_file['azi'].append(int(split_line[4])) 146 | fid.close() 147 | return desc_file 148 | 149 | def get_list_index(self, azi, ele): 150 | azi = (azi - self._azi_list[0]) // 10 151 | ele = (ele - self._ele_list[0]) // 10 152 | return azi * self._height + ele 153 | 154 | def get_matrix_index(self, ind): 155 | azi, ele = ind // self._height, ind % self._height 156 | azi = (azi * 10 + self._azi_list[0]) 157 | ele = (ele * 10 + self._ele_list[0]) 158 | return azi, ele 159 | 160 | def _get_doa_labels_regr(self, _desc_file): 161 | azi_label = self._default_azi*np.ones((self._max_frames, len(self._unique_classes))) 162 | ele_label = self._default_ele*np.ones((self._max_frames, len(self._unique_classes))) 163 | for i, ele_ang in enumerate(_desc_file['ele']): 164 | start_frame = _desc_file['start'][i] 165 | end_frame = self._max_frames if _desc_file['end'][i] > self._max_frames else _desc_file['end'][i] 166 | azi_ang = _desc_file['azi'][i] 167 | class_ind = self._unique_classes[_desc_file['class'][i]] 168 | if (azi_ang >= self._azi_list[0]) & (azi_ang <= self._azi_list[-1]) & \ 169 | (ele_ang >= self._ele_list[0]) & (ele_ang <= self._ele_list[-1]): 170 | azi_label[start_frame:end_frame + 1, class_ind] = azi_ang 171 | ele_label[start_frame:end_frame + 1, class_ind] = ele_ang 172 | else: 173 | print('bad_angle {} {}'.format(azi_ang, ele_ang)) 174 | doa_label_regr = np.concatenate((azi_label, ele_label), axis=1) 175 | return doa_label_regr 176 | 177 | def _get_se_labels(self, _desc_file): 178 | se_label = np.zeros((self._max_frames, len(self._unique_classes))) 179 | for i, se_class in enumerate(_desc_file['class']): 180 | start_frame = _desc_file['start'][i] 181 | end_frame = self._max_frames if _desc_file['end'][i] > self._max_frames else _desc_file['end'][i] 182 | se_label[start_frame:end_frame + 1, self._unique_classes[se_class]] = 1 183 | return se_label 184 | 185 | def get_labels_for_file(self, _desc_file): 186 | """ 187 | Reads description csv file and returns classification based SED labels and regression based DOA labels 188 | 189 | :param _desc_file: csv file 190 | :return: label_mat: labels of the format [sed_label, doa_label], 191 | where sed_label is of dimension [nb_frames, nb_classes] which is 1 for active sound event else zero 192 | where doa_labels is of dimension [nb_frames, 2*nb_classes], nb_classes each for azimuth and elevation angles, 193 | if active, the DOA values will be in degrees, else, it will contain default doa values given by 194 | self._default_ele and self._default_azi 195 | """ 196 | 197 | se_label = self._get_se_labels(_desc_file) 198 | doa_label = self._get_doa_labels_regr(_desc_file) 199 | label_mat = np.concatenate((se_label, doa_label), axis=1) 200 | # print(label_mat.shape) 201 | return label_mat 202 | 203 | def get_clas_labels_for_file(self, _desc_file): 204 | """ 205 | Reads description file and returns classification format labels for SELD 206 | 207 | :param _desc_file: csv file 208 | :return: _labels: matrix of SELD labels of dimension [nb_frames, nb_classes, nb_azi*nb_ele], 209 | which is 1 for active sound event and location else zero 210 | """ 211 | 212 | _labels = np.zeros((self._max_frames, len(self._unique_classes), len(self._azi_list) * len(self._ele_list))) 213 | for _ind, _start_frame in enumerate(_desc_file['start']): 214 | _tmp_class = self._unique_classes[_desc_file['class'][_ind]] 215 | _tmp_azi = _desc_file['azi'][_ind] 216 | _tmp_ele = _desc_file['ele'][_ind] 217 | _tmp_end = self._max_frames if _desc_file['end'][_ind] > self._max_frames else _desc_file['end'][_ind] 218 | _tmp_ind = self.get_list_index(_tmp_azi, _tmp_ele) 219 | _labels[_start_frame:_tmp_end + 1, _tmp_class, _tmp_ind] = 1 220 | 221 | return _labels 222 | 223 | # ------------------------------- EXTRACT FEATURE AND PREPROCESS IT ------------------------------- 224 | def extract_all_feature(self): 225 | # setting up folders 226 | self._feat_dir = self.get_unnormalized_feat_dir() 227 | create_folder(self._feat_dir) 228 | 229 | # extraction starts 230 | print('Extracting spectrogram:') 231 | print('\t\taud_dir {}\n\t\tdesc_dir {}\n\t\tfeat_dir {}'.format( 232 | self._aud_dir, self._desc_dir, self._feat_dir)) 233 | 234 | for file_cnt, file_name in enumerate(os.listdir(self._aud_dir)): 235 | print('{}: {}'.format(file_cnt, file_name)) 236 | wav_filename = '{}.wav'.format(file_name.split('.')[0]) 237 | self._extract_spectrogram_for_file(wav_filename) 238 | 239 | def preprocess_features(self): 240 | # Setting up folders and filenames 241 | self._feat_dir = self.get_unnormalized_feat_dir() 242 | self._feat_dir_norm = self.get_normalized_feat_dir() 243 | create_folder(self._feat_dir_norm) 244 | normalized_features_wts_file = self.get_normalized_wts_file() 245 | spec_scaler = None 246 | 247 | # pre-processing starts 248 | if self._is_eval: 249 | spec_scaler = joblib.load(normalized_features_wts_file) 250 | print('Normalized_features_wts_file: {}. Loaded.'.format(normalized_features_wts_file)) 251 | 252 | else: 253 | print('Estimating weights for normalizing feature files:') 254 | print('\t\tfeat_dir: {}'.format(self._feat_dir)) 255 | 256 | spec_scaler = preprocessing.StandardScaler() 257 | for file_cnt, file_name in enumerate(os.listdir(self._feat_dir)): 258 | print('{}: {}'.format(file_cnt, file_name)) 259 | feat_file = np.load(os.path.join(self._feat_dir, file_name)) 260 | spec_scaler.partial_fit(np.concatenate((np.abs(feat_file), np.angle(feat_file)), axis=1)) 261 | del feat_file 262 | joblib.dump( 263 | spec_scaler, 264 | normalized_features_wts_file 265 | ) 266 | print('Normalized_features_wts_file: {}. Saved.'.format(normalized_features_wts_file)) 267 | 268 | print('Normalizing feature files:') 269 | print('\t\tfeat_dir_norm {}'.format(self._feat_dir_norm)) 270 | for file_cnt, file_name in enumerate(os.listdir(self._feat_dir)): 271 | print('{}: {}'.format(file_cnt, file_name)) 272 | feat_file = np.load(os.path.join(self._feat_dir, file_name)) 273 | feat_file = spec_scaler.transform(np.concatenate((np.abs(feat_file), np.angle(feat_file)), axis=1)) 274 | np.save( 275 | os.path.join(self._feat_dir_norm, file_name), 276 | feat_file 277 | ) 278 | del feat_file 279 | 280 | print('normalized files written to {}'.format(self._feat_dir_norm)) 281 | 282 | # ------------------------------- EXTRACT LABELS AND PREPROCESS IT ------------------------------- 283 | def extract_all_labels(self): 284 | self._label_dir = self.get_label_dir() 285 | 286 | print('Extracting labels:') 287 | print('\t\taud_dir {}\n\t\tdesc_dir {}\n\t\tlabel_dir {}'.format( 288 | self._aud_dir, self._desc_dir, self._label_dir)) 289 | create_folder(self._label_dir) 290 | 291 | for file_cnt, file_name in enumerate(os.listdir(self._desc_dir)): 292 | print('{}: {}'.format(file_cnt, file_name)) 293 | wav_filename = '{}.wav'.format(file_name.split('.')[0]) 294 | desc_file = self.read_desc_file(os.path.join(self._desc_dir, file_name)) 295 | label_mat = self.get_labels_for_file(desc_file) 296 | np.save(os.path.join(self._label_dir, '{}.npy'.format(wav_filename.split('.')[0])), label_mat) 297 | 298 | # ------------------------------- Misc public functions ------------------------------- 299 | def get_classes(self): 300 | return self._unique_classes 301 | 302 | def get_normalized_feat_dir(self): 303 | return os.path.join( 304 | self._feat_label_dir, 305 | '{}_norm'.format(self._dataset_combination) 306 | ) 307 | 308 | def get_unnormalized_feat_dir(self): 309 | return os.path.join( 310 | self._feat_label_dir, 311 | '{}'.format(self._dataset_combination) 312 | ) 313 | 314 | def get_label_dir(self): 315 | if self._is_eval: 316 | return None 317 | else: 318 | return os.path.join( 319 | self._feat_label_dir, '{}_label'.format(self._dataset_combination) 320 | ) 321 | 322 | def get_normalized_wts_file(self): 323 | return os.path.join( 324 | self._feat_label_dir, 325 | '{}_wts'.format(self._dataset) 326 | ) 327 | 328 | def get_default_azi_ele_regr(self): 329 | return self._default_azi, self._default_ele 330 | 331 | def get_nb_channels(self): 332 | return self._nb_channels 333 | 334 | def nb_frames_1s(self): 335 | return self._nb_frames_1s 336 | 337 | def get_hop_len_sec(self): 338 | return self._hop_len_s 339 | 340 | def get_azi_ele_list(self): 341 | return self._azi_list, self._ele_list 342 | 343 | def get_nb_frames(self): 344 | return self._max_frames 345 | 346 | 347 | def create_folder(folder_name): 348 | if not os.path.exists(folder_name): 349 | print('{} folder does not exist, creating it.'.format(folder_name)) 350 | os.makedirs(folder_name) -------------------------------------------------------------------------------- /evaluation_tools/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Implements the core metrics from sound event detection evaluation module http://tut-arg.github.io/sed_eval/ and 3 | # The DOA metrics are explained in the SELDnet paper 4 | # 5 | 6 | import numpy as np 7 | from scipy.optimize import linear_sum_assignment 8 | from IPython import embed 9 | eps = np.finfo(np.float).eps 10 | 11 | 12 | ########################################################################################## 13 | # SELD scoring functions - class implementation 14 | # 15 | # NOTE: Supports only one-hot labels for both SED and DOA. Doesnt work for baseline method 16 | # directly, since it estimated DOA in regression approach. Check below the class for 17 | # one shot (function) implementations of all metrics. The function implementation has 18 | # support for both one-hot labels and regression values of DOA estimation. 19 | ########################################################################################## 20 | 21 | class SELDMetrics(object): 22 | def __init__(self, nb_frames_1s=None, data_gen=None): 23 | # SED params 24 | self._S = 0 25 | self._D = 0 26 | self._I = 0 27 | self._TP = 0 28 | self._Nref = 0 29 | self._Nsys = 0 30 | self._block_size = nb_frames_1s 31 | 32 | # DOA params 33 | self._doa_loss_pred_cnt = 0 34 | self._nb_frames = 0 35 | 36 | self._doa_loss_pred = 0 37 | self._nb_good_pks = 0 38 | 39 | self._data_gen = data_gen 40 | 41 | self._less_est_cnt, self._less_est_frame_cnt = 0, 0 42 | self._more_est_cnt, self._more_est_frame_cnt = 0, 0 43 | 44 | def f1_overall_framewise(self, O, T): 45 | TP = ((2 * T - O) == 1).sum() 46 | Nref, Nsys = T.sum(), O.sum() 47 | self._TP += TP 48 | self._Nref += Nref 49 | self._Nsys += Nsys 50 | 51 | def er_overall_framewise(self, O, T): 52 | FP = np.logical_and(T == 0, O == 1).sum(1) 53 | FN = np.logical_and(T == 1, O == 0).sum(1) 54 | S = np.minimum(FP, FN).sum() 55 | D = np.maximum(0, FN - FP).sum() 56 | I = np.maximum(0, FP - FN).sum() 57 | self._S += S 58 | self._D += D 59 | self._I += I 60 | 61 | def f1_overall_1sec(self, O, T): 62 | new_size = int(np.ceil(O.shape[0] / self._block_size)) 63 | O_block = np.zeros((new_size, O.shape[1])) 64 | T_block = np.zeros((new_size, O.shape[1])) 65 | for i in range(0, new_size): 66 | O_block[i, :] = np.max(O[int(i * self._block_size):int(i * self._block_size + self._block_size - 1), :], axis=0) 67 | T_block[i, :] = np.max(T[int(i * self._block_size):int(i * self._block_size + self._block_size - 1), :], axis=0) 68 | return self.f1_overall_framewise(O_block, T_block) 69 | 70 | def er_overall_1sec(self, O, T): 71 | new_size = int(O.shape[0] / self._block_size) 72 | O_block = np.zeros((new_size, O.shape[1])) 73 | T_block = np.zeros((new_size, O.shape[1])) 74 | for i in range(0, new_size): 75 | O_block[i, :] = np.max(O[int(i * self._block_size):int(i * self._block_size + self._block_size - 1), :], axis=0) 76 | T_block[i, :] = np.max(T[int(i * self._block_size):int(i * self._block_size + self._block_size - 1), :], axis=0) 77 | return self.er_overall_framewise(O_block, T_block) 78 | 79 | def update_sed_scores(self, pred, gt): 80 | """ 81 | Computes SED metrics for one second segments 82 | 83 | :param pred: predicted matrix of dimension [nb_frames, nb_classes], with 1 when sound event is active else 0 84 | :param gt: reference matrix of dimension [nb_frames, nb_classes], with 1 when sound event is active else 0 85 | :param nb_frames_1s: integer, number of frames in one second 86 | :return: 87 | """ 88 | self.f1_overall_1sec(pred, gt) 89 | self.er_overall_1sec(pred, gt) 90 | 91 | def compute_sed_scores(self): 92 | ER = (self._S + self._D + self._I) / (self._Nref + 0.0) 93 | 94 | prec = float(self._TP) / float(self._Nsys + eps) 95 | recall = float(self._TP) / float(self._Nref + eps) 96 | F = 2 * prec * recall / (prec + recall + eps) 97 | 98 | return ER, F 99 | 100 | def update_doa_scores(self, pred_doa_thresholded, gt_doa): 101 | ''' 102 | Compute DOA metrics when DOA is estimated using classification approach 103 | 104 | :param pred_doa_thresholded: predicted results of dimension [nb_frames, nb_classes, nb_azi*nb_ele], 105 | with value 1 when sound event active, else 0 106 | :param gt_doa: reference results of dimension [nb_frames, nb_classes, nb_azi*nb_ele], 107 | with value 1 when sound event active, else 0 108 | :param data_gen_test: feature or data generator class 109 | 110 | :return: DOA metrics 111 | 112 | ''' 113 | self._doa_loss_pred_cnt += np.sum(pred_doa_thresholded) 114 | self._nb_frames += pred_doa_thresholded.shape[0] 115 | 116 | for frame in range(pred_doa_thresholded.shape[0]): 117 | nb_gt_peaks = int(np.sum(gt_doa[frame, :])) 118 | nb_pred_peaks = int(np.sum(pred_doa_thresholded[frame, :])) 119 | 120 | # good_frame_cnt includes frames where the nb active sources were zero in both groundtruth and prediction 121 | if nb_gt_peaks == nb_pred_peaks: 122 | self._nb_good_pks += 1 123 | elif nb_gt_peaks > nb_pred_peaks: 124 | self._less_est_frame_cnt += 1 125 | self._less_est_cnt += (nb_gt_peaks - nb_pred_peaks) 126 | elif nb_pred_peaks > nb_gt_peaks: 127 | self._more_est_frame_cnt += 1 128 | self._more_est_cnt += (nb_pred_peaks - nb_gt_peaks) 129 | 130 | # when nb_ref_doa > nb_estimated_doa, ignores the extra ref doas and scores only the nearest matching doas 131 | # similarly, when nb_estimated_doa > nb_ref_doa, ignores the extra estimated doa and scores the remaining matching doas 132 | if nb_gt_peaks and nb_pred_peaks: 133 | pred_ind = np.where(pred_doa_thresholded[frame] == 1)[1] 134 | pred_list_rad = np.array(self._data_gen .get_matrix_index(pred_ind)) * np.pi / 180 135 | 136 | gt_ind = np.where(gt_doa[frame] == 1)[1] 137 | gt_list_rad = np.array(self._data_gen .get_matrix_index(gt_ind)) * np.pi / 180 138 | 139 | frame_dist = distance_between_gt_pred(gt_list_rad.T, pred_list_rad.T) 140 | self._doa_loss_pred += frame_dist 141 | 142 | def compute_doa_scores(self): 143 | doa_error = self._doa_loss_pred / self._doa_loss_pred_cnt 144 | frame_recall = self._nb_good_pks / float(self._nb_frames) 145 | return doa_error, frame_recall 146 | 147 | def reset(self): 148 | # SED params 149 | self._S = 0 150 | self._D = 0 151 | self._I = 0 152 | self._TP = 0 153 | self._Nref = 0 154 | self._Nsys = 0 155 | 156 | # DOA params 157 | self._doa_loss_pred_cnt = 0 158 | self._nb_frames = 0 159 | 160 | self._doa_loss_pred = 0 161 | self._nb_good_pks = 0 162 | 163 | self._less_est_cnt, self._less_est_frame_cnt = 0, 0 164 | self._more_est_cnt, self._more_est_frame_cnt = 0, 0 165 | 166 | 167 | ############################################################### 168 | # SED scoring functions 169 | ############################################################### 170 | 171 | 172 | def reshape_3Dto2D(A): 173 | return A.reshape(A.shape[0] * A.shape[1], A.shape[2]) 174 | 175 | 176 | def f1_overall_framewise(O, T): 177 | if len(O.shape) == 3: 178 | O, T = reshape_3Dto2D(O), reshape_3Dto2D(T) 179 | TP = ((2 * T - O) == 1).sum() 180 | Nref, Nsys = T.sum(), O.sum() 181 | 182 | prec = float(TP) / float(Nsys + eps) 183 | recall = float(TP) / float(Nref + eps) 184 | f1_score = 2 * prec * recall / (prec + recall + eps) 185 | return f1_score 186 | 187 | 188 | def er_overall_framewise(O, T): 189 | if len(O.shape) == 3: 190 | O, T = reshape_3Dto2D(O), reshape_3Dto2D(T) 191 | 192 | FP = np.logical_and(T == 0, O == 1).sum(1) 193 | FN = np.logical_and(T == 1, O == 0).sum(1) 194 | 195 | S = np.minimum(FP, FN).sum() 196 | D = np.maximum(0, FN-FP).sum() 197 | I = np.maximum(0, FP-FN).sum() 198 | 199 | Nref = T.sum() 200 | ER = (S+D+I) / (Nref + 0.0) 201 | return ER 202 | 203 | 204 | def f1_overall_1sec(O, T, block_size): 205 | if len(O.shape) == 3: 206 | O, T = reshape_3Dto2D(O), reshape_3Dto2D(T) 207 | new_size = int(np.ceil(O.shape[0] / block_size)) 208 | O_block = np.zeros((new_size, O.shape[1])) 209 | T_block = np.zeros((new_size, O.shape[1])) 210 | for i in range(0, new_size): 211 | O_block[i, :] = np.max(O[int(i * block_size):int(i * block_size + block_size - 1), :], axis=0) 212 | T_block[i, :] = np.max(T[int(i * block_size):int(i * block_size + block_size - 1), :], axis=0) 213 | return f1_overall_framewise(O_block, T_block) 214 | 215 | 216 | def er_overall_1sec(O, T, block_size): 217 | if len(O.shape) == 3: 218 | O, T = reshape_3Dto2D(O), reshape_3Dto2D(T) 219 | new_size = int(O.shape[0] / (block_size)) 220 | O_block = np.zeros((new_size, O.shape[1])) 221 | T_block = np.zeros((new_size, O.shape[1])) 222 | for i in range(0, new_size): 223 | O_block[i, :] = np.max(O[int(i * block_size):int(i * block_size + block_size - 1), :], axis=0) 224 | T_block[i, :] = np.max(T[int(i * block_size):int(i * block_size + block_size - 1), :], axis=0) 225 | return er_overall_framewise(O_block, T_block) 226 | 227 | 228 | def compute_sed_scores(pred, gt, nb_frames_1s): 229 | """ 230 | Computes SED metrics for one second segments 231 | 232 | :param pred: predicted matrix of dimension [nb_frames, nb_classes], with 1 when sound event is active else 0 233 | :param gt: reference matrix of dimension [nb_frames, nb_classes], with 1 when sound event is active else 0 234 | :param nb_frames_1s: integer, number of frames in one second 235 | :return: 236 | """ 237 | f1o = f1_overall_1sec(pred, gt, nb_frames_1s) 238 | ero = er_overall_1sec(pred, gt, nb_frames_1s) 239 | scores = [ero, f1o] 240 | return scores 241 | 242 | 243 | ############################################################### 244 | # DOA scoring functions 245 | ############################################################### 246 | 247 | 248 | def compute_doa_scores_regr(pred_doa_rad, gt_doa_rad, pred_sed, gt_sed): 249 | """ 250 | Compute DOA metrics when DOA is estimated using regression approach 251 | 252 | :param pred_doa_rad: predicted doa_labels is of dimension [nb_frames, 2*nb_classes], 253 | nb_classes each for azimuth and elevation angles, 254 | if active, the DOA values will be in RADIANS, else, it will contain default doa values 255 | :param gt_doa_rad: reference doa_labels is of dimension [nb_frames, 2*nb_classes], 256 | nb_classes each for azimuth and elevation angles, 257 | if active, the DOA values will be in RADIANS, else, it will contain default doa values 258 | :param pred_sed: predicted sed label of dimension [nb_frames, nb_classes] which is 1 for active sound event else zero 259 | :param gt_sed: reference sed label of dimension [nb_frames, nb_classes] which is 1 for active sound event else zero 260 | :return: 261 | """ 262 | 263 | nb_src_gt_list = np.zeros(gt_doa_rad.shape[0]).astype(int) 264 | nb_src_pred_list = np.zeros(gt_doa_rad.shape[0]).astype(int) 265 | good_frame_cnt = 0 266 | doa_loss_pred = 0.0 267 | nb_sed = gt_sed.shape[-1] 268 | 269 | less_est_cnt, less_est_frame_cnt = 0, 0 270 | more_est_cnt, more_est_frame_cnt = 0, 0 271 | 272 | for frame_cnt, sed_frame in enumerate(gt_sed): 273 | nb_src_gt_list[frame_cnt] = int(np.sum(sed_frame)) 274 | nb_src_pred_list[frame_cnt] = int(np.sum(pred_sed[frame_cnt])) 275 | 276 | # good_frame_cnt includes frames where the nb active sources were zero in both groundtruth and prediction 277 | if nb_src_gt_list[frame_cnt] == nb_src_pred_list[frame_cnt]: 278 | good_frame_cnt = good_frame_cnt + 1 279 | elif nb_src_gt_list[frame_cnt] > nb_src_pred_list[frame_cnt]: 280 | less_est_cnt = less_est_cnt + nb_src_gt_list[frame_cnt] - nb_src_pred_list[frame_cnt] 281 | less_est_frame_cnt = less_est_frame_cnt + 1 282 | elif nb_src_gt_list[frame_cnt] < nb_src_pred_list[frame_cnt]: 283 | more_est_cnt = more_est_cnt + nb_src_pred_list[frame_cnt] - nb_src_gt_list[frame_cnt] 284 | more_est_frame_cnt = more_est_frame_cnt + 1 285 | 286 | # when nb_ref_doa > nb_estimated_doa, ignores the extra ref doas and scores only the nearest matching doas 287 | # similarly, when nb_estimated_doa > nb_ref_doa, ignores the extra estimated doa and scores the remaining matching doas 288 | if nb_src_gt_list[frame_cnt] and nb_src_pred_list[frame_cnt]: 289 | # DOA Loss with respect to predicted confidence 290 | sed_frame_gt = gt_sed[frame_cnt] 291 | doa_frame_gt_azi = gt_doa_rad[frame_cnt][:nb_sed][sed_frame_gt == 1] 292 | doa_frame_gt_ele = gt_doa_rad[frame_cnt][nb_sed:][sed_frame_gt == 1] 293 | 294 | sed_frame_pred = pred_sed[frame_cnt] 295 | doa_frame_pred_azi = pred_doa_rad[frame_cnt][:nb_sed][sed_frame_pred == 1] 296 | doa_frame_pred_ele = pred_doa_rad[frame_cnt][nb_sed:][sed_frame_pred == 1] 297 | 298 | doa_loss_pred += distance_between_gt_pred(np.vstack((doa_frame_gt_azi, doa_frame_gt_ele)).T, 299 | np.vstack((doa_frame_pred_azi, doa_frame_pred_ele)).T) 300 | 301 | doa_loss_pred_cnt = np.sum(nb_src_pred_list) 302 | if doa_loss_pred_cnt: 303 | doa_loss_pred /= doa_loss_pred_cnt 304 | 305 | frame_recall = good_frame_cnt / float(gt_sed.shape[0]) 306 | er_metric = [doa_loss_pred, frame_recall, doa_loss_pred_cnt, good_frame_cnt, more_est_cnt, less_est_cnt] 307 | return er_metric 308 | 309 | 310 | def compute_doa_scores_clas(pred_doa_thresholded, gt_doa, data_gen_test): 311 | ''' 312 | Compute DOA metrics when DOA is estimated using classification approach 313 | 314 | :param pred_doa_thresholded: predicted results of dimension [nb_frames, nb_classes, nb_azi*nb_ele], 315 | with value 1 when sound event active, else 0 316 | :param gt_doa: reference results of dimension [nb_frames, nb_classes, nb_azi*nb_ele], 317 | with value 1 when sound event active, else 0 318 | :param data_gen_test: feature or data generator class 319 | 320 | :return: DOA metrics 321 | 322 | ''' 323 | doa_loss_pred_cnt = np.sum(pred_doa_thresholded) 324 | 325 | doa_loss_pred = 0 326 | nb_good_pks = 0 327 | 328 | less_est_cnt, less_est_frame_cnt = 0, 0 329 | more_est_cnt, more_est_frame_cnt = 0, 0 330 | 331 | for frame in range(pred_doa_thresholded.shape[0]): 332 | nb_gt_peaks = int(np.sum(gt_doa[frame, :])) 333 | nb_pred_peaks = int(np.sum(pred_doa_thresholded[frame, :])) 334 | 335 | # good_frame_cnt includes frames where the nb active sources were zero in both groundtruth and prediction 336 | if nb_gt_peaks == nb_pred_peaks: 337 | nb_good_pks += 1 338 | elif nb_gt_peaks > nb_pred_peaks: 339 | less_est_frame_cnt += 1 340 | less_est_cnt += (nb_gt_peaks - nb_pred_peaks) 341 | elif nb_pred_peaks > nb_gt_peaks: 342 | more_est_frame_cnt += 1 343 | more_est_cnt += (nb_pred_peaks - nb_gt_peaks) 344 | 345 | # when nb_ref_doa > nb_estimated_doa, ignores the extra ref doas and scores only the nearest matching doas 346 | # similarly, when nb_estimated_doa > nb_ref_doa, ignores the extra estimated doa and scores the remaining matching doas 347 | if nb_gt_peaks and nb_pred_peaks: 348 | pred_ind = np.where(pred_doa_thresholded[frame] == 1)[1] 349 | pred_list_rad = np.array(data_gen_test.get_matrix_index(pred_ind)) * np.pi / 180 350 | 351 | gt_ind = np.where(gt_doa[frame] == 1)[1] 352 | gt_list_rad = np.array(data_gen_test.get_matrix_index(gt_ind)) * np.pi / 180 353 | 354 | frame_dist = distance_between_gt_pred(gt_list_rad.T, pred_list_rad.T) 355 | doa_loss_pred += frame_dist 356 | 357 | if doa_loss_pred_cnt: 358 | doa_loss_pred /= doa_loss_pred_cnt 359 | 360 | frame_recall = nb_good_pks / float(pred_doa_thresholded.shape[0]) 361 | er_metric = [doa_loss_pred, frame_recall, doa_loss_pred_cnt, nb_good_pks, more_est_cnt, less_est_cnt] 362 | return er_metric 363 | 364 | 365 | def distance_between_gt_pred(gt_list_rad, pred_list_rad): 366 | """ 367 | Shortest distance between two sets of spherical coordinates. Given a set of groundtruth spherical coordinates, 368 | and its respective predicted coordinates, we calculate the spherical distance between each of the spherical 369 | coordinate pairs resulting in a matrix of distances, where one axis represents the number of groundtruth 370 | coordinates and the other the predicted coordinates. The number of estimated peaks need not be the same as in 371 | groundtruth, thus the distance matrix is not always a square matrix. We use the hungarian algorithm to find the 372 | least cost in this distance matrix. 373 | 374 | :param gt_list_rad: list of ground-truth spherical coordinates 375 | :param pred_list_rad: list of predicted spherical coordinates 376 | :return: cost - distance 377 | :return: less - number of DOA's missed 378 | :return: extra - number of DOA's over-estimated 379 | """ 380 | 381 | gt_len, pred_len = gt_list_rad.shape[0], pred_list_rad.shape[0] 382 | ind_pairs = np.array([[x, y] for y in range(pred_len) for x in range(gt_len)]) 383 | cost_mat = np.zeros((gt_len, pred_len)) 384 | 385 | # Slow implementation 386 | # cost_mat = np.zeros((gt_len, pred_len)) 387 | # for gt_cnt, gt in enumerate(gt_list_rad): 388 | # for pred_cnt, pred in enumerate(pred_list_rad): 389 | # cost_mat[gt_cnt, pred_cnt] = distance_between_spherical_coordinates_rad(gt, pred) 390 | 391 | # Fast implementation 392 | if gt_len and pred_len: 393 | az1, ele1, az2, ele2 = gt_list_rad[ind_pairs[:, 0], 0], gt_list_rad[ind_pairs[:, 0], 1], \ 394 | pred_list_rad[ind_pairs[:, 1], 0], pred_list_rad[ind_pairs[:, 1], 1] 395 | cost_mat[ind_pairs[:, 0], ind_pairs[:, 1]] = distance_between_spherical_coordinates_rad(az1, ele1, az2, ele2) 396 | 397 | row_ind, col_ind = linear_sum_assignment(cost_mat) 398 | cost = cost_mat[row_ind, col_ind].sum() 399 | return cost 400 | 401 | 402 | def distance_between_spherical_coordinates_rad(az1, ele1, az2, ele2): 403 | """ 404 | Angular distance between two spherical coordinates 405 | MORE: https://en.wikipedia.org/wiki/Great-circle_distance 406 | 407 | :return: angular distance in degrees 408 | """ 409 | dist = np.sin(ele1) * np.sin(ele2) + np.cos(ele1) * np.cos(ele2) * np.cos(np.abs(az1 - az2)) 410 | # Making sure the dist values are in -1 to 1 range, else np.arccos kills the job 411 | dist = np.clip(dist, -1, 1) 412 | dist = np.arccos(dist) * 180 / np.pi 413 | return dist 414 | 415 | 416 | def distance_between_cartesian_coordinates(x1, y1, z1, x2, y2, z2): 417 | """ 418 | Angular distance between two cartesian coordinates 419 | MORE: https://en.wikipedia.org/wiki/Great-circle_distance 420 | Check 'From chord length' section 421 | 422 | :return: angular distance in degrees 423 | """ 424 | dist = np.sqrt((x1-x2) ** 2 + (y1-y2) ** 2 + (z1-z2) ** 2) 425 | dist = 2 * np.arcsin(dist / 2.0) * 180/np.pi 426 | return dist 427 | 428 | 429 | def sph2cart(azimuth, elevation, r): 430 | ''' 431 | Convert spherical to cartesian coordinates 432 | 433 | :param azimuth: in radians 434 | :param elevation: in radians 435 | :param r: in meters 436 | :return: cartesian coordinates 437 | ''' 438 | 439 | x = r * np.cos(elevation) * np.cos(azimuth) 440 | y = r * np.cos(elevation) * np.sin(azimuth) 441 | z = r * np.sin(elevation) 442 | return x, y, z 443 | 444 | 445 | def cart2sph(x, y, z): 446 | ''' 447 | Convert cartesian to spherical coordinates 448 | 449 | :param x: 450 | :param y: 451 | :param z: 452 | :return: azi, ele in radians and r in meters 453 | ''' 454 | 455 | azimuth = np.arctan2(y,x) 456 | elevation = np.arctan2(z,np.sqrt(x**2 + y**2)) 457 | r = np.sqrt(x**2 + y**2 + z**2) 458 | return azimuth, elevation, r 459 | 460 | 461 | ############################################################### 462 | # SELD scoring functions 463 | ############################################################### 464 | 465 | 466 | def compute_seld_metric(sed_error, doa_error): 467 | """ 468 | Compute SELD metric from sed and doa errors. 469 | 470 | :param sed_error: [error rate (0 to 1 range), f score (0 to 1 range)] 471 | :param doa_error: [doa error (in degrees), frame recall (0 to 1 range)] 472 | :return: seld metric result 473 | """ 474 | seld_metric = np.mean([ 475 | sed_error[0], 476 | 1 - sed_error[1], 477 | doa_error[0]/180, 478 | 1 - doa_error[1]] 479 | ) 480 | return seld_metric 481 | 482 | 483 | def compute_seld_metrics_from_output_format_dict(_pred_dict, _gt_dict, _feat_cls): 484 | """ 485 | Compute SELD metrics between _gt_dict and_pred_dict in DCASE output format 486 | 487 | :param _pred_dict: dcase output format dict 488 | :param _gt_dict: dcase output format dict 489 | :param _feat_cls: feature or data generator class 490 | :return: the seld metrics 491 | """ 492 | _gt_labels = output_format_dict_to_classification_labels(_gt_dict, _feat_cls) 493 | _pred_labels = output_format_dict_to_classification_labels(_pred_dict, _feat_cls) 494 | 495 | _er, _f = compute_sed_scores(_pred_labels.max(2), _gt_labels.max(2), _feat_cls.nb_frames_1s()) 496 | _doa_err, _frame_recall, d1, d2, d3, d4 = compute_doa_scores_clas(_pred_labels, _gt_labels, _feat_cls) 497 | _seld_scr = compute_seld_metric([_er, _f], [_doa_err, _frame_recall]) 498 | return _seld_scr, _er, _f, _doa_err, _frame_recall 499 | 500 | 501 | ############################################################### 502 | # Functions for format conversions 503 | ############################################################### 504 | 505 | def output_format_dict_to_classification_labels(_output_dict, _feat_cls): 506 | 507 | _unique_classes = _feat_cls.get_classes() 508 | _nb_classes = len(_unique_classes) 509 | _azi_list, _ele_list = _feat_cls.get_azi_ele_list() 510 | _max_frames = _feat_cls.get_nb_frames() 511 | _labels = np.zeros((_max_frames, _nb_classes, len(_azi_list) * len(_ele_list))) 512 | 513 | for _frame_cnt in _output_dict.keys(): 514 | if _frame_cnt < _max_frames: 515 | for _tmp_doa in _output_dict[_frame_cnt]: 516 | # Making sure the doa's are within the limits 517 | _tmp_doa[1] = np.clip(_tmp_doa[1], _azi_list[0], _azi_list[-1]) 518 | _tmp_doa[2] = np.clip(_tmp_doa[2], _ele_list[0], _ele_list[-1]) 519 | 520 | # create label 521 | _labels[_frame_cnt, _tmp_doa[0], int(_feat_cls.get_list_index(_tmp_doa[1], _tmp_doa[2]))] = 1 522 | 523 | return _labels 524 | 525 | 526 | def regression_label_format_to_output_format(_feat_cls, _sed_labels, _doa_labels_deg): 527 | """ 528 | Converts the sed (classification) and doa labels predicted in regression format to dcase output format. 529 | 530 | :param _feat_cls: feature or data generator class instance 531 | :param _sed_labels: SED labels matrix [nb_frames, nb_classes] 532 | :param _doa_labels_deg: DOA labels matrix [nb_frames, 2*nb_classes] in degrees 533 | :return: _output_dict: returns a dict containing dcase output format 534 | """ 535 | 536 | _unique_classes = _feat_cls.get_classes() 537 | _nb_classes = len(_unique_classes) 538 | _azi_labels = _doa_labels_deg[:, :_nb_classes] 539 | _ele_labels = _doa_labels_deg[:, _nb_classes:] 540 | 541 | _output_dict = {} 542 | for _frame_ind in range(_sed_labels.shape[0]): 543 | _tmp_ind = np.where(_sed_labels[_frame_ind, :]) 544 | if len(_tmp_ind[0]): 545 | _output_dict[_frame_ind] = [] 546 | for _tmp_class in _tmp_ind[0]: 547 | _output_dict[_frame_ind].append([_tmp_class, _azi_labels[_frame_ind, _tmp_class], _ele_labels[_frame_ind, _tmp_class]]) 548 | return _output_dict 549 | 550 | 551 | def classification_label_format_to_output_format(_feat_cls, _labels): 552 | """ 553 | Converts the seld labels predicted in classification format to dcase output format. 554 | 555 | :param _feat_cls: feature or data generator class instance 556 | :param _labels: SED labels matrix [nb_frames, nb_classes, nb_azi*nb_ele] 557 | :return: _output_dict: returns a dict containing dcase output format 558 | """ 559 | _output_dict = {} 560 | for _frame_ind in range(_labels.shape[0]): 561 | _tmp_class_ind = np.where(_labels[_frame_ind].sum(1)) 562 | if len(_tmp_class_ind[0]): 563 | _output_dict[_frame_ind] = [] 564 | for _tmp_class in _tmp_class_ind[0]: 565 | _tmp_spatial_ind = np.where(_labels[_frame_ind, _tmp_class]) 566 | for _tmp_spatial in _tmp_spatial_ind[0]: 567 | _azi, _ele = _feat_cls.get_matrix_index(_tmp_spatial) 568 | _output_dict[_frame_ind].append( 569 | [_tmp_class, _azi, _ele]) 570 | 571 | return _output_dict 572 | 573 | 574 | def description_file_to_output_format(_desc_file_dict, _unique_classes, _hop_length_sec): 575 | """ 576 | Reads description file in csv format. Outputs, the dcase format results in dictionary, and additionally writes it 577 | to the _output_file 578 | 579 | :param _unique_classes: unique classes dictionary, maps class name to class index 580 | :param _desc_file_dict: full path of the description file 581 | :param _hop_length_sec: hop length in seconds 582 | 583 | :return: _output_dict: dcase output in dicitionary format 584 | """ 585 | 586 | _output_dict = {} 587 | for _ind, _tmp_start_sec in enumerate(_desc_file_dict['start']): 588 | _tmp_class = _unique_classes[_desc_file_dict['class'][_ind]] 589 | _tmp_azi = _desc_file_dict['azi'][_ind] 590 | _tmp_ele = _desc_file_dict['ele'][_ind] 591 | _tmp_end_sec = _desc_file_dict['end'][_ind] 592 | 593 | _start_frame = int(_tmp_start_sec / _hop_length_sec) 594 | _end_frame = int(_tmp_end_sec / _hop_length_sec) 595 | for _frame_ind in range(_start_frame, _end_frame + 1): 596 | if _frame_ind not in _output_dict: 597 | _output_dict[_frame_ind] = [] 598 | _output_dict[_frame_ind].append([_tmp_class, _tmp_azi, _tmp_ele]) 599 | 600 | return _output_dict 601 | 602 | 603 | def load_output_format_file(_output_format_file): 604 | """ 605 | Loads DCASE output format csv file and returns it in dictionary format 606 | 607 | :param _output_format_file: DCASE output format CSV 608 | :return: _output_dict: dictionary 609 | """ 610 | _output_dict = {} 611 | _fid = open(_output_format_file, 'r') 612 | # next(_fid) 613 | for _line in _fid: 614 | _words = _line.strip().split(',') 615 | _frame_ind = int(_words[0]) 616 | if _frame_ind not in _output_dict: 617 | _output_dict[_frame_ind] = [] 618 | _output_dict[_frame_ind].append([int(_words[1]), int(_words[2]), int(_words[3])]) 619 | _fid.close() 620 | return _output_dict 621 | 622 | 623 | def write_output_format_file(_output_format_file, _output_format_dict): 624 | """ 625 | Writes DCASE output format csv file, given output format dictionary 626 | 627 | :param _output_format_file: 628 | :param _output_format_dict: 629 | :return: 630 | """ 631 | _fid = open(_output_format_file, 'w') 632 | # _fid.write('{},{},{},{}\n'.format('frame number with 20ms hop (int)', 'class index (int)', 'azimuth angle (int)', 'elevation angle (int)')) 633 | for _frame_ind in _output_format_dict.keys(): 634 | for _value in _output_format_dict[_frame_ind]: 635 | _fid.write('{},{},{},{}\n'.format(int(_frame_ind), int(_value[0]), int(_value[1]), int(_value[2]))) 636 | _fid.close() 637 | -------------------------------------------------------------------------------- /licenses/MIT_LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2010-2017 Google, Inc. http://angularjs.org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/TUT_LICENSE.md: -------------------------------------------------------------------------------- 1 | -----------COPYRIGHT NOTICE STARTS WITH THIS LINE------------ Copyright (c) 2019 Tampere University and its licensors All rights reserved. 2 | 3 | Permission is hereby granted, without written agreement and without license or royalty fees, to use and copy the code for the Sound Event Localization and Detection using Convolutional Recurrent Neural Network method/architecture, present in the GitHub repository with the handle seld-dcase2019, (“Work”) described in the paper with title "Sound event localization and detection of overlapping sources using convolutional recurrent neural network" and composed of files with code in the Python programming language. This grant is only for experimental and non-commercial purposes, provided that the copyright notice in its entirety appear in all copies of this Work, and the original source of this Work, Audio Research Group at Tampere University, is acknowledged in any publication that reports research using this Work. 4 | 5 | Any commercial use of the Work or any part thereof is strictly prohibited. Commercial use include, but is not limited to: 6 | 7 | selling or reproducing the Work 8 | selling or distributing the results or content achieved by use of the Work 9 | providing services by using the Work. 10 | IN NO EVENT SHALL TAMPERE UNIVERSITY OR ITS LICENSORS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OF THIS WORK AND ITS DOCUMENTATION, EVEN IF TAMPERE UNIVERSITY OR ITS LICENSORS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | 12 | TAMPERE UNIVERSITY AND ALL ITS LICENSORS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE WORK PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE TAMPERE UNIVERSITY HAS NO OBLIGATION TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 13 | 14 | -----------COPYRIGHT NOTICE ENDS WITH THIS LINE------------ 15 | -------------------------------------------------------------------------------- /pytorch/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(1, os.path.join(sys.path[0], '../utils')) 4 | 5 | import numpy as np 6 | import time 7 | import logging 8 | import datetime 9 | import _pickle as cPickle 10 | import matplotlib.pyplot as plt 11 | 12 | from utilities import (get_filename, write_submission, calculate_metrics, 13 | inverse_scale) 14 | from pytorch_utils import forward 15 | from losses import event_spatial_loss 16 | import config 17 | 18 | 19 | class Evaluator(object): 20 | def __init__(self, model, data_generator, cuda=True): 21 | '''Evaluator to evaluate prediction performance. 22 | 23 | Args: 24 | model: object 25 | data_generator: object 26 | cuda: bool 27 | ''' 28 | 29 | self.model = model 30 | self.data_generator = data_generator 31 | self.cuda = cuda 32 | 33 | self.frames_per_second = config.frames_per_second 34 | self.submission_frames_per_second = config.submission_frames_per_second 35 | 36 | def evaluate(self, data_type, metadata_dir, submissions_dir, 37 | max_validate_num=None): 38 | '''Evaluate the performance. 39 | 40 | Args: 41 | data_type: 'train' | 'validate' 42 | metadata_dir: string, directory of reference meta csvs 43 | submissions_dir: string: directory to write out submission csvs 44 | max_validate_num: None | int, maximum iteration to run to speed up 45 | evaluation 46 | ''' 47 | 48 | # Forward 49 | generate_func=self.data_generator.generate_validate( 50 | data_type=data_type, max_validate_num=max_validate_num) 51 | 52 | list_dict = forward( 53 | model=self.model, 54 | generate_func=generate_func, 55 | cuda=self.cuda, 56 | return_target=True) 57 | 58 | # Calculate loss 59 | (total_loss, event_loss, position_loss) = self.calculate_loss(list_dict) 60 | 61 | logging.info('{:<20} {}: {:.3f}, {}: {:.3f}, {}: {:.3f}' 62 | ''.format(data_type + ' statistics: ', 'total_loss', total_loss, 63 | 'event_loss', event_loss, 'position_loss', position_loss)) 64 | 65 | # Write out submission and evaluate using code provided by organizer 66 | write_submission(list_dict, submissions_dir) 67 | 68 | prediction_paths = [os.path.join(submissions_dir, 69 | '{}.csv'.format(dict['name'])) for dict in list_dict] 70 | 71 | statistics = calculate_metrics(metadata_dir, prediction_paths) 72 | 73 | for key in statistics.keys(): 74 | logging.info(' {:<20} {:.3f}'.format(key + ' :', statistics[key])) 75 | 76 | return statistics 77 | 78 | def calculate_loss(self, list_dict): 79 | total_loss_list = [] 80 | event_loss_list = [] 81 | position_loss_list = [] 82 | 83 | for dict in list_dict: 84 | (output_dict, target_dict) = self._get_output_target_dict(dict) 85 | 86 | (total_loss, event_loss, position_loss) = event_spatial_loss( 87 | output_dict=output_dict, 88 | target_dict=target_dict, 89 | return_individual_loss=True) 90 | 91 | total_loss_list.append(total_loss) 92 | event_loss_list.append(event_loss) 93 | position_loss_list.append(position_loss) 94 | 95 | return np.mean(total_loss_list), np.mean(event_loss_list), np.mean(position_loss_list) 96 | 97 | def _get_output_target_dict(self, dict): 98 | output_dict = { 99 | 'event': dict['output_event'], 100 | 'elevation': dict['output_elevation'], 101 | 'azimuth': dict['output_azimuth']} 102 | 103 | target_dict = { 104 | 'event': dict['target_event'], 105 | 'elevation': dict['target_elevation'], 106 | 'azimuth': dict['target_azimuth']} 107 | 108 | return output_dict, target_dict 109 | 110 | 111 | def visualize(self, data_type, max_validate_num=None): 112 | '''Visualize the log mel spectrogram, reference and prediction of 113 | sound events, elevation and azimuth. 114 | 115 | Args: 116 | data_type: 'train' | 'validate' 117 | max_validate_num: None | int, maximum iteration to run to speed up 118 | evaluation 119 | ''' 120 | 121 | mel_bins = config.mel_bins 122 | frames_per_second = config.frames_per_second 123 | classes_num = config.classes_num 124 | labels = config.labels 125 | 126 | # Forward 127 | generate_func=self.data_generator.generate_validate( 128 | data_type=data_type, max_validate_num=max_validate_num) 129 | 130 | list_dict = forward( 131 | model=self.model, 132 | generate_func=generate_func, 133 | cuda=self.cuda, 134 | return_input=True, 135 | return_target=True) 136 | 137 | for n, dict in enumerate(list_dict): 138 | print('File: {}'.format(dict['name'])) 139 | 140 | frames_num = dict['target_event'].shape[1] 141 | length_in_second = frames_num / float(frames_per_second) 142 | 143 | fig, axs = plt.subplots(4, 2, figsize=(15, 10)) 144 | logmel = inverse_scale(dict['feature'][0][0], 145 | self.data_generator.scalar['mean'], 146 | self.data_generator.scalar['std']) 147 | axs[0, 0].matshow(logmel.T, origin='lower', aspect='auto', cmap='jet') 148 | axs[1, 0].matshow(dict['target_event'][0].T, origin='lower', aspect='auto', cmap='jet') 149 | axs[2, 0].matshow(dict['output_event'][0].T, origin='lower', aspect='auto', cmap='jet') 150 | axs[0, 1].matshow(dict['target_elevation'][0].T, origin='lower', aspect='auto', cmap='jet') 151 | axs[1, 1].matshow(dict['target_azimuth'][0].T, origin='lower', aspect='auto', cmap='jet') 152 | masksed_evaluation = dict['output_elevation'] * dict['output_event'] 153 | axs[2, 1].matshow(masksed_evaluation[0].T, origin='lower', aspect='auto', cmap='jet') 154 | masksed_azimuth = dict['output_azimuth'] * dict['output_event'] 155 | axs[3, 1].matshow(masksed_azimuth[0].T, origin='lower', aspect='auto', cmap='jet') 156 | 157 | axs[0,0].set_title('Log mel spectrogram', color='r') 158 | axs[1,0].set_title('Reference sound events', color='r') 159 | axs[2,0].set_title('Predicted sound events', color='b') 160 | axs[0,1].set_title('Reference elevation', color='r') 161 | axs[1,1].set_title('Reference azimuth', color='r') 162 | axs[2,1].set_title('Predicted elevation', color='b') 163 | axs[3,1].set_title('Predicted azimuth', color='b') 164 | 165 | for i in range(4): 166 | for j in range(2): 167 | axs[i, j].set_xticks([0, frames_num]) 168 | axs[i, j].set_xticklabels(['0', '{:.1f} s'.format(length_in_second)]) 169 | axs[i, j].xaxis.set_ticks_position('bottom') 170 | axs[i, j].set_yticks(np.arange(classes_num)) 171 | axs[i, j].set_yticklabels(labels) 172 | axs[i, j].yaxis.grid(color='w', linestyle='solid', linewidth=0.2) 173 | 174 | axs[0, 0].set_ylabel('Mel bins') 175 | axs[0, 0].set_yticks([0, mel_bins]) 176 | axs[0, 0].set_yticklabels([0, mel_bins]) 177 | axs[3, 0].set_visible(False) 178 | 179 | fig.tight_layout() 180 | plt.show() 181 | 182 | 183 | class StatisticsContainer(object): 184 | def __init__(self, statistics_path): 185 | '''Container of statistics during training. 186 | 187 | Args: 188 | statistics_path: string, path to write out 189 | ''' 190 | self.statistics_path = statistics_path 191 | 192 | self.backup_statistics_path = '{}_{}.pickle'.format( 193 | os.path.splitext(self.statistics_path)[0], 194 | datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 195 | 196 | self.statistics_list = [] 197 | 198 | def append_and_dump(self, iteration, statistics): 199 | '''Append statistics to container and dump the container. 200 | 201 | Args: 202 | iteration: int 203 | statistics: dict of statistics 204 | ''' 205 | statistics['iteration'] = iteration 206 | self.statistics_list.append(statistics) 207 | 208 | cPickle.dump(self.statistics_list, open(self.statistics_path, 'wb')) 209 | cPickle.dump(self.statistics_list, open(self.backup_statistics_path, 'wb')) 210 | logging.info(' Dump statistics to {}'.format(self.statistics_path)) -------------------------------------------------------------------------------- /pytorch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def to_tensor(x): 6 | if type(x).__name__ == 'ndarray': 7 | return torch.Tensor(x) 8 | else: 9 | return x 10 | 11 | 12 | def binary_crossentropy(output, target): 13 | '''Binary crossentropy between output and target. 14 | 15 | Args: 16 | output: (batch_size, frames_num, classes_num) 17 | target: (batch_size, frames_num, classes_num) 18 | ''' 19 | output = to_tensor(output) 20 | target = to_tensor(target) 21 | 22 | # To let output and target to have the same time steps. The mismatching 23 | # size is caused by pooling in CNNs. 24 | N = min(output.shape[1], target.shape[1]) 25 | 26 | return F.binary_cross_entropy( 27 | output[:, 0 : N, :], 28 | target[:, 0 : N, :]) 29 | 30 | 31 | def mean_absolute_error(output, target, mask): 32 | '''Mean absolute error between output and target. 33 | 34 | Args: 35 | output: (batch_size, frames_num, classes_num) 36 | target: (batch_size, frames_num, classes_num) 37 | ''' 38 | output = to_tensor(output) 39 | target = to_tensor(target) 40 | mask = to_tensor(mask) 41 | 42 | # To let output and target to have the same time steps. The mismatching 43 | # size is caused by pooling in CNNs. 44 | N = min(output.shape[1], target.shape[1]) 45 | 46 | output = output[:, 0 : N, :] 47 | target = target[:, 0 : N, :] 48 | mask = mask[:, 0 : N, :] 49 | 50 | normalize_value = torch.sum(mask) 51 | 52 | return torch.sum(torch.abs(output - target) * mask) / normalize_value 53 | 54 | 55 | def event_spatial_loss(output_dict, target_dict, return_individual_loss=False): 56 | '''Joint event and spatial loss. 57 | 58 | Args: 59 | output_dict: {'event': (batch_size, frames_num, classes_num), 60 | 'elevation': (batch_size, frames_num, classes_num), 61 | 'azimuth': (batch_size, frames_num, classes_num)} 62 | target_dict: {'event': (batch_size, frames_num, classes_num), 63 | 'elevation': (batch_size, frames_num, classes_num), 64 | 'azimuth': (batch_size, frames_num, classes_num)} 65 | return_individual_loss: bool 66 | 67 | Returns: 68 | total_loss: scalar 69 | ''' 70 | 71 | event_loss = binary_crossentropy( 72 | output_dict['event'], 73 | target_dict['event']) 74 | 75 | elevation_loss = mean_absolute_error( 76 | output=output_dict['elevation'], 77 | target=target_dict['elevation'], 78 | mask=target_dict['event']) 79 | 80 | azimuth_loss = mean_absolute_error( 81 | output=output_dict['azimuth'], 82 | target=target_dict['azimuth'], 83 | mask=target_dict['event']) 84 | 85 | alpha = 0.01 # To control the balance between the event loss and position loss 86 | position_loss = alpha * (elevation_loss + azimuth_loss) 87 | 88 | total_loss = event_loss + position_loss 89 | 90 | if return_individual_loss: 91 | return total_loss, event_loss, position_loss 92 | else: 93 | return total_loss -------------------------------------------------------------------------------- /pytorch/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(1, os.path.join(sys.path[0], '../utils')) 4 | import numpy as np 5 | import argparse 6 | import h5py 7 | import math 8 | import time 9 | import logging 10 | import matplotlib.pyplot as plt 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | 16 | from utilities import (create_folder, get_filename, create_logging, 17 | load_scalar, calculate_metrics) 18 | from data_generator import DataGenerator 19 | from models import (Cnn_5layers_AvgPooling, Cnn_9layers_AvgPooling, 20 | Cnn_9layers_MaxPooling, Cnn_13layers_AvgPooling) 21 | from losses import event_spatial_loss 22 | from evaluate import Evaluator, StatisticsContainer 23 | from pytorch_utils import move_data_to_gpu, forward 24 | import config 25 | 26 | 27 | def train(args): 28 | '''Train. Model will be saved after several iterations. 29 | 30 | Args: 31 | dataset_dir: string, directory of dataset 32 | workspace: string, directory of workspace 33 | audio_type: 'foa' | 'mic' 34 | holdout_fold: '1' | '2' | '3' | '4' | 'none', set to none if using all 35 | data without validation to train 36 | model_type: string, e.g. 'Cnn_9layers_AvgPooling' 37 | batch_size: int 38 | cuda: bool 39 | mini_data: bool, set True for debugging on a small part of data 40 | ''' 41 | 42 | # Arugments & parameters 43 | dataset_dir = args.dataset_dir 44 | workspace = args.workspace 45 | audio_type = args.audio_type 46 | holdout_fold = args.holdout_fold 47 | model_type = args.model_type 48 | batch_size = args.batch_size 49 | cuda = args.cuda and torch.cuda.is_available() 50 | mini_data = args.mini_data 51 | filename = args.filename 52 | 53 | mel_bins = config.mel_bins 54 | frames_per_second = config.frames_per_second 55 | classes_num = config.classes_num 56 | max_validate_num = None # Number of audio recordings to validate 57 | reduce_lr = True # Reduce learning rate after several iterations 58 | 59 | # Paths 60 | if mini_data: 61 | prefix = 'minidata_' 62 | else: 63 | prefix = '' 64 | 65 | metadata_dir = os.path.join(dataset_dir, 'metadata_dev') 66 | 67 | features_dir = os.path.join(workspace, 'features', 68 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 69 | 'dev', frames_per_second, mel_bins)) 70 | 71 | scalar_path = os.path.join(workspace, 'scalars', 72 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 73 | 'dev', frames_per_second, mel_bins), 'scalar.h5') 74 | 75 | checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, 76 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 77 | 'dev', frames_per_second, mel_bins), model_type, 78 | 'holdout_fold={}'.format(holdout_fold)) 79 | create_folder(checkpoints_dir) 80 | 81 | # All folds result should write to the same directory 82 | temp_submissions_dir = os.path.join(workspace, '_temp', 'submissions', filename, 83 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 84 | 'dev', frames_per_second, mel_bins), model_type) 85 | create_folder(temp_submissions_dir) 86 | 87 | validate_statistics_path = os.path.join(workspace, 'statistics', filename, 88 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 89 | 'dev', frames_per_second, mel_bins), 'holdout_fold={}'.format(holdout_fold), 90 | model_type, 'validate_statistics.pickle') 91 | create_folder(os.path.dirname(validate_statistics_path)) 92 | 93 | logs_dir = os.path.join(args.workspace, 'logs', filename, args.mode, 94 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 'dev', 95 | frames_per_second, mel_bins), 'holdout_fold={}'.format(holdout_fold), 96 | model_type) 97 | create_logging(logs_dir, filemode='w') 98 | logging.info(args) 99 | 100 | if cuda: 101 | logging.info('Using GPU.') 102 | else: 103 | logging.info('Using CPU. Set --cuda flag to use GPU.') 104 | 105 | # Load scalar 106 | scalar = load_scalar(scalar_path) 107 | 108 | # Model 109 | Model = eval(model_type) 110 | model = Model(classes_num) 111 | 112 | if cuda: 113 | model.cuda() 114 | 115 | # Optimizer 116 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), 117 | eps=1e-08, weight_decay=0., amsgrad=True) 118 | 119 | # Data generator 120 | data_generator = DataGenerator( 121 | features_dir=features_dir, 122 | scalar=scalar, 123 | batch_size=batch_size, 124 | holdout_fold=holdout_fold) 125 | 126 | # Evaluator 127 | evaluator = Evaluator( 128 | model=model, 129 | data_generator=data_generator, 130 | cuda=cuda) 131 | 132 | # Statistics 133 | validate_statistics_container = StatisticsContainer(validate_statistics_path) 134 | 135 | train_bgn_time = time.time() 136 | iteration = 0 137 | 138 | # Train on mini batches 139 | for batch_data_dict in data_generator.generate_train(): 140 | 141 | # Evaluate 142 | if iteration % 200 == 0: 143 | 144 | logging.info('------------------------------------') 145 | logging.info('Iteration: {}'.format(iteration)) 146 | 147 | train_fin_time = time.time() 148 | 149 | ''' 150 | # Uncomment for evaluating on training dataset 151 | train_statistics = evaluator.evaluate( 152 | data_type='train', 153 | metadata_dir=metadata_dir, 154 | submissions_dir=temp_submissions_dir, 155 | max_validate_num=max_validate_num) 156 | ''' 157 | 158 | if holdout_fold != 'none': 159 | validate_statistics = evaluator.evaluate( 160 | data_type='validate', 161 | metadata_dir=metadata_dir, 162 | submissions_dir=temp_submissions_dir, 163 | max_validate_num=max_validate_num) 164 | 165 | validate_statistics_container.append_and_dump( 166 | iteration, validate_statistics) 167 | 168 | train_time = train_fin_time - train_bgn_time 169 | validate_time = time.time() - train_fin_time 170 | 171 | logging.info( 172 | 'Train time: {:.3f} s, validate time: {:.3f} s' 173 | ''.format(train_time, validate_time)) 174 | 175 | train_bgn_time = time.time() 176 | 177 | # Save model 178 | if iteration % 1000 == 0 and iteration > 0: 179 | 180 | checkpoint = { 181 | 'iteration': iteration, 182 | 'model': model.state_dict(), 183 | 'optimizer': optimizer.state_dict()} 184 | 185 | checkpoint_path = os.path.join( 186 | checkpoints_dir, '{}_iterations.pth'.format(iteration)) 187 | 188 | torch.save(checkpoint, checkpoint_path) 189 | logging.info('Model saved to {}'.format(checkpoint_path)) 190 | 191 | # Reduce learning rate 192 | if reduce_lr and iteration % 200 == 0 and iteration > 0: 193 | for param_group in optimizer.param_groups: 194 | param_group['lr'] *= 0.9 195 | 196 | # Move data to GPU 197 | for key in batch_data_dict.keys(): 198 | batch_data_dict[key] = move_data_to_gpu(batch_data_dict[key], cuda) 199 | 200 | # Train 201 | model.train() 202 | batch_output_dict = model(batch_data_dict['feature']) 203 | loss = event_spatial_loss(batch_output_dict, batch_data_dict) 204 | 205 | # Backward 206 | optimizer.zero_grad() 207 | loss.backward() 208 | optimizer.step() 209 | 210 | # Stop learning 211 | if iteration == 5000: 212 | break 213 | 214 | iteration += 1 215 | 216 | 217 | def inference_validation(args): 218 | '''Inference validation data. 219 | 220 | Args: 221 | dataset_dir: string, directory of dataset 222 | workspace: string, directory of workspace 223 | audio_type: 'foa' | 'mic' 224 | holdout_fold: '1' | '2' | '3' | '4' | 'none', where 'none' represents 225 | summary and print results of all folds 1, 2, 3 and 4. 226 | model_type: string, e.g. 'Cnn_9layers_AvgPooling' 227 | iteration: int, load model of this iteration 228 | batch_size: int 229 | cuda: bool 230 | visualize: bool 231 | mini_data: bool, set True for debugging on a small part of data 232 | ''' 233 | 234 | # Arugments & parameters 235 | dataset_dir = args.dataset_dir 236 | workspace = args.workspace 237 | audio_type = args.audio_type 238 | holdout_fold = args.holdout_fold 239 | model_type = args.model_type 240 | iteration = args.iteration 241 | batch_size = args.batch_size 242 | cuda = args.cuda and torch.cuda.is_available() 243 | visualize = args.visualize 244 | mini_data = args.mini_data 245 | filename = args.filename 246 | 247 | mel_bins = config.mel_bins 248 | frames_per_second = config.frames_per_second 249 | classes_num = config.classes_num 250 | 251 | # Paths 252 | if mini_data: 253 | prefix = 'minidata_' 254 | else: 255 | prefix = '' 256 | 257 | metadata_dir = os.path.join(dataset_dir, 'metadata_dev') 258 | 259 | submissions_dir = os.path.join(workspace, 'submissions', filename, 260 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 'dev', 261 | frames_per_second, mel_bins), model_type, 'iteration={}'.format(iteration)) 262 | create_folder(submissions_dir) 263 | 264 | logs_dir = os.path.join(args.workspace, 'logs', filename, args.mode, 265 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 'dev', 266 | frames_per_second, mel_bins), 'holdout_fold={}'.format(holdout_fold), 267 | model_type) 268 | create_logging(logs_dir, filemode='w') 269 | logging.info(args) 270 | 271 | # Inference and calculate metrics for a fold 272 | if holdout_fold != 'none': 273 | 274 | features_dir = os.path.join(workspace, 'features', 275 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 276 | 'dev', frames_per_second, mel_bins)) 277 | 278 | scalar_path = os.path.join(workspace, 'scalars', 279 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 280 | 'dev', frames_per_second, mel_bins), 'scalar.h5') 281 | 282 | checkoutpoint_path = os.path.join(workspace, 'checkpoints', filename, 283 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 284 | 'dev', frames_per_second, mel_bins), model_type, 285 | 'holdout_fold={}'.format(holdout_fold), 286 | '{}_iterations.pth'.format(iteration)) 287 | 288 | # Load scalar 289 | scalar = load_scalar(scalar_path) 290 | 291 | # Load model 292 | Model = eval(model_type) 293 | model = Model(classes_num) 294 | checkpoint = torch.load(checkoutpoint_path) 295 | model.load_state_dict(checkpoint['model']) 296 | 297 | if cuda: 298 | model.cuda() 299 | 300 | # Data generator 301 | data_generator = DataGenerator( 302 | features_dir=features_dir, 303 | scalar=scalar, 304 | batch_size=batch_size, 305 | holdout_fold=holdout_fold) 306 | 307 | # Evaluator 308 | evaluator = Evaluator( 309 | model=model, 310 | data_generator=data_generator, 311 | cuda=cuda) 312 | 313 | # Calculate metrics 314 | data_type = 'validate' 315 | 316 | evaluator.evaluate( 317 | data_type=data_type, 318 | metadata_dir=metadata_dir, 319 | submissions_dir=submissions_dir, 320 | max_validate_num=None) 321 | 322 | # Visualize reference and predicted events, elevation and azimuth 323 | if visualize: 324 | evaluator.visualize(data_type=data_type) 325 | 326 | # Calculate metrics for all 4 folds 327 | else: 328 | prediction_names = os.listdir(submissions_dir) 329 | prediction_paths = [os.path.join(submissions_dir, name) for \ 330 | name in prediction_names] 331 | 332 | metrics = calculate_metrics(metadata_dir=metadata_dir, 333 | prediction_paths=prediction_paths) 334 | 335 | logging.info('Metrics of {} files: '.format(len(prediction_names))) 336 | for key in metrics.keys(): 337 | logging.info(' {:<20} {:.3f}'.format(key + ' :', metrics[key])) 338 | 339 | 340 | if __name__ == '__main__': 341 | parser = argparse.ArgumentParser(description='Example of parser. ') 342 | subparsers = parser.add_subparsers(dest='mode') 343 | 344 | # Train 345 | parser_train = subparsers.add_parser('train') 346 | parser_train.add_argument('--dataset_dir', type=str, required=True, help='Directory of dataset.') 347 | parser_train.add_argument('--workspace', type=str, required=True, help='Directory of your workspace.') 348 | parser_train.add_argument('--audio_type', type=str, choices=['foa', 'mic'], required=True) 349 | parser_train.add_argument('--holdout_fold', type=str, choices=['1', '2', '3', '4', 'none'], required=True, 350 | help='Holdout fold. Set to none if using all data without validation to train. ') 351 | parser_train.add_argument('--model_type', type=str, required=True, help='E.g., Cnn_9layers_AvgPooling.') 352 | parser_train.add_argument('--batch_size', type=int, required=True) 353 | parser_train.add_argument('--cuda', action='store_true', default=False) 354 | parser_train.add_argument('--mini_data', action='store_true', default=False, help='Set True for debugging on a small part of data.') 355 | 356 | # Inference validation data 357 | parser_inference_validation = subparsers.add_parser('inference_validation') 358 | parser_inference_validation.add_argument('--dataset_dir', type=str, required=True, help='Directory of dataset.') 359 | parser_inference_validation.add_argument('--workspace', type=str, required=True, help='Directory of your workspace.') 360 | parser_inference_validation.add_argument('--audio_type', type=str, choices=['foa', 'mic'], required=True) 361 | parser_inference_validation.add_argument('--holdout_fold', type=str, choices=['1', '2', '3', '4', 'none'], required=True) 362 | parser_inference_validation.add_argument('--model_type', type=str, required=True, help='E.g., Cnn_9layers_AvgPooling.') 363 | parser_inference_validation.add_argument('--iteration', type=int, required=True, help='Load model of this iteration.') 364 | parser_inference_validation.add_argument('--batch_size', type=int, required=True) 365 | parser_inference_validation.add_argument('--cuda', action='store_true', default=False) 366 | parser_inference_validation.add_argument('--visualize', action='store_true', default=False, help='Visualize log mel spectrogram, prediction and reference') 367 | parser_inference_validation.add_argument('--mini_data', action='store_true', default=False, help='Set True for debugging on a small part of data.') 368 | 369 | # Parse arguments 370 | args = parser.parse_args() 371 | args.filename = get_filename(__file__) 372 | 373 | if args.mode == 'train': 374 | train(args) 375 | 376 | elif args.mode == 'inference_validation': 377 | inference_validation(args) 378 | 379 | else: 380 | raise Exception('Error argument!') -------------------------------------------------------------------------------- /pytorch/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from pytorch_utils import interpolate 8 | 9 | 10 | def init_layer(layer, nonlinearity='leaky_relu'): 11 | """Initialize a Linear or Convolutional layer. """ 12 | nn.init.kaiming_uniform_(layer.weight, nonlinearity=nonlinearity) 13 | 14 | if hasattr(layer, 'bias'): 15 | if layer.bias is not None: 16 | layer.bias.data.fill_(0.) 17 | 18 | 19 | def init_bn(bn): 20 | """Initialize a Batchnorm layer. """ 21 | 22 | bn.bias.data.fill_(0.) 23 | bn.running_mean.data.fill_(0.) 24 | bn.weight.data.fill_(1.) 25 | bn.running_var.data.fill_(1.) 26 | 27 | 28 | class Cnn_5layers_AvgPooling(nn.Module): 29 | 30 | def __init__(self, classes_num): 31 | super(Cnn_5layers_AvgPooling, self).__init__() 32 | 33 | self.conv1 = nn.Conv2d(in_channels=4, out_channels=64, 34 | kernel_size=(5, 5), stride=(1, 1), 35 | padding=(2, 2), bias=False) 36 | 37 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, 38 | kernel_size=(5, 5), stride=(1, 1), 39 | padding=(2, 2), bias=False) 40 | 41 | self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, 42 | kernel_size=(5, 5), stride=(1, 1), 43 | padding=(2, 2), bias=False) 44 | 45 | self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, 46 | kernel_size=(5, 5), stride=(1, 1), 47 | padding=(2, 2), bias=False) 48 | 49 | self.bn1 = nn.BatchNorm2d(64) 50 | self.bn2 = nn.BatchNorm2d(128) 51 | self.bn3 = nn.BatchNorm2d(256) 52 | self.bn4 = nn.BatchNorm2d(512) 53 | 54 | self.event_fc = nn.Linear(512, classes_num, bias=True) 55 | self.elevation_fc = nn.Linear(512, classes_num, bias=True) 56 | self.azimuth_fc = nn.Linear(512, classes_num, bias=True) 57 | 58 | self.init_weights() 59 | 60 | def init_weights(self): 61 | init_layer(self.conv1) 62 | init_layer(self.conv2) 63 | init_layer(self.conv3) 64 | init_layer(self.conv4) 65 | init_layer(self.event_fc) 66 | init_layer(self.elevation_fc) 67 | init_layer(self.azimuth_fc) 68 | 69 | init_bn(self.bn1) 70 | init_bn(self.bn2) 71 | init_bn(self.bn3) 72 | init_bn(self.bn4) 73 | 74 | def forward(self, input): 75 | ''' 76 | Input: (channels_num, batch_size, times_steps, freq_bins)''' 77 | 78 | interpolate_ratio = 8 79 | 80 | x = input.transpose(0, 1) 81 | '''(batch_size, channels_num, times_steps, freq_bins)''' 82 | 83 | x = F.relu_(self.bn1(self.conv1(x))) 84 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 85 | 86 | x = F.relu_(self.bn2(self.conv2(x))) 87 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 88 | 89 | x = F.relu_(self.bn3(self.conv3(x))) 90 | x = F.avg_pool2d(x, kernel_size=(2, 2)) 91 | 92 | x = F.relu_(self.bn4(self.conv4(x))) 93 | x = F.avg_pool2d(x, kernel_size=(1, 1)) 94 | '''(batch_size, feature_maps, time_steps, freq_bins)''' 95 | 96 | x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_steps) 97 | x = x.transpose(1, 2) # (batch_size, time_steps, feature_maps) 98 | 99 | event_output = torch.sigmoid(self.event_fc(x)) # (batch_size, time_steps, classes_num) 100 | elevation_output = self.elevation_fc(x) # (batch_size, time_steps, classes_num) 101 | azimuth_output = self.azimuth_fc(x) # (batch_size, time_steps, classes_num) 102 | 103 | # Interpolate 104 | event_output = interpolate(event_output, interpolate_ratio) 105 | elevation_output = interpolate(elevation_output, interpolate_ratio) 106 | azimuth_output = interpolate(azimuth_output, interpolate_ratio) 107 | 108 | output_dict = { 109 | 'event': event_output, 110 | 'elevation': elevation_output, 111 | 'azimuth': azimuth_output} 112 | 113 | return output_dict 114 | 115 | 116 | class ConvBlock(nn.Module): 117 | def __init__(self, in_channels, out_channels): 118 | 119 | super(ConvBlock, self).__init__() 120 | 121 | self.conv1 = nn.Conv2d(in_channels=in_channels, 122 | out_channels=out_channels, 123 | kernel_size=(3, 3), stride=(1, 1), 124 | padding=(1, 1), bias=False) 125 | 126 | self.conv2 = nn.Conv2d(in_channels=out_channels, 127 | out_channels=out_channels, 128 | kernel_size=(3, 3), stride=(1, 1), 129 | padding=(1, 1), bias=False) 130 | 131 | self.bn1 = nn.BatchNorm2d(out_channels) 132 | self.bn2 = nn.BatchNorm2d(out_channels) 133 | 134 | self.init_weights() 135 | 136 | def init_weights(self): 137 | 138 | init_layer(self.conv1) 139 | init_layer(self.conv2) 140 | init_bn(self.bn1) 141 | init_bn(self.bn2) 142 | 143 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 144 | 145 | x = input 146 | x = F.relu_(self.bn1(self.conv1(x))) 147 | x = F.relu_(self.bn2(self.conv2(x))) 148 | if pool_type == 'max': 149 | x = F.max_pool2d(x, kernel_size=pool_size) 150 | elif pool_type == 'avg': 151 | x = F.avg_pool2d(x, kernel_size=pool_size) 152 | else: 153 | raise Exception('Incorrect argument!') 154 | 155 | return x 156 | 157 | 158 | class Cnn_9layers_AvgPooling(nn.Module): 159 | def __init__(self, classes_num): 160 | 161 | super(Cnn_9layers_AvgPooling, self).__init__() 162 | 163 | self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) 164 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 165 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 166 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 167 | 168 | self.event_fc = nn.Linear(512, classes_num, bias=True) 169 | self.elevation_fc = nn.Linear(512, classes_num, bias=True) 170 | self.azimuth_fc = nn.Linear(512, classes_num, bias=True) 171 | 172 | self.init_weights() 173 | 174 | def init_weights(self): 175 | 176 | init_layer(self.event_fc) 177 | init_layer(self.elevation_fc) 178 | init_layer(self.azimuth_fc) 179 | 180 | def forward(self, input): 181 | ''' 182 | Input: (channels_num, batch_size, times_steps, freq_bins)''' 183 | 184 | interpolate_ratio = 8 185 | 186 | x = input.transpose(0, 1) 187 | '''(batch_size, channels_num, times_steps, freq_bins)''' 188 | 189 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 190 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 191 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 192 | x = self.conv_block4(x, pool_size=(1, 1), pool_type='avg') 193 | 194 | x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_steps) 195 | x = x.transpose(1, 2) # (batch_size, time_steps, feature_maps) 196 | 197 | event_output = torch.sigmoid(self.event_fc(x)) # (batch_size, time_steps, classes_num) 198 | elevation_output = self.elevation_fc(x) # (batch_size, time_steps, classes_num) 199 | azimuth_output = self.azimuth_fc(x) # (batch_size, time_steps, classes_num) 200 | 201 | # Interpolate 202 | event_output = interpolate(event_output, interpolate_ratio) 203 | elevation_output = interpolate(elevation_output, interpolate_ratio) 204 | azimuth_output = interpolate(azimuth_output, interpolate_ratio) 205 | 206 | output_dict = { 207 | 'event': event_output, 208 | 'elevation': elevation_output, 209 | 'azimuth': azimuth_output} 210 | 211 | return output_dict 212 | 213 | 214 | class Cnn_9layers_MaxPooling(nn.Module): 215 | def __init__(self, classes_num): 216 | 217 | super(Cnn_9layers_MaxPooling, self).__init__() 218 | 219 | self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) 220 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 221 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 222 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 223 | 224 | self.event_fc = nn.Linear(512, classes_num, bias=True) 225 | self.elevation_fc = nn.Linear(512, classes_num, bias=True) 226 | self.azimuth_fc = nn.Linear(512, classes_num, bias=True) 227 | 228 | self.init_weights() 229 | 230 | def init_weights(self): 231 | 232 | init_layer(self.event_fc) 233 | init_layer(self.elevation_fc) 234 | init_layer(self.azimuth_fc) 235 | 236 | def forward(self, input): 237 | ''' 238 | Input: (channels_num, batch_size, times_steps, freq_bins)''' 239 | 240 | interpolate_ratio = 8 241 | 242 | x = input.transpose(0, 1) 243 | '''(batch_size, channels_num, times_steps, freq_bins)''' 244 | 245 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='max') 246 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='max') 247 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='max') 248 | x = self.conv_block4(x, pool_size=(1, 1), pool_type='max') 249 | 250 | x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_steps) 251 | x = x.transpose(1, 2) # (batch_size, time_steps, feature_maps) 252 | 253 | event_output = torch.sigmoid(self.event_fc(x)) # (batch_size, time_steps, classes_num) 254 | elevation_output = self.elevation_fc(x) # (batch_size, time_steps, classes_num) 255 | azimuth_output = self.azimuth_fc(x) # (batch_size, time_steps, classes_num) 256 | 257 | # Interpolate 258 | event_output = interpolate(event_output, interpolate_ratio) 259 | elevation_output = interpolate(elevation_output, interpolate_ratio) 260 | azimuth_output = interpolate(azimuth_output, interpolate_ratio) 261 | 262 | output_dict = { 263 | 'event': event_output, 264 | 'elevation': elevation_output, 265 | 'azimuth': azimuth_output} 266 | 267 | return output_dict 268 | 269 | 270 | class Cnn_13layers_AvgPooling(nn.Module): 271 | def __init__(self, classes_num): 272 | 273 | super(Cnn_13layers_AvgPooling, self).__init__() 274 | 275 | self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) 276 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 277 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 278 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 279 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 280 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 281 | 282 | self.event_fc = nn.Linear(2048, classes_num, bias=True) 283 | self.elevation_fc = nn.Linear(2048, classes_num, bias=True) 284 | self.azimuth_fc = nn.Linear(2048, classes_num, bias=True) 285 | 286 | self.init_weights() 287 | 288 | def init_weights(self): 289 | 290 | init_layer(self.event_fc) 291 | init_layer(self.elevation_fc) 292 | init_layer(self.azimuth_fc) 293 | 294 | def forward(self, input): 295 | ''' 296 | Input: (channels_num, batch_size, times_steps, freq_bins)''' 297 | 298 | interpolate_ratio = 32 299 | 300 | x = input.transpose(0, 1) 301 | '''(batch_size, channels_num, times_steps, freq_bins)''' 302 | 303 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 304 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 305 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 306 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 307 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 308 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 309 | 310 | x = torch.mean(x, dim=3) # (batch_size, feature_maps, time_steps) 311 | x = x.transpose(1, 2) # (batch_size, time_steps, feature_maps) 312 | 313 | event_output = torch.sigmoid(self.event_fc(x)) # (batch_size, time_steps, classes_num) 314 | elevation_output = self.elevation_fc(x) # (batch_size, time_steps, classes_num) 315 | azimuth_output = self.azimuth_fc(x) # (batch_size, time_steps, classes_num) 316 | 317 | # Interpolate 318 | event_output = interpolate(event_output, interpolate_ratio) 319 | elevation_output = interpolate(elevation_output, interpolate_ratio) 320 | azimuth_output = interpolate(azimuth_output, interpolate_ratio) 321 | 322 | output_dict = { 323 | 'event': event_output, 324 | 'elevation': elevation_output, 325 | 'azimuth': azimuth_output} 326 | 327 | return output_dict -------------------------------------------------------------------------------- /pytorch/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def move_data_to_gpu(x, cuda): 5 | if 'float' in str(x.dtype): 6 | x = torch.Tensor(x) 7 | elif 'int' in str(x.dtype): 8 | x = torch.LongTensor(x) 9 | else: 10 | raise Exception("Error!") 11 | 12 | if cuda: 13 | x = x.cuda() 14 | 15 | return x 16 | 17 | 18 | def interpolate(x, ratio): 19 | '''Interpolate the prediction to have the same time_steps as the target. 20 | The time_steps mismatch is caused by maxpooling in CNN. 21 | 22 | Args: 23 | x: (batch_size, time_steps, classes_num) 24 | ratio: int, ratio to upsample 25 | ''' 26 | (batch_size, time_steps, classes_num) = x.shape 27 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 28 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 29 | return upsampled 30 | 31 | 32 | def forward(model, generate_func, cuda, return_input=False, 33 | return_target=False): 34 | '''Forward data to model in mini-batch. 35 | 36 | Args: 37 | model: object 38 | generate_func: function 39 | cuda: bool 40 | return_input: bool 41 | return_target: bool 42 | 43 | Returns: 44 | list_dict, e.g.: 45 | [{'name': 'split1_ir0_ov1_7', 46 | 'output_event': (1, frames_num, classes_num), 47 | 'output_elevation': (1, frames_num, classes_num), 48 | 'output_azimuth': (1, frames_num, classes_num), 49 | ... 50 | }, 51 | ...] 52 | ''' 53 | 54 | list_dict = [] 55 | 56 | # Evaluate on mini-batch 57 | for (n, single_data_dict) in enumerate(generate_func): 58 | 59 | # Predict 60 | batch_feature = move_data_to_gpu(single_data_dict['feature'], cuda) 61 | 62 | with torch.no_grad(): 63 | model.eval() 64 | batch_output_dict = model(batch_feature) 65 | 66 | output_dict = { 67 | 'name': single_data_dict['name'], 68 | 'output_event': batch_output_dict['event'].data.cpu().numpy(), 69 | 'output_elevation': batch_output_dict['elevation'].data.cpu().numpy(), 70 | 'output_azimuth': batch_output_dict['azimuth'].data.cpu().numpy()} 71 | 72 | if return_input: 73 | output_dict['feature'] = single_data_dict['feature'] 74 | 75 | if return_target: 76 | output_dict['target_event'] = single_data_dict['event'] 77 | output_dict['target_elevation'] = single_data_dict['elevation'] 78 | output_dict['target_azimuth'] = single_data_dict['azimuth'] 79 | 80 | list_dict.append(output_dict) 81 | 82 | return list_dict -------------------------------------------------------------------------------- /runme.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # You need to modify this path to your downloaded dataset directory 3 | DATASET_DIR='/vol/vssp/cvpnobackup/scratch_4weeks/qk00006/dcase2019/task3/dataset_root' 4 | 5 | # You need to modify this path to your workspace to store features, models, etc. 6 | WORKSPACE='/vol/vssp/msos/qk/workspaces/dcase2019_task3' 7 | 8 | # Hyper-parameters 9 | GPU_ID=1 10 | DATA_TYPE='development' # 'development' | 'evaluation' 11 | AUDIO_TYPE='foa' # 'foa' | 'mic' 12 | MODEL_TYPE='Cnn_9layers_AvgPooling' 13 | BATCH_SIZE=32 14 | 15 | # Calculate feature 16 | python utils/features.py calculate_feature_for_each_audio_file --dataset_dir=$DATASET_DIR --workspace=$WORKSPACE --data_type=$DATA_TYPE --audio_type=$AUDIO_TYPE 17 | 18 | # Calculate scalar 19 | python utils/features.py calculate_scalar --workspace=$WORKSPACE --data_type=$DATA_TYPE --audio_type=$AUDIO_TYPE 20 | 21 | ############ Train and validate system on development dataset ############ 22 | for HOLDOUT_FOLD in '1' '2' '3' '4' 23 | do 24 | echo 'Holdout fold: '$HOLDOUT_FOLD 25 | 26 | # Train 27 | CUDA_VISIBLE_DEVICES=$GPU_ID python pytorch/main.py train --dataset_dir=$DATASET_DIR --workspace=$WORKSPACE --audio_type=$AUDIO_TYPE --holdout_fold=$HOLDOUT_FOLD --model_type=$MODEL_TYPE --batch_size=$BATCH_SIZE --cuda 28 | 29 | # Validate 30 | CUDA_VISIBLE_DEVICES=$GPU_ID python pytorch/main.py inference_validation --dataset_dir=$DATASET_DIR --workspace=$WORKSPACE --audio_type=$AUDIO_TYPE --holdout_fold=$HOLDOUT_FOLD --model_type=$MODEL_TYPE --iteration=5000 --batch_size=$BATCH_SIZE --cuda 31 | 32 | HOLDOUT_FOLD=$[$HOLDOUT_FOLD+1] 33 | done 34 | 35 | # Calculate metrics on all cross-validation folds 36 | HOLDOUT_FOLD=-1 37 | CUDA_VISIBLE_DEVICES=$GPU_ID python pytorch/main.py inference_validation --dataset_dir=$DATASET_DIR --workspace=$WORKSPACE --audio_type=$AUDIO_TYPE --holdout_fold=$HOLDOUT_FOLD --model_type=$MODEL_TYPE --iteration=5000 --batch_size=$BATCH_SIZE --cuda 38 | 39 | # Plot statistics 40 | python utils/plot_results.py --dataset_dir=$DATASET_DIR --workspace=$WORKSPACE --audio_type='foa' 41 | 42 | ############ END ############ 43 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | sample_rate = 32000 2 | window_size = 1024 3 | hop_size = 500 # So that there are 64 frames per second 4 | mel_bins = 64 5 | fmin = 50 # Hz 6 | fmax = 14000 # Hz 7 | 8 | frames_per_second = sample_rate // hop_size 9 | time_steps = frames_per_second * 10 # 10-second log mel spectrogram as input 10 | submission_frames_per_second = 50 # DCASE2019 Task3 submission format 11 | 12 | # The label configuration is the same as https://github.com/sharathadavanne/seld-dcase2019 13 | labels = ['knock', 'drawer', 'clearthroat', 'phone', 'keysDrop', 'speech', 14 | 'keyboard', 'pageturn', 'cough', 'doorslam', 'laughter'] 15 | 16 | classes_num = len(labels) 17 | lb_to_idx = {lb: idx for idx, lb in enumerate(labels)} 18 | idx_to_lb = {idx: lb for idx, lb in enumerate(labels)} -------------------------------------------------------------------------------- /utils/data_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import csv 4 | import time 5 | import logging 6 | import os 7 | import glob 8 | import matplotlib.pyplot as plt 9 | import logging 10 | 11 | from utilities import scale 12 | import config 13 | 14 | 15 | class DataGenerator(object): 16 | 17 | def __init__(self, features_dir, scalar, batch_size, holdout_fold, seed=1234): 18 | '''Data generator for training and validation. 19 | 20 | Args: 21 | features_dir: string, directory of features 22 | scalar: object, containing mean and std value 23 | batch_size: int 24 | holdout_fold: '1' | '2' | '3' | '4' | 'none', where 'none' indicates 25 | using all data without validation for training 26 | seed: int, random seed 27 | ''' 28 | 29 | self.scalar = scalar 30 | self.batch_size = batch_size 31 | self.random_state = np.random.RandomState(seed) 32 | 33 | self.frames_per_second = config.frames_per_second 34 | self.classes_num = config.classes_num 35 | self.lb_to_idx = config.lb_to_idx 36 | self.time_steps = config.time_steps 37 | 38 | # Load data 39 | load_time = time.time() 40 | 41 | feature_names = sorted(os.listdir(features_dir)) 42 | 43 | self.train_feature_names = [name for name in feature_names \ 44 | if 'split{}'.format(holdout_fold) not in name] 45 | 46 | self.validate_feature_names = [name for name in feature_names \ 47 | if 'split{}'.format(holdout_fold) in name] 48 | 49 | self.train_features_list = [] 50 | self.train_event_matrix_list = [] 51 | self.train_elevation_matrix_list = [] 52 | self.train_azimuth_matrix_list = [] 53 | self.train_index_array_list = [] 54 | frame_index = 0 55 | 56 | # Load training feature and targets 57 | for feature_name in self.train_feature_names: 58 | feature_path = os.path.join(features_dir, feature_name) 59 | 60 | (feature, event_matrix, elevation_matrix, azimuth_matrix) = \ 61 | self.load_hdf5(feature_path) 62 | 63 | frames_num = feature.shape[1] 64 | '''Number of frames of the log mel spectrogram of an audio 65 | recording. May be different from file to file''' 66 | 67 | index_array = np.arange(frame_index, frame_index + frames_num - self.time_steps) 68 | frame_index += frames_num 69 | 70 | # Append data 71 | self.train_features_list.append(feature) 72 | self.train_event_matrix_list.append(event_matrix) 73 | self.train_elevation_matrix_list.append(elevation_matrix) 74 | self.train_azimuth_matrix_list.append(azimuth_matrix) 75 | self.train_index_array_list.append(index_array) 76 | 77 | self.train_features = np.concatenate(self.train_features_list, axis=1) 78 | self.train_event_matrix = np.concatenate(self.train_event_matrix_list, axis=0) 79 | self.train_elevation_matrix = np.concatenate(self.train_elevation_matrix_list, axis=0) 80 | self.train_azimuth_matrix = np.concatenate(self.train_azimuth_matrix_list, axis=0) 81 | self.train_index_array = np.concatenate(self.train_index_array_list, axis=0) 82 | 83 | # Load validation feature and targets 84 | self.validate_features_list = [] 85 | self.validate_event_matrix_list = [] 86 | self.validate_elevation_matrix_list = [] 87 | self.validate_azimuth_matrix_list = [] 88 | 89 | for feature_name in self.validate_feature_names: 90 | feature_path = os.path.join(features_dir, feature_name) 91 | 92 | (feature, event_matrix, elevation_matrix, azimuth_matrix) = \ 93 | self.load_hdf5(feature_path) 94 | 95 | self.validate_features_list.append(feature) 96 | self.validate_event_matrix_list.append(event_matrix) 97 | self.validate_elevation_matrix_list.append(elevation_matrix) 98 | self.validate_azimuth_matrix_list.append(azimuth_matrix) 99 | 100 | logging.info('Load data time: {:.3f} s'.format(time.time() - load_time)) 101 | logging.info('Training audio num: {}'.format(len(self.train_feature_names))) 102 | logging.info('Validation audio num: {}'.format(len(self.validate_feature_names))) 103 | 104 | self.random_state.shuffle(self.train_index_array) 105 | self.pointer = 0 106 | 107 | def load_hdf5(self, feature_path): 108 | '''Load hdf5. 109 | 110 | Args: 111 | feature_path: string 112 | 113 | Returns: 114 | feature: (channels_num, frames_num, freq_bins) 115 | eevnt_matrix: (frames_num, classes_num) 116 | elevation_matrix: (frames_num, classes_num) 117 | azimuth_matrix: (frames_num, classes_num) 118 | ''' 119 | 120 | with h5py.File(feature_path, 'r') as hf: 121 | feature = hf['feature'][:] 122 | events = [e.decode() for e in hf['target']['event'][:]] 123 | start_times = hf['target']['start_time'][:] 124 | end_times = hf['target']['end_time'][:] 125 | elevations = hf['target']['elevation'][:] 126 | azimuths = hf['target']['azimuth'][:] 127 | distances = hf['target']['distance'][:] 128 | 129 | frames_num = feature.shape[1] 130 | 131 | # Researve space data 132 | event_matrix = np.zeros((frames_num, self.classes_num)) 133 | elevation_matrix = np.zeros((frames_num, self.classes_num)) 134 | azimuth_matrix = np.zeros((frames_num, self.classes_num)) 135 | 136 | for n in range(len(events)): 137 | class_id = self.lb_to_idx[events[n]] 138 | start_frame = int(round(start_times[n] * self.frames_per_second)) 139 | end_frame = int(round(end_times[n] * self.frames_per_second)) + 1 140 | 141 | event_matrix[start_frame : end_frame, class_id] = 1 142 | elevation_matrix[start_frame : end_frame, class_id] = elevations[n] 143 | azimuth_matrix[start_frame : end_frame, class_id] = azimuths[n] 144 | 145 | return feature, event_matrix, elevation_matrix, azimuth_matrix 146 | 147 | def generate_train(self): 148 | '''Generate mini-batch data for training. 149 | 150 | Returns: 151 | batch_data_dict: dict containing feature, event, elevation and azimuth 152 | ''' 153 | 154 | while True: 155 | # Reset pointer 156 | if self.pointer >= len(self.train_index_array): 157 | self.pointer = 0 158 | self.random_state.shuffle(self.train_index_array) 159 | 160 | # Get batch indexes 161 | batch_indexes = self.train_index_array[ 162 | self.pointer: self.pointer + self.batch_size] 163 | 164 | data_indexes = batch_indexes[:, None] + np.arange(self.time_steps) 165 | 166 | self.pointer += self.batch_size 167 | 168 | batch_feature = self.train_features[:, data_indexes] 169 | batch_event_matrix = self.train_event_matrix[data_indexes] 170 | batch_elevation_matrix = self.train_elevation_matrix[data_indexes] 171 | batch_azimuth_matrix = self.train_azimuth_matrix[data_indexes] 172 | 173 | # Transform data 174 | batch_feature = self.transform(batch_feature) 175 | 176 | batch_data_dict = { 177 | 'feature': batch_feature, 178 | 'event': batch_event_matrix, 179 | 'elevation': batch_elevation_matrix, 180 | 'azimuth': batch_azimuth_matrix} 181 | 182 | yield batch_data_dict 183 | 184 | def generate_validate(self, data_type, max_validate_num=None): 185 | '''Generate feature and targets of a single audio file. 186 | 187 | Args: 188 | data_type: 'train' | 'validate' 189 | max_validate_num: None | int, maximum iteration to run to speed up 190 | evaluation 191 | 192 | Returns: 193 | batch_data_dict: dict containing feature, event, elevation and azimuth 194 | ''' 195 | 196 | if data_type == 'train': 197 | feature_names = self.train_feature_names 198 | features_list = self.train_features_list 199 | event_matrix_list = self.train_event_matrix_list 200 | elevation_matrix_list = self.train_elevation_matrix_list 201 | azimuth_matrix_list = self.train_azimuth_matrix_list 202 | 203 | elif data_type == 'validate': 204 | feature_names = self.validate_feature_names 205 | features_list = self.validate_features_list 206 | event_matrix_list = self.validate_event_matrix_list 207 | elevation_matrix_list = self.validate_elevation_matrix_list 208 | azimuth_matrix_list = self.validate_azimuth_matrix_list 209 | 210 | else: 211 | raise Exception('Incorrect argument!') 212 | 213 | validate_num = len(feature_names) 214 | 215 | for n in range(validate_num): 216 | if n == max_validate_num: 217 | break 218 | 219 | name = os.path.splitext(feature_names[n])[0] 220 | feature = features_list[n] 221 | event_matrix = event_matrix_list[n] 222 | elevation_matrix = elevation_matrix_list[n] 223 | azimuth_matrix = azimuth_matrix_list[n] 224 | 225 | feature = self.transform(feature) 226 | 227 | batch_data_dict = { 228 | 'name': name, 229 | 'feature': feature[:, None, :, :], # (channels_num, batch_size=1, frames_num, mel_bins) 230 | 'event': event_matrix[None, :, :], # (batch_size=1, frames_num, mel_bins) 231 | 'elevation': elevation_matrix[None, :, :], # (batch_size=1, frames_num, mel_bins) 232 | 'azimuth': azimuth_matrix[None, :, :] # (batch_size=1, frames_num, mel_bins) 233 | } 234 | '''The None above indicates using an entire audio recording as 235 | input and batch_size=1 in inference''' 236 | 237 | yield batch_data_dict 238 | 239 | def transform(self, x): 240 | return scale(x, self.scalar['mean'], self.scalar['std']) 241 | -------------------------------------------------------------------------------- /utils/features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(1, os.path.join(sys.path[0], 'utils')) 4 | import numpy as np 5 | import pandas as pd 6 | import argparse 7 | import h5py 8 | import librosa 9 | from scipy import signal 10 | import matplotlib.pyplot as plt 11 | import time 12 | import csv 13 | import random 14 | 15 | from utilities import (read_multichannel_audio, create_folder, 16 | calculate_scalar_of_tensor) 17 | import config 18 | 19 | 20 | class LogMelExtractor(object): 21 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, fmax): 22 | '''Log mel feature extractor. 23 | 24 | Args: 25 | sample_rate: int 26 | window_size: int 27 | hop_size: int 28 | mel_bins: int 29 | fmin: int, minimum frequency of mel filter banks 30 | fmax: int, maximum frequency of mel filter banks 31 | ''' 32 | 33 | self.window_size = window_size 34 | self.hop_size = hop_size 35 | self.window_func = np.hanning(window_size) 36 | 37 | self.melW = librosa.filters.mel( 38 | sr=sample_rate, 39 | n_fft=window_size, 40 | n_mels=mel_bins, 41 | fmin=fmin, 42 | fmax=fmax).T 43 | '''(n_fft // 2 + 1, mel_bins)''' 44 | 45 | def transform_multichannel(self, multichannel_audio): 46 | '''Extract feature of a multichannel audio file. 47 | 48 | Args: 49 | multichannel_audio: (samples, channels_num) 50 | 51 | Returns: 52 | feature: (channels_num, frames_num, freq_bins) 53 | ''' 54 | 55 | (samples, channels_num) = multichannel_audio.shape 56 | 57 | feature = np.array([self.transform_singlechannel( 58 | multichannel_audio[:, m]) for m in range(channels_num)]) 59 | 60 | return feature 61 | 62 | def transform_singlechannel(self, audio): 63 | '''Extract feature of a singlechannel audio file. 64 | 65 | Args: 66 | audio: (samples,) 67 | 68 | Returns: 69 | feature: (frames_num, freq_bins) 70 | ''' 71 | 72 | window_size = self.window_size 73 | hop_size = self.hop_size 74 | window_func = self.window_func 75 | 76 | # Compute short-time Fourier transform 77 | stft_matrix = librosa.core.stft( 78 | y=audio, 79 | n_fft=window_size, 80 | hop_length=hop_size, 81 | window=window_func, 82 | center=True, 83 | dtype=np.complex64, 84 | pad_mode='reflect').T 85 | '''(N, n_fft // 2 + 1)''' 86 | 87 | # Mel spectrogram 88 | mel_spectrogram = np.dot(np.abs(stft_matrix) ** 2, self.melW) 89 | 90 | # Log mel spectrogram 91 | logmel_spectrogram = librosa.core.power_to_db( 92 | mel_spectrogram, ref=1.0, amin=1e-10, 93 | top_db=None) 94 | 95 | logmel_spectrogram = logmel_spectrogram.astype(np.float32) 96 | 97 | return logmel_spectrogram 98 | 99 | 100 | def calculate_feature_for_each_audio_file(args): 101 | '''Calculate feature for each audio file and write out to hdf5. 102 | 103 | Args: 104 | dataset_dir: string 105 | workspace: string 106 | data_type: 'development' | 'evaluation' 107 | audio_type: 'foa' | 'mic' 108 | mini_data: bool, set True for debugging on a small part of data 109 | ''' 110 | 111 | # Arguments & parameters 112 | dataset_dir = args.dataset_dir 113 | workspace = args.workspace 114 | data_type = args.data_type 115 | audio_type = args.audio_type 116 | mini_data = args.mini_data 117 | 118 | sample_rate = config.sample_rate 119 | window_size = config.window_size 120 | hop_size = config.hop_size 121 | mel_bins = config.mel_bins 122 | fmin = config.fmin 123 | fmax = config.fmax 124 | frames_per_second = config.frames_per_second 125 | 126 | # Paths 127 | if data_type == 'development': 128 | data_type = 'dev' 129 | 130 | elif data_type == 'evaluation': 131 | data_type = 'eva' 132 | raise Exception('Todo after evaluation data released. ') 133 | 134 | if mini_data: 135 | prefix = 'minidata_' 136 | else: 137 | prefix = '' 138 | 139 | metas_dir = os.path.join(dataset_dir, 'metadata_{}'.format(data_type)) 140 | audios_dir = os.path.join(dataset_dir, '{}_{}'.format(audio_type, data_type)) 141 | 142 | features_dir = os.path.join(workspace, 'features', 143 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 144 | data_type, frames_per_second, mel_bins)) 145 | 146 | create_folder(features_dir) 147 | 148 | # Feature extractor 149 | feature_extractor = LogMelExtractor( 150 | sample_rate=sample_rate, 151 | window_size=window_size, 152 | hop_size=hop_size, 153 | mel_bins=mel_bins, 154 | fmin=fmin, 155 | fmax=fmax) 156 | 157 | # Extract features and targets 158 | meta_names = sorted(os.listdir(metas_dir)) 159 | 160 | if mini_data: 161 | mini_num = 10 162 | random_state = np.random.RandomState(1234) 163 | random_state.shuffle(meta_names) 164 | meta_names = meta_names[0 : mini_num] 165 | 166 | print('Extracting features of all audio files ...') 167 | extract_time = time.time() 168 | 169 | for (n, meta_name) in enumerate(meta_names): 170 | meta_path = os.path.join(metas_dir, meta_name) 171 | bare_name = os.path.splitext(meta_name)[0] 172 | audio_path = os.path.join(audios_dir, '{}.wav'.format(bare_name)) 173 | feature_path = os.path.join(features_dir, '{}.h5'.format(bare_name)) 174 | 175 | df = pd.read_csv(meta_path, sep=',') 176 | event_array = df['sound_event_recording'].values 177 | start_time_array = df['start_time'].values 178 | end_time_array = df['end_time'].values 179 | elevation_array = df['ele'].values 180 | azimuth_array = df['azi'].values 181 | distance_array = df['dist'].values 182 | 183 | # Read audio 184 | (multichannel_audio, _) = read_multichannel_audio( 185 | audio_path=audio_path, 186 | target_fs=sample_rate) 187 | 188 | # Extract feature 189 | feature = feature_extractor.transform_multichannel(multichannel_audio) 190 | 191 | with h5py.File(feature_path, 'w') as hf: 192 | hf.create_dataset('feature', data=feature, dtype=np.float32) 193 | 194 | hf.create_group('target') 195 | hf['target'].create_dataset('event', data=[e.encode() for e in event_array], dtype='S20') 196 | hf['target'].create_dataset('start_time', data=start_time_array, dtype=np.float32) 197 | hf['target'].create_dataset('end_time', data=end_time_array, dtype=np.float32) 198 | hf['target'].create_dataset('elevation', data=elevation_array, dtype=np.int32) 199 | hf['target'].create_dataset('azimuth', data=azimuth_array, dtype=np.int32) 200 | hf['target'].create_dataset('distance', data=distance_array, dtype=np.int32) 201 | 202 | print(n, feature_path, feature.shape) 203 | 204 | print('Extract features finished! {:.3f} s'.format(time.time() - extract_time)) 205 | 206 | 207 | def calculate_scalar(args): 208 | '''Calculate and write out scalar of development data. 209 | 210 | Args: 211 | dataset_dir: string 212 | workspace: string 213 | audio_type: 'foa' | 'mic' 214 | mini_data: bool, set True for debugging on a small part of data 215 | ''' 216 | 217 | # Arguments & parameters 218 | workspace = args.workspace 219 | audio_type = args.audio_type 220 | mini_data = args.mini_data 221 | data_type = 'dev' 222 | 223 | mel_bins = config.mel_bins 224 | frames_per_second = config.frames_per_second 225 | 226 | # Paths 227 | if mini_data: 228 | prefix = 'minidata_' 229 | else: 230 | prefix = '' 231 | 232 | features_dir = os.path.join(workspace, 'features', 233 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 234 | data_type, frames_per_second, mel_bins)) 235 | 236 | scalar_path = os.path.join(workspace, 'scalars', 237 | '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, audio_type, 238 | data_type, frames_per_second, mel_bins), 'scalar.h5') 239 | 240 | create_folder(os.path.dirname(scalar_path)) 241 | 242 | # Load data 243 | load_time = time.time() 244 | feature_names = os.listdir(features_dir) 245 | all_features = [] 246 | 247 | for feature_name in feature_names: 248 | feature_path = os.path.join(features_dir, feature_name) 249 | 250 | with h5py.File(feature_path, 'r') as hf: 251 | feature = hf['feature'][:] 252 | all_features.append(feature) 253 | 254 | print('Load feature time: {:.3f} s'.format(time.time() - load_time)) 255 | 256 | # Calculate scalar 257 | all_features = np.concatenate(all_features, axis=1) 258 | (mean, std) = calculate_scalar_of_tensor(all_features) 259 | 260 | with h5py.File(scalar_path, 'w') as hf: 261 | hf.create_dataset('mean', data=mean, dtype=np.float32) 262 | hf.create_dataset('std', data=std, dtype=np.float32) 263 | 264 | print('All features: {}'.format(all_features.shape)) 265 | print('mean: {}'.format(mean)) 266 | print('std: {}'.format(std)) 267 | print('Write out scalar to {}'.format(scalar_path)) 268 | 269 | 270 | if __name__ == '__main__': 271 | parser = argparse.ArgumentParser(description='') 272 | subparsers = parser.add_subparsers(dest='mode') 273 | 274 | # Calculate feature for each audio file 275 | parser_logmel = subparsers.add_parser('calculate_feature_for_each_audio_file') 276 | parser_logmel.add_argument('--dataset_dir', type=str, required=True, help='Directory of dataset.') 277 | parser_logmel.add_argument('--workspace', type=str, required=True, help='Directory of your workspace.') 278 | parser_logmel.add_argument('--data_type', type=str, required=True, choices=['development', 'evaluation']) 279 | parser_logmel.add_argument('--audio_type', type=str, required=True, choices=['foa', 'mic']) 280 | parser_logmel.add_argument('--mini_data', action='store_true', default=False, help='Set True for debugging on a small part of data.') 281 | 282 | # Calculate scalar 283 | parser_scalar = subparsers.add_parser('calculate_scalar') 284 | parser_scalar.add_argument('--workspace', type=str, required=True, help='Directory of your workspace.') 285 | parser_scalar.add_argument('--data_type', type=str, required=True, choices=['development'], help='Scalar is calculated on development set.') 286 | parser_scalar.add_argument('--audio_type', type=str, required=True, choices=['foa', 'mic']) 287 | parser_scalar.add_argument('--mini_data', action='store_true', default=False, help='Set True for debugging on a small part of data.') 288 | 289 | # Parse arguments 290 | args = parser.parse_args() 291 | 292 | if args.mode == 'calculate_feature_for_each_audio_file': 293 | calculate_feature_for_each_audio_file(args) 294 | 295 | elif args.mode == 'calculate_scalar': 296 | calculate_scalar(args) 297 | 298 | else: 299 | raise Exception('Incorrect arguments!') -------------------------------------------------------------------------------- /utils/plot_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import matplotlib.pyplot as plt 4 | import _pickle as cPickle 5 | import numpy as np 6 | 7 | import config 8 | 9 | 10 | def plot_results(args): 11 | 12 | # Arugments & parameters 13 | dataset_dir = args.dataset_dir 14 | workspace = args.workspace 15 | audio_type = args.audio_type 16 | 17 | filename = 'main' 18 | prefix = '' 19 | frames_per_second = config.frames_per_second 20 | mel_bins = config.mel_bins 21 | holdout_fold = 1 22 | max_plot_iteration = 5000 23 | 24 | iterations = np.arange(0, max_plot_iteration, 200) 25 | 26 | def _load_stat(model_type): 27 | validate_statistics_path = os.path.join(workspace, 'statistics', 28 | filename, '{}{}_{}_logmel_{}frames_{}melbins'.format(prefix, 29 | audio_type, 'dev', frames_per_second, mel_bins), 30 | 'holdout_fold={}'.format(holdout_fold), model_type, 31 | 'validate_statistics.pickle') 32 | 33 | statistics_list = cPickle.load(open(validate_statistics_path, 'rb')) 34 | 35 | sed_error_rate = np.array([statistics['sed_error_rate'] 36 | for statistics in statistics_list]) 37 | 38 | sed_f1_score = np.array([statistics['sed_f1_score'] 39 | for statistics in statistics_list]) 40 | 41 | doa_error = np.array([statistics['doa_error'] 42 | for statistics in statistics_list]) 43 | 44 | doa_frame_recall = np.array([statistics['doa_frame_recall'] 45 | for statistics in statistics_list]) 46 | 47 | seld_score = np.array([statistics['seld_score'] 48 | for statistics in statistics_list]) 49 | 50 | legend = '{}'.format(model_type) 51 | 52 | results = {'sed_error_rate': sed_error_rate, 53 | 'sed_f1_score': sed_f1_score, 'doa_error': doa_error, 54 | 'doa_frame_recall': doa_frame_recall, 'seld_score': seld_score, 55 | 'legend': legend} 56 | 57 | print('Model type: {}'.format(model_type)) 58 | print(' sed_error_rate: {:.3f}'.format(sed_error_rate[-1])) 59 | print(' sed_f1_score: {:.3f}'.format(sed_f1_score[-1])) 60 | print(' doa_error: {:.3f}'.format(doa_error[-1])) 61 | print(' doa_frame_recall: {:.3f}'.format(doa_frame_recall[-1])) 62 | print(' seld_score: {:.3f}'.format(seld_score[-1])) 63 | 64 | return results 65 | 66 | measure_keys = ['sed_error_rate', 'sed_f1_score', 'doa_error', 'doa_frame_recall'] 67 | 68 | fig, axs = plt.subplots(2, 2, figsize=(12, 8)) 69 | 70 | results_dict = {} 71 | results_dict['Cnn_5layers_AvgPooling'] = _load_stat('Cnn_5layers_AvgPooling') 72 | results_dict['Cnn_9layers_AvgPooling'] = _load_stat('Cnn_9layers_AvgPooling') 73 | results_dict['Cnn_9layers_MaxPooling'] = _load_stat('Cnn_9layers_MaxPooling') 74 | results_dict['Cnn_13layers_AvgPooling'] = _load_stat('Cnn_13layers_AvgPooling') 75 | 76 | for n, measure_key in enumerate(measure_keys): 77 | lines = [] 78 | 79 | row = n // 2 80 | col = n % 2 81 | 82 | for model_key in results_dict.keys(): 83 | line, = axs[row, col].plot(results_dict[model_key][measure_key], label=results_dict[model_key]['legend']) 84 | lines.append(line) 85 | 86 | axs[row, col].set_title(measure_key) 87 | axs[row, col].legend(handles=lines, loc=4) 88 | axs[row, col].set_ylim(0, 1.0) 89 | axs[row, col].set_xlabel('Iterations') 90 | axs[row, col].grid(color='b', linestyle='solid', linewidth=0.2) 91 | axs[row, col].xaxis.set_ticks(np.arange(0, len(iterations), len(iterations) // 4)) 92 | axs[row, col].xaxis.set_ticklabels(np.arange(0, max_plot_iteration, max_plot_iteration // 4)) 93 | 94 | axs[1, 0].set_ylim(0, 100.) 95 | axs[0, 0].set_ylabel('sed_error_rate') 96 | axs[0, 1].set_ylabel('sed_f1_score') 97 | axs[1, 0].set_ylabel('doa_error') 98 | axs[1, 1].set_ylabel('doa_frame_recall') 99 | 100 | plt.tight_layout() 101 | plt.show() 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser(description='') 106 | parser.add_argument('--dataset_dir', type=str, required=True, help='Directory of dataset.') 107 | parser.add_argument('--workspace', type=str, required=True, help='Directory of your workspace.') 108 | parser.add_argument('--audio_type', type=str, choices=['foa', 'mic'], required=True) 109 | 110 | args = parser.parse_args() 111 | 112 | plot_results(args) -------------------------------------------------------------------------------- /utils/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(1, os.path.join(sys.path[0], '../evaluation_tools')) 4 | import numpy as np 5 | import soundfile 6 | import librosa 7 | import h5py 8 | from sklearn import metrics 9 | import logging 10 | import matplotlib.pyplot as plt 11 | 12 | import evaluation_metrics 13 | import cls_feature_class 14 | import config 15 | 16 | 17 | def create_folder(fd): 18 | if not os.path.exists(fd): 19 | os.makedirs(fd) 20 | 21 | 22 | def get_filename(path): 23 | path = os.path.realpath(path) 24 | name_ext = path.split('/')[-1] 25 | name = os.path.splitext(name_ext)[0] 26 | return name 27 | 28 | 29 | def create_logging(log_dir, filemode): 30 | 31 | create_folder(log_dir) 32 | i1 = 0 33 | 34 | while os.path.isfile(os.path.join(log_dir, '{:04d}.log'.format(i1))): 35 | i1 += 1 36 | 37 | log_path = os.path.join(log_dir, '{:04d}.log'.format(i1)) 38 | logging.basicConfig( 39 | level=logging.DEBUG, 40 | format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', 41 | datefmt='%a, %d %b %Y %H:%M:%S', 42 | filename=log_path, 43 | filemode=filemode) 44 | 45 | # Print to console 46 | console = logging.StreamHandler() 47 | console.setLevel(logging.INFO) 48 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 49 | console.setFormatter(formatter) 50 | logging.getLogger('').addHandler(console) 51 | 52 | return logging 53 | 54 | 55 | def read_multichannel_audio(audio_path, target_fs=None): 56 | 57 | (multichannel_audio, fs) = soundfile.read(audio_path) 58 | '''(samples, channels_num)''' 59 | 60 | if target_fs is not None and fs != target_fs: 61 | (samples, channels_num) = multichannel_audio.shape 62 | 63 | multichannel_audio = np.array( 64 | [librosa.resample( 65 | multichannel_audio[:, i], 66 | orig_sr=fs, 67 | target_sr=target_fs) 68 | for i in range(channels_num)]).T 69 | '''(samples, channels_num)''' 70 | 71 | return multichannel_audio, fs 72 | 73 | 74 | def calculate_scalar_of_tensor(x): 75 | if x.ndim == 2: 76 | axis = 0 77 | elif x.ndim == 3: 78 | axis = (0, 1) 79 | 80 | mean = np.mean(x, axis=axis) 81 | std = np.std(x, axis=axis) 82 | 83 | return mean, std 84 | 85 | 86 | def load_scalar(scalar_path): 87 | with h5py.File(scalar_path, 'r') as hf: 88 | mean = hf['mean'][:] 89 | std = hf['std'][:] 90 | 91 | scalar = {'mean': mean, 'std': std} 92 | return scalar 93 | 94 | 95 | def scale(x, mean, std): 96 | return (x - mean) / std 97 | 98 | 99 | def inverse_scale(x, mean, std): 100 | return x * std + mean 101 | 102 | 103 | def resample_matrix(matrix, ratio): 104 | '''Resample matrix 105 | 106 | Args: 107 | matrix: (time_steps, classes_num) 108 | ratio: float, ratio to resample 109 | ''' 110 | new_len = int(round(ratio * matrix.shape[0])) 111 | new_matrix = np.zeros((new_len, matrix.shape[1])) 112 | 113 | for n in range(new_len): 114 | new_matrix[n] = matrix[int(round(n / ratio))] 115 | 116 | return new_matrix 117 | 118 | 119 | def calculate_metrics(metadata_dir, prediction_paths): 120 | '''Calculate metrics using official tool. This part of code is modified from: 121 | https://github.com/sharathadavanne/seld-dcase2019/blob/master/calculate_SELD_metrics.py 122 | 123 | Args: 124 | metadata_dir: string, directory of reference files. 125 | prediction_paths: list of string 126 | 127 | Returns: 128 | metrics: dict 129 | ''' 130 | 131 | # Load feature class 132 | feat_cls = cls_feature_class.FeatureClass() 133 | 134 | # Load evaluation metric class 135 | eval = evaluation_metrics.SELDMetrics( 136 | nb_frames_1s=feat_cls.nb_frames_1s(), data_gen=feat_cls) 137 | 138 | eval.reset() # Reset the evaluation metric parameters 139 | for prediction_path in prediction_paths: 140 | reference_path = os.path.join(metadata_dir, '{}.csv'.format( 141 | get_filename(prediction_path))) 142 | 143 | prediction_dict = evaluation_metrics.load_output_format_file(prediction_path) 144 | reference_dict = feat_cls.read_desc_file(reference_path) 145 | 146 | # Generate classification labels for SELD 147 | reference_tensor = feat_cls.get_clas_labels_for_file(reference_dict) 148 | prediction_tensor = evaluation_metrics.output_format_dict_to_classification_labels( 149 | prediction_dict, feat_cls) 150 | 151 | # Calculated SED and DOA scores 152 | eval.update_sed_scores(prediction_tensor.max(2), reference_tensor.max(2)) 153 | eval.update_doa_scores(prediction_tensor, reference_tensor) 154 | 155 | # Overall SED and DOA scores 156 | sed_error_rate, sed_f1_score = eval.compute_sed_scores() 157 | doa_error, doa_frame_recall = eval.compute_doa_scores() 158 | seld_score = evaluation_metrics.compute_seld_metric( 159 | [sed_error_rate, sed_f1_score], [doa_error, doa_frame_recall]) 160 | 161 | metrics = { 162 | 'sed_error_rate': sed_error_rate, 163 | 'sed_f1_score': sed_f1_score, 164 | 'doa_error': doa_error, 165 | 'doa_frame_recall': doa_frame_recall, 166 | 'seld_score': seld_score } 167 | 168 | return metrics 169 | 170 | 171 | def write_submission(list_dict, submissions_dir): 172 | '''Write predicted result to submission csv files. 173 | 174 | Args: 175 | list_dict: list of dict, e.g.: 176 | [{'name': 'split1_ir0_ov1_7', 177 | 'output_event': (1, frames_num, classes_num), 178 | 'output_elevation': (1, frames_num, classes_num), 179 | 'output_azimuth': (1, frames_num, classes_num), 180 | ... 181 | }, 182 | ...] 183 | submissions_dir: string, directory to write out submission files 184 | ''' 185 | 186 | frames_per_second = config.frames_per_second 187 | submission_frames_per_second = config.submission_frames_per_second 188 | 189 | for dict in list_dict: 190 | filename = '{}.csv'.format(dict['name']) 191 | filepath = os.path.join(submissions_dir, filename) 192 | 193 | event_matrix = dict['output_event'][0] 194 | elevation_matrix = dict['output_elevation'][0] 195 | azimuth_matrix = dict['output_azimuth'][0] 196 | 197 | # Resample predicted frames to submission format 198 | ratio = submission_frames_per_second / float(frames_per_second) 199 | resampled_event_matrix = resample_matrix(event_matrix, ratio) 200 | resampled_elevation_matrix = resample_matrix(elevation_matrix, ratio) 201 | resampled_azimuth_matrix = resample_matrix(azimuth_matrix, ratio) 202 | 203 | with open(filepath, 'w') as f: 204 | for n in range(resampled_event_matrix.shape[0]): 205 | for k in range(resampled_event_matrix.shape[1]): 206 | if resampled_event_matrix[n, k] > 0.5: 207 | elevation = int(resampled_elevation_matrix[n, k]) 208 | azimuth = int(resampled_azimuth_matrix[n, k]) 209 | f.write('{},{},{},{}\n'.format(n, k, azimuth, elevation)) 210 | 211 | logging.info(' Total {} files written to {}'.format(len(list_dict), submissions_dir)) 212 | --------------------------------------------------------------------------------