├── .gitignore ├── .vscode └── launch.json ├── LICENSE ├── README.MD ├── audio ├── README.MD ├── audio_feature_extractor.py ├── audio_model.py ├── audio_records.py ├── audio_urban_preprocess.py └── audio_util.py ├── audio_inference_demo.py ├── audio_params.py ├── audio_train.py ├── conda.env.yml ├── data ├── records │ ├── urban_sound_test.tfrecords │ ├── urban_sound_train.tfrecords │ └── urban_sound_val.tfrecords ├── train │ └── urban_sound_train │ │ ├── audio_urban_model.ckpt.data-00000-of-00001 │ │ ├── audio_urban_model.ckpt.index │ │ └── audio_urban_model.ckpt.meta ├── vggish │ ├── README.MD │ └── vggish_pca_params.npz └── wav │ ├── 13230-0-0-1.wav │ ├── 16772-8-0-0.wav │ ├── 7389-1-1-0.wav │ └── README.MD └── vggish ├── README.MD ├── mel_features.py ├── vggish_inference_demo.py ├── vggish_input.py ├── vggish_params.py ├── vggish_postprocess.py ├── vggish_slim.py ├── vggish_smoke_test.py └── vggish_train_demo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # tensorflow temp file 107 | tensorboard/ 108 | checkpoint 109 | # ignore: l0.68_audio_urban_model.ckpt.data-00000-of-00001 110 | l*.ckpt.* 111 | *.png 112 | vggish_model.ckpt 113 | deprecated* 114 | *.records 115 | .idea/ 116 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "cwd": "${workspaceFolder}", 14 | "env": { 15 | "PYTHONPATH": "${workspaceFolder}" 16 | }, 17 | } 18 | ] 19 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | ## Audio Classification 2 | 3 | Classify the audios. In this repo, I train a model on [UrbanSound8K][data-urban] dataset, 4 | and achieve about `80%` accuracy on test dataset. 5 | 6 | There is a pre-trained model in [urban_sound_train][i-ckpt], trained epoch is `1000` 7 | 8 | ### Usage 9 | 10 | - [`audio_train.py`][i-train]: Train audio model from scratch or restore from checkpoint. 11 | - [`audio_params.py`][i-params]: Configuration for training a model. 12 | - [`audio_inference_demo.py`][i-demo]: Demo for test the trained model. 13 | - [`./audio/*`][i-audio]: Dependencies of training, model and datasets. 14 | - [`./vggish/*`][i-vggish]: Dependencies of [VGGish][tool-vggish] for feature extracting. 15 | 16 | 17 | ### Env setup 18 | 19 | Conda are recommended, just need one line: `conda env create -f conda.env.yml` 20 | 21 | ### Train & Test 22 | 23 | 1. Config parameters: `audio_params.py`. 24 | 2. Train the model: `python audio_train.py`. (It will **create tfrecords automaticly** if not exists) 25 | 3. Check the training process from tensorboard: `tensorboard --logdir=./data/tensorboard` 26 | 4. Test the model: `python audio_inference_demo.py`. 27 | 28 | ### Tools 29 | 30 | - [TensorFlow: VGGish][tool-vggish] 31 | - [Google AudioSet][tool-as] 32 | - [VGGish model checkpoint][tool-as-ckpt] 33 | - [Embedding PCA parameters][tool-as-pca] 34 | - [pyAudioAnalysis][tool-pyaa](Ref.) 35 | 36 | ### Dataset 37 | 38 | - [urban sound dataset][data-urban] 39 | 40 | ### Ref. Blogs 41 | 42 | - [AudioSet: An ontology and human-labelled dataset for audio events][blog-as] 43 | - [CNN Architectures for Large-Scale Audio Classification][blog-accnn] 44 | 45 | [i-train]: ./audio_train.py 46 | [i-params]: ./audio_params.py 47 | [i-demo]: ./audio_inference_demo.py 48 | [i-audio]: ./audio 49 | [i-vggish]: ./vggish 50 | [i-ckpt]: ./data/train/urban_sound_train 51 | [tool-vggish]: https://github.com/tensorflow/models/tree/master/research/audioset 52 | [tool-pyaa]: https://github.com/tyiannak/pyAudioAnalysis 53 | [tool-as]: https://research.google.com/audioset/index.html 54 | [tool-as-ckpt]: https://storage.googleapis.com/audioset/vggish_model.ckpt 55 | [tool-as-pca]: https://storage.googleapis.com/audioset/vggish_pca_params.npz 56 | [data-urban]: https://serv.cusp.nyu.edu/projects/urbansounddataset/urbansound8k.html 57 | [blog-as]: https://research.google.com/pubs/pub45857.html 58 | [blog-accnn]: https://research.google.com/pubs/pub45611.html -------------------------------------------------------------------------------- /audio/README.MD: -------------------------------------------------------------------------------- 1 | ## Audio Classification dependencies 2 | 3 | Classify the audios 4 | 5 | ### Usage 6 | 7 | - [`audio_feature_extractor.py`][i-fe]: Extracting VGGish or mel features from wav files. 8 | - [`audio_model.py`][i-mo]: Define classification model. 9 | - [`audio_records.py`][i-rcd]: Parse records. 10 | - [`audio_urban_preprocess.py`][i-prep]: Preprocess for [Urban Sound Dataset][data-urban]. 11 | - [`audio_util.py`][i-u]: Util functions. 12 | 13 | 14 | [i-fe]: ./audio_feature_extractor.py 15 | [i-mo]: ./audio_model.py 16 | [i-rcd]: ./audio_records.py 17 | [i-prep]: ./audio_urban_preprocess.py 18 | [i-u]: ./audio_util.py 19 | [data-urban]: https://serv.cusp.nyu.edu/projects/urbansounddataset/ 20 | -------------------------------------------------------------------------------- /audio/audio_feature_extractor.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | sys.path.append('./vggish') 5 | sys.path.append('./audio') 6 | 7 | import os 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from abc import ABC 13 | from abc import abstractmethod 14 | 15 | import vggish_slim 16 | import vggish_input 17 | import vggish_postprocess 18 | 19 | from audio_records import encodes_example 20 | from audio_util import maybe_create_directory 21 | from audio_params import NUM_VGGISH_FEATURE_PER_EXAMPLE 22 | 23 | 24 | class ExtractorBase(ABC): 25 | """Base class for Extractors""" 26 | def __init__(self): 27 | super(ExtractorBase, self).__init__() 28 | 29 | @abstractmethod 30 | def __enter__(self): 31 | return self 32 | 33 | @abstractmethod 34 | def __exit__(self, type, value, traceback): 35 | pass 36 | 37 | @abstractmethod 38 | def wavfile_to_features(self, wav_file): 39 | """Extract features from wav file.""" 40 | pass 41 | 42 | def create_records(self, record_path, wav_files, wav_labels): 43 | """Create TF Records from wav files and corresponding labels.""" 44 | record_dir = os.path.dirname(record_path) 45 | maybe_create_directory(record_dir) 46 | writer = tf.python_io.TFRecordWriter(record_path) 47 | N = len(wav_labels) 48 | for n, (wav_file, wav_label) in enumerate(zip(wav_files, wav_labels)): 49 | tf.logging.info('[{}/{}] Extracting VGGish feature:' 50 | ' label: {} - {}'.format(n, N, wav_label, wav_file)) 51 | 52 | features = self.wavfile_to_features(wav_file) 53 | 54 | if NUM_VGGISH_FEATURE_PER_EXAMPLE > 1: 55 | if NUM_VGGISH_FEATURE_PER_EXAMPLE != num_features: 56 | tf.logging.warning('Invalid vggish features length:' 57 | ' label: {} - {}'.format(wav_label, wav_file)) 58 | continue 59 | f = features.reshape(-1) 60 | example = encodes_example(np.float64(f), np.int64(l)) 61 | writer.write(example.SerializeToString()) 62 | else: 63 | num_features = features.shape[0] # one feature for one second 64 | if num_features == 0: 65 | tf.logging.warning('No vggish features:' 66 | ' label: {} - {}'.format(wav_label, wav_file)) 67 | continue 68 | cur_wav_labels = [wav_label] * num_features 69 | for (f, l) in zip(features, cur_wav_labels): 70 | example = encodes_example(np.float64(f), np.int64(l)) 71 | writer.write(example.SerializeToString()) 72 | writer.close() 73 | 74 | 75 | class MelExtractor(ExtractorBase): 76 | """Feature Extractor that extract mel feature from wav.""" 77 | def __init__(self): 78 | super(MelExtractor, self).__init__() 79 | 80 | def __enter__(self): 81 | return self 82 | 83 | def __exit__(self, type, value, traceback): 84 | pass 85 | 86 | @staticmethod 87 | def wavfile_to_features(wav_file): 88 | assert os.path.exists(wav_file), '{} not exists!'.format(wav_file) 89 | mel_features = vggish_input.wavfile_to_examples(wav_file) 90 | return mel_features 91 | 92 | 93 | class VGGishExtractor(ExtractorBase): 94 | """Feature Extractor use VGGish model from wav.""" 95 | def __init__(self, checkpoint, pca_params, input_tensor_name, output_tensor_name): 96 | """Create a new Graph and a new Session for every VGGishExtractor object.""" 97 | super(VGGishExtractor, self).__init__() 98 | 99 | self.graph = tf.Graph() 100 | with self.graph.as_default(): 101 | vggish_slim.define_vggish_slim(training=False) 102 | 103 | sess_config = tf.ConfigProto(allow_soft_placement=True) 104 | sess_config.gpu_options.allow_growth = True 105 | self.sess = tf.Session(graph=self.graph, config=sess_config) 106 | vggish_slim.load_defined_vggish_slim_checkpoint(self.sess, checkpoint) 107 | 108 | # use the self.sess to init others 109 | self.input_tensor = self.graph.get_tensor_by_name(input_tensor_name) 110 | self.output_tensor = self.graph.get_tensor_by_name(output_tensor_name) 111 | 112 | # postprocessor 113 | self.postprocess = vggish_postprocess.Postprocessor(pca_params) 114 | 115 | def __enter__(self): 116 | return self 117 | 118 | def __exit__(self, type, value, traceback): 119 | self.close() 120 | 121 | def mel_to_vggish(self, mel_features): 122 | """Converting mel features to VGGish features.""" 123 | assert mel_features is not None, 'mel_features is None' 124 | # mel_features shape is 0, skip 125 | if mel_features.shape[0]==0: 126 | return mel_features 127 | # Run inference and postprocessing. 128 | [embedding_batch] = self.sess.run([self.output_tensor], 129 | feed_dict={self.input_tensor: mel_features}) 130 | vggish_features = self.postprocess.postprocess(embedding_batch) 131 | return vggish_features 132 | 133 | def wavfile_to_features(self, wav_file): 134 | """Extract VGGish feature from wav file.""" 135 | assert os.path.exists(wav_file), '{} not exists!'.format(wav_file) 136 | mel_features = MelExtractor.wavfile_to_features(wav_file) 137 | return self.mel_to_vggish(mel_features) 138 | 139 | def close(self): 140 | self.sess.close() 141 | 142 | def main_test(): 143 | import audio_params 144 | from vggish import vggish_params 145 | import timeit 146 | from audio_util import urban_labels 147 | 148 | tf.get_logger().setLevel('INFO') 149 | 150 | wav_file = 'F:/3rd-datasets/UrbanSound8K-16bit/audio-classified/siren/90014-8-0-1.wav' 151 | wav_dir = 'F:/3rd-datasets/UrbanSound8K-16bit/audio-classified/siren' 152 | wav_filenames = os.listdir(wav_dir) 153 | wav_files = [os.path.join(wav_dir, wav_filename) for wav_filename in wav_filenames] 154 | wav_labels = urban_labels(wav_files) 155 | 156 | # test VGGishExtractor 157 | time_start = timeit.default_timer() 158 | with VGGishExtractor(audio_params.VGGISH_CHECKPOINT, 159 | audio_params.VGGISH_PCA_PARAMS, 160 | vggish_params.INPUT_TENSOR_NAME, 161 | vggish_params.OUTPUT_TENSOR_NAME) as ve: 162 | 163 | vggish_features = ve.wavfile_to_features(wav_file) 164 | print(vggish_features, vggish_features.shape) 165 | 166 | ve.create_records('./vggish_test.records', wav_files[:10], wav_labels[:10]) 167 | 168 | time_end = timeit.default_timer() 169 | # print('cost time: {}s, {}s/wav'.format((time_end-time_start), (time_end-time_start)/(i+1))) 170 | 171 | # test MelExtractor 172 | with MelExtractor() as me: 173 | mel_features = me.wavfile_to_features(wav_file) 174 | print(mel_features, mel_features.shape) 175 | me.create_records('./mel_test.records', wav_files[:10], wav_labels[:10]) 176 | 177 | 178 | def main_create_urban_tfr(): 179 | import timeit 180 | import natsort 181 | import audio_params 182 | import vggish_params 183 | from audio_util import train_test_val_split 184 | 185 | tf.get_logger().setLevel('INFO') 186 | 187 | def _listdir(d): 188 | return [os.path.join(d, f) for f in natsort.natsorted(os.listdir(d))] 189 | 190 | wav_dir = r"path/to/UrbanSound8K-16bit/audio-classified" 191 | tfr_dir = r"./data/tfrecords" 192 | 193 | wav_files = list() 194 | wav_labels = list() 195 | class_dict = dict() 196 | for idx, folder in enumerate(_listdir(wav_dir)): 197 | wavs = list(filter(lambda x: x.endswith('.wav'), _listdir(folder))) 198 | wav_files.extend(wavs) 199 | wav_labels.extend([idx] * len(wavs)) 200 | class_dict[idx] = os.path.basename(folder) 201 | print(f'class-id pair: {class_dict}') 202 | 203 | wav_file = wav_files[0] 204 | 205 | (X_train, Y_train), (X_test, Y_test), (X_val, Y_val) = train_test_val_split(wav_files, wav_labels, split=(.2, .1), shuffle=True) 206 | 207 | time_start = timeit.default_timer() 208 | with VGGishExtractor(audio_params.VGGISH_CHECKPOINT, 209 | audio_params.VGGISH_PCA_PARAMS, 210 | vggish_params.INPUT_TENSOR_NAME, 211 | vggish_params.OUTPUT_TENSOR_NAME) as ve: 212 | 213 | vggish_features = ve.wavfile_to_features(wav_file) 214 | print(vggish_features, vggish_features.shape) 215 | 216 | ve.create_records(os.path.join(tfr_dir, 'vggish.train.records'), X_train, Y_train) 217 | ve.create_records(os.path.join(tfr_dir, 'vggish.test.records'), X_test, Y_test) 218 | ve.create_records(os.path.join(tfr_dir, 'vggish.val.records'), X_val, Y_val) 219 | 220 | time_end = timeit.default_timer() 221 | print('cost time: {}s, {}s/wav'.format((time_end-time_start), (time_end-time_start)/len(wav_files))) 222 | 223 | if __name__ == '__main__': 224 | main_test() 225 | # main_create_urban_tfr() 226 | pass -------------------------------------------------------------------------------- /audio/audio_model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: luuil@outlook.com 3 | 4 | r"""Defines the 'audio' model used to classify the VGGish features.""" 5 | 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | import audio_params as params 10 | 11 | slim = tf.contrib.slim 12 | 13 | def define_audio_slim(training=False): 14 | """Defines the audio TensorFlow model. 15 | 16 | All ops are created in the current default graph, under the scope 'audio/'. 17 | 18 | The input is a placeholder named 'audio/vggish_input' of type float32 and 19 | shape [batch_size, feature_size] where batch_size is variable and 20 | feature_size is constant, and feature_size represents a VGGish output feature. 21 | The output is an op named 'audio/prediction' which produces the activations of 22 | a NUM_CLASSES layer. 23 | 24 | Args: 25 | training: If true, all parameters are marked trainable. 26 | 27 | Returns: 28 | The op 'audio/logits'. 29 | """ 30 | with slim.arg_scope([slim.fully_connected], 31 | weights_initializer=tf.truncated_normal_initializer( 32 | stddev=params.INIT_STDDEV), 33 | biases_initializer=tf.zeros_initializer(), 34 | trainable=training),\ 35 | tf.variable_scope('audio'): 36 | vggish_input = tf.placeholder(tf.float32, 37 | shape=[None, params.NUM_FEATURES], 38 | name='vggish_input') 39 | # Add a fully connected layer with NUM_UNITS units 40 | fc = slim.fully_connected(vggish_input, params.NUM_UNITS) 41 | logits = slim.fully_connected(fc, params.NUM_CLASSES, 42 | activation_fn=None, scope='logits') 43 | tf.nn.softmax(logits, name='prediction') 44 | return logits 45 | 46 | 47 | def load_audio_slim_checkpoint(session, checkpoint_path): 48 | """Loads a pre-trained audio-compatible checkpoint. 49 | 50 | This function can be used as an initialization function (referred to as 51 | init_fn in TensorFlow documentation) which is called in a Session after 52 | initializating all variables. When used as an init_fn, this will load 53 | a pre-trained checkpoint that is compatible with the audio model 54 | definition. Only variables defined by audio will be loaded. 55 | 56 | Args: 57 | session: an active TensorFlow session. 58 | checkpoint_path: path to a file containing a checkpoint that is 59 | compatible with the audio model definition. 60 | """ 61 | 62 | # Get the list of names of all audio variables that exist in 63 | # the checkpoint (i.e., all inference-mode audio variables). 64 | with tf.Graph().as_default(): 65 | define_audio_slim(training=False) 66 | audio_var_names = [v.name for v in tf.global_variables()] 67 | 68 | # Get list of variables from exist graph which passed by session 69 | with session.graph.as_default(): 70 | global_variables = tf.global_variables() 71 | 72 | # Get the list of all currently existing variables that match 73 | # the list of variable names we just computed. 74 | audio_vars = [v for v in global_variables if v.name in audio_var_names] 75 | 76 | # Use a Saver to restore just the variables selected above. 77 | saver = tf.train.Saver(audio_vars, name='audio_load_pretrained', 78 | write_version=1) 79 | saver.restore(session, checkpoint_path) -------------------------------------------------------------------------------- /audio/audio_records.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: luuil@outlook.com 3 | 4 | r"""Records related functions.""" 5 | 6 | from __future__ import print_function 7 | 8 | import sys 9 | sys.path.append('..') 10 | 11 | import tensorflow as tf 12 | 13 | from audio_params import AUDIO_FEATURE_NAME 14 | from audio_params import AUDIO_LABEL_NAME 15 | 16 | 17 | def encodes_example(feature, label): 18 | """Encodes to TF Example 19 | 20 | Args: 21 | feature: feature to encode 22 | label: label corresponding to feature 23 | 24 | Returns: 25 | tf.Example object 26 | """ 27 | def _bytes_feature(value): 28 | """Creates a TensorFlow Record Feature with value as a byte array. 29 | """ 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | def _int64_feature(value): 33 | """Creates a TensorFlow Record Feature with value as a 64 bit integer. 34 | """ 35 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 36 | 37 | features = {AUDIO_FEATURE_NAME: _bytes_feature(feature.tobytes()), 38 | AUDIO_LABEL_NAME: _int64_feature(label)} 39 | return tf.train.Example(features=tf.train.Features(feature=features)) 40 | 41 | 42 | def parse_example(example, shape=None): 43 | """Parse TF Example""" 44 | keys_to_feature = { AUDIO_FEATURE_NAME: tf.FixedLenFeature([], tf.string), 45 | AUDIO_LABEL_NAME: tf.FixedLenFeature([], tf.int64)} 46 | raw_parsed_example = tf.parse_single_example(example, features=keys_to_feature) 47 | feature = tf.decode_raw(raw_parsed_example[AUDIO_FEATURE_NAME], tf.float64) 48 | label = tf.cast(raw_parsed_example[AUDIO_LABEL_NAME], tf.int32) 49 | feature = tf.cast(feature, tf.float32) 50 | if shape is not None: 51 | feature = tf.reshape(feature, shape) 52 | return feature, label 53 | 54 | 55 | class RecordsParser(object): 56 | """Parse TF Records and return Iterator.""" 57 | def __init__(self, records_files, num_classes, feature_shape): 58 | super(RecordsParser, self).__init__() 59 | self.dataset = tf.data.TFRecordDataset(filenames=records_files) 60 | self.shape = feature_shape 61 | self.num_classes = num_classes 62 | 63 | def iterator(self, is_onehot=True, is_shuffle=False, batch_size=64, buffer_size=512): 64 | parse_func = lambda example: parse_example(example, shape=self.shape) 65 | dataset = self.dataset.map(parse_func) # Parse the record into tensors. 66 | # Only go through the data once with no repeat. 67 | num_repeats = 1 68 | if is_shuffle: 69 | # If training then read a buffer of the given size and randomly shuffle it. 70 | dataset = dataset.shuffle(buffer_size=buffer_size) 71 | dataset = dataset.repeat(num_repeats) # Repeat the input indefinitely. 72 | if is_onehot: 73 | onehot_func = lambda feature, label: (feature, 74 | tf.one_hot(label, self.num_classes)) 75 | dataset = dataset.map(onehot_func) 76 | 77 | dataset = dataset.batch(batch_size) 78 | iterator = dataset.make_initializable_iterator() 79 | batch = iterator.get_next() 80 | return iterator, batch 81 | 82 | 83 | if __name__ == '__main__': 84 | from audio_params import TF_RECORDS_VAL 85 | from audio_params import NUM_CLASSES 86 | from os.path import join as pjoin 87 | 88 | rp = RecordsParser([pjoin('..', TF_RECORDS_VAL)], NUM_CLASSES, feature_shape=None) 89 | iterator, data_batch = rp.iterator(is_onehot=True, batch_size=64) 90 | 91 | with tf.Session() as sess: 92 | sess.run(iterator.initializer) 93 | predicted = [] 94 | groundtruth = [] 95 | while True: 96 | try: 97 | features, labels = sess.run(data_batch) 98 | except tf.errors.OutOfRangeError: 99 | break 100 | print(features, features.shape) 101 | print(labels, labels.shape) -------------------------------------------------------------------------------- /audio/audio_urban_preprocess.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: luuil@outlook.com 3 | 4 | """Converting wav file's bits. 5 | 6 | Such as, convert `PCM_24` to `PCM_16` 7 | """ 8 | 9 | import os 10 | import soundfile # for convert wav file 11 | # import urban_sound_params 12 | from shutil import copyfile 13 | from audio_util import maybe_create_directory 14 | from audio_util import urban_labels 15 | 16 | 17 | def maybe_copy_file(src, dst): 18 | if not os.path.exists(dst): 19 | print('{} => {}'.format(src, dst)) 20 | copyfile(src, dst) 21 | 22 | 23 | 24 | def convert_wav(src_wav, dst_wav, subtype='PCM_16'): 25 | """Converting wav file's bits. 26 | 27 | Such as, convert `PCM_24` to `PCM_16` 28 | """ 29 | assert os.path.exists(src_wav), "{} not exists!".format(src_wav) 30 | data, sr = soundfile.read(src_wav) 31 | soundfile.write(dst_wav, data, sr, subtype=subtype) 32 | 33 | def convert_urban_pcm24_to_pcm16(): 34 | """Convert urban sound codec from PCM_24 to PCM_16.""" 35 | src_dir = ['/data1/data/UrbanSound8K/audio/fold{:d}'.format(i+1) for i in range(10)] 36 | dst_dir = ['/data1/data/UrbanSound8K-16bit/audio/fold{:d}'.format(i+1) for i in range(10)] 37 | converted_wav_paths = [] 38 | for dsrc, ddst in zip(src_dir, dst_dir): 39 | maybe_create_directory(ddst) 40 | wav_files = filter(lambda FP: FP if FP.endswith('.wav') else None, 41 | [FP for FP in os.listdir(dsrc)]) 42 | for wav_file in wav_files: 43 | src_wav, dst_wav = os.path.join(dsrc, wav_file), os.path.join(ddst, wav_file) 44 | convert_wav(src_wav, dst_wav, subtype='PCM_16') 45 | converted_wav_paths.append(dst_wav) 46 | print('converted count:', len(converted_wav_paths)) 47 | print(converted_wav_paths, len(converted_wav_paths)) 48 | 49 | 50 | def arange_urban_sound_file_by_class(): 51 | """Arange urban sound file by it's class.""" 52 | def _listdir(d): 53 | return [os.path.join(d, f) for f in sorted(os.listdir(d))] 54 | 55 | src_path = '/data1/data/UrbanSound8K-16bit/audio' 56 | dst_dir = '/data1/data/UrbanSound8K-16bit/audio-classfied' 57 | 58 | src_paths = list() 59 | for d in _listdir(src_path): 60 | wavs = filter(lambda x: x.endswith('.wav'), _listdir(d)) 61 | src_paths.extend(list(wavs)) 62 | 63 | CLASSES = [ 64 | 'air conditioner', 65 | 'car horn', 66 | 'children playing', 67 | 'dog bark', 68 | 'drilling', 69 | 'engine idling', 70 | 'gun shot', 71 | 'jackhammer', 72 | 'siren', 73 | 'street music'] 74 | CLASSES_STRIPED = [c.replace(' ', '_') for c in CLASSES] 75 | for src in src_paths: 76 | lbl = urban_labels([src])[0] 77 | dst = '{dir}/{label}'.format(dir=dst_dir, label=CLASSES_STRIPED[lbl]) 78 | maybe_create_directory(dst) 79 | maybe_copy_file(src, '{dst}/{name}'.format(dst=dst, name=os.path.split(src)[-1])) 80 | 81 | 82 | if __name__ == '__main__': 83 | convert_urban_pcm24_to_pcm16() 84 | arange_urban_sound_file_by_class() 85 | pass 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /audio/audio_util.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: luuil@outlook.com 3 | 4 | r"""Util functions.""" 5 | 6 | from __future__ import print_function 7 | 8 | import os 9 | import numpy as np 10 | 11 | 12 | def is_exists(path): 13 | if not os.path.exists(path): 14 | print('Not exists: {}'.format(path)) 15 | return False 16 | return True 17 | 18 | 19 | def maybe_create_directory(dirname): 20 | """Check directory exists or create it.""" 21 | if not os.path.exists(dirname): 22 | os.makedirs(dirname) 23 | 24 | 25 | def maybe_download(url, dst_dir): 26 | """Download file. 27 | 28 | If the file not exist then download it. 29 | 30 | Args: 31 | url: Web location of the file. 32 | 33 | Returns: 34 | path to downloaded file. 35 | """ 36 | import urllib.request 37 | maybe_create_directory(dst_dir) 38 | filename = url.split('/')[-1] 39 | filepath = os.path.join(dst_dir, filename) 40 | if not os.path.exists(filepath): 41 | def _progress(count, block_size, total_size): 42 | sys.stdout.write('\r>> Downloading %s %.1f%%' % 43 | (filename, 44 | float(count * block_size) / float(total_size) * 100.0)) 45 | sys.stdout.flush() 46 | 47 | filepath, _ = urllib.request.urlretrieve(url, filepath, _progress) 48 | print() 49 | statinfo = os.stat(filepath) 50 | print('Successfully downloaded:', filename, statinfo.st_size, 'bytes.') 51 | return filepath 52 | 53 | 54 | def maybe_download_and_extract(url, dst_dir): 55 | """Download and extract model tar file. 56 | 57 | If the pretrained model we're using doesn't already exist, this function 58 | downloads it from the TensorFlow.org website and unpacks it into a directory. 59 | 60 | Args: 61 | url: Web location of the tar file containing the pretrained model. 62 | dst_dir: Destination directory to save downloaded and extracted file. 63 | 64 | Returns: 65 | None. 66 | """ 67 | import tarfile 68 | filepath =maybe_download(url, dst_dir) 69 | tarfile.open(filepath, 'r:gz').extractall(dst_dir) 70 | 71 | 72 | def urban_labels(fpaths): 73 | """urban sound dataset labels.""" 74 | urban_label = lambda path: int(os.path.split(path)[-1].split('-')[1]) 75 | return [urban_label(p) for p in fpaths] 76 | 77 | 78 | def train_test_val_split(X, Y, split=(0.2, 0.1), shuffle=True): 79 | """Split dataset into train/val/test subsets by 70:20:10(default). 80 | 81 | Args: 82 | X: List of data. 83 | Y: List of labels corresponding to data. 84 | split: Tuple of split ratio in `test:val` order. 85 | shuffle: Bool of shuffle or not. 86 | 87 | Returns: 88 | Three dataset in `train:test:val` order. 89 | """ 90 | from sklearn.model_selection import train_test_split 91 | assert len(X) == len(Y), 'The length of X and Y must be consistent.' 92 | X_train, X_test_val, Y_train, Y_test_val = train_test_split(X, Y, 93 | test_size=(split[0]+split[1]), shuffle=shuffle) 94 | X_test, X_val, Y_test, Y_val = train_test_split(X_test_val, Y_test_val, 95 | test_size=split[1]/(split[0]+split[1]), shuffle=False) 96 | return (X_train, Y_train), (X_test, Y_test), (X_val, Y_val) 97 | 98 | 99 | def calculate_flops(graph): 100 | """Calculate floating point operations with specified `graph`. 101 | 102 | Print to stdout an analysis of the number of floating point operations in the 103 | model broken down by individual operations. 104 | """ 105 | tf.profiler.profile(graph=graph, 106 | options=tf.profiler.ProfileOptionBuilder.float_operation(), cmd='scope') 107 | 108 | 109 | if __name__ == '__main__': 110 | 111 | X, y = np.arange(20).reshape((10, 2)), np.arange(10) 112 | print(X) 113 | print(y) 114 | tr, te, vl = train_test_val_split(X, y, shuffle=True) 115 | print(tr) 116 | print(te) 117 | print(vl) 118 | 119 | import sys 120 | sys.path.append('..') 121 | sys.path.append('../vggish') 122 | from audio_model import define_audio_slim 123 | from vggish_slim import define_vggish_slim 124 | import tensorflow as tf 125 | with tf.Graph().as_default() as graph: 126 | # define_vggish_slim(training=False) 127 | define_audio_slim(training=False) 128 | calculate_flops(graph) 129 | pass -------------------------------------------------------------------------------- /audio_inference_demo.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: luuil@outlook.com 3 | 4 | r"""Test on audio model.""" 5 | 6 | from __future__ import print_function 7 | import sys 8 | sys.path.append('./audio') 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | from os.path import join as pjoin 13 | from sklearn.metrics import accuracy_score 14 | 15 | import audio_model 16 | import audio_params 17 | import audio_util as util 18 | from audio_feature_extractor import VGGishExtractor 19 | from audio_records import RecordsParser 20 | 21 | 22 | NUM_VGGISH_FEATURE_PER_EXAMPLE = audio_params.NUM_VGGISH_FEATURE_PER_EXAMPLE 23 | CKPT_DIR = audio_params.AUDIO_CHECKPOINT_DIR 24 | CKPT_NAME = audio_params.AUDIO_CHECKPOINT_NAME 25 | META = pjoin(CKPT_DIR, audio_params.AUDIO_TRAIN_NAME, '{ckpt}.meta'.format(ckpt=CKPT_NAME)) 26 | CKPT = pjoin(CKPT_DIR, audio_params.AUDIO_TRAIN_NAME, CKPT_NAME) 27 | 28 | VGGISH_CKPT = audio_params.VGGISH_CHECKPOINT 29 | VGGISH_PCA = audio_params.VGGISH_PCA_PARAMS 30 | 31 | SESS_CONFIG = tf.ConfigProto(allow_soft_placement=True) 32 | SESS_CONFIG.gpu_options.allow_growth = True 33 | 34 | def _restore_from_meta_and_ckpt(sess, meta, ckpt): 35 | """Restore graph from meta file and variables from ckpt file.""" 36 | saver = tf.train.import_meta_graph(meta) 37 | saver.restore(sess, ckpt) 38 | 39 | 40 | def _restore_from_defined_and_ckpt(sess, ckpt): 41 | """Restore graph from define and variables from ckpt file.""" 42 | with sess.graph.as_default(): 43 | audio_model.define_audio_slim(training=False) 44 | audio_model.load_audio_slim_checkpoint(sess, ckpt) 45 | 46 | def inference_wav(wav_file: str, label: int): 47 | """Test audio model on a wav file.""" 48 | graph = tf.Graph() 49 | with tf.Session(graph=graph, config=SESS_CONFIG) as sess: 50 | with VGGishExtractor(VGGISH_CKPT, 51 | VGGISH_PCA, 52 | audio_params.VGGISH_INPUT_TENSOR_NAME, 53 | audio_params.VGGISH_OUTPUT_TENSOR_NAME) as ve: 54 | vggish_features = ve.wavfile_to_features(wav_file) 55 | assert vggish_features is not None 56 | 57 | if NUM_VGGISH_FEATURE_PER_EXAMPLE > 1: 58 | vggish_features = vggish_features.reshape(1, -1) 59 | 60 | # restore graph 61 | # _restore_from_meta_and_ckpt(sess, META, CKPT) 62 | _restore_from_defined_and_ckpt(sess, CKPT) 63 | 64 | # get input and output tensor 65 | # graph = tf.get_default_graph() 66 | inputs = graph.get_tensor_by_name(audio_params.AUDIO_INPUT_TENSOR_NAME) 67 | outputs = graph.get_tensor_by_name(audio_params.AUDIO_OUTPUT_TENSOR_NAME) 68 | 69 | predictions = sess.run(outputs, feed_dict={inputs: vggish_features}) # [num_features, num_class] 70 | 71 | # voting 72 | predictions = np.mean(predictions, axis=0) 73 | label_pred = np.argmax(predictions) 74 | prob = predictions[label_pred] * 100 75 | 76 | print('\n'*3) 77 | print(f'{dict(zip(range(len(predictions)), predictions))}') 78 | print(f'true label: {label}') 79 | print(f'predict label: {label_pred}({prob:.03f}%)') 80 | print('\n'*3) 81 | 82 | 83 | def inference_on_test(): 84 | """Test audio model on test dataset.""" 85 | graph = tf.Graph() 86 | with tf.Session(graph=graph, config=SESS_CONFIG) as sess: 87 | rp = RecordsParser([audio_params.TF_RECORDS_TEST], 88 | audio_params.NUM_CLASSES, feature_shape=None) 89 | test_iterator, test_batch = rp.iterator(is_onehot=True, batch_size=1) 90 | 91 | 92 | # restore graph: 2 ways to restore, both will working 93 | # _restore_from_meta_and_ckpt(sess, META, CKPT) 94 | _restore_from_defined_and_ckpt(sess, CKPT) 95 | 96 | # get input and output tensor 97 | # graph = tf.get_default_graph() 98 | inputs = graph.get_tensor_by_name(audio_params.AUDIO_INPUT_TENSOR_NAME) 99 | outputs = graph.get_tensor_by_name(audio_params.AUDIO_OUTPUT_TENSOR_NAME) 100 | 101 | sess.run(test_iterator.initializer) 102 | predicted = [] 103 | groundtruth = [] 104 | while True: 105 | try: 106 | # feature: [batch_size, num_features] 107 | # label: [batch_size, num_classes] 108 | te_features, te_labels = sess.run(test_batch) 109 | except tf.errors.OutOfRangeError: 110 | break 111 | predictions = sess.run(outputs, feed_dict={inputs: te_features}) 112 | predicted.extend(np.argmax(predictions, 1)) 113 | groundtruth.extend(np.argmax(te_labels, 1)) 114 | # print(te_features.shape, te_labels, te_labels.shape) 115 | 116 | right = accuracy_score(groundtruth, predicted, normalize=False) # True: return prob 117 | print('all: {}, right: {}, wrong: {}, acc: {}'.format( 118 | len(predicted), right, len(predicted) - right, right/(len(predicted)))) 119 | 120 | 121 | if __name__ == '__main__': 122 | tf.logging.set_verbosity(tf.logging.INFO) 123 | inference_wav('./data/wav/16772-8-0-0.wav', 8) 124 | # inference_on_test() 125 | 126 | -------------------------------------------------------------------------------- /audio_params.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: luuil@outlook.com 3 | 4 | r"""Global parameters for the audio model. 5 | 6 | See audio_model.py for more information. 7 | """ 8 | 9 | from os.path import join as pjoin 10 | 11 | # Training 12 | AUDIO_TRAIN_NAME = 'urban_sound_train' # train name 13 | NUM_EPOCHS = 2000 14 | BATCH_SIZE = 128 15 | TENSORBOARD_DIR = './data/tensorboard' # Tensorboard 16 | 17 | # Path to UrbanSound8K 18 | WAV_FILE_PARENT_DIR = '/data1/data/UrbanSound8K-16bit/audio-classfied' 19 | NUM_VGGISH_FEATURE_PER_EXAMPLE = 1 20 | 21 | # Architectural constants. 22 | EMBEDDING_SIZE = 128 * NUM_VGGISH_FEATURE_PER_EXAMPLE # Size of embedding layer. 23 | NUM_FEATURES = EMBEDDING_SIZE 24 | NUM_CLASSES = 10 25 | 26 | 27 | # Hyperparameters used in training. 28 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 29 | LEARNING_RATE = 1e-5 # Learning rate for the Adam optimizer. 30 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 31 | NUM_UNITS = 10 # hidden units 32 | 33 | 34 | # Names of ops, tensors, and features. 35 | AUDIO_INPUT_OP_NAME = 'audio/vggish_input' 36 | AUDIO_INPUT_TENSOR_NAME = AUDIO_INPUT_OP_NAME + ':0' 37 | AUDIO_OUTPUT_OP_NAME = 'audio/prediction' 38 | AUDIO_OUTPUT_TENSOR_NAME = AUDIO_OUTPUT_OP_NAME + ':0' 39 | 40 | 41 | # Checkpoint 42 | AUDIO_CHECKPOINT_DIR = './data/train' 43 | AUDIO_CHECKPOINT_NAME = 'audio_urban_model.ckpt' 44 | AUDIO_CHECKPOINT = pjoin(AUDIO_CHECKPOINT_DIR, AUDIO_TRAIN_NAME, AUDIO_CHECKPOINT_NAME) 45 | 46 | 47 | # Records 48 | AUDIO_FEATURE_NAME = 'feature' 49 | AUDIO_LABEL_NAME = 'label' 50 | 51 | TF_RECORDS_TRAIN_NAME = 'urban_sound_train.tfrecords' 52 | TF_RECORDS_TEST_NAME = 'urban_sound_test.tfrecords' 53 | TF_RECORDS_VAL_NAME = 'urban_sound_val.tfrecords' 54 | 55 | TF_RECORDS_DIR = './data/records' 56 | TF_RECORDS_TRAIN = pjoin(TF_RECORDS_DIR, TF_RECORDS_TRAIN_NAME) 57 | TF_RECORDS_TEST = pjoin(TF_RECORDS_DIR, TF_RECORDS_TEST_NAME) 58 | TF_RECORDS_VAL = pjoin(TF_RECORDS_DIR, TF_RECORDS_VAL_NAME) 59 | 60 | 61 | # Vggish 62 | VGGISH_CHECKPOINT_DIR = './data/vggish' 63 | VGGISH_CHECKPOINT_NAME = 'vggish_model.ckpt' 64 | VGGISH_PCA_PARAMS_NAME = 'vggish_pca_params.npz' 65 | VGGISH_CHECKPOINT = pjoin(VGGISH_CHECKPOINT_DIR, VGGISH_CHECKPOINT_NAME) 66 | VGGISH_PCA_PARAMS = pjoin(VGGISH_CHECKPOINT_DIR, VGGISH_PCA_PARAMS_NAME) 67 | 68 | VGGISH_INPUT_OP_NAME = 'vggish/input_features' 69 | VGGISH_INPUT_TENSOR_NAME = VGGISH_INPUT_OP_NAME + ':0' 70 | VGGISH_OUTPUT_OP_NAME = 'vggish/embedding' 71 | VGGISH_OUTPUT_TENSOR_NAME = VGGISH_OUTPUT_OP_NAME + ':0' -------------------------------------------------------------------------------- /audio_train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: luuil@outlook.com 3 | 4 | r"""Train audio model.""" 5 | 6 | from __future__ import print_function 7 | import sys 8 | sys.path.append('./audio') 9 | 10 | import os 11 | import numpy as np 12 | import natsort 13 | import shutil 14 | 15 | import matplotlib 16 | matplotlib.use('Agg') 17 | import matplotlib.pyplot as plt 18 | # get_ipython().run_line_magic('matplotlib', 'inline') 19 | 20 | import tensorflow as tf 21 | from tensorflow.python.platform import gfile 22 | 23 | from sklearn.metrics import accuracy_score 24 | 25 | import audio_params as params 26 | import audio_util as util 27 | from audio_records import RecordsParser 28 | from audio_model import define_audio_slim 29 | from audio_feature_extractor import VGGishExtractor 30 | 31 | tf.logging.set_verbosity(tf.logging.DEBUG) 32 | 33 | flags = tf.app.flags 34 | 35 | flags.DEFINE_string( 36 | 'vggish_ckpt_dir', params.VGGISH_CHECKPOINT_DIR, 37 | 'Path to the VGGish checkpoint file.') 38 | 39 | flags.DEFINE_string( 40 | 'audio_ckpt_dir', params.AUDIO_CHECKPOINT_DIR, 41 | 'Path to the audio checkpoint file.') 42 | 43 | flags.DEFINE_string( 44 | 'train_name', params.AUDIO_TRAIN_NAME, 45 | 'Directory name for audio checkpoint file to save, i.e. audio checkpoint' 46 | 'file will save to `audio_ckpt_dir/train_name`.') 47 | 48 | flags.DEFINE_string( 49 | 'wavfile_parent_dir', params.WAV_FILE_PARENT_DIR, 50 | "Path to wav file's parent directory, each subdirectory is a class of files.") 51 | 52 | flags.DEFINE_string( 53 | 'records_dir', params.TF_RECORDS_DIR, 54 | "Path to the TF records file's parent directory.") 55 | 56 | flags.DEFINE_bool( 57 | 'restore_if_possible', True, 58 | "Restore variables from checkpoint if checkpoint is exists.") 59 | 60 | FLAGS = flags.FLAGS 61 | 62 | MAX_NUM_PER_CLASS = 2 ** 27 - 1 # ~134M 63 | 64 | 65 | train_records_path = os.path.join(FLAGS.records_dir, 66 | params.TF_RECORDS_TRAIN_NAME) 67 | 68 | test_records_path = os.path.join(FLAGS.records_dir, 69 | params.TF_RECORDS_TEST_NAME) 70 | 71 | val_records_path = os.path.join(FLAGS.records_dir, 72 | params.TF_RECORDS_VAL_NAME) 73 | 74 | vggish_ckpt_path = os.path.join(FLAGS.vggish_ckpt_dir, 75 | params.VGGISH_CHECKPOINT_NAME) 76 | 77 | vggish_pca_path = os.path.join(FLAGS.vggish_ckpt_dir, 78 | params.VGGISH_PCA_PARAMS_NAME) 79 | 80 | tensorboard_dir = os.path.join(params.TENSORBOARD_DIR, 81 | FLAGS.train_name) 82 | 83 | audio_ckpt_dir = os.path.join(FLAGS.audio_ckpt_dir, 84 | FLAGS.train_name) 85 | 86 | util.maybe_create_directory(tensorboard_dir) 87 | util.maybe_create_directory(audio_ckpt_dir) 88 | 89 | # backup params 90 | shutil.copy(os.path.join(os.path.dirname(__file__), 'audio_params.py'), audio_ckpt_dir) 91 | 92 | def _add_triaining_graph(): 93 | with tf.Graph().as_default() as graph: 94 | logits = define_audio_slim(training=True) 95 | tf.summary.histogram('logits', logits) 96 | # define training subgraph 97 | with tf.variable_scope('train'): 98 | labels = tf.placeholder(tf.float32, 99 | shape=[None, params.NUM_CLASSES], name='labels') 100 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2( 101 | logits=logits, labels=labels, name='cross_entropy') 102 | loss = tf.reduce_mean(cross_entropy, name='loss_op') 103 | tf.summary.scalar('loss', loss) 104 | # training 105 | global_step = tf.Variable(0, name='global_step', trainable=False, 106 | collections=[tf.GraphKeys.GLOBAL_VARIABLES, 107 | tf.GraphKeys.GLOBAL_STEP]) 108 | optimizer = tf.train.AdamOptimizer( 109 | learning_rate=params.LEARNING_RATE, 110 | epsilon=params.ADAM_EPSILON) 111 | optimizer.minimize(loss, global_step=global_step, name='train_op') 112 | return graph 113 | 114 | 115 | def _check_vggish_ckpt_exists(): 116 | """check VGGish checkpoint exists or not.""" 117 | util.maybe_create_directory(FLAGS.vggish_ckpt_dir) 118 | if not util.is_exists(vggish_ckpt_path): 119 | url = 'https://storage.googleapis.com/audioset/vggish_model.ckpt' 120 | util.maybe_download(url, params.VGGISH_CHECKPOINT_DIR) 121 | if not util.is_exists(vggish_pca_path): 122 | url = 'https://storage.googleapis.com/audioset/vggish_pca_params.npz' 123 | util.maybe_download(url, params.VGGISH_CHECKPOINT_DIR) 124 | 125 | 126 | def _wav_files_and_labels(): 127 | """Get wav files path and labels as a dict object. 128 | 129 | Args: 130 | None 131 | Returns: 132 | result = { label:wav_file_list } 133 | """ 134 | if not util.is_exists(FLAGS.wavfile_parent_dir): 135 | tf.logging.error("Can not find wav files at: {}, or you can download one at " 136 | "https://serv.cusp.nyu.edu/projects/urbansounddataset.".format( 137 | FLAGS.wavfile_parent_dir)) 138 | exit(1) 139 | 140 | 141 | sub_dirs = [x[0] for x in gfile.Walk(FLAGS.wavfile_parent_dir)] 142 | sub_dirs = natsort.natsorted(sub_dirs) 143 | sub_dirs = sub_dirs[1:] # The root directory comes first, so skip it. 144 | 145 | wav_files = [] 146 | wav_labels = [] 147 | for label_idx, sub_dir in enumerate(sub_dirs): 148 | extensions = ['wav'] 149 | file_list = [] 150 | dir_name = os.path.basename(sub_dir) 151 | if dir_name == FLAGS.wavfile_parent_dir: 152 | continue 153 | if dir_name[0] == '.': 154 | continue 155 | tf.logging.info("Looking for wavs in '" + dir_name + "'") 156 | for extension in extensions: 157 | file_glob = os.path.join(FLAGS.wavfile_parent_dir, dir_name, '*.' + extension) 158 | file_list.extend(gfile.Glob(file_glob)) 159 | if not file_list: 160 | tf.logging.warning('No files found') 161 | continue 162 | if len(file_list) < 20: 163 | tf.logging.warning('WARNING: Folder has less than 20 wavs,' 164 | 'which may cause issues.') 165 | elif len(file_list) > MAX_NUM_PER_CLASS: 166 | tf.logging.warning( 167 | 'WARNING: Folder {} has more than {} wavs. Some wavs will ' 168 | 'never be selected.'.format(dir_name, MAX_NUM_PER_CLASS)) 169 | # label_name = re.sub(r'[^a-z0-9]+', ' ', dir_name.lower()) 170 | wav_files.extend(file_list) 171 | wav_labels.extend([label_idx]*len(file_list)) 172 | assert len(wav_files) == len(wav_labels), \ 173 | 'Length of wav files and wav labels should be in consistent.' 174 | return wav_files, wav_labels 175 | 176 | def _create_records(): 177 | """Create audio `train`, `test` and `val` records file.""" 178 | tf.logging.info("Create records..") 179 | util.maybe_create_directory(FLAGS.records_dir) 180 | _check_vggish_ckpt_exists() 181 | wav_files, wav_labels = _wav_files_and_labels() 182 | tf.logging.info('Possible labels: {}'.format(set(wav_labels))) 183 | train, test, val = util.train_test_val_split(wav_files, wav_labels) 184 | with VGGishExtractor(vggish_ckpt_path, 185 | vggish_pca_path, 186 | params.VGGISH_INPUT_TENSOR_NAME, 187 | params.VGGISH_OUTPUT_TENSOR_NAME) as ve: 188 | 189 | train_x, train_y = train 190 | ve.create_records(train_records_path, train_x, train_y) 191 | 192 | test_x, test_y = test 193 | ve.create_records(test_records_path, test_x, test_y) 194 | 195 | val_x, val_y = val 196 | ve.create_records(val_records_path, val_x, val_y) 197 | tf.logging.info('Dataset size: Train-{} Test-{} Val-{}'.format( 198 | len(train_y), len(test_y), len(val_y))) 199 | 200 | def _get_records_iterator(records_path, batch_size): 201 | """Get records iterator""" 202 | if not util.is_exists(records_path): 203 | _create_records() 204 | rp = RecordsParser([records_path], params.NUM_CLASSES, feature_shape=None) 205 | return rp.iterator(is_onehot=True, batch_size=batch_size) 206 | 207 | 208 | def _add_scalar_summary(writer, tag, value, step): 209 | scalar_summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 210 | writer.add_summary(scalar_summary, step) 211 | 212 | 213 | def main(_): 214 | 215 | # initialize all log data containers: 216 | train_loss_per_epoch = [] 217 | val_loss_per_epoch = [] 218 | # initialize a list containing the 5 best val losses (is used to tell when to 219 | # save a model checkpoint): 220 | best_epoch_losses = [1000, 1000, 1000, 1000, 1000] 221 | 222 | sess_config = tf.ConfigProto(allow_soft_placement=True) 223 | sess_config.gpu_options.allow_growth = True 224 | 225 | with tf.Session(graph=_add_triaining_graph(), config=sess_config) as sess: 226 | train_iterator, train_batch = _get_records_iterator(train_records_path, 227 | batch_size=params.BATCH_SIZE) 228 | val_iterator, val_batch = _get_records_iterator(val_records_path, batch_size=128) 229 | test_iterator, test_batch = _get_records_iterator(test_records_path, batch_size=128) 230 | 231 | # op and tensors 232 | features_tensor = sess.graph.get_tensor_by_name(params.AUDIO_INPUT_TENSOR_NAME) 233 | output_tensor = sess.graph.get_tensor_by_name(params.AUDIO_OUTPUT_TENSOR_NAME) 234 | labels_tensor = sess.graph.get_tensor_by_name('train/labels:0') 235 | global_step_tensor = sess.graph.get_tensor_by_name('train/global_step:0') 236 | loss_tensor = sess.graph.get_tensor_by_name('train/loss_op:0') 237 | train_op = sess.graph.get_operation_by_name('train/train_op') 238 | 239 | summary_op = tf.summary.merge_all() 240 | summary_writer = tf.summary.FileWriter(tensorboard_dir, graph=sess.graph) 241 | saver = tf.train.Saver() 242 | 243 | init = tf.global_variables_initializer() 244 | sess.run(init) 245 | 246 | 247 | checkpoint_path = os.path.join(audio_ckpt_dir, params.AUDIO_CHECKPOINT_NAME) 248 | if util.is_exists(checkpoint_path+'.meta') and FLAGS.restore_if_possible: 249 | saver.restore(sess, checkpoint_path) 250 | 251 | # training and validation loop 252 | for epoch in range(params.NUM_EPOCHS): 253 | 254 | # training loop 255 | train_batch_losses = [] 256 | sess.run(train_iterator.initializer) 257 | while True: 258 | try: 259 | # feature: [batch_size, num_features] 260 | # label: [batch_size, num_classes] 261 | tr_features, tr_labels = sess.run(train_batch) 262 | except tf.errors.OutOfRangeError: 263 | break 264 | [num_steps, loss, summaries, _] = sess.run([global_step_tensor, loss_tensor, summary_op, train_op], 265 | feed_dict={features_tensor: tr_features, labels_tensor: tr_labels}) 266 | train_batch_losses.append(loss) 267 | summary_writer.add_summary(summaries, num_steps) 268 | print('Epoch {}/{}, Step {}: train loss {}'.format(epoch, params.NUM_EPOCHS, num_steps, loss)) 269 | 270 | # compute the train epoch loss: 271 | train_epoch_loss = np.mean(train_batch_losses) 272 | # save the train epoch loss: 273 | train_loss_per_epoch.append(train_epoch_loss) 274 | print("train epoch loss: %g" % train_epoch_loss) 275 | 276 | 277 | # validation loop 278 | val_batch_losses = [] 279 | sess.run(val_iterator.initializer) 280 | while True: 281 | try: 282 | val_features, val_labels = sess.run(val_batch) 283 | except tf.errors.OutOfRangeError: 284 | break 285 | [prediction, loss] = sess.run( 286 | [output_tensor, loss_tensor], 287 | feed_dict={features_tensor: val_features, labels_tensor: val_labels}) 288 | val_batch_losses.append(loss) 289 | # print('predict shape:', prediction.shape) 290 | # print("Example val loss: {:.5f}".format(loss)) 291 | val_loss = np.mean(val_batch_losses) 292 | val_loss_per_epoch.append(val_loss) 293 | print("validation loss: %g" % val_loss) 294 | _add_scalar_summary(summary_writer, 'train/val_loss', val_loss, num_steps) # add to summary 295 | 296 | # testing loop 297 | predicted = [] 298 | groundtruth = [] 299 | sess.run(test_iterator.initializer) 300 | while True: 301 | try: 302 | te_features, te_labels = sess.run(test_batch) 303 | except tf.errors.OutOfRangeError: 304 | break 305 | predictions = sess.run(output_tensor, feed_dict={features_tensor: te_features, labels_tensor: te_labels}) 306 | predicted.extend(np.argmax(predictions, axis=1)) 307 | groundtruth.extend(np.argmax(te_labels, axis=1)) 308 | test_acc = accuracy_score(groundtruth, predicted, normalize=True) 309 | print(f"test_acc: {test_acc}") 310 | _add_scalar_summary(summary_writer, 'train/test_acc', test_acc, num_steps) # add to summary 311 | 312 | 313 | if val_loss < min(best_epoch_losses): # (if top 5 performance on val:) 314 | # save the model weights to disk: 315 | checkpoint_path2 = os.path.join(audio_ckpt_dir, 316 | 'l{loss:.2f}_{name}'.format(loss=val_loss, name=params.AUDIO_CHECKPOINT_NAME)) 317 | saver.save(sess, checkpoint_path) 318 | saver.save(sess, checkpoint_path2) 319 | print("checkpoint saved in file: %s" % checkpoint_path) 320 | 321 | # update the top 5 val losses: 322 | index = best_epoch_losses.index(min(best_epoch_losses)) 323 | best_epoch_losses[index] = val_loss 324 | 325 | # plot the training loss vs epoch and save to disk: 326 | plt.figure(1) 327 | plt.plot(train_loss_per_epoch, "k^-") 328 | # plt.plot(train_loss_per_epoch, "k") 329 | plt.ylabel("loss") 330 | plt.xlabel("epoch") 331 | plt.title("training loss per epoch") 332 | plt.savefig("%s/train_loss_per_epoch.png" % audio_ckpt_dir) 333 | # plt.show() 334 | 335 | # plot the val loss vs epoch and save to disk: 336 | plt.figure(2) 337 | plt.plot(val_loss_per_epoch, "k^-") 338 | # plt.plot(val_loss_per_epoch, "k") 339 | plt.ylabel("loss") 340 | plt.xlabel("epoch") 341 | plt.title("validation loss per epoch") 342 | plt.savefig("%s/val_loss_per_epoch.png" % audio_ckpt_dir) 343 | # plt.show() 344 | 345 | if __name__ == '__main__': 346 | tf.app.run() -------------------------------------------------------------------------------- /conda.env.yml: -------------------------------------------------------------------------------- 1 | name: tf_audio_classify 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ca-certificates=2021.10.26=haa95532_2 6 | - certifi=2021.10.8=py37haa95532_0 7 | - openssl=1.1.1l=h2bbff1b_0 8 | - pip=21.2.4=py37haa95532_0 9 | - python=3.7.11=h6244533_0 10 | - setuptools=58.0.4=py37haa95532_0 11 | - sqlite=3.36.0=h2bbff1b_0 12 | - vc=14.2=h21ff451_1 13 | - vs2015_runtime=14.27.29016=h5e58377_2 14 | - wheel=0.37.0=pyhd3eb1b0_1 15 | - wincertstore=0.2=py37haa95532_2 16 | - pip: 17 | - absl-py==1.0.0 18 | - astor==0.8.1 19 | - cached-property==1.5.2 20 | - cffi==1.15.0 21 | - cycler==0.11.0 22 | - fonttools==4.28.2 23 | - gast==0.5.3 24 | - grpcio==1.42.0 25 | - h5py==3.6.0 26 | - importlib-metadata==4.8.2 27 | - joblib==1.1.0 28 | - keras-applications==1.0.8 29 | - keras-preprocessing==1.1.2 30 | - kiwisolver==1.3.2 31 | - llvmlite==0.37.0 32 | - markdown==3.3.6 33 | - matplotlib==3.5.0 34 | - mock==4.0.3 35 | - natsort==8.0.0 36 | - numba==0.54.1 37 | - numpy==1.20.3 38 | - packaging==21.3 39 | - pillow==8.4.0 40 | - protobuf==3.19.1 41 | - pycparser==2.21 42 | - pyparsing==3.0.6 43 | - python-dateutil==2.8.2 44 | - resampy==0.2.2 45 | - scikit-learn==1.0.1 46 | - scipy==1.7.3 47 | - setuptools-scm==6.3.2 48 | - six==1.16.0 49 | - sklearn==0.0 50 | - soundfile==0.10.3.post1 51 | - tensorboard==1.13.1 52 | - tensorflow-estimator==1.13.0 53 | - tensorflow-gpu==1.13.1 54 | - termcolor==1.1.0 55 | - threadpoolctl==3.0.0 56 | - tomli==1.2.2 57 | - typing-extensions==4.0.0 58 | - werkzeug==2.0.2 59 | - zipp==3.6.0 60 | -------------------------------------------------------------------------------- /data/records/urban_sound_test.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/records/urban_sound_test.tfrecords -------------------------------------------------------------------------------- /data/records/urban_sound_train.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/records/urban_sound_train.tfrecords -------------------------------------------------------------------------------- /data/records/urban_sound_val.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/records/urban_sound_val.tfrecords -------------------------------------------------------------------------------- /data/train/urban_sound_train/audio_urban_model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/train/urban_sound_train/audio_urban_model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /data/train/urban_sound_train/audio_urban_model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/train/urban_sound_train/audio_urban_model.ckpt.index -------------------------------------------------------------------------------- /data/train/urban_sound_train/audio_urban_model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/train/urban_sound_train/audio_urban_model.ckpt.meta -------------------------------------------------------------------------------- /data/vggish/README.MD: -------------------------------------------------------------------------------- 1 | ## VGGish checkpoint 2 | 3 | The [train script][i-train] will download for you automaticly if 4 | it can't find one. But you can also download manually and put it 5 | in this directory: 6 | 7 | - [vggish_model.ckpt][vggish-ckpt] 8 | - [vggish_pca_params.npz][vggish-pca] 9 | 10 | [i-train]: ../../audio_train.py 11 | [vggish-ckpt]: https://storage.googleapis.com/audioset/vggish_model.ckpt 12 | [vggish-pca]: https://storage.googleapis.com/audioset/vggish_pca_params.npz -------------------------------------------------------------------------------- /data/vggish/vggish_pca_params.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/vggish/vggish_pca_params.npz -------------------------------------------------------------------------------- /data/wav/13230-0-0-1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/wav/13230-0-0-1.wav -------------------------------------------------------------------------------- /data/wav/16772-8-0-0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/wav/16772-8-0-0.wav -------------------------------------------------------------------------------- /data/wav/7389-1-1-0.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luuil/Tensorflow-Audio-Classification/c3678e38c90b5ef0bbf0f9243cd5dcedd5fee7c6/data/wav/7389-1-1-0.wav -------------------------------------------------------------------------------- /data/wav/README.MD: -------------------------------------------------------------------------------- 1 | ## Test file from UrbanSound8K 2 | 3 | Details see below. 4 | 5 | ----------------- 6 | 7 | UrbanSound8K 8 | ============ 9 | 10 | Created By 11 | ---------- 12 | 13 | Justin Salamon*^, Christopher Jacoby* and Juan Pablo Bello* 14 | * Music and Audio Research Lab (MARL), New York University, USA 15 | ^ Center for Urban Science and Progress (CUSP), New York University, USA 16 | http://serv.cusp.nyu.edu/projects/urbansounddataset 17 | http://marl.smusic.nyu.edu/ 18 | http://cusp.nyu.edu/ 19 | 20 | Version 1.0 21 | 22 | 23 | Description 24 | ----------- 25 | 26 | This dataset contains 8732 labeled sound excerpts (<=4s) of urban sounds from 10 classes: air_conditioner, car_horn, 27 | children_playing, dog_bark, drilling, engine_idling, gun_shot, jackhammer, siren, and street_music. The classes are 28 | drawn from the urban sound taxonomy described in the following article, which also includes a detailed description of 29 | the dataset and how it was compiled: 30 | 31 | J. Salamon, C. Jacoby and J. P. Bello, "A Dataset and Taxonomy for Urban Sound Research", 32 | 22nd ACM International Conference on Multimedia, Orlando USA, Nov. 2014. 33 | 34 | All excerpts are taken from field recordings uploaded to www.freesound.org. The files are pre-sorted into ten folds 35 | (folders named fold1-fold10) to help in the reproduction of and comparison with the automatic classification results 36 | reported in the article above. 37 | 38 | In addition to the sound excerpts, a CSV file containing metadata about each excerpt is also provided. 39 | 40 | 41 | Audio Files Included 42 | -------------------- 43 | 44 | 8732 audio files of urban sounds (see description above) in WAV format. The sampling rate, bit depth, and number of 45 | channels are the same as those of the original file uploaded to Freesound (and hence may vary from file to file). 46 | 47 | 48 | Meta-data Files Included 49 | ------------------------ 50 | 51 | UrbanSound8k.csv 52 | 53 | This file contains meta-data information about every audio file in the dataset. This includes: 54 | 55 | * slice_file_name: 56 | The name of the audio file. The name takes the following format: [fsID]-[classID]-[occurrenceID]-[sliceID].wav, where: 57 | [fsID] = the Freesound ID of the recording from which this excerpt (slice) is taken 58 | [classID] = a numeric identifier of the sound class (see description of classID below for further details) 59 | [occurrenceID] = a numeric identifier to distinguish different occurrences of the sound within the original recording 60 | [sliceID] = a numeric identifier to distinguish different slices taken from the same occurrence 61 | 62 | * fsID: 63 | The Freesound ID of the recording from which this excerpt (slice) is taken 64 | 65 | * start 66 | The start time of the slice in the original Freesound recording 67 | 68 | * end: 69 | The end time of slice in the original Freesound recording 70 | 71 | * salience: 72 | A (subjective) salience rating of the sound. 1 = foreground, 2 = background. 73 | 74 | * fold: 75 | The fold number (1-10) to which this file has been allocated. 76 | 77 | * classID: 78 | A numeric identifier of the sound class: 79 | 0 = air_conditioner 80 | 1 = car_horn 81 | 2 = children_playing 82 | 3 = dog_bark 83 | 4 = drilling 84 | 5 = engine_idling 85 | 6 = gun_shot 86 | 7 = jackhammer 87 | 8 = siren 88 | 9 = street_music 89 | 90 | * class: 91 | The class name: air_conditioner, car_horn, children_playing, dog_bark, drilling, engine_idling, gun_shot, jackhammer, 92 | siren, street_music. 93 | 94 | 95 | Please Acknowledge UrbanSound8K in Academic Research 96 | ---------------------------------------------------- 97 | 98 | When UrbanSound8K is used for academic research, we would highly appreciate it if scientific publications of works 99 | partly based on the UrbanSound8K dataset cite the following publication: 100 | 101 | J. Salamon, C. Jacoby and J. P. Bello, "A Dataset and Taxonomy for Urban Sound Research", 102 | 22nd ACM International Conference on Multimedia, Orlando USA, Nov. 2014. 103 | 104 | The creation of this dataset was supported by a seed grant by NYU's Center for Urban Science and Progress (CUSP). 105 | 106 | 107 | Conditions of Use 108 | ----------------- 109 | 110 | Dataset compiled by Justin Salamon, Christopher Jacoby and Juan Pablo Bello. All files are excerpts of recordings 111 | uploaded to www.freesound.org. Please see FREESOUNDCREDITS.txt for an attribution list. 112 | 113 | The UrbanSound8K dataset is offered free of charge for non-commercial use only under the terms of the Creative Commons 114 | Attribution Noncommercial License (by-nc), version 3.0: http://creativecommons.org/licenses/by-nc/3.0/ 115 | 116 | The dataset and its contents are made available on an "as is" basis and without warranties of any kind, including 117 | without limitation satisfactory quality and conformity, merchantability, fitness for a particular purpose, accuracy or 118 | completeness, or absence of errors. Subject to any liability that may not be excluded or limited by law, NYU is not 119 | liable for, and expressly excludes, all liability for loss or damage however and whenever caused to anyone by any use of 120 | the UrbanSound8K dataset or any part of it. 121 | 122 | 123 | Feedback 124 | -------- 125 | 126 | Please help us improve UrbanSound8K by sending your feedback to: justin.salamon@nyu.edu or justin.salamon@gmail.com 127 | In case of a problem report please include as many details as possible. 128 | -------------------------------------------------------------------------------- /vggish/README.MD: -------------------------------------------------------------------------------- 1 | ## VGGish 2 | 3 | Refer to [Vggish][tool-vggish] 4 | 5 | [tool-vggish]: https://github.com/tensorflow/models/tree/master/research/audioset 6 | -------------------------------------------------------------------------------- /vggish/mel_features.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines routines to compute mel spectrogram features from audio waveform.""" 17 | 18 | import numpy as np 19 | 20 | 21 | def frame(data, window_length, hop_length): 22 | """Convert array into a sequence of successive possibly overlapping frames. 23 | 24 | An n-dimensional array of shape (num_samples, ...) is converted into an 25 | (n+1)-D array of shape (num_frames, window_length, ...), where each frame 26 | starts hop_length points after the preceding one. 27 | 28 | This is accomplished using stride_tricks, so the original data is not 29 | copied. However, there is no zero-padding, so any incomplete frames at the 30 | end are not included. 31 | 32 | Args: 33 | data: np.array of dimension N >= 1. 34 | window_length: Number of samples in each frame. 35 | hop_length: Advance (in samples) between each window. 36 | 37 | Returns: 38 | (N+1)-D np.array with as many rows as there are complete frames that can be 39 | extracted. 40 | """ 41 | num_samples = data.shape[0] 42 | num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length)) 43 | shape = (num_frames, window_length) + data.shape[1:] 44 | strides = (data.strides[0] * hop_length,) + data.strides 45 | return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides) 46 | 47 | 48 | def periodic_hann(window_length): 49 | """Calculate a "periodic" Hann window. 50 | 51 | The classic Hann window is defined as a raised cosine that starts and 52 | ends on zero, and where every value appears twice, except the middle 53 | point for an odd-length window. Matlab calls this a "symmetric" window 54 | and np.hanning() returns it. However, for Fourier analysis, this 55 | actually represents just over one cycle of a period N-1 cosine, and 56 | thus is not compactly expressed on a length-N Fourier basis. Instead, 57 | it's better to use a raised cosine that ends just before the final 58 | zero value - i.e. a complete cycle of a period-N cosine. Matlab 59 | calls this a "periodic" window. This routine calculates it. 60 | 61 | Args: 62 | window_length: The number of points in the returned window. 63 | 64 | Returns: 65 | A 1D np.array containing the periodic hann window. 66 | """ 67 | return 0.5 - (0.5 * np.cos(2 * np.pi / window_length * 68 | np.arange(window_length))) 69 | 70 | 71 | def stft_magnitude(signal, fft_length, 72 | hop_length=None, 73 | window_length=None): 74 | """Calculate the short-time Fourier transform magnitude. 75 | 76 | Args: 77 | signal: 1D np.array of the input time-domain signal. 78 | fft_length: Size of the FFT to apply. 79 | hop_length: Advance (in samples) between each frame passed to FFT. 80 | window_length: Length of each block of samples to pass to FFT. 81 | 82 | Returns: 83 | 2D np.array where each row contains the magnitudes of the fft_length/2+1 84 | unique values of the FFT for the corresponding frame of input samples. 85 | """ 86 | frames = frame(signal, window_length, hop_length) 87 | # Apply frame window to each frame. We use a periodic Hann (cosine of period 88 | # window_length) instead of the symmetric Hann of np.hanning (period 89 | # window_length-1). 90 | window = periodic_hann(window_length) 91 | windowed_frames = frames * window 92 | return np.abs(np.fft.rfft(windowed_frames, int(fft_length))) 93 | 94 | 95 | # Mel spectrum constants and functions. 96 | _MEL_BREAK_FREQUENCY_HERTZ = 700.0 97 | _MEL_HIGH_FREQUENCY_Q = 1127.0 98 | 99 | 100 | def hertz_to_mel(frequencies_hertz): 101 | """Convert frequencies to mel scale using HTK formula. 102 | 103 | Args: 104 | frequencies_hertz: Scalar or np.array of frequencies in hertz. 105 | 106 | Returns: 107 | Object of same size as frequencies_hertz containing corresponding values 108 | on the mel scale. 109 | """ 110 | return _MEL_HIGH_FREQUENCY_Q * np.log( 111 | 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)) 112 | 113 | 114 | def spectrogram_to_mel_matrix(num_mel_bins=20, 115 | num_spectrogram_bins=129, 116 | audio_sample_rate=8000, 117 | lower_edge_hertz=125.0, 118 | upper_edge_hertz=3800.0): 119 | """Return a matrix that can post-multiply spectrogram rows to make mel. 120 | 121 | Returns a np.array matrix A that can be used to post-multiply a matrix S of 122 | spectrogram values (STFT magnitudes) arranged as frames x bins to generate a 123 | "mel spectrogram" M of frames x num_mel_bins. M = S A. 124 | 125 | The classic HTK algorithm exploits the complementarity of adjacent mel bands 126 | to multiply each FFT bin by only one mel weight, then add it, with positive 127 | and negative signs, to the two adjacent mel bands to which that bin 128 | contributes. Here, by expressing this operation as a matrix multiply, we go 129 | from num_fft multiplies per frame (plus around 2*num_fft adds) to around 130 | num_fft^2 multiplies and adds. However, because these are all presumably 131 | accomplished in a single call to np.dot(), it's not clear which approach is 132 | faster in Python. The matrix multiplication has the attraction of being more 133 | general and flexible, and much easier to read. 134 | 135 | Args: 136 | num_mel_bins: How many bands in the resulting mel spectrum. This is 137 | the number of columns in the output matrix. 138 | num_spectrogram_bins: How many bins there are in the source spectrogram 139 | data, which is understood to be fft_size/2 + 1, i.e. the spectrogram 140 | only contains the nonredundant FFT bins. 141 | audio_sample_rate: Samples per second of the audio at the input to the 142 | spectrogram. We need this to figure out the actual frequencies for 143 | each spectrogram bin, which dictates how they are mapped into mel. 144 | lower_edge_hertz: Lower bound on the frequencies to be included in the mel 145 | spectrum. This corresponds to the lower edge of the lowest triangular 146 | band. 147 | upper_edge_hertz: The desired top edge of the highest frequency band. 148 | 149 | Returns: 150 | An np.array with shape (num_spectrogram_bins, num_mel_bins). 151 | 152 | Raises: 153 | ValueError: if frequency edges are incorrectly ordered or out of range. 154 | """ 155 | nyquist_hertz = audio_sample_rate / 2. 156 | if lower_edge_hertz < 0.0: 157 | raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz) 158 | if lower_edge_hertz >= upper_edge_hertz: 159 | raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % 160 | (lower_edge_hertz, upper_edge_hertz)) 161 | if upper_edge_hertz > nyquist_hertz: 162 | raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % 163 | (upper_edge_hertz, nyquist_hertz)) 164 | spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins) 165 | spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz) 166 | # The i'th mel band (starting from i=1) has center frequency 167 | # band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge 168 | # band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in 169 | # the band_edges_mel arrays. 170 | band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz), 171 | hertz_to_mel(upper_edge_hertz), num_mel_bins + 2) 172 | # Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins 173 | # of spectrogram values. 174 | mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins)) 175 | for i in range(num_mel_bins): 176 | lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3] 177 | # Calculate lower and upper slopes for every spectrogram bin. 178 | # Line segments are linear in the *mel* domain, not hertz. 179 | lower_slope = ((spectrogram_bins_mel - lower_edge_mel) / 180 | (center_mel - lower_edge_mel)) 181 | upper_slope = ((upper_edge_mel - spectrogram_bins_mel) / 182 | (upper_edge_mel - center_mel)) 183 | # .. then intersect them with each other and zero. 184 | mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope, 185 | upper_slope)) 186 | # HTK excludes the spectrogram DC bin; make sure it always gets a zero 187 | # coefficient. 188 | mel_weights_matrix[0, :] = 0.0 189 | return mel_weights_matrix 190 | 191 | 192 | def log_mel_spectrogram(data, 193 | audio_sample_rate=8000, 194 | log_offset=0.0, 195 | window_length_secs=0.025, 196 | hop_length_secs=0.010, 197 | **kwargs): 198 | """Convert waveform to a log magnitude mel-frequency spectrogram. 199 | 200 | Args: 201 | data: 1D np.array of waveform data. 202 | audio_sample_rate: The sampling rate of data. 203 | log_offset: Add this to values when taking log to avoid -Infs. 204 | window_length_secs: Duration of each window to analyze. 205 | hop_length_secs: Advance between successive analysis windows. 206 | **kwargs: Additional arguments to pass to spectrogram_to_mel_matrix. 207 | 208 | Returns: 209 | 2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank 210 | magnitudes for successive frames. 211 | """ 212 | window_length_samples = int(round(audio_sample_rate * window_length_secs)) 213 | hop_length_samples = int(round(audio_sample_rate * hop_length_secs)) 214 | fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) 215 | spectrogram = stft_magnitude( 216 | data, 217 | fft_length=fft_length, 218 | hop_length=hop_length_samples, 219 | window_length=window_length_samples) 220 | mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix( 221 | num_spectrogram_bins=spectrogram.shape[1], 222 | audio_sample_rate=audio_sample_rate, **kwargs)) 223 | return np.log(mel_spectrogram + log_offset) 224 | -------------------------------------------------------------------------------- /vggish/vggish_inference_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A simple demonstration of running VGGish in inference mode. 17 | 18 | This is intended as a toy example that demonstrates how the various building 19 | blocks (feature extraction, model definition and loading, postprocessing) work 20 | together in an inference context. 21 | 22 | A WAV file (assumed to contain signed 16-bit PCM samples) is read in, converted 23 | into log mel spectrogram examples, fed into VGGish, the raw embedding output is 24 | whitened and quantized, and the postprocessed embeddings are optionally written 25 | in a SequenceExample to a TFRecord file (using the same format as the embedding 26 | features released in AudioSet). 27 | 28 | Usage: 29 | # Run a WAV file through the model and print the embeddings. The model 30 | # checkpoint is loaded from vggish_model.ckpt and the PCA parameters are 31 | # loaded from vggish_pca_params.npz in the current directory. 32 | $ python vggish_inference_demo.py --wav_file /path/to/a/wav/file 33 | 34 | # Run a WAV file through the model and also write the embeddings to 35 | # a TFRecord file. The model checkpoint and PCA parameters are explicitly 36 | # passed in as well. 37 | $ python vggish_inference_demo.py --wav_file /path/to/a/wav/file \ 38 | --tfrecord_file /path/to/tfrecord/file \ 39 | --checkpoint /path/to/model/checkpoint \ 40 | --pca_params /path/to/pca/params 41 | 42 | # Run a built-in input (a sine wav) through the model and print the 43 | # embeddings. Associated model files are read from the current directory. 44 | $ python vggish_inference_demo.py 45 | """ 46 | 47 | from __future__ import print_function 48 | 49 | import numpy as np 50 | from scipy.io import wavfile 51 | import six 52 | import tensorflow as tf 53 | 54 | import vggish_input 55 | import vggish_params 56 | import vggish_postprocess 57 | import vggish_slim 58 | 59 | flags = tf.app.flags 60 | 61 | flags.DEFINE_string( 62 | 'wav_file', '/data1/data/UrbanSound8K-16bit/audio/fold1/99180-9-0-7.wav', 63 | 'Path to a wav file. Should contain signed 16-bit PCM samples. ' 64 | 'If none is provided, a synthetic sound is used.') 65 | 66 | flags.DEFINE_string( 67 | 'checkpoint', 'vggish_model.ckpt', 68 | 'Path to the VGGish checkpoint file.') 69 | 70 | flags.DEFINE_string( 71 | 'pca_params', 'vggish_pca_params.npz', 72 | 'Path to the VGGish PCA parameters file.') 73 | 74 | flags.DEFINE_string( 75 | 'tfrecord_file', None, 76 | 'Path to a TFRecord file where embeddings will be written.') 77 | 78 | FLAGS = flags.FLAGS 79 | 80 | 81 | def main(_): 82 | # In this simple example, we run the examples from a single audio file through 83 | # the model. If none is provided, we generate a synthetic input. 84 | if FLAGS.wav_file: 85 | wav_file = FLAGS.wav_file 86 | else: 87 | # Write a WAV of a sine wav into an in-memory file object. 88 | num_secs = 5 89 | freq = 1000 90 | sr = 44100 91 | t = np.linspace(0, num_secs, int(num_secs * sr)) 92 | x = np.sin(2 * np.pi * freq * t) 93 | # Convert to signed 16-bit samples. 94 | samples = np.clip(x * 32768, -32768, 32767).astype(np.int16) 95 | wav_file = six.BytesIO() 96 | wavfile.write(wav_file, sr, samples) 97 | wav_file.seek(0) 98 | examples_batch = vggish_input.wavfile_to_examples(wav_file) 99 | print(examples_batch, examples_batch.shape) 100 | 101 | # Prepare a postprocessor to munge the model embeddings. 102 | pproc = vggish_postprocess.Postprocessor(FLAGS.pca_params) 103 | 104 | # If needed, prepare a record writer to store the postprocessed embeddings. 105 | writer = tf.python_io.TFRecordWriter( 106 | FLAGS.tfrecord_file) if FLAGS.tfrecord_file else None 107 | 108 | with tf.Graph().as_default(), tf.Session() as sess: 109 | # Define the model in inference mode, load the checkpoint, and 110 | # locate input and output tensors. 111 | vggish_slim.define_vggish_slim(training=False) 112 | vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint) 113 | features_tensor = sess.graph.get_tensor_by_name( 114 | vggish_params.INPUT_TENSOR_NAME) 115 | embedding_tensor = sess.graph.get_tensor_by_name( 116 | vggish_params.OUTPUT_TENSOR_NAME) 117 | 118 | # Run inference and postprocessing. 119 | [embedding_batch] = sess.run([embedding_tensor], 120 | feed_dict={features_tensor: examples_batch}) 121 | print(embedding_batch) 122 | postprocessed_batch = pproc.postprocess(embedding_batch) 123 | print(postprocessed_batch, postprocessed_batch.shape) 124 | 125 | # Write the postprocessed embeddings as a SequenceExample, in a similar 126 | # format as the features released in AudioSet. Each row of the batch of 127 | # embeddings corresponds to roughly a second of audio (96 10ms frames), and 128 | # the rows are written as a sequence of bytes-valued features, where each 129 | # feature value contains the 128 bytes of the whitened quantized embedding. 130 | seq_example = tf.train.SequenceExample( 131 | feature_lists=tf.train.FeatureLists( 132 | feature_list={ 133 | vggish_params.AUDIO_EMBEDDING_FEATURE_NAME: 134 | tf.train.FeatureList( 135 | feature=[ 136 | tf.train.Feature( 137 | bytes_list=tf.train.BytesList( 138 | value=[embedding.tobytes()])) 139 | for embedding in postprocessed_batch 140 | ] 141 | ) 142 | } 143 | ) 144 | ) 145 | print(seq_example) 146 | if writer: 147 | writer.write(seq_example.SerializeToString()) 148 | 149 | if writer: 150 | writer.close() 151 | 152 | if __name__ == '__main__': 153 | tf.app.run() 154 | -------------------------------------------------------------------------------- /vggish/vggish_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Compute input examples for VGGish from audio waveform.""" 17 | 18 | import numpy as np 19 | import resampy 20 | from scipy.io import wavfile 21 | 22 | import mel_features 23 | import vggish_params 24 | import six 25 | 26 | 27 | def waveform_to_examples(data, sample_rate): 28 | """Converts audio waveform into an array of examples for VGGish. 29 | 30 | Args: 31 | data: np.array of either one dimension (mono) or two dimensions 32 | (multi-channel, with the outer dimension representing channels). 33 | Each sample is generally expected to lie in the range [-1.0, +1.0], 34 | although this is not required. 35 | sample_rate: Sample rate of data. 36 | 37 | Returns: 38 | 3-D np.array of shape [num_examples, num_frames, num_bands] which represents 39 | a sequence of examples, each of which contains a patch of log mel 40 | spectrogram, covering num_frames frames of audio and num_bands mel frequency 41 | bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. 42 | """ 43 | # Convert to mono. 44 | if len(data.shape) > 1: 45 | data = np.mean(data, axis=1) 46 | # Resample to the rate assumed by VGGish. 47 | if sample_rate != vggish_params.SAMPLE_RATE: 48 | data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) 49 | 50 | # Compute log mel spectrogram features. 51 | log_mel = mel_features.log_mel_spectrogram( 52 | data, 53 | audio_sample_rate=vggish_params.SAMPLE_RATE, 54 | log_offset=vggish_params.LOG_OFFSET, 55 | window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS, 56 | hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS, 57 | num_mel_bins=vggish_params.NUM_MEL_BINS, 58 | lower_edge_hertz=vggish_params.MEL_MIN_HZ, 59 | upper_edge_hertz=vggish_params.MEL_MAX_HZ) 60 | 61 | # Frame features into examples. 62 | features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS 63 | example_window_length = int(round( 64 | vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate)) 65 | example_hop_length = int(round( 66 | vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate)) 67 | log_mel_examples = mel_features.frame( 68 | log_mel, 69 | window_length=example_window_length, 70 | hop_length=example_hop_length) 71 | return log_mel_examples 72 | 73 | 74 | def wavfile_to_examples(wav_file): 75 | """Convenience wrapper around waveform_to_examples() for a common WAV format. 76 | 77 | Args: 78 | wav_file: String path to a file, or a file-like object. The file 79 | is assumed to contain WAV audio data with signed 16-bit PCM samples. 80 | 81 | Returns: 82 | See waveform_to_examples. 83 | """ 84 | sr, wav_data = wavfile.read(wav_file) 85 | assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype 86 | samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] 87 | return waveform_to_examples(samples, sr) 88 | -------------------------------------------------------------------------------- /vggish/vggish_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Global parameters for the VGGish model. 17 | 18 | See vggish_slim.py for more information. 19 | """ 20 | 21 | from os.path import join as pjoin 22 | 23 | # Architectural constants. 24 | NUM_FRAMES = 96 # Frames in input mel-spectrogram patch. 25 | NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch. 26 | EMBEDDING_SIZE = 128 # Size of embedding layer. 27 | 28 | # Hyperparameters used in feature and example generation. 29 | SAMPLE_RATE = 16000 30 | STFT_WINDOW_LENGTH_SECONDS = 0.025 31 | STFT_HOP_LENGTH_SECONDS = 0.010 32 | NUM_MEL_BINS = NUM_BANDS 33 | MEL_MIN_HZ = 125 34 | MEL_MAX_HZ = 7500 35 | LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram. 36 | EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames 37 | EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap. 38 | 39 | # Parameters used for embedding postprocessing. 40 | PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors' 41 | PCA_MEANS_NAME = 'pca_means' 42 | QUANTIZE_MIN_VAL = -2.0 43 | QUANTIZE_MAX_VAL = +2.0 44 | 45 | # Hyperparameters used in training. 46 | INIT_STDDEV = 0.01 # Standard deviation used to initialize weights. 47 | LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer. 48 | ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer. 49 | 50 | # Names of ops, tensors, and features. 51 | INPUT_OP_NAME = 'vggish/input_features' 52 | INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0' 53 | OUTPUT_OP_NAME = 'vggish/embedding' 54 | OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0' 55 | AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding' 56 | 57 | # Checkpoint 58 | CHECKPOINT_DIR = './data/vggish' 59 | CHECKPOINT_NAME = 'vggish_model.ckpt' 60 | PCA_PARAMS_NAME = 'vggish_pca_params.npz' 61 | CHECKPOINT = pjoin(CHECKPOINT_DIR, CHECKPOINT_NAME) 62 | PCA_PARAMS = pjoin(CHECKPOINT_DIR, PCA_PARAMS_NAME) -------------------------------------------------------------------------------- /vggish/vggish_postprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Post-process embeddings from VGGish.""" 17 | 18 | import numpy as np 19 | 20 | import vggish_params 21 | 22 | 23 | class Postprocessor(object): 24 | """Post-processes VGGish embeddings. 25 | 26 | The initial release of AudioSet included 128-D VGGish embeddings for each 27 | segment of AudioSet. These released embeddings were produced by applying 28 | a PCA transformation (technically, a whitening transform is included as well) 29 | and 8-bit quantization to the raw embedding output from VGGish, in order to 30 | stay compatible with the YouTube-8M project which provides visual embeddings 31 | in the same format for a large set of YouTube videos. This class implements 32 | the same PCA (with whitening) and quantization transformations. 33 | """ 34 | 35 | def __init__(self, pca_params_npz_path): 36 | """Constructs a postprocessor. 37 | 38 | Args: 39 | pca_params_npz_path: Path to a NumPy-format .npz file that 40 | contains the PCA parameters used in postprocessing. 41 | """ 42 | params = np.load(pca_params_npz_path) 43 | self._pca_matrix = params[vggish_params.PCA_EIGEN_VECTORS_NAME] 44 | # Load means into a column vector for easier broadcasting later. 45 | self._pca_means = params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1) 46 | assert self._pca_matrix.shape == ( 47 | vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE), ( 48 | 'Bad PCA matrix shape: %r' % (self._pca_matrix.shape,)) 49 | assert self._pca_means.shape == (vggish_params.EMBEDDING_SIZE, 1), ( 50 | 'Bad PCA means shape: %r' % (self._pca_means.shape,)) 51 | 52 | def postprocess(self, embeddings_batch): 53 | """Applies postprocessing to a batch of embeddings. 54 | 55 | Args: 56 | embeddings_batch: An nparray of shape [batch_size, embedding_size] 57 | containing output from the embedding layer of VGGish. 58 | 59 | Returns: 60 | An nparray of the same shape as the input but of type uint8, 61 | containing the PCA-transformed and quantized version of the input. 62 | """ 63 | assert len(embeddings_batch.shape) == 2, ( 64 | 'Expected 2-d batch, got %r' % (embeddings_batch.shape,)) 65 | assert embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE, ( 66 | 'Bad batch shape: %r' % (embeddings_batch.shape,)) 67 | 68 | # Apply PCA. 69 | # - Embeddings come in as [batch_size, embedding_size]. 70 | # - Transpose to [embedding_size, batch_size]. 71 | # - Subtract pca_means column vector from each column. 72 | # - Premultiply by PCA matrix of shape [output_dims, input_dims] 73 | # where both are are equal to embedding_size in our case. 74 | # - Transpose result back to [batch_size, embedding_size]. 75 | pca_applied = np.dot(self._pca_matrix, 76 | (embeddings_batch.T - self._pca_means)).T 77 | 78 | # Quantize by: 79 | # - clipping to [min, max] range 80 | clipped_embeddings = np.clip( 81 | pca_applied, vggish_params.QUANTIZE_MIN_VAL, 82 | vggish_params.QUANTIZE_MAX_VAL) 83 | # - convert to 8-bit in range [0.0, 255.0] 84 | quantized_embeddings = ( 85 | (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) * 86 | (255.0 / 87 | (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL))) 88 | # - cast 8-bit float to uint8 89 | quantized_embeddings = quantized_embeddings.astype(np.uint8) 90 | 91 | return quantized_embeddings 92 | -------------------------------------------------------------------------------- /vggish/vggish_slim.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Defines the 'VGGish' model used to generate AudioSet embedding features. 17 | 18 | The public AudioSet release (https://research.google.com/audioset/download.html) 19 | includes 128-D features extracted from the embedding layer of a VGG-like model 20 | that was trained on a large Google-internal YouTube dataset. Here we provide 21 | a TF-Slim definition of the same model, without any dependences on libraries 22 | internal to Google. We call it 'VGGish'. 23 | 24 | Note that we only define the model up to the embedding layer, which is the 25 | penultimate layer before the final classifier layer. We also provide various 26 | hyperparameter values (in vggish_params.py) that were used to train this model 27 | internally. 28 | 29 | For comparison, here is TF-Slim's VGG definition: 30 | https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py 31 | """ 32 | 33 | import tensorflow as tf 34 | import vggish_params as params 35 | 36 | slim = tf.contrib.slim 37 | 38 | 39 | def define_vggish_slim(training=False): 40 | """Defines the VGGish TensorFlow model. 41 | 42 | All ops are created in the current default graph, under the scope 'vggish/'. 43 | 44 | The input is a placeholder named 'vggish/input_features' of type float32 and 45 | shape [batch_size, num_frames, num_bands] where batch_size is variable and 46 | num_frames and num_bands are constants, and [num_frames, num_bands] represents 47 | a log-mel-scale spectrogram patch covering num_bands frequency bands and 48 | num_frames time frames (where each frame step is usually 10ms). This is 49 | produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET). 50 | The output is an op named 'vggish/embedding' which produces the activations of 51 | a 128-D embedding layer, which is usually the penultimate layer when used as 52 | part of a full model with a final classifier layer. 53 | 54 | Args: 55 | training: If true, all parameters are marked trainable. 56 | 57 | Returns: 58 | The op 'vggish/embeddings'. 59 | """ 60 | # Defaults: 61 | # - All weights are initialized to N(0, INIT_STDDEV). 62 | # - All biases are initialized to 0. 63 | # - All activations are ReLU. 64 | # - All convolutions are 3x3 with stride 1 and SAME padding. 65 | # - All max-pools are 2x2 with stride 2 and SAME padding. 66 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 67 | weights_initializer=tf.truncated_normal_initializer( 68 | stddev=params.INIT_STDDEV), 69 | biases_initializer=tf.zeros_initializer(), 70 | activation_fn=tf.nn.relu, 71 | trainable=training), \ 72 | slim.arg_scope([slim.conv2d], 73 | kernel_size=[3, 3], stride=1, padding='SAME'), \ 74 | slim.arg_scope([slim.max_pool2d], 75 | kernel_size=[2, 2], stride=2, padding='SAME'), \ 76 | tf.variable_scope('vggish'): 77 | # Input: a batch of 2-D log-mel-spectrogram patches. 78 | features = tf.placeholder( 79 | tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS), 80 | name='input_features') 81 | # Reshape to 4-D so that we can convolve a batch with conv2d(). 82 | net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1]) 83 | 84 | # The VGG stack of alternating convolutions and max-pools. 85 | net = slim.conv2d(net, 64, scope='conv1') 86 | net = slim.max_pool2d(net, scope='pool1') 87 | net = slim.conv2d(net, 128, scope='conv2') 88 | net = slim.max_pool2d(net, scope='pool2') 89 | net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3') 90 | net = slim.max_pool2d(net, scope='pool3') 91 | net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4') 92 | net = slim.max_pool2d(net, scope='pool4') 93 | 94 | # Flatten before entering fully-connected layers 95 | net = slim.flatten(net) 96 | net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1') 97 | # The embedding layer. 98 | net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2') 99 | return tf.identity(net, name='embedding') 100 | 101 | 102 | def load_vggish_slim_checkpoint(session, checkpoint_path): 103 | """Loads a pre-trained VGGish-compatible checkpoint. 104 | 105 | This function can be used as an initialization function (referred to as 106 | init_fn in TensorFlow documentation) which is called in a Session after 107 | initializating all variables. When used as an init_fn, this will load 108 | a pre-trained checkpoint that is compatible with the VGGish model 109 | definition. Only variables defined by VGGish will be loaded. 110 | 111 | Args: 112 | session: an active TensorFlow session. 113 | checkpoint_path: path to a file containing a checkpoint that is 114 | compatible with the VGGish model definition. 115 | """ 116 | # Get the list of names of all VGGish variables that exist in 117 | # the checkpoint (i.e., all inference-mode VGGish variables). 118 | with tf.Graph().as_default(): 119 | define_vggish_slim(training=False) 120 | vggish_var_names = [v.name for v in tf.global_variables()] 121 | 122 | # Get the list of all currently existing variables that match 123 | # the list of variable names we just computed. 124 | vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names] 125 | 126 | # Use a Saver to restore just the variables selected above. 127 | saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained', 128 | write_version=1) 129 | saver.restore(session, checkpoint_path) 130 | 131 | 132 | def load_defined_vggish_slim_checkpoint(session, checkpoint_path): 133 | """Loads a pre-trained VGGish-compatible checkpoint. 134 | 135 | This function can be used as an initialization function (referred to as 136 | init_fn in TensorFlow documentation) which is called in a Session after 137 | initializating all variables. When used as an init_fn, this will load 138 | a pre-trained checkpoint that is compatible with the VGGish model 139 | definition. Only variables defined by VGGish will be loaded. 140 | 141 | Args: 142 | session: an active TensorFlow session with an exist default graph 143 | checkpoint_path: path to a file containing a checkpoint that is 144 | compatible with the VGGish model definition. 145 | """ 146 | # Get the list of names of all VGGish variables that exist in 147 | # the checkpoint (i.e., all inference-mode VGGish variables). 148 | with tf.Graph().as_default(): 149 | define_vggish_slim(training=False) 150 | vggish_var_names = [v.name for v in tf.global_variables()] 151 | 152 | # Get list of variables from exist graph which passed by session 153 | with session.graph.as_default(): 154 | global_variables = tf.global_variables() 155 | 156 | # Get the list of all currently existing variables that match 157 | # the list of variable names we just computed. 158 | vggish_vars = [v for v in global_variables if v.name in vggish_var_names] 159 | 160 | # Use a Saver to restore just the variables selected above. 161 | saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained', 162 | write_version=1) 163 | saver.restore(session, checkpoint_path) 164 | -------------------------------------------------------------------------------- /vggish/vggish_smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A smoke test for VGGish. 17 | 18 | This is a simple smoke test of a local install of VGGish and its associated 19 | downloaded files. We create a synthetic sound, extract log mel spectrogram 20 | features, run them through VGGish, post-process the embedding ouputs, and 21 | check some simple statistics of the results, allowing for variations that 22 | might occur due to platform/version differences in the libraries we use. 23 | 24 | Usage: 25 | - Download the VGGish checkpoint and PCA parameters into the same directory as 26 | the VGGish source code. If you keep them elsewhere, update the checkpoint_path 27 | and pca_params_path variables below. 28 | - Run: 29 | $ python vggish_smoke_test.py 30 | """ 31 | 32 | from __future__ import print_function 33 | 34 | import numpy as np 35 | import tensorflow as tf 36 | 37 | import vggish_input 38 | import vggish_params 39 | import vggish_postprocess 40 | import vggish_slim 41 | 42 | print('\nTesting your install of VGGish\n') 43 | 44 | # Paths to downloaded VGGish files. 45 | checkpoint_path = 'vggish_model.ckpt' 46 | pca_params_path = 'vggish_pca_params.npz' 47 | 48 | # Relative tolerance of errors in mean and standard deviation of embeddings. 49 | rel_error = 0.1 # Up to 10% 50 | 51 | # Generate a 1 kHz sine wave at 44.1 kHz (we use a high sampling rate 52 | # to test resampling to 16 kHz during feature extraction). 53 | num_secs = 3 54 | freq = 1000 55 | sr = 44100 56 | t = np.linspace(0, num_secs, int(num_secs * sr)) 57 | x = np.sin(2 * np.pi * freq * t) 58 | 59 | # Produce a batch of log mel spectrogram examples. 60 | input_batch = vggish_input.waveform_to_examples(x, sr) 61 | 62 | print(input_batch, input_batch.shape, type(input_batch)) 63 | print(input_batch[0][0][0], input_batch[0][0][0].shape, type(input_batch[0][0][0])) 64 | 65 | print('Log Mel Spectrogram example: ', input_batch[0]) 66 | np.testing.assert_equal( 67 | input_batch.shape, 68 | [num_secs, vggish_params.NUM_FRAMES, vggish_params.NUM_BANDS]) 69 | print('input_batch shape:', input_batch.shape) 70 | 71 | # Define VGGish, load the checkpoint, and run the batch through the model to 72 | # produce embeddings. 73 | with tf.Graph().as_default(), tf.Session() as sess: 74 | vggish_slim.define_vggish_slim() 75 | vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path) 76 | 77 | features_tensor = sess.graph.get_tensor_by_name( 78 | vggish_params.INPUT_TENSOR_NAME) 79 | embedding_tensor = sess.graph.get_tensor_by_name( 80 | vggish_params.OUTPUT_TENSOR_NAME) 81 | [embedding_batch] = sess.run([embedding_tensor], 82 | feed_dict={features_tensor: input_batch}) 83 | print('VGGish embedding: ', embedding_batch[0]) 84 | print('VGGish embedding shape: ', embedding_batch.shape) 85 | expected_embedding_mean = 0.131 86 | expected_embedding_std = 0.238 87 | np.testing.assert_allclose( 88 | [np.mean(embedding_batch), np.std(embedding_batch)], 89 | [expected_embedding_mean, expected_embedding_std], 90 | rtol=rel_error) 91 | 92 | # Postprocess the results to produce whitened quantized embeddings. 93 | pproc = vggish_postprocess.Postprocessor(pca_params_path) 94 | postprocessed_batch = pproc.postprocess(embedding_batch) 95 | print('Postprocessed VGGish embedding: ', postprocessed_batch[0]) 96 | print('Postprocessed VGGish embedding shape: ', postprocessed_batch.shape) 97 | 98 | expected_postprocessed_mean = 123.0 99 | expected_postprocessed_std = 75.0 100 | np.testing.assert_allclose( 101 | [np.mean(postprocessed_batch), np.std(postprocessed_batch)], 102 | [expected_postprocessed_mean, expected_postprocessed_std], 103 | rtol=rel_error) 104 | 105 | print('\nLooks Good To Me!\n') 106 | -------------------------------------------------------------------------------- /vggish/vggish_train_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | r"""A simple demonstration of running VGGish in training mode. 17 | 18 | This is intended as a toy example that demonstrates how to use the VGGish model 19 | definition within a larger model that adds more layers on top, and then train 20 | the larger model. If you let VGGish train as well, then this allows you to 21 | fine-tune the VGGish model parameters for your application. If you don't let 22 | VGGish train, then you use VGGish as a feature extractor for the layers above 23 | it. 24 | 25 | For this toy task, we are training a classifier to distinguish between three 26 | classes: sine waves, constant signals, and white noise. We generate synthetic 27 | waveforms from each of these classes, convert into shuffled batches of log mel 28 | spectrogram examples with associated labels, and feed the batches into a model 29 | that includes VGGish at the bottom and a couple of additional layers on top. We 30 | also plumb in labels that are associated with the examples, which feed a label 31 | loss used for training. 32 | 33 | Usage: 34 | # Run training for 100 steps using a model checkpoint in the default 35 | # location (vggish_model.ckpt in the current directory). Allow VGGish 36 | # to get fine-tuned. 37 | $ python vggish_train_demo.py --num_batches 100 38 | 39 | # Same as before but run for fewer steps and don't change VGGish parameters 40 | # and use a checkpoint in a different location 41 | $ python vggish_train_demo.py --num_batches 50 \ 42 | --train_vggish=False \ 43 | --checkpoint /path/to/model/checkpoint 44 | """ 45 | 46 | from __future__ import print_function 47 | 48 | from random import shuffle 49 | 50 | import numpy as np 51 | import tensorflow as tf 52 | 53 | import vggish_input 54 | import vggish_params 55 | import vggish_slim 56 | 57 | flags = tf.app.flags 58 | slim = tf.contrib.slim 59 | 60 | flags.DEFINE_integer( 61 | 'num_batches', 30, 62 | 'Number of batches of examples to feed into the model. Each batch is of ' 63 | 'variable size and contains shuffled examples of each class of audio.') 64 | 65 | flags.DEFINE_boolean( 66 | 'train_vggish', True, 67 | 'If Frue, allow VGGish parameters to change during training, thus ' 68 | 'fine-tuning VGGish. If False, VGGish parameters are fixed, thus using ' 69 | 'VGGish as a fixed feature extractor.') 70 | 71 | flags.DEFINE_string( 72 | 'checkpoint', 'vggish_model.ckpt', 73 | 'Path to the VGGish checkpoint file.') 74 | 75 | FLAGS = flags.FLAGS 76 | 77 | _NUM_CLASSES = 3 78 | 79 | 80 | def _get_examples_batch(): 81 | """Returns a shuffled batch of examples of all audio classes. 82 | 83 | Note that this is just a toy function because this is a simple demo intended 84 | to illustrate how the training code might work. 85 | 86 | Returns: 87 | a tuple (features, labels) where features is a NumPy array of shape 88 | [batch_size, num_frames, num_bands] where the batch_size is variable and 89 | each row is a log mel spectrogram patch of shape [num_frames, num_bands] 90 | suitable for feeding VGGish, while labels is a NumPy array of shape 91 | [batch_size, num_classes] where each row is a multi-hot label vector that 92 | provides the labels for corresponding rows in features. 93 | """ 94 | # Make a waveform for each class. 95 | num_seconds = 5 96 | sr = 44100 # Sampling rate. 97 | t = np.linspace(0, num_seconds, int(num_seconds * sr)) # Time axis. 98 | # Random sine wave. 99 | freq = np.random.uniform(100, 1000) 100 | sine = np.sin(2 * np.pi * freq * t) 101 | # Random constant signal. 102 | magnitude = np.random.uniform(-1, 1) 103 | const = magnitude * t 104 | # White noise. 105 | noise = np.random.normal(-1, 1, size=t.shape) 106 | 107 | # Make examples of each signal and corresponding labels. 108 | # Sine is class index 0, Const class index 1, Noise class index 2. 109 | sine_examples = vggish_input.waveform_to_examples(sine, sr) 110 | sine_labels = np.array([[1, 0, 0]] * sine_examples.shape[0]) 111 | const_examples = vggish_input.waveform_to_examples(const, sr) 112 | const_labels = np.array([[0, 1, 0]] * const_examples.shape[0]) 113 | noise_examples = vggish_input.waveform_to_examples(noise, sr) 114 | noise_labels = np.array([[0, 0, 1]] * noise_examples.shape[0]) 115 | 116 | # Shuffle (example, label) pairs across all classes. 117 | all_examples = np.concatenate((sine_examples, const_examples, noise_examples)) 118 | print('all_examples shape:', all_examples.shape) 119 | all_labels = np.concatenate((sine_labels, const_labels, noise_labels)) 120 | print('all_labels shape:', all_labels.shape) 121 | labeled_examples = list(zip(all_examples, all_labels)) 122 | shuffle(labeled_examples) 123 | 124 | # Separate and return the features and labels. 125 | features = [example for (example, _) in labeled_examples] 126 | labels = [label for (_, label) in labeled_examples] 127 | return (features, labels) 128 | 129 | 130 | def main(_): 131 | with tf.Graph().as_default(), tf.Session() as sess: 132 | # Define VGGish. 133 | embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish) 134 | 135 | # Define a shallow classification model and associated training ops on top 136 | # of VGGish. 137 | with tf.variable_scope('mymodel'): 138 | # Add a fully connected layer with 100 units. 139 | num_units = 100 140 | fc = slim.fully_connected(embeddings, num_units) 141 | 142 | # Add a classifier layer at the end, consisting of parallel logistic 143 | # classifiers, one per class. This allows for multi-class tasks. 144 | logits = slim.fully_connected( 145 | fc, _NUM_CLASSES, activation_fn=None, scope='logits') 146 | tf.sigmoid(logits, name='prediction') 147 | 148 | # Add training ops. 149 | with tf.variable_scope('train'): 150 | global_step = tf.Variable( 151 | 0, name='global_step', trainable=False, 152 | collections=[tf.GraphKeys.GLOBAL_VARIABLES, 153 | tf.GraphKeys.GLOBAL_STEP]) 154 | 155 | # Labels are assumed to be fed as a batch multi-hot vectors, with 156 | # a 1 in the position of each positive class label, and 0 elsewhere. 157 | labels = tf.placeholder( 158 | tf.float32, shape=(None, _NUM_CLASSES), name='labels') 159 | 160 | # Cross-entropy label loss. 161 | xent = tf.nn.sigmoid_cross_entropy_with_logits( 162 | logits=logits, labels=labels, name='xent') 163 | loss = tf.reduce_mean(xent, name='loss_op') 164 | tf.summary.scalar('loss', loss) 165 | 166 | # We use the same optimizer and hyperparameters as used to train VGGish. 167 | optimizer = tf.train.AdamOptimizer( 168 | learning_rate=vggish_params.LEARNING_RATE, 169 | epsilon=vggish_params.ADAM_EPSILON) 170 | optimizer.minimize(loss, global_step=global_step, name='train_op') 171 | 172 | # Initialize all variables in the model, and then load the pre-trained 173 | # VGGish checkpoint. 174 | sess.run(tf.global_variables_initializer()) 175 | vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint) 176 | 177 | # Locate all the tensors and ops we need for the training loop. 178 | features_tensor = sess.graph.get_tensor_by_name( 179 | vggish_params.INPUT_TENSOR_NAME) 180 | labels_tensor = sess.graph.get_tensor_by_name('mymodel/train/labels:0') 181 | global_step_tensor = sess.graph.get_tensor_by_name( 182 | 'mymodel/train/global_step:0') 183 | loss_tensor = sess.graph.get_tensor_by_name('mymodel/train/loss_op:0') 184 | train_op = sess.graph.get_operation_by_name('mymodel/train/train_op') 185 | 186 | # The training loop. 187 | for _ in range(FLAGS.num_batches): 188 | (features, labels) = _get_examples_batch() 189 | [num_steps, loss, _] = sess.run( 190 | [global_step_tensor, loss_tensor, train_op], 191 | feed_dict={features_tensor: features, labels_tensor: labels}) 192 | print('Step %d: loss %g' % (num_steps, loss)) 193 | 194 | if __name__ == '__main__': 195 | tf.app.run() 196 | --------------------------------------------------------------------------------