├── .gitignore ├── LICENSE ├── README.md ├── requirements.txt ├── sample_cnn ├── __init__.py ├── data │ ├── audio_processing.py │ └── mtt │ │ ├── annotation_processing.py │ │ └── build_mtt.py ├── evaluate.py ├── keras_utils │ └── tfrecord_model.py ├── model.py ├── ops │ ├── __init__.py │ ├── batch_inputs.py │ └── evaluation.py └── train.py ├── scripts ├── build_mtt.sh.template ├── evaluate.sh.template └── train.sh.template └── tests └── tfrecord_model_test_mnist.py /.gitignore: -------------------------------------------------------------------------------- 1 | *tmp* 2 | scripts/*.sh 3 | 4 | # Created by .ignore support plugin (hsz.mobi) 5 | ### Python template 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *,cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv/ 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | ### JetBrains template 97 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 98 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 99 | 100 | # User-specific stuff: 101 | .idea/workspace.xml 102 | .idea/tasks.xml 103 | 104 | # Sensitive or high-churn files: 105 | .idea/dataSources/ 106 | .idea/dataSources.ids 107 | .idea/dataSources.xml 108 | .idea/dataSources.local.xml 109 | .idea/sqlDataSources.xml 110 | .idea/dynamic.xml 111 | .idea/uiDesigner.xml 112 | 113 | # Gradle: 114 | .idea/gradle.xml 115 | .idea/libraries 116 | 117 | # Mongo Explorer plugin: 118 | .idea/mongoSettings.xml 119 | 120 | ## File-based project format: 121 | *.iws 122 | 123 | ## Plugin-specific files: 124 | 125 | # IntelliJ 126 | /out/ 127 | 128 | # mpeltonen/sbt-idea plugin 129 | .idea_modules/ 130 | 131 | # JIRA plugin 132 | atlassian-ide-plugin.xml 133 | 134 | # Crashlytics plugin (for Android Studio and IntelliJ) 135 | com_crashlytics_export_strings.xml 136 | crashlytics.properties 137 | crashlytics-build.properties 138 | fabric.properties 139 | ### Linux template 140 | *~ 141 | 142 | # temporary files which can be created if a process still has a handle open of a deleted file 143 | .fuse_hidden* 144 | 145 | # KDE directory preferences 146 | .directory 147 | 148 | # Linux trash folder which might appear on any partition or disk 149 | .Trash-* 150 | 151 | # .nfs files are created when an open file is removed but is still being accessed 152 | .nfs* 153 | ### macOS template 154 | *.DS_Store 155 | .AppleDouble 156 | .LSOverride 157 | 158 | # Icon must end with two \r 159 | Icon 160 | 161 | 162 | # Thumbnails 163 | ._* 164 | 165 | # Files that might appear in the root of a volume 166 | .DocumentRevisions-V100 167 | .fseventsd 168 | .Spotlight-V100 169 | .TemporaryItems 170 | .Trashes 171 | .VolumeIcon.icns 172 | .com.apple.timemachine.donotpresent 173 | 174 | # Directories potentially created on remote AFP share 175 | .AppleDB 176 | .AppleDesktop 177 | Network Trash Folder 178 | Temporary Items 179 | .apdisk 180 | ### VirtualEnv template 181 | # Virtualenv 182 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 183 | .Python 184 | [Bb]in 185 | [Ii]nclude 186 | [Ll]ib 187 | [Ll]ib64 188 | [Ll]ocal 189 | #[Ss]cripts 190 | pyvenv.cfg 191 | .venv 192 | pip-selfcheck.json 193 | 194 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Tae-jun Kim 2 | 3 | Permission is hereby granted, free of charge, to any person 4 | obtaining a copy of this software and associated documentation 5 | files (the "Software"), to deal in the Software without 6 | restriction, including without limitation the rights to use, 7 | copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the 9 | Software is furnished to do so, subject to the following 10 | conditions: 11 | 12 | The above copyright notice and this permission notice shall be 13 | included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 17 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 19 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 20 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sample CNN 2 | ***A TensorFlow implementation of "Sample-level Deep Convolutional 3 | Neural Networks for Music Auto-tagging Using Raw Waveforms"*** 4 | 5 | This is a [TensorFlow][1] implementation of "[*Sample-level Deep 6 | Convolutional Neural Networks for Music Auto-tagging Using Raw 7 | Waveforms*][10]" using [Keras][11]. This repository only implements the 8 | best model of the paper. (the model described in Table 1; m=3, n=9) 9 | 10 | 11 | ## Table of contents 12 | * [Prerequisites](#prerequisites) 13 | * [Preparing MagnaTagATune (MTT) Dataset](#preparing-mtt) 14 | * [Preprocessing the MTT dataset](#preprocessing) 15 | * [Training a model from scratch](#training) 16 | * [Evaluating a model](#evaluating) 17 | 18 | 19 | 20 | ## Prerequisites 21 | * Python 3.5 and the required packages 22 | * `ffmpeg` (required for `madmom`) 23 | 24 | ### Installing required Python packages 25 | ```sh 26 | pip install -r requirements.txt 27 | pip install madmom 28 | ``` 29 | The `madmom` package has a install-time dependency, so should be 30 | installed after installing packages in `requirements.txt`. 31 | 32 | This will install the required packages: 33 | * [tensorflow][1] **1.0.1** (has an issue on 1.1.0) 34 | * [keras][11] 35 | * [pandas][2] 36 | * [scikit-learn][3] 37 | * [madmom][4] 38 | * [numpy][5] 39 | * [scipy][6] 40 | * [cython][7] 41 | * [h5py][12] 42 | 43 | ### Installing ffmpeg 44 | `ffmpeg` is required for `madmom`. 45 | 46 | #### MacOS (with Homebrew): 47 | ```sh 48 | brew install ffmpeg 49 | ``` 50 | 51 | #### Ubuntu: 52 | ```sh 53 | add-apt-repository ppa:mc3man/trusty-media 54 | apt-get update 55 | apt-get dist-upgrade 56 | apt-get install ffmpeg 57 | ``` 58 | 59 | #### CentOS: 60 | ```sh 61 | yum install epel-release 62 | rpm --import http://li.nux.ro/download/nux/RPM-GPG-KEY-nux.ro 63 | rpm -Uvh http://li.nux.ro/download/nux/dextop/el ... noarch.rpm 64 | yum install ffmpeg 65 | ``` 66 | 67 | 68 | 69 | ## Preparing MagnaTagATune (MTT) dataset 70 | Download audio data and tag annotations from [here][8]. Then you should 71 | see 3 `.zip` files and 1 `.csv` file: 72 | ```sh 73 | mp3.zip.001 74 | mp3.zip.002 75 | mp3.zip.003 76 | annotations_final.csv 77 | ``` 78 | 79 | To unzip the `.zip` files, merge and unzip them (referenced [here][9]): 80 | ```sh 81 | cat mp3.zip.* > mp3_all.zip 82 | unzip mp3_all.zip 83 | ``` 84 | 85 | You should see 16 directories named `0` to `f`. Typically, `0 ~ b` are 86 | used to training, `c` to validation, and `d ~ f` to test. 87 | 88 | To make your life easier, place them in a directory as below: 89 | ```sh 90 | ├── annotations_final.csv 91 | └── raw 92 | ├── 0 93 | ├── 1 94 | ├── ... 95 | └── f 96 | ``` 97 | 98 | And we will call the directory `BASE_DIR`. Preparing the MTT dataset is Done! 99 | 100 | 101 | 102 | ## Preprocessing the MTT dataset 103 | This section describes a required preprocessing task for the MTT 104 | dataset. Note that this requires `57G` storage space. 105 | 106 | These are what the preprocessing does: 107 | * Select top 50 tags in `annotations_final.csv` 108 | * Split dataset into training, validation, and test sets 109 | * Segment the raw audio files into `59049` sample length 110 | * Convert to TFRecord format 111 | 112 | To run the preprocessing, copy a shell template and edit the copy: 113 | ```sh 114 | cp scripts/build_mtt.sh.template scripts/build_mtt.sh 115 | vi scripts/build_mtt.sh 116 | ``` 117 | 118 | You should fill in the environment variables: 119 | * `BASE_DIR` the directory contains `annotations_final.csv` file and 120 | `raw` directory 121 | * `N_PROCESSES` number of processes to use; the preprocessing uses 122 | multi-processing 123 | * `ENV_NAME` (optional) if you use `virtualenv` or `conda` to create a 124 | separated environment, write your environment name 125 | 126 | The below is an example: 127 | ```sh 128 | BASE_DIR="/path/to/mtt/basedir" 129 | N_PROCESSES=4 130 | ENV_NAME="sample_cnn" 131 | ``` 132 | 133 | And run it: 134 | ```sh 135 | ./scripts/build_mtt.sh 136 | ``` 137 | 138 | The script will **automatically run a process in the background**, and 139 | **tail output** which the process prints. This will take a few minutes 140 | to an hour according to your device. 141 | 142 | The converted TFRecord files will be located in your 143 | `${BASE_DIR}/tfrecord`. Now, your `BASE_DIR`'s structure should be like 144 | this: 145 | ```sh 146 | ├── annotations_final.csv 147 | ├── build_mtt.log 148 | ├── labels.txt 149 | ├── raw 150 | │   ├── 0 151 | │   ├── ... 152 | │   └── f 153 | └── tfrecord 154 | ├── test-000-of-036.seq.tfrecords 155 | ├── ... 156 | ├── test-035-of-036.seq.tfrecords 157 | ├── train-000-of-128.tfrecords 158 | ├── ... 159 | ├── train-127-of-128.tfrecords 160 | ├── val-000-of-012.seq.tfrecords 161 | ├── ... 162 | └── val-011-of-012.seq.tfrecords 163 | ``` 164 | 165 | 166 | 167 | ## Training a model from scratch 168 | To train a model from scratch, copy a shell template and edit the 169 | copy like what did above: 170 | ```sh 171 | cp scripts/train.sh.template scripts/train.sh 172 | vi scripts/train.sh 173 | ``` 174 | 175 | And fill in the environment variables: 176 | * `BASE_DIR` the directory contains `tfrecord` directory 177 | * `TRAIN_DIR` where to save your trained model, and summaries to 178 | visualize your training using TensorBoard 179 | * `ENV_NAME` (optional) if you use `virtualenv` or `conda` to create a 180 | separated environment, write your environment name 181 | 182 | The below is an example: 183 | ```sh 184 | BASE_DIR="/path/to/mtt/basedir" 185 | TRAIN_DIR="/path/to/save/outputs" 186 | ENV_NAME="sample_cnn" 187 | ``` 188 | 189 | Let's kick off the training!: 190 | ```sh 191 | ./scripts/train.sh 192 | ``` 193 | 194 | The script will **automatically run a process in the background**, and 195 | **tail output** which the process prints. 196 | 197 | 198 | 199 | ## Evaluating a model 200 | Copy an evaluating shell script template and edit the copy: 201 | ```sh 202 | cp scripts/evaluate.sh.template scripts/evaluate.sh 203 | vi scripts/evaluate.sh 204 | ``` 205 | 206 | Fill in the environment variables: 207 | * `BASE_DIR` the directory contains `tfrecord` directory 208 | * `CHECKPOINT_DIR` where you saved your model (`TRAIN_DIR` when training) 209 | * `ENV_NAME` (optional) if you use `virtualenv` or `conda` to create a 210 | separated environment, write your environment name 211 | 212 | The script doesn't evaluate the latest model but the best model. If you 213 | want to evaluate the latest model, you should give `--best=False` as an 214 | option. 215 | 216 | [1]: https://www.tensorflow.org/ 217 | [2]: http://pandas.pydata.org/ 218 | [3]: https://www.scipy.org/ 219 | [4]: https://madmom.readthedocs.io/en/latest/ 220 | [5]: http://www.numpy.org/ 221 | [6]: https://www.scipy.org 222 | [7]: http://cython.org/ 223 | [8]: http://mirg.city.ac.uk/codeapps/the-magnatagatune-dataset 224 | [9]: https://github.com/keunwoochoi/magnatagatune-list 225 | [10]: https://arxiv.org/abs/1703.01789 226 | [11]: https://keras.io/ 227 | [12]: http://www.h5py.org 228 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython==0.25.2 2 | numpy==1.12.1 3 | pandas==0.19.2 4 | scikit-learn==0.18.1 5 | scipy==0.19.0 6 | tensorflow-gpu==1.1.0 7 | h5py==2.7.1 8 | -------------------------------------------------------------------------------- /sample_cnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tae-jun/sample-cnn/a8581d6aaa4233e981ffa83a74c7f88e1c450647/sample_cnn/__init__.py -------------------------------------------------------------------------------- /sample_cnn/data/audio_processing.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from madmom.audio.signal import Signal 5 | 6 | FLAGS = tf.flags.FLAGS 7 | 8 | 9 | def _int64_feature(value): 10 | """Wrapper for inserting int64 features into Example proto.""" 11 | if not isinstance(value, list): 12 | value = [value] 13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 14 | 15 | 16 | def _bytes_feature(value): 17 | """Wrapper for inserting bytes features into Example proto.""" 18 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 19 | 20 | 21 | def _bytes_feature_list(values): 22 | """Wrapper for inserting a bytes FeatureList into a SequenceExample proto.""" 23 | return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values]) 24 | 25 | 26 | def _segments_to_sequence_example(segments, labels): 27 | """Converts a list of segments to a SequenceExample proto. 28 | 29 | Args: 30 | segments: A list of segments. 31 | labels: A list of labels of the segments. 32 | 33 | Returns: 34 | A SequenceExample proto. 35 | """ 36 | raw_segments = [segment.tostring() for segment in segments] 37 | raw_labels = np.array(labels, dtype=np.uint8).tostring() 38 | 39 | context = tf.train.Features(feature={ 40 | 'raw_labels': _bytes_feature(raw_labels) # uint8 Tensor (50,) 41 | }) 42 | 43 | feature_lists = tf.train.FeatureLists(feature_list={ 44 | # list of float32 Tensor (59049,) 45 | 'raw_segments': _bytes_feature_list(raw_segments) 46 | }) 47 | 48 | sequence_example = tf.train.SequenceExample( 49 | context=context, feature_lists=feature_lists) 50 | 51 | return sequence_example 52 | 53 | 54 | def _segment_to_example(segment, labels): 55 | """Converts a list of segments to a list of Example protos. 56 | 57 | Args: 58 | segments: A list of segments. 59 | labels: A list of labels of the segments. 60 | 61 | Returns: 62 | A list of Example protos. 63 | """ 64 | # dtype of segment is float32 65 | raw_segment = segment.tostring() 66 | raw_labels = np.array(labels, dtype=np.uint8).tostring() 67 | 68 | example = tf.train.Example(features=tf.train.Features(feature={ 69 | 'raw_labels': _bytes_feature(raw_labels), # uint8 Tensor (50,) 70 | 'raw_segment': _bytes_feature(raw_segment) # float32 Tensor (59049,) 71 | })) 72 | 73 | return example 74 | 75 | 76 | def _audio_to_segments(filename, sample_rate, n_samples, center=False): 77 | """Loads, and splits an audio into N segments. 78 | 79 | Args: 80 | filename: A path to the audio. 81 | sample_rate: Sampling rate of the audios. If the sampling rate is different 82 | with an audio's original sampling rate, then it re-samples the audio. 83 | n_samples: Number of samples one segment contains. 84 | 85 | Returns: 86 | A list of numpy arrays; segments. 87 | """ 88 | # Load an audio file as a numpy array 89 | sig = Signal(filename, sample_rate=sample_rate, dtype=np.float32) 90 | 91 | total_samples = sig.shape[0] 92 | n_segment = total_samples // n_samples 93 | 94 | if center: 95 | # Take center samples 96 | remainder = total_samples % n_samples 97 | sig = sig[remainder // 2: -remainder // 2] 98 | 99 | # Split the signal into segments 100 | segments = [sig[i * n_samples:(i + 1) * n_samples] for i in range(n_segment)] 101 | 102 | return segments 103 | 104 | 105 | def audio_to_sequence_example(filename, labels, sample_rate, n_samples): 106 | """Converts an audio to a SequenceExample proto. 107 | 108 | Args: 109 | filename: A path to the audio. 110 | labels: A list of labels of the audio. 111 | sample_rate: Sampling rate of the audios. If the sampling rate is different 112 | with an audio's original sampling rate, then it re-samples the audio. 113 | n_samples: Number of samples one segment contains. 114 | 115 | Returns: 116 | A SequenceExample proto. 117 | """ 118 | segments = _audio_to_segments(filename, 119 | sample_rate=sample_rate, 120 | n_samples=n_samples) 121 | 122 | sequence_example = _segments_to_sequence_example(segments, labels) 123 | 124 | return sequence_example 125 | 126 | 127 | def audio_to_examples(filename, labels, sample_rate, n_samples): 128 | """Converts an audio to a list of Example protos. 129 | 130 | Args: 131 | filename: A path to the audio. 132 | labels: A list of labels of the audio. 133 | sample_rate: Sampling rate of the audios. If the sampling rate is different 134 | with an audio's original sampling rate, then it re-samples the audio. 135 | n_samples: Number of samples one segment contains. 136 | 137 | Returns: 138 | A list of Example protos. 139 | """ 140 | segments = _audio_to_segments(filename, 141 | sample_rate=sample_rate, 142 | n_samples=n_samples) 143 | 144 | examples = [_segment_to_example(segment, labels) for segment in segments] 145 | 146 | return examples 147 | -------------------------------------------------------------------------------- /sample_cnn/data/mtt/annotation_processing.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | np.random.seed(0) 5 | 6 | 7 | def load_annotations(filename, 8 | n_top=50, 9 | n_audios_per_shard=100): 10 | """Reads annotation file, takes top N tags, and splits data samples. 11 | 12 | Results 54 (top50_tags + [clip_id, mp3_path, split, shard]) columns: 13 | 14 | ['choral', 'female voice', 'metal', 'country', 'weird', 'no voice', 15 | 'cello', 'harp', 'beats', 'female vocal', 'male voice', 'dance', 16 | 'new age', 'voice', 'choir', 'classic', 'man', 'solo', 'sitar', 'soft', 17 | 'pop', 'no vocal', 'male vocal', 'woman', 'flute', 'quiet', 'loud', 18 | 'harpsichord', 'no vocals', 'vocals', 'singing', 'male', 'opera', 19 | 'indian', 'female', 'synth', 'vocal', 'violin', 'beat', 'ambient', 20 | 'piano', 'fast', 'rock', 'electronic', 'drums', 'strings', 'techno', 21 | 'slow', 'classical', 'guitar', 'clip_id', 'mp3_path', 'split', 'shard'] 22 | 23 | NOTE: This will exclude audios which have only zero-tags. Therefore, number of 24 | each split will be 15250 / 1529 / 4332 (training / validation / test). 25 | 26 | Args: 27 | filename: A path to annotation CSV file. 28 | n_top: Number of the most popular tags to take. 29 | n_audios_per_shard: Number of audios per shard. 30 | 31 | Returns: 32 | A DataFrame contains information of audios. 33 | 34 | Schema: 35 | : 0 or 1 36 | clip_id: clip_id of the original dataset 37 | mp3_path: A path to a mp3 audio file. 38 | split: A split of dataset (training / validation / test). 39 | The split is determined by its directory (0, 1, ... , f). 40 | First 12 directories (0 ~ b) are used for training, 41 | 1 (c) for validation, and 3 (d ~ f) for test. 42 | shard: A shard index of the audio. 43 | """ 44 | df = pd.read_csv(filename, delimiter='\t') 45 | 46 | top50 = (df.drop(['clip_id', 'mp3_path'], axis=1) 47 | .sum() 48 | .sort_values() 49 | .tail(n_top) 50 | .index 51 | .tolist()) 52 | 53 | df = df[top50 + ['clip_id', 'mp3_path']] 54 | 55 | # Exclude rows which only have zeros. 56 | df = df.ix[~(df.ix[:, :n_top] == 0).all(axis=1)] 57 | 58 | def split_by_directory(mp3_path): 59 | directory = mp3_path.split('/')[0] 60 | part = int(directory, 16) 61 | 62 | if part in range(12): 63 | return 'train' 64 | elif part is 12: 65 | return 'val' 66 | elif part in range(13, 16): 67 | return 'test' 68 | 69 | df['split'] = df['mp3_path'].apply( 70 | lambda mp3_path: split_by_directory(mp3_path)) 71 | 72 | for split in ['train', 'val', 'test']: 73 | n_audios = sum(df['split'] == split) 74 | n_shards = n_audios // n_audios_per_shard 75 | n_remainders = n_audios % n_audios_per_shard 76 | 77 | shards = np.tile(np.arange(n_shards), n_audios_per_shard) 78 | shards = np.concatenate([shards, np.arange(n_remainders)]) 79 | shards = np.random.permutation(shards) 80 | 81 | df.loc[df['split'] == split, 'shard'] = shards 82 | 83 | df['shard'] = df['shard'].astype(int) 84 | 85 | return df 86 | -------------------------------------------------------------------------------- /sample_cnn/data/mtt/build_mtt.py: -------------------------------------------------------------------------------- 1 | """Converts MagnaTagATune dataset to TFRecord files. """ 2 | 3 | import os 4 | 5 | import pandas as pd 6 | import tensorflow as tf 7 | 8 | from threading import Thread 9 | from queue import Queue 10 | 11 | from madmom.audio.signal import LoadAudioFileError 12 | 13 | from sample_cnn.data.audio_processing import (audio_to_sequence_example, 14 | audio_to_examples) 15 | from sample_cnn.data.mtt.annotation_processing import load_annotations 16 | 17 | tf.flags.DEFINE_string('data_dir', '', 'MagnaTagATune audio dataset directory.') 18 | tf.flags.DEFINE_string('annotation_file', '', 19 | 'A path to CSV annotation file which contains labels.') 20 | tf.flags.DEFINE_string('output_dir', '', 'Output data directory.') 21 | tf.flags.DEFINE_string('output_labels', '', 'Output label file.') 22 | 23 | tf.flags.DEFINE_integer('n_threads', 4, 24 | 'Number of threads to process audios in parallel.') 25 | 26 | tf.flags.DEFINE_integer('n_audios_per_shard', 100, 27 | 'Number of audios per shard.') 28 | 29 | # Audio processing flags 30 | tf.flags.DEFINE_integer('n_top', 50, 'Number of top N tags.') 31 | tf.flags.DEFINE_integer('sample_rate', 22050, 'Sample rate of audio.') 32 | tf.flags.DEFINE_integer('n_samples', 59049, 'Number of samples per segment.') 33 | 34 | FLAGS = tf.flags.FLAGS 35 | 36 | 37 | def _process_audio_files(args_queue): 38 | """Processes and saves audios as TFRecord files in one sub-process. 39 | 40 | Args: 41 | args_queue: A queue contains arguments which consist of: 42 | assigned_anno: A DataFrame which contains information about the audios 43 | that should be process in this sub-process. 44 | sample_rate: Sampling rate of the audios. If the sampling rate is different 45 | with an audio's original sampling rate, then it re-samples the audio. 46 | n_samples: Number of samples one segment contains. 47 | split: Dataset split which is one of 'train', 'val', or 'test'. 48 | shard: Shard index. 49 | n_shards: Number of the entire shards. 50 | """ 51 | while not args_queue.empty(): 52 | (assigned_anno, sample_rate, 53 | n_samples, split, shard, n_shards) = args_queue.get() 54 | 55 | is_test = (split == 'test') 56 | 57 | output_filename_format = ('{}-{:03d}-of-{:03d}.seq.tfrecords' 58 | if is_test else 59 | '{}-{:03d}-of-{:03d}.tfrecords') 60 | output_filename = output_filename_format.format(split, shard, n_shards) 61 | output_file_path = os.path.join(FLAGS.output_dir, output_filename) 62 | 63 | writer = tf.python_io.TFRecordWriter(output_file_path) 64 | 65 | for _, row in assigned_anno.iterrows(): 66 | audio_path = os.path.join(FLAGS.data_dir, row['mp3_path']) 67 | labels = row[:FLAGS.n_top].tolist() 68 | 69 | try: 70 | if is_test: 71 | examples = [audio_to_sequence_example(audio_path, labels, 72 | sample_rate, n_samples)] 73 | else: 74 | examples = audio_to_examples(audio_path, labels, sample_rate, 75 | n_samples) 76 | except LoadAudioFileError: 77 | # There are some broken mp3 files. Ignore it. 78 | print('Cannot load audio "{}". Ignore it.'.format(audio_path)) 79 | continue 80 | 81 | for example in examples: 82 | writer.write(example.SerializeToString()) 83 | 84 | writer.close() 85 | 86 | print('{} audios are written into "{}". {} shards left.' 87 | .format(len(assigned_anno), output_filename, args_queue.qsize())) 88 | 89 | 90 | def _process_dataset(anno, sample_rate, n_samples, n_threads): 91 | """Processes, and saves MagnaTagATune dataset using multi-processes. 92 | 93 | Args: 94 | anno: Annotation DataFrame contains tags, mp3_path, split, and shard. 95 | sample_rate: Sampling rate of the audios. If the sampling rate is different 96 | with an audio's original sampling rate, then it re-samples the audio. 97 | n_samples: Number of samples one segment contains. 98 | n_threads: Number of threads to process the dataset. 99 | """ 100 | args_queue = Queue() 101 | split_and_shard_sets = pd.unique([tuple(x) for x in anno[['split', 'shard']].values]) 102 | 103 | for split, shard in split_and_shard_sets: 104 | assigned_anno = anno[(anno['split'] == split) & (anno['shard'] == shard)] 105 | n_shards = anno[anno['split'] == split]['shard'].nunique() 106 | 107 | args = (assigned_anno, sample_rate, n_samples, split, shard, n_shards) 108 | args_queue.put(args) 109 | 110 | if FLAGS.n_threads > 1: 111 | threads = [] 112 | for _ in range(FLAGS.n_threads): 113 | thread = Thread(target=_process_audio_files, args=[args_queue]) 114 | thread.start() 115 | threads.append(thread) 116 | 117 | for thread in threads: 118 | thread.join() 119 | else: 120 | _process_audio_files(args_queue) 121 | 122 | 123 | def _save_tags(tag_list, output_labels): 124 | """Saves a list of tags to a file. 125 | 126 | Args: 127 | tag_list: The list of tags. 128 | output_labels: A path to save the list. 129 | """ 130 | with open(output_labels, 'w') as f: 131 | f.write('\n'.join(tag_list)) 132 | 133 | 134 | def main(unused_argv): 135 | df = load_annotations(filename=FLAGS.annotation_file, 136 | n_top=FLAGS.n_top, 137 | n_audios_per_shard=FLAGS.n_audios_per_shard) 138 | 139 | if not tf.gfile.IsDirectory(FLAGS.output_dir): 140 | tf.logging.info('Creating output directory: %s', FLAGS.output_dir) 141 | tf.gfile.MakeDirs(FLAGS.output_dir) 142 | 143 | # Save top N tags 144 | tag_list = df.columns[:FLAGS.n_top].tolist() 145 | _save_tags(tag_list, FLAGS.output_labels) 146 | print('Top {} tags written to {}'.format(len(tag_list), FLAGS.output_labels)) 147 | 148 | df_train = df[df['split'] == 'train'] 149 | df_val = df[df['split'] == 'val'] 150 | df_test = df[df['split'] == 'test'] 151 | 152 | n_train = len(df_train) 153 | n_val = len(df_val) 154 | n_test = len(df_test) 155 | print('Number of songs for each split: {} / {} / {} ' 156 | '(training / validation / test)'.format(n_train, n_val, n_test)) 157 | 158 | n_train_shards = df_train['shard'].nunique() 159 | n_val_shards = df_val['shard'].nunique() 160 | n_test_shards = df_test['shard'].nunique() 161 | print('Number of shards for each split: {} / {} / {} ' 162 | '(training / validation / test)'.format(n_train_shards, 163 | n_val_shards, n_test_shards)) 164 | 165 | print('Start processing MagnaTagATune using {} threads' 166 | .format(FLAGS.n_threads)) 167 | _process_dataset(df, FLAGS.sample_rate, FLAGS.n_samples, FLAGS.n_threads) 168 | 169 | print() 170 | print('Done.') 171 | 172 | 173 | if __name__ == '__main__': 174 | tf.app.run() 175 | -------------------------------------------------------------------------------- /sample_cnn/evaluate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from sample_cnn.ops import evaluate 4 | 5 | tf.flags.DEFINE_string('input_file_pattern', '', 6 | 'File pattern of sharded TFRecord input files.') 7 | tf.flags.DEFINE_string('weights_path', '', 'Path to learned weights.') 8 | 9 | tf.flags.DEFINE_integer('n_examples', 4332, 'Number of examples to run.') 10 | tf.flags.DEFINE_integer('n_audios_per_shard', 100, 11 | 'Number of audios per shard.') 12 | 13 | tf.logging.set_verbosity(tf.logging.INFO) 14 | 15 | FLAGS = tf.flags.FLAGS 16 | 17 | 18 | def main(unused_argv): 19 | assert FLAGS.input_file_pattern, '--input_file_pattern is required' 20 | assert FLAGS.weights_path, '--weights_path is required' 21 | 22 | evaluate(FLAGS.input_file_pattern, FLAGS.weights_path, 23 | n_examples=FLAGS.n_examples, 24 | n_audios_per_shard=FLAGS.n_audios_per_shard) 25 | 26 | 27 | if __name__ == '__main__': 28 | tf.app.run() 29 | -------------------------------------------------------------------------------- /sample_cnn/keras_utils/tfrecord_model.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import copy 3 | 4 | import six 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | from keras.models import Model 9 | from keras import optimizers 10 | from keras import losses 11 | from keras import callbacks as cbks 12 | from keras import backend as K 13 | from keras import metrics as metrics_module 14 | from keras.engine.training import _collect_metrics 15 | 16 | 17 | class TFRecordModel(Model): 18 | def __init__(self, inputs, outputs, val_inputs=None, name=None): 19 | super(TFRecordModel, self).__init__(inputs, outputs, name=name) 20 | 21 | # Prepare val_inputs. 22 | if val_inputs is None: 23 | self.val_inputs = [] 24 | elif isinstance(val_inputs, (list, tuple)): 25 | self.val_inputs = list(val_inputs) # Tensor or list of tensors. 26 | else: 27 | self.val_inputs = [val_inputs] 28 | 29 | # Prepare val_outputs. 30 | if val_inputs is None: 31 | self.val_outputs = [] 32 | else: 33 | val_outputs = self(val_inputs) 34 | if isinstance(val_outputs, (list, tuple)): 35 | self.val_outputs = list(val_outputs) # Tensor or list of tensors. 36 | else: 37 | self.val_outputs = [val_outputs] 38 | 39 | def compile_tfrecord(self, optimizer, loss, y, metrics=None, 40 | y_val=None): 41 | """Configures the model for training. 42 | 43 | # Arguments 44 | optimizer: str (name of optimizer) or optimizer object. 45 | See [optimizers](/optimizers). 46 | loss: str (name of objective function) or objective function. 47 | See [losses](/losses). 48 | If the model has multiple outputs, you can use a different loss 49 | on each output by passing a dictionary or a list of losses. 50 | The loss value that will be minimized by the model 51 | will then be the sum of all individual losses. 52 | metrics: list of metrics to be evaluated by the model 53 | during training and testing. 54 | Typically you will use `metrics=['accuracy']`. 55 | To specify different metrics for different outputs of a 56 | multi-output model, you could also pass a dictionary, 57 | such as `metrics={'output_a': 'accuracy'}`. 58 | 59 | # Raises 60 | ValueError: In case of invalid arguments for 61 | `optimizer`, `loss`, `metrics` or `sample_weight_mode`. 62 | """ 63 | loss = loss or {} 64 | self.optimizer = optimizers.get(optimizer) 65 | self.loss = loss 66 | self.sample_weight_mode = None 67 | self.loss_weights = None 68 | self.y_val = y_val 69 | self.constraints = None 70 | 71 | do_validation = bool(len(self.val_inputs) > 0) 72 | if do_validation and y_val is None: 73 | raise ValueError('When you use validation inputs, ' 74 | 'you should provide y_val.') 75 | 76 | # Prepare loss functions. 77 | if isinstance(loss, dict): 78 | for name in loss: 79 | if name not in self.output_names: 80 | raise ValueError('Unknown entry in loss ' 81 | 'dictionary: "' + name + '". ' 82 | 'Only expected the following keys: ' + 83 | str(self.output_names)) 84 | loss_functions = [] 85 | for name in self.output_names: 86 | if name not in loss: 87 | warnings.warn('Output "' + name + 88 | '" missing from loss dictionary. ' 89 | 'We assume this was done on purpose, ' 90 | 'and we will not be expecting ' 91 | 'any data to be passed to "' + name + 92 | '" during training.', stacklevel=2) 93 | loss_functions.append(losses.get(loss.get(name))) 94 | elif isinstance(loss, list): 95 | if len(loss) != len(self.outputs): 96 | raise ValueError('When passing a list as loss, ' 97 | 'it should have one entry per model outputs. ' 98 | 'The model has ' + str(len(self.outputs)) + 99 | ' outputs, but you passed loss=' + 100 | str(loss)) 101 | loss_functions = [losses.get(l) for l in loss] 102 | else: 103 | loss_function = losses.get(loss) 104 | loss_functions = [loss_function for _ in range(len(self.outputs))] 105 | self.loss_functions = loss_functions 106 | 107 | # Prepare training targets of model. 108 | if isinstance(y, (list, tuple)): 109 | y = list(y) # Tensor or list of tensors. 110 | else: 111 | y = [y] 112 | self.targets = [] 113 | for i in range(len(self.outputs)): 114 | target = y[i] 115 | self.targets.append(target) 116 | 117 | # Prepare validation targets of model. 118 | if isinstance(y_val, (list, tuple)): 119 | y_val = list(y_val) # Tensor or list of tensors. 120 | else: 121 | y_val = [y_val] 122 | self.y_val = y_val 123 | self.val_targets = [] 124 | for i in range(len(self.val_outputs)): 125 | val_target = y_val[i] 126 | self.val_targets.append(val_target) 127 | 128 | # Prepare metrics. 129 | self.metrics = metrics 130 | self.metrics_names = ['loss'] 131 | self.metrics_tensors = [] 132 | self.val_metrics_names = ['loss'] 133 | self.val_metrics_tensors = [] 134 | 135 | # Compute total training loss. 136 | total_loss = None 137 | for i in range(len(self.outputs)): 138 | y_true = self.targets[i] 139 | y_pred = self.outputs[i] 140 | loss_function = loss_functions[i] 141 | val_output_loss = K.mean(loss_function(y_true, y_pred)) 142 | if len(self.outputs) > 1: 143 | self.metrics_tensors.append(val_output_loss) 144 | self.metrics_names.append(self.output_names[i] + '_loss') 145 | if total_loss is None: 146 | total_loss = val_output_loss 147 | else: 148 | total_loss += val_output_loss 149 | if total_loss is None: 150 | if not self.losses: 151 | raise RuntimeError('The model cannot be compiled ' 152 | 'because it has no loss to optimize.') 153 | else: 154 | total_loss = 0. 155 | 156 | # Compute total validation loss. 157 | val_total_loss = None 158 | for i in range(len(self.val_outputs)): 159 | y_true = self.val_targets[i] 160 | y_pred = self.val_outputs[i] 161 | loss_function = loss_functions[i] 162 | val_output_loss = K.mean(loss_function(y_true, y_pred)) 163 | if len(self.outputs) > 1: 164 | self.val_metrics_tensors.append(val_output_loss) 165 | self.val_metrics_names.append(self.output_names[i] + '_val_loss') 166 | if val_total_loss is None: 167 | val_total_loss = val_output_loss 168 | else: 169 | val_total_loss += val_output_loss 170 | if val_total_loss is None: 171 | if not self.losses and do_validation: 172 | raise RuntimeError('The model cannot be compiled ' 173 | 'because it has no loss to optimize.') 174 | else: 175 | val_total_loss = 0. 176 | 177 | # Add regularization penalties 178 | # and other layer-specific losses. 179 | for loss_tensor in self.losses: 180 | total_loss += loss_tensor 181 | val_total_loss += loss_tensor 182 | 183 | # List of same size as output_names. 184 | # contains tuples (metrics for output, names of metrics). 185 | nested_metrics = _collect_metrics(metrics, self.output_names) 186 | 187 | def append_metric(layer_num, metric_name, metric_tensor): 188 | """Helper function used in loop below.""" 189 | if len(self.output_names) > 1: 190 | metric_name = self.output_layers[layer_num].name + '_' + metric_name 191 | self.metrics_names.append(metric_name) 192 | self.metrics_tensors.append(metric_tensor) 193 | 194 | for i in range(len(self.outputs)): 195 | y_true = self.targets[i] 196 | y_pred = self.outputs[i] 197 | output_metrics = nested_metrics[i] 198 | for metric in output_metrics: 199 | if metric == 'accuracy' or metric == 'acc': 200 | # custom handling of accuracy 201 | # (because of class mode duality) 202 | output_shape = self.internal_output_shapes[i] 203 | acc_fn = None 204 | if (output_shape[-1] == 1 or 205 | self.loss_functions[i] == losses.binary_crossentropy): 206 | # case: binary accuracy 207 | acc_fn = metrics_module.binary_accuracy 208 | elif self.loss_functions[i] == losses.sparse_categorical_crossentropy: 209 | # case: categorical accuracy with sparse targets 210 | acc_fn = metrics_module.sparse_categorical_accuracy 211 | else: 212 | acc_fn = metrics_module.categorical_accuracy 213 | 214 | append_metric(i, 'acc', K.mean(acc_fn(y_true, y_pred))) 215 | else: 216 | metric_fn = metrics_module.get(metric) 217 | metric_result = metric_fn(y_true, y_pred) 218 | metric_result = { 219 | metric_fn.__name__: metric_result 220 | } 221 | for name, tensor in six.iteritems(metric_result): 222 | append_metric(i, name, tensor) 223 | 224 | def append_val_metric(layer_num, metric_name, metric_tensor): 225 | """Helper function used in loop below.""" 226 | if len(self.output_names) > 1: 227 | metric_name = self.output_layers[layer_num].name + '_val_' + metric_name 228 | self.val_metrics_names.append(metric_name) 229 | self.val_metrics_tensors.append(metric_tensor) 230 | 231 | for i in range(len(self.val_outputs)): 232 | y_true = self.val_targets[i] 233 | y_pred = self.val_outputs[i] 234 | output_metrics = nested_metrics[i] 235 | for metric in output_metrics: 236 | if metric == 'accuracy' or metric == 'acc': 237 | # custom handling of accuracy 238 | # (because of class mode duality) 239 | output_shape = self.internal_output_shapes[i] 240 | acc_fn = None 241 | if (output_shape[-1] == 1 or 242 | self.loss_functions[i] == losses.binary_crossentropy): 243 | # case: binary accuracy 244 | acc_fn = metrics_module.binary_accuracy 245 | elif self.loss_functions[i] == losses.sparse_categorical_crossentropy: 246 | # case: categorical accuracy with sparse targets 247 | acc_fn = metrics_module.sparse_categorical_accuracy 248 | else: 249 | acc_fn = metrics_module.categorical_accuracy 250 | 251 | append_val_metric(i, 'acc', K.mean(acc_fn(y_true, y_pred))) 252 | else: 253 | metric_fn = metrics_module.get(metric) 254 | metric_result = metric_fn(y_true, y_pred) 255 | metric_result = { 256 | metric_fn.__name__: metric_result 257 | } 258 | for name, tensor in six.iteritems(metric_result): 259 | append_val_metric(i, name, tensor) 260 | 261 | # Prepare gradient updates and state updates. 262 | self.total_loss = total_loss 263 | self.val_total_loss = val_total_loss 264 | 265 | # Functions for train, test and predict will 266 | # be compiled lazily when required. 267 | # This saves time when the user is not using all functions. 268 | self.train_function = None 269 | self.val_function = None 270 | self.test_function = None 271 | self.predict_function = None 272 | 273 | # Collected trainable weights and sort them deterministically. 274 | trainable_weights = self.trainable_weights 275 | # Sort weights by name. 276 | if trainable_weights: 277 | trainable_weights.sort(key=lambda x: x.name) 278 | self._collected_trainable_weights = trainable_weights 279 | 280 | def _make_tfrecord_train_function(self): 281 | if not hasattr(self, 'train_function'): 282 | raise RuntimeError('You must compile your model before using it.') 283 | if self.train_function is None: 284 | inputs = [] 285 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 286 | inputs += [K.learning_phase()] 287 | 288 | training_updates = self.optimizer.get_updates( 289 | self._collected_trainable_weights, 290 | self.constraints, 291 | self.total_loss) 292 | updates = self.updates + training_updates 293 | # Gets loss and metrics. Updates weights at each call. 294 | self.train_function = K.function(inputs, 295 | [self.total_loss] + self.metrics_tensors, 296 | updates=updates) 297 | 298 | def _make_tfrecord_test_function(self): 299 | if not hasattr(self, 'test_function'): 300 | raise RuntimeError('You must compile your model before using it.') 301 | if self.test_function is None: 302 | inputs = [] 303 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 304 | inputs += [K.learning_phase()] 305 | # Return loss and metrics, no gradient updates. 306 | # Does update the network states. 307 | self.test_function = K.function(inputs, 308 | [self.total_loss] + self.metrics_tensors, 309 | updates=self.state_updates) 310 | 311 | def _make_tfrecord_val_function(self): 312 | if not hasattr(self, 'val_function'): 313 | raise RuntimeError('You must compile your model before using it.') 314 | if self.val_function is None: 315 | inputs = [] 316 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 317 | inputs += [K.learning_phase()] 318 | # Return loss and metrics, no gradient updates. 319 | # Does update the network states. 320 | self.val_function = K.function( 321 | inputs, 322 | [self.val_total_loss] + self.val_metrics_tensors, 323 | updates=self.state_updates) 324 | 325 | def fit_tfrecord(self, steps_per_epoch, 326 | epochs=1, 327 | verbose=1, 328 | callbacks=None, 329 | validation_steps=None, 330 | initial_epoch=0): 331 | epoch = initial_epoch 332 | 333 | self._make_tfrecord_train_function() 334 | 335 | do_validation = bool(len(self.val_inputs) > 0) 336 | if do_validation and not validation_steps: 337 | raise ValueError('When using a validation batch, ' 338 | 'you must specify a value for ' 339 | '`validation_steps`.') 340 | 341 | # Prepare display labels. 342 | out_labels = self._get_deduped_metrics_names() 343 | 344 | if do_validation: 345 | callback_metrics = copy.copy(out_labels) + ['val_' + n 346 | for n in out_labels] 347 | else: 348 | callback_metrics = copy.copy(out_labels) 349 | 350 | # prepare callbacks 351 | self.history = cbks.History() 352 | callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history] 353 | if verbose: 354 | callbacks += [cbks.ProgbarLogger(count_mode='steps')] 355 | callbacks = cbks.CallbackList(callbacks) 356 | 357 | # it's possible to callback a different model than self: 358 | if hasattr(self, 'callback_model') and self.callback_model: 359 | callback_model = self.callback_model 360 | else: 361 | callback_model = self 362 | callbacks.set_model(callback_model) 363 | callbacks.set_params({ 364 | 'epochs': epochs, 365 | 'steps': steps_per_epoch, 366 | 'verbose': verbose, 367 | 'do_validation': do_validation, 368 | 'metrics': callback_metrics, 369 | }) 370 | callbacks.on_train_begin() 371 | 372 | if do_validation: 373 | val_sample_weight = None 374 | for cbk in callbacks: 375 | cbk.validation_data = [self.val_inputs, self.y_val, val_sample_weight] 376 | 377 | try: 378 | sess = K.get_session() 379 | coord = tf.train.Coordinator() 380 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 381 | 382 | callback_model.stop_training = False 383 | while epoch < epochs: 384 | callbacks.on_epoch_begin(epoch) 385 | steps_done = 0 386 | batch_index = 0 387 | while steps_done < steps_per_epoch: 388 | # build batch logs 389 | batch_logs = { 390 | 'batch': batch_index, 391 | 'size': self.inputs[0].shape[0].value 392 | } 393 | callbacks.on_batch_begin(batch_index, batch_logs) 394 | 395 | if self.uses_learning_phase and not isinstance(K.learning_phase(), 396 | int): 397 | ins = [1.] 398 | else: 399 | ins = [] 400 | outs = self.train_function(ins) 401 | 402 | if not isinstance(outs, list): 403 | outs = [outs] 404 | for l, o in zip(out_labels, outs): 405 | batch_logs[l] = o 406 | 407 | callbacks.on_batch_end(batch_index, batch_logs) 408 | 409 | # Construct epoch logs. 410 | epoch_logs = {} 411 | batch_index += 1 412 | steps_done += 1 413 | 414 | # Epoch finished. 415 | if steps_done >= steps_per_epoch and do_validation: 416 | val_outs = self._validate_tfrecord(steps=validation_steps) 417 | if not isinstance(val_outs, list): 418 | val_outs = [val_outs] 419 | # Same labels assumed. 420 | for l, o in zip(out_labels, val_outs): 421 | epoch_logs['val_' + l] = o 422 | 423 | callbacks.on_epoch_end(epoch, epoch_logs) 424 | epoch += 1 425 | if callback_model.stop_training: 426 | break 427 | 428 | finally: 429 | # TODO: If you close the queue, you can't open it again.. 430 | # coord.request_stop() 431 | # coord.join(threads) 432 | pass 433 | 434 | callbacks.on_train_end() 435 | return self.history 436 | 437 | def _validate_tfrecord(self, steps): 438 | self._make_tfrecord_val_function() 439 | 440 | steps_done = 0 441 | all_outs = [] 442 | batch_sizes = [] 443 | 444 | while steps_done < steps: 445 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 446 | ins = [0.] 447 | else: 448 | ins = [] 449 | outs = self.val_function(ins) 450 | if len(outs) == 1: 451 | outs = outs[0] 452 | 453 | batch_size = self.val_inputs[0].shape[0].value 454 | all_outs.append(outs) 455 | 456 | steps_done += 1 457 | batch_sizes.append(batch_size) 458 | 459 | if not isinstance(outs, list): 460 | return np.average(np.asarray(all_outs), 461 | weights=batch_sizes) 462 | else: 463 | averages = [] 464 | for i in range(len(outs)): 465 | averages.append(np.average([out[i] for out in all_outs], 466 | weights=batch_sizes)) 467 | return averages 468 | 469 | def evaluate_tfrecord(self, steps): 470 | """Evaluates the model on a data generator. 471 | 472 | The generator should return the same kind of data 473 | as accepted by `test_on_batch`. 474 | 475 | # Arguments 476 | x_batch: 477 | y_batch: 478 | steps: Total number of steps (batches of samples) 479 | to yield from `generator` before stopping. 480 | stop_queue_runners: If True, stop queue runners after evaluation. 481 | 482 | # Returns 483 | Scalar test loss (if the model has a single output and no metrics) 484 | or list of scalars (if the model has multiple outputs 485 | and/or metrics). The attribute `model.metrics_names` will give you 486 | the display labels for the scalar outputs. 487 | 488 | # Raises 489 | ValueError: In case the generator yields 490 | data in an invalid format. 491 | """ 492 | self._make_tfrecord_test_function() 493 | 494 | steps_done = 0 495 | all_outs = [] 496 | batch_sizes = [] 497 | 498 | try: 499 | sess = K.get_session() 500 | coord = tf.train.Coordinator() 501 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 502 | 503 | while steps_done < steps: 504 | 505 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 506 | ins = [0.] 507 | else: 508 | ins = [] 509 | outs = self.test_function(ins) 510 | if len(outs) == 1: 511 | outs = outs[0] 512 | 513 | batch_size = self.inputs[0].shape[0].value 514 | all_outs.append(outs) 515 | 516 | steps_done += 1 517 | batch_sizes.append(batch_size) 518 | 519 | finally: 520 | # TODO: If you close the queue, you can't open it again.. 521 | # if stop_queue_runners: 522 | # coord.request_stop() 523 | # coord.join(threads) 524 | pass 525 | 526 | if not isinstance(outs, list): 527 | return np.average(np.asarray(all_outs), 528 | weights=batch_sizes) 529 | else: 530 | averages = [] 531 | for i in range(len(outs)): 532 | averages.append(np.average([out[i] for out in all_outs], 533 | weights=batch_sizes)) 534 | return averages 535 | 536 | def _make_tfrecord_predict_function(self): 537 | if not hasattr(self, 'predict_function'): 538 | self.predict_function = None 539 | if self.predict_function is None: 540 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 541 | inputs = [K.learning_phase()] 542 | else: 543 | inputs = [] 544 | # Gets network outputs. Does not update weights. 545 | # Does update the network states. 546 | self.predict_function = K.function(inputs, 547 | self.outputs, 548 | updates=self.state_updates) 549 | 550 | def predict_tfrecord(self, x_batch): 551 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int): 552 | ins = [0.] 553 | else: 554 | ins = [] 555 | self._make_tfrecord_predict_function() 556 | 557 | try: 558 | sess = K.get_session() 559 | coord = tf.train.Coordinator() 560 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 561 | 562 | outputs = self.predict_function(ins) 563 | 564 | finally: 565 | # TODO: If you close the queue, you can't open it again.. 566 | # if stop_queue_runners: 567 | # coord.request_stop() 568 | # coord.join(threads) 569 | pass 570 | 571 | if len(outputs) == 1: 572 | return outputs[0] 573 | return outputs 574 | -------------------------------------------------------------------------------- /sample_cnn/model.py: -------------------------------------------------------------------------------- 1 | from keras.layers import (Conv1D, MaxPool1D, BatchNormalization, 2 | Dense, Dropout, Activation, Flatten, Reshape, Input) 3 | 4 | from sample_cnn.keras_utils.tfrecord_model import TFRecordModel 5 | 6 | 7 | class SampleCNN(TFRecordModel): 8 | def __init__(self, segments, 9 | val_segments=None, 10 | n_outputs=50, 11 | activation='relu', 12 | kernel_initializer='he_uniform', 13 | dropout_rate=0.5, 14 | extra_inputs=None, 15 | extra_outputs=None): 16 | # 59049 17 | segments = Input(tensor=segments) 18 | net = Reshape([-1, 1])(segments) 19 | # 59049 X 1 20 | net = Conv1D(128, 3, strides=3, padding='valid', 21 | kernel_initializer=kernel_initializer)(net) 22 | net = BatchNormalization()(net) 23 | net = Activation(activation)(net) 24 | # 19683 X 128 25 | net = Conv1D(128, 3, padding='same', 26 | kernel_initializer=kernel_initializer)(net) 27 | net = BatchNormalization()(net) 28 | net = Activation(activation)(net) 29 | net = MaxPool1D(3)(net) 30 | # 6561 X 128 31 | net = Conv1D(128, 3, padding='same', 32 | kernel_initializer=kernel_initializer)(net) 33 | net = BatchNormalization()(net) 34 | net = Activation(activation)(net) 35 | net = MaxPool1D(3)(net) 36 | # 2187 X 128 37 | net = Conv1D(256, 3, padding='same', 38 | kernel_initializer=kernel_initializer)(net) 39 | net = BatchNormalization()(net) 40 | net = Activation(activation)(net) 41 | net = MaxPool1D(3)(net) 42 | # 729 X 256 43 | net = Conv1D(256, 3, padding='same', 44 | kernel_initializer=kernel_initializer)(net) 45 | net = BatchNormalization()(net) 46 | net = Activation(activation)(net) 47 | net = MaxPool1D(3)(net) 48 | # 243 X 256 49 | net = Conv1D(256, 3, padding='same', 50 | kernel_initializer=kernel_initializer)(net) 51 | net = BatchNormalization()(net) 52 | net = Activation(activation)(net) 53 | net = MaxPool1D(3)(net) 54 | # 81 X 256 55 | net = Conv1D(256, 3, padding='same', 56 | kernel_initializer=kernel_initializer)(net) 57 | net = BatchNormalization()(net) 58 | net = Activation(activation)(net) 59 | net = MaxPool1D(3)(net) 60 | # 27 X 256 61 | net = Conv1D(256, 3, padding='same', 62 | kernel_initializer=kernel_initializer)(net) 63 | net = BatchNormalization()(net) 64 | net = Activation(activation)(net) 65 | net = MaxPool1D(3)(net) 66 | # 9 X 256 67 | net = Conv1D(256, 3, padding='same', 68 | kernel_initializer=kernel_initializer)(net) 69 | net = BatchNormalization()(net) 70 | net = Activation(activation)(net) 71 | net = MaxPool1D(3)(net) 72 | # 3 X 256 73 | net = Conv1D(512, 3, padding='same', 74 | kernel_initializer=kernel_initializer)(net) 75 | net = BatchNormalization()(net) 76 | net = Activation(activation)(net) 77 | net = MaxPool1D(3)(net) 78 | # 1 X 512 79 | net = Conv1D(512, 1, padding='same', 80 | kernel_initializer=kernel_initializer)(net) 81 | net = BatchNormalization()(net) 82 | net = Activation(activation)(net) 83 | # 1 X 512 84 | net = Dropout(dropout_rate)(net) 85 | net = Flatten()(net) 86 | 87 | logits = Dense(units=n_outputs, activation='sigmoid')(net) 88 | 89 | if extra_inputs is None: 90 | extra_inputs = [] 91 | if extra_outputs is None: 92 | extra_outputs = [] 93 | 94 | if not isinstance(extra_inputs, list): 95 | extra_inputs = [extra_inputs] 96 | inputs = [segments] + extra_inputs 97 | 98 | if not isinstance(extra_outputs, list): 99 | extra_outputs = [extra_outputs] 100 | outputs = [logits] + extra_outputs 101 | 102 | super(SampleCNN, self).__init__(inputs=inputs, 103 | val_inputs=val_segments, 104 | outputs=outputs) 105 | -------------------------------------------------------------------------------- /sample_cnn/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from sample_cnn.ops.batch_inputs import * 2 | from sample_cnn.ops.evaluation import * -------------------------------------------------------------------------------- /sample_cnn/ops/batch_inputs.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def _read_example(filename_queue, n_labels=50, n_samples=59049): 5 | reader = tf.TFRecordReader() 6 | _, serialized_example = reader.read(filename_queue) 7 | features = tf.parse_single_example( 8 | serialized_example, 9 | features={ 10 | 'raw_labels': tf.FixedLenFeature([], tf.string), 11 | 'raw_segment': tf.FixedLenFeature([], tf.string) 12 | }) 13 | 14 | segment = tf.decode_raw(features['raw_segment'], tf.float32) 15 | segment.set_shape([n_samples]) 16 | 17 | labels = tf.decode_raw(features['raw_labels'], tf.uint8) 18 | labels.set_shape([n_labels]) 19 | labels = tf.cast(labels, tf.float32) 20 | 21 | return segment, labels 22 | 23 | 24 | def _read_sequence_example(filename_queue, 25 | n_labels=50, n_samples=59049, n_segments=10): 26 | reader = tf.TFRecordReader() 27 | _, serialized_example = reader.read(filename_queue) 28 | context, sequence = tf.parse_single_sequence_example( 29 | serialized_example, 30 | context_features={ 31 | 'raw_labels': tf.FixedLenFeature([], dtype=tf.string) 32 | }, 33 | sequence_features={ 34 | 'raw_segments': tf.FixedLenSequenceFeature([], dtype=tf.string) 35 | }) 36 | 37 | segments = tf.decode_raw(sequence['raw_segments'], tf.float32) 38 | segments.set_shape([n_segments, n_samples]) 39 | 40 | labels = tf.decode_raw(context['raw_labels'], tf.uint8) 41 | labels.set_shape([n_labels]) 42 | labels = tf.cast(labels, tf.float32) 43 | 44 | return segments, labels 45 | 46 | 47 | def batch_inputs(file_pattern, 48 | batch_size, 49 | is_training, 50 | examples_per_shard, 51 | is_sequence=False, 52 | input_queue_capacity_factor=16, 53 | n_read_threads=4, 54 | shard_queue_name='filename_queue', 55 | example_queue_name='input_queue'): 56 | data_files = [] 57 | for pattern in file_pattern.split(","): 58 | data_files.extend(tf.gfile.Glob(pattern)) 59 | 60 | assert len(data_files) > 0, ( 61 | 'Found no input files matching {}'.format(file_pattern)) 62 | 63 | tf.logging.info('Prefetching values from %d files matching %s', 64 | len(data_files), file_pattern) 65 | 66 | if is_sequence: 67 | read_example_fn = _read_sequence_example 68 | else: 69 | read_example_fn = _read_example 70 | 71 | if is_training: 72 | filename_queue = tf.train.string_input_producer( 73 | data_files, shuffle=True, capacity=16, name=shard_queue_name) 74 | 75 | # examples_per_shard 76 | # = examples_per_song * n_training_songs / n_training_shards 77 | # = 10 * 15250 / 152 78 | # = 1003 79 | # 80 | # example_size = 59049 * 4bytes = 232KB 81 | # 82 | # queue_size 83 | # = examples_per_shard * input_queue_capacity_factor * example_size 84 | # = 1003 * 16 * 232KB = 3.7GB 85 | min_queue_examples = examples_per_shard * input_queue_capacity_factor 86 | capacity = min_queue_examples + 3 * batch_size 87 | 88 | example_list = [read_example_fn(filename_queue) 89 | for _ in range(n_read_threads)] 90 | 91 | segment, label = tf.train.shuffle_batch_join( 92 | example_list, 93 | batch_size=batch_size, 94 | capacity=capacity, 95 | min_after_dequeue=min_queue_examples, 96 | name='shuffle_' + example_queue_name, 97 | ) 98 | 99 | return segment, label 100 | 101 | else: 102 | filename_queue = tf.train.string_input_producer( 103 | data_files, shuffle=False, capacity=1, name=shard_queue_name) 104 | 105 | segment, label = read_example_fn(filename_queue) 106 | 107 | capacity = examples_per_shard + 2 * batch_size 108 | segment_batch, label_batch = tf.train.batch( 109 | [segment, label], 110 | batch_size=batch_size, 111 | num_threads=1, 112 | capacity=capacity, 113 | name='fifo_' + example_queue_name) 114 | 115 | return segment_batch, label_batch 116 | -------------------------------------------------------------------------------- /sample_cnn/ops/evaluation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from keras.layers import Input 5 | from sklearn.metrics import roc_auc_score, log_loss 6 | 7 | from sample_cnn.ops import batch_inputs 8 | from sample_cnn.model import SampleCNN 9 | 10 | 11 | def evaluate(input_file_pattern, weights_path, n_examples, 12 | n_audios_per_shard=100, print_progress=True): 13 | # audio: [1, 10, 58081] 14 | # labels: [1, 50] 15 | audio, labels = batch_inputs( 16 | file_pattern=input_file_pattern, 17 | batch_size=1, 18 | is_training=False, 19 | is_sequence=True, 20 | n_read_threads=1, 21 | examples_per_shard=n_audios_per_shard, 22 | shard_queue_name='filename_queue', 23 | example_queue_name='input_queue') 24 | 25 | segments = tf.squeeze(audio) 26 | 27 | labels = tf.squeeze(labels) 28 | labels = Input(tensor=labels) 29 | 30 | model = SampleCNN(segments=segments, 31 | extra_inputs=labels, 32 | extra_outputs=labels) 33 | 34 | print('Load weights from "{}".'.format(weights_path)) 35 | model.load_weights(weights_path) 36 | 37 | n_classes = labels.shape[0].value 38 | all_y_pred = np.empty([0, n_classes], dtype=np.float32) 39 | all_y_true = np.empty([0, n_classes], dtype=np.float32) 40 | 41 | print('Start evaluation.') 42 | for i in range(n_examples): 43 | y_pred_segments, y_true = model.predict_tfrecord(segments) 44 | 45 | y_pred = np.mean(y_pred_segments, axis=0) 46 | 47 | y_pred = np.expand_dims(y_pred, 0) 48 | y_true = np.expand_dims(y_true, 0) 49 | 50 | all_y_pred = np.append(all_y_pred, y_pred, axis=0) 51 | all_y_true = np.append(all_y_true, y_true, axis=0) 52 | 53 | if print_progress and i % (n_examples // 100) == 0 and i: 54 | print('Evaluated [{:04d}/{:04d}].'.format(i + 1, n_examples)) 55 | 56 | losses = [] 57 | for i in range(n_classes): 58 | class_y_true = all_y_true[:, i] 59 | class_y_pred = all_y_pred[:, i] 60 | if np.sum(class_y_true) != 0: 61 | class_loss = log_loss(class_y_true, class_y_pred) 62 | losses.append(class_loss) 63 | 64 | loss = np.mean(losses) 65 | print('@ binary cross entropy loss: {}'.format(loss)) 66 | 67 | roc_auc = roc_auc_score(all_y_true, all_y_pred, average='macro') 68 | print('@ ROC AUC score: {}'.format(roc_auc)) 69 | -------------------------------------------------------------------------------- /sample_cnn/train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import re 4 | 5 | import tensorflow as tf 6 | 7 | from glob import glob 8 | 9 | from keras.callbacks import (TensorBoard, ModelCheckpoint, 10 | EarlyStopping, CSVLogger) 11 | from keras.optimizers import SGD 12 | 13 | from sample_cnn.model import SampleCNN 14 | from sample_cnn.ops import batch_inputs, evaluate 15 | 16 | tf.logging.set_verbosity(tf.logging.INFO) 17 | 18 | # Paths. 19 | tf.flags.DEFINE_string('train_input_file_pattern', '', 20 | 'File pattern of TFRecord input files.') 21 | tf.flags.DEFINE_string('val_input_file_pattern', '', 22 | 'File pattern of validation TFRecord input files.') 23 | tf.flags.DEFINE_string('test_input_file_pattern', '', 24 | 'File pattern of test TFRecord input files.') 25 | tf.flags.DEFINE_string('train_dir', '', 26 | 'Directory where to write event logs and checkpoints.') 27 | tf.flags.DEFINE_string('checkpoint_prefix', 'best_weights', 28 | 'Prefix of the checkpoint filename.') 29 | 30 | # Batch options. 31 | tf.flags.DEFINE_integer('batch_size', 23, 'Batch size.') 32 | tf.flags.DEFINE_integer('n_audios_per_shard', 100, 33 | 'Number of audios per shard.') 34 | tf.flags.DEFINE_integer('n_segments_per_audio', 10, 35 | 'Number of segments per audio.') 36 | tf.flags.DEFINE_integer('n_train_examples', 15250, 37 | 'Number of examples in training dataset.') 38 | tf.flags.DEFINE_integer('n_val_examples', 1529, 39 | 'Number of examples in validation dataset.') 40 | tf.flags.DEFINE_integer('n_test_examples', 4332, 41 | 'Number of examples in test dataset.') 42 | 43 | tf.flags.DEFINE_integer('n_read_threads', 4, 'Number of example reader.') 44 | 45 | # Learning options. 46 | tf.flags.DEFINE_float('initial_learning_rate', 0.01, 'Initial learning rate.') 47 | tf.flags.DEFINE_float('momentum', 0.9, 'Momentum.') 48 | tf.flags.DEFINE_float('dropout_rate', 0.5, 'Dropout keep probability.') 49 | tf.flags.DEFINE_float('global_lr_decay', 0.2, 'Global learning rate decay.') 50 | tf.flags.DEFINE_float('local_lr_decay', 1e-6, 'Local learning rate decay.') 51 | 52 | # Training options. 53 | tf.flags.DEFINE_integer('patience', 3, 'A patience for the early stopping.') 54 | tf.flags.DEFINE_integer('max_trains', 5, 'Number of re-training.') 55 | tf.flags.DEFINE_integer('initial_stage', 0, 56 | 'The stage where to start training.') 57 | 58 | FLAGS = tf.app.flags.FLAGS 59 | 60 | 61 | def make_path(*paths): 62 | path = os.path.join(*[str(path) for path in paths]) 63 | path = os.path.realpath(path) 64 | return path 65 | 66 | 67 | def calculate_steps(n_examples, n_segments, batch_size): 68 | steps = 1. * n_examples * n_segments 69 | steps = math.ceil(steps / batch_size) 70 | return steps 71 | 72 | 73 | def find_best_checkpoint(*dirs): 74 | best_checkpoint_path = None 75 | best_epoch = -1 76 | best_val_loss = 1e+10 77 | for dir in dirs: 78 | checkpoint_paths = glob('{}/{}*'.format(dir, FLAGS.checkpoint_prefix)) 79 | for checkpoint_path in checkpoint_paths: 80 | epoch = int(re.findall('e\d+', checkpoint_path)[0][1:]) 81 | val_loss = float(re.findall('l\d\.\d+', checkpoint_path)[0][1:]) 82 | 83 | if val_loss < best_val_loss: 84 | best_checkpoint_path = checkpoint_path 85 | best_epoch = epoch 86 | best_val_loss = val_loss 87 | 88 | return best_checkpoint_path, best_epoch, best_val_loss 89 | 90 | 91 | def train(initial_lr, 92 | stage_train_dir, 93 | checkpoint_path_to_load=None, 94 | initial_epoch=0): 95 | if not tf.gfile.Exists(stage_train_dir): 96 | tf.logging.info('Creating training directory: %s', stage_train_dir) 97 | tf.gfile.MakeDirs(stage_train_dir) 98 | 99 | x_train_batch, y_train_batch = batch_inputs( 100 | file_pattern=FLAGS.train_input_file_pattern, 101 | batch_size=FLAGS.batch_size, 102 | is_training=True, 103 | n_read_threads=FLAGS.n_read_threads, 104 | examples_per_shard=FLAGS.n_audios_per_shard * FLAGS.n_segments_per_audio, 105 | shard_queue_name='train_filename_queue', 106 | example_queue_name='train_input_queue') 107 | 108 | x_val_batch, y_val_batch = batch_inputs( 109 | file_pattern=FLAGS.val_input_file_pattern, 110 | batch_size=FLAGS.batch_size, 111 | is_training=False, 112 | n_read_threads=1, 113 | examples_per_shard=FLAGS.n_audios_per_shard * FLAGS.n_segments_per_audio, 114 | shard_queue_name='val_filename_queue', 115 | example_queue_name='val_input_queue') 116 | 117 | # Create a model. 118 | model = SampleCNN(segments=x_train_batch, 119 | val_segments=x_val_batch, 120 | dropout_rate=FLAGS.dropout_rate) 121 | 122 | # Load weights from a checkpoint if exists. 123 | if checkpoint_path_to_load: 124 | print('Load weights from "{}".'.format(checkpoint_path_to_load)) 125 | model.load_weights(checkpoint_path_to_load) 126 | 127 | # Setup an optimizer. 128 | optimizer = SGD(lr=initial_lr, 129 | momentum=FLAGS.momentum, 130 | decay=FLAGS.local_lr_decay, 131 | nesterov=True) 132 | 133 | # Compile the model. 134 | model.compile_tfrecord(y=y_train_batch, 135 | y_val=y_val_batch, 136 | loss='binary_crossentropy', 137 | optimizer=optimizer) 138 | 139 | # Setup a TensorBoard callback. 140 | tensor_board = TensorBoard(log_dir=stage_train_dir) 141 | 142 | # Use early stopping mechanism. 143 | early_stopping = EarlyStopping(monitor='val_loss', patience=FLAGS.patience) 144 | 145 | # Setup a checkpointer. 146 | checkpoint_path = make_path( 147 | stage_train_dir, 148 | FLAGS.checkpoint_prefix + '-e{epoch:03d}-l{val_loss:.4f}.hdf5') 149 | checkpointer = ModelCheckpoint( 150 | filepath=checkpoint_path, 151 | monitor='val_loss', 152 | save_best_only=True) 153 | 154 | # Setup a CSV logger. 155 | csv_logger = CSVLogger(filename=make_path(stage_train_dir, 'training.csv'), 156 | append=True) 157 | 158 | # Kick-off the training! 159 | train_steps = calculate_steps(n_examples=FLAGS.n_train_examples, 160 | n_segments=FLAGS.n_segments_per_audio, 161 | batch_size=FLAGS.batch_size) 162 | val_steps = calculate_steps(n_examples=FLAGS.n_val_examples, 163 | n_segments=FLAGS.n_segments_per_audio, 164 | batch_size=FLAGS.batch_size) 165 | 166 | model.fit_tfrecord(epochs=100, 167 | initial_epoch=initial_epoch, 168 | steps_per_epoch=train_steps, 169 | validation_steps=val_steps, 170 | callbacks=[tensor_board, early_stopping, 171 | checkpointer, csv_logger]) 172 | 173 | # The end of the stage. Evaluate on test set. 174 | best_ckpt_path, *_ = find_best_checkpoint(stage_train_dir) 175 | print('The end of the stage. ' 176 | 'Start evaluation on test set using checkpoint "{}"' 177 | .format(best_ckpt_path)) 178 | 179 | evaluate(input_file_pattern=FLAGS.test_input_file_pattern, 180 | weights_path=best_ckpt_path, 181 | n_examples=FLAGS.n_test_examples, 182 | n_audios_per_shard=FLAGS.n_audios_per_shard, 183 | print_progress=False) 184 | 185 | 186 | def main(unused_argv): 187 | assert FLAGS.train_dir, '--train_dir is required' 188 | assert FLAGS.train_input_file_pattern, '--train_input_file_pattern is required' 189 | assert FLAGS.val_input_file_pattern, '--val_input_file_pattern is required' 190 | assert FLAGS.test_input_file_pattern, '--test_input_file_pattern is required' 191 | 192 | # Print all flags. 193 | print('@@ Flags') 194 | for key, value in FLAGS.__flags.items(): 195 | print('{}={}'.format(key, value)) 196 | 197 | for i in range(FLAGS.initial_stage, FLAGS.max_trains): 198 | stage_train_dir = make_path(FLAGS.train_dir, i) 199 | previous_stage_train_dir = make_path(FLAGS.train_dir, i - 1) 200 | next_stage_train_dir = make_path(FLAGS.train_dir, i + 1) 201 | 202 | # Pass if there is a training directory of the next stage. 203 | if os.path.isdir(next_stage_train_dir): 204 | continue 205 | 206 | # Setup the initial learning rate for the stage. 207 | decay = FLAGS.global_lr_decay ** i 208 | learning_rate = FLAGS.initial_learning_rate * decay 209 | 210 | # Create a directory for the stage. 211 | os.makedirs(stage_train_dir, exist_ok=True) 212 | 213 | # Find the best checkpoint to load weights. 214 | (ckpt_path, ckpt_epoch, ckpt_val_loss) = find_best_checkpoint( 215 | stage_train_dir, previous_stage_train_dir) 216 | 217 | print('\n@@ Start training stage {:02d}: lr={}, train_dir={}' 218 | .format(i, learning_rate, stage_train_dir)) 219 | if ckpt_path: 220 | print('Found a trained model: epoch={}, val_loss={}, path={}' 221 | .format(ckpt_epoch, ckpt_val_loss, ckpt_path)) 222 | else: 223 | print('No trained model found.') 224 | 225 | train(initial_lr=learning_rate, 226 | stage_train_dir=stage_train_dir, 227 | checkpoint_path_to_load=ckpt_path, 228 | initial_epoch=ckpt_epoch + 1) 229 | 230 | print('\nDone.') 231 | 232 | 233 | if __name__ == '__main__': 234 | tf.app.run() 235 | -------------------------------------------------------------------------------- /scripts/build_mtt.sh.template: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Please fill in the blanks! 5 | BASE_DIR= 6 | N_THREADS=4 7 | ENV_NAME= # if you don't use an environment, just leave it blank. 8 | 9 | 10 | DATA_DIR="${BASE_DIR}/raw" 11 | ANNOTATION_FILE="${BASE_DIR}/annotations_final.csv" 12 | OUTPUT_DIR="${BASE_DIR}/tfrecord" 13 | OUTPUT_LABELS="${BASE_DIR}/labels.txt" 14 | LOG_FILE="${BASE_DIR}/build_mtt.log" 15 | 16 | if [ -n "${ENV_NAME}" ]; then 17 | source activate "${ENV_NAME}" 18 | fi 19 | 20 | export PYTHONPATH='.' 21 | export PYTHONUNBUFFERED=1 22 | 23 | nohup python sample_cnn/data/mtt/build_mtt.py \ 24 | --data_dir="${DATA_DIR}" \ 25 | --annotation_file="${ANNOTATION_FILE}" \ 26 | --output_dir="${OUTPUT_DIR}" \ 27 | --output_labels="${OUTPUT_LABELS}" \ 28 | --n_threads="${N_THREADS}" > "${LOG_FILE}" & 29 | 30 | tail -F "${LOG_FILE}" 31 | -------------------------------------------------------------------------------- /scripts/evaluate.sh.template: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Please fill in the blanks! 5 | BASE_DIR= 6 | WEIGHTS_PATH= 7 | ENV_NAME= # if you don't use an environment, just leave it blank. 8 | 9 | 10 | INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/test-???-of-???.seq.tfrecords" 11 | 12 | if [ -n "${ENV_NAME}" ]; then 13 | source activate "${ENV_NAME}" 14 | fi 15 | 16 | export PYTHONPATH='.' 17 | export PYTHONUNBUFFERED=1 18 | 19 | python sample_cnn/evaluate.py \ 20 | --input_file_pattern="${INPUT_FILE_PATTERN}" \ 21 | --weights_path="${WEIGHTS_PATH}" 22 | -------------------------------------------------------------------------------- /scripts/train.sh.template: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # Please fill in the blanks! 5 | BASE_DIR= 6 | TRAIN_DIR= 7 | ENV_NAME= # if you don't use an environment, just leave it blank. 8 | 9 | 10 | TRAIN_INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/train-???-of-???.tfrecords" 11 | VAL_INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/val-???-of-???.tfrecords" 12 | TEST_INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/test-???-of-???.seq.tfrecords" 13 | 14 | LOG_FILE="${TRAIN_DIR}/train.log" 15 | 16 | mkdir -p "${TRAIN_DIR}" 17 | 18 | if [ -n "${ENV_NAME}" ]; then 19 | source activate "${ENV_NAME}" 20 | fi 21 | 22 | export PYTHONPATH='.' 23 | export PYTHONUNBUFFERED=1 24 | 25 | nohup python sample_cnn/train.py \ 26 | --train_input_file_pattern="${TRAIN_INPUT_FILE_PATTERN}" \ 27 | --val_input_file_pattern="${VAL_INPUT_FILE_PATTERN}" \ 28 | --test_input_file_pattern="${TEST_INPUT_FILE_PATTERN}" \ 29 | --train_dir="${TRAIN_DIR}" >> "${LOG_FILE}" & 30 | 31 | tail -F "${LOG_FILE}" 32 | -------------------------------------------------------------------------------- /tests/tfrecord_model_test_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | from keras.layers import Input, Conv2D, Dense, Flatten 7 | from keras.datasets import mnist 8 | from keras.callbacks import ModelCheckpoint 9 | 10 | from sample_cnn.keras_utils.tfrecord_model import TFRecordModel 11 | 12 | 13 | def data_to_tfrecord(images, labels, filename): 14 | def _int64_feature(value): 15 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 16 | 17 | def _bytes_feature(value): 18 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 19 | 20 | """ Save data into TFRecord """ 21 | if not os.path.isfile(filename): 22 | num_examples = images.shape[0] 23 | 24 | rows = images.shape[1] 25 | cols = images.shape[2] 26 | depth = images.shape[3] 27 | 28 | print('Writing', filename) 29 | writer = tf.python_io.TFRecordWriter(filename) 30 | for index in range(num_examples): 31 | image_raw = images[index].tostring() 32 | example = tf.train.Example(features=tf.train.Features(feature={ 33 | 'height': _int64_feature(rows), 34 | 'width': _int64_feature(cols), 35 | 'depth': _int64_feature(depth), 36 | 'label': _int64_feature(int(labels[index])), 37 | 'image_raw': _bytes_feature(image_raw)})) 38 | writer.write(example.SerializeToString()) 39 | writer.close() 40 | else: 41 | print('tfrecord already exist') 42 | 43 | 44 | def read_and_decode(filename, one_hot=True, n_classes=None): 45 | """ Return tensor to read from TFRecord """ 46 | filename_queue = tf.train.string_input_producer([filename]) 47 | reader = tf.TFRecordReader() 48 | _, serialized_example = reader.read(filename_queue) 49 | features = tf.parse_single_example(serialized_example, 50 | features={ 51 | 'label': tf.FixedLenFeature([], 52 | tf.int64), 53 | 'image_raw': tf.FixedLenFeature([], 54 | tf.string), 55 | }) 56 | # You can do more image distortion here for training data 57 | img = tf.decode_raw(features['image_raw'], tf.uint8) 58 | img.set_shape([28 * 28]) 59 | img = tf.reshape(img, [28, 28, 1]) 60 | img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 61 | 62 | label = tf.cast(features['label'], tf.int32) 63 | if one_hot and n_classes: 64 | label = tf.one_hot(label, n_classes) 65 | 66 | return img, label 67 | 68 | 69 | def build_net(inputs): 70 | net = Conv2D(32, 3, strides=2, activation='relu')(inputs) 71 | net = Conv2D(32, 3, strides=2, activation='relu')(net) 72 | net = Flatten()(net) 73 | net = Dense(128, activation='relu')(net) 74 | net = Dense(10, activation='softmax')(net) 75 | return net 76 | 77 | 78 | if __name__ == '__main__': 79 | TMP_DIR = '_test_tmp_dir_' 80 | TRAIN_TFRECORD_PATH = TMP_DIR + '/mnist_train.tfrecords' 81 | VAL_TFRECORD_PATH = TMP_DIR + '/mnist_val.tfrecords' 82 | TEST_TFRECORD_PATH = TMP_DIR + '/mnist_test.tfrecords' 83 | WEIGHTS_PATH = TMP_DIR + '/weights.hdf5' 84 | 85 | batch_size = 64 86 | n_train = 50000 87 | 88 | # Create temp dir 89 | os.makedirs(TMP_DIR, exist_ok=True) 90 | 91 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 92 | 93 | x_train = np.expand_dims(x_train, -1) 94 | x_test = np.expand_dims(x_test, -1) 95 | 96 | x_val = x_train[n_train:] 97 | y_val = y_train[n_train:] 98 | 99 | x_train = x_train[:n_train] 100 | y_train = y_train[:n_train] 101 | 102 | data_to_tfrecord(images=x_train, 103 | labels=y_train, 104 | filename=TRAIN_TFRECORD_PATH) 105 | data_to_tfrecord(images=x_val, 106 | labels=y_val, 107 | filename=VAL_TFRECORD_PATH) 108 | data_to_tfrecord(images=x_test, 109 | labels=y_test, 110 | filename=TEST_TFRECORD_PATH) 111 | 112 | x_train, y_train = read_and_decode(filename=TRAIN_TFRECORD_PATH, 113 | one_hot=True, 114 | n_classes=10) 115 | x_val, y_val = read_and_decode(filename=VAL_TFRECORD_PATH, 116 | one_hot=True, 117 | n_classes=10) 118 | x_test, y_test = read_and_decode(filename=TEST_TFRECORD_PATH, 119 | one_hot=True, 120 | n_classes=10) 121 | 122 | x_train_batch, y_train_batch = tf.train.shuffle_batch([x_train, y_train], 123 | batch_size=batch_size, 124 | capacity=2000, 125 | min_after_dequeue=1000, 126 | name='train_batch') 127 | x_val_batch, y_val_batch = tf.train.shuffle_batch([x_val, y_val], 128 | batch_size=batch_size, 129 | capacity=2000, 130 | min_after_dequeue=1000, 131 | name='val_batch') 132 | x_test_batch, y_test_batch = tf.train.shuffle_batch([x_test, y_test], 133 | batch_size=batch_size, 134 | capacity=2000, 135 | min_after_dequeue=1000, 136 | name='test_batch') 137 | 138 | x_train_input = Input(tensor=x_train_batch) 139 | x_val_input = Input(tensor=x_val_batch) 140 | x_test_input = Input(tensor=x_test_batch) 141 | logits = build_net(x_train_input) 142 | 143 | model = TFRecordModel(inputs=x_train_input, 144 | outputs=logits, 145 | val_inputs=x_val_input) 146 | 147 | model.compile_tfrecord(optimizer='adam', 148 | loss='categorical_crossentropy', 149 | metrics=['accuracy'], 150 | y=y_train_batch, 151 | y_val=y_val_batch) 152 | 153 | model.summary() 154 | 155 | model_checkpoint = ModelCheckpoint( 156 | filepath=TMP_DIR + '/best_weights.{epoch:02d}-{val_loss:.4f}.hdf5', 157 | save_best_only=True) 158 | 159 | model.fit_tfrecord(epochs=10, 160 | verbose=2, 161 | steps_per_epoch=n_train // batch_size + 1, 162 | validation_steps=10000 // batch_size + 1, 163 | callbacks=[model_checkpoint]) 164 | 165 | model.save_weights(WEIGHTS_PATH) 166 | model.load_weights(WEIGHTS_PATH) 167 | 168 | test_outputs = model(x_test_input) 169 | test_model = TFRecordModel(inputs=x_test_input, 170 | outputs=test_outputs) 171 | test_model.compile_tfrecord(optimizer='adam', 172 | loss='categorical_crossentropy', 173 | metrics=['accuracy'], 174 | y=y_test_batch) 175 | 176 | loss, accuracy = test_model.evaluate_tfrecord(steps=10000 // batch_size + 1) 177 | 178 | print('accuracy={}, loss={}'.format(accuracy, loss)) 179 | print('Done.') 180 | --------------------------------------------------------------------------------