├── .gitignore ├── .idea └── .gitignore ├── Classical_methods ├── play_with_spectograms.py └── train_svm_detector.py ├── LICENSE ├── README.md ├── analyze_spectogram.py ├── assets └── SED.png ├── dataset ├── common_config.py ├── dataset_utils.py ├── download_tau_sed_2019.py ├── spectogram │ ├── preprocess.py │ ├── spectogram_configs.py │ └── spectograms_dataset.py └── waveform │ ├── waveform_configs.py │ └── waveform_dataset.py ├── infer.py ├── main.py ├── models ├── spectogram_models.py └── waveform_models.py ├── train.py └── utils ├── common.py ├── metric_utils.py └── plot_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .idea/ 4 | /outputs/ 5 | /training_dir/ 6 | /samples/ 7 | /inference_outputs/ 8 | /processed_datasets/ 9 | /Successfull_Train_dir/ 10 | *.pth 11 | Classical_methods/svm-classification.png 12 | Classical_methods/plots 13 | 14 | tmp_file.WAV 15 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /Classical_methods/play_with_spectograms.py: -------------------------------------------------------------------------------- 1 | from dataset.spectogram.spectograms_dataset import preprocess_film_clap_data, SpectogramDataset 2 | import numpy as np 3 | from scipy.linalg import eigh 4 | from sklearn.decomposition import PCA as skPCA 5 | from sklearn.manifold import TSNE 6 | from sklearn import svm 7 | 8 | import matplotlib 9 | import matplotlib.pyplot as plt 10 | matplotlib.use('TkAgg') 11 | 12 | if __name__ == '__main__': 13 | 14 | 15 | features_and_labels_dir, features_mean_std_file = preprocess_film_clap_data('../../data', 16 | preprocessed_mode="logMel", 17 | force_preprocess=False) 18 | 19 | dataset = SpectogramDataset(features_and_labels_dir, features_mean_std_file, 20 | augment_data=False, 21 | balance_classes=False, 22 | val_descriptor=0.2, 23 | preprocessed_mode="logMel") 24 | 25 | pos_frames = [] 26 | neg_frames = [] 27 | for idx in dataset.train_start_indices: 28 | features = dataset.train_features[0, idx] 29 | label = dataset.train_event_matrix[idx, 0] 30 | if label: 31 | pos_frames.append(features) 32 | else: 33 | neg_frames.append(features) 34 | pos_frames = np.array(pos_frames) 35 | neg_frames = np.array(neg_frames) 36 | neg_frames = neg_frames[np.random.randint(len(neg_frames), size=len(pos_frames)).tolist()] 37 | 38 | # pos_frames = pos_frames[np.random.randint(len(pos_frames), size=3000).tolist()] 39 | # neg_frames = neg_frames[np.random.randint(len(neg_frames), size=3000).tolist()] 40 | 41 | # neg_frames = random.sample(neg_frames, len(pos_frames)) 42 | labels = np.zeros(len(pos_frames) + len(neg_frames)) 43 | labels[:len(pos_frames)] = 1 44 | data = np.concatenate((pos_frames, neg_frames), axis=0) 45 | 46 | # pca = PCA(data.shape[1], 2) 47 | # pca.learn_encoder_decoder(data) 48 | # data_2d = pca.encode(data) 49 | 50 | # pca = skPCA(n_components=2) 51 | # pca.fit(data) 52 | # data_2d = pca.transform(data) 53 | 54 | # data_2d = TSNE(n_components=2, perplexity=40, n_iter=300).fit_transform(data) 55 | 56 | # plt.scatter(data_2d[:len(pos_frames),0], data_2d[:len(pos_frames),1], color='r', label='pos', alpha=0.5) 57 | # plt.scatter(data_2d[len(pos_frames):,0], data_2d[len(pos_frames):,1], color='b', label='neg', alpha=0.5) 58 | 59 | print("Classifying") 60 | classifier = svm.SVC(C=1, kernel="rbf") 61 | classifier.fit(data[:-100], labels[:-100]) 62 | predictions = classifier.predict(data[-100:]) 63 | 64 | accuracy = np.mean(predictions == labels[-100:]) 65 | print(accuracy) 66 | 67 | 68 | -------------------------------------------------------------------------------- /Classical_methods/train_svm_detector.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from sklearn import svm 4 | import os 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import random 8 | 9 | from dataset.dataset_utils import get_film_clap_paths_and_labels, read_multichannel_audio 10 | from dataset.spectogram.preprocess import multichannel_complex_to_log_mel 11 | from dataset.waveform.waveform_dataset import split_to_frames_with_hop_size 12 | from utils.metric_utils import calculate_metrics, f_score 13 | from utils.plot_utils import plot_sample_features 14 | 15 | import dataset.waveform.waveform_configs as cfg 16 | matplotlib.use('TkAgg') 17 | 18 | 19 | class SVM_detector: 20 | def __init__(self, soft_svm, recall_priority): 21 | self.soft_svm = soft_svm 22 | self.svm = svm.SVC(C=1, kernel="rbf", probability=soft_svm) 23 | self.recall_priority = recall_priority 24 | def learn(self, spectograms, event_matrices): 25 | data = np.concatenate(spectograms, axis=0) 26 | labels = np.concatenate(event_matrices, axis=0) 27 | sample_weights = labels * self.recall_priority + (1 - labels) 28 | print(f"Svm training on {len(data)} samples... ", end='') 29 | self.svm.fit(data, labels, sample_weight=sample_weights) 30 | print("Done") 31 | 32 | def predict(self, spectogram): 33 | result = np.zeros(spectogram.shape[0]) 34 | for i in range(spectogram.shape[0]): 35 | if self.soft_svm: 36 | result[i] = self.svm.predict_proba([spectogram[i]])[0,1] 37 | else: 38 | result[i] = self.svm.predict([spectogram[i]]) 39 | 40 | return result 41 | 42 | def save(self, path): 43 | with open(path, 'wb') as file: 44 | pickle.dump(self.svm, file) 45 | 46 | def load(self, path): 47 | if os.path.exists(path): 48 | with open(path, 'rb') as file: 49 | self.svm = pickle.load(file) 50 | 51 | def get_raw_data(): 52 | NFFT = 2**int(np.ceil(np.log2(cfg.frame_size))) 53 | 54 | audio_paths_labels_and_names = get_film_clap_paths_and_labels("../../data/FilmClap", time_margin=cfg.time_margin) 55 | 56 | features = [] 57 | label_sets = [] 58 | file_names = [] 59 | for i, (audio_path, start_times, end_times, audio_name) in enumerate(audio_paths_labels_and_names): 60 | assert "_".join(audio_name.split("_")[1:]) in audio_path 61 | waveform = read_multichannel_audio(audio_path, target_fs=cfg.working_sample_rate) 62 | waveform = waveform.T # -> (channels, samples) 63 | # Split wave form to overlapping frames and create labels for each 64 | frames, labels = split_to_frames_with_hop_size(waveform, start_times, end_times) 65 | frames = np.concatenate(frames, axis=0) 66 | frames *= np.hanning(frames.shape[1]) 67 | complex_spectogram = np.fft.rfft(frames, NFFT) 68 | mel_features = multichannel_complex_to_log_mel(complex_spectogram) 69 | 70 | features.append(mel_features) 71 | label_sets.append(np.array(labels)) 72 | file_names.append(audio_name) 73 | 74 | data = list(zip(features, label_sets, file_names)) 75 | return data 76 | 77 | def split_train_val(all_data): 78 | random.shuffle(all_data) 79 | features_list, event_matrix_list, file_names = zip(*all_data) 80 | features_list, event_matrix_list, file_names = list(features_list), list(event_matrix_list), list(file_names) 81 | 82 | # Split to train val 83 | val_amount = len(features_list) // 5 84 | train_features_list = features_list[val_amount:] 85 | train_event_matrix_list = event_matrix_list[val_amount:] 86 | train_file_names = file_names[val_amount:] 87 | 88 | val_features_list = features_list[:val_amount] 89 | val_event_matrix_list = event_matrix_list[:val_amount] 90 | val_file_names = file_names[:val_amount] 91 | 92 | return train_features_list, train_event_matrix_list, val_features_list, val_event_matrix_list, val_file_names 93 | 94 | def evaluate_model(model, eval_data): 95 | # Evaluate model 96 | recal_sets, precision_sets, APs, accs = [], [], [], [] 97 | for feature, event_mat, name in eval_data: 98 | pred = model.predict(feature) 99 | acc = np.mean(pred == event_mat) 100 | 101 | recals, precisions, AP = calculate_metrics(pred, event_mat) 102 | f1s = [f_score(r,p,1) for r,p in zip(recals, precisions)] 103 | print(f"{name} max f1 score: {np.max(f1s)}") 104 | recal_sets.append(recals) 105 | precision_sets.append(precisions) 106 | APs.append(AP) 107 | accs.append(acc) 108 | 109 | plot_sample_features(np.array([feature]), 110 | mode='spectogram', 111 | output=pred.reshape(-1,1), 112 | target=event_mat.reshape(-1,1), 113 | file_name=f"Acc:{acc:.2f}, AP: {AP:.2f}, f1: {np.max(f1s):.2f}", 114 | plot_path=f"plots/{name}-f1: {np.max(f1s):.2f}.png") 115 | 116 | recal_vals = np.mean(recal_sets, axis=0) 117 | precision_vals = np.mean(precision_sets, axis=0) 118 | MAP = np.sum(recal_vals[:-1] * (recal_vals[:-1] - recal_vals[1:])) 119 | 120 | plt.plot(recal_vals, precision_vals) 121 | plt.xticks([0, 0.25, 0.5, 0.75, 1]) 122 | plt.yticks([0, 0.25, 0.5, 0.75, 1]) 123 | plt.title(f"Validation AVG ROC" 124 | f"\nAP: {MAP:.2f}") 125 | plt.xlabel("Avg Recall") 126 | plt.ylabel("Avg Precision") 127 | plt.savefig("svm-classification.png") 128 | plt.clf() 129 | 130 | if __name__ == '__main__': 131 | # Load data 132 | all_data = get_raw_data() 133 | train_features_list, train_event_matrix_list, val_features_list, val_event_matrix_list, val_file_names = split_train_val(all_data) 134 | 135 | # Train model 136 | model = SVM_detector(soft_svm=True, recall_priority=10) 137 | model.learn(train_features_list, train_event_matrix_list) 138 | model.save("last_pickled_model.pkl") 139 | 140 | # Evaluate model 141 | evaluate_model(model, zip(val_features_list, val_event_matrix_list, val_file_names)) 142 | 143 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Train and evaluation of sound-event detector. 2 | 3 | # Train a CNN detector: 4 | Train 2d CNN on wavesound spectogram image or a 1d CNN on raw sound wave samples 5 | - run main.py 6 | 7 | # Train an SVM detector: 8 | Train an SVM on Spectogram columns or frames of raw sound wave samples 9 | - run main.py 10 | 11 | # Requirements 12 | - soundfile 13 | - librosa 14 | 15 | # Credits 16 | - Greatly inspired by https://github.com/qiuqiangkong/dcase2019_task3 17 | 18 | 19 | 20 | # Sound event detection example 21 | This image is an evaluation of a detector working on the spectogram domain: 22 | - Top: A spectogram of 60s sound file. 23 | - Middle: Event prediction for each frame (temporal segment of sound wave). Confidence values from 0 to 1. 24 | - Bottom: Event ground-truth for each frame. Confidence values from 0 to 1. 25 | -------------------------------------------------------------------------------- /analyze_spectogram.py: -------------------------------------------------------------------------------- 1 | from dataset.spectogram.preprocess import multichannel_stft, multichannel_complex_to_log_mel 2 | from dataset.dataset_utils import read_multichannel_audio 3 | from dataset.spectogram import spectogram_configs as cfg 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import soundfile 7 | import matplotlib 8 | matplotlib.use('TkAgg') 9 | 10 | if __name__ == '__main__': 11 | # audio_path = '/home/ariel/projects/sound/data/FilmClap/original/Meron/S005-S004T1.WAV' 12 | # audio_path = '/home/ariel/projects/sound/data/FilmClap/original/StillJames/2C-T001.WAV' 13 | # audio_path = '/home/ariel/projects/sound/data/FilmClap/original/JackRinger-05/161019_1233.wav' 14 | audio_path = '/home/ariel/projects/sound/data/FilmClap/original/StillJames/8D-T001.WAV' 15 | 16 | sec_start = 35.45 17 | sec_end = 35.65 18 | 19 | multichannel_waveform = read_multichannel_audio(audio_path=audio_path, target_fs=cfg.working_sample_rate) 20 | 21 | 22 | multichannel_waveform = multichannel_waveform[int(cfg.working_sample_rate * sec_start): int(cfg.working_sample_rate * sec_end)] 23 | soundfile.write("tmp_file.WAV", multichannel_waveform, cfg.working_sample_rate) 24 | feature = multichannel_stft(multichannel_waveform) 25 | feature = multichannel_complex_to_log_mel(feature) 26 | 27 | frames_num = feature.shape[1] 28 | tick_hop = max(1, frames_num // 20) 29 | xticks = np.concatenate((np.arange(0, frames_num - tick_hop, tick_hop), [frames_num])) 30 | xlabels = [f"{x / cfg.frames_per_second:.3f}s" for x in xticks] 31 | 32 | fig = plt.figure() 33 | ax = fig.add_subplot(211) 34 | ax.matshow(feature[0].T, origin='lower', cmap='jet') 35 | ax.set_xticks(xticks) 36 | ax.set_xticklabels(xlabels, rotation='vertical') 37 | ax.xaxis.set_ticks_position('bottom') 38 | 39 | ax = fig.add_subplot(212) 40 | signal = multichannel_waveform.mean(1) 41 | ax.plot(range(len(signal)), signal) 42 | 43 | ax.get_xaxis().set_visible(False) 44 | ax.get_yaxis().set_visible(False) 45 | plt.autoscale(tight=True) 46 | 47 | plt.show() -------------------------------------------------------------------------------- /assets/SED.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ariel415el/SoundEventDetection-Pytorch/22591921bb84aace01ae4605506fb95892935689/assets/SED.png -------------------------------------------------------------------------------- /dataset/common_config.py: -------------------------------------------------------------------------------- 1 | 2 | time_margin = 0.33 3 | working_sample_rate = 48000 4 | frame_size = int(working_sample_rate * time_margin * 2) 5 | hop_size = frame_size // 2 6 | audio_channels = 1 7 | min_event_percentage_in_positive_frame = 0.74 8 | frames_per_second = working_sample_rate // hop_size 9 | 10 | # Tau-SED details: 11 | # tau_sed_labels = ['knock', 'drawer', 'clearthroat', 'phone', 'keysDrop', 'speech', 12 | # 'keyboard', 'pageturn', 'cough', 'doorslam', 'laughter'] 13 | 14 | # tau_sed_labels = ['knock', 'keysDrop', 'doorslam'] 15 | tau_sed_labels = ['doorslam'] 16 | classes_num = len(tau_sed_labels) 17 | 18 | -------------------------------------------------------------------------------- /dataset/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from collections import defaultdict 4 | 5 | import librosa 6 | import numpy as np 7 | import pandas as pd 8 | import soundfile 9 | 10 | from dataset.spectogram import spectogram_configs as cfg 11 | 12 | 13 | def get_film_clap_paths_and_labels(data_root, time_margin=0.1): 14 | """ 15 | Parses the Film_clap raw data and collect audio file paths , start_times and end_times of claps 16 | """ 17 | result = [] 18 | num_claps = 0 19 | num_audio_files = 0 20 | files_per_film = defaultdict(lambda:0) 21 | path_to_label = json.load(open(os.path.join(data_root, 'paths_and_labels_fixed_Meron.txt'))) 22 | print("Collecting Film-clap dataset") 23 | for sound_path in path_to_label: 24 | soundfile_name = os.path.splitext(os.path.basename(sound_path))[0] 25 | film_name = os.path.basename(os.path.dirname(sound_path)) 26 | name = f"{film_name}_{soundfile_name}" 27 | evemt_centers_list = path_to_label[sound_path] 28 | assert os.path.exists(sound_path), sound_path 29 | start_times = [e - time_margin for e in evemt_centers_list] 30 | end_times = [e + time_margin for e in evemt_centers_list] 31 | result += [(sound_path, start_times, end_times, name)] 32 | num_claps += len(start_times) 33 | num_audio_files += 1 34 | files_per_film[film_name] += 1 35 | 36 | for film_name in files_per_film: 37 | print(f"\t- {film_name} has {files_per_film[film_name]}") 38 | print(f"\tFilm clap dataset contains {num_audio_files} audio files with {num_claps} clap incidents") 39 | return result 40 | 41 | 42 | def get_tau_sed_paths_and_labels(audio_dir, labels_data_dir): 43 | """ 44 | Parses the Tau_sed raw data and collect audio file paths, start_times and end_times of claps 45 | """ 46 | results = [] 47 | for audio_fname in os.listdir(audio_dir): 48 | bare_name = os.path.splitext(audio_fname)[0] 49 | 50 | audio_path = os.path.join(audio_dir, audio_fname) 51 | 52 | df = pd.read_csv(os.path.join(labels_data_dir, bare_name + ".csv"), sep=',') 53 | relevant_classes = [i for i in range(len(df['sound_event_recording'].values)) 54 | if df['sound_event_recording'].values[i] in cfg.tau_sed_labels] 55 | 56 | start_times, end_times = df['start_time'].values[relevant_classes], df['end_time'].values[relevant_classes] 57 | 58 | results += [(audio_path, start_times, end_times, bare_name)] 59 | 60 | return results 61 | 62 | 63 | def read_multichannel_audio(audio_path, target_fs=None): 64 | """ 65 | Read the audio samples in files and resample them to fit the desired sample ratre 66 | """ 67 | (multichannel_audio, sample_rate) = soundfile.read(audio_path) 68 | if len(multichannel_audio.shape) == 1: 69 | multichannel_audio = multichannel_audio.reshape(-1, 1) 70 | if multichannel_audio.shape[1] < cfg.audio_channels: 71 | print(multichannel_audio.shape[1]) 72 | multichannel_audio = np.repeat(multichannel_audio.mean(1).reshape(-1, 1), cfg.audio_channels, axis=1) 73 | elif cfg.audio_channels == 1: 74 | multichannel_audio = multichannel_audio.mean(1).reshape(-1, 1) 75 | elif multichannel_audio.shape[1] > cfg.audio_channels: 76 | multichannel_audio = multichannel_audio[:, :cfg.audio_channels] 77 | 78 | if target_fs is not None and sample_rate != target_fs: 79 | 80 | channels_num = multichannel_audio.shape[1] 81 | 82 | multichannel_audio = np.array( 83 | [librosa.resample(multichannel_audio[:, i], orig_sr=sample_rate, target_sr=target_fs) for i in range(channels_num)] 84 | ).T 85 | 86 | return multichannel_audio 87 | -------------------------------------------------------------------------------- /dataset/download_tau_sed_2019.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | from torchvision.datasets.utils import download_url 5 | 6 | 7 | def download_foa_data(data_dir, fold_name='eval'): 8 | urls = [ 9 | 'https://zenodo.org/record/2599196/files/foa_dev.z01?download=1', 10 | 'https://zenodo.org/record/2599196/files/foa_dev.z02?download=1', 11 | 'https://zenodo.org/record/2599196/files/foa_dev.zip?download=1', 12 | 'https://zenodo.org/record/2599196/files/metadata_dev.zip?download=1', 13 | 'https://zenodo.org/record/3377088/files/foa_eval.zip?download=1', 14 | 'https://zenodo.org/record/3377088/files/metadata_eval.zip?download=1', 15 | ] 16 | md5s = [ 17 | 'bd5b18a47a3ed96e80069baa6b221a5a', 18 | '5194ebf43ae095190ed78691ec9889b1', 19 | '2154ad0d9e1e45bfc933b39591b49206', 20 | 'c2e5c8b0ab430dfd76c497325171245d', 21 | '4a8ca8bfb69d7c154a56a672e3b635d5', 22 | 'a0ec7640284ade0744dfe299f7ba107b' 23 | ] 24 | names = [ 25 | 'foa_dev.z01', 26 | 'foa_dev.z02', 27 | 'foa_dev.zip', 28 | 'metadata_dev.zip', 29 | 'foa_eval.zip', 30 | 'metadata_eval.zip' 31 | ] 32 | 33 | if fold_name == 'eval': 34 | urls, md5s, names = urls[-2:], md5s[-2:], names[-2:] 35 | 36 | os.makedirs(data_dir, exist_ok=True) 37 | for url, md5, name in zip(urls, md5s, names): 38 | download_url(url, data_dir, md5=md5, filename=name) 39 | 40 | 41 | def extract_foa_data(data_dir, output_dir, fold_name='eval'): 42 | os.makedirs(output_dir, exist_ok=True) 43 | os.makedirs(data_dir, exist_ok=True) 44 | subprocess.call(["unzip", os.path.join(data_dir,'metadata_eval'), "-d", output_dir]) 45 | subprocess.call(["unzip", os.path.join(data_dir, 'foa_eval'), "-d", output_dir]) 46 | 47 | subprocess.call(f"cp -R {output_dir}/proj/asignal/DCASE2019/dataset/foa_eval -d {output_dir}/foa_eval".split(" ")) 48 | shutil.rmtree(f"{output_dir}/proj") 49 | 50 | if fold_name == 'train': 51 | subprocess.call(["unzip", os.path.join(data_dir, 'metadata_dev.zip'), "-d", output_dir]) 52 | subprocess.call(f"zip -s 0 {os.path.join(data_dir,'foa_dev.zip')} --out {os.path.join(data_dir,'unsplit_foa_dev.zip')}".split(" ")) 53 | subprocess.call(f"unzip {os.path.join(data_dir, 'unsplit_foa_dev.zip')} -d {output_dir}".split(" ")) 54 | 55 | 56 | def ensure_tau_data(data_dir, fold_name='eval'): 57 | zipped_data_dir = os.path.join(data_dir, 'zipped') 58 | extracted_data_dir = os.path.join(data_dir, 'raw') 59 | audio_dir = f"{extracted_data_dir}/foa_{fold_name}" 60 | meta_data_dir = f"{extracted_data_dir}/metadata_{fold_name}" 61 | 62 | # Download and extact data 63 | if not os.path.exists(zipped_data_dir): 64 | print("Downloading zipped data") 65 | download_foa_data(zipped_data_dir, fold_name) 66 | if not os.path.exists(audio_dir): 67 | print("Extracting raw data") 68 | extract_foa_data(zipped_data_dir, extracted_data_dir, fold_name) 69 | else: 70 | print("Using existing raw data") 71 | 72 | return audio_dir, meta_data_dir -------------------------------------------------------------------------------- /dataset/spectogram/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import librosa 5 | import numpy as np 6 | import soundfile 7 | from tqdm import tqdm 8 | 9 | import dataset.spectogram.spectogram_configs as cfg 10 | from dataset.dataset_utils import read_multichannel_audio 11 | from utils.plot_utils import plot_sample_features 12 | 13 | MEL_FILTER_BANK_MATRIX = librosa.filters.mel( 14 | sr=cfg.working_sample_rate, 15 | n_fft=cfg.NFFT, 16 | n_mels=cfg.mel_bins, 17 | fmin=cfg.mel_min_freq, 18 | fmax=cfg.mel_max_freq).T 19 | 20 | 21 | def multichannel_stft(multichannel_signal): 22 | (samples, channels_num) = multichannel_signal.shape 23 | features = [] 24 | for c in range(channels_num): 25 | complex_spectogram = librosa.core.stft( 26 | y=multichannel_signal[:, c], 27 | n_fft=cfg.NFFT, 28 | win_length=cfg.frame_size, 29 | hop_length=cfg.hop_size, 30 | window=np.hanning(cfg.frame_size), 31 | center=True, 32 | dtype=np.complex64, 33 | pad_mode='reflect').T 34 | '''(N, n_fft // 2 + 1)''' 35 | features.append(complex_spectogram) 36 | return np.array(features) 37 | 38 | 39 | def multichannel_complex_to_log_mel(multichannel_complex_spectogram): 40 | multichannel_power_spectogram = np.abs(multichannel_complex_spectogram) ** 2 41 | multichannel_mel_spectogram = np.dot(multichannel_power_spectogram, MEL_FILTER_BANK_MATRIX) 42 | multichannel_logmel_spectogram = librosa.core.power_to_db(multichannel_mel_spectogram, 43 | ref=1.0, amin=1e-10, top_db=None).astype(np.float32) 44 | 45 | return multichannel_logmel_spectogram 46 | 47 | 48 | def calculate_scalar_of_tensor(x): 49 | if x.ndim == 2: 50 | axis = 0 51 | elif x.ndim == 3: 52 | axis = (0, 1) 53 | 54 | mean = np.mean(x, axis=axis) 55 | std = np.std(x, axis=axis) 56 | 57 | return mean, std 58 | 59 | 60 | def preprocess_data(audio_path_and_labels, output_dir, output_mean_std_file, preprocess_mode='logMel'): 61 | print("Preprocessing collected data") 62 | os.makedirs(output_dir, exist_ok=True) 63 | 64 | all_features = [] 65 | 66 | for (audio_path, start_times, end_times, audio_name) in tqdm(audio_path_and_labels): 67 | multichannel_waveform = read_multichannel_audio(audio_path=audio_path, target_fs=cfg.working_sample_rate) 68 | feature = multichannel_stft(multichannel_waveform) 69 | if preprocess_mode == 'logMel': 70 | feature = multichannel_complex_to_log_mel(feature) 71 | all_features.append(feature) 72 | 73 | output_path = os.path.join(output_dir, audio_name + f"_{preprocess_mode}_features_and_labels.pkl") 74 | with open(output_path, 'wb') as f: 75 | pickle.dump({'features': feature, 'start_times': start_times, 'end_times': end_times}, 76 | f) 77 | 78 | all_features = np.concatenate(all_features, axis=1) 79 | mean, std = calculate_scalar_of_tensor(all_features) 80 | with open(output_mean_std_file, 'wb') as f: 81 | pickle.dump({'mean': mean, 'std': std}, f) 82 | 83 | # Visualize single data sample 84 | (audio_path, start_times, end_times, audio_name) = random.choice(audio_path_and_labels) 85 | analyze_data_sample(audio_path, start_times, end_times, audio_name, 86 | os.path.join(os.path.dirname(output_mean_std_file), "data_sample.png")) 87 | 88 | 89 | def analyze_data_sample(audio_path, start_times, end_times, audio_name, plot_path): 90 | """ 91 | A debug function that plots a single sample and analyzes how the spectogram configuration affect the feature final size 92 | """ 93 | from dataset.spectogram.spectograms_dataset import create_event_matrix 94 | org_multichannel_audio, org_sample_rate = soundfile.read(audio_path) 95 | 96 | multichannel_audio = read_multichannel_audio(audio_path=audio_path, target_fs=cfg.working_sample_rate) 97 | feature = multichannel_stft(multichannel_audio) 98 | feature = multichannel_complex_to_log_mel(feature) # (channels, frames, mel_bins) 99 | event_matrix = create_event_matrix(feature.shape[1], start_times, end_times) 100 | plot_sample_features(feature, mode='spectogram', target=event_matrix, plot_path=plot_path, file_name=audio_name) 101 | 102 | signal_time = multichannel_audio.shape[0]/cfg.working_sample_rate 103 | FPS = cfg.working_sample_rate / cfg.hop_size 104 | print(f"Data sample analysis: {audio_name}") 105 | print(f"\tOriginal audio: {org_multichannel_audio.shape} sample_rate={org_sample_rate}") 106 | print(f"\tsingle channel audio: {multichannel_audio.shape}, sample_rate={cfg.working_sample_rate}") 107 | print(f"\tSignal time is (num_samples/sample_rate)={signal_time:.1f}s") 108 | print(f"\tSIFT FPS is (sample_rate/hop_size)={FPS}") 109 | print(f"\tTotal number of frames is (FPS*signal_time)={FPS*signal_time:.1f}") 110 | print(f"\tEach frame covers {cfg.frame_size} samples or {cfg.frame_size / cfg.working_sample_rate:.3f} seconds " 111 | f"padded into {cfg.NFFT} samples and allow ({cfg.NFFT}//2+1)={cfg.NFFT // 2 + 1} frequency bins") 112 | print(f"\tFeatures shape: {feature.shape}") 113 | 114 | 115 | -------------------------------------------------------------------------------- /dataset/spectogram/spectogram_configs.py: -------------------------------------------------------------------------------- 1 | from utils.common import human_format 2 | import numpy as np 3 | from dataset.common_config import * 4 | 5 | NFFT = 2**int(np.ceil(np.log2(frame_size))) # The size of the padded frames on which fft will actualy work. Set this to a power of two for faster preprocessing 6 | mel_bins = 64 # How much frames to stretch over the 7 | mel_min_freq = 20 # Hz first mel bin (minimal possible value 0) 8 | mel_max_freq = working_sample_rate // 2 # Hz last mel bin (maximal possible value sampling_rate / 2) 9 | 10 | train_crop_size = frames_per_second * 10 # 10-second log mel spectrogram as input 11 | 12 | 13 | cfg_descriptor = f"Spectogram_SaR-{human_format(working_sample_rate)}_FrS-{human_format(frame_size)}" \ 14 | f"_HoS-{human_format(hop_size)}_Mel-{mel_bins}_Ch-{audio_channels}" 15 | -------------------------------------------------------------------------------- /dataset/spectogram/spectograms_dataset.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import os 4 | import pickle 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | import dataset.spectogram.spectogram_configs 10 | import dataset.spectogram.spectogram_configs as cfg 11 | from dataset.dataset_utils import get_film_clap_paths_and_labels, get_tau_sed_paths_and_labels 12 | from dataset.download_tau_sed_2019 import ensure_tau_data 13 | from dataset.spectogram.preprocess import preprocess_data, multichannel_complex_to_log_mel 14 | from random import shuffle 15 | 16 | 17 | class SpectogramDataset(Dataset): 18 | def __init__(self, features_and_labels_dir, mean_std_file, val_descriptor, 19 | balance_classes=False, augment_data=False, preprocessed_mode='Complex'): 20 | """ 21 | This dataset loads crops of the entire concatenated features of the data 22 | Args: 23 | features_and_labels_dir: 24 | mean_std_file: mean and std of the saved features # TODO: currently these are different for Complex histograms 25 | val_descriptor: How to split the data; float for percentage and string for specifing substring in desired files 26 | balance_classes: Limit the number of crops with no event to match the number of crops with events 27 | augment_data: 1. Add noise. 2. Mix STFT spectograms of multiple samples before converting to LogMel 28 | preprocessed_mode: defines whether if the preprocess phase included converting to LogMel or only STFT 29 | """ 30 | assert preprocessed_mode in ['logMel', 'Complex'], "Spectogram type should be either logmel or complex" 31 | assert not (preprocessed_mode == 'logMel' and augment_data), "Can't perform augmentation in logMel spectograms" 32 | self.preprocessed_mode = preprocessed_mode 33 | self.augment_data = augment_data 34 | self.train_crop_size = cfg.train_crop_size 35 | 36 | # Load data mean and std 37 | d = pickle.load(open(mean_std_file, 'rb')) 38 | self.mean = d['mean'] 39 | self.std = d['std'] 40 | 41 | all_paths = [os.path.join(features_and_labels_dir, x) for x in os.listdir(features_and_labels_dir)] 42 | train_feature_paths, self.val_feature_paths = split_train_val(all_paths, val_descriptor) 43 | 44 | self.train_features, self.train_event_matrix, self.train_start_indices = _read_train_data_to_memory(train_feature_paths, 45 | cfg.train_crop_size, 46 | balance_classes) 47 | 48 | self.val_features_list, self.val_event_matrix_list = _read_validation_data_to_memory(self.val_feature_paths) 49 | 50 | print(f"Data generator initiated with {len(train_feature_paths)} train samples " 51 | f"totaling {len(self.train_event_matrix) / cfg.frames_per_second:.1f} seconds " 52 | f"and {len(self.val_feature_paths)} val samples " 53 | f"totaling {len(np.concatenate(self.val_event_matrix_list, axis=0)) / cfg.frames_per_second:.1f} seconds") 54 | 55 | def __len__(self): 56 | return len(self.train_start_indices) 57 | 58 | def __getitem__(self, idx): 59 | ''' 60 | Generate mini-batch data for training. 61 | Samples a start index and crops a self.train_crop_size long segment from the concatenated featues 62 | Returns: 63 | batch_data_dict: dict containing feature, event, elevation and azimuth 64 | ''' 65 | 66 | data_indexes = np.arange(self.train_crop_size) + self.train_start_indices[idx] 67 | 68 | features = self.train_features[:, data_indexes] 69 | event_matrix = self.train_event_matrix[data_indexes] 70 | 71 | if self.augment_data: 72 | feature, event_matrix = self.augment_mix_samples(features, event_matrix) 73 | feature, event_matrix = self.augment_add_noise(feature, event_matrix) 74 | 75 | # Transform data 76 | features = self.transform(features) 77 | 78 | return torch.from_numpy(features), torch.from_numpy(event_matrix) 79 | 80 | def get_validation_sampler(self, max_validate_num=None): 81 | feature_names = self.val_feature_paths 82 | features_list = self.val_features_list 83 | event_matrix_list = self.val_event_matrix_list 84 | 85 | validate_num = len(feature_names) 86 | 87 | for n in range(validate_num): 88 | if n == max_validate_num: 89 | break 90 | 91 | name = os.path.basename(os.path.splitext(feature_names[n])[0]) 92 | feature = features_list[n] 93 | event_matrix = event_matrix_list[n] 94 | 95 | feature = self.transform(feature) 96 | 97 | features = feature[None, :, :, :] # ( batch_size=1, channels_num, frames_num, mel_bins) 98 | event_matrix = event_matrix[None, :, :] # (batch_size=1, frames_num, mel_bins) 99 | '''The None above indicates using an entire audio recording as 100 | input and batch_size=1 in inference''' 101 | 102 | yield torch.from_numpy(features), torch.from_numpy(event_matrix), name 103 | 104 | def transform(self, x): 105 | x = (x - self.mean) / self.std 106 | 107 | if self.preprocessed_mode == 'logMel': 108 | return x 109 | else: # If the preprocessed spectograms are saved as raw complex spectograms transform them into logMel 110 | return multichannel_complex_to_log_mel(x) 111 | 112 | def augment_add_noise(self, batch_feature, batch_event_matrix): 113 | # TODO these number are fit to noise added to waveform and not spectogram 114 | r = np.random.rand() 115 | if r > 0.5: 116 | noise_var = 0.001 + (r + 0.5) * (0.005 - 0.001) 117 | batch_feature += np.random.normal(0, noise_var, size=batch_feature.shape) 118 | return batch_feature, batch_event_matrix 119 | 120 | def augment_mix_samples(self, feature, event_matrix): 121 | """ 122 | Augment a samples by mixing its features and labesl with other train samples 123 | """ 124 | number_of_augmentations = np.random.choice([0, 1, 2, 3], 1, p=[0.6, 0.25, 0.1, 0.05])[0] 125 | for i in range(number_of_augmentations): 126 | random_pointer = np.random.randint(len(self.train_start_indices) + 1) 127 | new_data_indexes = np.arange(self.train_crop_size) + self.train_start_indices[random_pointer] 128 | new_feature = self.train_features[:, new_data_indexes] 129 | new_event_matrix = self.train_event_matrix[new_data_indexes] 130 | 131 | feature += new_feature 132 | event_matrix = np.maximum(event_matrix, new_event_matrix) 133 | feature /= (number_of_augmentations + 1) 134 | 135 | return feature, event_matrix 136 | 137 | 138 | def _read_train_data_to_memory(train_feature_paths, crop_size, balance_classes=False): 139 | """ 140 | Creates a list with all spectograms conatenated to each other so that one can sample random crops over them by choosing 141 | from a set of start indices. 142 | """ 143 | # Load training feature and targets 144 | frame_index = 0 145 | 146 | train_features_list = [] 147 | train_event_matrix_list = [] 148 | train_index_with_event = [] 149 | train_index_empty = [] 150 | 151 | for feature_path in train_feature_paths: 152 | data = pickle.load(open(feature_path, 'rb')) 153 | feature = data['features'] 154 | event_matrix = create_event_matrix(feature.shape[1], data['start_times'], data['end_times']) 155 | 156 | frames_num = feature.shape[1] 157 | '''Number of frames of the (log mel / complex) spectrogram of an audio 158 | recording. May be different from file to file''' 159 | 160 | possible_start_indices = np.arange(frame_index, frame_index + frames_num - crop_size) 161 | frame_index += frames_num 162 | 163 | # Append data 164 | train_features_list.append(feature) 165 | train_event_matrix_list.append(event_matrix) 166 | 167 | # Slpit data to chunks which contain an event and such that are not 168 | indices_with_event = np.zeros(possible_start_indices.shape, dtype=bool) 169 | for i in np.where(event_matrix > 0)[0]: 170 | indices_with_event[i - crop_size: i] = True 171 | train_index_with_event += possible_start_indices[np.where(indices_with_event)[0]].tolist() 172 | train_index_empty += possible_start_indices[np.where(indices_with_event == False)[0]].tolist() 173 | 174 | train_features = np.concatenate(train_features_list, axis=1) 175 | train_event_matrix = np.concatenate(train_event_matrix_list, axis=0) 176 | 177 | # Balance classes in train data 178 | np.random.shuffle(train_index_with_event) 179 | np.random.shuffle(train_index_empty) 180 | if balance_classes: 181 | size = min(len(train_index_with_event), len(train_index_empty)) 182 | train_index_with_event = train_index_with_event[:size] 183 | train_index_empty = train_index_empty[:size] 184 | train_start_indices = np.concatenate((train_index_empty, train_index_with_event)) 185 | np.random.shuffle(train_start_indices) 186 | 187 | return train_features, train_event_matrix, train_start_indices 188 | 189 | 190 | def _read_validation_data_to_memory(feature_paths): 191 | # Load validation feature and targets 192 | features_list = [] 193 | event_matrix_list = [] 194 | for feature_path in feature_paths: 195 | data = pickle.load(open(feature_path, 'rb')) 196 | feature = data['features'] 197 | event_matrix = create_event_matrix(feature.shape[1], data['start_times'], data['end_times']) 198 | 199 | features_list.append(feature) 200 | event_matrix_list.append(event_matrix) 201 | 202 | return features_list, event_matrix_list 203 | 204 | 205 | def create_event_matrix(frames_num, start_times, end_times): 206 | """ 207 | Create a per-frame classification matrix whith 1 in times specified by start/end times and 0 elsewhere 208 | """ 209 | # Researve space data 210 | event_matrix = np.zeros((frames_num, cfg.classes_num)) 211 | 212 | for n in range(len(start_times)): 213 | start_frame = int(round(start_times[n] * cfg.frames_per_second)) 214 | end_frame = int(round(end_times[n] * cfg.frames_per_second)) + 1 215 | 216 | event_matrix[start_frame: end_frame] = 1 217 | 218 | return event_matrix 219 | 220 | 221 | def preprocess_tau_sed_data(data_dir, preprocess_mode, force_preprocess=False, fold_name='eval'): 222 | """ 223 | Download, extract and preprocess the tau sed datset 224 | force_preprocess: Force the preprocess phase to repeate: usefull in case you change the preprocess parameters 225 | """ 226 | cfg.cfg_descriptor += f"_C-{'-'.join(cfg.tau_sed_labels)}" 227 | 228 | ambisonic_2019_data_dir = f"{data_dir}/Tau_sound_events_2019" 229 | audio_dir, meta_data_dir = ensure_tau_data(ambisonic_2019_data_dir, fold_name=fold_name) 230 | 231 | processed_data_dir = os.path.join(ambisonic_2019_data_dir, 'processed', f"{dataset.spectogram_features.spectogram_configs.cfg_descriptor}") 232 | features_and_labels_dir = f"{processed_data_dir}/{preprocess_mode}-features_and_labels_{fold_name}" 233 | features_mean_std_file = f"{processed_data_dir}/{preprocess_mode}-features_mean_std_{fold_name}.pkl" 234 | if not os.path.exists(features_and_labels_dir) or force_preprocess: 235 | audio_paths_and_labels = get_tau_sed_paths_and_labels(audio_dir, meta_data_dir) 236 | preprocess_data(audio_paths_and_labels, output_dir=features_and_labels_dir, 237 | output_mean_std_file=features_mean_std_file, preprocess_mode=preprocess_mode) 238 | else: 239 | print("Using existing mel features") 240 | return features_and_labels_dir, features_mean_std_file 241 | 242 | 243 | def preprocess_film_clap_data(data_dir, preprocessed_mode, force_preprocess=False): 244 | """ 245 | Preprocess and Creates a data generator for the film_clap dataset 246 | """ 247 | film_clap_dir = os.path.join(data_dir, 'FilmClap') 248 | audio_and_labels_dir = os.path.join(film_clap_dir) 249 | cfg.cfg_descriptor += f"_tm-{cfg.time_margin}" 250 | if not os.path.exists(film_clap_dir): 251 | raise Exception("You should get you own dataset...") 252 | features_and_labels_dir = f"{film_clap_dir}/processed/{dataset.spectogram.spectogram_configs.cfg_descriptor}/{preprocessed_mode}-features_and_labels" 253 | features_mean_std_file = f"{film_clap_dir}/processed/{dataset.spectogram.spectogram_configs.cfg_descriptor}/{preprocessed_mode}-features_mean_std.pkl" 254 | if not os.path.exists(features_and_labels_dir) or force_preprocess: 255 | print("preprocessing raw data") 256 | audio_paths_and_labels = get_film_clap_paths_and_labels(audio_and_labels_dir, time_margin=cfg.time_margin) 257 | preprocess_data(audio_paths_and_labels, output_dir=features_and_labels_dir, 258 | output_mean_std_file=features_mean_std_file, preprocess_mode=preprocessed_mode) 259 | else: 260 | print("Using existing mel features") 261 | return features_and_labels_dir, features_mean_std_file 262 | 263 | 264 | def split_train_val(feature_names, val_descriptor): 265 | # Split to train, test 266 | if type(val_descriptor) == float: 267 | shuffle(feature_names) 268 | val_split = int(len(feature_names) * val_descriptor) 269 | train_feature_names = feature_names[val_split:] 270 | validate_feature_names = feature_names[:val_split] 271 | else: 272 | train_feature_names = [] 273 | validate_feature_names = [] 274 | for name in feature_names: 275 | if val_descriptor in name: 276 | validate_feature_names.append(name) 277 | else: 278 | train_feature_names.append(name) 279 | 280 | return train_feature_names, validate_feature_names -------------------------------------------------------------------------------- /dataset/waveform/waveform_configs.py: -------------------------------------------------------------------------------- 1 | from utils.common import human_format 2 | from dataset.common_config import * 3 | 4 | cfg_descriptor = f"WaveForm_SaR-{human_format(working_sample_rate)}_FrS-{human_format(frame_size)}" \ 5 | f"_HoS-{human_format(hop_size)}_Ch-{audio_channels}" -------------------------------------------------------------------------------- /dataset/waveform/waveform_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from dataset.waveform import waveform_configs as cfg 7 | from dataset.dataset_utils import read_multichannel_audio 8 | 9 | 10 | def split_to_frames_with_hop_size(waveform, start_times, end_times): 11 | """ 12 | Splits the waveform to overlapping frames and taggs each frame if its covered up to some degree by an event 13 | """ 14 | frames = [] 15 | labels = [] 16 | half_frame_size = cfg.frame_size // 2 17 | for center in np.arange(half_frame_size, waveform.shape[1] - half_frame_size + 1, step=cfg.hop_size): 18 | frame = waveform[:, center - half_frame_size: center + half_frame_size] 19 | label = False 20 | for s, e in zip(start_times, end_times): 21 | min_sample = max(s * cfg.working_sample_rate, center - half_frame_size) 22 | max_sample = min(e * cfg.working_sample_rate, center + half_frame_size) 23 | coverage = (max_sample - min_sample) / cfg.frame_size 24 | label = label or coverage > cfg.min_event_percentage_in_positive_frame 25 | # label = np.any([t[0] * cfg.working_sample_rate - half_frame_size < center < t[1] * cfg.working_sample_rate + half_frame_size for t in 26 | # zip(start_times, end_times)]) 27 | # label = num_event_samples_in_frame / cfg.frame_size >= cfg.min_event_percentage_in_positive_frame 28 | frames.append(frame) 29 | labels.append(label) 30 | return frames, labels 31 | 32 | 33 | def get_start_indices_labesl(waveform_length, start_times, end_times): 34 | """ 35 | Returns: a waveform_length size boolean array where the ith entry says wheter or not a frame starting from the ith 36 | sample is covered by an event 37 | """ 38 | label = np.zeros(waveform_length) 39 | for start, end in zip(start_times, end_times): 40 | event_first_start_index = int(start * cfg.working_sample_rate - cfg.frame_size * (1 - cfg.min_event_percentage_in_positive_frame)) 41 | event_last_start_index = int(end * cfg.working_sample_rate - cfg.frame_size * cfg.min_event_percentage_in_positive_frame) 42 | label[event_first_start_index: event_last_start_index] = 1 43 | return label 44 | 45 | 46 | class WaveformDataset: 47 | """ 48 | This dataset allows training a detector on raw waveforms. 49 | It splits all waveforms to frames of a defined size with some overlap and tags gives them a tag of one of the classes 50 | or zero for no-event. 51 | """ 52 | def __init__(self, audio_paths_labels_and_names, val_descriptor=0.15, balance_classes=False, augment_data=False): 53 | self.balance_classes = balance_classes 54 | self.augment_data = augment_data 55 | 56 | print("WaveformDataset:") 57 | print("\t- Loading samples into memory... ") 58 | train_audio_paths_labels_and_names, val_audio_paths_labels_and_names = split_train_val(audio_paths_labels_and_names, val_descriptor) 59 | 60 | self.long_waveform = [] 61 | self.all_start_indices_labels = [] 62 | self.possible_start_indices = [] 63 | frame_index = 0 64 | 65 | for i, (audio_path, start_times, end_times, audio_name) in enumerate(train_audio_paths_labels_and_names): 66 | waveform = read_multichannel_audio(audio_path, target_fs=cfg.working_sample_rate) 67 | waveform = waveform.T # -> (channels, samples) 68 | 69 | self.long_waveform.append(waveform) 70 | 71 | # restrict the starting indices so that random crop are not taken over two different waveforms 72 | possible_start_indices = np.arange(frame_index, frame_index + waveform.shape[1] - cfg.frame_size, dtype=np.uint32) 73 | self.possible_start_indices.append(possible_start_indices) 74 | frame_index += waveform.shape[1] 75 | 76 | # Store the correct label for each starting sample index of a frame 77 | label_per_start_index = get_start_indices_labesl(waveform.shape[1], start_times, end_times).astype(bool) 78 | self.all_start_indices_labels.append(label_per_start_index) 79 | 80 | self.long_waveform = np.concatenate(self.long_waveform, axis=1) 81 | self.all_start_indices_labels = np.concatenate(self.all_start_indices_labels) 82 | self.possible_start_indices = np.concatenate(self.possible_start_indices) 83 | 84 | np.random.shuffle(self.possible_start_indices) 85 | 86 | # Load val samples 87 | self.val_samples_sets = [] 88 | self.val_label_sets = [] 89 | self.val_file_names = [] 90 | for i, (audio_path, start_times, end_times, audio_name) in enumerate(val_audio_paths_labels_and_names): 91 | waveform = read_multichannel_audio(audio_path, target_fs=cfg.working_sample_rate) 92 | waveform = waveform.T # -> (channels, samples) 93 | # Split wave form to overlapping frames and create labels for each 94 | frames, labels = split_to_frames_with_hop_size(waveform, start_times, end_times) 95 | self.val_samples_sets.append(frames) 96 | self.val_label_sets.append(labels) 97 | self.val_file_names.append(audio_name) 98 | 99 | 100 | print(f"\t- Train split: {len(self.possible_start_indices)} overlapping fames. ~{100*np.sum(self.all_start_indices_labels==1)/len(self.possible_start_indices):.1f}% tagged as event") 101 | print(f"\t- Val split: {np.sum([ len(x) for x in self.val_label_sets])} frames. {np.sum([ np.sum(x) for x in self.val_label_sets])} tagged as event") 102 | 103 | def get_validation_sampler(self, max_validate_num): 104 | for i, (frames, labels, file_names) in enumerate(zip(self.val_samples_sets, self.val_label_sets, self.val_file_names)): 105 | if i > max_validate_num: 106 | break 107 | yield torch.tensor(frames), torch.tensor(labels), file_names 108 | 109 | def __len__(self): 110 | return len(self.possible_start_indices) 111 | 112 | def __getitem__(self, idx): 113 | start_index = self.possible_start_indices[idx] 114 | 115 | waveform = self.long_waveform[:, start_index + np.arange(cfg.frame_size)] 116 | label = self.all_start_indices_labels[start_index] 117 | 118 | if self.augment_data: 119 | waveform, label = self.augment_mix_samples(waveform, label) 120 | waveform, label = self.augment_add_noise(waveform, label) 121 | 122 | return waveform, label 123 | 124 | def augment_mix_samples(self, waveform, label): 125 | number_of_augmentations = np.random.choice([0, 1, 2, 3], 1, p=[0.5, 0.3, 0.15, 0.05])[0] 126 | for i in range(number_of_augmentations): 127 | random_start_idx = np.random.choice(self.possible_start_indices) 128 | waveform += self.long_waveform[:, random_start_idx + np.arange(cfg.frame_size)] 129 | label = max(label, self.all_start_indices_labels[random_start_idx]) 130 | waveform /= (number_of_augmentations + 1) 131 | return waveform, label 132 | 133 | def augment_add_noise(self, waveform, label): 134 | # TODO these number are fit to noise added to waveform and not spectogram 135 | r = np.random.rand() 136 | if r > 0.5: 137 | noise_var = 0.001 + (r + 0.5) * (0.005 - 0.001) 138 | waveform += np.random.normal(0, noise_var, size=waveform.shape) 139 | return waveform, label 140 | 141 | 142 | def split_train_val(tuples, val_descriptor): 143 | # Split to train, test 144 | if type(val_descriptor) == float: 145 | np.random.shuffle(tuples) 146 | val_split = int(len(tuples) * val_descriptor) 147 | train_tuples = tuples[val_split:] 148 | val_tuples = tuples[:val_split] 149 | else: 150 | train_tuples = [] 151 | val_tuples = [] 152 | for tuple in tuples: 153 | if val_descriptor in tuple[0]: 154 | val_tuples.append(tuple) 155 | else: 156 | train_tuples.append(tuple) 157 | 158 | return train_tuples, val_tuples 159 | 160 | if __name__ == '__main__': 161 | from dataset.waveform.waveform_dataset import WaveformDataset 162 | from dataset.dataset_utils import get_film_clap_paths_and_labels, get_tau_sed_paths_and_labels, \ 163 | read_multichannel_audio 164 | from dataset.download_tau_sed_2019 import ensure_tau_data 165 | # audio_dir, meta_data_dir = ensure_tau_data('/home/ariel/projects/sound/data/Tau_sound_events_2019', fold_name='eval') 166 | # audio_paths_labels_and_names = get_tau_sed_paths_and_labels(audio_dir, meta_data_dir) 167 | # dataset = WaveformDataset(audio_paths_labels_and_names) 168 | dataset = WaveformDataset(get_film_clap_paths_and_labels('/home/ariel/projects/sound/data/FilmClap', time_margin=cfg.time_margin), val_descriptor=0.0) 169 | import matplotlib.pyplot as plt 170 | import soundfile 171 | import matplotlib as mpl 172 | 173 | os.makedirs("debug", exist_ok=True) 174 | 175 | i = 0 176 | w = 0 177 | while w < 20: 178 | frame, label = dataset[i] 179 | i += 1 180 | if label: 181 | plt.plot(range(len(frame[0])), frame[0], c='r') 182 | plt.ylim(-0.5,0.5) 183 | plt.autoscale(tight=True) 184 | plt.savefig(os.path.join(f"debug/a_{i}.png")) 185 | plt.clf() 186 | w += 1 187 | 188 | # for frames, labels, filename in zip(dataset.val_samples_sets, dataset.val_label_sets, dataset.val_file_names): 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from models import * 4 | from dataset.spectogram import spectogram_configs as cfg 5 | from dataset.spectogram.preprocess import multichannel_stft, multichannel_complex_to_log_mel 6 | from dataset.dataset_utils import read_multichannel_audio 7 | from utils import plot_debug_image 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description='Example of parser. ') 11 | 12 | # Train 13 | parser.add_argument('audio_file', type=str) 14 | parser.add_argument('--ckpt', type=str, required=True) 15 | parser.add_argument('--outputs_dir', type=str, default='inference_outputs', help='Directory of your workspace.') 16 | parser.add_argument('--device', default='cuda:0', type=str) 17 | args = parser.parse_args() 18 | 19 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == "cuda:0" else "cpu") 20 | 21 | model = Cnn_AvgPooling(cfg.classes_num).to(device) 22 | # checkpoint = torch.load(args.ckpt, map_location=device) 23 | # model.load_state_dict(checkpoint['model']) 24 | 25 | print("Preprocessing audio file..") 26 | 27 | multichannel_audio = read_multichannel_audio(audio_path=args.audio_file, target_fs=cfg.working_sample_rate) 28 | 29 | log_mel_features = multichannel_complex_to_log_mel(multichannel_stft(multichannel_audio))[0] 30 | 31 | print("Inference..") 32 | with torch.no_grad(): 33 | output_event = model(torch.from_numpy(log_mel_features).to(device).float().unsqueeze(1)) 34 | output_event = output_event.cpu() 35 | os.makedirs(args.outputs_dir, exist_ok=True) 36 | 37 | plot_debug_image(log_mel_features, output=output_event[0], plot_path=os.path.join(args.outputs_dir, f"{os.path.splitext(os.path.basename(args.audio_file))[0]}.png")) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from train import train 7 | from utils.common import WeightedBCE 8 | 9 | 10 | def get_spectogram_dataset_model_and_criterion(args): 11 | from dataset.spectogram.spectograms_dataset import preprocess_film_clap_data, SpectogramDataset, preprocess_tau_sed_data 12 | from dataset.spectogram import spectogram_configs as cfg 13 | from models.spectogram_models import Cnn_AvgPooling 14 | 15 | # Define the dataset 16 | if args.dataset_name.lower() == "tau": 17 | features_and_labels_dir, features_mean_std_file = preprocess_tau_sed_data(args.dataset_dir, 18 | fold_name='eval', 19 | preprocess_mode=args.preprocess_mode, 20 | force_preprocess=args.force_preprocess) 21 | elif args.dataset_name.lower() == "filmclap": 22 | features_and_labels_dir, features_mean_std_file = preprocess_film_clap_data(args.dataset_dir, 23 | preprocessed_mode=args.preprocess_mode, 24 | force_preprocess=args.force_preprocess) 25 | else: 26 | raise ValueError(f"Only tau and filmclap datasets are supported, '{args.dataset_name}' given") 27 | 28 | dataset = SpectogramDataset(features_and_labels_dir, features_mean_std_file, 29 | augment_data=args.augment_data, 30 | balance_classes=args.balance_classes, 31 | val_descriptor=args.val_descriptor, 32 | preprocessed_mode=args.preprocess_mode) 33 | 34 | # Define the model 35 | model = Cnn_AvgPooling(cfg.classes_num, model_config=[(32,2), (64,2), (128,2), (128,1)]) 36 | # model = MobileNetV1(cfg.classes_num) 37 | if args.ckpt != '': 38 | checkpoint = torch.load(args.ckpt, map_location=device) 39 | model.load_state_dict(checkpoint['model']) 40 | 41 | # define the crieterion 42 | criterion = WeightedBCE(recall_factor=args.recall_priority, multi_frame=True) 43 | 44 | full_descriptor = f"{args.preprocess_mode}-{cfg.cfg_descriptor}" 45 | 46 | return dataset, model, criterion, full_descriptor 47 | 48 | 49 | def get_waveform_dataset_and_model(args): 50 | from dataset.waveform.waveform_dataset import WaveformDataset 51 | from dataset.waveform.waveform_configs import cfg_descriptor, time_margin 52 | from models.waveform_models import M5 53 | from dataset.dataset_utils import get_film_clap_paths_and_labels, get_tau_sed_paths_and_labels 54 | from dataset.download_tau_sed_2019 import ensure_tau_data 55 | 56 | if args.dataset_name.lower() == "tau": 57 | audio_dir, meta_data_dir = ensure_tau_data(f"{args.dataset_dir}/Tau_sound_events_2019", fold_name='eval') 58 | audio_paths_labels_and_names = get_tau_sed_paths_and_labels(audio_dir, meta_data_dir) 59 | elif args.dataset_name.lower() == "filmclap": 60 | audio_paths_labels_and_names = get_film_clap_paths_and_labels(os.path.join(args.dataset_dir, 'FilmClap'), time_margin) 61 | else: 62 | raise ValueError(f"Only tau and filmclap datasets are supported, '{args.dataset_name}' given") 63 | 64 | dataset = WaveformDataset(audio_paths_labels_and_names, 65 | augment_data=args.augment_data, 66 | balance_classes=args.balance_classes, 67 | val_descriptor=args.val_descriptor 68 | ) 69 | model = M5(1) 70 | 71 | criterion = WeightedBCE(recall_factor=args.recall_priority, multi_frame=False) 72 | 73 | return dataset, model, criterion, cfg_descriptor 74 | 75 | 76 | def get_dataset_and_model(args): 77 | if args.train_features.lower() == "spectogram": 78 | return get_spectogram_dataset_model_and_criterion(args) 79 | elif args.train_features.lower() == "waveform": 80 | return get_waveform_dataset_and_model(args) 81 | else: 82 | raise ValueError(f"training features can be raw waveform or spectogram only, '{args.train_features}' given") 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser(description='Example of parser. ') 87 | 88 | # Traininng 89 | parser.add_argument('--dataset_dir', type=str, default='../data', help='Directory of dataset.') 90 | parser.add_argument('--dataset_name', type=str, default='FilmClap', help='FilmClap or TAU') 91 | parser.add_argument('--train_features', type=str, default='Waveform', help='Spectogram or Waveform') 92 | 93 | # Spectogram only arguments 94 | parser.add_argument('--preprocess_mode', type=str, default='logMel', help='logMel or Complex; relevant only for Spectogram features') 95 | parser.add_argument('--force_preprocess', action='store_true', default=False, help='relevant only for Spectogram features') 96 | 97 | # Train 98 | parser.add_argument('--outputs_root', type=str, default='training_dir') 99 | parser.add_argument('--ckpt', type=str, default='') 100 | parser.add_argument('--val_descriptor', default=0.2, help='float for percentage string for specifying fold substring') 101 | parser.add_argument('--train_tag', type=str, default='') 102 | 103 | # Training tricks 104 | parser.add_argument('--augment_data', action='store_true', default=False) 105 | parser.add_argument('--balance_classes', action='store_true', default=False, 106 | help='Whether to make sure there is same number of samples with and without events') 107 | parser.add_argument('--recall_priority', type=float, default=5, help='priority factor for the bce loss') 108 | 109 | # Hyper parameters 110 | parser.add_argument('--batch_size', type=int, default=128) 111 | parser.add_argument('--lr', type=float, default=0.000001) 112 | parser.add_argument('--num_train_steps', type=int, default=100000) 113 | parser.add_argument('--log_freq', type=int, default=5000) 114 | 115 | # Infrastructure 116 | parser.add_argument('--device', default='cuda:0', type=str) 117 | parser.add_argument('--num_workers', default=12, type=int) 118 | 119 | args = parser.parse_args() 120 | 121 | device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == "cuda:0" else "cpu") 122 | 123 | dataset, model, criterion, cfg_descriptor = get_dataset_and_model(args) 124 | 125 | dataloader = DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) 126 | 127 | model = model.to(device) 128 | model.model_description() 129 | 130 | train_name = f"{args.dataset_name}_cfg({cfg_descriptor}_b{args.batch_size}_lr{args.lr}_{args.train_tag}" 131 | if args.balance_classes: 132 | train_name += "_BC" 133 | if args.augment_data: 134 | train_name += "_AD" 135 | 136 | train(model, dataloader, criterion, 137 | num_steps=args.num_train_steps, 138 | outputs_dir=os.path.join(args.outputs_root, train_name), 139 | device=device, 140 | lr=args.lr, 141 | log_freq=args.log_freq) 142 | -------------------------------------------------------------------------------- /models/spectogram_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dataset.spectogram.spectogram_configs import audio_channels, working_sample_rate, mel_bins, hop_size, classes_num 5 | from utils.common import count_parameters, human_format 6 | 7 | DEFAULT_CHANNEL_AND_POOL=[(64,2), (128,2), (256,2), (512,1)] 8 | 9 | def interpolate(x, ratio): 10 | ''' 11 | Upscales the 2nd axis of x by 'ratio', i.e Repeats each element in it 'ratio' times: 12 | In other words: Interpolate the prediction to have the same time_steps as the target. 13 | The time_steps mismatch is caused by maxpooling in CNN. 14 | 15 | Args: 16 | x: (batch_size, time_steps, classes_num) 17 | ratio: int, ratio to upsample 18 | ''' 19 | (batch_size, time_steps, classes_num) = x.shape 20 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 21 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 22 | return upsampled 23 | 24 | 25 | def init_layer(layer, nonlinearity='leaky_relu'): 26 | """Initialize a Linear or Convolutional layer. """ 27 | nn.init.kaiming_uniform_(layer.weight, nonlinearity=nonlinearity) 28 | 29 | if hasattr(layer, 'bias'): 30 | if layer.bias is not None: 31 | layer.bias.data.fill_(0.) 32 | 33 | 34 | def init_bn(bn): 35 | """Initialize a Batchnorm layer. """ 36 | 37 | bn.bias.data.fill_(0.) 38 | bn.running_mean.data.fill_(0.) 39 | bn.weight.data.fill_(1.) 40 | bn.running_var.data.fill_(1.) 41 | 42 | class MobileNetV1(nn.Module): 43 | def __init__(self, classes_num): 44 | super(MobileNetV1, self).__init__() 45 | # self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False) 46 | self.bn0 = nn.BatchNorm2d(64) 47 | 48 | def conv_bn(inp, oup, stride): 49 | _layers = [ 50 | nn.Conv2d(inp, oup, 3, 1, 1, bias=False), 51 | nn.AvgPool2d(stride), 52 | nn.BatchNorm2d(oup), 53 | nn.ReLU(inplace=True) 54 | ] 55 | _layers = nn.Sequential(*_layers) 56 | init_layer(_layers[0]) 57 | init_bn(_layers[2]) 58 | return _layers 59 | 60 | def conv_dw(inp, oup, stride): 61 | _layers = [ 62 | nn.Conv2d(inp, inp, 3, 1, 1, groups=inp, bias=False), 63 | nn.AvgPool2d(stride), 64 | nn.BatchNorm2d(inp), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 67 | nn.BatchNorm2d(oup), 68 | nn.ReLU(inplace=True) 69 | ] 70 | _layers = nn.Sequential(*_layers) 71 | init_layer(_layers[0]) 72 | init_bn(_layers[2]) 73 | init_layer(_layers[4]) 74 | init_bn(_layers[5]) 75 | return _layers 76 | self.num_pools = 3 77 | self.features = nn.Sequential( 78 | conv_bn(1, 32, 2), 79 | conv_dw( 32, 64, 1), 80 | conv_dw( 64, 128, 2), 81 | conv_dw(128, 128, 1), 82 | conv_dw(128, 256, 2), 83 | conv_dw(256, 256, 1), 84 | conv_dw(256, 512, 1), 85 | conv_dw(512, 512, 1), 86 | conv_dw(512, 512, 1), 87 | conv_dw(512, 512, 1), 88 | conv_dw(512, 512, 1), 89 | conv_dw(512, 1024, 1), 90 | conv_dw(1024, 1024, 1) 91 | ) 92 | self.fc1 = nn.Linear(1024, 1024, bias=True) 93 | self.fc_audioset = nn.Linear(1024, classes_num, bias=True) 94 | 95 | self.init_weights() 96 | 97 | def init_weights(self): 98 | init_bn(self.bn0) 99 | init_layer(self.fc1) 100 | init_layer(self.fc_audioset) 101 | 102 | def forward(self, x): 103 | """ 104 | Input: (batch_size, data_length)""" 105 | x = x.transpose(0, 1) # -> (batch_size, channels_num, times_steps, freq_bins) 106 | # x = x.transpose(1, 3) 107 | # x = self.bn0(x) 108 | # x = x.transpose(1, 3) 109 | 110 | x = self.features(x) # (batch_size, 512, time_steps / x, mel_bins / x) 111 | x = torch.mean(x, dim=3) # (batch_size, 512, time_steps / x) 112 | 113 | x = x.transpose(1, 2) # (batch_size, time_steps, 512) 114 | x = F.relu_(self.fc1(x)) # (batch_size, time_steps, 512) 115 | # embedding = F.dropout(x, p=0.5, training=self.training) 116 | 117 | event_output = torch.sigmoid(self.fc_audioset(x)) # (batch_size, time_steps, classes_num) 118 | 119 | # Interpolate 120 | event_output = interpolate(event_output, 2**self.num_pools) 121 | 122 | return event_output 123 | 124 | def model_description(self): 125 | print(f"\tMobileNetV1 has {human_format(count_parameters(self))} parameters") 126 | 127 | 128 | class ConvBlock(nn.Module): 129 | def __init__(self, in_channels, out_channels, pool_size=2): 130 | super(ConvBlock, self).__init__() 131 | self.pool_size = pool_size 132 | self.conv1 = nn.Conv2d(in_channels=in_channels, 133 | out_channels=out_channels, 134 | kernel_size=(3, 3), stride=(1, 1), 135 | padding=(1, 1), bias=False) 136 | 137 | self.conv2 = nn.Conv2d(in_channels=out_channels, 138 | out_channels=out_channels, 139 | kernel_size=(3, 3), stride=(1, 1), 140 | padding=(1, 1), bias=False) 141 | 142 | self.bn1 = nn.BatchNorm2d(out_channels) 143 | self.bn2 = nn.BatchNorm2d(out_channels) 144 | 145 | self.init_weights() 146 | 147 | def init_weights(self): 148 | init_layer(self.conv1) 149 | init_layer(self.conv2) 150 | init_bn(self.bn1) 151 | init_bn(self.bn2) 152 | 153 | def forward(self, input): 154 | x = input 155 | x = F.relu_(self.bn1(self.conv1(x))) 156 | x = F.relu_(self.bn2(self.conv2(x))) 157 | 158 | x = F.avg_pool2d(x, kernel_size=self.pool_size) 159 | 160 | return x 161 | 162 | 163 | class Cnn_AvgPooling(nn.Module): 164 | def __init__(self, classes_num, model_config=DEFAULT_CHANNEL_AND_POOL): 165 | super(Cnn_AvgPooling, self).__init__() 166 | self.model_config = model_config 167 | self.num_pools = 1 if model_config[0][1] == 2 else 1 168 | self.conv_blocks = [ConvBlock(in_channels=audio_channels, out_channels=model_config[0][0], pool_size=model_config[0][1])] 169 | for i in range(1, len(model_config)): 170 | pool_size = model_config[i][1] 171 | if pool_size == 2: 172 | self.num_pools += 1 173 | self.conv_blocks.append(ConvBlock(in_channels=model_config[i - 1][0], out_channels=model_config[i][0], pool_size=pool_size)) 174 | 175 | self.conv_blocks = torch.nn.Sequential(*self.conv_blocks) 176 | 177 | self.event_fc = nn.Linear(model_config[-1][0], classes_num, bias=True) 178 | 179 | self.init_weights() 180 | 181 | def init_weights(self): 182 | 183 | init_layer(self.event_fc) 184 | 185 | def forward(self, x): 186 | ''' 187 | Input: (batch_size, channels_num, times_steps, freq_bins)''' 188 | 189 | x = self.conv_blocks(x) 190 | 191 | # x.shape : (batch_size, channels_num, times_steps, freq_bins) 192 | 193 | x = torch.mean(x, dim=3) # (batch_size, channels_num, time_steps) 194 | x = x.transpose(1, 2) # (batch_size, time_steps, channels_num) 195 | 196 | # event_output = torch.sigmoid(self.event_fc(x)) # (batch_size, time_steps, classes_num) 197 | event_output = self.event_fc(x) # (batch_size, time_steps, classes_num) 198 | 199 | # Interpolate 200 | event_output = interpolate(event_output, 2**(self.num_pools)) 201 | 202 | return event_output 203 | 204 | def logits(self, x): 205 | return torch.sigmoid(self.forward(x)) 206 | 207 | def model_description(self): 208 | print("Model description") 209 | b = 'b' 210 | w = mel_bins 211 | h = 60 * working_sample_rate // hop_size 212 | c = audio_channels 213 | # dummy_input = torch.ones() 214 | print(f"\tInput: ({b}, {c}, {h}, {w})") 215 | for (c, k) in self.model_config: 216 | h = h // k 217 | w = w // k 218 | print(f"\tconv_block -> ({b}, {c}, {h}, {w})") 219 | 220 | print(f"\tmean(dim=3) -> ({b}, {c}, {h})") 221 | print(f"\ttranspose(1,2) -> ({b}, {h}, {c})") 222 | print(f"\tFC + sigmoid -> ({b}, {h}, {classes_num})") 223 | num_outputs = h 224 | h *= 2**(self.num_pools) 225 | num_frames = h 226 | frame_duration = hop_size / working_sample_rate 227 | print(f"\tinterpolate({2**(self.num_pools)})-> ({b}, {h}, {classes_num})") 228 | print(f"\tModel has {num_outputs} outputs before interpolation, each stands for {2**(self.num_pools)} frames or" 229 | f" {2**(self.num_pools)*frame_duration:.2f}s") 230 | print(f"\tModel has {human_format(count_parameters(self))} parameters") -------------------------------------------------------------------------------- /models/waveform_models.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Sequential 2 | from dataset.waveform.waveform_configs import frame_size, audio_channels 3 | import torch.nn as nn 4 | import torch 5 | 6 | from utils.common import count_parameters, human_format 7 | 8 | 9 | class M5(nn.Module): 10 | """ 11 | Model described in "VERY DEEP CONVOLUTIONAL NEURAL NETWORKS FOR RAW WAVEFORMS 12 | """ 13 | def __init__(self, classes_num): 14 | super(M5, self).__init__() 15 | self.conv_block1 = Sequential( 16 | nn.Conv1d(audio_channels, 64, kernel_size=79, stride=4, padding=39), 17 | nn.BatchNorm1d(64), 18 | nn.ReLU(), 19 | nn.MaxPool1d(4,4) 20 | ) 21 | 22 | self.conv_block2 = Sequential( 23 | nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1), 24 | nn.BatchNorm1d(64), 25 | nn.ReLU(), 26 | nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1), 27 | nn.BatchNorm1d(64), 28 | nn.ReLU(), 29 | nn.MaxPool1d(4, 4) 30 | ) 31 | self.conv_block3 = Sequential( 32 | nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1), 33 | nn.BatchNorm1d(64), 34 | nn.ReLU(), 35 | nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1), 36 | nn.BatchNorm1d(64), 37 | nn.ReLU(), 38 | nn.MaxPool1d(4, 4) 39 | ) 40 | self.conv_block4 = Sequential( 41 | nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1), 42 | nn.BatchNorm1d(128), 43 | nn.ReLU(), 44 | nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1), 45 | nn.BatchNorm1d(128), 46 | nn.ReLU(), 47 | nn.MaxPool1d(4, 4) 48 | ) 49 | self.conv_block5 = Sequential( 50 | nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1), 51 | nn.BatchNorm1d(256), 52 | nn.ReLU(), 53 | nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1), 54 | nn.BatchNorm1d(256), 55 | nn.ReLU(), 56 | ) 57 | self.fc = nn.Linear(256, classes_num) 58 | 59 | def forward(self, x): 60 | # x: (b, c, frame_size) 61 | x = self.conv_block1(x) # x: (b, 64, frame_size / 16) 62 | x = self.conv_block2(x) # x: (b, 64, frame_size / 64) 63 | x = self.conv_block3(x) # x: (b, 64, frame_size / 256) 64 | x = self.conv_block4(x) # x: (b, 64, frame_size / 1024) 65 | x = self.conv_block5(x) # x: (b, 64, frame_size / 1024) 66 | x = torch.mean(x, dim=2) 67 | x = self.fc(x) 68 | 69 | # x = torch.sigmoid(x) 70 | 71 | return x 72 | 73 | def model_description(self): 74 | print("Waveform model:") 75 | print(f"\t- Model has {human_format(count_parameters(self))} parameters") -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | from torch import optim 5 | from utils.common import ProgressPlotter 6 | from utils.metric_utils import calculate_metrics 7 | from utils.plot_utils import plot_sample_features 8 | from time import time 9 | import numpy as np 10 | 11 | 12 | def eval(model, dataloader, criterion, outputs_dir, iteration, device, limit_val_samples=None): 13 | losses = [] 14 | recal_sets, precision_sets, APs = [], [], [] 15 | debug_outputs = [] 16 | debug_targets = [] 17 | debug_inputs = [] 18 | debug_file_names = [] 19 | val_sampler = dataloader.dataset.get_validation_sampler(max_validate_num=limit_val_samples) 20 | for idx, (input, target, file_name) in enumerate(val_sampler): 21 | model.eval() 22 | with torch.no_grad(): 23 | model.eval() 24 | output = model(input.to(device).float()).cpu() 25 | 26 | loss = criterion(output, target.float()) 27 | 28 | if len(input.shape) == 4: 29 | mode = 'Spectogram' 30 | # spectogram: (batch, channels, frames, bins), 31 | # output: (batch, frames, classes) 32 | # target: (batch, frames, classes) 33 | input = input[0] 34 | output = output[0] 35 | target = target[0] 36 | else: 37 | mode = 'Waveform' 38 | # waveform (frames, channels, wave_samples), 39 | # output: (frames, classes) 40 | # target: (frames) 41 | input = input.permute(1, 0, 2) 42 | target = target.reshape(-1,1) 43 | 44 | output_logits = torch.sigmoid(output).numpy() 45 | target = target.numpy() 46 | 47 | recal_vals, precision_vals, AP = calculate_metrics(output_logits, target) 48 | 49 | losses.append(loss.item()) 50 | recal_sets.append(recal_vals) 51 | precision_sets.append(precision_vals) 52 | APs.append(AP) 53 | 54 | debug_inputs.append(input) 55 | debug_outputs.append(output_logits) 56 | debug_targets.append(target) 57 | debug_file_names.append(file_name) 58 | 59 | # plot input, outputs and targets of worst and best samples by each metric 60 | for (metric_name, values, named_indices) in [ 61 | ("loss", losses, [('worst', -1), ('2-worst', -2), ('3-worst', -3), ('best', 0)]), 62 | ('AP', APs, [('worst', 0), ('best', -1)])]: 63 | indices = np.argsort(values) 64 | for (name, idx) in named_indices: 65 | val_sample_idx = indices[idx] 66 | plot_sample_features(debug_inputs[val_sample_idx], 67 | mode=mode, 68 | output=debug_outputs[val_sample_idx], 69 | target=debug_targets[val_sample_idx], 70 | file_name=debug_file_names[val_sample_idx] + f" {metric_name} {values[val_sample_idx]:.2f}", 71 | plot_path=os.path.join(outputs_dir, 'images', f"Iter-{iteration}", 72 | f"{metric_name}-{name}.png")) 73 | 74 | return losses, recal_sets, precision_sets, APs 75 | 76 | 77 | def train(model, data_loader, criterion, num_steps, lr, log_freq, outputs_dir, device): 78 | print("Training:") 79 | print("\t- Using device: ", device) 80 | lr_decay_freq = 200 81 | plotter = ProgressPlotter() 82 | os.makedirs(os.path.join(outputs_dir, 'checkpoints'), exist_ok=True) 83 | 84 | # Optimizer 85 | optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0., amsgrad=True) 86 | 87 | iterations = 0 88 | epoch = 0 89 | training_start_time = time() 90 | tqdm_bar = tqdm(total=num_steps) 91 | tqdm_bar.set_description("Waiting for information..") 92 | while iterations < num_steps: 93 | for (batch_features, event_labels) in data_loader: 94 | tqdm_bar.update() 95 | # forward 96 | model.train() 97 | batch_outputs = model(batch_features.to(device).float()) 98 | loss = criterion(batch_outputs, event_labels.to(device).float()) 99 | 100 | # Backward 101 | optimizer.zero_grad() 102 | loss.backward() 103 | optimizer.step() 104 | 105 | plotter.report_train_loss(loss.item()) 106 | iterations += 1 107 | 108 | if iterations % lr_decay_freq == 0: 109 | for param_group in optimizer.param_groups: 110 | param_group['lr'] *= 0.997 111 | 112 | if iterations % log_freq == 0: 113 | im_sec = iterations * data_loader.batch_size / (time() - training_start_time) 114 | tqdm_bar.set_description( 115 | f"epoch: {epoch}, step: {iterations}, loss: {loss.item():.2f}, im/sec: {im_sec:.1f}, lr: {optimizer.param_groups[0]['lr']:.8f}") 116 | 117 | val_losses, recal_sets, precision_sets, APs = eval(model, data_loader, criterion, outputs_dir, iteration=iterations, 118 | device=device, limit_val_samples=3) 119 | 120 | plotter.report_validation_metrics(val_losses, recal_sets, precision_sets, APs, iterations) 121 | plotter.plot(outputs_dir) 122 | 123 | checkpoint = { 124 | 'iterations': iterations, 125 | 'model': model.state_dict(), 126 | 'optimizer': optimizer.state_dict()} 127 | 128 | torch.save(checkpoint, os.path.join(outputs_dir, 'checkpoints', f"iteration_{iterations}.pth")) 129 | 130 | if iterations == num_steps: 131 | break 132 | epoch += 1 -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | 5 | from torch import tensor 6 | from torch.nn.functional import binary_cross_entropy_with_logits 7 | 8 | from utils.metric_utils import f_score 9 | 10 | 11 | class WeightedBCE: 12 | def __init__(self, recall_factor, multi_frame): 13 | self.recall_factor = tensor([recall_factor]) 14 | self.multi_frame = multi_frame 15 | 16 | def __call__(self, output, target): 17 | if self.multi_frame: 18 | # expected shape (batch_size, frames_num, classes_num) 19 | # Number of frames differ due to pooling on eve/odd number of frames 20 | N = min(output.shape[1], target.shape[1]) 21 | _output = output[:, :N] 22 | _target = target[:, :N] 23 | 24 | else: 25 | # expected shape (batch_size, classes_num) 26 | _output = output.reshape(-1) 27 | _target = target 28 | 29 | return binary_cross_entropy_with_logits(_output, _target, 30 | pos_weight=self.recall_factor.to(_output.device)) 31 | 32 | 33 | class ProgressPlotter: 34 | def __init__(self): 35 | self.train_buffer = [] 36 | self.train_avgs = [] 37 | self.val_avgs = [] 38 | self.f1_score_avgs = [] 39 | self.f5_score_avgs = [] 40 | self.AP_avgs = [] 41 | self.iterations = [] 42 | 43 | def report_train_loss(self, loss): 44 | self.train_buffer.append(loss) 45 | 46 | def report_validation_metrics(self, val_losses, recal_sets, precision_sets, APs, iteration): 47 | self.iterations.append(iteration) 48 | 49 | self.val_avgs.append(np.mean(val_losses)) 50 | self.AP_avgs.append(np.mean(APs)) 51 | self.last_recal_vals = np.mean(recal_sets, axis=0) 52 | self.last_precision_vals = np.mean(precision_sets, axis=0) 53 | f1_scores = f_score(self.last_precision_vals, self.last_recal_vals, precision_importance_factor=1) 54 | f5_scores = f_score(self.last_precision_vals, self.last_recal_vals, precision_importance_factor=5) 55 | self.f1_score_avgs.append(np.max(f1_scores)) 56 | self.f5_score_avgs.append(np.max(f5_scores)) 57 | 58 | def plot(self, outputs_dir): 59 | self.plot_train_eval_losses(os.path.join(outputs_dir, 'Training_loss.png')) 60 | self.plot_metrics(os.path.join(outputs_dir, 'Metrics.png')) 61 | self.plot_roc(os.path.join(outputs_dir, 'ROC_plots', f"Roc-iteration-{self.iterations[-1]}.png")) 62 | 63 | def plot_train_eval_losses(self, plot_path): 64 | self.train_avgs += [np.mean(self.train_buffer)] 65 | self.train_buffer = [] 66 | 67 | plt.plot(np.arange(len(self.train_avgs)), self.train_avgs, label='train', color='blue') 68 | plt.plot(np.arange(len(self.val_avgs)), self.val_avgs, label='validation', color='orange') 69 | x_indices = np.arange(0, len(self.iterations), max(len(self.iterations) // 5, 1)) 70 | plt.xticks(x_indices, np.array(self.iterations)[x_indices]) 71 | plt.xlabel("train step") 72 | plt.ylabel("loss") 73 | plt.legend() 74 | plt.savefig(plot_path) 75 | plt.clf() 76 | 77 | def plot_metrics(self, plot_path): 78 | plt.plot(np.arange(len(self.f1_score_avgs)), self.f1_score_avgs, color='blue', label='Max f1 scroe') 79 | plt.plot(np.arange(len(self.f5_score_avgs)), self.f5_score_avgs, color='green', label='Max f5 scroe') 80 | plt.plot(np.arange(len(self.AP_avgs)), self.AP_avgs, color='orange', label='Average precision') 81 | plt.title("Metrics") 82 | x_indices = np.arange(0, len(self.iterations), max(len(self.iterations) // 5, 1)) 83 | plt.xticks(x_indices, np.array(self.iterations)[x_indices]) 84 | plt.legend() 85 | plt.savefig(plot_path) 86 | plt.clf() 87 | 88 | def plot_roc(self, plot_path): 89 | os.makedirs(os.path.dirname(plot_path), exist_ok=True) 90 | plt.plot(self.last_recal_vals, self.last_precision_vals) 91 | plt.xticks([0, 0.25, 0.5, 0.75, 1]) 92 | plt.yticks([0, 0.25, 0.5, 0.75, 1]) 93 | MAP = np.sum(self.last_precision_vals[:-1] * (self.last_recal_vals[:-1] - self.last_recal_vals[1:])) 94 | plt.title(f"Validation AVG ROC" 95 | f"\nAP: {MAP:.2f}") 96 | plt.xlabel("Avg Recall") 97 | plt.ylabel("Avg Precision") 98 | plt.savefig(plot_path) 99 | plt.clf() 100 | 101 | 102 | def human_format(num): 103 | """ 104 | :param num: A number to print in a nice readable way. 105 | :return: A string representing this number in a readable way (e.g. 1000 --> 1K). 106 | """ 107 | magnitude = 0 108 | 109 | while abs(num) >= 1000: 110 | magnitude += 1 111 | num /= 1000.0 112 | 113 | return '%.1f%s' % (num, ['', 'K', 'M', 'G', 'T', 'P'][magnitude]) # add more suffices if you need them 114 | 115 | 116 | def count_parameters(model): 117 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 118 | -------------------------------------------------------------------------------- /utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def calculate_metrics(output, target): 5 | ths =np.arange(0.00, 1.05, 0.05) 6 | N = min(output.shape[0], target.shape[0]) 7 | T = target[:N] 8 | O = output[:N] 9 | recals = [] 10 | precisions = [] 11 | for th in ths: 12 | O_discrete = np.where(O > th, 1, 0) 13 | recall, prec = compute_recall_precision(O_discrete, T) 14 | recals.append(recall) 15 | precisions.append(prec) 16 | 17 | recals, precisions = np.array(recals), np.array(precisions) 18 | # from sklearn.metrics import average_precision_score 19 | # AP = average_precision_score(T.reshape(-1).astype(int), O.reshape(-1)) 20 | AP = np.sum(precisions[:-1] * (recals[:-1] - recals[1:])) 21 | return recals, precisions, AP 22 | 23 | 24 | def compute_recall_precision(O, T): 25 | TP = ((2 * T - O) == 1).sum() 26 | 27 | num_gt = T.sum() 28 | num_positives = O.sum() 29 | 30 | recall = float(TP) / float(num_gt) if num_gt > 0 else 1 31 | prec = float(TP) / float(num_positives) if num_positives > 0 else 1 32 | 33 | return recall, prec 34 | 35 | 36 | def f_score(recll, precision, precision_importance_factor=1): 37 | return (1+precision_importance_factor**2) * recll * precision / (precision_importance_factor**2 * recll + precision + 1e-9) -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | from mpl_toolkits.axes_grid1 import make_axes_locatable 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | 8 | 9 | def plot_waveform(ax, waveform, sample_rate): 10 | # subsample for faster plotting 11 | ax.set_facecolor('k') 12 | new_sample_rate = sample_rate / 10 13 | new_waveform = waveform[::10] 14 | ax.plot(range(len(new_waveform)), new_waveform, c='r') 15 | ax.margins(x=0) 16 | ax.set_title('Time', color='r') 17 | ax.set_ylabel('Amplitudes') 18 | 19 | xticks = np.arange(0, len(new_waveform), len(new_waveform) // 8) 20 | xlabels = [f"{x / new_sample_rate:.2f}s" for x in xticks] 21 | 22 | ax.set_xticks(xticks) 23 | ax.set_xticklabels(xlabels) 24 | ax.xaxis.set_ticks_position('bottom') 25 | 26 | 27 | def plot_spectogram(ax, spectpgram, frames_per_second): 28 | frames_num, mel_bins = spectpgram.shape 29 | colorbar = ax.matshow(spectpgram.T, origin='lower', aspect='auto', cmap='jet') 30 | ax.set_title('Log mel spectrogram', color='r') 31 | ax.set_ylabel('Mel bins') 32 | ax.set_yticks([0, mel_bins]) 33 | ax.set_yticklabels([0, mel_bins]) 34 | 35 | tick_hop = frames_num // 8 36 | xticks = np.concatenate((np.arange(0, frames_num - tick_hop, tick_hop), [frames_num])) 37 | xlabels = [f"frame {x}\n{x / frames_per_second:.1f}s" for x in xticks] 38 | 39 | ax.set_xticks(xticks) 40 | ax.set_xticklabels(xlabels) 41 | ax.xaxis.set_ticks_position('bottom') 42 | 43 | return colorbar 44 | 45 | 46 | def plot_classification_matrix(ax, mat, frames_per_second): 47 | frames_num = mat.shape[0] 48 | colorbar = ax.matshow(mat.T, origin='lower', aspect='auto', cmap='jet', vmin=0, vmax=1) 49 | tick_hop = frames_num // 8 50 | xticks = np.concatenate((np.arange(0, frames_num - tick_hop, tick_hop), [frames_num])) 51 | xlabels = [f"frame {x}\n{x / frames_per_second:.1f}s" for x in xticks] 52 | 53 | ax.set_xticks(xticks) 54 | ax.set_xticklabels(xlabels) 55 | ax.xaxis.set_ticks_position('bottom') 56 | 57 | return colorbar 58 | 59 | 60 | def add_colorbar_to_axs(fig, ax, colorbar): 61 | divider = make_axes_locatable(ax) 62 | cax = divider.append_axes('right', size='1%', pad=0.01) 63 | fig.colorbar(colorbar, cax=cax, orientation='vertical') 64 | 65 | 66 | def plot_sample_features(input, mode, output=None, target=None, file_name=None, plot_path=None): 67 | os.makedirs(os.path.dirname(plot_path), exist_ok=True) 68 | num_plots = 1 69 | if output is not None: 70 | num_plots += 1 71 | if target is not None: 72 | num_plots += 1 73 | 74 | fig, axs = plt.subplots(num_plots, 1, figsize=(20, 20)) 75 | plt.subplots_adjust(hspace=1) 76 | if file_name: 77 | fig.suptitle(f"Sample name: {file_name}") 78 | 79 | input = input.mean(0) # Mean over channels 80 | if mode.lower() == 'spectogram': 81 | from dataset.spectogram.spectogram_configs import frames_per_second 82 | colorbar = plot_spectogram(axs[0], input, frames_per_second) 83 | add_colorbar_to_axs(fig, axs[0], colorbar) 84 | else: # mode == 'Waveform 85 | from dataset.waveform.waveform_configs import working_sample_rate, hop_size 86 | frames_per_second = working_sample_rate // hop_size 87 | waveform = input[:,:hop_size].flatten() 88 | plot_waveform(axs[0], waveform, working_sample_rate) 89 | 90 | # shrink plot to fit labels plots with colorbar on the right 91 | divider = make_axes_locatable(axs[0]) 92 | cax = divider.append_axes('right', size='1%', pad=0.01) 93 | 94 | 95 | if output is not None: 96 | colorbar = plot_classification_matrix(axs[1], output, frames_per_second) 97 | axs[1].set_title("Predicted sound events", color='b') 98 | add_colorbar_to_axs(fig, axs[1], colorbar) 99 | 100 | if target is not None: 101 | idx = 1 if output is None else 2 102 | colorbar = plot_classification_matrix(axs[idx], target, frames_per_second) 103 | axs[idx].set_title(f"Reference sound events, marked frames: {int(target.sum())}", color='r') 104 | 105 | add_colorbar_to_axs(fig, axs[idx], colorbar) 106 | 107 | fig.tight_layout() 108 | plt.savefig(plot_path) 109 | plt.close(fig) 110 | plt.close() 111 | plt.clf() 112 | plt.close('all') 113 | import gc 114 | gc.collect() --------------------------------------------------------------------------------