├── LICENSE
├── README.md
├── feature.py
├── images
└── CRNN_SED_DCASE2017_task3.jpg
├── metrics.py
├── requirements.txt
├── sed.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | -----------COPYRIGHT NOTICE STARTS WITH THIS LINE------------
2 | Copyright (c) 2017 Tampere University of Technology and its licensors
3 | All rights reserved.
4 |
5 | Permission is hereby granted, without written agreement and without
6 | license or royalty fees, to use and copy the code for the Multichannel
7 | Sound Event Detection using Convolutional Recurrent Neural Network
8 | method/architecture, present in the GitHub repository with the handle
9 | sed-crnn, (“Work”) described in the paper with title
10 | "Sound event detection using spatial features and convolutional
11 | recurrent neural network" (and available also from
12 | https://arxiv.org/abs/1706.02291) and composed of files with code in the
13 | Python programming language. This grant is only for experimental and
14 | non-commercial purposes, provided that the copyright notice in its entirety
15 | appear in all copies of this Work, and the original source of this Work,
16 | Audio Research Group, Lab. of Signal Processing at Tampere University
17 | of Technology, is acknowledged in any publication that reports research
18 | using this Work.
19 |
20 | Any commercial use of the Work or any part thereof is strictly prohibited.
21 | Commercial use include, but is not limited to:
22 | - selling or reproducing the Work
23 | - selling or distributing the results or content achieved by use of the Work
24 | - providing services by using the Work.
25 |
26 | IN NO EVENT SHALL TAMPERE UNIVERSITY OF TECHNOLOGY OR ITS LICENSORS BE LIABLE TO
27 | ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES
28 | ARISING OUT OF THE USE OF THIS WORK AND ITS DOCUMENTATION, EVEN IF TAMPERE
29 | UNIVERSITY OF TECHNOLOGY OR ITS LICENSORS HAS BEEN ADVISED OF THE POSSIBILITY
30 | OF SUCH DAMAGE.
31 |
32 | TAMPERE UNIVERSITY OF TECHNOLOGY AND ALL ITS LICENSORS SPECIFICALLY DISCLAIMS
33 | ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
34 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE WORK PROVIDED HEREUNDER
35 | IS ON AN "AS IS" BASIS, AND THE TAMPERE UNIVERSITY OF TECHNOLOGY HAS NO OBLIGATION
36 | TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
37 |
38 | -----------COPYRIGHT NOTICE ENDS WITH THIS LINE------------
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Single and multichannel sound event detection using convolutional recurrent neural network
2 | [Sound event detection (SED)](https://www.aane.in/research/computational-audio-scene-analysis-casa/sound-event-detection) is the task of recognizing the sound events and their respective temporal start and end time in a recording. Sound events in real life do not always occur in isolation, but tend to considerably overlap with each other.
3 | Recognizing such overlapping sound events is referred as polyphonic SED. Performing polyphonic SED using monochannel audio is a challenging task. These overlapping sound events can potentially be recognized better with multichannel audio.
4 | This repository supports both single- and multichannel versions of polyphonic SED and is referred as SEDnet hereafter. You can read more about [sound event detection literature here](https://www.aane.in/research/computational-audio-scene-analysis-casa/sound-event-detection).
5 |
6 | This method was first proposed in '[Sound event detection using spatial features and convolutional recurrent neural network](https://arxiv.org/abs/1706.02291 "Arxiv paper")'. It recently won the [DCASE 2017 real-life sound event detection](https://goo.gl/8eqCg3 "Challenge webpage"). We are releasing a simple vanila code without much frills here.
7 |
8 | If you are using anything from this repository please consider citing,
9 |
10 | >Sharath Adavanne, Pasi Pertila and Tuomas Virtanen, "Sound event detection using spatial features and convolutional recurrent neural network" in IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP 2017)
11 |
12 | Similar CRNN architecture has been successfully used for different tasks and research challenges as below. You can accordingly play around with a suitable prediction layer as the task requires.
13 |
14 | 1. Sound event detection
15 | - Sharath Adavanne, Pasi Pertila and Tuomas Virtanen, '[Sound event detection using spatial features and convolutional recurrent neural network](https://arxiv.org/abs/1706.02291 "Arxiv paper")' at IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP 2017)
16 | - Sharath Adavanne, Archontis Politis and Tuomas Virtanen, '[Multichannel sound event detection using 3D convolutional neural networks for learning inter-channel features](https://arxiv.org/abs/1801.09522 "Arxiv paper")' at International Joint Conference on Neural Networks (IJCNN 2018)
17 |
18 | 2. SED with weak labels
19 | - Sharath Adavanne and Tuomas Virtanen, '[Sound event detection using weakly labeled dataset with stacked convolutional and recurrent neural network](https://arxiv.org/abs/1710.02998 "Arxiv paper")' at Detection and Classification of Acoustic Scenes and Events (DCASE 2017)
20 |
21 | 3. Bird audio detection
22 | - Sharath Adavanne, Konstantinos Drossos, Emre Cakir and Tuomas Virtanen, '[Stacked convolutional and recurrent neural networks for bird audio detection](https://arxiv.org/abs/1706.02047 "Arxiv paper")' at European Signal Processing Conference (EUSIPCO 2017)
23 | - Emre Cakir, Sharath Adavanne, Giambattista Parascandolo, Konstantinos Drossos and Tuomas Virtanen, '[Convolutional recurrent neural networks for bird audio detection](https://arxiv.org/abs/1703.02317 "Arxiv paper")' at European Signal Processing Conference (EUSIPCO 2017)
24 |
25 | 4. Music emotion recognition
26 | - Miroslav Malik, Sharath Adavanne, Konstantinos Drossos, Tuomas Virtanen, Dasa Ticha, Roman Jarina , '[Stacked convolutional and recurrent neural networks for music emotion recognition](https://arxiv.org/abs/1706.02292 "Arxiv paper")', at Sound and Music Computing Conference (SMC 2017)
27 |
28 | ## More about SEDnet
29 | The proposed SEDnet is shown in the figure below. The input to the method is either a single or multichannel audio. The log mel-band energy feature is then extracted from each channel of the corresponding input audio. These audio features are fed to a convolutional recurrent neural network that maps them to the activities of the sound event classes in the dataset. The output of the neural network is in the continuous range of [0, 1] for each of the sound event classes and corresponds to the probability of the particular sound class being active in the frame. This continuous range output is further thresholded to obtain the final binary decision of the sound event class being active or absent in each frame. In general, the proposed method takes a sequence of frame-wise audio features as the input and predicts the activity of the target sound event classes for each of the input frames.
30 |
31 |
32 |
33 |
34 |
35 |
36 | ## Getting Started
37 |
38 | This repository is built around the DCASE 2017 task 3 dataset, and consists of four Python scripts.
39 | * The feature.py script, extracts the features, labels, and normalizes the training and test split features. Make sure you update the location of the wav files, evaluation setup and folder to write features in before running it.
40 | * The sed.py script, loads the normalized features, and traines the SEDnet. The training stops when the error rate metric in one second segment (http://tut-arg.github.io/sed_eval/) stops improving.
41 | * The metrics.py script, implements the core metrics from sound event detection evaluation module http://tut-arg.github.io/sed_eval/
42 | * The utils.py script has some utility functions.
43 |
44 | If you are only interested in the SEDnet model then just check `get_model()` function in the sed.py script.
45 |
46 |
47 | ### Prerequisites
48 |
49 | The requirements.txt file consists of the libraries and their versions used. The Python script is written and tested in 3.7.3 version. You can install the requirements by running the following line
50 |
51 | ```
52 | pip install -r requirements.txt
53 | ```
54 | ## Training the SEDnet on development dataset of DCASE 2017
55 |
56 | * Download the dataset from https://zenodo.org/record/814831#.Ws2xO3VuYUE
57 | * Update the path of the `audio/street/` and `evaluation_setup` folders of the dataset in feature.py script. Also update the `feat_folder` variable with a local folder where the script can dump the extracted feature files. Run the script `python feature.py` this will save the features and labels of training and test splits in the provided `feat_folder`. Change the flag `is_mono` to `True` for single channel SED, and `False` for multichannel SED. Since the dataset used has only binaural audio, by setting `is_mono = False`, the SEDnet trains on binaural audio.
58 | * In the sed.py script, update the `feat_folder` path as used in feature.py script. Change the `is_mono` flag according to single or multichannel SED studies and run the script `python sed.py`. This should train on the default training split of the dataset, and evaluate the model on the testing split for all four folds.
59 |
60 | The sound event detection metrics - error rate (ER) and F-score for one second segment averaged over four folds are as following. Since the dataset is small the results vary quite a bit, hence we report the mean of five separate runs. An ideal SED method has an ER of 0 and F of 1.
61 |
62 | | SEDnet mode | ER | F|
63 | | ----| --- | --- |
64 | | Single channel | 0.60 | 0.57 |
65 | | Multichannel |0.60 | 0.59|
66 |
67 | The results vary from the original paper, as we are not using the evaluation split here
68 |
69 | ## License
70 |
71 | This repository is licensed under the TUT License - see the [LICENSE](LICENSE) file for details
72 |
73 | ## Acknowledgments
74 |
75 | The research leading to these results has received funding from the European Research Council under the European Unions H2020 Framework Programme through ERC Grant Agreement 637422 EVERYSOUND.
76 |
--------------------------------------------------------------------------------
/feature.py:
--------------------------------------------------------------------------------
1 | import wave
2 | import numpy as np
3 | import utils
4 | import librosa
5 | from IPython import embed
6 | import os
7 | from sklearn import preprocessing
8 |
9 |
10 | def load_audio(filename, mono=True, fs=44100):
11 | """Load audio file into numpy array
12 | Supports 24-bit wav-format
13 |
14 | Taken from TUT-SED system: https://github.com/TUT-ARG/DCASE2016-baseline-system-python
15 |
16 | Parameters
17 | ----------
18 | filename: str
19 | Path to audio file
20 |
21 | mono : bool
22 | In case of multi-channel audio, channels are averaged into single channel.
23 | (Default value=True)
24 |
25 | fs : int > 0 [scalar]
26 | Target sample rate, if input audio does not fulfil this, audio is resampled.
27 | (Default value=44100)
28 |
29 | Returns
30 | -------
31 | audio_data : numpy.ndarray [shape=(signal_length, channel)]
32 | Audio
33 |
34 | sample_rate : integer
35 | Sample rate
36 |
37 | """
38 |
39 | file_base, file_extension = os.path.splitext(filename)
40 | if file_extension == '.wav':
41 | _audio_file = wave.open(filename)
42 |
43 | # Audio info
44 | sample_rate = _audio_file.getframerate()
45 | sample_width = _audio_file.getsampwidth()
46 | number_of_channels = _audio_file.getnchannels()
47 | number_of_frames = _audio_file.getnframes()
48 |
49 | # Read raw bytes
50 | data = _audio_file.readframes(number_of_frames)
51 | _audio_file.close()
52 |
53 | # Convert bytes based on sample_width
54 | num_samples, remainder = divmod(len(data), sample_width * number_of_channels)
55 | if remainder > 0:
56 | raise ValueError('The length of data is not a multiple of sample size * number of channels.')
57 | if sample_width > 4:
58 | raise ValueError('Sample size cannot be bigger than 4 bytes.')
59 |
60 | if sample_width == 3:
61 | # 24 bit audio
62 | a = np.empty((num_samples, number_of_channels, 4), dtype=np.uint8)
63 | raw_bytes = np.fromstring(data, dtype=np.uint8)
64 | a[:, :, :sample_width] = raw_bytes.reshape(-1, number_of_channels, sample_width)
65 | a[:, :, sample_width:] = (a[:, :, sample_width - 1:sample_width] >> 7) * 255
66 | audio_data = a.view(' posterior_thresh
162 | score_list = metrics.compute_scores(pred_thresh, Y_test, frames_in_1_sec=frames_1_sec)
163 |
164 | f1_overall_1sec_list[i] = score_list['f1_overall_1sec']
165 | er_overall_1sec_list[i] = score_list['er_overall_1sec']
166 | pat_cnt = pat_cnt + 1
167 |
168 | # Calculate confusion matrix
169 | test_pred_cnt = np.sum(pred_thresh, 2)
170 | Y_test_cnt = np.sum(Y_test, 2)
171 | conf_mat = confusion_matrix(Y_test_cnt.reshape(-1), test_pred_cnt.reshape(-1))
172 | conf_mat = conf_mat / (utils.eps + np.sum(conf_mat, 1)[:, None].astype('float'))
173 |
174 | if er_overall_1sec_list[i] < best_er:
175 | best_conf_mat = conf_mat
176 | best_er = er_overall_1sec_list[i]
177 | f1_for_best_er = f1_overall_1sec_list[i]
178 | model.save(os.path.join(__models_dir, '{}_fold_{}_model.h5'.format(__fig_name, fold)))
179 | best_epoch = i
180 | pat_cnt = 0
181 |
182 | print('tr Er : {}, val Er : {}, F1_overall : {}, ER_overall : {} Best ER : {}, best_epoch: {}'.format(
183 | tr_loss[i], val_loss[i], f1_overall_1sec_list[i], er_overall_1sec_list[i], best_er, best_epoch))
184 | plot_functions(nb_epoch, tr_loss, val_loss, f1_overall_1sec_list, er_overall_1sec_list, '_fold_{}'.format(fold))
185 | if pat_cnt > patience:
186 | break
187 | avg_er.append(best_er)
188 | avg_f1.append(f1_for_best_er)
189 | print('saved model for the best_epoch: {} with best_f1: {} f1_for_best_er: {}'.format(
190 | best_epoch, best_er, f1_for_best_er))
191 | print('best_conf_mat: {}'.format(best_conf_mat))
192 | print('best_conf_mat_diag: {}'.format(np.diag(best_conf_mat)))
193 |
194 | print('\n\nMETRICS FOR ALL FOUR FOLDS: avg_er: {}, avg_f1: {}'.format(avg_er, avg_f1))
195 | print('MODEL AVERAGE OVER FOUR FOLDS: avg_er: {}, avg_f1: {}'.format(np.mean(avg_er), np.mean(avg_f1)))
196 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | eps = np.finfo(np.float).eps
5 |
6 |
7 | def create_folder(_fold_path):
8 | if not os.path.exists(_fold_path):
9 | os.makedirs(_fold_path)
10 |
11 |
12 | def reshape_3Dto2D(A):
13 | return A.reshape(A.shape[0] * A.shape[1], A.shape[2])
14 |
15 |
16 | def split_multi_channels(data, num_channels):
17 | in_shape = data.shape
18 | if len(in_shape) == 3:
19 | hop = in_shape[2] // num_channels
20 | tmp = np.zeros((in_shape[0], num_channels, in_shape[1], hop))
21 | for i in range(num_channels):
22 | tmp[:, i, :, :] = data[:, :, i * hop:(i + 1) * hop]
23 | else:
24 | print("ERROR: The input should be a 3D matrix but it seems to have dimensions ", in_shape)
25 | exit()
26 | return tmp
27 |
28 |
29 | def split_in_seqs(data, subdivs):
30 | if len(data.shape) == 1:
31 | if data.shape[0] % subdivs:
32 | data = data[:-(data.shape[0] % subdivs), :]
33 | data = data.reshape((data.shape[0] // subdivs, subdivs, 1))
34 | elif len(data.shape) == 2:
35 | if data.shape[0] % subdivs:
36 | data = data[:-(data.shape[0] % subdivs), :]
37 | data = data.reshape((data.shape[0] // subdivs, subdivs, data.shape[1]))
38 | elif len(data.shape) == 3:
39 | if data.shape[0] % subdivs:
40 | data = data[:-(data.shape[0] % subdivs), :, :]
41 | data = data.reshape((data.shape[0] // subdivs, subdivs, data.shape[1], data.shape[2]))
42 | return data
43 |
--------------------------------------------------------------------------------