├── cloudml-gpu.yaml ├── cloudml-4gpu.yaml ├── cloudml-gpu-distributed.yaml ├── __init__.py ├── models.py ├── CONTRIBUTING.md ├── feature_extractor ├── feature_extractor_test.py ├── README.md ├── feature_extractor.py └── extract_tfrecords_main.py ├── convert_prediction_from_json_to_csv.py ├── model_utils.py ├── losses.py ├── video_level_models.py ├── mean_average_precision_calculator.py ├── export_model.py ├── utils.py ├── inference.py ├── eval_util.py ├── frame_level_models.py ├── average_precision_calculator.py ├── readers.py ├── LICENSE ├── eval.py ├── README.md └── train.py /cloudml-gpu.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | scaleTier: CUSTOM 3 | masterType: standard_gpu 4 | runtimeVersion: "1.0" 5 | -------------------------------------------------------------------------------- /cloudml-4gpu.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | scaleTier: CUSTOM 3 | masterType: complex_model_m_gpu 4 | runtimeVersion: "1.0" 5 | -------------------------------------------------------------------------------- /cloudml-gpu-distributed.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | runtimeVersion: "1.0" 3 | scaleTier: CUSTOM 4 | masterType: standard_gpu 5 | workerCount: 2 6 | workerType: standard_gpu 7 | parameterServerCount: 2 8 | parameterServerType: standard 9 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains the base class for models.""" 16 | 17 | class BaseModel(object): 18 | """Inherit from this class when implementing new models.""" 19 | 20 | def create_model(self, unused_model_input, **unused_params): 21 | raise NotImplementedError() 22 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We are accepting patches and contributions to this project. To set expectations, 4 | this project is primarily intended to be a flexible starting point for 5 | researchers working with the YouTube-8M dataset. As such, we would like to keep 6 | it simple. We are more likely to accept small bug fixes and optimizations, and 7 | less likely to accept patches which add significant complexity. For the latter 8 | type of contribution, we recommend creating a Github fork of the project 9 | instead. 10 | 11 | If you would like to contribute, there are a few small guidelines you need to 12 | follow. 13 | 14 | ## Contributor License Agreement 15 | 16 | Contributions to any Google project must be accompanied by a Contributor License 17 | Agreement. This is necessary because you own the copyright to your changes, even 18 | after your contribution becomes part of this project. So this agreement simply 19 | gives us permission to use and redistribute your contributions as part of the 20 | project. Head over to to see your current 21 | agreements on file or to sign a new one. 22 | 23 | You generally only need to submit a CLA once, so if you've already submitted one 24 | (even if it was for a different project), you probably don't need to do it 25 | again. 26 | 27 | ## Code reviews 28 | 29 | All submissions, including submissions by project members, require review. We 30 | use GitHub pull requests for this purpose. Consult [GitHub Help] for more 31 | information on using pull requests. 32 | 33 | [GitHub Help]: https://help.github.com/articles/about-pull-requests/ 34 | -------------------------------------------------------------------------------- /feature_extractor/feature_extractor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for feature_extractor.""" 15 | 16 | import pickle 17 | import json 18 | import os 19 | import feature_extractor 20 | import numpy 21 | from PIL import Image 22 | from tensorflow.python.platform import googletest 23 | 24 | 25 | def _FilePath(filename): 26 | return os.path.join('testdata', filename) 27 | 28 | 29 | def _MeanElementWiseDifference(a, b): 30 | """Calculates element-wise percent difference between two numpy matrices.""" 31 | difference = numpy.abs(a - b) 32 | denominator = numpy.maximum(numpy.abs(a), numpy.abs(b)) 33 | 34 | # We dont care if one is 0 and another is 0.01 35 | return (difference / (0.01 + denominator)).mean() 36 | 37 | 38 | class FeatureExtractorTest(googletest.TestCase): 39 | 40 | def setUp(self): 41 | self._extractor = feature_extractor.YouTube8MFeatureExtractor() 42 | 43 | def testPCAOnFeatureVector(self): 44 | sports_1m_test_data = cPickle.load(open(_FilePath('sports1m_frame.pkl'))) 45 | actual_pca = self._extractor.apply_pca(sports_1m_test_data['original']) 46 | expected_pca = sports_1m_test_data['pca'] 47 | self.assertLess(_MeanElementWiseDifference(actual_pca, expected_pca), 1e-5) 48 | 49 | 50 | if __name__ == '__main__': 51 | googletest.main() 52 | -------------------------------------------------------------------------------- /convert_prediction_from_json_to_csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility to convert the output of batch prediction into a CSV submission. 16 | 17 | It converts the JSON files created by the command 18 | 'gcloud beta ml jobs submit prediction' into a CSV file ready for submission. 19 | """ 20 | 21 | import json 22 | import tensorflow as tf 23 | 24 | from builtins import range 25 | from tensorflow import app 26 | from tensorflow import flags 27 | from tensorflow import gfile 28 | from tensorflow import logging 29 | 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | if __name__ == '__main__': 34 | 35 | flags.DEFINE_string( 36 | "json_prediction_files_pattern", None, 37 | "Pattern specifying the list of JSON files that the command " 38 | "'gcloud beta ml jobs submit prediction' outputs. These files are " 39 | "located in the output path of the prediction command and are prefixed " 40 | "with 'prediction.results'.") 41 | flags.DEFINE_string( 42 | "csv_output_file", None, 43 | "The file to save the predictions converted to the CSV format.") 44 | 45 | 46 | def get_csv_header(): 47 | return "VideoId,LabelConfidencePairs\n" 48 | 49 | def to_csv_row(json_data): 50 | 51 | video_id = json_data["video_id"] 52 | 53 | class_indexes = json_data["class_indexes"] 54 | predictions = json_data["predictions"] 55 | 56 | if isinstance(video_id, list): 57 | video_id = video_id[0] 58 | class_indexes = class_indexes[0] 59 | predictions = predictions[0] 60 | 61 | if len(class_indexes) != len(predictions): 62 | raise ValueError( 63 | "The number of indexes (%s) and predictions (%s) must be equal." 64 | % (len(class_indexes), len(predictions))) 65 | 66 | return (video_id.decode('utf-8') + "," + " ".join("%i %f" % 67 | (class_indexes[i], predictions[i]) 68 | for i in range(len(class_indexes))) + "\n") 69 | 70 | def main(unused_argv): 71 | logging.set_verbosity(tf.logging.INFO) 72 | 73 | if not FLAGS.json_prediction_files_pattern: 74 | raise ValueError( 75 | "The flag --json_prediction_files_pattern must be specified.") 76 | 77 | if not FLAGS.csv_output_file: 78 | raise ValueError("The flag --csv_output_file must be specified.") 79 | 80 | logging.info("Looking for prediction files with pattern: %s", 81 | FLAGS.json_prediction_files_pattern) 82 | 83 | file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) 84 | logging.info("Found files: %s", file_paths) 85 | 86 | logging.info("Writing submission file to: %s", FLAGS.csv_output_file) 87 | with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: 88 | output_file.write(get_csv_header()) 89 | 90 | for file_path in file_paths: 91 | logging.info("processing file: %s", file_path) 92 | 93 | with gfile.Open(file_path) as input_file: 94 | 95 | for line in input_file: 96 | json_data = json.loads(line) 97 | output_file.write(to_csv_row(json_data)) 98 | 99 | output_file.flush() 100 | logging.info("done") 101 | 102 | if __name__ == "__main__": 103 | app.run() 104 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains a collection of util functions for model construction. 16 | """ 17 | import numpy 18 | import tensorflow as tf 19 | from tensorflow import logging 20 | from tensorflow import flags 21 | import tensorflow.contrib.slim as slim 22 | 23 | def SampleRandomSequence(model_input, num_frames, num_samples): 24 | """Samples a random sequence of frames of size num_samples. 25 | 26 | Args: 27 | model_input: A tensor of size batch_size x max_frames x feature_size 28 | num_frames: A tensor of size batch_size x 1 29 | num_samples: A scalar 30 | 31 | Returns: 32 | `model_input`: A tensor of size batch_size x num_samples x feature_size 33 | """ 34 | 35 | batch_size = tf.shape(model_input)[0] 36 | frame_index_offset = tf.tile( 37 | tf.expand_dims(tf.range(num_samples), 0), [batch_size, 1]) 38 | max_start_frame_index = tf.maximum(num_frames - num_samples, 0) 39 | start_frame_index = tf.cast( 40 | tf.multiply( 41 | tf.random_uniform([batch_size, 1]), 42 | tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32) 43 | frame_index = tf.minimum(start_frame_index + frame_index_offset, 44 | tf.cast(num_frames - 1, tf.int32)) 45 | batch_index = tf.tile( 46 | tf.expand_dims(tf.range(batch_size), 1), [1, num_samples]) 47 | index = tf.stack([batch_index, frame_index], 2) 48 | return tf.gather_nd(model_input, index) 49 | 50 | 51 | def SampleRandomFrames(model_input, num_frames, num_samples): 52 | """Samples a random set of frames of size num_samples. 53 | 54 | Args: 55 | model_input: A tensor of size batch_size x max_frames x feature_size 56 | num_frames: A tensor of size batch_size x 1 57 | num_samples: A scalar 58 | 59 | Returns: 60 | `model_input`: A tensor of size batch_size x num_samples x feature_size 61 | """ 62 | batch_size = tf.shape(model_input)[0] 63 | frame_index = tf.cast( 64 | tf.multiply( 65 | tf.random_uniform([batch_size, num_samples]), 66 | tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])), tf.int32) 67 | batch_index = tf.tile( 68 | tf.expand_dims(tf.range(batch_size), 1), [1, num_samples]) 69 | index = tf.stack([batch_index, frame_index], 2) 70 | return tf.gather_nd(model_input, index) 71 | 72 | def FramePooling(frames, method, **unused_params): 73 | """Pools over the frames of a video. 74 | 75 | Args: 76 | frames: A tensor with shape [batch_size, num_frames, feature_size]. 77 | method: "average", "max", "attention", or "none". 78 | Returns: 79 | A tensor with shape [batch_size, feature_size] for average, max, or 80 | attention pooling. A tensor with shape [batch_size*num_frames, feature_size] 81 | for none pooling. 82 | 83 | Raises: 84 | ValueError: if method is other than "average", "max", "attention", or 85 | "none". 86 | """ 87 | if method == "average": 88 | return tf.reduce_mean(frames, 1) 89 | elif method == "max": 90 | return tf.reduce_max(frames, 1) 91 | elif method == "none": 92 | feature_size = frames.shape_as_list()[2] 93 | return tf.reshape(frames, [-1, feature_size]) 94 | else: 95 | raise ValueError("Unrecognized pooling method: %s" % method) 96 | -------------------------------------------------------------------------------- /feature_extractor/README.md: -------------------------------------------------------------------------------- 1 | # YouTube8M Feature Extractor 2 | This directory contains binary and library code that can extract YouTube8M 3 | features from images and videos. 4 | The code requires the Inception TensorFlow model ([tutorial](https://www.tensorflow.org/tutorials/image_recognition)) and our PCA matrix, as 5 | outlined in Section 3.3 of our [paper](https://arxiv.org/abs/1609.08675). The 6 | first time you use our code, it will **automatically** download the inception 7 | model (75 Megabytes, tensorflow [GraphDef proto](https://www.tensorflow.org/api_docs/python/tf/GraphDef), 8 | [download link](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz)) 9 | and the PCA matrix (25 Megabytes, Numpy arrays, 10 | [download link](http://data.yt8m.org/yt8m_pca.tgz)). 11 | 12 | ## Usage 13 | 14 | There are two ways to use this code: 15 | 16 | 1. Binary `extract_tfrecords_main.py` processes a CSV file of videos (and their 17 | labels) and outputs `tfrecord` file. Files created with this binary match 18 | the schema of YouTube-8M dataset files, and are therefore are compatible 19 | with our training starter code. You can also use the file for inference 20 | using your models that are pre-trained on YouTube-8M. 21 | 1. Library `feature_extractor.py` which can extract features from images. 22 | 23 | 24 | ### Using the Binary to create `tfrecords` from videos 25 | 26 | You can use binary `extract_tfrecords_main.py` to create `tfrecord` files. 27 | However, this binary assumes that you have OpenCV properly installed (see end 28 | of subsection). Assume that you have two videos `/path/to/vid1` and 29 | `/path/to/vid2`, respectively, with multi-integer labels of `(52, 3, 10)` and 30 | `(7, 67)`. To create `tfrecord` containing features and labels for those videos, 31 | you must first create a CSV file (e.g. on `/path/to/vid_dataset.csv`) with 32 | contents: 33 | 34 | /path/to/vid1,52;3;10 35 | /path/to/vid2,7;67 36 | 37 | Note that the CSV is comma-separated but the label-field is semi-colon separated 38 | to allow for multiple labels per video. 39 | 40 | Then, you can create the `tfrecord` by calling the binary: 41 | 42 | python extract_tfrecords_main.py --input /path/to/vid_dataset.csv \ 43 | --output_tfrecords_file /path/to/output.tfrecord 44 | 45 | Now, you can use the output file for training and/or inference using our starter 46 | code. 47 | 48 | `extract_tfrecords_main.py` requires OpenCV python bindings to be 49 | installed and linked with ffmpeg. In other words, running this command should 50 | print `True`: 51 | 52 | python -c 'import cv2; print cv2.VideoCapture().open("/path/to/some/video.mp4")' 53 | 54 | 55 | ### Using the library to extract features from images 56 | 57 | To extract our features from an image file `cropped_panda.jpg`, you can use 58 | this python code: 59 | 60 | ```python 61 | from PIL import Image 62 | import numpy 63 | 64 | # Instantiate extractor. Slow if called first time on your machine, as it 65 | # needs to download 100 MB. 66 | extractor = YouTube8MFeatureExtractor() 67 | 68 | image_file = os.path.join(extractor._model_dir, 'cropped_panda.jpg') 69 | 70 | im = numpy.array(Image.open(image_file)) 71 | features = extractor.extract_rgb_frame_features(im) 72 | ``` 73 | 74 | The constructor `extractor = YouTube8MFeatureExtractor()` will create a 75 | directory `~/yt8m/`, if it does not exist, and will download and untar the two 76 | model files (inception and PCA matrix). If you prefer, you can point our 77 | extractor to another directory as: 78 | 79 | ```python 80 | extractor = YouTube8MFeatureExtractor(model_dir="/path/to/yt8m_files") 81 | ``` 82 | 83 | You can also pre-populate your custom `"/path/to/yt8m_files"` by manually 84 | downloading (e.g. using `wget`) the URLs and un-tarring them, for example: 85 | 86 | ```bash 87 | mkdir -p /path/to/yt8m_files 88 | cd /path/to/yt8m_files 89 | 90 | wget http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 91 | wget http://data.yt8m.org/yt8m_pca.tgz 92 | 93 | tar zxvf inception-2015-12-05.tgz 94 | tar zxvf yt8m_pca.tgz 95 | ``` 96 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Provides definitions for non-regularized training or test losses.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | class BaseLoss(object): 21 | """Inherit from this class when implementing new losses.""" 22 | 23 | def calculate_loss(self, unused_predictions, unused_labels, **unused_params): 24 | """Calculates the average loss of the examples in a mini-batch. 25 | 26 | Args: 27 | unused_predictions: a 2-d tensor storing the prediction scores, in which 28 | each row represents a sample in the mini-batch and each column 29 | represents a class. 30 | unused_labels: a 2-d tensor storing the labels, which has the same shape 31 | as the unused_predictions. The labels must be in the range of 0 and 1. 32 | unused_params: loss specific parameters. 33 | 34 | Returns: 35 | A scalar loss tensor. 36 | """ 37 | raise NotImplementedError() 38 | 39 | 40 | class CrossEntropyLoss(BaseLoss): 41 | """Calculate the cross entropy loss between the predictions and labels. 42 | """ 43 | 44 | def calculate_loss(self, predictions, labels, **unused_params): 45 | with tf.name_scope("loss_xent"): 46 | epsilon = 10e-6 47 | float_labels = tf.cast(labels, tf.float32) 48 | cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + ( 49 | 1 - float_labels) * tf.log(1 - predictions + epsilon) 50 | cross_entropy_loss = tf.negative(cross_entropy_loss) 51 | return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1)) 52 | 53 | 54 | class HingeLoss(BaseLoss): 55 | """Calculate the hinge loss between the predictions and labels. 56 | 57 | Note the subgradient is used in the backpropagation, and thus the optimization 58 | may converge slower. The predictions trained by the hinge loss are between -1 59 | and +1. 60 | """ 61 | 62 | def calculate_loss(self, predictions, labels, b=1.0, **unused_params): 63 | with tf.name_scope("loss_hinge"): 64 | float_labels = tf.cast(labels, tf.float32) 65 | all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32) 66 | all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32) 67 | sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones) 68 | hinge_loss = tf.maximum( 69 | all_zeros, tf.scalar_mul(b, all_ones) - sign_labels * predictions) 70 | return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1)) 71 | 72 | 73 | class SoftmaxLoss(BaseLoss): 74 | """Calculate the softmax loss between the predictions and labels. 75 | 76 | The function calculates the loss in the following way: first we feed the 77 | predictions to the softmax activation function and then we calculate 78 | the minus linear dot product between the logged softmax activations and the 79 | normalized ground truth label. 80 | 81 | It is an extension to the one-hot label. It allows for more than one positive 82 | labels for each sample. 83 | """ 84 | 85 | def calculate_loss(self, predictions, labels, **unused_params): 86 | with tf.name_scope("loss_softmax"): 87 | epsilon = 10e-8 88 | float_labels = tf.cast(labels, tf.float32) 89 | # l1 normalization (labels are no less than 0) 90 | label_rowsum = tf.maximum( 91 | tf.reduce_sum(float_labels, 1, keep_dims=True), 92 | epsilon) 93 | norm_float_labels = tf.div(float_labels, label_rowsum) 94 | softmax_outputs = tf.nn.softmax(predictions) 95 | softmax_loss = tf.negative(tf.reduce_sum( 96 | tf.multiply(norm_float_labels, tf.log(softmax_outputs)), 1)) 97 | return tf.reduce_mean(softmax_loss) 98 | -------------------------------------------------------------------------------- /video_level_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains model definitions.""" 16 | import math 17 | 18 | import models 19 | import tensorflow as tf 20 | import utils 21 | 22 | from tensorflow import flags 23 | import tensorflow.contrib.slim as slim 24 | 25 | FLAGS = flags.FLAGS 26 | flags.DEFINE_integer( 27 | "moe_num_mixtures", 2, 28 | "The number of mixtures (excluding the dummy 'expert') used for MoeModel.") 29 | 30 | class LogisticModel(models.BaseModel): 31 | """Logistic model with L2 regularization.""" 32 | 33 | def create_model(self, model_input, vocab_size, l2_penalty=1e-8, **unused_params): 34 | """Creates a logistic model. 35 | 36 | Args: 37 | model_input: 'batch' x 'num_features' matrix of input features. 38 | vocab_size: The number of classes in the dataset. 39 | 40 | Returns: 41 | A dictionary with a tensor containing the probability predictions of the 42 | model in the 'predictions' key. The dimensions of the tensor are 43 | batch_size x num_classes.""" 44 | output = slim.fully_connected( 45 | model_input, vocab_size, activation_fn=tf.nn.sigmoid, 46 | weights_regularizer=slim.l2_regularizer(l2_penalty)) 47 | return {"predictions": output} 48 | 49 | class MoeModel(models.BaseModel): 50 | """A softmax over a mixture of logistic models (with L2 regularization).""" 51 | 52 | def create_model(self, 53 | model_input, 54 | vocab_size, 55 | num_mixtures=None, 56 | l2_penalty=1e-8, 57 | **unused_params): 58 | """Creates a Mixture of (Logistic) Experts model. 59 | 60 | The model consists of a per-class softmax distribution over a 61 | configurable number of logistic classifiers. One of the classifiers in the 62 | mixture is not trained, and always predicts 0. 63 | 64 | Args: 65 | model_input: 'batch_size' x 'num_features' matrix of input features. 66 | vocab_size: The number of classes in the dataset. 67 | num_mixtures: The number of mixtures (excluding a dummy 'expert' that 68 | always predicts the non-existence of an entity). 69 | l2_penalty: How much to penalize the squared magnitudes of parameter 70 | values. 71 | Returns: 72 | A dictionary with a tensor containing the probability predictions of the 73 | model in the 'predictions' key. The dimensions of the tensor are 74 | batch_size x num_classes. 75 | """ 76 | num_mixtures = num_mixtures or FLAGS.moe_num_mixtures 77 | 78 | gate_activations = slim.fully_connected( 79 | model_input, 80 | vocab_size * (num_mixtures + 1), 81 | activation_fn=None, 82 | biases_initializer=None, 83 | weights_regularizer=slim.l2_regularizer(l2_penalty), 84 | scope="gates") 85 | expert_activations = slim.fully_connected( 86 | model_input, 87 | vocab_size * num_mixtures, 88 | activation_fn=None, 89 | weights_regularizer=slim.l2_regularizer(l2_penalty), 90 | scope="experts") 91 | 92 | gating_distribution = tf.nn.softmax(tf.reshape( 93 | gate_activations, 94 | [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1) 95 | expert_distribution = tf.nn.sigmoid(tf.reshape( 96 | expert_activations, 97 | [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures 98 | 99 | final_probabilities_by_class_and_batch = tf.reduce_sum( 100 | gating_distribution[:, :num_mixtures] * expert_distribution, 1) 101 | final_probabilities = tf.reshape(final_probabilities_by_class_and_batch, 102 | [-1, vocab_size]) 103 | return {"predictions": final_probabilities} 104 | -------------------------------------------------------------------------------- /mean_average_precision_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Calculate the mean average precision. 16 | 17 | It provides an interface for calculating mean average precision 18 | for an entire list or the top-n ranked items. 19 | 20 | Example usages: 21 | We first call the function accumulate many times to process parts of the ranked 22 | list. After processing all the parts, we call peek_map_at_n 23 | to calculate the mean average precision. 24 | 25 | ``` 26 | import random 27 | 28 | p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)]) 29 | a = np.array([[random.choice([0, 1]) for _ in xrange(50)] 30 | for _ in xrange(1000)]) 31 | 32 | # mean average precision for 50 classes. 33 | calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator( 34 | num_class=50) 35 | calculator.accumulate(p, a) 36 | aps = calculator.peek_map_at_n() 37 | ``` 38 | """ 39 | 40 | import numpy 41 | import average_precision_calculator 42 | 43 | 44 | class MeanAveragePrecisionCalculator(object): 45 | """This class is to calculate mean average precision. 46 | """ 47 | 48 | def __init__(self, num_class): 49 | """Construct a calculator to calculate the (macro) average precision. 50 | 51 | Args: 52 | num_class: A positive Integer specifying the number of classes. 53 | top_n_array: A list of positive integers specifying the top n for each 54 | class. The top n in each class will be used to calculate its average 55 | precision at n. 56 | The size of the array must be num_class. 57 | 58 | Raises: 59 | ValueError: An error occurred when num_class is not a positive integer; 60 | or the top_n_array is not a list of positive integers. 61 | """ 62 | if not isinstance(num_class, int) or num_class <= 1: 63 | raise ValueError("num_class must be a positive integer.") 64 | 65 | self._ap_calculators = [] # member of AveragePrecisionCalculator 66 | self._num_class = num_class # total number of classes 67 | for i in range(num_class): 68 | self._ap_calculators.append( 69 | average_precision_calculator.AveragePrecisionCalculator()) 70 | 71 | def accumulate(self, predictions, actuals, num_positives=None): 72 | """Accumulate the predictions and their ground truth labels. 73 | 74 | Args: 75 | predictions: A list of lists storing the prediction scores. The outer 76 | dimension corresponds to classes. 77 | actuals: A list of lists storing the ground truth labels. The dimensions 78 | should correspond to the predictions input. Any value 79 | larger than 0 will be treated as positives, otherwise as negatives. 80 | num_positives: If provided, it is a list of numbers representing the 81 | number of true positives for each class. If not provided, the number of 82 | true positives will be inferred from the 'actuals' array. 83 | 84 | Raises: 85 | ValueError: An error occurred when the shape of predictions and actuals 86 | does not match. 87 | """ 88 | if not num_positives: 89 | num_positives = [None for i in predictions.shape[1]] 90 | 91 | calculators = self._ap_calculators 92 | for i in range(len(predictions)): 93 | calculators[i].accumulate(predictions[i], actuals[i], num_positives[i]) 94 | 95 | def clear(self): 96 | for calculator in self._ap_calculators: 97 | calculator.clear() 98 | 99 | def is_empty(self): 100 | return ([calculator.heap_size for calculator in self._ap_calculators] == 101 | [0 for _ in range(self._num_class)]) 102 | 103 | def peek_map_at_n(self): 104 | """Peek the non-interpolated mean average precision at n. 105 | 106 | Returns: 107 | An array of non-interpolated average precision at n (default 0) for each 108 | class. 109 | """ 110 | aps = [self._ap_calculators[i].peek_ap_at_n() 111 | for i in range(self._num_class)] 112 | return aps 113 | -------------------------------------------------------------------------------- /export_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities to export a model for batch prediction.""" 15 | 16 | import tensorflow as tf 17 | import tensorflow.contrib.slim as slim 18 | 19 | from tensorflow.python.saved_model import builder as saved_model_builder 20 | from tensorflow.python.saved_model import signature_constants 21 | from tensorflow.python.saved_model import signature_def_utils 22 | from tensorflow.python.saved_model import tag_constants 23 | from tensorflow.python.saved_model import utils as saved_model_utils 24 | 25 | _TOP_PREDICTIONS_IN_OUTPUT = 20 26 | 27 | class ModelExporter(object): 28 | 29 | def __init__(self, frame_features, model, reader): 30 | self.frame_features = frame_features 31 | self.model = model 32 | self.reader = reader 33 | 34 | with tf.Graph().as_default() as graph: 35 | self.inputs, self.outputs = self.build_inputs_and_outputs() 36 | self.graph = graph 37 | self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True) 38 | 39 | def export_model(self, model_dir, global_step_val, last_checkpoint): 40 | """Exports the model so that it can used for batch predictions.""" 41 | 42 | with self.graph.as_default(): 43 | with tf.Session() as session: 44 | session.run(tf.global_variables_initializer()) 45 | self.saver.restore(session, last_checkpoint) 46 | 47 | signature = signature_def_utils.build_signature_def( 48 | inputs=self.inputs, 49 | outputs=self.outputs, 50 | method_name=signature_constants.PREDICT_METHOD_NAME) 51 | 52 | signature_map = {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 53 | signature} 54 | 55 | model_builder = saved_model_builder.SavedModelBuilder(model_dir) 56 | model_builder.add_meta_graph_and_variables(session, 57 | tags=[tag_constants.SERVING], 58 | signature_def_map=signature_map, 59 | clear_devices=True) 60 | model_builder.save() 61 | 62 | def build_inputs_and_outputs(self): 63 | if self.frame_features: 64 | serialized_examples = tf.placeholder(tf.string, shape=(None,)) 65 | 66 | fn = lambda x: self.build_prediction_graph(x) 67 | video_id_output, top_indices_output, top_predictions_output = ( 68 | tf.map_fn(fn, serialized_examples, 69 | dtype=(tf.string, tf.int32, tf.float32))) 70 | 71 | else: 72 | serialized_examples = tf.placeholder(tf.string, shape=(None,)) 73 | 74 | video_id_output, top_indices_output, top_predictions_output = ( 75 | self.build_prediction_graph(serialized_examples)) 76 | 77 | inputs = {"example_bytes": 78 | saved_model_utils.build_tensor_info(serialized_examples)} 79 | 80 | outputs = { 81 | "video_id": saved_model_utils.build_tensor_info(video_id_output), 82 | "class_indexes": saved_model_utils.build_tensor_info(top_indices_output), 83 | "predictions": saved_model_utils.build_tensor_info(top_predictions_output)} 84 | 85 | return inputs, outputs 86 | 87 | def build_prediction_graph(self, serialized_examples): 88 | video_id, model_input_raw, labels_batch, num_frames = ( 89 | self.reader.prepare_serialized_examples(serialized_examples)) 90 | 91 | feature_dim = len(model_input_raw.get_shape()) - 1 92 | model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) 93 | 94 | with tf.variable_scope("tower"): 95 | result = self.model.create_model( 96 | model_input, 97 | num_frames=num_frames, 98 | vocab_size=self.reader.num_classes, 99 | labels=labels_batch, 100 | is_training=False) 101 | 102 | for variable in slim.get_model_variables(): 103 | tf.summary.histogram(variable.op.name, variable) 104 | 105 | predictions = result["predictions"] 106 | 107 | top_predictions, top_indices = tf.nn.top_k(predictions, 108 | _TOP_PREDICTIONS_IN_OUTPUT) 109 | return video_id, top_indices, top_predictions 110 | -------------------------------------------------------------------------------- /feature_extractor/feature_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Facilitates extracting YouTube8M features from RGB images.""" 15 | 16 | import os 17 | import sys 18 | import tarfile 19 | import numpy 20 | from six.moves import urllib 21 | import tensorflow as tf 22 | 23 | INCEPTION_TF_GRAPH = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 24 | YT8M_PCA_MAT = 'http://data.yt8m.org/yt8m_pca.tgz' 25 | MODEL_DIR = os.path.join(os.getenv('HOME'), 'yt8m') 26 | 27 | 28 | class YouTube8MFeatureExtractor(object): 29 | """Extracts YouTube8M features for RGB frames. 30 | 31 | First time constructing this class will create directory `yt8m` inside your 32 | home directory, and will download inception model (85 MB) and YouTube8M PCA 33 | matrix (15 MB). If you want to use another directory, then pass it to argument 34 | `model_dir` of constructor. 35 | 36 | If the model_dir exist and contains the necessary files, then files will be 37 | re-used without download. 38 | 39 | Usage Example: 40 | 41 | from PIL import Image 42 | import numpy 43 | 44 | # Instantiate extractor. Slow if called first time on your machine, as it 45 | # needs to download 100 MB. 46 | extractor = YouTube8MFeatureExtractor() 47 | 48 | image_file = os.path.join(extractor._model_dir, 'cropped_panda.jpg') 49 | 50 | im = numpy.array(Image.open(image_file)) 51 | features = extractor.extract_rgb_frame_features(im) 52 | 53 | ** Note: OpenCV reverses the order of channels (i.e. orders channels as BGR 54 | instead of RGB). If you are using OpenCV, then you must do: 55 | 56 | im = im[:, :, ::-1] # Reverses order on last (i.e. channel) dimension. 57 | 58 | then call `extractor.extract_rgb_frame_features(im)` 59 | """ 60 | 61 | def __init__(self, model_dir=MODEL_DIR): 62 | # Create MODEL_DIR if not created. 63 | self._model_dir = model_dir 64 | if not os.path.exists(model_dir): 65 | os.makedirs(model_dir) 66 | 67 | # Load Inception Network 68 | download_path = self._maybe_download(INCEPTION_TF_GRAPH) 69 | inception_proto_file = os.path.join( 70 | self._model_dir, 'classify_image_graph_def.pb') 71 | if not os.path.exists(inception_proto_file): 72 | tarfile.open(download_path, 'r:gz').extractall(model_dir) 73 | self._load_inception(inception_proto_file) 74 | 75 | # Load PCA Matrix. 76 | download_path = self._maybe_download(YT8M_PCA_MAT) 77 | pca_mean = os.path.join(self._model_dir, 'mean.npy') 78 | if not os.path.exists(pca_mean): 79 | tarfile.open(download_path, 'r:gz').extractall(model_dir) 80 | self._load_pca() 81 | 82 | def extract_rgb_frame_features(self, frame_rgb, apply_pca=True): 83 | """Applies the YouTube8M feature extraction over an RGB frame. 84 | 85 | This passes `frame_rgb` to inception3 model, extracting hidden layer 86 | activations and passing it to the YouTube8M PCA transformation. 87 | 88 | Args: 89 | frame_rgb: numpy array of uint8 with shape (height, width, channels) where 90 | channels must be 3 (RGB), and height and weight can be anything, as the 91 | inception model will resize. 92 | apply_pca: If not set, PCA transformation will be skipped. 93 | 94 | Returns: 95 | Output of inception from `frame_rgb` (2048-D) and optionally passed into 96 | YouTube8M PCA transformation (1024-D). 97 | """ 98 | assert len(frame_rgb.shape) == 3 99 | assert frame_rgb.shape[2] == 3 # 3 channels (R, G, B) 100 | with self._inception_graph.as_default(): 101 | frame_features = self.session.run('pool_3/_reshape:0', 102 | feed_dict={'DecodeJpeg:0': frame_rgb}) 103 | frame_features = frame_features[0] # Unbatch. 104 | 105 | if apply_pca: 106 | frame_features = self.apply_pca(frame_features) 107 | 108 | return frame_features 109 | 110 | def apply_pca(self, frame_features): 111 | """Applies the YouTube8M PCA Transformation over `frame_features`. 112 | 113 | Args: 114 | frame_features: numpy array of floats, 2048 dimensional vector. 115 | 116 | Returns: 117 | 1024 dimensional vector as a numpy array. 118 | """ 119 | # Subtract mean 120 | feats = frame_features - self.pca_mean 121 | 122 | # Multiply by eigenvectors. 123 | feats = feats.reshape((1, 2048)).dot(self.pca_eigenvecs).reshape((1024,)) 124 | 125 | # Whiten 126 | feats /= numpy.sqrt(self.pca_eigenvals + 1e-4) 127 | return feats 128 | 129 | def _maybe_download(self, url): 130 | """Downloads `url` if not in `_model_dir`.""" 131 | filename = os.path.basename(url) 132 | download_path = os.path.join(self._model_dir, filename) 133 | if os.path.exists(download_path): 134 | return download_path 135 | 136 | def _progress(count, block_size, total_size): 137 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 138 | filename, float(count * block_size) / float(total_size) * 100.0)) 139 | sys.stdout.flush() 140 | urllib.request.urlretrieve(url, download_path, _progress) 141 | statinfo = os.stat(download_path) 142 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 143 | return download_path 144 | 145 | def _load_inception(self, proto_file): 146 | graph_def = tf.GraphDef.FromString(open(proto_file, 'rb').read()) 147 | self._inception_graph = tf.Graph() 148 | with self._inception_graph.as_default(): 149 | _ = tf.import_graph_def(graph_def, name='') 150 | self.session = tf.Session() 151 | 152 | def _load_pca(self): 153 | self.pca_mean = numpy.load( 154 | os.path.join(self._model_dir, 'mean.npy'))[:, 0] 155 | self.pca_eigenvals = numpy.load( 156 | os.path.join(self._model_dir, 'eigenvals.npy'))[:1024, 0] 157 | self.pca_eigenvecs = numpy.load( 158 | os.path.join(self._model_dir, 'eigenvecs.npy')).T[:, :1024] 159 | -------------------------------------------------------------------------------- /feature_extractor/extract_tfrecords_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Produces tfrecord files similar to the YouTube-8M dataset. 15 | 16 | It processes a CSV file containing lines like ",", where 17 | must be a path of a video, and must be an integer list 18 | joined with semi-colon ";". It processes all videos and outputs tfrecord file 19 | onto --output_tfrecords_file. 20 | 21 | It assumes that you have OpenCV installed and properly linked with ffmpeg (i.e. 22 | function `cv2.VideoCapture().open('/path/to/some/video')` should return True). 23 | 24 | The binary only processes the video stream (images) and not the audio stream. 25 | """ 26 | 27 | import csv 28 | import os 29 | import sys 30 | 31 | import cv2 32 | import feature_extractor 33 | import numpy 34 | import tensorflow as tf 35 | from tensorflow import app 36 | from tensorflow import flags 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | # In OpenCV3.X, this is available as cv2.CAP_PROP_POS_MSEC 41 | # In OpenCV2.X, this is available as cv2.cv.CV_CAP_PROP_POS_MSEC 42 | CAP_PROP_POS_MSEC = 0 43 | 44 | if __name__ == '__main__': 45 | # Required flags for input and output. 46 | flags.DEFINE_string('output_tfrecords_file', None, 47 | 'File containing tfrecords will be written at this path.') 48 | flags.DEFINE_string('input_videos_csv', None, 49 | 'CSV file with lines ",", where ' 50 | ' must be a path of a video and ' 51 | 'must be an integer list joined with semi-colon ";"') 52 | # Optional flags. 53 | flags.DEFINE_string('model_dir', os.path.join(os.getenv('HOME'), 'yt8m'), 54 | 'Directory to store model files. It defaults to ~/yt8m') 55 | 56 | # The following flags are set to match the YouTube-8M dataset format. 57 | flags.DEFINE_integer('frames_per_second', 1, 58 | 'This many frames per second will be processed') 59 | flags.DEFINE_string('labels_feature_key', 'labels', 60 | 'Labels will be written to context feature with this ' 61 | 'key, as int64 list feature.') 62 | flags.DEFINE_string('image_feature_key', 'rgb', 63 | 'Image features will be written to sequence feature with ' 64 | 'this key, as bytes list feature, with only one entry, ' 65 | 'containing quantized feature string.') 66 | flags.DEFINE_string('video_file_key_feature_key', 'video_id', 67 | 'Input will be written to context feature ' 68 | 'with this key, as bytes list feature, with only one ' 69 | 'entry, containing the file path of the video. This ' 70 | 'can be used for debugging but not for training or eval.') 71 | flags.DEFINE_boolean('insert_zero_audio_features', True, 72 | 'If set, inserts features with name "audio" to be 128-D ' 73 | 'zero vectors. This allows you to use YouTube-8M ' 74 | 'pre-trained model.') 75 | 76 | 77 | def frame_iterator(filename, every_ms=1000, max_num_frames=300): 78 | """Uses OpenCV to iterate over all frames of filename at a given frequency. 79 | 80 | Args: 81 | filename: Path to video file (e.g. mp4) 82 | every_ms: The duration (in milliseconds) to skip between frames. 83 | max_num_frames: Maximum number of frames to process, taken from the 84 | beginning of the video. 85 | 86 | Yields: 87 | RGB frame with shape (image height, image width, channels) 88 | """ 89 | video_capture = cv2.VideoCapture() 90 | if not video_capture.open(filename): 91 | print >> sys.stderr, 'Error: Cannot open video file ' + filename 92 | return 93 | last_ts = -99999 # The timestamp of last retrieved frame. 94 | num_retrieved = 0 95 | 96 | while num_retrieved < max_num_frames: 97 | # Skip frames 98 | while video_capture.get(CAP_PROP_POS_MSEC) < every_ms + last_ts: 99 | if not video_capture.read()[0]: 100 | return 101 | 102 | last_ts = video_capture.get(CAP_PROP_POS_MSEC) 103 | has_frames, frame = video_capture.read() 104 | if not has_frames: 105 | break 106 | yield frame 107 | num_retrieved += 1 108 | 109 | 110 | def _int64_list_feature(int64_list): 111 | return tf.train.Feature(int64_list=tf.train.Int64List(value=int64_list)) 112 | 113 | 114 | def _bytes_feature(value): 115 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 116 | 117 | 118 | def _make_bytes(int_array): 119 | if bytes == str: # Python2 120 | return ''.join(map(chr, int_array)) 121 | else: 122 | return bytes(int_array) 123 | 124 | 125 | def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): 126 | """Quantizes float32 `features` into string.""" 127 | assert features.dtype == 'float32' 128 | assert len(features.shape) == 1 # 1-D array 129 | features = numpy.clip(features, min_quantized_value, max_quantized_value) 130 | quantize_range = max_quantized_value - min_quantized_value 131 | features = (features - min_quantized_value) * (255.0 / quantize_range) 132 | features = [int(round(f)) for f in features] 133 | 134 | return _make_bytes(features) 135 | 136 | 137 | def main(unused_argv): 138 | extractor = feature_extractor.YouTube8MFeatureExtractor(FLAGS.model_dir) 139 | writer = tf.python_io.TFRecordWriter(FLAGS.output_tfrecords_file) 140 | total_written = 0 141 | total_error = 0 142 | for video_file, labels in csv.reader(open(FLAGS.input_videos_csv)): 143 | rgb_features = [] 144 | for rgb in frame_iterator( 145 | video_file, every_ms=1000.0/FLAGS.frames_per_second): 146 | features = extractor.extract_rgb_frame_features(rgb[:, :, ::-1]) 147 | rgb_features.append(_bytes_feature(quantize(features))) 148 | 149 | if not rgb_features: 150 | print >> sys.stderr, 'Could not get features for ' + video_file 151 | total_error += 1 152 | continue 153 | 154 | # Create SequenceExample proto and write to output. 155 | feature_list = { 156 | FLAGS.image_feature_key: tf.train.FeatureList(feature=rgb_features), 157 | } 158 | if FLAGS.insert_zero_audio_features: 159 | feature_list['audio'] = tf.train.FeatureList( 160 | feature=[_bytes_feature(_make_bytes([0] * 128))] * len(rgb_features)) 161 | 162 | example = tf.train.SequenceExample( 163 | context=tf.train.Features(feature={ 164 | FLAGS.labels_feature_key: 165 | _int64_list_feature(sorted(map(int, labels.split(';')))), 166 | FLAGS.video_file_key_feature_key: 167 | _bytes_feature(_make_bytes(map(ord, video_file))), 168 | }), 169 | feature_lists=tf.train.FeatureLists(feature_list=feature_list)) 170 | writer.write(example.SerializeToString()) 171 | total_written += 1 172 | 173 | writer.close() 174 | print('Successfully encoded %i out of %i videos' % ( 175 | total_written, total_written + total_error)) 176 | 177 | 178 | if __name__ == '__main__': 179 | app.run(main) 180 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains a collection of util functions for training and evaluating. 16 | """ 17 | 18 | import numpy 19 | import tensorflow as tf 20 | from tensorflow import logging 21 | 22 | try: 23 | xrange # Python 2 24 | except NameError: 25 | xrange = range # Python 3 26 | 27 | 28 | def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): 29 | """Dequantize the feature from the byte format to the float format. 30 | 31 | Args: 32 | feat_vector: the input 1-d vector. 33 | max_quantized_value: the maximum of the quantized value. 34 | min_quantized_value: the minimum of the quantized value. 35 | 36 | Returns: 37 | A float vector which has the same shape as feat_vector. 38 | """ 39 | assert max_quantized_value > min_quantized_value 40 | quantized_range = max_quantized_value - min_quantized_value 41 | scalar = quantized_range / 255.0 42 | bias = (quantized_range / 512.0) + min_quantized_value 43 | return feat_vector * scalar + bias 44 | 45 | 46 | def MakeSummary(name, value): 47 | """Creates a tf.Summary proto with the given name and value.""" 48 | summary = tf.Summary() 49 | val = summary.value.add() 50 | val.tag = str(name) 51 | val.simple_value = float(value) 52 | return summary 53 | 54 | 55 | def AddGlobalStepSummary(summary_writer, 56 | global_step_val, 57 | global_step_info_dict, 58 | summary_scope="Eval"): 59 | """Add the global_step summary to the Tensorboard. 60 | 61 | Args: 62 | summary_writer: Tensorflow summary_writer. 63 | global_step_val: a int value of the global step. 64 | global_step_info_dict: a dictionary of the evaluation metrics calculated for 65 | a mini-batch. 66 | summary_scope: Train or Eval. 67 | 68 | Returns: 69 | A string of this global_step summary 70 | """ 71 | this_hit_at_one = global_step_info_dict["hit_at_one"] 72 | this_perr = global_step_info_dict["perr"] 73 | this_loss = global_step_info_dict["loss"] 74 | examples_per_second = global_step_info_dict.get("examples_per_second", -1) 75 | 76 | summary_writer.add_summary( 77 | MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one), 78 | global_step_val) 79 | summary_writer.add_summary( 80 | MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr), 81 | global_step_val) 82 | summary_writer.add_summary( 83 | MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss), 84 | global_step_val) 85 | 86 | if examples_per_second != -1: 87 | summary_writer.add_summary( 88 | MakeSummary("GlobalStep/" + summary_scope + "_Example_Second", 89 | examples_per_second), global_step_val) 90 | 91 | summary_writer.flush() 92 | info = ("global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch Loss: {3:.3f} " 93 | "| Examples_per_sec: {4:.3f}").format( 94 | global_step_val, this_hit_at_one, this_perr, this_loss, 95 | examples_per_second) 96 | return info 97 | 98 | 99 | def AddEpochSummary(summary_writer, 100 | global_step_val, 101 | epoch_info_dict, 102 | summary_scope="Eval"): 103 | """Add the epoch summary to the Tensorboard. 104 | 105 | Args: 106 | summary_writer: Tensorflow summary_writer. 107 | global_step_val: a int value of the global step. 108 | epoch_info_dict: a dictionary of the evaluation metrics calculated for the 109 | whole epoch. 110 | summary_scope: Train or Eval. 111 | 112 | Returns: 113 | A string of this global_step summary 114 | """ 115 | epoch_id = epoch_info_dict["epoch_id"] 116 | avg_hit_at_one = epoch_info_dict["avg_hit_at_one"] 117 | avg_perr = epoch_info_dict["avg_perr"] 118 | avg_loss = epoch_info_dict["avg_loss"] 119 | aps = epoch_info_dict["aps"] 120 | gap = epoch_info_dict["gap"] 121 | mean_ap = numpy.mean(aps) 122 | 123 | summary_writer.add_summary( 124 | MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one), 125 | global_step_val) 126 | summary_writer.add_summary( 127 | MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr), 128 | global_step_val) 129 | summary_writer.add_summary( 130 | MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss), 131 | global_step_val) 132 | summary_writer.add_summary( 133 | MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), 134 | global_step_val) 135 | summary_writer.add_summary( 136 | MakeSummary("Epoch/" + summary_scope + "_GAP", gap), 137 | global_step_val) 138 | summary_writer.flush() 139 | 140 | info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} " 141 | "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f}").format( 142 | epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss) 143 | return info 144 | 145 | def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): 146 | """Extract the list of feature names and the dimensionality of each feature 147 | from string of comma separated values. 148 | 149 | Args: 150 | feature_names: string containing comma separated list of feature names 151 | feature_sizes: string containing comma separated list of feature sizes 152 | 153 | Returns: 154 | List of the feature names and list of the dimensionality of each feature. 155 | Elements in the first/second list are strings/integers. 156 | """ 157 | list_of_feature_names = [ 158 | feature_names.strip() for feature_names in feature_names.split(',')] 159 | list_of_feature_sizes = [ 160 | int(feature_sizes) for feature_sizes in feature_sizes.split(',')] 161 | if len(list_of_feature_names) != len(list_of_feature_sizes): 162 | logging.error("length of the feature names (=" + 163 | str(len(list_of_feature_names)) + ") != length of feature " 164 | "sizes (=" + str(len(list_of_feature_sizes)) + ")") 165 | 166 | return list_of_feature_names, list_of_feature_sizes 167 | 168 | def clip_gradient_norms(gradients_to_variables, max_norm): 169 | """Clips the gradients by the given value. 170 | 171 | Args: 172 | gradients_to_variables: A list of gradient to variable pairs (tuples). 173 | max_norm: the maximum norm value. 174 | 175 | Returns: 176 | A list of clipped gradient to variable pairs. 177 | """ 178 | clipped_grads_and_vars = [] 179 | for grad, var in gradients_to_variables: 180 | if grad is not None: 181 | if isinstance(grad, tf.IndexedSlices): 182 | tmp = tf.clip_by_norm(grad.values, max_norm) 183 | grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape) 184 | else: 185 | grad = tf.clip_by_norm(grad, max_norm) 186 | clipped_grads_and_vars.append((grad, var)) 187 | return clipped_grads_and_vars 188 | 189 | def combine_gradients(tower_grads): 190 | """Calculate the combined gradient for each shared variable across all towers. 191 | 192 | Note that this function provides a synchronization point across all towers. 193 | 194 | Args: 195 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 196 | is over individual gradients. The inner list is over the gradient 197 | calculation for each tower. 198 | Returns: 199 | List of pairs of (gradient, variable) where the gradient has been summed 200 | across all towers. 201 | """ 202 | filtered_grads = [[x for x in grad_list if x[0] is not None] for grad_list in tower_grads] 203 | final_grads = [] 204 | for i in xrange(len(filtered_grads[0])): 205 | grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))] 206 | grad = tf.stack([x[0] for x in grads], 0) 207 | grad = tf.reduce_sum(grad, 0) 208 | final_grads.append((grad, filtered_grads[0][i][1],)) 209 | 210 | return final_grads 211 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Binary for generating predictions over a set of videos.""" 16 | 17 | import os 18 | import time 19 | 20 | import numpy 21 | import tensorflow as tf 22 | 23 | from tensorflow import app 24 | from tensorflow import flags 25 | from tensorflow import gfile 26 | from tensorflow import logging 27 | 28 | import eval_util 29 | import losses 30 | import readers 31 | import utils 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | if __name__ == '__main__': 36 | flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", 37 | "The directory to load the model files from.") 38 | flags.DEFINE_string("checkpoint_file", "", 39 | "If provided, this specific checkpoint file will be " 40 | "used for inference. Otherwise, the latest checkpoint " 41 | "from the train_dir' argument will be used instead.") 42 | flags.DEFINE_string("output_file", "", 43 | "The file to save the predictions to.") 44 | flags.DEFINE_string( 45 | "input_data_pattern", "", 46 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " 47 | "format. The SequenceExamples are expected to have an 'rgb' byte array " 48 | "sequence feature as well as a 'labels' int64 context feature.") 49 | 50 | # Model flags. 51 | flags.DEFINE_bool( 52 | "frame_features", False, 53 | "If set, then --input_data_pattern must be frame-level features. " 54 | "Otherwise, --input_data_pattern must be aggregated video-level " 55 | "features. The model must also be set appropriately (i.e. to read 3D " 56 | "batches VS 4D batches.") 57 | flags.DEFINE_integer( 58 | "batch_size", 8192, 59 | "How many examples to process per batch.") 60 | flags.DEFINE_string("feature_names", "mean_rgb", "Name of the feature " 61 | "to use for training.") 62 | flags.DEFINE_string("feature_sizes", "1024", "Length of the feature vectors.") 63 | 64 | 65 | # Other flags. 66 | flags.DEFINE_integer("num_readers", 1, 67 | "How many threads to use for reading input files.") 68 | flags.DEFINE_integer("top_k", 20, 69 | "How many predictions to output per video.") 70 | 71 | def format_lines(video_ids, predictions, top_k): 72 | batch_size = len(video_ids) 73 | for video_index in range(batch_size): 74 | top_indices = numpy.argpartition(predictions[video_index], -top_k)[-top_k:] 75 | line = [(class_index, predictions[video_index][class_index]) 76 | for class_index in top_indices] 77 | line = sorted(line, key=lambda p: -p[1]) 78 | yield video_ids[video_index].decode('utf-8') + "," + " ".join("%i %f" % pair 79 | for pair in line) + "\n" 80 | 81 | 82 | def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1): 83 | """Creates the section of the graph which reads the input data. 84 | 85 | Args: 86 | reader: A class which parses the input data. 87 | data_pattern: A 'glob' style path to the data files. 88 | batch_size: How many examples to process at a time. 89 | num_readers: How many I/O threads to use. 90 | 91 | Returns: 92 | A tuple containing the features tensor, labels tensor, and optionally a 93 | tensor containing the number of frames per video. The exact dimensions 94 | depend on the reader being used. 95 | 96 | Raises: 97 | IOError: If no files matching the given pattern were found. 98 | """ 99 | with tf.name_scope("input"): 100 | files = gfile.Glob(data_pattern) 101 | if not files: 102 | raise IOError("Unable to find input files. data_pattern='" + 103 | data_pattern + "'") 104 | logging.info("number of input files: " + str(len(files))) 105 | filename_queue = tf.train.string_input_producer( 106 | files, num_epochs=1, shuffle=False) 107 | examples_and_labels = [reader.prepare_reader(filename_queue) 108 | for _ in range(num_readers)] 109 | 110 | video_id_batch, video_batch, unused_labels, num_frames_batch = ( 111 | tf.train.batch_join(examples_and_labels, 112 | batch_size=batch_size, 113 | allow_smaller_final_batch=True, 114 | enqueue_many=True)) 115 | return video_id_batch, video_batch, num_frames_batch 116 | 117 | def inference(reader, checkpoint_file, train_dir, data_pattern, out_file_location, batch_size, top_k): 118 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file: 119 | video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size) 120 | if checkpoint_file: 121 | if not gfile.Exists(checkpoint_file + ".meta"): 122 | logging.fatal("Unable to find checkpoint file at provided location '%s'" % checkpoint_file) 123 | latest_checkpoint = checkpoint_file 124 | else: 125 | latest_checkpoint = tf.train.latest_checkpoint(train_dir) 126 | if latest_checkpoint is None: 127 | raise Exception("unable to find a checkpoint at location: %s" % train_dir) 128 | else: 129 | meta_graph_location = latest_checkpoint + ".meta" 130 | logging.info("loading meta-graph: " + meta_graph_location) 131 | saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True) 132 | logging.info("restoring variables from " + latest_checkpoint) 133 | saver.restore(sess, latest_checkpoint) 134 | input_tensor = tf.get_collection("input_batch_raw")[0] 135 | num_frames_tensor = tf.get_collection("num_frames")[0] 136 | predictions_tensor = tf.get_collection("predictions")[0] 137 | 138 | # Workaround for num_epochs issue. 139 | def set_up_init_ops(variables): 140 | init_op_list = [] 141 | for variable in list(variables): 142 | if "train_input" in variable.name: 143 | init_op_list.append(tf.assign(variable, 1)) 144 | variables.remove(variable) 145 | init_op_list.append(tf.variables_initializer(variables)) 146 | return init_op_list 147 | 148 | sess.run(set_up_init_ops(tf.get_collection_ref( 149 | tf.GraphKeys.LOCAL_VARIABLES))) 150 | 151 | coord = tf.train.Coordinator() 152 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 153 | num_examples_processed = 0 154 | start_time = time.time() 155 | out_file.write("VideoId,LabelConfidencePairs\n") 156 | 157 | try: 158 | while not coord.should_stop(): 159 | video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch]) 160 | predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val}) 161 | now = time.time() 162 | num_examples_processed += len(video_batch_val) 163 | num_classes = predictions_val.shape[1] 164 | logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time)) 165 | for line in format_lines(video_id_batch_val, predictions_val, top_k): 166 | out_file.write(line) 167 | out_file.flush() 168 | 169 | 170 | except tf.errors.OutOfRangeError: 171 | logging.info('Done with inference. The output file was written to ' + out_file_location) 172 | finally: 173 | coord.request_stop() 174 | 175 | coord.join(threads) 176 | sess.close() 177 | 178 | 179 | def main(unused_argv): 180 | logging.set_verbosity(tf.logging.INFO) 181 | 182 | # convert feature_names and feature_sizes to lists of values 183 | feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( 184 | FLAGS.feature_names, FLAGS.feature_sizes) 185 | 186 | if FLAGS.frame_features: 187 | reader = readers.YT8MFrameFeatureReader(feature_names=feature_names, 188 | feature_sizes=feature_sizes) 189 | else: 190 | reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names, 191 | feature_sizes=feature_sizes) 192 | 193 | if FLAGS.output_file is "": 194 | raise ValueError("'output_file' was not specified. " 195 | "Unable to continue with inference.") 196 | 197 | if FLAGS.input_data_pattern is "": 198 | raise ValueError("'input_data_pattern' was not specified. " 199 | "Unable to continue with inference.") 200 | 201 | inference(reader, FLAGS.checkpoint_file, FLAGS.train_dir, FLAGS.input_data_pattern, 202 | FLAGS.output_file, FLAGS.batch_size, FLAGS.top_k) 203 | 204 | 205 | if __name__ == "__main__": 206 | app.run() 207 | -------------------------------------------------------------------------------- /eval_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Provides functions to help with evaluating models.""" 16 | import datetime 17 | import numpy 18 | 19 | from tensorflow.python.platform import gfile 20 | 21 | import mean_average_precision_calculator as map_calculator 22 | import average_precision_calculator as ap_calculator 23 | 24 | def flatten(l): 25 | """ Merges a list of lists into a single list. """ 26 | return [item for sublist in l for item in sublist] 27 | 28 | def calculate_hit_at_one(predictions, actuals): 29 | """Performs a local (numpy) calculation of the hit at one. 30 | 31 | Args: 32 | predictions: Matrix containing the outputs of the model. 33 | Dimensions are 'batch' x 'num_classes'. 34 | actuals: Matrix containing the ground truth labels. 35 | Dimensions are 'batch' x 'num_classes'. 36 | 37 | Returns: 38 | float: The average hit at one across the entire batch. 39 | """ 40 | top_prediction = numpy.argmax(predictions, 1) 41 | hits = actuals[numpy.arange(actuals.shape[0]), top_prediction] 42 | return numpy.average(hits) 43 | 44 | 45 | def calculate_precision_at_equal_recall_rate(predictions, actuals): 46 | """Performs a local (numpy) calculation of the PERR. 47 | 48 | Args: 49 | predictions: Matrix containing the outputs of the model. 50 | Dimensions are 'batch' x 'num_classes'. 51 | actuals: Matrix containing the ground truth labels. 52 | Dimensions are 'batch' x 'num_classes'. 53 | 54 | Returns: 55 | float: The average precision at equal recall rate across the entire batch. 56 | """ 57 | aggregated_precision = 0.0 58 | num_videos = actuals.shape[0] 59 | for row in numpy.arange(num_videos): 60 | num_labels = int(numpy.sum(actuals[row])) 61 | top_indices = numpy.argpartition(predictions[row], 62 | -num_labels)[-num_labels:] 63 | item_precision = 0.0 64 | for label_index in top_indices: 65 | if predictions[row][label_index] > 0: 66 | item_precision += actuals[row][label_index] 67 | item_precision /= top_indices.size 68 | aggregated_precision += item_precision 69 | aggregated_precision /= num_videos 70 | return aggregated_precision 71 | 72 | def calculate_gap(predictions, actuals, top_k=20): 73 | """Performs a local (numpy) calculation of the global average precision. 74 | 75 | Only the top_k predictions are taken for each of the videos. 76 | 77 | Args: 78 | predictions: Matrix containing the outputs of the model. 79 | Dimensions are 'batch' x 'num_classes'. 80 | actuals: Matrix containing the ground truth labels. 81 | Dimensions are 'batch' x 'num_classes'. 82 | top_k: How many predictions to use per video. 83 | 84 | Returns: 85 | float: The global average precision. 86 | """ 87 | gap_calculator = ap_calculator.AveragePrecisionCalculator() 88 | sparse_predictions, sparse_labels, num_positives = top_k_by_class(predictions, actuals, top_k) 89 | gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels), sum(num_positives)) 90 | return gap_calculator.peek_ap_at_n() 91 | 92 | 93 | def top_k_by_class(predictions, labels, k=20): 94 | """Extracts the top k predictions for each video, sorted by class. 95 | 96 | Args: 97 | predictions: A numpy matrix containing the outputs of the model. 98 | Dimensions are 'batch' x 'num_classes'. 99 | k: the top k non-zero entries to preserve in each prediction. 100 | 101 | Returns: 102 | A tuple (predictions,labels, true_positives). 'predictions' and 'labels' 103 | are lists of lists of floats. 'true_positives' is a list of scalars. The 104 | length of the lists are equal to the number of classes. The entries in the 105 | predictions variable are probability predictions, and 106 | the corresponding entries in the labels variable are the ground truth for 107 | those predictions. The entries in 'true_positives' are the number of true 108 | positives for each class in the ground truth. 109 | 110 | Raises: 111 | ValueError: An error occurred when the k is not a positive integer. 112 | """ 113 | if k <= 0: 114 | raise ValueError("k must be a positive integer.") 115 | k = min(k, predictions.shape[1]) 116 | num_classes = predictions.shape[1] 117 | prediction_triplets= [] 118 | for video_index in range(predictions.shape[0]): 119 | prediction_triplets.extend(top_k_triplets(predictions[video_index],labels[video_index], k)) 120 | out_predictions = [[] for v in range(num_classes)] 121 | out_labels = [[] for v in range(num_classes)] 122 | for triplet in prediction_triplets: 123 | out_predictions[triplet[0]].append(triplet[1]) 124 | out_labels[triplet[0]].append(triplet[2]) 125 | out_true_positives = [numpy.sum(labels[:,i]) for i in range(num_classes)] 126 | 127 | return out_predictions, out_labels, out_true_positives 128 | 129 | def top_k_triplets(predictions, labels, k=20): 130 | """Get the top_k for a 1-d numpy array. Returns a sparse list of tuples in 131 | (prediction, class) format""" 132 | m = len(predictions) 133 | k = min(k, m) 134 | indices = numpy.argpartition(predictions, -k)[-k:] 135 | return [(index, predictions[index], labels[index]) for index in indices] 136 | 137 | class EvaluationMetrics(object): 138 | """A class to store the evaluation metrics.""" 139 | 140 | def __init__(self, num_class, top_k): 141 | """Construct an EvaluationMetrics object to store the evaluation metrics. 142 | 143 | Args: 144 | num_class: A positive integer specifying the number of classes. 145 | top_k: A positive integer specifying how many predictions are considered per video. 146 | 147 | Raises: 148 | ValueError: An error occurred when MeanAveragePrecisionCalculator cannot 149 | not be constructed. 150 | """ 151 | self.sum_hit_at_one = 0.0 152 | self.sum_perr = 0.0 153 | self.sum_loss = 0.0 154 | self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(num_class) 155 | self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator() 156 | self.top_k = top_k 157 | self.num_examples = 0 158 | 159 | def accumulate(self, predictions, labels, loss): 160 | """Accumulate the metrics calculated locally for this mini-batch. 161 | 162 | Args: 163 | predictions: A numpy matrix containing the outputs of the model. 164 | Dimensions are 'batch' x 'num_classes'. 165 | labels: A numpy matrix containing the ground truth labels. 166 | Dimensions are 'batch' x 'num_classes'. 167 | loss: A numpy array containing the loss for each sample. 168 | 169 | Returns: 170 | dictionary: A dictionary storing the metrics for the mini-batch. 171 | 172 | Raises: 173 | ValueError: An error occurred when the shape of predictions and actuals 174 | does not match. 175 | """ 176 | batch_size = labels.shape[0] 177 | mean_hit_at_one = calculate_hit_at_one(predictions, labels) 178 | mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels) 179 | mean_loss = numpy.mean(loss) 180 | 181 | # Take the top 20 predictions. 182 | sparse_predictions, sparse_labels, num_positives = top_k_by_class(predictions, labels, self.top_k) 183 | self.map_calculator.accumulate(sparse_predictions, sparse_labels, num_positives) 184 | self.global_ap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels), sum(num_positives)) 185 | 186 | self.num_examples += batch_size 187 | self.sum_hit_at_one += mean_hit_at_one * batch_size 188 | self.sum_perr += mean_perr * batch_size 189 | self.sum_loss += mean_loss * batch_size 190 | 191 | return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss} 192 | 193 | def get(self): 194 | """Calculate the evaluation metrics for the whole epoch. 195 | 196 | Raises: 197 | ValueError: If no examples were accumulated. 198 | 199 | Returns: 200 | dictionary: a dictionary storing the evaluation metrics for the epoch. The 201 | dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and 202 | aps (default nan). 203 | """ 204 | if self.num_examples <= 0: 205 | raise ValueError("total_sample must be positive.") 206 | avg_hit_at_one = self.sum_hit_at_one / self.num_examples 207 | avg_perr = self.sum_perr / self.num_examples 208 | avg_loss = self.sum_loss / self.num_examples 209 | 210 | aps = self.map_calculator.peek_map_at_n() 211 | gap = self.global_ap_calculator.peek_ap_at_n() 212 | 213 | epoch_info_dict = {} 214 | return {"avg_hit_at_one": avg_hit_at_one, "avg_perr": avg_perr, 215 | "avg_loss": avg_loss, "aps": aps, "gap": gap} 216 | 217 | def clear(self): 218 | """Clear the evaluation metrics and reset the EvaluationMetrics object.""" 219 | self.sum_hit_at_one = 0.0 220 | self.sum_perr = 0.0 221 | self.sum_loss = 0.0 222 | self.map_calculator.clear() 223 | self.global_ap_calculator.clear() 224 | self.num_examples = 0 225 | -------------------------------------------------------------------------------- /frame_level_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Contains a collection of models which operate on variable-length sequences. 16 | """ 17 | import math 18 | 19 | import models 20 | import video_level_models 21 | import tensorflow as tf 22 | import model_utils as utils 23 | 24 | import tensorflow.contrib.slim as slim 25 | from tensorflow import flags 26 | 27 | FLAGS = flags.FLAGS 28 | flags.DEFINE_integer("iterations", 30, 29 | "Number of frames per batch for DBoF.") 30 | flags.DEFINE_bool("dbof_add_batch_norm", True, 31 | "Adds batch normalization to the DBoF model.") 32 | flags.DEFINE_bool( 33 | "sample_random_frames", True, 34 | "If true samples random frames (for frame level models). If false, a random" 35 | "sequence of frames is sampled instead.") 36 | flags.DEFINE_integer("dbof_cluster_size", 8192, 37 | "Number of units in the DBoF cluster layer.") 38 | flags.DEFINE_integer("dbof_hidden_size", 1024, 39 | "Number of units in the DBoF hidden layer.") 40 | flags.DEFINE_string("dbof_pooling_method", "max", 41 | "The pooling method used in the DBoF cluster layer. " 42 | "Choices are 'average' and 'max'.") 43 | flags.DEFINE_string("video_level_classifier_model", "MoeModel", 44 | "Some Frame-Level models can be decomposed into a " 45 | "generalized pooling operation followed by a " 46 | "classifier layer") 47 | flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.") 48 | flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.") 49 | 50 | class FrameLevelLogisticModel(models.BaseModel): 51 | 52 | def create_model(self, model_input, vocab_size, num_frames, **unused_params): 53 | """Creates a model which uses a logistic classifier over the average of the 54 | frame-level features. 55 | 56 | This class is intended to be an example for implementors of frame level 57 | models. If you want to train a model over averaged features it is more 58 | efficient to average them beforehand rather than on the fly. 59 | 60 | Args: 61 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of 62 | input features. 63 | vocab_size: The number of classes in the dataset. 64 | num_frames: A vector of length 'batch' which indicates the number of 65 | frames for each video (before padding). 66 | 67 | Returns: 68 | A dictionary with a tensor containing the probability predictions of the 69 | model in the 'predictions' key. The dimensions of the tensor are 70 | 'batch_size' x 'num_classes'. 71 | """ 72 | num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) 73 | feature_size = model_input.get_shape().as_list()[2] 74 | 75 | denominators = tf.reshape( 76 | tf.tile(num_frames, [1, feature_size]), [-1, feature_size]) 77 | avg_pooled = tf.reduce_sum(model_input, 78 | axis=[1]) / denominators 79 | 80 | output = slim.fully_connected( 81 | avg_pooled, vocab_size, activation_fn=tf.nn.sigmoid, 82 | weights_regularizer=slim.l2_regularizer(1e-8)) 83 | return {"predictions": output} 84 | 85 | class DbofModel(models.BaseModel): 86 | """Creates a Deep Bag of Frames model. 87 | 88 | The model projects the features for each frame into a higher dimensional 89 | 'clustering' space, pools across frames in that space, and then 90 | uses a configurable video-level model to classify the now aggregated features. 91 | 92 | The model will randomly sample either frames or sequences of frames during 93 | training to speed up convergence. 94 | 95 | Args: 96 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of 97 | input features. 98 | vocab_size: The number of classes in the dataset. 99 | num_frames: A vector of length 'batch' which indicates the number of 100 | frames for each video (before padding). 101 | 102 | Returns: 103 | A dictionary with a tensor containing the probability predictions of the 104 | model in the 'predictions' key. The dimensions of the tensor are 105 | 'batch_size' x 'num_classes'. 106 | """ 107 | 108 | def create_model(self, 109 | model_input, 110 | vocab_size, 111 | num_frames, 112 | iterations=None, 113 | add_batch_norm=None, 114 | sample_random_frames=None, 115 | cluster_size=None, 116 | hidden_size=None, 117 | is_training=True, 118 | **unused_params): 119 | iterations = iterations or FLAGS.iterations 120 | add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm 121 | random_frames = sample_random_frames or FLAGS.sample_random_frames 122 | cluster_size = cluster_size or FLAGS.dbof_cluster_size 123 | hidden1_size = hidden_size or FLAGS.dbof_hidden_size 124 | 125 | num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) 126 | if random_frames: 127 | model_input = utils.SampleRandomFrames(model_input, num_frames, 128 | iterations) 129 | else: 130 | model_input = utils.SampleRandomSequence(model_input, num_frames, 131 | iterations) 132 | max_frames = model_input.get_shape().as_list()[1] 133 | feature_size = model_input.get_shape().as_list()[2] 134 | reshaped_input = tf.reshape(model_input, [-1, feature_size]) 135 | tf.summary.histogram("input_hist", reshaped_input) 136 | 137 | if add_batch_norm: 138 | reshaped_input = slim.batch_norm( 139 | reshaped_input, 140 | center=True, 141 | scale=True, 142 | is_training=is_training, 143 | scope="input_bn") 144 | 145 | cluster_weights = tf.get_variable("cluster_weights", 146 | [feature_size, cluster_size], 147 | initializer = tf.random_normal_initializer(stddev=1 / math.sqrt(feature_size))) 148 | tf.summary.histogram("cluster_weights", cluster_weights) 149 | activation = tf.matmul(reshaped_input, cluster_weights) 150 | if add_batch_norm: 151 | activation = slim.batch_norm( 152 | activation, 153 | center=True, 154 | scale=True, 155 | is_training=is_training, 156 | scope="cluster_bn") 157 | else: 158 | cluster_biases = tf.get_variable("cluster_biases", 159 | [cluster_size], 160 | initializer = tf.random_normal(stddev=1 / math.sqrt(feature_size))) 161 | tf.summary.histogram("cluster_biases", cluster_biases) 162 | activation += cluster_biases 163 | activation = tf.nn.relu6(activation) 164 | tf.summary.histogram("cluster_output", activation) 165 | 166 | activation = tf.reshape(activation, [-1, max_frames, cluster_size]) 167 | activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method) 168 | 169 | hidden1_weights = tf.get_variable("hidden1_weights", 170 | [cluster_size, hidden1_size], 171 | initializer=tf.random_normal_initializer(stddev=1 / math.sqrt(cluster_size))) 172 | tf.summary.histogram("hidden1_weights", hidden1_weights) 173 | activation = tf.matmul(activation, hidden1_weights) 174 | if add_batch_norm: 175 | activation = slim.batch_norm( 176 | activation, 177 | center=True, 178 | scale=True, 179 | is_training=is_training, 180 | scope="hidden1_bn") 181 | else: 182 | hidden1_biases = tf.get_variable("hidden1_biases", 183 | [hidden1_size], 184 | initializer = tf.random_normal_initializer(stddev=0.01)) 185 | tf.summary.histogram("hidden1_biases", hidden1_biases) 186 | activation += hidden1_biases 187 | activation = tf.nn.relu6(activation) 188 | tf.summary.histogram("hidden1_output", activation) 189 | 190 | aggregated_model = getattr(video_level_models, 191 | FLAGS.video_level_classifier_model) 192 | return aggregated_model().create_model( 193 | model_input=activation, 194 | vocab_size=vocab_size, 195 | **unused_params) 196 | 197 | class LstmModel(models.BaseModel): 198 | 199 | def create_model(self, model_input, vocab_size, num_frames, **unused_params): 200 | """Creates a model which uses a stack of LSTMs to represent the video. 201 | 202 | Args: 203 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of 204 | input features. 205 | vocab_size: The number of classes in the dataset. 206 | num_frames: A vector of length 'batch' which indicates the number of 207 | frames for each video (before padding). 208 | 209 | Returns: 210 | A dictionary with a tensor containing the probability predictions of the 211 | model in the 'predictions' key. The dimensions of the tensor are 212 | 'batch_size' x 'num_classes'. 213 | """ 214 | lstm_size = FLAGS.lstm_cells 215 | number_of_layers = FLAGS.lstm_layers 216 | 217 | stacked_lstm = tf.contrib.rnn.MultiRNNCell( 218 | [ 219 | tf.contrib.rnn.BasicLSTMCell( 220 | lstm_size, forget_bias=1.0) 221 | for _ in range(number_of_layers) 222 | ]) 223 | 224 | loss = 0.0 225 | 226 | outputs, state = tf.nn.dynamic_rnn(stacked_lstm, model_input, 227 | sequence_length=num_frames, 228 | dtype=tf.float32) 229 | 230 | aggregated_model = getattr(video_level_models, 231 | FLAGS.video_level_classifier_model) 232 | 233 | return aggregated_model().create_model( 234 | model_input=state[-1].h, 235 | vocab_size=vocab_size, 236 | **unused_params) 237 | -------------------------------------------------------------------------------- /average_precision_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Calculate or keep track of the interpolated average precision. 16 | 17 | It provides an interface for calculating interpolated average precision for an 18 | entire list or the top-n ranked items. For the definition of the 19 | (non-)interpolated average precision: 20 | http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf 21 | 22 | Example usages: 23 | 1) Use it as a static function call to directly calculate average precision for 24 | a short ranked list in the memory. 25 | 26 | ``` 27 | import random 28 | 29 | p = np.array([random.random() for _ in xrange(10)]) 30 | a = np.array([random.choice([0, 1]) for _ in xrange(10)]) 31 | 32 | ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a) 33 | ``` 34 | 35 | 2) Use it as an object for long ranked list that cannot be stored in memory or 36 | the case where partial predictions can be observed at a time (Tensorflow 37 | predictions). In this case, we first call the function accumulate many times 38 | to process parts of the ranked list. After processing all the parts, we call 39 | peek_interpolated_ap_at_n. 40 | ``` 41 | p1 = np.array([random.random() for _ in xrange(5)]) 42 | a1 = np.array([random.choice([0, 1]) for _ in xrange(5)]) 43 | p2 = np.array([random.random() for _ in xrange(5)]) 44 | a2 = np.array([random.choice([0, 1]) for _ in xrange(5)]) 45 | 46 | # interpolated average precision at 10 using 1000 break points 47 | calculator = average_precision_calculator.AveragePrecisionCalculator(10) 48 | calculator.accumulate(p1, a1) 49 | calculator.accumulate(p2, a2) 50 | ap3 = calculator.peek_ap_at_n() 51 | ``` 52 | """ 53 | 54 | import heapq 55 | import random 56 | import numbers 57 | 58 | import numpy 59 | 60 | 61 | class AveragePrecisionCalculator(object): 62 | """Calculate the average precision and average precision at n.""" 63 | 64 | def __init__(self, top_n=None): 65 | """Construct an AveragePrecisionCalculator to calculate average precision. 66 | 67 | This class is used to calculate the average precision for a single label. 68 | 69 | Args: 70 | top_n: A positive Integer specifying the average precision at n, or 71 | None to use all provided data points. 72 | 73 | Raises: 74 | ValueError: An error occurred when the top_n is not a positive integer. 75 | """ 76 | if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None): 77 | raise ValueError("top_n must be a positive integer or None.") 78 | 79 | self._top_n = top_n # average precision at n 80 | self._total_positives = 0 # total number of positives have seen 81 | self._heap = [] # max heap of (prediction, actual) 82 | 83 | @property 84 | def heap_size(self): 85 | """Gets the heap size maintained in the class.""" 86 | return len(self._heap) 87 | 88 | @property 89 | def num_accumulated_positives(self): 90 | """Gets the number of positive samples that have been accumulated.""" 91 | return self._total_positives 92 | 93 | def accumulate(self, predictions, actuals, num_positives=None): 94 | """Accumulate the predictions and their ground truth labels. 95 | 96 | After the function call, we may call peek_ap_at_n to actually calculate 97 | the average precision. 98 | Note predictions and actuals must have the same shape. 99 | 100 | Args: 101 | predictions: a list storing the prediction scores. 102 | actuals: a list storing the ground truth labels. Any value 103 | larger than 0 will be treated as positives, otherwise as negatives. 104 | num_positives = If the 'predictions' and 'actuals' inputs aren't complete, 105 | then it's possible some true positives were missed in them. In that case, 106 | you can provide 'num_positives' in order to accurately track recall. 107 | 108 | Raises: 109 | ValueError: An error occurred when the format of the input is not the 110 | numpy 1-D array or the shape of predictions and actuals does not match. 111 | """ 112 | if len(predictions) != len(actuals): 113 | raise ValueError("the shape of predictions and actuals does not match.") 114 | 115 | if not num_positives is None: 116 | if not isinstance(num_positives, numbers.Number) or num_positives < 0: 117 | raise ValueError("'num_positives' was provided but it wan't a nonzero number.") 118 | 119 | if not num_positives is None: 120 | self._total_positives += num_positives 121 | else: 122 | self._total_positives += numpy.size(numpy.where(actuals > 0)) 123 | topk = self._top_n 124 | heap = self._heap 125 | 126 | for i in range(numpy.size(predictions)): 127 | if topk is None or len(heap) < topk: 128 | heapq.heappush(heap, (predictions[i], actuals[i])) 129 | else: 130 | if predictions[i] > heap[0][0]: # heap[0] is the smallest 131 | heapq.heappop(heap) 132 | heapq.heappush(heap, (predictions[i], actuals[i])) 133 | 134 | def clear(self): 135 | """Clear the accumulated predictions.""" 136 | self._heap = [] 137 | self._total_positives = 0 138 | 139 | def peek_ap_at_n(self): 140 | """Peek the non-interpolated average precision at n. 141 | 142 | Returns: 143 | The non-interpolated average precision at n (default 0). 144 | If n is larger than the length of the ranked list, 145 | the average precision will be returned. 146 | """ 147 | if self.heap_size <= 0: 148 | return 0 149 | predlists = numpy.array(list(zip(*self._heap))) 150 | 151 | ap = self.ap_at_n(predlists[0], 152 | predlists[1], 153 | n=self._top_n, 154 | total_num_positives=self._total_positives) 155 | return ap 156 | 157 | @staticmethod 158 | def ap(predictions, actuals): 159 | """Calculate the non-interpolated average precision. 160 | 161 | Args: 162 | predictions: a numpy 1-D array storing the sparse prediction scores. 163 | actuals: a numpy 1-D array storing the ground truth labels. Any value 164 | larger than 0 will be treated as positives, otherwise as negatives. 165 | 166 | Returns: 167 | The non-interpolated average precision at n. 168 | If n is larger than the length of the ranked list, 169 | the average precision will be returned. 170 | 171 | Raises: 172 | ValueError: An error occurred when the format of the input is not the 173 | numpy 1-D array or the shape of predictions and actuals does not match. 174 | """ 175 | return AveragePrecisionCalculator.ap_at_n(predictions, 176 | actuals, 177 | n=None) 178 | 179 | @staticmethod 180 | def ap_at_n(predictions, actuals, n=20, total_num_positives=None): 181 | """Calculate the non-interpolated average precision. 182 | 183 | Args: 184 | predictions: a numpy 1-D array storing the sparse prediction scores. 185 | actuals: a numpy 1-D array storing the ground truth labels. Any value 186 | larger than 0 will be treated as positives, otherwise as negatives. 187 | n: the top n items to be considered in ap@n. 188 | total_num_positives : (optionally) you can specify the number of total 189 | positive 190 | in the list. If specified, it will be used in calculation. 191 | 192 | Returns: 193 | The non-interpolated average precision at n. 194 | If n is larger than the length of the ranked list, 195 | the average precision will be returned. 196 | 197 | Raises: 198 | ValueError: An error occurred when 199 | 1) the format of the input is not the numpy 1-D array; 200 | 2) the shape of predictions and actuals does not match; 201 | 3) the input n is not a positive integer. 202 | """ 203 | if len(predictions) != len(actuals): 204 | raise ValueError("the shape of predictions and actuals does not match.") 205 | 206 | if n is not None: 207 | if not isinstance(n, int) or n <= 0: 208 | raise ValueError("n must be 'None' or a positive integer." 209 | " It was '%s'." % n) 210 | 211 | ap = 0.0 212 | 213 | predictions = numpy.array(predictions) 214 | actuals = numpy.array(actuals) 215 | 216 | # add a shuffler to avoid overestimating the ap 217 | predictions, actuals = AveragePrecisionCalculator._shuffle(predictions, 218 | actuals) 219 | sortidx = sorted( 220 | range(len(predictions)), 221 | key=lambda k: predictions[k], 222 | reverse=True) 223 | 224 | if total_num_positives is None: 225 | numpos = numpy.size(numpy.where(actuals > 0)) 226 | else: 227 | numpos = total_num_positives 228 | 229 | if numpos == 0: 230 | return 0 231 | 232 | if n is not None: 233 | numpos = min(numpos, n) 234 | delta_recall = 1.0 / numpos 235 | poscount = 0.0 236 | 237 | # calculate the ap 238 | r = len(sortidx) 239 | if n is not None: 240 | r = min(r, n) 241 | for i in range(r): 242 | if actuals[sortidx[i]] > 0: 243 | poscount += 1 244 | ap += poscount / (i + 1) * delta_recall 245 | return ap 246 | 247 | @staticmethod 248 | def _shuffle(predictions, actuals): 249 | random.seed(0) 250 | suffidx = random.sample(range(len(predictions)), len(predictions)) 251 | predictions = predictions[suffidx] 252 | actuals = actuals[suffidx] 253 | return predictions, actuals 254 | 255 | @staticmethod 256 | def _zero_one_normalize(predictions, epsilon=1e-7): 257 | """Normalize the predictions to the range between 0.0 and 1.0. 258 | 259 | For some predictions like SVM predictions, we need to normalize them before 260 | calculate the interpolated average precision. The normalization will not 261 | change the rank in the original list and thus won't change the average 262 | precision. 263 | 264 | Args: 265 | predictions: a numpy 1-D array storing the sparse prediction scores. 266 | epsilon: a small constant to avoid denominator being zero. 267 | 268 | Returns: 269 | The normalized prediction. 270 | """ 271 | denominator = numpy.max(predictions) - numpy.min(predictions) 272 | ret = (predictions - numpy.min(predictions)) / numpy.max(denominator, 273 | epsilon) 274 | return ret 275 | -------------------------------------------------------------------------------- /readers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Provides readers configured for different datasets.""" 16 | 17 | import tensorflow as tf 18 | import utils 19 | 20 | from tensorflow import logging 21 | def resize_axis(tensor, axis, new_size, fill_value=0): 22 | """Truncates or pads a tensor to new_size on on a given axis. 23 | 24 | Truncate or extend tensor such that tensor.shape[axis] == new_size. If the 25 | size increases, the padding will be performed at the end, using fill_value. 26 | 27 | Args: 28 | tensor: The tensor to be resized. 29 | axis: An integer representing the dimension to be sliced. 30 | new_size: An integer or 0d tensor representing the new value for 31 | tensor.shape[axis]. 32 | fill_value: Value to use to fill any new entries in the tensor. Will be 33 | cast to the type of tensor. 34 | 35 | Returns: 36 | The resized tensor. 37 | """ 38 | tensor = tf.convert_to_tensor(tensor) 39 | shape = tf.unstack(tf.shape(tensor)) 40 | 41 | pad_shape = shape[:] 42 | pad_shape[axis] = tf.maximum(0, new_size - shape[axis]) 43 | 44 | shape[axis] = tf.minimum(shape[axis], new_size) 45 | shape = tf.stack(shape) 46 | 47 | resized = tf.concat([ 48 | tf.slice(tensor, tf.zeros_like(shape), shape), 49 | tf.fill(tf.stack(pad_shape), tf.cast(fill_value, tensor.dtype)) 50 | ], axis) 51 | 52 | # Update shape. 53 | new_shape = tensor.get_shape().as_list() # A copy is being made. 54 | new_shape[axis] = new_size 55 | resized.set_shape(new_shape) 56 | return resized 57 | 58 | class BaseReader(object): 59 | """Inherit from this class when implementing new readers.""" 60 | 61 | def prepare_reader(self, unused_filename_queue): 62 | """Create a thread for generating prediction and label tensors.""" 63 | raise NotImplementedError() 64 | 65 | 66 | class YT8MAggregatedFeatureReader(BaseReader): 67 | """Reads TFRecords of pre-aggregated Examples. 68 | 69 | The TFRecords must contain Examples with a sparse int64 'labels' feature and 70 | a fixed length float32 feature, obtained from the features in 'feature_name'. 71 | The float features are assumed to be an average of dequantized values. 72 | """ 73 | 74 | def __init__(self, 75 | num_classes=4716, 76 | feature_sizes=[1024], 77 | feature_names=["mean_inc3"]): 78 | """Construct a YT8MAggregatedFeatureReader. 79 | 80 | Args: 81 | num_classes: a positive integer for the number of classes. 82 | feature_sizes: positive integer(s) for the feature dimensions as a list. 83 | feature_names: the feature name(s) in the tensorflow record as a list. 84 | """ 85 | 86 | assert len(feature_names) == len(feature_sizes), \ 87 | "length of feature_names (={}) != length of feature_sizes (={})".format( \ 88 | len(feature_names), len(feature_sizes)) 89 | 90 | self.num_classes = num_classes 91 | self.feature_sizes = feature_sizes 92 | self.feature_names = feature_names 93 | 94 | def prepare_reader(self, filename_queue, batch_size=1024): 95 | """Creates a single reader thread for pre-aggregated YouTube 8M Examples. 96 | 97 | Args: 98 | filename_queue: A tensorflow queue of filename locations. 99 | 100 | Returns: 101 | A tuple of video indexes, features, labels, and padding data. 102 | """ 103 | reader = tf.TFRecordReader() 104 | _, serialized_examples = reader.read_up_to(filename_queue, batch_size) 105 | 106 | tf.add_to_collection("serialized_examples", serialized_examples) 107 | return self.prepare_serialized_examples(serialized_examples) 108 | 109 | def prepare_serialized_examples(self, serialized_examples): 110 | # set the mapping from the fields to data types in the proto 111 | num_features = len(self.feature_names) 112 | assert num_features > 0, "self.feature_names is empty!" 113 | assert len(self.feature_names) == len(self.feature_sizes), \ 114 | "length of feature_names (={}) != length of feature_sizes (={})".format( \ 115 | len(self.feature_names), len(self.feature_sizes)) 116 | 117 | feature_map = {"video_id": tf.FixedLenFeature([], tf.string), 118 | "labels": tf.VarLenFeature(tf.int64)} 119 | for feature_index in range(num_features): 120 | feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature( 121 | [self.feature_sizes[feature_index]], tf.float32) 122 | 123 | features = tf.parse_example(serialized_examples, features=feature_map) 124 | labels = tf.sparse_to_indicator(features["labels"], self.num_classes) 125 | labels.set_shape([None, self.num_classes]) 126 | concatenated_features = tf.concat([ 127 | features[feature_name] for feature_name in self.feature_names], 1) 128 | 129 | return features["video_id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]]) 130 | 131 | class YT8MFrameFeatureReader(BaseReader): 132 | """Reads TFRecords of SequenceExamples. 133 | 134 | The TFRecords must contain SequenceExamples with the sparse in64 'labels' 135 | context feature and a fixed length byte-quantized feature vector, obtained 136 | from the features in 'feature_names'. The quantized features will be mapped 137 | back into a range between min_quantized_value and max_quantized_value. 138 | """ 139 | 140 | def __init__(self, 141 | num_classes=4716, 142 | feature_sizes=[1024], 143 | feature_names=["inc3"], 144 | max_frames=300): 145 | """Construct a YT8MFrameFeatureReader. 146 | 147 | Args: 148 | num_classes: a positive integer for the number of classes. 149 | feature_sizes: positive integer(s) for the feature dimensions as a list. 150 | feature_names: the feature name(s) in the tensorflow record as a list. 151 | max_frames: the maximum number of frames to process. 152 | """ 153 | 154 | assert len(feature_names) == len(feature_sizes), \ 155 | "length of feature_names (={}) != length of feature_sizes (={})".format( \ 156 | len(feature_names), len(feature_sizes)) 157 | 158 | self.num_classes = num_classes 159 | self.feature_sizes = feature_sizes 160 | self.feature_names = feature_names 161 | self.max_frames = max_frames 162 | 163 | def get_video_matrix(self, 164 | features, 165 | feature_size, 166 | max_frames, 167 | max_quantized_value, 168 | min_quantized_value): 169 | """Decodes features from an input string and quantizes it. 170 | 171 | Args: 172 | features: raw feature values 173 | feature_size: length of each frame feature vector 174 | max_frames: number of frames (rows) in the output feature_matrix 175 | max_quantized_value: the maximum of the quantized value. 176 | min_quantized_value: the minimum of the quantized value. 177 | 178 | Returns: 179 | feature_matrix: matrix of all frame-features 180 | num_frames: number of frames in the sequence 181 | """ 182 | decoded_features = tf.reshape( 183 | tf.cast(tf.decode_raw(features, tf.uint8), tf.float32), 184 | [-1, feature_size]) 185 | 186 | num_frames = tf.minimum(tf.shape(decoded_features)[0], max_frames) 187 | feature_matrix = utils.Dequantize(decoded_features, 188 | max_quantized_value, 189 | min_quantized_value) 190 | feature_matrix = resize_axis(feature_matrix, 0, max_frames) 191 | return feature_matrix, num_frames 192 | 193 | def prepare_reader(self, 194 | filename_queue, 195 | max_quantized_value=2, 196 | min_quantized_value=-2): 197 | """Creates a single reader thread for YouTube8M SequenceExamples. 198 | 199 | Args: 200 | filename_queue: A tensorflow queue of filename locations. 201 | max_quantized_value: the maximum of the quantized value. 202 | min_quantized_value: the minimum of the quantized value. 203 | 204 | Returns: 205 | A tuple of video indexes, video features, labels, and padding data. 206 | """ 207 | reader = tf.TFRecordReader() 208 | _, serialized_example = reader.read(filename_queue) 209 | 210 | return self.prepare_serialized_examples(serialized_example, 211 | max_quantized_value, min_quantized_value) 212 | 213 | def prepare_serialized_examples(self, serialized_example, 214 | max_quantized_value=2, min_quantized_value=-2): 215 | 216 | contexts, features = tf.parse_single_sequence_example( 217 | serialized_example, 218 | context_features={"video_id": tf.FixedLenFeature( 219 | [], tf.string), 220 | "labels": tf.VarLenFeature(tf.int64)}, 221 | sequence_features={ 222 | feature_name : tf.FixedLenSequenceFeature([], dtype=tf.string) 223 | for feature_name in self.feature_names 224 | }) 225 | 226 | # read ground truth labels 227 | labels = (tf.cast( 228 | tf.sparse_to_dense(contexts["labels"].values, (self.num_classes,), 1, 229 | validate_indices=False), 230 | tf.bool)) 231 | 232 | # loads (potentially) different types of features and concatenates them 233 | num_features = len(self.feature_names) 234 | assert num_features > 0, "No feature selected: feature_names is empty!" 235 | 236 | assert len(self.feature_names) == len(self.feature_sizes), \ 237 | "length of feature_names (={}) != length of feature_sizes (={})".format( \ 238 | len(self.feature_names), len(self.feature_sizes)) 239 | 240 | num_frames = -1 # the number of frames in the video 241 | feature_matrices = [None] * num_features # an array of different features 242 | for feature_index in range(num_features): 243 | feature_matrix, num_frames_in_this_feature = self.get_video_matrix( 244 | features[self.feature_names[feature_index]], 245 | self.feature_sizes[feature_index], 246 | self.max_frames, 247 | max_quantized_value, 248 | min_quantized_value) 249 | if num_frames == -1: 250 | num_frames = num_frames_in_this_feature 251 | else: 252 | tf.assert_equal(num_frames, num_frames_in_this_feature) 253 | 254 | feature_matrices[feature_index] = feature_matrix 255 | 256 | # cap the number of frames at self.max_frames 257 | num_frames = tf.minimum(num_frames, self.max_frames) 258 | 259 | # concatenate different features 260 | video_matrix = tf.concat(feature_matrices, 1) 261 | 262 | # convert to batch format. 263 | # TODO: Do proper batch reads to remove the IO bottleneck. 264 | batch_video_ids = tf.expand_dims(contexts["video_id"], 0) 265 | batch_video_matrix = tf.expand_dims(video_matrix, 0) 266 | batch_labels = tf.expand_dims(labels, 0) 267 | batch_frames = tf.expand_dims(num_frames, 0) 268 | 269 | return batch_video_ids, batch_video_matrix, batch_labels, batch_frames 270 | 271 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Binary for evaluating Tensorflow models on the YouTube-8M dataset.""" 15 | 16 | import time 17 | 18 | import eval_util 19 | import losses 20 | import frame_level_models 21 | import video_level_models 22 | import readers 23 | import tensorflow as tf 24 | from tensorflow import app 25 | from tensorflow import flags 26 | from tensorflow import gfile 27 | from tensorflow import logging 28 | import utils 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | if __name__ == "__main__": 33 | # Dataset flags. 34 | flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", 35 | "The directory to load the model files from. " 36 | "The tensorboard metrics files are also saved to this " 37 | "directory.") 38 | flags.DEFINE_string( 39 | "eval_data_pattern", "", 40 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " 41 | "format. The SequenceExamples are expected to have an 'rgb' byte array " 42 | "sequence feature as well as a 'labels' int64 context feature.") 43 | flags.DEFINE_string("feature_names", "mean_rgb", "Name of the feature " 44 | "to use for training.") 45 | flags.DEFINE_string("feature_sizes", "1024", "Length of the feature vectors.") 46 | flags.DEFINE_integer("num_classes", 4716, "Number of classes in dataset.") 47 | 48 | # Model flags. 49 | flags.DEFINE_bool( 50 | "frame_features", False, 51 | "If set, then --eval_data_pattern must be frame-level features. " 52 | "Otherwise, --eval_data_pattern must be aggregated video-level " 53 | "features. The model must also be set appropriately (i.e. to read 3D " 54 | "batches VS 4D batches.") 55 | flags.DEFINE_string( 56 | "model", "LogisticModel", 57 | "Which architecture to use for the model. Options include 'Logistic', " 58 | "'SingleMixtureMoe', and 'TwoLayerSigmoid'. See aggregated_models.py and " 59 | "frame_level_models.py for the model definitions.") 60 | flags.DEFINE_integer("batch_size", 1024, 61 | "How many examples to process per batch.") 62 | flags.DEFINE_string("label_loss", "CrossEntropyLoss", 63 | "Loss computed on validation data") 64 | 65 | # Other flags. 66 | flags.DEFINE_integer("num_readers", 8, 67 | "How many threads to use for reading input files.") 68 | flags.DEFINE_boolean("run_once", False, "Whether to run eval only once.") 69 | flags.DEFINE_integer("top_k", 20, "How many predictions to output per video.") 70 | 71 | 72 | def find_class_by_name(name, modules): 73 | """Searches the provided modules for the named class and returns it.""" 74 | modules = [getattr(module, name, None) for module in modules] 75 | return next(a for a in modules if a) 76 | 77 | 78 | def get_input_evaluation_tensors(reader, 79 | data_pattern, 80 | batch_size=1024, 81 | num_readers=1): 82 | """Creates the section of the graph which reads the evaluation data. 83 | 84 | Args: 85 | reader: A class which parses the training data. 86 | data_pattern: A 'glob' style path to the data files. 87 | batch_size: How many examples to process at a time. 88 | num_readers: How many I/O threads to use. 89 | 90 | Returns: 91 | A tuple containing the features tensor, labels tensor, and optionally a 92 | tensor containing the number of frames per video. The exact dimensions 93 | depend on the reader being used. 94 | 95 | Raises: 96 | IOError: If no files matching the given pattern were found. 97 | """ 98 | logging.info("Using batch size of " + str(batch_size) + " for evaluation.") 99 | with tf.name_scope("eval_input"): 100 | files = gfile.Glob(data_pattern) 101 | if not files: 102 | raise IOError("Unable to find the evaluation files.") 103 | logging.info("number of evaluation files: " + str(len(files))) 104 | filename_queue = tf.train.string_input_producer( 105 | files, shuffle=False, num_epochs=1) 106 | eval_data = [ 107 | reader.prepare_reader(filename_queue) for _ in range(num_readers) 108 | ] 109 | return tf.train.batch_join( 110 | eval_data, 111 | batch_size=batch_size, 112 | capacity=3 * batch_size, 113 | allow_smaller_final_batch=True, 114 | enqueue_many=True) 115 | 116 | 117 | def build_graph(reader, 118 | model, 119 | eval_data_pattern, 120 | label_loss_fn, 121 | batch_size=1024, 122 | num_readers=1): 123 | """Creates the Tensorflow graph for evaluation. 124 | 125 | Args: 126 | reader: The data file reader. It should inherit from BaseReader. 127 | model: The core model (e.g. logistic or neural net). It should inherit 128 | from BaseModel. 129 | eval_data_pattern: glob path to the evaluation data files. 130 | label_loss_fn: What kind of loss to apply to the model. It should inherit 131 | from BaseLoss. 132 | batch_size: How many examples to process at a time. 133 | num_readers: How many threads to use for I/O operations. 134 | """ 135 | 136 | global_step = tf.Variable(0, trainable=False, name="global_step") 137 | video_id_batch, model_input_raw, labels_batch, num_frames = get_input_evaluation_tensors( # pylint: disable=g-line-too-long 138 | reader, 139 | eval_data_pattern, 140 | batch_size=batch_size, 141 | num_readers=num_readers) 142 | tf.summary.histogram("model_input_raw", model_input_raw) 143 | 144 | feature_dim = len(model_input_raw.get_shape()) - 1 145 | 146 | # Normalize input features. 147 | model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) 148 | 149 | with tf.variable_scope("tower"): 150 | result = model.create_model(model_input, 151 | num_frames=num_frames, 152 | vocab_size=reader.num_classes, 153 | labels=labels_batch, 154 | is_training=False) 155 | predictions = result["predictions"] 156 | tf.summary.histogram("model_activations", predictions) 157 | if "loss" in result.keys(): 158 | label_loss = result["loss"] 159 | else: 160 | label_loss = label_loss_fn.calculate_loss(predictions, labels_batch) 161 | 162 | tf.add_to_collection("global_step", global_step) 163 | tf.add_to_collection("loss", label_loss) 164 | tf.add_to_collection("predictions", predictions) 165 | tf.add_to_collection("input_batch", model_input) 166 | tf.add_to_collection("video_id_batch", video_id_batch) 167 | tf.add_to_collection("num_frames", num_frames) 168 | tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32)) 169 | tf.add_to_collection("summary_op", tf.summary.merge_all()) 170 | 171 | 172 | def evaluation_loop(video_id_batch, prediction_batch, label_batch, loss, 173 | summary_op, saver, summary_writer, evl_metrics, 174 | last_global_step_val): 175 | """Run the evaluation loop once. 176 | 177 | Args: 178 | video_id_batch: a tensor of video ids mini-batch. 179 | prediction_batch: a tensor of predictions mini-batch. 180 | label_batch: a tensor of label_batch mini-batch. 181 | loss: a tensor of loss for the examples in the mini-batch. 182 | summary_op: a tensor which runs the tensorboard summary operations. 183 | saver: a tensorflow saver to restore the model. 184 | summary_writer: a tensorflow summary_writer 185 | evl_metrics: an EvaluationMetrics object. 186 | last_global_step_val: the global step used in the previous evaluation. 187 | 188 | Returns: 189 | The global_step used in the latest model. 190 | """ 191 | 192 | global_step_val = -1 193 | with tf.Session() as sess: 194 | latest_checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir) 195 | if latest_checkpoint: 196 | logging.info("Loading checkpoint for eval: " + latest_checkpoint) 197 | # Restores from checkpoint 198 | saver.restore(sess, latest_checkpoint) 199 | # Assuming model_checkpoint_path looks something like: 200 | # /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it. 201 | global_step_val = latest_checkpoint.split("/")[-1].split("-")[-1] 202 | else: 203 | logging.info("No checkpoint file found.") 204 | return global_step_val 205 | 206 | if global_step_val == last_global_step_val: 207 | logging.info("skip this checkpoint global_step_val=%s " 208 | "(same as the previous one).", global_step_val) 209 | return global_step_val 210 | 211 | sess.run([tf.local_variables_initializer()]) 212 | 213 | # Start the queue runners. 214 | fetches = [video_id_batch, prediction_batch, label_batch, loss, summary_op] 215 | coord = tf.train.Coordinator() 216 | try: 217 | threads = [] 218 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 219 | threads.extend(qr.create_threads( 220 | sess, coord=coord, daemon=True, 221 | start=True)) 222 | logging.info("enter eval_once loop global_step_val = %s. ", 223 | global_step_val) 224 | 225 | evl_metrics.clear() 226 | 227 | examples_processed = 0 228 | while not coord.should_stop(): 229 | batch_start_time = time.time() 230 | _, predictions_val, labels_val, loss_val, summary_val = sess.run( 231 | fetches) 232 | seconds_per_batch = time.time() - batch_start_time 233 | example_per_second = labels_val.shape[0] / seconds_per_batch 234 | examples_processed += labels_val.shape[0] 235 | 236 | iteration_info_dict = evl_metrics.accumulate(predictions_val, 237 | labels_val, loss_val) 238 | iteration_info_dict["examples_per_second"] = example_per_second 239 | 240 | iterinfo = utils.AddGlobalStepSummary( 241 | summary_writer, 242 | global_step_val, 243 | iteration_info_dict, 244 | summary_scope="Eval") 245 | logging.info("examples_processed: %d | %s", examples_processed, 246 | iterinfo) 247 | 248 | except tf.errors.OutOfRangeError as e: 249 | logging.info( 250 | "Done with batched inference. Now calculating global performance " 251 | "metrics.") 252 | # calculate the metrics for the entire epoch 253 | epoch_info_dict = evl_metrics.get() 254 | epoch_info_dict["epoch_id"] = global_step_val 255 | 256 | summary_writer.add_summary(summary_val, global_step_val) 257 | epochinfo = utils.AddEpochSummary( 258 | summary_writer, 259 | global_step_val, 260 | epoch_info_dict, 261 | summary_scope="Eval") 262 | logging.info(epochinfo) 263 | evl_metrics.clear() 264 | except Exception as e: # pylint: disable=broad-except 265 | logging.info("Unexpected exception: " + str(e)) 266 | coord.request_stop(e) 267 | 268 | coord.request_stop() 269 | coord.join(threads, stop_grace_period_secs=10) 270 | 271 | return global_step_val 272 | 273 | 274 | def evaluate(): 275 | tf.set_random_seed(0) # for reproducibility 276 | with tf.Graph().as_default(): 277 | # convert feature_names and feature_sizes to lists of values 278 | feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( 279 | FLAGS.feature_names, FLAGS.feature_sizes) 280 | num_classes = FLAGS.num_classes 281 | 282 | if FLAGS.frame_features: 283 | reader = readers.YT8MFrameFeatureReader( 284 | num_classes=num_classes, 285 | feature_names=feature_names,feature_sizes=feature_sizes) 286 | else: 287 | reader = readers.YT8MAggregatedFeatureReader( 288 | num_classes=num_classes, 289 | feature_names=feature_names, feature_sizes=feature_sizes) 290 | 291 | model = find_class_by_name(FLAGS.model, 292 | [frame_level_models, video_level_models])() 293 | label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])() 294 | 295 | if FLAGS.eval_data_pattern is "": 296 | raise IOError("'eval_data_pattern' was not specified. " + 297 | "Nothing to evaluate.") 298 | 299 | build_graph( 300 | reader=reader, 301 | model=model, 302 | eval_data_pattern=FLAGS.eval_data_pattern, 303 | label_loss_fn=label_loss_fn, 304 | num_readers=FLAGS.num_readers, 305 | batch_size=FLAGS.batch_size) 306 | logging.info("built evaluation graph") 307 | video_id_batch = tf.get_collection("video_id_batch")[0] 308 | prediction_batch = tf.get_collection("predictions")[0] 309 | label_batch = tf.get_collection("labels")[0] 310 | loss = tf.get_collection("loss")[0] 311 | summary_op = tf.get_collection("summary_op")[0] 312 | 313 | saver = tf.train.Saver(tf.global_variables()) 314 | summary_writer = tf.summary.FileWriter( 315 | FLAGS.train_dir, graph=tf.get_default_graph()) 316 | 317 | evl_metrics = eval_util.EvaluationMetrics(reader.num_classes, FLAGS.top_k) 318 | 319 | last_global_step_val = -1 320 | while True: 321 | last_global_step_val = evaluation_loop(video_id_batch, prediction_batch, 322 | label_batch, loss, summary_op, 323 | saver, summary_writer, evl_metrics, 324 | last_global_step_val) 325 | if FLAGS.run_once: 326 | break 327 | 328 | 329 | def main(unused_argv): 330 | logging.set_verbosity(tf.logging.INFO) 331 | print("tensorflow version: %s" % tf.__version__) 332 | evaluate() 333 | 334 | 335 | if __name__ == "__main__": 336 | app.run() 337 | 338 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YouTube-8M Tensorflow Starter Code 2 | 3 | This repo contains starter code for training and evaluating machine learning 4 | models over the [YouTube-8M](https://research.google.com/youtube8m/) dataset. 5 | The code gives an end-to-end working example for reading the dataset, training a 6 | TensorFlow model, and evaluating the performance of the model. Out of the box, 7 | you can train several [model architectures](#overview-of-models) over either 8 | frame-level or video-level features. The code can easily be extended to train 9 | your own custom-defined models. 10 | 11 | It is possible to train and evaluate on YouTube-8M in two ways: on Google Cloud 12 | or on your own machine. This README provides instructions for both. 13 | 14 | ## Table of Contents 15 | * [Running on Google's Cloud Machine Learning Platform](#running-on-googles-cloud-machine-learning-platform) 16 | * [Requirements](#requirements) 17 | * [Testing Locally](#testing-locally) 18 | * [Training on the Cloud over Video-Level Features](#training-on-video-level-features) 19 | * [Evaluation and Inference](#evaluation-and-inference) 20 | * [Accessing Files on Google Cloud](#accessing-files-on-google-cloud) 21 | * [Using Frame-Level Features](#using-frame-level-features) 22 | * [Using Audio Features](#using-audio-features) 23 | * [Using Larger Machine Types](#using-larger-machine-types) 24 | * [Running on Your Own Machine](#running-on-your-own-machine) 25 | * [Requirements](#requirements-1) 26 | * [Training on Video-Level Features](#training-on-video-level-features-1) 27 | * [Evaluation and Inference](#evaluation-and-inference-1) 28 | * [Using Frame-Level Features](#using-frame-level-features-1) 29 | * [Using Audio Features](#using-audio-features-1) 30 | * [Using GPUs](#using-gpus) 31 | * [Ground-Truth Label Files](#ground-truth-label-files) 32 | * [Overview of Models](#overview-of-models) 33 | * [Video-Level Models](#video-level-models) 34 | * [Frame-Level Models](#frame-level-models) 35 | * [Create Your Own Dataset Files](#create-your-own-dataset-files) 36 | * [Overview of Files](#overview-of-files) 37 | * [Training](#training) 38 | * [Evaluation](#evaluation) 39 | * [Inference](#inference) 40 | * [Misc](#misc) 41 | * [About This Project](#about-this-project) 42 | 43 | ## Running on Google's Cloud Machine Learning Platform 44 | 45 | ### Requirements 46 | 47 | This option requires you to have an appropriately configured Google Cloud 48 | Platform account. To create and configure your account, please make sure you 49 | follow the instructions [here](https://cloud.google.com/ml/docs/how-tos/getting-set-up). 50 | If you are participating in the Google Cloud & YouTube-8M Video Understanding 51 | Challenge hosted on [kaggle](https://www.kaggle.com/c/youtube8m), see [these instructions](https://www.kaggle.com/c/youtube8m#getting-started-with-google-cloud) instead. 52 | 53 | Please also verify that you have Python 2.7+ and Tensorflow 1.0.0 or higher 54 | installed by running the following commands: 55 | 56 | ```sh 57 | python --version 58 | python -c 'import tensorflow as tf; print(tf.__version__)' 59 | ``` 60 | 61 | ### Testing Locally 62 | All gcloud commands should be done from the directory *immediately above* the 63 | source code. You should be able to see the source code directory if you 64 | run 'ls'. 65 | 66 | As you are developing your own models, you will want to test them 67 | quickly to flush out simple problems without having to submit them to the cloud. 68 | You can use the `gcloud beta ml local` set of commands for that. 69 | 70 | Here is an example command line for video-level training: 71 | 72 | ```sh 73 | gcloud ml-engine local train \ 74 | --package-path=youtube-8m --module-name=youtube-8m.train -- \ 75 | --train_data_pattern='gs://youtube8m-ml/1/video_level/train/train*.tfrecord' \ 76 | --train_dir=/tmp/yt8m_train --model=LogisticModel --start_new_model 77 | ``` 78 | 79 | You might want to download some training shards locally to speed things up and 80 | allow you to work offline. The command below will copy 10 out of the 4096 81 | training data files to the current directory. 82 | 83 | ```sh 84 | # Downloads 55MB of data. 85 | gsutil cp gs://us.data.yt8m.org/1/video_level/train/traina[0-9].tfrecord . 86 | ``` 87 | Once you download the files, you can point the job to them using the 88 | 'train_data_pattern' argument (i.e. instead of pointing to the "gs://..." 89 | files, you point to the local files). 90 | 91 | Once your model is working locally, you can scale up on the Cloud 92 | which is described below. 93 | 94 | ### Training on the Cloud over Video-Level Features 95 | 96 | The following commands will train a model on Google Cloud 97 | over video-level features. 98 | 99 | ```sh 100 | BUCKET_NAME=gs://${USER}_yt8m_train_bucket 101 | # (One Time) Create a storage bucket to store training logs and checkpoints. 102 | gsutil mb -l us-east1 $BUCKET_NAME 103 | # Submit the training job. 104 | JOB_NAME=yt8m_train_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ml-engine jobs \ 105 | submit training $JOB_NAME \ 106 | --package-path=youtube-8m --module-name=youtube-8m.train \ 107 | --staging-bucket=$BUCKET_NAME --region=us-east1 \ 108 | --config=youtube-8m/cloudml-gpu.yaml \ 109 | -- --train_data_pattern='gs://youtube8m-ml-us-east1/1/video_level/train/train*.tfrecord' \ 110 | --model=LogisticModel \ 111 | --train_dir=$BUCKET_NAME/yt8m_train_video_level_logistic_model 112 | ``` 113 | 114 | In the 'gsutil' command above, the 'package-path' flag refers to the directory 115 | containing the 'train.py' script and more generally the python package which 116 | should be deployed to the cloud worker. The module-name refers to the specific 117 | python script which should be executed (in this case the train module). 118 | 119 | It may take several minutes before the job starts running on Google Cloud. 120 | When it starts you will see outputs like the following: 121 | 122 | ``` 123 | training step 270| Hit@1: 0.68 PERR: 0.52 Loss: 638.453 124 | training step 271| Hit@1: 0.66 PERR: 0.49 Loss: 635.537 125 | training step 272| Hit@1: 0.70 PERR: 0.52 Loss: 637.564 126 | ``` 127 | 128 | At this point you can disconnect your console by pressing "ctrl-c". The 129 | model will continue to train indefinitely in the Cloud. Later, you can check 130 | on its progress or halt the job by visiting the 131 | [Google Cloud ML Jobs console](https://console.cloud.google.com/ml/jobs). 132 | 133 | You can train many jobs at once and use tensorboard to compare their performance 134 | visually. 135 | 136 | ```sh 137 | tensorboard --logdir=$BUCKET_NAME --port=8080 138 | ``` 139 | 140 | Once tensorboard is running, you can access it at the following url: 141 | [http://localhost:8080](http://localhost:8080). 142 | If you are using Google Cloud Shell, you can instead click the Web Preview button 143 | on the upper left corner of the Cloud Shell window and select "Preview on port 8080". 144 | This will bring up a new browser tab with the Tensorboard view. 145 | 146 | ### Evaluation and Inference 147 | Here's how to evaluate a model on the validation dataset: 148 | 149 | ```sh 150 | JOB_TO_EVAL=yt8m_train_video_level_logistic_model 151 | JOB_NAME=yt8m_eval_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ml-engine jobs \ 152 | submit training $JOB_NAME \ 153 | --package-path=youtube-8m --module-name=youtube-8m.eval \ 154 | --staging-bucket=$BUCKET_NAME --region=us-east1 \ 155 | --config=youtube-8m/cloudml-gpu.yaml \ 156 | -- --eval_data_pattern='gs://youtube8m-ml-us-east1/1/video_level/validate/validate*.tfrecord' \ 157 | --model=LogisticModel \ 158 | --train_dir=$BUCKET_NAME/${JOB_TO_EVAL} --run_once=True 159 | ``` 160 | 161 | And here's how to perform inference with a model on the test set: 162 | 163 | ```sh 164 | JOB_TO_EVAL=yt8m_train_video_level_logistic_model 165 | JOB_NAME=yt8m_inference_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ml-engine jobs \ 166 | submit training $JOB_NAME \ 167 | --package-path=youtube-8m --module-name=youtube-8m.inference \ 168 | --staging-bucket=$BUCKET_NAME --region=us-east1 \ 169 | --config=youtube-8m/cloudml-gpu.yaml \ 170 | -- --input_data_pattern='gs://youtube8m-ml/1/video_level/test/test*.tfrecord' \ 171 | --train_dir=$BUCKET_NAME/${JOB_TO_EVAL} \ 172 | --output_file=$BUCKET_NAME/${JOB_TO_EVAL}/predictions.csv 173 | ``` 174 | 175 | Note the confusing use of 'training' in the above gcloud commands. Despite the 176 | name, the 'training' argument really just offers a cloud hosted 177 | python/tensorflow service. From the point of view of the Cloud Platform, there 178 | is no distinction between our training and inference jobs. The Cloud ML platform 179 | also offers specialized functionality for prediction with 180 | Tensorflow models, but discussing that is beyond the scope of this readme. 181 | 182 | Once these job starts executing you will see outputs similar to the 183 | following for the evaluation code: 184 | 185 | ``` 186 | examples_processed: 1024 | global_step 447044 | Batch Hit@1: 0.782 | Batch PERR: 0.637 | Batch Loss: 7.821 | Examples_per_sec: 834.658 187 | ``` 188 | 189 | and the following for the inference code: 190 | 191 | ``` 192 | num examples processed: 8192 elapsed seconds: 14.85 193 | ``` 194 | 195 | ### Accessing Files on Google Cloud 196 | 197 | You can browse the storage buckets you created on Google Cloud, for example, to 198 | access the trained models, prediction CSV files, etc. by visiting the 199 | [Google Cloud storage browser](https://console.cloud.google.com/storage/browser). 200 | 201 | Alternatively, you can use the 'gsutil' command to download the files directly. 202 | For example, to download the output of the inference code from the previous 203 | section to your local machine, run: 204 | 205 | 206 | ``` 207 | gsutil cp $BUCKET_NAME/${JOB_TO_EVAL}/predictions.csv . 208 | ``` 209 | 210 | ### Using Frame-Level Features 211 | 212 | Append 213 | ```sh 214 | --frame_features=True --model=FrameLevelLogisticModel --feature_names="rgb" \ 215 | --feature_sizes="1024" --batch_size=128 \ 216 | --train_dir=$BUCKET_NAME/yt8m_train_frame_level_logistic_model 217 | ``` 218 | 219 | to the 'gcloud' commands given above, and change 'video_level' in paths to 220 | 'frame_level'. Here is a sample command to kick-off a frame-level job: 221 | 222 | ```sh 223 | JOB_NAME=yt8m_train_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ml-engine jobs \ 224 | submit training $JOB_NAME \ 225 | --package-path=youtube-8m --module-name=youtube-8m.train \ 226 | --staging-bucket=$BUCKET_NAME --region=us-east1 \ 227 | --config=youtube-8m/cloudml-gpu.yaml \ 228 | -- --train_data_pattern='gs://youtube8m-ml-us-east1/1/frame_level/train/train*.tfrecord' \ 229 | --frame_features=True --model=FrameLevelLogisticModel --feature_names="rgb" \ 230 | --feature_sizes="1024" --batch_size=128 \ 231 | --train_dir=$BUCKET_NAME/yt8m_train_frame_level_logistic_model 232 | ``` 233 | 234 | The 'FrameLevelLogisticModel' is designed to provide equivalent results to a 235 | logistic model trained over the video-level features. Please look at the 236 | 'video_level_models.py' or 'frame_level_models.py' files to see how to implement 237 | your own models. 238 | 239 | 240 | ### Using Audio Features 241 | 242 | The feature files (both Frame-Level and Video-Level) contain two sets of 243 | features: 1) visual and 2) audio. The code defaults to using the visual 244 | features only, but it is possible to use audio features instead of (or besides) 245 | visual features. To specify the (combination of) features to use you must set 246 | `--feature_names` and `--feature_sizes` flags. The visual and audio features are 247 | called 'rgb' and 'audio' and have 1024 and 128 dimensions, respectively. 248 | The two flags take a comma-separated list of values in string. For example, to 249 | use audio-visual Video-Level features the flags must be set as follows: 250 | 251 | ``` 252 | --feature_names="mean_rgb, mean_audio" --feature_sizes="1024, 128" 253 | ``` 254 | 255 | Similarly, to use audio-visual Frame-Level features use: 256 | 257 | ``` 258 | --feature_names="rgb, audio" --feature_sizes="1024, 128" 259 | ``` 260 | 261 | **NOTE:** Make sure the set of features and the order in which the appear in the 262 | lists provided to the two flags above match. Also, the order must match when 263 | running training, evaluation, or inference. 264 | 265 | ### Using Larger Machine Types 266 | 267 | Some complex frame-level models can take as long as a week to converge when 268 | using only one GPU. You can train these models more quickly by using more 269 | powerful machine types which have additional GPUs. To use a configuration with 270 | 4 GPUs, replace the argument to `--config` with `youtube-8m/cloudml-4gpu.yaml`. 271 | Be careful with this argument as it will also increase the rate you are charged 272 | by a factor of 4 as well. 273 | 274 | ## Running on Your Own Machine 275 | 276 | ### Requirements 277 | 278 | The starter code requires Tensorflow. If you haven't installed it yet, follow 279 | the instructions on [tensorflow.org](https://www.tensorflow.org/install/). 280 | This code has been tested with Tensorflow 1.0.0. Going forward, we will continue 281 | to target the latest released version of Tensorflow. 282 | 283 | Please verify that you have Python 2.7+ and Tensorflow 1.0.0 or higher 284 | installed by running the following commands: 285 | 286 | ```sh 287 | python --version 288 | python -c 'import tensorflow as tf; print(tf.__version__)' 289 | ``` 290 | 291 | You can find complete instructions for downloading the dataset on the 292 | [YouTube-8M website](https://research.google.com/youtube8m/download.html). 293 | We recommend downloading the smaller video-level features dataset first when 294 | getting started. To do that, run: 295 | 296 | ``` 297 | mkdir -p features; cd features 298 | curl data.yt8m.org/download.py | partition=1/video_level/train mirror=us python 299 | ``` 300 | 301 | This will download the full set of video level features, which takes up 31GB 302 | of space. 303 | If you are located outside of North America, you should change the flag 'mirror' 304 | to 'eu' for Europe or 'asia' for Asia to speed up the transfer of the files. 305 | 306 | Change 'train' to 'validate'/'test' and re-run the command to download the 307 | other splits of the dataset. 308 | 309 | Change 'video_level' to 'frame_level' to download the frame-level features. The 310 | complete frame-level features take about 1.71TB of space. You can set the 311 | environment variable 'shard' to 'm,n' to download only m/n-th of the data. For 312 | example, to download 1/100-th of the frame-level features from the training set, 313 | run: 314 | 315 | ``` 316 | curl data.yt8m.org/download.py | shard=1,100 partition=1/frame_level/train mirror=us python 317 | ``` 318 | 319 | ### Training on Video-Level Features 320 | 321 | To start training a logistic model on the video-level features, run 322 | 323 | ```sh 324 | MODEL_DIR=/tmp/yt8m 325 | python train.py --train_data_pattern='/path/to/features/train*.tfrecord' --model=LogisticModel --train_dir=$MODEL_DIR/video_level_logistic_model 326 | ``` 327 | 328 | Since the dataset is sharded into 4096 individual files, we use a wildcard (\*) 329 | to represent all of those files. 330 | 331 | By default, the training code will frequently write _checkpoint_ files (i.e. 332 | values of all trainable parameters, at the current training iteration). These 333 | will be written to the `--train_dir`. If you re-use a `--train_dir`, the trainer 334 | will first restore the latest checkpoint written in that directory. This only 335 | works if the architecture of the checkpoint matches the graph created by the 336 | training code. If you are in active development/debugging phase, consider 337 | adding `--start_new_model` flag to your run configuration. 338 | 339 | ### Evaluation and Inference 340 | 341 | To evaluate the model, run 342 | 343 | ```sh 344 | python eval.py --eval_data_pattern='/path/to/features/validate*.tfrecord' --model=LogisticModel --train_dir=$MODEL_DIR/video_level_logistic_model --run_once=True 345 | ``` 346 | 347 | As the model is training or evaluating, you can view the results on tensorboard 348 | by running 349 | 350 | ```sh 351 | tensorboard --logdir=$MODEL_DIR 352 | ``` 353 | 354 | and navigating to http://localhost:6006 in your web browser. 355 | 356 | When you are happy with your model, you can generate a csv file of predictions 357 | from it by running 358 | 359 | ```sh 360 | python inference.py --output_file=$MODEL_DIR/video_level_logistic_model/predictions.csv --input_data_pattern='/path/to/features/test*.tfrecord' --train_dir=$MODEL_DIR/video_level_logistic_model 361 | ``` 362 | 363 | This will output the top 20 predicted labels from the model for every example 364 | to 'predictions.csv'. 365 | 366 | ### Using Frame-Level Features 367 | 368 | Follow the same instructions as above, appending 369 | `--frame_features=True --model=FrameLevelLogisticModel --feature_names="rgb" 370 | --feature_sizes="1024" --train_dir=$MODEL_DIR/frame_level_logistic_model` 371 | for the 'train.py', 'eval.py', and 'inference.py' scripts. 372 | 373 | The 'FrameLevelLogisticModel' is designed to provide equivalent results to a 374 | logistic model trained over the video-level features. Please look at the 375 | 'models.py' file to see how to implement your own models. 376 | 377 | ### Using Audio Features 378 | 379 | See [Using Audio Features](#using-audio-features) section above. 380 | 381 | ### Using GPUs 382 | 383 | If your Tensorflow installation has GPU support, this code will make use of all 384 | of your compatible GPUs. You can verify your installation by running 385 | 386 | ``` 387 | python -c 'import tensorflow as tf; tf.Session()' 388 | ``` 389 | 390 | This will print out something like the following for each of your compatible 391 | GPUs. 392 | 393 | ``` 394 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:102] Found device 0 with properties: 395 | name: Tesla M40 396 | major: 5 minor: 2 memoryClockRate (GHz) 1.112 397 | pciBusID 0000:04:00.0 398 | Total memory: 11.25GiB 399 | Free memory: 11.09GiB 400 | ... 401 | ``` 402 | 403 | If at least one GPU was found, the forward and backward passes will be computed 404 | with the GPUs, whereas the CPU will be used primarily for the input and output 405 | pipelines. If you have multiple GPUs, each of them will be given a full batch 406 | of examples, and the resulting gradients will be summed together before being 407 | applied. This will increase your effective batch size. For example, if you set 408 | `batch_size=128` and you have 4 GPUs, this will result in 512 examples being 409 | evaluated every training step. 410 | 411 | ### Ground-Truth Label Files 412 | 413 | We also provide CSV files containing the ground-truth label information of the 414 | 'train' and 'validation' partitions of the dataset. These files can be 415 | downloaded using 'gsutil' command: 416 | 417 | ``` 418 | gsutil cp gs://us.data.yt8m.org/1/ground_truth_labels/train_labels.csv /destination/folder/ 419 | gsutil cp gs://us.data.yt8m.org/1/ground_truth_labels/validate_labels.csv /destination/folder/ 420 | ``` 421 | 422 | or directly using the following links: 423 | 424 | * [http://us.data.yt8m.org/1/ground_truth_labels/train_labels.csv](http://us.data.yt8m.org/1/ground_truth_labels/train_labels.csv) 425 | * [http://us.data.yt8m.org/1/ground_truth_labels/validate_labels.csv](http://us.data.yt8m.org/1/ground_truth_labels/validate_labels.csv) 426 | 427 | Each line in the files starts with the video id and is followed by the list of 428 | ground-truth labels corresponding to that video. For example, for a video with 429 | id 'VIDEO_ID' and two labels 'LABEL1' and 'LABEL2' we store the following line: 430 | 431 | ``` 432 | VIDEO_ID,LABEL1 LABEL2 433 | ``` 434 | 435 | ## Overview of Models 436 | 437 | This sample code contains implementations of the models given in the 438 | [YouTube-8M technical report](https://arxiv.org/abs/1609.08675). 439 | 440 | ### Video-Level Models 441 | * `LogisticModel`: Linear projection of the output features into the label 442 | space, followed by a sigmoid function to convert logit 443 | values to probabilities. 444 | * `MoeModel`: A per-class softmax distribution over a configurable number of 445 | logistic classifiers. One of the classifiers in the mixture 446 | is not trained, and always predicts 0. 447 | 448 | ### Frame-Level Models 449 | * `LstmModel`: Processes the features for each frame using a multi-layered 450 | LSTM neural net. The final internal state of the LSTM 451 | is input to a video-level model for classification. Note that 452 | you will need to change the learning rate to 0.001 when using 453 | this model. 454 | * `DbofModel`: Projects the features for each frame into a higher dimensional 455 | 'clustering' space, pools across frames in that space, and then 456 | uses a video-level model to classify the now aggregated features. 457 | * `FrameLevelLogisticModel`: Equivalent to 'LogisticModel', but performs 458 | average-pooling on the fly over frame-level 459 | features rather than using pre-aggregated features. 460 | 461 | ## Create Your Own Dataset Files 462 | You can create your dataset files from your own videos. Our 463 | [feature extractor](./feature_extractor) code creates `tfrecord` 464 | files, identical to our dataset files. You can use our starter code to train on 465 | the `tfrecord` files output by the feature extractor. In addition, you can 466 | fine-tune your YouTube-8M models on your new dataset. 467 | 468 | ## Overview of Files 469 | 470 | ### Training 471 | * `train.py`: The primary script for training models. 472 | * `losses.py`: Contains definitions for loss functions. 473 | * `models.py`: Contains the base class for defining a model. 474 | * `video_level_models.py`: Contains definitions for models that take 475 | aggregated features as input. 476 | * `frame_level_models.py`: Contains definitions for models that take frame- 477 | level features as input. 478 | * `model_util.py`: Contains functions that are of general utility for 479 | implementing models. 480 | * `export_model.py`: Provides a class to export a model during training 481 | for later use in batch prediction. 482 | * `readers.py`: Contains definitions for the Video dataset and Frame 483 | dataset readers. 484 | 485 | ### Evaluation 486 | * `eval.py`: The primary script for evaluating models. 487 | * `eval_util.py`: Provides a class that calculates all evaluation metrics. 488 | * `average_precision_calculator.py`: Functions for calculating 489 | average precision. 490 | * `mean_average_precision_calculator.py`: Functions for calculating mean 491 | average precision. 492 | 493 | ### Inference 494 | * `inference.py`: Generates an output file containing predictions of 495 | the model over a set of videos. 496 | 497 | ### Misc 498 | * `README.md`: This documentation. 499 | * `utils.py`: Common functions. 500 | * `convert_prediction_from_json_to_csv.py`: Converts the JSON output of 501 | batch prediction into a CSV file for submission. 502 | 503 | ## About This Project 504 | This project is meant help people quickly get started working with the 505 | [YouTube-8M](https://research.google.com/youtube8m/) dataset. 506 | This is not an official Google product. 507 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Binary for training Tensorflow models on the YouTube-8M dataset.""" 15 | 16 | import json 17 | import os 18 | import time 19 | 20 | import eval_util 21 | import export_model 22 | import losses 23 | import frame_level_models 24 | import video_level_models 25 | import readers 26 | import tensorflow as tf 27 | import tensorflow.contrib.slim as slim 28 | from tensorflow import app 29 | from tensorflow import flags 30 | from tensorflow import gfile 31 | from tensorflow import logging 32 | from tensorflow.python.client import device_lib 33 | import utils 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | if __name__ == "__main__": 38 | # Dataset flags. 39 | flags.DEFINE_string("train_dir", "/tmp/yt8m_model/", 40 | "The directory to save the model files in.") 41 | flags.DEFINE_string( 42 | "train_data_pattern", "", 43 | "File glob for the training dataset. If the files refer to Frame Level " 44 | "features (i.e. tensorflow.SequenceExample), then set --reader_type " 45 | "format. The (Sequence)Examples are expected to have 'rgb' byte array " 46 | "sequence feature as well as a 'labels' int64 context feature.") 47 | flags.DEFINE_string("feature_names", "mean_rgb", "Name of the feature " 48 | "to use for training.") 49 | flags.DEFINE_string("feature_sizes", "1024", "Length of the feature vectors.") 50 | flags.DEFINE_integer("num_classes", 4716, "Number of classes in dataset.") 51 | 52 | # Model flags. 53 | flags.DEFINE_bool( 54 | "frame_features", False, 55 | "If set, then --train_data_pattern must be frame-level features. " 56 | "Otherwise, --train_data_pattern must be aggregated video-level " 57 | "features. The model must also be set appropriately (i.e. to read 3D " 58 | "batches VS 4D batches.") 59 | flags.DEFINE_string( 60 | "model", "LogisticModel", 61 | "Which architecture to use for the model. Models are defined " 62 | "in models.py.") 63 | flags.DEFINE_bool( 64 | "start_new_model", False, 65 | "If set, this will not resume from a checkpoint and will instead create a" 66 | " new model instance.") 67 | 68 | # Training flags. 69 | flags.DEFINE_integer("batch_size", 1024, 70 | "How many examples to process per batch for training.") 71 | flags.DEFINE_string("label_loss", "CrossEntropyLoss", 72 | "Which loss function to use for training the model.") 73 | flags.DEFINE_float( 74 | "regularization_penalty", 1.0, 75 | "How much weight to give to the regularization loss (the label loss has " 76 | "a weight of 1).") 77 | flags.DEFINE_float("base_learning_rate", 0.01, 78 | "Which learning rate to start with.") 79 | flags.DEFINE_float("learning_rate_decay", 0.95, 80 | "Learning rate decay factor to be applied every " 81 | "learning_rate_decay_examples.") 82 | flags.DEFINE_float("learning_rate_decay_examples", 4000000, 83 | "Multiply current learning rate by learning_rate_decay " 84 | "every learning_rate_decay_examples.") 85 | flags.DEFINE_integer("num_epochs", 5, 86 | "How many passes to make over the dataset before " 87 | "halting training.") 88 | flags.DEFINE_integer("max_steps", None, 89 | "The maximum number of iterations of the training loop.") 90 | flags.DEFINE_integer("export_model_steps", 1000, 91 | "The period, in number of steps, with which the model " 92 | "is exported for batch prediction.") 93 | 94 | # Other flags. 95 | flags.DEFINE_integer("num_readers", 8, 96 | "How many threads to use for reading input files.") 97 | flags.DEFINE_string("optimizer", "AdamOptimizer", 98 | "What optimizer class to use.") 99 | flags.DEFINE_float("clip_gradient_norm", 1.0, "Norm to clip gradients to.") 100 | flags.DEFINE_bool( 101 | "log_device_placement", False, 102 | "Whether to write the device on which every op will run into the " 103 | "logs on startup.") 104 | 105 | def validate_class_name(flag_value, category, modules, expected_superclass): 106 | """Checks that the given string matches a class of the expected type. 107 | 108 | Args: 109 | flag_value: A string naming the class to instantiate. 110 | category: A string used further describe the class in error messages 111 | (e.g. 'model', 'reader', 'loss'). 112 | modules: A list of modules to search for the given class. 113 | expected_superclass: A class that the given class should inherit from. 114 | 115 | Raises: 116 | FlagsError: If the given class could not be found or if the first class 117 | found with that name doesn't inherit from the expected superclass. 118 | 119 | Returns: 120 | True if a class was found that matches the given constraints. 121 | """ 122 | candidates = [getattr(module, flag_value, None) for module in modules] 123 | for candidate in candidates: 124 | if not candidate: 125 | continue 126 | if not issubclass(candidate, expected_superclass): 127 | raise flags.FlagsError("%s '%s' doesn't inherit from %s." % 128 | (category, flag_value, 129 | expected_superclass.__name__)) 130 | return True 131 | raise flags.FlagsError("Unable to find %s '%s'." % (category, flag_value)) 132 | 133 | def get_input_data_tensors(reader, 134 | data_pattern, 135 | batch_size=1000, 136 | num_epochs=None, 137 | num_readers=1): 138 | """Creates the section of the graph which reads the training data. 139 | 140 | Args: 141 | reader: A class which parses the training data. 142 | data_pattern: A 'glob' style path to the data files. 143 | batch_size: How many examples to process at a time. 144 | num_epochs: How many passes to make over the training data. Set to 'None' 145 | to run indefinitely. 146 | num_readers: How many I/O threads to use. 147 | 148 | Returns: 149 | A tuple containing the features tensor, labels tensor, and optionally a 150 | tensor containing the number of frames per video. The exact dimensions 151 | depend on the reader being used. 152 | 153 | Raises: 154 | IOError: If no files matching the given pattern were found. 155 | """ 156 | logging.info("Using batch size of " + str(batch_size) + " for training.") 157 | with tf.name_scope("train_input"): 158 | files = gfile.Glob(data_pattern) 159 | if not files: 160 | raise IOError("Unable to find training files. data_pattern='" + 161 | data_pattern + "'.") 162 | logging.info("Number of training files: %s.", str(len(files))) 163 | filename_queue = tf.train.string_input_producer( 164 | files, num_epochs=num_epochs, shuffle=True) 165 | training_data = [ 166 | reader.prepare_reader(filename_queue) for _ in range(num_readers) 167 | ] 168 | 169 | return tf.train.shuffle_batch_join( 170 | training_data, 171 | batch_size=batch_size, 172 | capacity=batch_size * 5, 173 | min_after_dequeue=batch_size, 174 | allow_smaller_final_batch=True, 175 | enqueue_many=True) 176 | 177 | 178 | def find_class_by_name(name, modules): 179 | """Searches the provided modules for the named class and returns it.""" 180 | modules = [getattr(module, name, None) for module in modules] 181 | return next(a for a in modules if a) 182 | 183 | def build_graph(reader, 184 | model, 185 | train_data_pattern, 186 | label_loss_fn=losses.CrossEntropyLoss(), 187 | batch_size=1000, 188 | base_learning_rate=0.01, 189 | learning_rate_decay_examples=1000000, 190 | learning_rate_decay=0.95, 191 | optimizer_class=tf.train.AdamOptimizer, 192 | clip_gradient_norm=1.0, 193 | regularization_penalty=1, 194 | num_readers=1, 195 | num_epochs=None): 196 | """Creates the Tensorflow graph. 197 | 198 | This will only be called once in the life of 199 | a training model, because after the graph is created the model will be 200 | restored from a meta graph file rather than being recreated. 201 | 202 | Args: 203 | reader: The data file reader. It should inherit from BaseReader. 204 | model: The core model (e.g. logistic or neural net). It should inherit 205 | from BaseModel. 206 | train_data_pattern: glob path to the training data files. 207 | label_loss_fn: What kind of loss to apply to the model. It should inherit 208 | from BaseLoss. 209 | batch_size: How many examples to process at a time. 210 | base_learning_rate: What learning rate to initialize the optimizer with. 211 | optimizer_class: Which optimization algorithm to use. 212 | clip_gradient_norm: Magnitude of the gradient to clip to. 213 | regularization_penalty: How much weight to give the regularization loss 214 | compared to the label loss. 215 | num_readers: How many threads to use for I/O operations. 216 | num_epochs: How many passes to make over the data. 'None' means an 217 | unlimited number of passes. 218 | """ 219 | 220 | global_step = tf.Variable(0, trainable=False, name="global_step") 221 | 222 | local_device_protos = device_lib.list_local_devices() 223 | gpus = [x.name for x in local_device_protos if x.device_type == 'GPU'] 224 | num_gpus = len(gpus) 225 | 226 | if num_gpus > 0: 227 | logging.info("Using the following GPUs to train: " + str(gpus)) 228 | num_towers = num_gpus 229 | device_string = '/gpu:%d' 230 | else: 231 | logging.info("No GPUs found. Training on CPU.") 232 | num_towers = 1 233 | device_string = '/cpu:%d' 234 | 235 | learning_rate = tf.train.exponential_decay( 236 | base_learning_rate, 237 | global_step * batch_size * num_towers, 238 | learning_rate_decay_examples, 239 | learning_rate_decay, 240 | staircase=True) 241 | tf.summary.scalar('learning_rate', learning_rate) 242 | 243 | optimizer = optimizer_class(learning_rate) 244 | unused_video_id, model_input_raw, labels_batch, num_frames = ( 245 | get_input_data_tensors( 246 | reader, 247 | train_data_pattern, 248 | batch_size=batch_size * num_towers, 249 | num_readers=num_readers, 250 | num_epochs=num_epochs)) 251 | tf.summary.histogram("model/input_raw", model_input_raw) 252 | 253 | feature_dim = len(model_input_raw.get_shape()) - 1 254 | 255 | model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) 256 | 257 | tower_inputs = tf.split(model_input, num_towers) 258 | tower_labels = tf.split(labels_batch, num_towers) 259 | tower_num_frames = tf.split(num_frames, num_towers) 260 | tower_gradients = [] 261 | tower_predictions = [] 262 | tower_label_losses = [] 263 | tower_reg_losses = [] 264 | for i in range(num_towers): 265 | # For some reason these 'with' statements can't be combined onto the same 266 | # line. They have to be nested. 267 | with tf.device(device_string % i): 268 | with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)): 269 | with (slim.arg_scope([slim.model_variable, slim.variable], device="/cpu:0" if num_gpus!=1 else "/gpu:0")): 270 | result = model.create_model( 271 | tower_inputs[i], 272 | num_frames=tower_num_frames[i], 273 | vocab_size=reader.num_classes, 274 | labels=tower_labels[i]) 275 | for variable in slim.get_model_variables(): 276 | tf.summary.histogram(variable.op.name, variable) 277 | 278 | predictions = result["predictions"] 279 | tower_predictions.append(predictions) 280 | 281 | if "loss" in result.keys(): 282 | label_loss = result["loss"] 283 | else: 284 | label_loss = label_loss_fn.calculate_loss(predictions, tower_labels[i]) 285 | 286 | if "regularization_loss" in result.keys(): 287 | reg_loss = result["regularization_loss"] 288 | else: 289 | reg_loss = tf.constant(0.0) 290 | 291 | reg_losses = tf.losses.get_regularization_losses() 292 | if reg_losses: 293 | reg_loss += tf.add_n(reg_losses) 294 | 295 | tower_reg_losses.append(reg_loss) 296 | 297 | # Adds update_ops (e.g., moving average updates in batch normalization) as 298 | # a dependency to the train_op. 299 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 300 | if "update_ops" in result.keys(): 301 | update_ops += result["update_ops"] 302 | if update_ops: 303 | with tf.control_dependencies(update_ops): 304 | barrier = tf.no_op(name="gradient_barrier") 305 | with tf.control_dependencies([barrier]): 306 | label_loss = tf.identity(label_loss) 307 | 308 | tower_label_losses.append(label_loss) 309 | 310 | # Incorporate the L2 weight penalties etc. 311 | final_loss = regularization_penalty * reg_loss + label_loss 312 | gradients = optimizer.compute_gradients(final_loss, 313 | colocate_gradients_with_ops=False) 314 | tower_gradients.append(gradients) 315 | label_loss = tf.reduce_mean(tf.stack(tower_label_losses)) 316 | tf.summary.scalar("label_loss", label_loss) 317 | if regularization_penalty != 0: 318 | reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses)) 319 | tf.summary.scalar("reg_loss", reg_loss) 320 | merged_gradients = utils.combine_gradients(tower_gradients) 321 | 322 | if clip_gradient_norm > 0: 323 | with tf.name_scope('clip_grads'): 324 | merged_gradients = utils.clip_gradient_norms(merged_gradients, clip_gradient_norm) 325 | 326 | train_op = optimizer.apply_gradients(merged_gradients, global_step=global_step) 327 | 328 | tf.add_to_collection("global_step", global_step) 329 | tf.add_to_collection("loss", label_loss) 330 | tf.add_to_collection("predictions", tf.concat(tower_predictions, 0)) 331 | tf.add_to_collection("input_batch_raw", model_input_raw) 332 | tf.add_to_collection("input_batch", model_input) 333 | tf.add_to_collection("num_frames", num_frames) 334 | tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32)) 335 | tf.add_to_collection("train_op", train_op) 336 | 337 | 338 | class Trainer(object): 339 | """A Trainer to train a Tensorflow graph.""" 340 | 341 | def __init__(self, cluster, task, train_dir, model, reader, model_exporter, 342 | log_device_placement=True, max_steps=None, 343 | export_model_steps=1000): 344 | """"Creates a Trainer. 345 | 346 | Args: 347 | cluster: A tf.train.ClusterSpec if the execution is distributed. 348 | None otherwise. 349 | task: A TaskSpec describing the job type and the task index. 350 | """ 351 | 352 | self.cluster = cluster 353 | self.task = task 354 | self.is_master = (task.type == "master" and task.index == 0) 355 | self.train_dir = train_dir 356 | self.config = tf.ConfigProto( 357 | allow_soft_placement=True,log_device_placement=log_device_placement) 358 | self.model = model 359 | self.reader = reader 360 | self.model_exporter = model_exporter 361 | self.max_steps = max_steps 362 | self.max_steps_reached = False 363 | self.export_model_steps = export_model_steps 364 | self.last_model_export_step = 0 365 | 366 | # if self.is_master and self.task.index > 0: 367 | # raise StandardError("%s: Only one replica of master expected", 368 | # task_as_string(self.task)) 369 | 370 | def run(self, start_new_model=False): 371 | """Performs training on the currently defined Tensorflow graph. 372 | 373 | Returns: 374 | A tuple of the training Hit@1 and the training PERR. 375 | """ 376 | if self.is_master and start_new_model: 377 | self.remove_training_directory(self.train_dir) 378 | 379 | target, device_fn = self.start_server_if_distributed() 380 | 381 | meta_filename = self.get_meta_filename(start_new_model, self.train_dir) 382 | 383 | with tf.Graph().as_default() as graph: 384 | 385 | if meta_filename: 386 | saver = self.recover_model(meta_filename) 387 | 388 | with tf.device(device_fn): 389 | if not meta_filename: 390 | saver = self.build_model(self.model, self.reader) 391 | 392 | global_step = tf.get_collection("global_step")[0] 393 | loss = tf.get_collection("loss")[0] 394 | predictions = tf.get_collection("predictions")[0] 395 | labels = tf.get_collection("labels")[0] 396 | train_op = tf.get_collection("train_op")[0] 397 | init_op = tf.global_variables_initializer() 398 | 399 | sv = tf.train.Supervisor( 400 | graph, 401 | logdir=self.train_dir, 402 | init_op=init_op, 403 | is_chief=self.is_master, 404 | global_step=global_step, 405 | save_model_secs=15 * 60, 406 | save_summaries_secs=120, 407 | saver=saver) 408 | 409 | logging.info("%s: Starting managed session.", task_as_string(self.task)) 410 | with sv.managed_session(target, config=self.config) as sess: 411 | try: 412 | logging.info("%s: Entering training loop.", task_as_string(self.task)) 413 | while (not sv.should_stop()) and (not self.max_steps_reached): 414 | batch_start_time = time.time() 415 | _, global_step_val, loss_val, predictions_val, labels_val = sess.run( 416 | [train_op, global_step, loss, predictions, labels]) 417 | seconds_per_batch = time.time() - batch_start_time 418 | examples_per_second = labels_val.shape[0] / seconds_per_batch 419 | 420 | if self.max_steps and self.max_steps <= global_step_val: 421 | self.max_steps_reached = True 422 | 423 | if self.is_master and global_step_val % 10 == 0 and self.train_dir: 424 | eval_start_time = time.time() 425 | hit_at_one = eval_util.calculate_hit_at_one(predictions_val, labels_val) 426 | perr = eval_util.calculate_precision_at_equal_recall_rate(predictions_val, 427 | labels_val) 428 | gap = eval_util.calculate_gap(predictions_val, labels_val) 429 | eval_end_time = time.time() 430 | eval_time = eval_end_time - eval_start_time 431 | 432 | logging.info("training step " + str(global_step_val) + " | Loss: " + ("%.2f" % loss_val) + 433 | " Examples/sec: " + ("%.2f" % examples_per_second) + " | Hit@1: " + 434 | ("%.2f" % hit_at_one) + " PERR: " + ("%.2f" % perr) + 435 | " GAP: " + ("%.2f" % gap)) 436 | 437 | sv.summary_writer.add_summary( 438 | utils.MakeSummary("model/Training_Hit@1", hit_at_one), 439 | global_step_val) 440 | sv.summary_writer.add_summary( 441 | utils.MakeSummary("model/Training_Perr", perr), global_step_val) 442 | sv.summary_writer.add_summary( 443 | utils.MakeSummary("model/Training_GAP", gap), global_step_val) 444 | sv.summary_writer.add_summary( 445 | utils.MakeSummary("global_step/Examples/Second", 446 | examples_per_second), global_step_val) 447 | sv.summary_writer.flush() 448 | 449 | # Exporting the model every x steps 450 | time_to_export = ((self.last_model_export_step == 0) or 451 | (global_step_val - self.last_model_export_step 452 | >= self.export_model_steps)) 453 | 454 | if self.is_master and time_to_export: 455 | self.export_model(global_step_val, sv.saver, sv.save_path, sess) 456 | self.last_model_export_step = global_step_val 457 | else: 458 | logging.info("training step " + str(global_step_val) + " | Loss: " + 459 | ("%.2f" % loss_val) + " Examples/sec: " + ("%.2f" % examples_per_second)) 460 | except tf.errors.OutOfRangeError: 461 | logging.info("%s: Done training -- epoch limit reached.", 462 | task_as_string(self.task)) 463 | 464 | logging.info("%s: Exited training loop.", task_as_string(self.task)) 465 | sv.Stop() 466 | 467 | def export_model(self, global_step_val, saver, save_path, session): 468 | 469 | # If the model has already been exported at this step, return. 470 | if global_step_val == self.last_model_export_step: 471 | return 472 | 473 | last_checkpoint = saver.save(session, save_path, global_step_val) 474 | 475 | model_dir = "{0}/export/step_{1}".format(self.train_dir, global_step_val) 476 | logging.info("%s: Exporting the model at step %s to %s.", 477 | task_as_string(self.task), global_step_val, model_dir) 478 | 479 | self.model_exporter.export_model( 480 | model_dir=model_dir, 481 | global_step_val=global_step_val, 482 | last_checkpoint=last_checkpoint) 483 | 484 | def start_server_if_distributed(self): 485 | """Starts a server if the execution is distributed.""" 486 | 487 | if self.cluster: 488 | logging.info("%s: Starting trainer within cluster %s.", 489 | task_as_string(self.task), self.cluster.as_dict()) 490 | server = start_server(self.cluster, self.task) 491 | target = server.target 492 | device_fn = tf.train.replica_device_setter( 493 | ps_device="/job:ps", 494 | worker_device="/job:%s/task:%d" % (self.task.type, self.task.index), 495 | cluster=self.cluster) 496 | else: 497 | target = "" 498 | device_fn = "" 499 | return (target, device_fn) 500 | 501 | def remove_training_directory(self, train_dir): 502 | """Removes the training directory.""" 503 | try: 504 | logging.info( 505 | "%s: Removing existing train directory.", 506 | task_as_string(self.task)) 507 | gfile.DeleteRecursively(train_dir) 508 | except: 509 | logging.error( 510 | "%s: Failed to delete directory " + train_dir + 511 | " when starting a new model. Please delete it manually and" + 512 | " try again.", task_as_string(self.task)) 513 | 514 | def get_meta_filename(self, start_new_model, train_dir): 515 | if start_new_model: 516 | logging.info("%s: Flag 'start_new_model' is set. Building a new model.", 517 | task_as_string(self.task)) 518 | return None 519 | 520 | latest_checkpoint = tf.train.latest_checkpoint(train_dir) 521 | if not latest_checkpoint: 522 | logging.info("%s: No checkpoint file found. Building a new model.", 523 | task_as_string(self.task)) 524 | return None 525 | 526 | meta_filename = latest_checkpoint + ".meta" 527 | if not gfile.Exists(meta_filename): 528 | logging.info("%s: No meta graph file found. Building a new model.", 529 | task_as_string(self.task)) 530 | return None 531 | else: 532 | return meta_filename 533 | 534 | def recover_model(self, meta_filename): 535 | logging.info("%s: Restoring from meta graph file %s", 536 | task_as_string(self.task), meta_filename) 537 | return tf.train.import_meta_graph(meta_filename) 538 | 539 | def build_model(self, model, reader): 540 | """Find the model and build the graph.""" 541 | 542 | label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])() 543 | optimizer_class = find_class_by_name(FLAGS.optimizer, [tf.train]) 544 | 545 | build_graph(reader=reader, 546 | model=model, 547 | optimizer_class=optimizer_class, 548 | clip_gradient_norm=FLAGS.clip_gradient_norm, 549 | train_data_pattern=FLAGS.train_data_pattern, 550 | label_loss_fn=label_loss_fn, 551 | base_learning_rate=FLAGS.base_learning_rate, 552 | learning_rate_decay=FLAGS.learning_rate_decay, 553 | learning_rate_decay_examples=FLAGS.learning_rate_decay_examples, 554 | regularization_penalty=FLAGS.regularization_penalty, 555 | num_readers=FLAGS.num_readers, 556 | batch_size=FLAGS.batch_size, 557 | num_epochs=FLAGS.num_epochs) 558 | 559 | return tf.train.Saver(max_to_keep=0, keep_checkpoint_every_n_hours=0.25) 560 | 561 | 562 | def get_reader(): 563 | # Convert feature_names and feature_sizes to lists of values. 564 | feature_names, feature_sizes = utils.GetListOfFeatureNamesAndSizes( 565 | FLAGS.feature_names, FLAGS.feature_sizes) 566 | num_classes = FLAGS.num_classes 567 | 568 | if FLAGS.frame_features: 569 | reader = readers.YT8MFrameFeatureReader( 570 | num_classes=num_classes, 571 | feature_names=feature_names, feature_sizes=feature_sizes) 572 | else: 573 | reader = readers.YT8MAggregatedFeatureReader( 574 | num_classes=num_classes, 575 | feature_names=feature_names, feature_sizes=feature_sizes) 576 | 577 | return reader 578 | 579 | 580 | class ParameterServer(object): 581 | """A parameter server to serve variables in a distributed execution.""" 582 | 583 | def __init__(self, cluster, task): 584 | """Creates a ParameterServer. 585 | 586 | Args: 587 | cluster: A tf.train.ClusterSpec if the execution is distributed. 588 | None otherwise. 589 | task: A TaskSpec describing the job type and the task index. 590 | """ 591 | 592 | self.cluster = cluster 593 | self.task = task 594 | 595 | def run(self): 596 | """Starts the parameter server.""" 597 | 598 | logging.info("%s: Starting parameter server within cluster %s.", 599 | task_as_string(self.task), self.cluster.as_dict()) 600 | server = start_server(self.cluster, self.task) 601 | server.join() 602 | 603 | 604 | def start_server(cluster, task): 605 | """Creates a Server. 606 | 607 | Args: 608 | cluster: A tf.train.ClusterSpec if the execution is distributed. 609 | None otherwise. 610 | task: A TaskSpec describing the job type and the task index. 611 | """ 612 | 613 | if not task.type: 614 | raise ValueError("%s: The task type must be specified." % 615 | task_as_string(task)) 616 | if task.index is None: 617 | raise ValueError("%s: The task index must be specified." % 618 | task_as_string(task)) 619 | 620 | # Create and start a server. 621 | return tf.train.Server( 622 | tf.train.ClusterSpec(cluster), 623 | protocol="grpc", 624 | job_name=task.type, 625 | task_index=task.index) 626 | 627 | def task_as_string(task): 628 | return "/job:%s/task:%s" % (task.type, task.index) 629 | 630 | def main(unused_argv): 631 | # Load the environment. 632 | env = json.loads(os.environ.get("TF_CONFIG", "{}")) 633 | 634 | # Load the cluster data from the environment. 635 | cluster_data = env.get("cluster", None) 636 | cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None 637 | 638 | # Load the task data from the environment. 639 | task_data = env.get("task", None) or {"type": "master", "index": 0} 640 | task = type("TaskSpec", (object,), task_data) 641 | 642 | # Logging the version. 643 | logging.set_verbosity(tf.logging.INFO) 644 | logging.info("%s: Tensorflow version: %s.", 645 | task_as_string(task), tf.__version__) 646 | 647 | # Dispatch to a master, a worker, or a parameter server. 648 | if not cluster or task.type == "master" or task.type == "worker": 649 | model = find_class_by_name(FLAGS.model, 650 | [frame_level_models, video_level_models])() 651 | 652 | reader = get_reader() 653 | 654 | model_exporter = export_model.ModelExporter( 655 | frame_features=FLAGS.frame_features, 656 | model=model, 657 | reader=reader) 658 | 659 | Trainer(cluster, task, FLAGS.train_dir, model, reader, model_exporter, 660 | FLAGS.log_device_placement, FLAGS.max_steps, 661 | FLAGS.export_model_steps).run(start_new_model=FLAGS.start_new_model) 662 | 663 | elif task.type == "ps": 664 | ParameterServer(cluster, task).run() 665 | else: 666 | raise ValueError("%s: Invalid task_type: %s." % 667 | (task_as_string(task), task.type)) 668 | 669 | if __name__ == "__main__": 670 | app.run() 671 | --------------------------------------------------------------------------------