├── .gitignore
├── LICENSE
├── README.md
├── requirements.txt
├── sample_cnn
├── __init__.py
├── data
│ ├── audio_processing.py
│ └── mtt
│ │ ├── annotation_processing.py
│ │ └── build_mtt.py
├── evaluate.py
├── keras_utils
│ └── tfrecord_model.py
├── model.py
├── ops
│ ├── __init__.py
│ ├── batch_inputs.py
│ └── evaluation.py
└── train.py
├── scripts
├── build_mtt.sh.template
├── evaluate.sh.template
└── train.sh.template
└── tests
└── tfrecord_model_test_mnist.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *tmp*
2 | scripts/*.sh
3 |
4 | # Created by .ignore support plugin (hsz.mobi)
5 | ### Python template
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | env/
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *,cover
51 | .hypothesis/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # dotenv
84 | .env
85 |
86 | # virtualenv
87 | .venv/
88 | venv/
89 | ENV/
90 |
91 | # Spyder project settings
92 | .spyderproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 | ### JetBrains template
97 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
98 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
99 |
100 | # User-specific stuff:
101 | .idea/workspace.xml
102 | .idea/tasks.xml
103 |
104 | # Sensitive or high-churn files:
105 | .idea/dataSources/
106 | .idea/dataSources.ids
107 | .idea/dataSources.xml
108 | .idea/dataSources.local.xml
109 | .idea/sqlDataSources.xml
110 | .idea/dynamic.xml
111 | .idea/uiDesigner.xml
112 |
113 | # Gradle:
114 | .idea/gradle.xml
115 | .idea/libraries
116 |
117 | # Mongo Explorer plugin:
118 | .idea/mongoSettings.xml
119 |
120 | ## File-based project format:
121 | *.iws
122 |
123 | ## Plugin-specific files:
124 |
125 | # IntelliJ
126 | /out/
127 |
128 | # mpeltonen/sbt-idea plugin
129 | .idea_modules/
130 |
131 | # JIRA plugin
132 | atlassian-ide-plugin.xml
133 |
134 | # Crashlytics plugin (for Android Studio and IntelliJ)
135 | com_crashlytics_export_strings.xml
136 | crashlytics.properties
137 | crashlytics-build.properties
138 | fabric.properties
139 | ### Linux template
140 | *~
141 |
142 | # temporary files which can be created if a process still has a handle open of a deleted file
143 | .fuse_hidden*
144 |
145 | # KDE directory preferences
146 | .directory
147 |
148 | # Linux trash folder which might appear on any partition or disk
149 | .Trash-*
150 |
151 | # .nfs files are created when an open file is removed but is still being accessed
152 | .nfs*
153 | ### macOS template
154 | *.DS_Store
155 | .AppleDouble
156 | .LSOverride
157 |
158 | # Icon must end with two \r
159 | Icon
160 |
161 |
162 | # Thumbnails
163 | ._*
164 |
165 | # Files that might appear in the root of a volume
166 | .DocumentRevisions-V100
167 | .fseventsd
168 | .Spotlight-V100
169 | .TemporaryItems
170 | .Trashes
171 | .VolumeIcon.icns
172 | .com.apple.timemachine.donotpresent
173 |
174 | # Directories potentially created on remote AFP share
175 | .AppleDB
176 | .AppleDesktop
177 | Network Trash Folder
178 | Temporary Items
179 | .apdisk
180 | ### VirtualEnv template
181 | # Virtualenv
182 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
183 | .Python
184 | [Bb]in
185 | [Ii]nclude
186 | [Ll]ib
187 | [Ll]ib64
188 | [Ll]ocal
189 | #[Ss]cripts
190 | pyvenv.cfg
191 | .venv
192 | pip-selfcheck.json
193 |
194 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Tae-jun Kim
2 |
3 | Permission is hereby granted, free of charge, to any person
4 | obtaining a copy of this software and associated documentation
5 | files (the "Software"), to deal in the Software without
6 | restriction, including without limitation the rights to use,
7 | copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the
9 | Software is furnished to do so, subject to the following
10 | conditions:
11 |
12 | The above copyright notice and this permission notice shall be
13 | included in all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
17 | OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
18 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
19 | HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
20 | WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
22 | OTHER DEALINGS IN THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sample CNN
2 | ***A TensorFlow implementation of "Sample-level Deep Convolutional
3 | Neural Networks for Music Auto-tagging Using Raw Waveforms"***
4 |
5 | This is a [TensorFlow][1] implementation of "[*Sample-level Deep
6 | Convolutional Neural Networks for Music Auto-tagging Using Raw
7 | Waveforms*][10]" using [Keras][11]. This repository only implements the
8 | best model of the paper. (the model described in Table 1; m=3, n=9)
9 |
10 |
11 | ## Table of contents
12 | * [Prerequisites](#prerequisites)
13 | * [Preparing MagnaTagATune (MTT) Dataset](#preparing-mtt)
14 | * [Preprocessing the MTT dataset](#preprocessing)
15 | * [Training a model from scratch](#training)
16 | * [Evaluating a model](#evaluating)
17 |
18 |
19 |
20 | ## Prerequisites
21 | * Python 3.5 and the required packages
22 | * `ffmpeg` (required for `madmom`)
23 |
24 | ### Installing required Python packages
25 | ```sh
26 | pip install -r requirements.txt
27 | pip install madmom
28 | ```
29 | The `madmom` package has a install-time dependency, so should be
30 | installed after installing packages in `requirements.txt`.
31 |
32 | This will install the required packages:
33 | * [tensorflow][1] **1.0.1** (has an issue on 1.1.0)
34 | * [keras][11]
35 | * [pandas][2]
36 | * [scikit-learn][3]
37 | * [madmom][4]
38 | * [numpy][5]
39 | * [scipy][6]
40 | * [cython][7]
41 | * [h5py][12]
42 |
43 | ### Installing ffmpeg
44 | `ffmpeg` is required for `madmom`.
45 |
46 | #### MacOS (with Homebrew):
47 | ```sh
48 | brew install ffmpeg
49 | ```
50 |
51 | #### Ubuntu:
52 | ```sh
53 | add-apt-repository ppa:mc3man/trusty-media
54 | apt-get update
55 | apt-get dist-upgrade
56 | apt-get install ffmpeg
57 | ```
58 |
59 | #### CentOS:
60 | ```sh
61 | yum install epel-release
62 | rpm --import http://li.nux.ro/download/nux/RPM-GPG-KEY-nux.ro
63 | rpm -Uvh http://li.nux.ro/download/nux/dextop/el ... noarch.rpm
64 | yum install ffmpeg
65 | ```
66 |
67 |
68 |
69 | ## Preparing MagnaTagATune (MTT) dataset
70 | Download audio data and tag annotations from [here][8]. Then you should
71 | see 3 `.zip` files and 1 `.csv` file:
72 | ```sh
73 | mp3.zip.001
74 | mp3.zip.002
75 | mp3.zip.003
76 | annotations_final.csv
77 | ```
78 |
79 | To unzip the `.zip` files, merge and unzip them (referenced [here][9]):
80 | ```sh
81 | cat mp3.zip.* > mp3_all.zip
82 | unzip mp3_all.zip
83 | ```
84 |
85 | You should see 16 directories named `0` to `f`. Typically, `0 ~ b` are
86 | used to training, `c` to validation, and `d ~ f` to test.
87 |
88 | To make your life easier, place them in a directory as below:
89 | ```sh
90 | ├── annotations_final.csv
91 | └── raw
92 | ├── 0
93 | ├── 1
94 | ├── ...
95 | └── f
96 | ```
97 |
98 | And we will call the directory `BASE_DIR`. Preparing the MTT dataset is Done!
99 |
100 |
101 |
102 | ## Preprocessing the MTT dataset
103 | This section describes a required preprocessing task for the MTT
104 | dataset. Note that this requires `57G` storage space.
105 |
106 | These are what the preprocessing does:
107 | * Select top 50 tags in `annotations_final.csv`
108 | * Split dataset into training, validation, and test sets
109 | * Segment the raw audio files into `59049` sample length
110 | * Convert to TFRecord format
111 |
112 | To run the preprocessing, copy a shell template and edit the copy:
113 | ```sh
114 | cp scripts/build_mtt.sh.template scripts/build_mtt.sh
115 | vi scripts/build_mtt.sh
116 | ```
117 |
118 | You should fill in the environment variables:
119 | * `BASE_DIR` the directory contains `annotations_final.csv` file and
120 | `raw` directory
121 | * `N_PROCESSES` number of processes to use; the preprocessing uses
122 | multi-processing
123 | * `ENV_NAME` (optional) if you use `virtualenv` or `conda` to create a
124 | separated environment, write your environment name
125 |
126 | The below is an example:
127 | ```sh
128 | BASE_DIR="/path/to/mtt/basedir"
129 | N_PROCESSES=4
130 | ENV_NAME="sample_cnn"
131 | ```
132 |
133 | And run it:
134 | ```sh
135 | ./scripts/build_mtt.sh
136 | ```
137 |
138 | The script will **automatically run a process in the background**, and
139 | **tail output** which the process prints. This will take a few minutes
140 | to an hour according to your device.
141 |
142 | The converted TFRecord files will be located in your
143 | `${BASE_DIR}/tfrecord`. Now, your `BASE_DIR`'s structure should be like
144 | this:
145 | ```sh
146 | ├── annotations_final.csv
147 | ├── build_mtt.log
148 | ├── labels.txt
149 | ├── raw
150 | │ ├── 0
151 | │ ├── ...
152 | │ └── f
153 | └── tfrecord
154 | ├── test-000-of-036.seq.tfrecords
155 | ├── ...
156 | ├── test-035-of-036.seq.tfrecords
157 | ├── train-000-of-128.tfrecords
158 | ├── ...
159 | ├── train-127-of-128.tfrecords
160 | ├── val-000-of-012.seq.tfrecords
161 | ├── ...
162 | └── val-011-of-012.seq.tfrecords
163 | ```
164 |
165 |
166 |
167 | ## Training a model from scratch
168 | To train a model from scratch, copy a shell template and edit the
169 | copy like what did above:
170 | ```sh
171 | cp scripts/train.sh.template scripts/train.sh
172 | vi scripts/train.sh
173 | ```
174 |
175 | And fill in the environment variables:
176 | * `BASE_DIR` the directory contains `tfrecord` directory
177 | * `TRAIN_DIR` where to save your trained model, and summaries to
178 | visualize your training using TensorBoard
179 | * `ENV_NAME` (optional) if you use `virtualenv` or `conda` to create a
180 | separated environment, write your environment name
181 |
182 | The below is an example:
183 | ```sh
184 | BASE_DIR="/path/to/mtt/basedir"
185 | TRAIN_DIR="/path/to/save/outputs"
186 | ENV_NAME="sample_cnn"
187 | ```
188 |
189 | Let's kick off the training!:
190 | ```sh
191 | ./scripts/train.sh
192 | ```
193 |
194 | The script will **automatically run a process in the background**, and
195 | **tail output** which the process prints.
196 |
197 |
198 |
199 | ## Evaluating a model
200 | Copy an evaluating shell script template and edit the copy:
201 | ```sh
202 | cp scripts/evaluate.sh.template scripts/evaluate.sh
203 | vi scripts/evaluate.sh
204 | ```
205 |
206 | Fill in the environment variables:
207 | * `BASE_DIR` the directory contains `tfrecord` directory
208 | * `CHECKPOINT_DIR` where you saved your model (`TRAIN_DIR` when training)
209 | * `ENV_NAME` (optional) if you use `virtualenv` or `conda` to create a
210 | separated environment, write your environment name
211 |
212 | The script doesn't evaluate the latest model but the best model. If you
213 | want to evaluate the latest model, you should give `--best=False` as an
214 | option.
215 |
216 | [1]: https://www.tensorflow.org/
217 | [2]: http://pandas.pydata.org/
218 | [3]: https://www.scipy.org/
219 | [4]: https://madmom.readthedocs.io/en/latest/
220 | [5]: http://www.numpy.org/
221 | [6]: https://www.scipy.org
222 | [7]: http://cython.org/
223 | [8]: http://mirg.city.ac.uk/codeapps/the-magnatagatune-dataset
224 | [9]: https://github.com/keunwoochoi/magnatagatune-list
225 | [10]: https://arxiv.org/abs/1703.01789
226 | [11]: https://keras.io/
227 | [12]: http://www.h5py.org
228 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Cython==0.25.2
2 | numpy==1.12.1
3 | pandas==0.19.2
4 | scikit-learn==0.18.1
5 | scipy==0.19.0
6 | tensorflow-gpu==1.1.0
7 | h5py==2.7.1
8 |
--------------------------------------------------------------------------------
/sample_cnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tae-jun/sample-cnn/a8581d6aaa4233e981ffa83a74c7f88e1c450647/sample_cnn/__init__.py
--------------------------------------------------------------------------------
/sample_cnn/data/audio_processing.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from madmom.audio.signal import Signal
5 |
6 | FLAGS = tf.flags.FLAGS
7 |
8 |
9 | def _int64_feature(value):
10 | """Wrapper for inserting int64 features into Example proto."""
11 | if not isinstance(value, list):
12 | value = [value]
13 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
14 |
15 |
16 | def _bytes_feature(value):
17 | """Wrapper for inserting bytes features into Example proto."""
18 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
19 |
20 |
21 | def _bytes_feature_list(values):
22 | """Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
23 | return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])
24 |
25 |
26 | def _segments_to_sequence_example(segments, labels):
27 | """Converts a list of segments to a SequenceExample proto.
28 |
29 | Args:
30 | segments: A list of segments.
31 | labels: A list of labels of the segments.
32 |
33 | Returns:
34 | A SequenceExample proto.
35 | """
36 | raw_segments = [segment.tostring() for segment in segments]
37 | raw_labels = np.array(labels, dtype=np.uint8).tostring()
38 |
39 | context = tf.train.Features(feature={
40 | 'raw_labels': _bytes_feature(raw_labels) # uint8 Tensor (50,)
41 | })
42 |
43 | feature_lists = tf.train.FeatureLists(feature_list={
44 | # list of float32 Tensor (59049,)
45 | 'raw_segments': _bytes_feature_list(raw_segments)
46 | })
47 |
48 | sequence_example = tf.train.SequenceExample(
49 | context=context, feature_lists=feature_lists)
50 |
51 | return sequence_example
52 |
53 |
54 | def _segment_to_example(segment, labels):
55 | """Converts a list of segments to a list of Example protos.
56 |
57 | Args:
58 | segments: A list of segments.
59 | labels: A list of labels of the segments.
60 |
61 | Returns:
62 | A list of Example protos.
63 | """
64 | # dtype of segment is float32
65 | raw_segment = segment.tostring()
66 | raw_labels = np.array(labels, dtype=np.uint8).tostring()
67 |
68 | example = tf.train.Example(features=tf.train.Features(feature={
69 | 'raw_labels': _bytes_feature(raw_labels), # uint8 Tensor (50,)
70 | 'raw_segment': _bytes_feature(raw_segment) # float32 Tensor (59049,)
71 | }))
72 |
73 | return example
74 |
75 |
76 | def _audio_to_segments(filename, sample_rate, n_samples, center=False):
77 | """Loads, and splits an audio into N segments.
78 |
79 | Args:
80 | filename: A path to the audio.
81 | sample_rate: Sampling rate of the audios. If the sampling rate is different
82 | with an audio's original sampling rate, then it re-samples the audio.
83 | n_samples: Number of samples one segment contains.
84 |
85 | Returns:
86 | A list of numpy arrays; segments.
87 | """
88 | # Load an audio file as a numpy array
89 | sig = Signal(filename, sample_rate=sample_rate, dtype=np.float32)
90 |
91 | total_samples = sig.shape[0]
92 | n_segment = total_samples // n_samples
93 |
94 | if center:
95 | # Take center samples
96 | remainder = total_samples % n_samples
97 | sig = sig[remainder // 2: -remainder // 2]
98 |
99 | # Split the signal into segments
100 | segments = [sig[i * n_samples:(i + 1) * n_samples] for i in range(n_segment)]
101 |
102 | return segments
103 |
104 |
105 | def audio_to_sequence_example(filename, labels, sample_rate, n_samples):
106 | """Converts an audio to a SequenceExample proto.
107 |
108 | Args:
109 | filename: A path to the audio.
110 | labels: A list of labels of the audio.
111 | sample_rate: Sampling rate of the audios. If the sampling rate is different
112 | with an audio's original sampling rate, then it re-samples the audio.
113 | n_samples: Number of samples one segment contains.
114 |
115 | Returns:
116 | A SequenceExample proto.
117 | """
118 | segments = _audio_to_segments(filename,
119 | sample_rate=sample_rate,
120 | n_samples=n_samples)
121 |
122 | sequence_example = _segments_to_sequence_example(segments, labels)
123 |
124 | return sequence_example
125 |
126 |
127 | def audio_to_examples(filename, labels, sample_rate, n_samples):
128 | """Converts an audio to a list of Example protos.
129 |
130 | Args:
131 | filename: A path to the audio.
132 | labels: A list of labels of the audio.
133 | sample_rate: Sampling rate of the audios. If the sampling rate is different
134 | with an audio's original sampling rate, then it re-samples the audio.
135 | n_samples: Number of samples one segment contains.
136 |
137 | Returns:
138 | A list of Example protos.
139 | """
140 | segments = _audio_to_segments(filename,
141 | sample_rate=sample_rate,
142 | n_samples=n_samples)
143 |
144 | examples = [_segment_to_example(segment, labels) for segment in segments]
145 |
146 | return examples
147 |
--------------------------------------------------------------------------------
/sample_cnn/data/mtt/annotation_processing.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | np.random.seed(0)
5 |
6 |
7 | def load_annotations(filename,
8 | n_top=50,
9 | n_audios_per_shard=100):
10 | """Reads annotation file, takes top N tags, and splits data samples.
11 |
12 | Results 54 (top50_tags + [clip_id, mp3_path, split, shard]) columns:
13 |
14 | ['choral', 'female voice', 'metal', 'country', 'weird', 'no voice',
15 | 'cello', 'harp', 'beats', 'female vocal', 'male voice', 'dance',
16 | 'new age', 'voice', 'choir', 'classic', 'man', 'solo', 'sitar', 'soft',
17 | 'pop', 'no vocal', 'male vocal', 'woman', 'flute', 'quiet', 'loud',
18 | 'harpsichord', 'no vocals', 'vocals', 'singing', 'male', 'opera',
19 | 'indian', 'female', 'synth', 'vocal', 'violin', 'beat', 'ambient',
20 | 'piano', 'fast', 'rock', 'electronic', 'drums', 'strings', 'techno',
21 | 'slow', 'classical', 'guitar', 'clip_id', 'mp3_path', 'split', 'shard']
22 |
23 | NOTE: This will exclude audios which have only zero-tags. Therefore, number of
24 | each split will be 15250 / 1529 / 4332 (training / validation / test).
25 |
26 | Args:
27 | filename: A path to annotation CSV file.
28 | n_top: Number of the most popular tags to take.
29 | n_audios_per_shard: Number of audios per shard.
30 |
31 | Returns:
32 | A DataFrame contains information of audios.
33 |
34 | Schema:
35 | : 0 or 1
36 | clip_id: clip_id of the original dataset
37 | mp3_path: A path to a mp3 audio file.
38 | split: A split of dataset (training / validation / test).
39 | The split is determined by its directory (0, 1, ... , f).
40 | First 12 directories (0 ~ b) are used for training,
41 | 1 (c) for validation, and 3 (d ~ f) for test.
42 | shard: A shard index of the audio.
43 | """
44 | df = pd.read_csv(filename, delimiter='\t')
45 |
46 | top50 = (df.drop(['clip_id', 'mp3_path'], axis=1)
47 | .sum()
48 | .sort_values()
49 | .tail(n_top)
50 | .index
51 | .tolist())
52 |
53 | df = df[top50 + ['clip_id', 'mp3_path']]
54 |
55 | # Exclude rows which only have zeros.
56 | df = df.ix[~(df.ix[:, :n_top] == 0).all(axis=1)]
57 |
58 | def split_by_directory(mp3_path):
59 | directory = mp3_path.split('/')[0]
60 | part = int(directory, 16)
61 |
62 | if part in range(12):
63 | return 'train'
64 | elif part is 12:
65 | return 'val'
66 | elif part in range(13, 16):
67 | return 'test'
68 |
69 | df['split'] = df['mp3_path'].apply(
70 | lambda mp3_path: split_by_directory(mp3_path))
71 |
72 | for split in ['train', 'val', 'test']:
73 | n_audios = sum(df['split'] == split)
74 | n_shards = n_audios // n_audios_per_shard
75 | n_remainders = n_audios % n_audios_per_shard
76 |
77 | shards = np.tile(np.arange(n_shards), n_audios_per_shard)
78 | shards = np.concatenate([shards, np.arange(n_remainders)])
79 | shards = np.random.permutation(shards)
80 |
81 | df.loc[df['split'] == split, 'shard'] = shards
82 |
83 | df['shard'] = df['shard'].astype(int)
84 |
85 | return df
86 |
--------------------------------------------------------------------------------
/sample_cnn/data/mtt/build_mtt.py:
--------------------------------------------------------------------------------
1 | """Converts MagnaTagATune dataset to TFRecord files. """
2 |
3 | import os
4 |
5 | import pandas as pd
6 | import tensorflow as tf
7 |
8 | from threading import Thread
9 | from queue import Queue
10 |
11 | from madmom.audio.signal import LoadAudioFileError
12 |
13 | from sample_cnn.data.audio_processing import (audio_to_sequence_example,
14 | audio_to_examples)
15 | from sample_cnn.data.mtt.annotation_processing import load_annotations
16 |
17 | tf.flags.DEFINE_string('data_dir', '', 'MagnaTagATune audio dataset directory.')
18 | tf.flags.DEFINE_string('annotation_file', '',
19 | 'A path to CSV annotation file which contains labels.')
20 | tf.flags.DEFINE_string('output_dir', '', 'Output data directory.')
21 | tf.flags.DEFINE_string('output_labels', '', 'Output label file.')
22 |
23 | tf.flags.DEFINE_integer('n_threads', 4,
24 | 'Number of threads to process audios in parallel.')
25 |
26 | tf.flags.DEFINE_integer('n_audios_per_shard', 100,
27 | 'Number of audios per shard.')
28 |
29 | # Audio processing flags
30 | tf.flags.DEFINE_integer('n_top', 50, 'Number of top N tags.')
31 | tf.flags.DEFINE_integer('sample_rate', 22050, 'Sample rate of audio.')
32 | tf.flags.DEFINE_integer('n_samples', 59049, 'Number of samples per segment.')
33 |
34 | FLAGS = tf.flags.FLAGS
35 |
36 |
37 | def _process_audio_files(args_queue):
38 | """Processes and saves audios as TFRecord files in one sub-process.
39 |
40 | Args:
41 | args_queue: A queue contains arguments which consist of:
42 | assigned_anno: A DataFrame which contains information about the audios
43 | that should be process in this sub-process.
44 | sample_rate: Sampling rate of the audios. If the sampling rate is different
45 | with an audio's original sampling rate, then it re-samples the audio.
46 | n_samples: Number of samples one segment contains.
47 | split: Dataset split which is one of 'train', 'val', or 'test'.
48 | shard: Shard index.
49 | n_shards: Number of the entire shards.
50 | """
51 | while not args_queue.empty():
52 | (assigned_anno, sample_rate,
53 | n_samples, split, shard, n_shards) = args_queue.get()
54 |
55 | is_test = (split == 'test')
56 |
57 | output_filename_format = ('{}-{:03d}-of-{:03d}.seq.tfrecords'
58 | if is_test else
59 | '{}-{:03d}-of-{:03d}.tfrecords')
60 | output_filename = output_filename_format.format(split, shard, n_shards)
61 | output_file_path = os.path.join(FLAGS.output_dir, output_filename)
62 |
63 | writer = tf.python_io.TFRecordWriter(output_file_path)
64 |
65 | for _, row in assigned_anno.iterrows():
66 | audio_path = os.path.join(FLAGS.data_dir, row['mp3_path'])
67 | labels = row[:FLAGS.n_top].tolist()
68 |
69 | try:
70 | if is_test:
71 | examples = [audio_to_sequence_example(audio_path, labels,
72 | sample_rate, n_samples)]
73 | else:
74 | examples = audio_to_examples(audio_path, labels, sample_rate,
75 | n_samples)
76 | except LoadAudioFileError:
77 | # There are some broken mp3 files. Ignore it.
78 | print('Cannot load audio "{}". Ignore it.'.format(audio_path))
79 | continue
80 |
81 | for example in examples:
82 | writer.write(example.SerializeToString())
83 |
84 | writer.close()
85 |
86 | print('{} audios are written into "{}". {} shards left.'
87 | .format(len(assigned_anno), output_filename, args_queue.qsize()))
88 |
89 |
90 | def _process_dataset(anno, sample_rate, n_samples, n_threads):
91 | """Processes, and saves MagnaTagATune dataset using multi-processes.
92 |
93 | Args:
94 | anno: Annotation DataFrame contains tags, mp3_path, split, and shard.
95 | sample_rate: Sampling rate of the audios. If the sampling rate is different
96 | with an audio's original sampling rate, then it re-samples the audio.
97 | n_samples: Number of samples one segment contains.
98 | n_threads: Number of threads to process the dataset.
99 | """
100 | args_queue = Queue()
101 | split_and_shard_sets = pd.unique([tuple(x) for x in anno[['split', 'shard']].values])
102 |
103 | for split, shard in split_and_shard_sets:
104 | assigned_anno = anno[(anno['split'] == split) & (anno['shard'] == shard)]
105 | n_shards = anno[anno['split'] == split]['shard'].nunique()
106 |
107 | args = (assigned_anno, sample_rate, n_samples, split, shard, n_shards)
108 | args_queue.put(args)
109 |
110 | if FLAGS.n_threads > 1:
111 | threads = []
112 | for _ in range(FLAGS.n_threads):
113 | thread = Thread(target=_process_audio_files, args=[args_queue])
114 | thread.start()
115 | threads.append(thread)
116 |
117 | for thread in threads:
118 | thread.join()
119 | else:
120 | _process_audio_files(args_queue)
121 |
122 |
123 | def _save_tags(tag_list, output_labels):
124 | """Saves a list of tags to a file.
125 |
126 | Args:
127 | tag_list: The list of tags.
128 | output_labels: A path to save the list.
129 | """
130 | with open(output_labels, 'w') as f:
131 | f.write('\n'.join(tag_list))
132 |
133 |
134 | def main(unused_argv):
135 | df = load_annotations(filename=FLAGS.annotation_file,
136 | n_top=FLAGS.n_top,
137 | n_audios_per_shard=FLAGS.n_audios_per_shard)
138 |
139 | if not tf.gfile.IsDirectory(FLAGS.output_dir):
140 | tf.logging.info('Creating output directory: %s', FLAGS.output_dir)
141 | tf.gfile.MakeDirs(FLAGS.output_dir)
142 |
143 | # Save top N tags
144 | tag_list = df.columns[:FLAGS.n_top].tolist()
145 | _save_tags(tag_list, FLAGS.output_labels)
146 | print('Top {} tags written to {}'.format(len(tag_list), FLAGS.output_labels))
147 |
148 | df_train = df[df['split'] == 'train']
149 | df_val = df[df['split'] == 'val']
150 | df_test = df[df['split'] == 'test']
151 |
152 | n_train = len(df_train)
153 | n_val = len(df_val)
154 | n_test = len(df_test)
155 | print('Number of songs for each split: {} / {} / {} '
156 | '(training / validation / test)'.format(n_train, n_val, n_test))
157 |
158 | n_train_shards = df_train['shard'].nunique()
159 | n_val_shards = df_val['shard'].nunique()
160 | n_test_shards = df_test['shard'].nunique()
161 | print('Number of shards for each split: {} / {} / {} '
162 | '(training / validation / test)'.format(n_train_shards,
163 | n_val_shards, n_test_shards))
164 |
165 | print('Start processing MagnaTagATune using {} threads'
166 | .format(FLAGS.n_threads))
167 | _process_dataset(df, FLAGS.sample_rate, FLAGS.n_samples, FLAGS.n_threads)
168 |
169 | print()
170 | print('Done.')
171 |
172 |
173 | if __name__ == '__main__':
174 | tf.app.run()
175 |
--------------------------------------------------------------------------------
/sample_cnn/evaluate.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from sample_cnn.ops import evaluate
4 |
5 | tf.flags.DEFINE_string('input_file_pattern', '',
6 | 'File pattern of sharded TFRecord input files.')
7 | tf.flags.DEFINE_string('weights_path', '', 'Path to learned weights.')
8 |
9 | tf.flags.DEFINE_integer('n_examples', 4332, 'Number of examples to run.')
10 | tf.flags.DEFINE_integer('n_audios_per_shard', 100,
11 | 'Number of audios per shard.')
12 |
13 | tf.logging.set_verbosity(tf.logging.INFO)
14 |
15 | FLAGS = tf.flags.FLAGS
16 |
17 |
18 | def main(unused_argv):
19 | assert FLAGS.input_file_pattern, '--input_file_pattern is required'
20 | assert FLAGS.weights_path, '--weights_path is required'
21 |
22 | evaluate(FLAGS.input_file_pattern, FLAGS.weights_path,
23 | n_examples=FLAGS.n_examples,
24 | n_audios_per_shard=FLAGS.n_audios_per_shard)
25 |
26 |
27 | if __name__ == '__main__':
28 | tf.app.run()
29 |
--------------------------------------------------------------------------------
/sample_cnn/keras_utils/tfrecord_model.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import copy
3 |
4 | import six
5 | import tensorflow as tf
6 | import numpy as np
7 |
8 | from keras.models import Model
9 | from keras import optimizers
10 | from keras import losses
11 | from keras import callbacks as cbks
12 | from keras import backend as K
13 | from keras import metrics as metrics_module
14 | from keras.engine.training import _collect_metrics
15 |
16 |
17 | class TFRecordModel(Model):
18 | def __init__(self, inputs, outputs, val_inputs=None, name=None):
19 | super(TFRecordModel, self).__init__(inputs, outputs, name=name)
20 |
21 | # Prepare val_inputs.
22 | if val_inputs is None:
23 | self.val_inputs = []
24 | elif isinstance(val_inputs, (list, tuple)):
25 | self.val_inputs = list(val_inputs) # Tensor or list of tensors.
26 | else:
27 | self.val_inputs = [val_inputs]
28 |
29 | # Prepare val_outputs.
30 | if val_inputs is None:
31 | self.val_outputs = []
32 | else:
33 | val_outputs = self(val_inputs)
34 | if isinstance(val_outputs, (list, tuple)):
35 | self.val_outputs = list(val_outputs) # Tensor or list of tensors.
36 | else:
37 | self.val_outputs = [val_outputs]
38 |
39 | def compile_tfrecord(self, optimizer, loss, y, metrics=None,
40 | y_val=None):
41 | """Configures the model for training.
42 |
43 | # Arguments
44 | optimizer: str (name of optimizer) or optimizer object.
45 | See [optimizers](/optimizers).
46 | loss: str (name of objective function) or objective function.
47 | See [losses](/losses).
48 | If the model has multiple outputs, you can use a different loss
49 | on each output by passing a dictionary or a list of losses.
50 | The loss value that will be minimized by the model
51 | will then be the sum of all individual losses.
52 | metrics: list of metrics to be evaluated by the model
53 | during training and testing.
54 | Typically you will use `metrics=['accuracy']`.
55 | To specify different metrics for different outputs of a
56 | multi-output model, you could also pass a dictionary,
57 | such as `metrics={'output_a': 'accuracy'}`.
58 |
59 | # Raises
60 | ValueError: In case of invalid arguments for
61 | `optimizer`, `loss`, `metrics` or `sample_weight_mode`.
62 | """
63 | loss = loss or {}
64 | self.optimizer = optimizers.get(optimizer)
65 | self.loss = loss
66 | self.sample_weight_mode = None
67 | self.loss_weights = None
68 | self.y_val = y_val
69 | self.constraints = None
70 |
71 | do_validation = bool(len(self.val_inputs) > 0)
72 | if do_validation and y_val is None:
73 | raise ValueError('When you use validation inputs, '
74 | 'you should provide y_val.')
75 |
76 | # Prepare loss functions.
77 | if isinstance(loss, dict):
78 | for name in loss:
79 | if name not in self.output_names:
80 | raise ValueError('Unknown entry in loss '
81 | 'dictionary: "' + name + '". '
82 | 'Only expected the following keys: ' +
83 | str(self.output_names))
84 | loss_functions = []
85 | for name in self.output_names:
86 | if name not in loss:
87 | warnings.warn('Output "' + name +
88 | '" missing from loss dictionary. '
89 | 'We assume this was done on purpose, '
90 | 'and we will not be expecting '
91 | 'any data to be passed to "' + name +
92 | '" during training.', stacklevel=2)
93 | loss_functions.append(losses.get(loss.get(name)))
94 | elif isinstance(loss, list):
95 | if len(loss) != len(self.outputs):
96 | raise ValueError('When passing a list as loss, '
97 | 'it should have one entry per model outputs. '
98 | 'The model has ' + str(len(self.outputs)) +
99 | ' outputs, but you passed loss=' +
100 | str(loss))
101 | loss_functions = [losses.get(l) for l in loss]
102 | else:
103 | loss_function = losses.get(loss)
104 | loss_functions = [loss_function for _ in range(len(self.outputs))]
105 | self.loss_functions = loss_functions
106 |
107 | # Prepare training targets of model.
108 | if isinstance(y, (list, tuple)):
109 | y = list(y) # Tensor or list of tensors.
110 | else:
111 | y = [y]
112 | self.targets = []
113 | for i in range(len(self.outputs)):
114 | target = y[i]
115 | self.targets.append(target)
116 |
117 | # Prepare validation targets of model.
118 | if isinstance(y_val, (list, tuple)):
119 | y_val = list(y_val) # Tensor or list of tensors.
120 | else:
121 | y_val = [y_val]
122 | self.y_val = y_val
123 | self.val_targets = []
124 | for i in range(len(self.val_outputs)):
125 | val_target = y_val[i]
126 | self.val_targets.append(val_target)
127 |
128 | # Prepare metrics.
129 | self.metrics = metrics
130 | self.metrics_names = ['loss']
131 | self.metrics_tensors = []
132 | self.val_metrics_names = ['loss']
133 | self.val_metrics_tensors = []
134 |
135 | # Compute total training loss.
136 | total_loss = None
137 | for i in range(len(self.outputs)):
138 | y_true = self.targets[i]
139 | y_pred = self.outputs[i]
140 | loss_function = loss_functions[i]
141 | val_output_loss = K.mean(loss_function(y_true, y_pred))
142 | if len(self.outputs) > 1:
143 | self.metrics_tensors.append(val_output_loss)
144 | self.metrics_names.append(self.output_names[i] + '_loss')
145 | if total_loss is None:
146 | total_loss = val_output_loss
147 | else:
148 | total_loss += val_output_loss
149 | if total_loss is None:
150 | if not self.losses:
151 | raise RuntimeError('The model cannot be compiled '
152 | 'because it has no loss to optimize.')
153 | else:
154 | total_loss = 0.
155 |
156 | # Compute total validation loss.
157 | val_total_loss = None
158 | for i in range(len(self.val_outputs)):
159 | y_true = self.val_targets[i]
160 | y_pred = self.val_outputs[i]
161 | loss_function = loss_functions[i]
162 | val_output_loss = K.mean(loss_function(y_true, y_pred))
163 | if len(self.outputs) > 1:
164 | self.val_metrics_tensors.append(val_output_loss)
165 | self.val_metrics_names.append(self.output_names[i] + '_val_loss')
166 | if val_total_loss is None:
167 | val_total_loss = val_output_loss
168 | else:
169 | val_total_loss += val_output_loss
170 | if val_total_loss is None:
171 | if not self.losses and do_validation:
172 | raise RuntimeError('The model cannot be compiled '
173 | 'because it has no loss to optimize.')
174 | else:
175 | val_total_loss = 0.
176 |
177 | # Add regularization penalties
178 | # and other layer-specific losses.
179 | for loss_tensor in self.losses:
180 | total_loss += loss_tensor
181 | val_total_loss += loss_tensor
182 |
183 | # List of same size as output_names.
184 | # contains tuples (metrics for output, names of metrics).
185 | nested_metrics = _collect_metrics(metrics, self.output_names)
186 |
187 | def append_metric(layer_num, metric_name, metric_tensor):
188 | """Helper function used in loop below."""
189 | if len(self.output_names) > 1:
190 | metric_name = self.output_layers[layer_num].name + '_' + metric_name
191 | self.metrics_names.append(metric_name)
192 | self.metrics_tensors.append(metric_tensor)
193 |
194 | for i in range(len(self.outputs)):
195 | y_true = self.targets[i]
196 | y_pred = self.outputs[i]
197 | output_metrics = nested_metrics[i]
198 | for metric in output_metrics:
199 | if metric == 'accuracy' or metric == 'acc':
200 | # custom handling of accuracy
201 | # (because of class mode duality)
202 | output_shape = self.internal_output_shapes[i]
203 | acc_fn = None
204 | if (output_shape[-1] == 1 or
205 | self.loss_functions[i] == losses.binary_crossentropy):
206 | # case: binary accuracy
207 | acc_fn = metrics_module.binary_accuracy
208 | elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
209 | # case: categorical accuracy with sparse targets
210 | acc_fn = metrics_module.sparse_categorical_accuracy
211 | else:
212 | acc_fn = metrics_module.categorical_accuracy
213 |
214 | append_metric(i, 'acc', K.mean(acc_fn(y_true, y_pred)))
215 | else:
216 | metric_fn = metrics_module.get(metric)
217 | metric_result = metric_fn(y_true, y_pred)
218 | metric_result = {
219 | metric_fn.__name__: metric_result
220 | }
221 | for name, tensor in six.iteritems(metric_result):
222 | append_metric(i, name, tensor)
223 |
224 | def append_val_metric(layer_num, metric_name, metric_tensor):
225 | """Helper function used in loop below."""
226 | if len(self.output_names) > 1:
227 | metric_name = self.output_layers[layer_num].name + '_val_' + metric_name
228 | self.val_metrics_names.append(metric_name)
229 | self.val_metrics_tensors.append(metric_tensor)
230 |
231 | for i in range(len(self.val_outputs)):
232 | y_true = self.val_targets[i]
233 | y_pred = self.val_outputs[i]
234 | output_metrics = nested_metrics[i]
235 | for metric in output_metrics:
236 | if metric == 'accuracy' or metric == 'acc':
237 | # custom handling of accuracy
238 | # (because of class mode duality)
239 | output_shape = self.internal_output_shapes[i]
240 | acc_fn = None
241 | if (output_shape[-1] == 1 or
242 | self.loss_functions[i] == losses.binary_crossentropy):
243 | # case: binary accuracy
244 | acc_fn = metrics_module.binary_accuracy
245 | elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:
246 | # case: categorical accuracy with sparse targets
247 | acc_fn = metrics_module.sparse_categorical_accuracy
248 | else:
249 | acc_fn = metrics_module.categorical_accuracy
250 |
251 | append_val_metric(i, 'acc', K.mean(acc_fn(y_true, y_pred)))
252 | else:
253 | metric_fn = metrics_module.get(metric)
254 | metric_result = metric_fn(y_true, y_pred)
255 | metric_result = {
256 | metric_fn.__name__: metric_result
257 | }
258 | for name, tensor in six.iteritems(metric_result):
259 | append_val_metric(i, name, tensor)
260 |
261 | # Prepare gradient updates and state updates.
262 | self.total_loss = total_loss
263 | self.val_total_loss = val_total_loss
264 |
265 | # Functions for train, test and predict will
266 | # be compiled lazily when required.
267 | # This saves time when the user is not using all functions.
268 | self.train_function = None
269 | self.val_function = None
270 | self.test_function = None
271 | self.predict_function = None
272 |
273 | # Collected trainable weights and sort them deterministically.
274 | trainable_weights = self.trainable_weights
275 | # Sort weights by name.
276 | if trainable_weights:
277 | trainable_weights.sort(key=lambda x: x.name)
278 | self._collected_trainable_weights = trainable_weights
279 |
280 | def _make_tfrecord_train_function(self):
281 | if not hasattr(self, 'train_function'):
282 | raise RuntimeError('You must compile your model before using it.')
283 | if self.train_function is None:
284 | inputs = []
285 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
286 | inputs += [K.learning_phase()]
287 |
288 | training_updates = self.optimizer.get_updates(
289 | self._collected_trainable_weights,
290 | self.constraints,
291 | self.total_loss)
292 | updates = self.updates + training_updates
293 | # Gets loss and metrics. Updates weights at each call.
294 | self.train_function = K.function(inputs,
295 | [self.total_loss] + self.metrics_tensors,
296 | updates=updates)
297 |
298 | def _make_tfrecord_test_function(self):
299 | if not hasattr(self, 'test_function'):
300 | raise RuntimeError('You must compile your model before using it.')
301 | if self.test_function is None:
302 | inputs = []
303 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
304 | inputs += [K.learning_phase()]
305 | # Return loss and metrics, no gradient updates.
306 | # Does update the network states.
307 | self.test_function = K.function(inputs,
308 | [self.total_loss] + self.metrics_tensors,
309 | updates=self.state_updates)
310 |
311 | def _make_tfrecord_val_function(self):
312 | if not hasattr(self, 'val_function'):
313 | raise RuntimeError('You must compile your model before using it.')
314 | if self.val_function is None:
315 | inputs = []
316 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
317 | inputs += [K.learning_phase()]
318 | # Return loss and metrics, no gradient updates.
319 | # Does update the network states.
320 | self.val_function = K.function(
321 | inputs,
322 | [self.val_total_loss] + self.val_metrics_tensors,
323 | updates=self.state_updates)
324 |
325 | def fit_tfrecord(self, steps_per_epoch,
326 | epochs=1,
327 | verbose=1,
328 | callbacks=None,
329 | validation_steps=None,
330 | initial_epoch=0):
331 | epoch = initial_epoch
332 |
333 | self._make_tfrecord_train_function()
334 |
335 | do_validation = bool(len(self.val_inputs) > 0)
336 | if do_validation and not validation_steps:
337 | raise ValueError('When using a validation batch, '
338 | 'you must specify a value for '
339 | '`validation_steps`.')
340 |
341 | # Prepare display labels.
342 | out_labels = self._get_deduped_metrics_names()
343 |
344 | if do_validation:
345 | callback_metrics = copy.copy(out_labels) + ['val_' + n
346 | for n in out_labels]
347 | else:
348 | callback_metrics = copy.copy(out_labels)
349 |
350 | # prepare callbacks
351 | self.history = cbks.History()
352 | callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history]
353 | if verbose:
354 | callbacks += [cbks.ProgbarLogger(count_mode='steps')]
355 | callbacks = cbks.CallbackList(callbacks)
356 |
357 | # it's possible to callback a different model than self:
358 | if hasattr(self, 'callback_model') and self.callback_model:
359 | callback_model = self.callback_model
360 | else:
361 | callback_model = self
362 | callbacks.set_model(callback_model)
363 | callbacks.set_params({
364 | 'epochs': epochs,
365 | 'steps': steps_per_epoch,
366 | 'verbose': verbose,
367 | 'do_validation': do_validation,
368 | 'metrics': callback_metrics,
369 | })
370 | callbacks.on_train_begin()
371 |
372 | if do_validation:
373 | val_sample_weight = None
374 | for cbk in callbacks:
375 | cbk.validation_data = [self.val_inputs, self.y_val, val_sample_weight]
376 |
377 | try:
378 | sess = K.get_session()
379 | coord = tf.train.Coordinator()
380 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
381 |
382 | callback_model.stop_training = False
383 | while epoch < epochs:
384 | callbacks.on_epoch_begin(epoch)
385 | steps_done = 0
386 | batch_index = 0
387 | while steps_done < steps_per_epoch:
388 | # build batch logs
389 | batch_logs = {
390 | 'batch': batch_index,
391 | 'size': self.inputs[0].shape[0].value
392 | }
393 | callbacks.on_batch_begin(batch_index, batch_logs)
394 |
395 | if self.uses_learning_phase and not isinstance(K.learning_phase(),
396 | int):
397 | ins = [1.]
398 | else:
399 | ins = []
400 | outs = self.train_function(ins)
401 |
402 | if not isinstance(outs, list):
403 | outs = [outs]
404 | for l, o in zip(out_labels, outs):
405 | batch_logs[l] = o
406 |
407 | callbacks.on_batch_end(batch_index, batch_logs)
408 |
409 | # Construct epoch logs.
410 | epoch_logs = {}
411 | batch_index += 1
412 | steps_done += 1
413 |
414 | # Epoch finished.
415 | if steps_done >= steps_per_epoch and do_validation:
416 | val_outs = self._validate_tfrecord(steps=validation_steps)
417 | if not isinstance(val_outs, list):
418 | val_outs = [val_outs]
419 | # Same labels assumed.
420 | for l, o in zip(out_labels, val_outs):
421 | epoch_logs['val_' + l] = o
422 |
423 | callbacks.on_epoch_end(epoch, epoch_logs)
424 | epoch += 1
425 | if callback_model.stop_training:
426 | break
427 |
428 | finally:
429 | # TODO: If you close the queue, you can't open it again..
430 | # coord.request_stop()
431 | # coord.join(threads)
432 | pass
433 |
434 | callbacks.on_train_end()
435 | return self.history
436 |
437 | def _validate_tfrecord(self, steps):
438 | self._make_tfrecord_val_function()
439 |
440 | steps_done = 0
441 | all_outs = []
442 | batch_sizes = []
443 |
444 | while steps_done < steps:
445 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
446 | ins = [0.]
447 | else:
448 | ins = []
449 | outs = self.val_function(ins)
450 | if len(outs) == 1:
451 | outs = outs[0]
452 |
453 | batch_size = self.val_inputs[0].shape[0].value
454 | all_outs.append(outs)
455 |
456 | steps_done += 1
457 | batch_sizes.append(batch_size)
458 |
459 | if not isinstance(outs, list):
460 | return np.average(np.asarray(all_outs),
461 | weights=batch_sizes)
462 | else:
463 | averages = []
464 | for i in range(len(outs)):
465 | averages.append(np.average([out[i] for out in all_outs],
466 | weights=batch_sizes))
467 | return averages
468 |
469 | def evaluate_tfrecord(self, steps):
470 | """Evaluates the model on a data generator.
471 |
472 | The generator should return the same kind of data
473 | as accepted by `test_on_batch`.
474 |
475 | # Arguments
476 | x_batch:
477 | y_batch:
478 | steps: Total number of steps (batches of samples)
479 | to yield from `generator` before stopping.
480 | stop_queue_runners: If True, stop queue runners after evaluation.
481 |
482 | # Returns
483 | Scalar test loss (if the model has a single output and no metrics)
484 | or list of scalars (if the model has multiple outputs
485 | and/or metrics). The attribute `model.metrics_names` will give you
486 | the display labels for the scalar outputs.
487 |
488 | # Raises
489 | ValueError: In case the generator yields
490 | data in an invalid format.
491 | """
492 | self._make_tfrecord_test_function()
493 |
494 | steps_done = 0
495 | all_outs = []
496 | batch_sizes = []
497 |
498 | try:
499 | sess = K.get_session()
500 | coord = tf.train.Coordinator()
501 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
502 |
503 | while steps_done < steps:
504 |
505 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
506 | ins = [0.]
507 | else:
508 | ins = []
509 | outs = self.test_function(ins)
510 | if len(outs) == 1:
511 | outs = outs[0]
512 |
513 | batch_size = self.inputs[0].shape[0].value
514 | all_outs.append(outs)
515 |
516 | steps_done += 1
517 | batch_sizes.append(batch_size)
518 |
519 | finally:
520 | # TODO: If you close the queue, you can't open it again..
521 | # if stop_queue_runners:
522 | # coord.request_stop()
523 | # coord.join(threads)
524 | pass
525 |
526 | if not isinstance(outs, list):
527 | return np.average(np.asarray(all_outs),
528 | weights=batch_sizes)
529 | else:
530 | averages = []
531 | for i in range(len(outs)):
532 | averages.append(np.average([out[i] for out in all_outs],
533 | weights=batch_sizes))
534 | return averages
535 |
536 | def _make_tfrecord_predict_function(self):
537 | if not hasattr(self, 'predict_function'):
538 | self.predict_function = None
539 | if self.predict_function is None:
540 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
541 | inputs = [K.learning_phase()]
542 | else:
543 | inputs = []
544 | # Gets network outputs. Does not update weights.
545 | # Does update the network states.
546 | self.predict_function = K.function(inputs,
547 | self.outputs,
548 | updates=self.state_updates)
549 |
550 | def predict_tfrecord(self, x_batch):
551 | if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
552 | ins = [0.]
553 | else:
554 | ins = []
555 | self._make_tfrecord_predict_function()
556 |
557 | try:
558 | sess = K.get_session()
559 | coord = tf.train.Coordinator()
560 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
561 |
562 | outputs = self.predict_function(ins)
563 |
564 | finally:
565 | # TODO: If you close the queue, you can't open it again..
566 | # if stop_queue_runners:
567 | # coord.request_stop()
568 | # coord.join(threads)
569 | pass
570 |
571 | if len(outputs) == 1:
572 | return outputs[0]
573 | return outputs
574 |
--------------------------------------------------------------------------------
/sample_cnn/model.py:
--------------------------------------------------------------------------------
1 | from keras.layers import (Conv1D, MaxPool1D, BatchNormalization,
2 | Dense, Dropout, Activation, Flatten, Reshape, Input)
3 |
4 | from sample_cnn.keras_utils.tfrecord_model import TFRecordModel
5 |
6 |
7 | class SampleCNN(TFRecordModel):
8 | def __init__(self, segments,
9 | val_segments=None,
10 | n_outputs=50,
11 | activation='relu',
12 | kernel_initializer='he_uniform',
13 | dropout_rate=0.5,
14 | extra_inputs=None,
15 | extra_outputs=None):
16 | # 59049
17 | segments = Input(tensor=segments)
18 | net = Reshape([-1, 1])(segments)
19 | # 59049 X 1
20 | net = Conv1D(128, 3, strides=3, padding='valid',
21 | kernel_initializer=kernel_initializer)(net)
22 | net = BatchNormalization()(net)
23 | net = Activation(activation)(net)
24 | # 19683 X 128
25 | net = Conv1D(128, 3, padding='same',
26 | kernel_initializer=kernel_initializer)(net)
27 | net = BatchNormalization()(net)
28 | net = Activation(activation)(net)
29 | net = MaxPool1D(3)(net)
30 | # 6561 X 128
31 | net = Conv1D(128, 3, padding='same',
32 | kernel_initializer=kernel_initializer)(net)
33 | net = BatchNormalization()(net)
34 | net = Activation(activation)(net)
35 | net = MaxPool1D(3)(net)
36 | # 2187 X 128
37 | net = Conv1D(256, 3, padding='same',
38 | kernel_initializer=kernel_initializer)(net)
39 | net = BatchNormalization()(net)
40 | net = Activation(activation)(net)
41 | net = MaxPool1D(3)(net)
42 | # 729 X 256
43 | net = Conv1D(256, 3, padding='same',
44 | kernel_initializer=kernel_initializer)(net)
45 | net = BatchNormalization()(net)
46 | net = Activation(activation)(net)
47 | net = MaxPool1D(3)(net)
48 | # 243 X 256
49 | net = Conv1D(256, 3, padding='same',
50 | kernel_initializer=kernel_initializer)(net)
51 | net = BatchNormalization()(net)
52 | net = Activation(activation)(net)
53 | net = MaxPool1D(3)(net)
54 | # 81 X 256
55 | net = Conv1D(256, 3, padding='same',
56 | kernel_initializer=kernel_initializer)(net)
57 | net = BatchNormalization()(net)
58 | net = Activation(activation)(net)
59 | net = MaxPool1D(3)(net)
60 | # 27 X 256
61 | net = Conv1D(256, 3, padding='same',
62 | kernel_initializer=kernel_initializer)(net)
63 | net = BatchNormalization()(net)
64 | net = Activation(activation)(net)
65 | net = MaxPool1D(3)(net)
66 | # 9 X 256
67 | net = Conv1D(256, 3, padding='same',
68 | kernel_initializer=kernel_initializer)(net)
69 | net = BatchNormalization()(net)
70 | net = Activation(activation)(net)
71 | net = MaxPool1D(3)(net)
72 | # 3 X 256
73 | net = Conv1D(512, 3, padding='same',
74 | kernel_initializer=kernel_initializer)(net)
75 | net = BatchNormalization()(net)
76 | net = Activation(activation)(net)
77 | net = MaxPool1D(3)(net)
78 | # 1 X 512
79 | net = Conv1D(512, 1, padding='same',
80 | kernel_initializer=kernel_initializer)(net)
81 | net = BatchNormalization()(net)
82 | net = Activation(activation)(net)
83 | # 1 X 512
84 | net = Dropout(dropout_rate)(net)
85 | net = Flatten()(net)
86 |
87 | logits = Dense(units=n_outputs, activation='sigmoid')(net)
88 |
89 | if extra_inputs is None:
90 | extra_inputs = []
91 | if extra_outputs is None:
92 | extra_outputs = []
93 |
94 | if not isinstance(extra_inputs, list):
95 | extra_inputs = [extra_inputs]
96 | inputs = [segments] + extra_inputs
97 |
98 | if not isinstance(extra_outputs, list):
99 | extra_outputs = [extra_outputs]
100 | outputs = [logits] + extra_outputs
101 |
102 | super(SampleCNN, self).__init__(inputs=inputs,
103 | val_inputs=val_segments,
104 | outputs=outputs)
105 |
--------------------------------------------------------------------------------
/sample_cnn/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from sample_cnn.ops.batch_inputs import *
2 | from sample_cnn.ops.evaluation import *
--------------------------------------------------------------------------------
/sample_cnn/ops/batch_inputs.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def _read_example(filename_queue, n_labels=50, n_samples=59049):
5 | reader = tf.TFRecordReader()
6 | _, serialized_example = reader.read(filename_queue)
7 | features = tf.parse_single_example(
8 | serialized_example,
9 | features={
10 | 'raw_labels': tf.FixedLenFeature([], tf.string),
11 | 'raw_segment': tf.FixedLenFeature([], tf.string)
12 | })
13 |
14 | segment = tf.decode_raw(features['raw_segment'], tf.float32)
15 | segment.set_shape([n_samples])
16 |
17 | labels = tf.decode_raw(features['raw_labels'], tf.uint8)
18 | labels.set_shape([n_labels])
19 | labels = tf.cast(labels, tf.float32)
20 |
21 | return segment, labels
22 |
23 |
24 | def _read_sequence_example(filename_queue,
25 | n_labels=50, n_samples=59049, n_segments=10):
26 | reader = tf.TFRecordReader()
27 | _, serialized_example = reader.read(filename_queue)
28 | context, sequence = tf.parse_single_sequence_example(
29 | serialized_example,
30 | context_features={
31 | 'raw_labels': tf.FixedLenFeature([], dtype=tf.string)
32 | },
33 | sequence_features={
34 | 'raw_segments': tf.FixedLenSequenceFeature([], dtype=tf.string)
35 | })
36 |
37 | segments = tf.decode_raw(sequence['raw_segments'], tf.float32)
38 | segments.set_shape([n_segments, n_samples])
39 |
40 | labels = tf.decode_raw(context['raw_labels'], tf.uint8)
41 | labels.set_shape([n_labels])
42 | labels = tf.cast(labels, tf.float32)
43 |
44 | return segments, labels
45 |
46 |
47 | def batch_inputs(file_pattern,
48 | batch_size,
49 | is_training,
50 | examples_per_shard,
51 | is_sequence=False,
52 | input_queue_capacity_factor=16,
53 | n_read_threads=4,
54 | shard_queue_name='filename_queue',
55 | example_queue_name='input_queue'):
56 | data_files = []
57 | for pattern in file_pattern.split(","):
58 | data_files.extend(tf.gfile.Glob(pattern))
59 |
60 | assert len(data_files) > 0, (
61 | 'Found no input files matching {}'.format(file_pattern))
62 |
63 | tf.logging.info('Prefetching values from %d files matching %s',
64 | len(data_files), file_pattern)
65 |
66 | if is_sequence:
67 | read_example_fn = _read_sequence_example
68 | else:
69 | read_example_fn = _read_example
70 |
71 | if is_training:
72 | filename_queue = tf.train.string_input_producer(
73 | data_files, shuffle=True, capacity=16, name=shard_queue_name)
74 |
75 | # examples_per_shard
76 | # = examples_per_song * n_training_songs / n_training_shards
77 | # = 10 * 15250 / 152
78 | # = 1003
79 | #
80 | # example_size = 59049 * 4bytes = 232KB
81 | #
82 | # queue_size
83 | # = examples_per_shard * input_queue_capacity_factor * example_size
84 | # = 1003 * 16 * 232KB = 3.7GB
85 | min_queue_examples = examples_per_shard * input_queue_capacity_factor
86 | capacity = min_queue_examples + 3 * batch_size
87 |
88 | example_list = [read_example_fn(filename_queue)
89 | for _ in range(n_read_threads)]
90 |
91 | segment, label = tf.train.shuffle_batch_join(
92 | example_list,
93 | batch_size=batch_size,
94 | capacity=capacity,
95 | min_after_dequeue=min_queue_examples,
96 | name='shuffle_' + example_queue_name,
97 | )
98 |
99 | return segment, label
100 |
101 | else:
102 | filename_queue = tf.train.string_input_producer(
103 | data_files, shuffle=False, capacity=1, name=shard_queue_name)
104 |
105 | segment, label = read_example_fn(filename_queue)
106 |
107 | capacity = examples_per_shard + 2 * batch_size
108 | segment_batch, label_batch = tf.train.batch(
109 | [segment, label],
110 | batch_size=batch_size,
111 | num_threads=1,
112 | capacity=capacity,
113 | name='fifo_' + example_queue_name)
114 |
115 | return segment_batch, label_batch
116 |
--------------------------------------------------------------------------------
/sample_cnn/ops/evaluation.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from keras.layers import Input
5 | from sklearn.metrics import roc_auc_score, log_loss
6 |
7 | from sample_cnn.ops import batch_inputs
8 | from sample_cnn.model import SampleCNN
9 |
10 |
11 | def evaluate(input_file_pattern, weights_path, n_examples,
12 | n_audios_per_shard=100, print_progress=True):
13 | # audio: [1, 10, 58081]
14 | # labels: [1, 50]
15 | audio, labels = batch_inputs(
16 | file_pattern=input_file_pattern,
17 | batch_size=1,
18 | is_training=False,
19 | is_sequence=True,
20 | n_read_threads=1,
21 | examples_per_shard=n_audios_per_shard,
22 | shard_queue_name='filename_queue',
23 | example_queue_name='input_queue')
24 |
25 | segments = tf.squeeze(audio)
26 |
27 | labels = tf.squeeze(labels)
28 | labels = Input(tensor=labels)
29 |
30 | model = SampleCNN(segments=segments,
31 | extra_inputs=labels,
32 | extra_outputs=labels)
33 |
34 | print('Load weights from "{}".'.format(weights_path))
35 | model.load_weights(weights_path)
36 |
37 | n_classes = labels.shape[0].value
38 | all_y_pred = np.empty([0, n_classes], dtype=np.float32)
39 | all_y_true = np.empty([0, n_classes], dtype=np.float32)
40 |
41 | print('Start evaluation.')
42 | for i in range(n_examples):
43 | y_pred_segments, y_true = model.predict_tfrecord(segments)
44 |
45 | y_pred = np.mean(y_pred_segments, axis=0)
46 |
47 | y_pred = np.expand_dims(y_pred, 0)
48 | y_true = np.expand_dims(y_true, 0)
49 |
50 | all_y_pred = np.append(all_y_pred, y_pred, axis=0)
51 | all_y_true = np.append(all_y_true, y_true, axis=0)
52 |
53 | if print_progress and i % (n_examples // 100) == 0 and i:
54 | print('Evaluated [{:04d}/{:04d}].'.format(i + 1, n_examples))
55 |
56 | losses = []
57 | for i in range(n_classes):
58 | class_y_true = all_y_true[:, i]
59 | class_y_pred = all_y_pred[:, i]
60 | if np.sum(class_y_true) != 0:
61 | class_loss = log_loss(class_y_true, class_y_pred)
62 | losses.append(class_loss)
63 |
64 | loss = np.mean(losses)
65 | print('@ binary cross entropy loss: {}'.format(loss))
66 |
67 | roc_auc = roc_auc_score(all_y_true, all_y_pred, average='macro')
68 | print('@ ROC AUC score: {}'.format(roc_auc))
69 |
--------------------------------------------------------------------------------
/sample_cnn/train.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import re
4 |
5 | import tensorflow as tf
6 |
7 | from glob import glob
8 |
9 | from keras.callbacks import (TensorBoard, ModelCheckpoint,
10 | EarlyStopping, CSVLogger)
11 | from keras.optimizers import SGD
12 |
13 | from sample_cnn.model import SampleCNN
14 | from sample_cnn.ops import batch_inputs, evaluate
15 |
16 | tf.logging.set_verbosity(tf.logging.INFO)
17 |
18 | # Paths.
19 | tf.flags.DEFINE_string('train_input_file_pattern', '',
20 | 'File pattern of TFRecord input files.')
21 | tf.flags.DEFINE_string('val_input_file_pattern', '',
22 | 'File pattern of validation TFRecord input files.')
23 | tf.flags.DEFINE_string('test_input_file_pattern', '',
24 | 'File pattern of test TFRecord input files.')
25 | tf.flags.DEFINE_string('train_dir', '',
26 | 'Directory where to write event logs and checkpoints.')
27 | tf.flags.DEFINE_string('checkpoint_prefix', 'best_weights',
28 | 'Prefix of the checkpoint filename.')
29 |
30 | # Batch options.
31 | tf.flags.DEFINE_integer('batch_size', 23, 'Batch size.')
32 | tf.flags.DEFINE_integer('n_audios_per_shard', 100,
33 | 'Number of audios per shard.')
34 | tf.flags.DEFINE_integer('n_segments_per_audio', 10,
35 | 'Number of segments per audio.')
36 | tf.flags.DEFINE_integer('n_train_examples', 15250,
37 | 'Number of examples in training dataset.')
38 | tf.flags.DEFINE_integer('n_val_examples', 1529,
39 | 'Number of examples in validation dataset.')
40 | tf.flags.DEFINE_integer('n_test_examples', 4332,
41 | 'Number of examples in test dataset.')
42 |
43 | tf.flags.DEFINE_integer('n_read_threads', 4, 'Number of example reader.')
44 |
45 | # Learning options.
46 | tf.flags.DEFINE_float('initial_learning_rate', 0.01, 'Initial learning rate.')
47 | tf.flags.DEFINE_float('momentum', 0.9, 'Momentum.')
48 | tf.flags.DEFINE_float('dropout_rate', 0.5, 'Dropout keep probability.')
49 | tf.flags.DEFINE_float('global_lr_decay', 0.2, 'Global learning rate decay.')
50 | tf.flags.DEFINE_float('local_lr_decay', 1e-6, 'Local learning rate decay.')
51 |
52 | # Training options.
53 | tf.flags.DEFINE_integer('patience', 3, 'A patience for the early stopping.')
54 | tf.flags.DEFINE_integer('max_trains', 5, 'Number of re-training.')
55 | tf.flags.DEFINE_integer('initial_stage', 0,
56 | 'The stage where to start training.')
57 |
58 | FLAGS = tf.app.flags.FLAGS
59 |
60 |
61 | def make_path(*paths):
62 | path = os.path.join(*[str(path) for path in paths])
63 | path = os.path.realpath(path)
64 | return path
65 |
66 |
67 | def calculate_steps(n_examples, n_segments, batch_size):
68 | steps = 1. * n_examples * n_segments
69 | steps = math.ceil(steps / batch_size)
70 | return steps
71 |
72 |
73 | def find_best_checkpoint(*dirs):
74 | best_checkpoint_path = None
75 | best_epoch = -1
76 | best_val_loss = 1e+10
77 | for dir in dirs:
78 | checkpoint_paths = glob('{}/{}*'.format(dir, FLAGS.checkpoint_prefix))
79 | for checkpoint_path in checkpoint_paths:
80 | epoch = int(re.findall('e\d+', checkpoint_path)[0][1:])
81 | val_loss = float(re.findall('l\d\.\d+', checkpoint_path)[0][1:])
82 |
83 | if val_loss < best_val_loss:
84 | best_checkpoint_path = checkpoint_path
85 | best_epoch = epoch
86 | best_val_loss = val_loss
87 |
88 | return best_checkpoint_path, best_epoch, best_val_loss
89 |
90 |
91 | def train(initial_lr,
92 | stage_train_dir,
93 | checkpoint_path_to_load=None,
94 | initial_epoch=0):
95 | if not tf.gfile.Exists(stage_train_dir):
96 | tf.logging.info('Creating training directory: %s', stage_train_dir)
97 | tf.gfile.MakeDirs(stage_train_dir)
98 |
99 | x_train_batch, y_train_batch = batch_inputs(
100 | file_pattern=FLAGS.train_input_file_pattern,
101 | batch_size=FLAGS.batch_size,
102 | is_training=True,
103 | n_read_threads=FLAGS.n_read_threads,
104 | examples_per_shard=FLAGS.n_audios_per_shard * FLAGS.n_segments_per_audio,
105 | shard_queue_name='train_filename_queue',
106 | example_queue_name='train_input_queue')
107 |
108 | x_val_batch, y_val_batch = batch_inputs(
109 | file_pattern=FLAGS.val_input_file_pattern,
110 | batch_size=FLAGS.batch_size,
111 | is_training=False,
112 | n_read_threads=1,
113 | examples_per_shard=FLAGS.n_audios_per_shard * FLAGS.n_segments_per_audio,
114 | shard_queue_name='val_filename_queue',
115 | example_queue_name='val_input_queue')
116 |
117 | # Create a model.
118 | model = SampleCNN(segments=x_train_batch,
119 | val_segments=x_val_batch,
120 | dropout_rate=FLAGS.dropout_rate)
121 |
122 | # Load weights from a checkpoint if exists.
123 | if checkpoint_path_to_load:
124 | print('Load weights from "{}".'.format(checkpoint_path_to_load))
125 | model.load_weights(checkpoint_path_to_load)
126 |
127 | # Setup an optimizer.
128 | optimizer = SGD(lr=initial_lr,
129 | momentum=FLAGS.momentum,
130 | decay=FLAGS.local_lr_decay,
131 | nesterov=True)
132 |
133 | # Compile the model.
134 | model.compile_tfrecord(y=y_train_batch,
135 | y_val=y_val_batch,
136 | loss='binary_crossentropy',
137 | optimizer=optimizer)
138 |
139 | # Setup a TensorBoard callback.
140 | tensor_board = TensorBoard(log_dir=stage_train_dir)
141 |
142 | # Use early stopping mechanism.
143 | early_stopping = EarlyStopping(monitor='val_loss', patience=FLAGS.patience)
144 |
145 | # Setup a checkpointer.
146 | checkpoint_path = make_path(
147 | stage_train_dir,
148 | FLAGS.checkpoint_prefix + '-e{epoch:03d}-l{val_loss:.4f}.hdf5')
149 | checkpointer = ModelCheckpoint(
150 | filepath=checkpoint_path,
151 | monitor='val_loss',
152 | save_best_only=True)
153 |
154 | # Setup a CSV logger.
155 | csv_logger = CSVLogger(filename=make_path(stage_train_dir, 'training.csv'),
156 | append=True)
157 |
158 | # Kick-off the training!
159 | train_steps = calculate_steps(n_examples=FLAGS.n_train_examples,
160 | n_segments=FLAGS.n_segments_per_audio,
161 | batch_size=FLAGS.batch_size)
162 | val_steps = calculate_steps(n_examples=FLAGS.n_val_examples,
163 | n_segments=FLAGS.n_segments_per_audio,
164 | batch_size=FLAGS.batch_size)
165 |
166 | model.fit_tfrecord(epochs=100,
167 | initial_epoch=initial_epoch,
168 | steps_per_epoch=train_steps,
169 | validation_steps=val_steps,
170 | callbacks=[tensor_board, early_stopping,
171 | checkpointer, csv_logger])
172 |
173 | # The end of the stage. Evaluate on test set.
174 | best_ckpt_path, *_ = find_best_checkpoint(stage_train_dir)
175 | print('The end of the stage. '
176 | 'Start evaluation on test set using checkpoint "{}"'
177 | .format(best_ckpt_path))
178 |
179 | evaluate(input_file_pattern=FLAGS.test_input_file_pattern,
180 | weights_path=best_ckpt_path,
181 | n_examples=FLAGS.n_test_examples,
182 | n_audios_per_shard=FLAGS.n_audios_per_shard,
183 | print_progress=False)
184 |
185 |
186 | def main(unused_argv):
187 | assert FLAGS.train_dir, '--train_dir is required'
188 | assert FLAGS.train_input_file_pattern, '--train_input_file_pattern is required'
189 | assert FLAGS.val_input_file_pattern, '--val_input_file_pattern is required'
190 | assert FLAGS.test_input_file_pattern, '--test_input_file_pattern is required'
191 |
192 | # Print all flags.
193 | print('@@ Flags')
194 | for key, value in FLAGS.__flags.items():
195 | print('{}={}'.format(key, value))
196 |
197 | for i in range(FLAGS.initial_stage, FLAGS.max_trains):
198 | stage_train_dir = make_path(FLAGS.train_dir, i)
199 | previous_stage_train_dir = make_path(FLAGS.train_dir, i - 1)
200 | next_stage_train_dir = make_path(FLAGS.train_dir, i + 1)
201 |
202 | # Pass if there is a training directory of the next stage.
203 | if os.path.isdir(next_stage_train_dir):
204 | continue
205 |
206 | # Setup the initial learning rate for the stage.
207 | decay = FLAGS.global_lr_decay ** i
208 | learning_rate = FLAGS.initial_learning_rate * decay
209 |
210 | # Create a directory for the stage.
211 | os.makedirs(stage_train_dir, exist_ok=True)
212 |
213 | # Find the best checkpoint to load weights.
214 | (ckpt_path, ckpt_epoch, ckpt_val_loss) = find_best_checkpoint(
215 | stage_train_dir, previous_stage_train_dir)
216 |
217 | print('\n@@ Start training stage {:02d}: lr={}, train_dir={}'
218 | .format(i, learning_rate, stage_train_dir))
219 | if ckpt_path:
220 | print('Found a trained model: epoch={}, val_loss={}, path={}'
221 | .format(ckpt_epoch, ckpt_val_loss, ckpt_path))
222 | else:
223 | print('No trained model found.')
224 |
225 | train(initial_lr=learning_rate,
226 | stage_train_dir=stage_train_dir,
227 | checkpoint_path_to_load=ckpt_path,
228 | initial_epoch=ckpt_epoch + 1)
229 |
230 | print('\nDone.')
231 |
232 |
233 | if __name__ == '__main__':
234 | tf.app.run()
235 |
--------------------------------------------------------------------------------
/scripts/build_mtt.sh.template:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | # Please fill in the blanks!
5 | BASE_DIR=
6 | N_THREADS=4
7 | ENV_NAME= # if you don't use an environment, just leave it blank.
8 |
9 |
10 | DATA_DIR="${BASE_DIR}/raw"
11 | ANNOTATION_FILE="${BASE_DIR}/annotations_final.csv"
12 | OUTPUT_DIR="${BASE_DIR}/tfrecord"
13 | OUTPUT_LABELS="${BASE_DIR}/labels.txt"
14 | LOG_FILE="${BASE_DIR}/build_mtt.log"
15 |
16 | if [ -n "${ENV_NAME}" ]; then
17 | source activate "${ENV_NAME}"
18 | fi
19 |
20 | export PYTHONPATH='.'
21 | export PYTHONUNBUFFERED=1
22 |
23 | nohup python sample_cnn/data/mtt/build_mtt.py \
24 | --data_dir="${DATA_DIR}" \
25 | --annotation_file="${ANNOTATION_FILE}" \
26 | --output_dir="${OUTPUT_DIR}" \
27 | --output_labels="${OUTPUT_LABELS}" \
28 | --n_threads="${N_THREADS}" > "${LOG_FILE}" &
29 |
30 | tail -F "${LOG_FILE}"
31 |
--------------------------------------------------------------------------------
/scripts/evaluate.sh.template:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | # Please fill in the blanks!
5 | BASE_DIR=
6 | WEIGHTS_PATH=
7 | ENV_NAME= # if you don't use an environment, just leave it blank.
8 |
9 |
10 | INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/test-???-of-???.seq.tfrecords"
11 |
12 | if [ -n "${ENV_NAME}" ]; then
13 | source activate "${ENV_NAME}"
14 | fi
15 |
16 | export PYTHONPATH='.'
17 | export PYTHONUNBUFFERED=1
18 |
19 | python sample_cnn/evaluate.py \
20 | --input_file_pattern="${INPUT_FILE_PATTERN}" \
21 | --weights_path="${WEIGHTS_PATH}"
22 |
--------------------------------------------------------------------------------
/scripts/train.sh.template:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | # Please fill in the blanks!
5 | BASE_DIR=
6 | TRAIN_DIR=
7 | ENV_NAME= # if you don't use an environment, just leave it blank.
8 |
9 |
10 | TRAIN_INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/train-???-of-???.tfrecords"
11 | VAL_INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/val-???-of-???.tfrecords"
12 | TEST_INPUT_FILE_PATTERN="${BASE_DIR}/tfrecord/test-???-of-???.seq.tfrecords"
13 |
14 | LOG_FILE="${TRAIN_DIR}/train.log"
15 |
16 | mkdir -p "${TRAIN_DIR}"
17 |
18 | if [ -n "${ENV_NAME}" ]; then
19 | source activate "${ENV_NAME}"
20 | fi
21 |
22 | export PYTHONPATH='.'
23 | export PYTHONUNBUFFERED=1
24 |
25 | nohup python sample_cnn/train.py \
26 | --train_input_file_pattern="${TRAIN_INPUT_FILE_PATTERN}" \
27 | --val_input_file_pattern="${VAL_INPUT_FILE_PATTERN}" \
28 | --test_input_file_pattern="${TEST_INPUT_FILE_PATTERN}" \
29 | --train_dir="${TRAIN_DIR}" >> "${LOG_FILE}" &
30 |
31 | tail -F "${LOG_FILE}"
32 |
--------------------------------------------------------------------------------
/tests/tfrecord_model_test_mnist.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 |
6 | from keras.layers import Input, Conv2D, Dense, Flatten
7 | from keras.datasets import mnist
8 | from keras.callbacks import ModelCheckpoint
9 |
10 | from sample_cnn.keras_utils.tfrecord_model import TFRecordModel
11 |
12 |
13 | def data_to_tfrecord(images, labels, filename):
14 | def _int64_feature(value):
15 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
16 |
17 | def _bytes_feature(value):
18 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
19 |
20 | """ Save data into TFRecord """
21 | if not os.path.isfile(filename):
22 | num_examples = images.shape[0]
23 |
24 | rows = images.shape[1]
25 | cols = images.shape[2]
26 | depth = images.shape[3]
27 |
28 | print('Writing', filename)
29 | writer = tf.python_io.TFRecordWriter(filename)
30 | for index in range(num_examples):
31 | image_raw = images[index].tostring()
32 | example = tf.train.Example(features=tf.train.Features(feature={
33 | 'height': _int64_feature(rows),
34 | 'width': _int64_feature(cols),
35 | 'depth': _int64_feature(depth),
36 | 'label': _int64_feature(int(labels[index])),
37 | 'image_raw': _bytes_feature(image_raw)}))
38 | writer.write(example.SerializeToString())
39 | writer.close()
40 | else:
41 | print('tfrecord already exist')
42 |
43 |
44 | def read_and_decode(filename, one_hot=True, n_classes=None):
45 | """ Return tensor to read from TFRecord """
46 | filename_queue = tf.train.string_input_producer([filename])
47 | reader = tf.TFRecordReader()
48 | _, serialized_example = reader.read(filename_queue)
49 | features = tf.parse_single_example(serialized_example,
50 | features={
51 | 'label': tf.FixedLenFeature([],
52 | tf.int64),
53 | 'image_raw': tf.FixedLenFeature([],
54 | tf.string),
55 | })
56 | # You can do more image distortion here for training data
57 | img = tf.decode_raw(features['image_raw'], tf.uint8)
58 | img.set_shape([28 * 28])
59 | img = tf.reshape(img, [28, 28, 1])
60 | img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
61 |
62 | label = tf.cast(features['label'], tf.int32)
63 | if one_hot and n_classes:
64 | label = tf.one_hot(label, n_classes)
65 |
66 | return img, label
67 |
68 |
69 | def build_net(inputs):
70 | net = Conv2D(32, 3, strides=2, activation='relu')(inputs)
71 | net = Conv2D(32, 3, strides=2, activation='relu')(net)
72 | net = Flatten()(net)
73 | net = Dense(128, activation='relu')(net)
74 | net = Dense(10, activation='softmax')(net)
75 | return net
76 |
77 |
78 | if __name__ == '__main__':
79 | TMP_DIR = '_test_tmp_dir_'
80 | TRAIN_TFRECORD_PATH = TMP_DIR + '/mnist_train.tfrecords'
81 | VAL_TFRECORD_PATH = TMP_DIR + '/mnist_val.tfrecords'
82 | TEST_TFRECORD_PATH = TMP_DIR + '/mnist_test.tfrecords'
83 | WEIGHTS_PATH = TMP_DIR + '/weights.hdf5'
84 |
85 | batch_size = 64
86 | n_train = 50000
87 |
88 | # Create temp dir
89 | os.makedirs(TMP_DIR, exist_ok=True)
90 |
91 | (x_train, y_train), (x_test, y_test) = mnist.load_data()
92 |
93 | x_train = np.expand_dims(x_train, -1)
94 | x_test = np.expand_dims(x_test, -1)
95 |
96 | x_val = x_train[n_train:]
97 | y_val = y_train[n_train:]
98 |
99 | x_train = x_train[:n_train]
100 | y_train = y_train[:n_train]
101 |
102 | data_to_tfrecord(images=x_train,
103 | labels=y_train,
104 | filename=TRAIN_TFRECORD_PATH)
105 | data_to_tfrecord(images=x_val,
106 | labels=y_val,
107 | filename=VAL_TFRECORD_PATH)
108 | data_to_tfrecord(images=x_test,
109 | labels=y_test,
110 | filename=TEST_TFRECORD_PATH)
111 |
112 | x_train, y_train = read_and_decode(filename=TRAIN_TFRECORD_PATH,
113 | one_hot=True,
114 | n_classes=10)
115 | x_val, y_val = read_and_decode(filename=VAL_TFRECORD_PATH,
116 | one_hot=True,
117 | n_classes=10)
118 | x_test, y_test = read_and_decode(filename=TEST_TFRECORD_PATH,
119 | one_hot=True,
120 | n_classes=10)
121 |
122 | x_train_batch, y_train_batch = tf.train.shuffle_batch([x_train, y_train],
123 | batch_size=batch_size,
124 | capacity=2000,
125 | min_after_dequeue=1000,
126 | name='train_batch')
127 | x_val_batch, y_val_batch = tf.train.shuffle_batch([x_val, y_val],
128 | batch_size=batch_size,
129 | capacity=2000,
130 | min_after_dequeue=1000,
131 | name='val_batch')
132 | x_test_batch, y_test_batch = tf.train.shuffle_batch([x_test, y_test],
133 | batch_size=batch_size,
134 | capacity=2000,
135 | min_after_dequeue=1000,
136 | name='test_batch')
137 |
138 | x_train_input = Input(tensor=x_train_batch)
139 | x_val_input = Input(tensor=x_val_batch)
140 | x_test_input = Input(tensor=x_test_batch)
141 | logits = build_net(x_train_input)
142 |
143 | model = TFRecordModel(inputs=x_train_input,
144 | outputs=logits,
145 | val_inputs=x_val_input)
146 |
147 | model.compile_tfrecord(optimizer='adam',
148 | loss='categorical_crossentropy',
149 | metrics=['accuracy'],
150 | y=y_train_batch,
151 | y_val=y_val_batch)
152 |
153 | model.summary()
154 |
155 | model_checkpoint = ModelCheckpoint(
156 | filepath=TMP_DIR + '/best_weights.{epoch:02d}-{val_loss:.4f}.hdf5',
157 | save_best_only=True)
158 |
159 | model.fit_tfrecord(epochs=10,
160 | verbose=2,
161 | steps_per_epoch=n_train // batch_size + 1,
162 | validation_steps=10000 // batch_size + 1,
163 | callbacks=[model_checkpoint])
164 |
165 | model.save_weights(WEIGHTS_PATH)
166 | model.load_weights(WEIGHTS_PATH)
167 |
168 | test_outputs = model(x_test_input)
169 | test_model = TFRecordModel(inputs=x_test_input,
170 | outputs=test_outputs)
171 | test_model.compile_tfrecord(optimizer='adam',
172 | loss='categorical_crossentropy',
173 | metrics=['accuracy'],
174 | y=y_test_batch)
175 |
176 | loss, accuracy = test_model.evaluate_tfrecord(steps=10000 // batch_size + 1)
177 |
178 | print('accuracy={}, loss={}'.format(accuracy, loss))
179 | print('Done.')
180 |
--------------------------------------------------------------------------------