├── cloudml-gpu.yaml ├── __init__.py ├── .github └── workflows │ └── pythonpackage.yml ├── models.py ├── docs ├── model_overview.md └── files_overview.md ├── CONTRIBUTING.md ├── feature_extractor ├── feature_extractor_test.py ├── README.md ├── feature_extractor.py └── extract_tfrecords_main.py ├── export_model_mediapipe.py ├── convert_prediction_from_json_to_csv.py ├── model_utils.py ├── losses.py ├── video_level_models.py ├── mean_average_precision_calculator.py ├── export_model.py ├── segment_eval_inference.py ├── utils.py ├── eval_util.py ├── average_precision_calculator.py ├── segment_label_ids.csv ├── frame_level_models.py ├── LICENSE ├── readers.py ├── README.md ├── eval.py ├── inference.py └── train.py /cloudml-gpu.yaml: -------------------------------------------------------------------------------- 1 | trainingInput: 2 | scaleTier: CUSTOM 3 | # https://cloud.google.com/ml-engine/docs/machine-types#machine_type_table 4 | masterType: standard_gpu 5 | runtimeVersion: "1.14" 6 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 4 11 | matrix: 12 | python-version: [3.6] 13 | 14 | steps: 15 | - uses: actions/checkout@v1 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v1 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install tensorflow==1.14.0 six Pillow 24 | - name: Yapf Check 25 | run: | 26 | pip install yapf 27 | yapf --diff --style="{based_on_style: google, indent_width:2}" *.py 28 | - name: Test with nosetests 29 | run: | 30 | pip install -U nose 31 | nosetests 32 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Contains the base class for models.""" 15 | 16 | 17 | class BaseModel(object): 18 | """Inherit from this class when implementing new models.""" 19 | 20 | def create_model(self, unused_model_input, **unused_params): 21 | raise NotImplementedError() 22 | -------------------------------------------------------------------------------- /docs/model_overview.md: -------------------------------------------------------------------------------- 1 | # Overview of Models 2 | 3 | This sample code contains implementations of the models given in the 4 | [YouTube-8M technical report](https://arxiv.org/abs/1609.08675). 5 | 6 | ## Video-Level Models 7 | 8 | * `LogisticModel`: Linear projection of the output features into the label 9 | space, followed by a sigmoid function to convert logit values to 10 | probabilities. 11 | * `MoeModel`: A per-class softmax distribution over a configurable number of 12 | logistic classifiers. One of the classifiers in the mixture is not trained, 13 | and always predicts 0. 14 | 15 | ## Frame-Level Models 16 | 17 | * `LstmModel`: Processes the features for each frame using a multi-layered 18 | LSTM neural net. The final internal state of the LSTM is input to a 19 | video-level model for classification. Note that you will need to change the 20 | learning rate to 0.001 when using this model. 21 | * `DbofModel`: Projects the features for each frame into a higher dimensional 22 | 'clustering' space, pools across frames in that space, and then uses a 23 | video-level model to classify the now aggregated features. 24 | * `FrameLevelLogisticModel`: Equivalent to 'LogisticModel', but performs 25 | average-pooling on the fly over frame-level features rather than using 26 | pre-aggregated features. 27 | -------------------------------------------------------------------------------- /docs/files_overview.md: -------------------------------------------------------------------------------- 1 | # Overview of Files 2 | 3 | ## Training 4 | 5 | * `train.py`: The primary script for training models. 6 | * `losses.py`: Contains definitions for loss functions. 7 | * `models.py`: Contains the base class for defining a model. 8 | * `video_level_models.py`: Contains definitions for models that take 9 | aggregated features as input. 10 | * `frame_level_models.py`: Contains definitions for models that take frame- 11 | level features as input. 12 | * `model_util.py`: Contains functions that are of general utility for 13 | implementing models. 14 | * `export_model.py`: Provides a class to export a model during training for 15 | later use in batch prediction. 16 | * `readers.py`: Contains definitions for the Video dataset and Frame dataset 17 | readers. 18 | 19 | ## Evaluation 20 | 21 | * `eval.py`: The primary script for evaluating models. 22 | * `eval_util.py`: Provides a class that calculates all evaluation metrics. 23 | * `average_precision_calculator.py`: Functions for calculating average 24 | precision. 25 | * `mean_average_precision_calculator.py`: Functions for calculating mean 26 | average precision. 27 | * `segment_eval_inference.py`: The primary script to evaluate segment models 28 | with Kaggle metrics. 29 | 30 | ## Inference 31 | 32 | * `inference.py`: Generates an output CSV file containing predictions of the 33 | model over a set of videos. It optionally generates a tarred file of the 34 | model. 35 | 36 | ## Misc 37 | 38 | * `README.md`: This documentation. 39 | * `utils.py`: Common functions. 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We are accepting patches and contributions to this project. To set expectations, 4 | this project is primarily intended to be a flexible starting point for 5 | researchers working with the YouTube-8M dataset. As such, we would like to keep 6 | it simple. We are more likely to accept small bug fixes and optimizations, and 7 | less likely to accept patches which add significant complexity. For the latter 8 | type of contribution, we recommend creating a Github fork of the project 9 | instead. 10 | 11 | If you would like to contribute, there are a few small guidelines you need to 12 | follow. 13 | 14 | ## Contributor License Agreement 15 | 16 | Contributions to any Google project must be accompanied by a Contributor License 17 | Agreement. This is necessary because you own the copyright to your changes, even 18 | after your contribution becomes part of this project. So this agreement simply 19 | gives us permission to use and redistribute your contributions as part of the 20 | project. Head over to to see your current 21 | agreements on file or to sign a new one. 22 | 23 | You generally only need to submit a CLA once, so if you've already submitted one 24 | (even if it was for a different project), you probably don't need to do it 25 | again. 26 | 27 | ## Code reviews 28 | 29 | All submissions, including submissions by project members, require review. We 30 | use GitHub pull requests for this purpose. Consult [GitHub Help] for more 31 | information on using pull requests. 32 | 33 | [GitHub Help]: https://help.github.com/articles/about-pull-requests/ 34 | -------------------------------------------------------------------------------- /feature_extractor/feature_extractor_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for feature_extractor.""" 15 | 16 | import json 17 | import os 18 | import feature_extractor 19 | import numpy 20 | from PIL import Image 21 | from six.moves import cPickle 22 | from tensorflow.python.platform import googletest 23 | 24 | 25 | def _FilePath(filename): 26 | return os.path.join('testdata', filename) 27 | 28 | 29 | def _MeanElementWiseDifference(a, b): 30 | """Calculates element-wise percent difference between two numpy matrices.""" 31 | difference = numpy.abs(a - b) 32 | denominator = numpy.maximum(numpy.abs(a), numpy.abs(b)) 33 | 34 | # We dont care if one is 0 and another is 0.01 35 | return (difference / (0.01 + denominator)).mean() 36 | 37 | 38 | class FeatureExtractorTest(googletest.TestCase): 39 | 40 | def setUp(self): 41 | self._extractor = feature_extractor.YouTube8MFeatureExtractor() 42 | 43 | def testPCAOnFeatureVector(self): 44 | sports_1m_test_data = cPickle.load(open(_FilePath('sports1m_frame.pkl'))) 45 | actual_pca = self._extractor.apply_pca(sports_1m_test_data['original']) 46 | expected_pca = sports_1m_test_data['pca'] 47 | self.assertLess(_MeanElementWiseDifference(actual_pca, expected_pca), 1e-5) 48 | 49 | 50 | if __name__ == '__main__': 51 | googletest.main() 52 | -------------------------------------------------------------------------------- /export_model_mediapipe.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | import numpy as np 3 | import tensorflow as tf 4 | from tensorflow import app 5 | from tensorflow import flags 6 | 7 | FLAGS = flags.FLAGS 8 | 9 | 10 | def main(unused_argv): 11 | # Get the input tensor names to be replaced. 12 | tf.reset_default_graph() 13 | meta_graph_location = FLAGS.checkpoint_file + ".meta" 14 | tf.train.import_meta_graph(meta_graph_location, clear_devices=True) 15 | 16 | input_tensor_name = tf.get_collection("input_batch_raw")[0].name 17 | num_frames_tensor_name = tf.get_collection("num_frames")[0].name 18 | 19 | # Create output graph. 20 | saver = tf.train.Saver() 21 | tf.reset_default_graph() 22 | 23 | input_feature_placeholder = tf.placeholder( 24 | tf.float32, shape=(None, None, 1152)) 25 | num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1)) 26 | 27 | saver = tf.train.import_meta_graph( 28 | meta_graph_location, 29 | input_map={ 30 | input_tensor_name: input_feature_placeholder, 31 | num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1) 32 | }, 33 | clear_devices=True) 34 | predictions_tensor = tf.get_collection("predictions")[0] 35 | 36 | with tf.Session() as sess: 37 | print("restoring variables from " + FLAGS.checkpoint_file) 38 | saver.restore(sess, FLAGS.checkpoint_file) 39 | tf.saved_model.simple_save( 40 | sess, 41 | FLAGS.output_dir, 42 | inputs={'rgb_and_audio': input_feature_placeholder, 43 | 'num_frames': num_frames_placeholder}, 44 | outputs={'predictions': predictions_tensor}) 45 | 46 | # Try running inference. 47 | predictions = sess.run( 48 | [predictions_tensor], 49 | feed_dict={ 50 | input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32), 51 | num_frames_placeholder: np.array([[7]], dtype=np.int32)}) 52 | print('Test inference:', predictions) 53 | 54 | print('Model saved to ', FLAGS.output_dir) 55 | 56 | 57 | if __name__ == '__main__': 58 | flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.') 59 | flags.DEFINE_string('output_dir', None, 'SavedModel output directory.') 60 | app.run(main) 61 | -------------------------------------------------------------------------------- /convert_prediction_from_json_to_csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utility to convert the output of batch prediction into a CSV submission. 15 | 16 | It converts the JSON files created by the command 17 | 'gcloud beta ml jobs submit prediction' into a CSV file ready for submission. 18 | """ 19 | 20 | import json 21 | import tensorflow as tf 22 | 23 | from builtins import range 24 | from tensorflow import app 25 | from tensorflow import flags 26 | from tensorflow import gfile 27 | from tensorflow import logging 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | if __name__ == "__main__": 32 | 33 | flags.DEFINE_string( 34 | "json_prediction_files_pattern", None, 35 | "Pattern specifying the list of JSON files that the command " 36 | "'gcloud beta ml jobs submit prediction' outputs. These files are " 37 | "located in the output path of the prediction command and are prefixed " 38 | "with 'prediction.results'.") 39 | flags.DEFINE_string( 40 | "csv_output_file", None, 41 | "The file to save the predictions converted to the CSV format.") 42 | 43 | 44 | def get_csv_header(): 45 | return "VideoId,LabelConfidencePairs\n" 46 | 47 | 48 | def to_csv_row(json_data): 49 | 50 | video_id = json_data["video_id"] 51 | 52 | class_indexes = json_data["class_indexes"] 53 | predictions = json_data["predictions"] 54 | 55 | if isinstance(video_id, list): 56 | video_id = video_id[0] 57 | class_indexes = class_indexes[0] 58 | predictions = predictions[0] 59 | 60 | if len(class_indexes) != len(predictions): 61 | raise ValueError( 62 | "The number of indexes (%s) and predictions (%s) must be equal." % 63 | (len(class_indexes), len(predictions))) 64 | 65 | return (video_id.decode("utf-8") + "," + 66 | " ".join("%i %f" % (class_indexes[i], predictions[i]) 67 | for i in range(len(class_indexes))) + "\n") 68 | 69 | 70 | def main(unused_argv): 71 | logging.set_verbosity(tf.logging.INFO) 72 | 73 | if not FLAGS.json_prediction_files_pattern: 74 | raise ValueError( 75 | "The flag --json_prediction_files_pattern must be specified.") 76 | 77 | if not FLAGS.csv_output_file: 78 | raise ValueError("The flag --csv_output_file must be specified.") 79 | 80 | logging.info("Looking for prediction files with pattern: %s", 81 | FLAGS.json_prediction_files_pattern) 82 | 83 | file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern) 84 | logging.info("Found files: %s", file_paths) 85 | 86 | logging.info("Writing submission file to: %s", FLAGS.csv_output_file) 87 | with gfile.Open(FLAGS.csv_output_file, "w+") as output_file: 88 | output_file.write(get_csv_header()) 89 | 90 | for file_path in file_paths: 91 | logging.info("processing file: %s", file_path) 92 | 93 | with gfile.Open(file_path) as input_file: 94 | 95 | for line in input_file: 96 | json_data = json.loads(line) 97 | output_file.write(to_csv_row(json_data)) 98 | 99 | output_file.flush() 100 | logging.info("done") 101 | 102 | 103 | if __name__ == "__main__": 104 | app.run() 105 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Contains a collection of util functions for model construction.""" 15 | import numpy 16 | import tensorflow as tf 17 | from tensorflow import logging 18 | from tensorflow import flags 19 | import tensorflow.contrib.slim as slim 20 | 21 | 22 | def SampleRandomSequence(model_input, num_frames, num_samples): 23 | """Samples a random sequence of frames of size num_samples. 24 | 25 | Args: 26 | model_input: A tensor of size batch_size x max_frames x feature_size 27 | num_frames: A tensor of size batch_size x 1 28 | num_samples: A scalar 29 | 30 | Returns: 31 | `model_input`: A tensor of size batch_size x num_samples x feature_size 32 | """ 33 | 34 | batch_size = tf.shape(model_input)[0] 35 | frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0), 36 | [batch_size, 1]) 37 | max_start_frame_index = tf.maximum(num_frames - num_samples, 0) 38 | start_frame_index = tf.cast( 39 | tf.multiply(tf.random_uniform([batch_size, 1]), 40 | tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32) 41 | frame_index = tf.minimum(start_frame_index + frame_index_offset, 42 | tf.cast(num_frames - 1, tf.int32)) 43 | batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), 44 | [1, num_samples]) 45 | index = tf.stack([batch_index, frame_index], 2) 46 | return tf.gather_nd(model_input, index) 47 | 48 | 49 | def SampleRandomFrames(model_input, num_frames, num_samples): 50 | """Samples a random set of frames of size num_samples. 51 | 52 | Args: 53 | model_input: A tensor of size batch_size x max_frames x feature_size 54 | num_frames: A tensor of size batch_size x 1 55 | num_samples: A scalar 56 | 57 | Returns: 58 | `model_input`: A tensor of size batch_size x num_samples x feature_size 59 | """ 60 | batch_size = tf.shape(model_input)[0] 61 | frame_index = tf.cast( 62 | tf.multiply(tf.random_uniform([batch_size, num_samples]), 63 | tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])), 64 | tf.int32) 65 | batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1), 66 | [1, num_samples]) 67 | index = tf.stack([batch_index, frame_index], 2) 68 | return tf.gather_nd(model_input, index) 69 | 70 | 71 | def FramePooling(frames, method, **unused_params): 72 | """Pools over the frames of a video. 73 | 74 | Args: 75 | frames: A tensor with shape [batch_size, num_frames, feature_size]. 76 | method: "average", "max", "attention", or "none". 77 | 78 | Returns: 79 | A tensor with shape [batch_size, feature_size] for average, max, or 80 | attention pooling. A tensor with shape [batch_size*num_frames, feature_size] 81 | for none pooling. 82 | 83 | Raises: 84 | ValueError: if method is other than "average", "max", "attention", or 85 | "none". 86 | """ 87 | if method == "average": 88 | return tf.reduce_mean(frames, 1) 89 | elif method == "max": 90 | return tf.reduce_max(frames, 1) 91 | elif method == "none": 92 | feature_size = frames.shape_as_list()[2] 93 | return tf.reshape(frames, [-1, feature_size]) 94 | else: 95 | raise ValueError("Unrecognized pooling method: %s" % method) 96 | -------------------------------------------------------------------------------- /feature_extractor/README.md: -------------------------------------------------------------------------------- 1 | # Please use the [MediaPipe YouTube8M feature extractor](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m) which extracts both RGB and audio features instead. 2 | 3 | --- 4 | 5 | # YouTube8M Feature Extractor (DEPRECATED) 6 | 7 | This directory contains binary and library code that can extract YouTube8M 8 | features from images and videos. The code requires the Inception TensorFlow 9 | model ([tutorial](https://www.tensorflow.org/tutorials/image_recognition)) and 10 | our PCA matrix, as outlined in Section 3.3 of our 11 | [paper](https://arxiv.org/abs/1609.08675). The first time you use our code, it 12 | will **automatically** download the inception model (75 Megabytes, tensorflow 13 | [GraphDef proto](https://www.tensorflow.org/api_docs/python/tf/GraphDef), 14 | [download link](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz)) 15 | and the PCA matrix (25 Megabytes, Numpy arrays, 16 | [download link](http://data.yt8m.org/yt8m_pca.tgz)). 17 | 18 | ## Usage 19 | 20 | There are two ways to use this code: 21 | 22 | 1. Binary `extract_tfrecords_main.py` processes a CSV file of videos (and their 23 | labels) and outputs `tfrecord` file. Files created with this binary match 24 | the schema of YouTube-8M dataset files, and are therefore are compatible 25 | with our training starter code. You can also use the file for inference 26 | using your models that are pre-trained on YouTube-8M. 27 | 1. Library `feature_extractor.py` which can extract features from images. 28 | 29 | ### Using the Binary to create `tfrecords` from videos 30 | 31 | You can use binary `extract_tfrecords_main.py` to create `tfrecord` files. 32 | However, this binary assumes that you have OpenCV properly installed (see end of 33 | subsection). Assume that you have two videos `/path/to/vid1` and 34 | `/path/to/vid2`, respectively, with multi-integer labels of `(52, 3, 10)` and 35 | `(7, 67)`. To create `tfrecord` containing features and labels for those videos, 36 | you must first create a CSV file (e.g. on `/path/to/vid_dataset.csv`) with 37 | contents: 38 | 39 | /path/to/vid1,52;3;10 40 | /path/to/vid2,7;67 41 | 42 | Note that the CSV is comma-separated but the label-field is semi-colon separated 43 | to allow for multiple labels per video. 44 | 45 | Then, you can create the `tfrecord` by calling the binary: 46 | 47 | python extract_tfrecords_main.py --input_videos_csv /path/to/vid_dataset.csv \ 48 | --output_tfrecords_file /path/to/output.tfrecord 49 | 50 | Now, you can use the output file for training and/or inference using our starter 51 | code. 52 | 53 | `extract_tfrecords_main.py` requires OpenCV python bindings to be installed and 54 | linked with ffmpeg. In other words, running this command should print `True`: 55 | 56 | python -c 'import cv2; print cv2.VideoCapture().open("/path/to/some/video.mp4")' 57 | 58 | ### Using the library to extract features from images 59 | 60 | To extract our features from an image file `cropped_panda.jpg`, you can use this 61 | python code: 62 | 63 | ```python 64 | from PIL import Image 65 | import numpy 66 | 67 | # Instantiate extractor. Slow if called first time on your machine, as it 68 | # needs to download 100 MB. 69 | extractor = YouTube8MFeatureExtractor() 70 | 71 | image_file = os.path.join(extractor._model_dir, 'cropped_panda.jpg') 72 | 73 | im = numpy.array(Image.open(image_file)) 74 | features = extractor.extract_rgb_frame_features(im) 75 | ``` 76 | 77 | The constructor `extractor = YouTube8MFeatureExtractor()` will create a 78 | directory `~/yt8m/`, if it does not exist, and will download and untar the two 79 | model files (inception and PCA matrix). If you prefer, you can point our 80 | extractor to another directory as: 81 | 82 | ```python 83 | extractor = YouTube8MFeatureExtractor(model_dir="/path/to/yt8m_files") 84 | ``` 85 | 86 | You can also pre-populate your custom `"/path/to/yt8m_files"` by manually 87 | downloading (e.g. using `wget`) the URLs and un-tarring them, for example: 88 | 89 | ```bash 90 | mkdir -p /path/to/yt8m_files 91 | cd /path/to/yt8m_files 92 | 93 | wget http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 94 | wget http://data.yt8m.org/yt8m_pca.tgz 95 | 96 | tar zxvf inception-2015-12-05.tgz 97 | tar zxvf yt8m_pca.tgz 98 | ``` 99 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Provides definitions for non-regularized training or test losses.""" 15 | 16 | import tensorflow as tf 17 | 18 | 19 | class BaseLoss(object): 20 | """Inherit from this class when implementing new losses.""" 21 | 22 | def calculate_loss(self, unused_predictions, unused_labels, **unused_params): 23 | """Calculates the average loss of the examples in a mini-batch. 24 | 25 | Args: 26 | unused_predictions: a 2-d tensor storing the prediction scores, in which 27 | each row represents a sample in the mini-batch and each column 28 | represents a class. 29 | unused_labels: a 2-d tensor storing the labels, which has the same shape 30 | as the unused_predictions. The labels must be in the range of 0 and 1. 31 | unused_params: loss specific parameters. 32 | 33 | Returns: 34 | A scalar loss tensor. 35 | """ 36 | raise NotImplementedError() 37 | 38 | 39 | class CrossEntropyLoss(BaseLoss): 40 | """Calculate the cross entropy loss between the predictions and labels.""" 41 | 42 | def calculate_loss(self, 43 | predictions, 44 | labels, 45 | label_weights=None, 46 | **unused_params): 47 | with tf.name_scope("loss_xent"): 48 | epsilon = 1e-5 49 | float_labels = tf.cast(labels, tf.float32) 50 | cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + ( 51 | 1 - float_labels) * tf.math.log(1 - predictions + epsilon) 52 | cross_entropy_loss = tf.negative(cross_entropy_loss) 53 | if label_weights is not None: 54 | cross_entropy_loss *= label_weights 55 | return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1)) 56 | 57 | 58 | class HingeLoss(BaseLoss): 59 | """Calculate the hinge loss between the predictions and labels. 60 | 61 | Note the subgradient is used in the backpropagation, and thus the optimization 62 | may converge slower. The predictions trained by the hinge loss are between -1 63 | and +1. 64 | """ 65 | 66 | def calculate_loss(self, predictions, labels, b=1.0, **unused_params): 67 | with tf.name_scope("loss_hinge"): 68 | float_labels = tf.cast(labels, tf.float32) 69 | all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32) 70 | all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32) 71 | sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones) 72 | hinge_loss = tf.maximum( 73 | all_zeros, 74 | tf.scalar_mul(b, all_ones) - sign_labels * predictions) 75 | return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1)) 76 | 77 | 78 | class SoftmaxLoss(BaseLoss): 79 | """Calculate the softmax loss between the predictions and labels. 80 | 81 | The function calculates the loss in the following way: first we feed the 82 | predictions to the softmax activation function and then we calculate 83 | the minus linear dot product between the logged softmax activations and the 84 | normalized ground truth label. 85 | 86 | It is an extension to the one-hot label. It allows for more than one positive 87 | labels for each sample. 88 | """ 89 | 90 | def calculate_loss(self, predictions, labels, **unused_params): 91 | with tf.name_scope("loss_softmax"): 92 | epsilon = 10e-8 93 | float_labels = tf.cast(labels, tf.float32) 94 | # l1 normalization (labels are no less than 0) 95 | label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True), 96 | epsilon) 97 | norm_float_labels = tf.div(float_labels, label_rowsum) 98 | softmax_outputs = tf.nn.softmax(predictions) 99 | softmax_loss = tf.negative( 100 | tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)), 101 | 1)) 102 | return tf.reduce_mean(softmax_loss) 103 | -------------------------------------------------------------------------------- /video_level_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Contains model definitions.""" 15 | import math 16 | 17 | import models 18 | import tensorflow as tf 19 | import utils 20 | 21 | from tensorflow import flags 22 | import tensorflow.contrib.slim as slim 23 | 24 | FLAGS = flags.FLAGS 25 | flags.DEFINE_integer( 26 | "moe_num_mixtures", 2, 27 | "The number of mixtures (excluding the dummy 'expert') used for MoeModel.") 28 | 29 | 30 | class LogisticModel(models.BaseModel): 31 | """Logistic model with L2 regularization.""" 32 | 33 | def create_model(self, 34 | model_input, 35 | vocab_size, 36 | l2_penalty=1e-8, 37 | **unused_params): 38 | """Creates a logistic model. 39 | 40 | Args: 41 | model_input: 'batch' x 'num_features' matrix of input features. 42 | vocab_size: The number of classes in the dataset. 43 | 44 | Returns: 45 | A dictionary with a tensor containing the probability predictions of the 46 | model in the 'predictions' key. The dimensions of the tensor are 47 | batch_size x num_classes. 48 | """ 49 | output = slim.fully_connected( 50 | model_input, 51 | vocab_size, 52 | activation_fn=tf.nn.sigmoid, 53 | weights_regularizer=slim.l2_regularizer(l2_penalty)) 54 | return {"predictions": output} 55 | 56 | 57 | class MoeModel(models.BaseModel): 58 | """A softmax over a mixture of logistic models (with L2 regularization).""" 59 | 60 | def create_model(self, 61 | model_input, 62 | vocab_size, 63 | num_mixtures=None, 64 | l2_penalty=1e-8, 65 | **unused_params): 66 | """Creates a Mixture of (Logistic) Experts model. 67 | 68 | The model consists of a per-class softmax distribution over a 69 | configurable number of logistic classifiers. One of the classifiers in the 70 | mixture is not trained, and always predicts 0. 71 | 72 | Args: 73 | model_input: 'batch_size' x 'num_features' matrix of input features. 74 | vocab_size: The number of classes in the dataset. 75 | num_mixtures: The number of mixtures (excluding a dummy 'expert' that 76 | always predicts the non-existence of an entity). 77 | l2_penalty: How much to penalize the squared magnitudes of parameter 78 | values. 79 | 80 | Returns: 81 | A dictionary with a tensor containing the probability predictions of the 82 | model in the 'predictions' key. The dimensions of the tensor are 83 | batch_size x num_classes. 84 | """ 85 | num_mixtures = num_mixtures or FLAGS.moe_num_mixtures 86 | 87 | gate_activations = slim.fully_connected( 88 | model_input, 89 | vocab_size * (num_mixtures + 1), 90 | activation_fn=None, 91 | biases_initializer=None, 92 | weights_regularizer=slim.l2_regularizer(l2_penalty), 93 | scope="gates") 94 | expert_activations = slim.fully_connected( 95 | model_input, 96 | vocab_size * num_mixtures, 97 | activation_fn=None, 98 | weights_regularizer=slim.l2_regularizer(l2_penalty), 99 | scope="experts") 100 | 101 | gating_distribution = tf.nn.softmax( 102 | tf.reshape( 103 | gate_activations, 104 | [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1) 105 | expert_distribution = tf.nn.sigmoid( 106 | tf.reshape(expert_activations, 107 | [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures 108 | 109 | final_probabilities_by_class_and_batch = tf.reduce_sum( 110 | gating_distribution[:, :num_mixtures] * expert_distribution, 1) 111 | final_probabilities = tf.reshape(final_probabilities_by_class_and_batch, 112 | [-1, vocab_size]) 113 | return {"predictions": final_probabilities} 114 | -------------------------------------------------------------------------------- /mean_average_precision_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Calculate the mean average precision. 15 | 16 | It provides an interface for calculating mean average precision 17 | for an entire list or the top-n ranked items. 18 | 19 | Example usages: 20 | We first call the function accumulate many times to process parts of the ranked 21 | list. After processing all the parts, we call peek_map_at_n 22 | to calculate the mean average precision. 23 | 24 | ``` 25 | import random 26 | 27 | p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)]) 28 | a = np.array([[random.choice([0, 1]) for _ in xrange(50)] 29 | for _ in xrange(1000)]) 30 | 31 | # mean average precision for 50 classes. 32 | calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator( 33 | num_class=50) 34 | calculator.accumulate(p, a) 35 | aps = calculator.peek_map_at_n() 36 | ``` 37 | """ 38 | 39 | import average_precision_calculator 40 | 41 | 42 | class MeanAveragePrecisionCalculator(object): 43 | """This class is to calculate mean average precision.""" 44 | 45 | def __init__(self, num_class, filter_empty_classes=True, top_n=None): 46 | """Construct a calculator to calculate the (macro) average precision. 47 | 48 | Args: 49 | num_class: A positive Integer specifying the number of classes. 50 | filter_empty_classes: whether to filter classes without any positives. 51 | top_n: A positive Integer specifying the average precision at n, or None 52 | to use all provided data points. 53 | 54 | Raises: 55 | ValueError: An error occurred when num_class is not a positive integer; 56 | or the top_n_array is not a list of positive integers. 57 | """ 58 | if not isinstance(num_class, int) or num_class <= 1: 59 | raise ValueError("num_class must be a positive integer.") 60 | 61 | self._ap_calculators = [] # member of AveragePrecisionCalculator 62 | self._num_class = num_class # total number of classes 63 | self._filter_empty_classes = filter_empty_classes 64 | for _ in range(num_class): 65 | self._ap_calculators.append( 66 | average_precision_calculator.AveragePrecisionCalculator(top_n=top_n)) 67 | 68 | def accumulate(self, predictions, actuals, num_positives=None): 69 | """Accumulate the predictions and their ground truth labels. 70 | 71 | Args: 72 | predictions: A list of lists storing the prediction scores. The outer 73 | dimension corresponds to classes. 74 | actuals: A list of lists storing the ground truth labels. The dimensions 75 | should correspond to the predictions input. Any value larger than 0 will 76 | be treated as positives, otherwise as negatives. 77 | num_positives: If provided, it is a list of numbers representing the 78 | number of true positives for each class. If not provided, the number of 79 | true positives will be inferred from the 'actuals' array. 80 | 81 | Raises: 82 | ValueError: An error occurred when the shape of predictions and actuals 83 | does not match. 84 | """ 85 | if not num_positives: 86 | num_positives = [None for i in range(self._num_class)] 87 | 88 | calculators = self._ap_calculators 89 | for i in range(self._num_class): 90 | calculators[i].accumulate(predictions[i], actuals[i], num_positives[i]) 91 | 92 | def clear(self): 93 | for calculator in self._ap_calculators: 94 | calculator.clear() 95 | 96 | def is_empty(self): 97 | return ([calculator.heap_size for calculator in self._ap_calculators 98 | ] == [0 for _ in range(self._num_class)]) 99 | 100 | def peek_map_at_n(self): 101 | """Peek the non-interpolated mean average precision at n. 102 | 103 | Returns: 104 | An array of non-interpolated average precision at n (default 0) for each 105 | class. 106 | """ 107 | aps = [] 108 | for i in range(self._num_class): 109 | if (not self._filter_empty_classes or 110 | self._ap_calculators[i].num_accumulated_positives > 0): 111 | ap = self._ap_calculators[i].peek_ap_at_n() 112 | aps.append(ap) 113 | return aps 114 | -------------------------------------------------------------------------------- /export_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities to export a model for batch prediction.""" 15 | 16 | import tensorflow as tf 17 | import tensorflow.contrib.slim as slim 18 | 19 | from tensorflow.python.saved_model import builder as saved_model_builder 20 | from tensorflow.python.saved_model import signature_constants 21 | from tensorflow.python.saved_model import signature_def_utils 22 | from tensorflow.python.saved_model import tag_constants 23 | from tensorflow.python.saved_model import utils as saved_model_utils 24 | 25 | _TOP_PREDICTIONS_IN_OUTPUT = 20 26 | 27 | 28 | class ModelExporter(object): 29 | 30 | def __init__(self, frame_features, model, reader): 31 | self.frame_features = frame_features 32 | self.model = model 33 | self.reader = reader 34 | 35 | with tf.Graph().as_default() as graph: 36 | self.inputs, self.outputs = self.build_inputs_and_outputs() 37 | self.graph = graph 38 | self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True) 39 | 40 | def export_model(self, model_dir, global_step_val, last_checkpoint): 41 | """Exports the model so that it can used for batch predictions.""" 42 | 43 | with self.graph.as_default(): 44 | with tf.Session() as session: 45 | session.run(tf.global_variables_initializer()) 46 | self.saver.restore(session, last_checkpoint) 47 | 48 | signature = signature_def_utils.build_signature_def( 49 | inputs=self.inputs, 50 | outputs=self.outputs, 51 | method_name=signature_constants.PREDICT_METHOD_NAME) 52 | 53 | signature_map = { 54 | signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature 55 | } 56 | 57 | model_builder = saved_model_builder.SavedModelBuilder(model_dir) 58 | model_builder.add_meta_graph_and_variables( 59 | session, 60 | tags=[tag_constants.SERVING], 61 | signature_def_map=signature_map, 62 | clear_devices=True) 63 | model_builder.save() 64 | 65 | def build_inputs_and_outputs(self): 66 | if self.frame_features: 67 | serialized_examples = tf.placeholder(tf.string, shape=(None,)) 68 | 69 | fn = lambda x: self.build_prediction_graph(x) 70 | video_id_output, top_indices_output, top_predictions_output = (tf.map_fn( 71 | fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32))) 72 | 73 | else: 74 | serialized_examples = tf.placeholder(tf.string, shape=(None,)) 75 | 76 | video_id_output, top_indices_output, top_predictions_output = ( 77 | self.build_prediction_graph(serialized_examples)) 78 | 79 | inputs = { 80 | "example_bytes": 81 | saved_model_utils.build_tensor_info(serialized_examples) 82 | } 83 | 84 | outputs = { 85 | "video_id": 86 | saved_model_utils.build_tensor_info(video_id_output), 87 | "class_indexes": 88 | saved_model_utils.build_tensor_info(top_indices_output), 89 | "predictions": 90 | saved_model_utils.build_tensor_info(top_predictions_output) 91 | } 92 | 93 | return inputs, outputs 94 | 95 | def build_prediction_graph(self, serialized_examples): 96 | input_data_dict = ( 97 | self.reader.prepare_serialized_examples(serialized_examples)) 98 | video_id = input_data_dict["video_ids"] 99 | model_input_raw = input_data_dict["video_matrix"] 100 | labels_batch = input_data_dict["labels"] 101 | num_frames = input_data_dict["num_frames"] 102 | 103 | feature_dim = len(model_input_raw.get_shape()) - 1 104 | model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) 105 | 106 | with tf.variable_scope("tower"): 107 | result = self.model.create_model(model_input, 108 | num_frames=num_frames, 109 | vocab_size=self.reader.num_classes, 110 | labels=labels_batch, 111 | is_training=False) 112 | 113 | for variable in slim.get_model_variables(): 114 | tf.summary.histogram(variable.op.name, variable) 115 | 116 | predictions = result["predictions"] 117 | 118 | top_predictions, top_indices = tf.nn.top_k(predictions, 119 | _TOP_PREDICTIONS_IN_OUTPUT) 120 | return video_id, top_indices, top_predictions 121 | -------------------------------------------------------------------------------- /feature_extractor/feature_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Facilitates extracting YouTube8M features from RGB images.""" 15 | 16 | import os 17 | import sys 18 | import tarfile 19 | import numpy 20 | from six.moves import urllib 21 | import tensorflow as tf 22 | 23 | INCEPTION_TF_GRAPH = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 24 | YT8M_PCA_MAT = 'http://data.yt8m.org/yt8m_pca.tgz' 25 | MODEL_DIR = os.path.join(os.getenv('HOME'), 'yt8m') 26 | 27 | 28 | class YouTube8MFeatureExtractor(object): 29 | """Extracts YouTube8M features for RGB frames. 30 | 31 | First time constructing this class will create directory `yt8m` inside your 32 | home directory, and will download inception model (85 MB) and YouTube8M PCA 33 | matrix (15 MB). If you want to use another directory, then pass it to argument 34 | `model_dir` of constructor. 35 | 36 | If the model_dir exist and contains the necessary files, then files will be 37 | re-used without download. 38 | 39 | Usage Example: 40 | 41 | from PIL import Image 42 | import numpy 43 | 44 | # Instantiate extractor. Slow if called first time on your machine, as it 45 | # needs to download 100 MB. 46 | extractor = YouTube8MFeatureExtractor() 47 | 48 | image_file = os.path.join(extractor._model_dir, 'cropped_panda.jpg') 49 | 50 | im = numpy.array(Image.open(image_file)) 51 | features = extractor.extract_rgb_frame_features(im) 52 | 53 | ** Note: OpenCV reverses the order of channels (i.e. orders channels as BGR 54 | instead of RGB). If you are using OpenCV, then you must do: 55 | 56 | im = im[:, :, ::-1] # Reverses order on last (i.e. channel) dimension. 57 | 58 | then call `extractor.extract_rgb_frame_features(im)` 59 | """ 60 | 61 | def __init__(self, model_dir=MODEL_DIR): 62 | # Create MODEL_DIR if not created. 63 | self._model_dir = model_dir 64 | if not os.path.exists(model_dir): 65 | os.makedirs(model_dir) 66 | 67 | # Load PCA Matrix. 68 | download_path = self._maybe_download(YT8M_PCA_MAT) 69 | pca_mean = os.path.join(self._model_dir, 'mean.npy') 70 | if not os.path.exists(pca_mean): 71 | tarfile.open(download_path, 'r:gz').extractall(model_dir) 72 | self._load_pca() 73 | 74 | # Load Inception Network 75 | download_path = self._maybe_download(INCEPTION_TF_GRAPH) 76 | inception_proto_file = os.path.join(self._model_dir, 77 | 'classify_image_graph_def.pb') 78 | if not os.path.exists(inception_proto_file): 79 | tarfile.open(download_path, 'r:gz').extractall(model_dir) 80 | self._load_inception(inception_proto_file) 81 | 82 | def extract_rgb_frame_features(self, frame_rgb, apply_pca=True): 83 | """Applies the YouTube8M feature extraction over an RGB frame. 84 | 85 | This passes `frame_rgb` to inception3 model, extracting hidden layer 86 | activations and passing it to the YouTube8M PCA transformation. 87 | 88 | Args: 89 | frame_rgb: numpy array of uint8 with shape (height, width, channels) where 90 | channels must be 3 (RGB), and height and weight can be anything, as the 91 | inception model will resize. 92 | apply_pca: If not set, PCA transformation will be skipped. 93 | 94 | Returns: 95 | Output of inception from `frame_rgb` (2048-D) and optionally passed into 96 | YouTube8M PCA transformation (1024-D). 97 | """ 98 | assert len(frame_rgb.shape) == 3 99 | assert frame_rgb.shape[2] == 3 # 3 channels (R, G, B) 100 | with self._inception_graph.as_default(): 101 | if apply_pca: 102 | frame_features = self.session.run( 103 | 'pca_final_feature:0', feed_dict={'DecodeJpeg:0': frame_rgb}) 104 | else: 105 | frame_features = self.session.run( 106 | 'pool_3/_reshape:0', feed_dict={'DecodeJpeg:0': frame_rgb}) 107 | frame_features = frame_features[0] 108 | return frame_features 109 | 110 | def apply_pca(self, frame_features): 111 | """Applies the YouTube8M PCA Transformation over `frame_features`. 112 | 113 | Args: 114 | frame_features: numpy array of floats, 2048 dimensional vector. 115 | 116 | Returns: 117 | 1024 dimensional vector as a numpy array. 118 | """ 119 | # Subtract mean 120 | feats = frame_features - self.pca_mean 121 | 122 | # Multiply by eigenvectors. 123 | feats = feats.reshape((1, 2048)).dot(self.pca_eigenvecs).reshape((1024,)) 124 | 125 | # Whiten 126 | feats /= numpy.sqrt(self.pca_eigenvals + 1e-4) 127 | return feats 128 | 129 | def _maybe_download(self, url): 130 | """Downloads `url` if not in `_model_dir`.""" 131 | filename = os.path.basename(url) 132 | download_path = os.path.join(self._model_dir, filename) 133 | if os.path.exists(download_path): 134 | return download_path 135 | 136 | def _progress(count, block_size, total_size): 137 | sys.stdout.write( 138 | '\r>> Downloading %s %.1f%%' % 139 | (filename, float(count * block_size) / float(total_size) * 100.0)) 140 | sys.stdout.flush() 141 | 142 | urllib.request.urlretrieve(url, download_path, _progress) 143 | statinfo = os.stat(download_path) 144 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 145 | return download_path 146 | 147 | def _load_inception(self, proto_file): 148 | graph_def = tf.GraphDef.FromString(open(proto_file, 'rb').read()) 149 | self._inception_graph = tf.Graph() 150 | with self._inception_graph.as_default(): 151 | _ = tf.import_graph_def(graph_def, name='') 152 | self.session = tf.Session() 153 | Frame_Features = self.session.graph.get_tensor_by_name( 154 | 'pool_3/_reshape:0') 155 | Pca_Mean = tf.constant(value=self.pca_mean, dtype=tf.float32) 156 | Pca_Eigenvecs = tf.constant(value=self.pca_eigenvecs, dtype=tf.float32) 157 | Pca_Eigenvals = tf.constant(value=self.pca_eigenvals, dtype=tf.float32) 158 | Feats = Frame_Features[0] - Pca_Mean 159 | Feats = tf.reshape( 160 | tf.matmul(tf.reshape(Feats, [1, 2048]), Pca_Eigenvecs), [ 161 | 1024, 162 | ]) 163 | tf.divide(Feats, tf.sqrt(Pca_Eigenvals + 1e-4), name='pca_final_feature') 164 | 165 | def _load_pca(self): 166 | self.pca_mean = numpy.load(os.path.join(self._model_dir, 'mean.npy'))[:, 0] 167 | self.pca_eigenvals = numpy.load( 168 | os.path.join(self._model_dir, 'eigenvals.npy'))[:1024, 0] 169 | self.pca_eigenvecs = numpy.load( 170 | os.path.join(self._model_dir, 'eigenvecs.npy')).T[:, :1024] 171 | -------------------------------------------------------------------------------- /segment_eval_inference.py: -------------------------------------------------------------------------------- 1 | """Eval mAP@N metric from inference file.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from absl import app 8 | from absl import flags 9 | 10 | import mean_average_precision_calculator as map_calculator 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | flags.DEFINE_string( 15 | "eval_data_pattern", "", 16 | "File glob defining the evaluation dataset in tensorflow.SequenceExample " 17 | "format. The SequenceExamples are expected to have an 'rgb' byte array " 18 | "sequence feature as well as a 'labels' int64 context feature.") 19 | flags.DEFINE_string( 20 | "label_cache", "", 21 | "The path for the label cache file. Leave blank for not to cache.") 22 | flags.DEFINE_string("submission_file", "", 23 | "The segment submission file generated by inference.py.") 24 | flags.DEFINE_integer( 25 | "top_n", 0, 26 | "The cap per-class predictions by a maximum of N. Use 0 for not capping.") 27 | 28 | FLAGS = flags.FLAGS 29 | 30 | 31 | class Labels(object): 32 | """Contains the class to hold label objects. 33 | 34 | This class can serialize and de-serialize the groundtruths. 35 | The ground truth is in a mapping from (segment_id, class_id) -> label_score. 36 | """ 37 | 38 | def __init__(self, labels): 39 | """__init__ method.""" 40 | self._labels = labels 41 | 42 | @property 43 | def labels(self): 44 | """Return the ground truth mapping. See class docstring for details.""" 45 | return self._labels 46 | 47 | def to_file(self, file_name): 48 | """Materialize the GT mapping to file.""" 49 | with tf.gfile.Open(file_name, "w") as fobj: 50 | for k, v in self._labels.items(): 51 | seg_id, label = k 52 | line = "%s,%s,%s\n" % (seg_id, label, v) 53 | fobj.write(line) 54 | 55 | @classmethod 56 | def from_file(cls, file_name): 57 | """Read the GT mapping from cached file.""" 58 | labels = {} 59 | with tf.gfile.Open(file_name) as fobj: 60 | for line in fobj: 61 | line = line.strip().strip("\n") 62 | seg_id, label, score = line.split(",") 63 | labels[(seg_id, int(label))] = float(score) 64 | return cls(labels) 65 | 66 | 67 | def read_labels(data_pattern, cache_path=""): 68 | """Read labels from TFRecords. 69 | 70 | Args: 71 | data_pattern: the data pattern to the TFRecords. 72 | cache_path: the cache path for the label file. 73 | 74 | Returns: 75 | a Labels object. 76 | """ 77 | if cache_path: 78 | if tf.gfile.Exists(cache_path): 79 | tf.logging.info("Reading cached labels from %s..." % cache_path) 80 | return Labels.from_file(cache_path) 81 | tf.enable_eager_execution() 82 | data_paths = tf.gfile.Glob(data_pattern) 83 | ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50) 84 | context_features = { 85 | "id": tf.FixedLenFeature([], tf.string), 86 | "segment_labels": tf.VarLenFeature(tf.int64), 87 | "segment_start_times": tf.VarLenFeature(tf.int64), 88 | "segment_scores": tf.VarLenFeature(tf.float32) 89 | } 90 | 91 | def _parse_se_func(sequence_example): 92 | return tf.parse_single_sequence_example(sequence_example, 93 | context_features=context_features) 94 | 95 | ds = ds.map(_parse_se_func) 96 | rated_labels = {} 97 | tf.logging.info("Reading labels from TFRecords...") 98 | last_batch = 0 99 | batch_size = 5000 100 | for cxt_feature_val, _ in ds: 101 | video_id = cxt_feature_val["id"].numpy() 102 | segment_labels = cxt_feature_val["segment_labels"].values.numpy() 103 | segment_start_times = cxt_feature_val["segment_start_times"].values.numpy() 104 | segment_scores = cxt_feature_val["segment_scores"].values.numpy() 105 | for label, start_time, score in zip(segment_labels, segment_start_times, 106 | segment_scores): 107 | rated_labels[("%s:%d" % (video_id, start_time), label)] = score 108 | batch_id = len(rated_labels) // batch_size 109 | if batch_id != last_batch: 110 | tf.logging.info("%d examples processed.", len(rated_labels)) 111 | last_batch = batch_id 112 | tf.logging.info("Finish reading labels from TFRecords...") 113 | labels_obj = Labels(rated_labels) 114 | if cache_path: 115 | tf.logging.info("Caching labels to %s..." % cache_path) 116 | labels_obj.to_file(cache_path) 117 | return labels_obj 118 | 119 | 120 | def read_segment_predictions(file_path, labels, top_n=None): 121 | """Read segement predictions. 122 | 123 | Args: 124 | file_path: the submission file path. 125 | labels: a Labels object containing the eval labels. 126 | top_n: the per-class class capping. 127 | 128 | Returns: 129 | a segment prediction list for each classes. 130 | """ 131 | cls_preds = {} # A label_id to pred list mapping. 132 | with tf.gfile.Open(file_path) as fobj: 133 | tf.logging.info("Reading predictions from %s..." % file_path) 134 | for line in fobj: 135 | label_id, pred_ids_val = line.split(",") 136 | pred_ids = pred_ids_val.split(" ") 137 | if top_n: 138 | pred_ids = pred_ids[:top_n] 139 | pred_ids = [ 140 | pred_id for pred_id in pred_ids 141 | if (pred_id, int(label_id)) in labels.labels 142 | ] 143 | cls_preds[int(label_id)] = pred_ids 144 | if len(cls_preds) % 50 == 0: 145 | tf.logging.info("Processed %d classes..." % len(cls_preds)) 146 | tf.logging.info("Finish reading predictions.") 147 | return cls_preds 148 | 149 | 150 | def main(unused_argv): 151 | """Entry function of the script.""" 152 | if not FLAGS.submission_file: 153 | raise ValueError("You must input submission file.") 154 | eval_labels = read_labels(FLAGS.eval_data_pattern, 155 | cache_path=FLAGS.label_cache) 156 | tf.logging.info("Total rated segments: %d." % len(eval_labels.labels)) 157 | positive_counter = {} 158 | for k, v in eval_labels.labels.items(): 159 | _, label_id = k 160 | if v > 0: 161 | positive_counter[label_id] = positive_counter.get(label_id, 0) + 1 162 | 163 | seg_preds = read_segment_predictions(FLAGS.submission_file, 164 | eval_labels, 165 | top_n=FLAGS.top_n) 166 | map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds)) 167 | seg_labels = [] 168 | seg_scored_preds = [] 169 | num_positives = [] 170 | for label_id in sorted(seg_preds): 171 | class_preds = seg_preds[label_id] 172 | seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds] 173 | seg_labels.append(seg_label) 174 | seg_scored_pred = [] 175 | if class_preds: 176 | seg_scored_pred = [ 177 | float(x) / len(class_preds) for x in range(len(class_preds), 0, -1) 178 | ] 179 | seg_scored_preds.append(seg_scored_pred) 180 | num_positives.append(positive_counter[label_id]) 181 | map_cal.accumulate(seg_scored_preds, seg_labels, num_positives) 182 | map_at_n = np.mean(map_cal.peek_map_at_n()) 183 | tf.logging.info("Num classes: %d | mAP@%d: %.6f" % 184 | (len(seg_preds), FLAGS.top_n, map_at_n)) 185 | 186 | 187 | if __name__ == "__main__": 188 | app.run(main) 189 | -------------------------------------------------------------------------------- /feature_extractor/extract_tfrecords_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Produces tfrecord files similar to the YouTube-8M dataset. 15 | 16 | It processes a CSV file containing lines like ",", where 17 | must be a path of a video, and must be an integer list 18 | joined with semi-colon ";". It processes all videos and outputs tfrecord file 19 | onto --output_tfrecords_file. 20 | 21 | It assumes that you have OpenCV installed and properly linked with ffmpeg (i.e. 22 | function `cv2.VideoCapture().open('/path/to/some/video')` should return True). 23 | 24 | The binary only processes the video stream (images) and not the audio stream. 25 | """ 26 | 27 | import csv 28 | import os 29 | import sys 30 | 31 | import cv2 32 | import feature_extractor 33 | import numpy 34 | import tensorflow as tf 35 | from tensorflow import app 36 | from tensorflow import flags 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | # In OpenCV3.X, this is available as cv2.CAP_PROP_POS_MSEC 41 | # In OpenCV2.X, this is available as cv2.cv.CV_CAP_PROP_POS_MSEC 42 | CAP_PROP_POS_MSEC = 0 43 | 44 | if __name__ == '__main__': 45 | # Required flags for input and output. 46 | flags.DEFINE_string( 47 | 'output_tfrecords_file', None, 48 | 'File containing tfrecords will be written at this path.') 49 | flags.DEFINE_string( 50 | 'input_videos_csv', None, 51 | 'CSV file with lines ",", where ' 52 | ' must be a path of a video and ' 53 | 'must be an integer list joined with semi-colon ";"') 54 | # Optional flags. 55 | flags.DEFINE_string('model_dir', os.path.join(os.getenv('HOME'), 'yt8m'), 56 | 'Directory to store model files. It defaults to ~/yt8m') 57 | 58 | # The following flags are set to match the YouTube-8M dataset format. 59 | flags.DEFINE_integer('frames_per_second', 1, 60 | 'This many frames per second will be processed') 61 | flags.DEFINE_boolean( 62 | 'skip_frame_level_features', False, 63 | 'If set, frame-level features will not be written: only ' 64 | 'video-level features will be written with feature ' 65 | 'names mean_*') 66 | flags.DEFINE_string( 67 | 'labels_feature_key', 'labels', 68 | 'Labels will be written to context feature with this ' 69 | 'key, as int64 list feature.') 70 | flags.DEFINE_string( 71 | 'image_feature_key', 'rgb', 72 | 'Image features will be written to sequence feature with ' 73 | 'this key, as bytes list feature, with only one entry, ' 74 | 'containing quantized feature string.') 75 | flags.DEFINE_string( 76 | 'video_file_feature_key', 'id', 77 | 'Input will be written to context feature ' 78 | 'with this key, as bytes list feature, with only one ' 79 | 'entry, containing the file path of the video. This ' 80 | 'can be used for debugging but not for training or eval.') 81 | flags.DEFINE_boolean( 82 | 'insert_zero_audio_features', True, 83 | 'If set, inserts features with name "audio" to be 128-D ' 84 | 'zero vectors. This allows you to use YouTube-8M ' 85 | 'pre-trained model.') 86 | 87 | 88 | def frame_iterator(filename, every_ms=1000, max_num_frames=300): 89 | """Uses OpenCV to iterate over all frames of filename at a given frequency. 90 | 91 | Args: 92 | filename: Path to video file (e.g. mp4) 93 | every_ms: The duration (in milliseconds) to skip between frames. 94 | max_num_frames: Maximum number of frames to process, taken from the 95 | beginning of the video. 96 | 97 | Yields: 98 | RGB frame with shape (image height, image width, channels) 99 | """ 100 | video_capture = cv2.VideoCapture() 101 | if not video_capture.open(filename): 102 | print >> sys.stderr, 'Error: Cannot open video file ' + filename 103 | return 104 | last_ts = -99999 # The timestamp of last retrieved frame. 105 | num_retrieved = 0 106 | 107 | while num_retrieved < max_num_frames: 108 | # Skip frames 109 | while video_capture.get(CAP_PROP_POS_MSEC) < every_ms + last_ts: 110 | if not video_capture.read()[0]: 111 | return 112 | 113 | last_ts = video_capture.get(CAP_PROP_POS_MSEC) 114 | has_frames, frame = video_capture.read() 115 | if not has_frames: 116 | break 117 | yield frame 118 | num_retrieved += 1 119 | 120 | 121 | def _int64_list_feature(int64_list): 122 | return tf.train.Feature(int64_list=tf.train.Int64List(value=int64_list)) 123 | 124 | 125 | def _bytes_feature(value): 126 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 127 | 128 | 129 | def _make_bytes(int_array): 130 | if bytes == str: # Python2 131 | return ''.join(map(chr, int_array)) 132 | else: 133 | return bytes(int_array) 134 | 135 | 136 | def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0): 137 | """Quantizes float32 `features` into string.""" 138 | assert features.dtype == 'float32' 139 | assert len(features.shape) == 1 # 1-D array 140 | features = numpy.clip(features, min_quantized_value, max_quantized_value) 141 | quantize_range = max_quantized_value - min_quantized_value 142 | features = (features - min_quantized_value) * (255.0 / quantize_range) 143 | features = [int(round(f)) for f in features] 144 | 145 | return _make_bytes(features) 146 | 147 | 148 | def main(unused_argv): 149 | extractor = feature_extractor.YouTube8MFeatureExtractor(FLAGS.model_dir) 150 | writer = tf.python_io.TFRecordWriter(FLAGS.output_tfrecords_file) 151 | total_written = 0 152 | total_error = 0 153 | for video_file, labels in csv.reader(open(FLAGS.input_videos_csv)): 154 | rgb_features = [] 155 | sum_rgb_features = None 156 | for rgb in frame_iterator( 157 | video_file, every_ms=1000.0 / FLAGS.frames_per_second): 158 | features = extractor.extract_rgb_frame_features(rgb[:, :, ::-1]) 159 | if sum_rgb_features is None: 160 | sum_rgb_features = features 161 | else: 162 | sum_rgb_features += features 163 | rgb_features.append(_bytes_feature(quantize(features))) 164 | 165 | if not rgb_features: 166 | print >> sys.stderr, 'Could not get features for ' + video_file 167 | total_error += 1 168 | continue 169 | 170 | mean_rgb_features = sum_rgb_features / len(rgb_features) 171 | 172 | # Create SequenceExample proto and write to output. 173 | feature_list = { 174 | FLAGS.image_feature_key: tf.train.FeatureList(feature=rgb_features), 175 | } 176 | context_features = { 177 | FLAGS.labels_feature_key: 178 | _int64_list_feature(sorted(map(int, labels.split(';')))), 179 | FLAGS.video_file_feature_key: 180 | _bytes_feature(_make_bytes(map(ord, video_file))), 181 | 'mean_' + FLAGS.image_feature_key: 182 | tf.train.Feature( 183 | float_list=tf.train.FloatList(value=mean_rgb_features)), 184 | } 185 | 186 | if FLAGS.insert_zero_audio_features: 187 | zero_vec = [0] * 128 188 | feature_list['audio'] = tf.train.FeatureList( 189 | feature=[_bytes_feature(_make_bytes(zero_vec))] * len(rgb_features)) 190 | context_features['mean_audio'] = tf.train.Feature( 191 | float_list=tf.train.FloatList(value=zero_vec)) 192 | 193 | if FLAGS.skip_frame_level_features: 194 | example = tf.train.SequenceExample( 195 | context=tf.train.Features(feature=context_features)) 196 | else: 197 | example = tf.train.SequenceExample( 198 | context=tf.train.Features(feature=context_features), 199 | feature_lists=tf.train.FeatureLists(feature_list=feature_list)) 200 | writer.write(example.SerializeToString()) 201 | total_written += 1 202 | 203 | writer.close() 204 | print('Successfully encoded %i out of %i videos' % 205 | (total_written, total_written + total_error)) 206 | 207 | 208 | if __name__ == '__main__': 209 | app.run(main) 210 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Contains a collection of util functions for training and evaluating.""" 15 | 16 | import numpy 17 | import tensorflow as tf 18 | from tensorflow import logging 19 | 20 | try: 21 | xrange # Python 2 22 | except NameError: 23 | xrange = range # Python 3 24 | 25 | 26 | def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2): 27 | """Dequantize the feature from the byte format to the float format. 28 | 29 | Args: 30 | feat_vector: the input 1-d vector. 31 | max_quantized_value: the maximum of the quantized value. 32 | min_quantized_value: the minimum of the quantized value. 33 | 34 | Returns: 35 | A float vector which has the same shape as feat_vector. 36 | """ 37 | assert max_quantized_value > min_quantized_value 38 | quantized_range = max_quantized_value - min_quantized_value 39 | scalar = quantized_range / 255.0 40 | bias = (quantized_range / 512.0) + min_quantized_value 41 | return feat_vector * scalar + bias 42 | 43 | 44 | def MakeSummary(name, value): 45 | """Creates a tf.Summary proto with the given name and value.""" 46 | summary = tf.Summary() 47 | val = summary.value.add() 48 | val.tag = str(name) 49 | val.simple_value = float(value) 50 | return summary 51 | 52 | 53 | def AddGlobalStepSummary(summary_writer, 54 | global_step_val, 55 | global_step_info_dict, 56 | summary_scope="Eval"): 57 | """Add the global_step summary to the Tensorboard. 58 | 59 | Args: 60 | summary_writer: Tensorflow summary_writer. 61 | global_step_val: a int value of the global step. 62 | global_step_info_dict: a dictionary of the evaluation metrics calculated for 63 | a mini-batch. 64 | summary_scope: Train or Eval. 65 | 66 | Returns: 67 | A string of this global_step summary 68 | """ 69 | this_hit_at_one = global_step_info_dict["hit_at_one"] 70 | this_perr = global_step_info_dict["perr"] 71 | this_loss = global_step_info_dict["loss"] 72 | examples_per_second = global_step_info_dict.get("examples_per_second", -1) 73 | 74 | summary_writer.add_summary( 75 | MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one), 76 | global_step_val) 77 | summary_writer.add_summary( 78 | MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr), 79 | global_step_val) 80 | summary_writer.add_summary( 81 | MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss), 82 | global_step_val) 83 | 84 | if examples_per_second != -1: 85 | summary_writer.add_summary( 86 | MakeSummary("GlobalStep/" + summary_scope + "_Example_Second", 87 | examples_per_second), global_step_val) 88 | 89 | summary_writer.flush() 90 | info = ( 91 | "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch " 92 | "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format( 93 | global_step_val, this_hit_at_one, this_perr, this_loss, 94 | examples_per_second) 95 | return info 96 | 97 | 98 | def AddEpochSummary(summary_writer, 99 | global_step_val, 100 | epoch_info_dict, 101 | summary_scope="Eval"): 102 | """Add the epoch summary to the Tensorboard. 103 | 104 | Args: 105 | summary_writer: Tensorflow summary_writer. 106 | global_step_val: a int value of the global step. 107 | epoch_info_dict: a dictionary of the evaluation metrics calculated for the 108 | whole epoch. 109 | summary_scope: Train or Eval. 110 | 111 | Returns: 112 | A string of this global_step summary 113 | """ 114 | epoch_id = epoch_info_dict["epoch_id"] 115 | avg_hit_at_one = epoch_info_dict["avg_hit_at_one"] 116 | avg_perr = epoch_info_dict["avg_perr"] 117 | avg_loss = epoch_info_dict["avg_loss"] 118 | aps = epoch_info_dict["aps"] 119 | gap = epoch_info_dict["gap"] 120 | mean_ap = numpy.mean(aps) 121 | 122 | summary_writer.add_summary( 123 | MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one), 124 | global_step_val) 125 | summary_writer.add_summary( 126 | MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr), 127 | global_step_val) 128 | summary_writer.add_summary( 129 | MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss), 130 | global_step_val) 131 | summary_writer.add_summary( 132 | MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val) 133 | summary_writer.add_summary( 134 | MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val) 135 | summary_writer.flush() 136 | 137 | info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} " 138 | "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}" 139 | ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss, 140 | len(aps)) 141 | return info 142 | 143 | 144 | def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes): 145 | """Extract the list of feature names and the dimensionality of each feature 146 | 147 | from string of comma separated values. 148 | 149 | Args: 150 | feature_names: string containing comma separated list of feature names 151 | feature_sizes: string containing comma separated list of feature sizes 152 | 153 | Returns: 154 | List of the feature names and list of the dimensionality of each feature. 155 | Elements in the first/second list are strings/integers. 156 | """ 157 | list_of_feature_names = [ 158 | feature_names.strip() for feature_names in feature_names.split(",") 159 | ] 160 | list_of_feature_sizes = [ 161 | int(feature_sizes) for feature_sizes in feature_sizes.split(",") 162 | ] 163 | if len(list_of_feature_names) != len(list_of_feature_sizes): 164 | logging.error("length of the feature names (=" + 165 | str(len(list_of_feature_names)) + ") != length of feature " 166 | "sizes (=" + str(len(list_of_feature_sizes)) + ")") 167 | 168 | return list_of_feature_names, list_of_feature_sizes 169 | 170 | 171 | def clip_gradient_norms(gradients_to_variables, max_norm): 172 | """Clips the gradients by the given value. 173 | 174 | Args: 175 | gradients_to_variables: A list of gradient to variable pairs (tuples). 176 | max_norm: the maximum norm value. 177 | 178 | Returns: 179 | A list of clipped gradient to variable pairs. 180 | """ 181 | clipped_grads_and_vars = [] 182 | for grad, var in gradients_to_variables: 183 | if grad is not None: 184 | if isinstance(grad, tf.IndexedSlices): 185 | tmp = tf.clip_by_norm(grad.values, max_norm) 186 | grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape) 187 | else: 188 | grad = tf.clip_by_norm(grad, max_norm) 189 | clipped_grads_and_vars.append((grad, var)) 190 | return clipped_grads_and_vars 191 | 192 | 193 | def combine_gradients(tower_grads): 194 | """Calculate the combined gradient for each shared variable across all towers. 195 | 196 | Note that this function provides a synchronization point across all towers. 197 | 198 | Args: 199 | tower_grads: List of lists of (gradient, variable) tuples. The outer list is 200 | over individual gradients. The inner list is over the gradient calculation 201 | for each tower. 202 | 203 | Returns: 204 | List of pairs of (gradient, variable) where the gradient has been summed 205 | across all towers. 206 | """ 207 | filtered_grads = [ 208 | [x for x in grad_list if x[0] is not None] for grad_list in tower_grads 209 | ] 210 | final_grads = [] 211 | for i in xrange(len(filtered_grads[0])): 212 | grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))] 213 | grad = tf.stack([x[0] for x in grads], 0) 214 | grad = tf.reduce_sum(grad, 0) 215 | final_grads.append(( 216 | grad, 217 | filtered_grads[0][i][1], 218 | )) 219 | 220 | return final_grads 221 | -------------------------------------------------------------------------------- /eval_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Provides functions to help with evaluating models.""" 15 | import average_precision_calculator as ap_calculator 16 | import mean_average_precision_calculator as map_calculator 17 | import numpy 18 | from tensorflow.python.platform import gfile 19 | 20 | 21 | def flatten(l): 22 | """Merges a list of lists into a single list. """ 23 | return [item for sublist in l for item in sublist] 24 | 25 | 26 | def calculate_hit_at_one(predictions, actuals): 27 | """Performs a local (numpy) calculation of the hit at one. 28 | 29 | Args: 30 | predictions: Matrix containing the outputs of the model. Dimensions are 31 | 'batch' x 'num_classes'. 32 | actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x 33 | 'num_classes'. 34 | 35 | Returns: 36 | float: The average hit at one across the entire batch. 37 | """ 38 | top_prediction = numpy.argmax(predictions, 1) 39 | hits = actuals[numpy.arange(actuals.shape[0]), top_prediction] 40 | return numpy.average(hits) 41 | 42 | 43 | def calculate_precision_at_equal_recall_rate(predictions, actuals): 44 | """Performs a local (numpy) calculation of the PERR. 45 | 46 | Args: 47 | predictions: Matrix containing the outputs of the model. Dimensions are 48 | 'batch' x 'num_classes'. 49 | actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x 50 | 'num_classes'. 51 | 52 | Returns: 53 | float: The average precision at equal recall rate across the entire batch. 54 | """ 55 | aggregated_precision = 0.0 56 | num_videos = actuals.shape[0] 57 | for row in numpy.arange(num_videos): 58 | num_labels = int(numpy.sum(actuals[row])) 59 | top_indices = numpy.argpartition(predictions[row], 60 | -num_labels)[-num_labels:] 61 | item_precision = 0.0 62 | for label_index in top_indices: 63 | if predictions[row][label_index] > 0: 64 | item_precision += actuals[row][label_index] 65 | item_precision /= top_indices.size 66 | aggregated_precision += item_precision 67 | aggregated_precision /= num_videos 68 | return aggregated_precision 69 | 70 | 71 | def calculate_gap(predictions, actuals, top_k=20): 72 | """Performs a local (numpy) calculation of the global average precision. 73 | 74 | Only the top_k predictions are taken for each of the videos. 75 | 76 | Args: 77 | predictions: Matrix containing the outputs of the model. Dimensions are 78 | 'batch' x 'num_classes'. 79 | actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x 80 | 'num_classes'. 81 | top_k: How many predictions to use per video. 82 | 83 | Returns: 84 | float: The global average precision. 85 | """ 86 | gap_calculator = ap_calculator.AveragePrecisionCalculator() 87 | sparse_predictions, sparse_labels, num_positives = top_k_by_class( 88 | predictions, actuals, top_k) 89 | gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels), 90 | sum(num_positives)) 91 | return gap_calculator.peek_ap_at_n() 92 | 93 | 94 | def top_k_by_class(predictions, labels, k=20): 95 | """Extracts the top k predictions for each video, sorted by class. 96 | 97 | Args: 98 | predictions: A numpy matrix containing the outputs of the model. Dimensions 99 | are 'batch' x 'num_classes'. 100 | k: the top k non-zero entries to preserve in each prediction. 101 | 102 | Returns: 103 | A tuple (predictions,labels, true_positives). 'predictions' and 'labels' 104 | are lists of lists of floats. 'true_positives' is a list of scalars. The 105 | length of the lists are equal to the number of classes. The entries in the 106 | predictions variable are probability predictions, and 107 | the corresponding entries in the labels variable are the ground truth for 108 | those predictions. The entries in 'true_positives' are the number of true 109 | positives for each class in the ground truth. 110 | 111 | Raises: 112 | ValueError: An error occurred when the k is not a positive integer. 113 | """ 114 | if k <= 0: 115 | raise ValueError("k must be a positive integer.") 116 | k = min(k, predictions.shape[1]) 117 | num_classes = predictions.shape[1] 118 | prediction_triplets = [] 119 | for video_index in range(predictions.shape[0]): 120 | prediction_triplets.extend( 121 | top_k_triplets(predictions[video_index], labels[video_index], k)) 122 | out_predictions = [[] for _ in range(num_classes)] 123 | out_labels = [[] for _ in range(num_classes)] 124 | for triplet in prediction_triplets: 125 | out_predictions[triplet[0]].append(triplet[1]) 126 | out_labels[triplet[0]].append(triplet[2]) 127 | out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)] 128 | 129 | return out_predictions, out_labels, out_true_positives 130 | 131 | 132 | def top_k_triplets(predictions, labels, k=20): 133 | """Get the top_k for a 1-d numpy array. 134 | 135 | Returns a sparse list of tuples in 136 | (prediction, class) format 137 | """ 138 | m = len(predictions) 139 | k = min(k, m) 140 | indices = numpy.argpartition(predictions, -k)[-k:] 141 | return [(index, predictions[index], labels[index]) for index in indices] 142 | 143 | 144 | class EvaluationMetrics(object): 145 | """A class to store the evaluation metrics.""" 146 | 147 | def __init__(self, num_class, top_k, top_n): 148 | """Construct an EvaluationMetrics object to store the evaluation metrics. 149 | 150 | Args: 151 | num_class: A positive integer specifying the number of classes. 152 | top_k: A positive integer specifying how many predictions are considered 153 | per video. 154 | top_n: A positive Integer specifying the average precision at n, or None 155 | to use all provided data points. 156 | 157 | Raises: 158 | ValueError: An error occurred when MeanAveragePrecisionCalculator cannot 159 | not be constructed. 160 | """ 161 | self.sum_hit_at_one = 0.0 162 | self.sum_perr = 0.0 163 | self.sum_loss = 0.0 164 | self.map_calculator = map_calculator.MeanAveragePrecisionCalculator( 165 | num_class, top_n=top_n) 166 | self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator() 167 | self.top_k = top_k 168 | self.num_examples = 0 169 | 170 | def accumulate(self, predictions, labels, loss): 171 | """Accumulate the metrics calculated locally for this mini-batch. 172 | 173 | Args: 174 | predictions: A numpy matrix containing the outputs of the model. 175 | Dimensions are 'batch' x 'num_classes'. 176 | labels: A numpy matrix containing the ground truth labels. Dimensions are 177 | 'batch' x 'num_classes'. 178 | loss: A numpy array containing the loss for each sample. 179 | 180 | Returns: 181 | dictionary: A dictionary storing the metrics for the mini-batch. 182 | 183 | Raises: 184 | ValueError: An error occurred when the shape of predictions and actuals 185 | does not match. 186 | """ 187 | batch_size = labels.shape[0] 188 | mean_hit_at_one = calculate_hit_at_one(predictions, labels) 189 | mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels) 190 | mean_loss = numpy.mean(loss) 191 | 192 | # Take the top 20 predictions. 193 | sparse_predictions, sparse_labels, num_positives = top_k_by_class( 194 | predictions, labels, self.top_k) 195 | self.map_calculator.accumulate(sparse_predictions, sparse_labels, 196 | num_positives) 197 | self.global_ap_calculator.accumulate(flatten(sparse_predictions), 198 | flatten(sparse_labels), 199 | sum(num_positives)) 200 | 201 | self.num_examples += batch_size 202 | self.sum_hit_at_one += mean_hit_at_one * batch_size 203 | self.sum_perr += mean_perr * batch_size 204 | self.sum_loss += mean_loss * batch_size 205 | 206 | return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss} 207 | 208 | def get(self): 209 | """Calculate the evaluation metrics for the whole epoch. 210 | 211 | Raises: 212 | ValueError: If no examples were accumulated. 213 | 214 | Returns: 215 | dictionary: a dictionary storing the evaluation metrics for the epoch. The 216 | dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and 217 | aps (default nan). 218 | """ 219 | if self.num_examples <= 0: 220 | raise ValueError("total_sample must be positive.") 221 | avg_hit_at_one = self.sum_hit_at_one / self.num_examples 222 | avg_perr = self.sum_perr / self.num_examples 223 | avg_loss = self.sum_loss / self.num_examples 224 | 225 | aps = self.map_calculator.peek_map_at_n() 226 | gap = self.global_ap_calculator.peek_ap_at_n() 227 | 228 | epoch_info_dict = { 229 | "avg_hit_at_one": avg_hit_at_one, 230 | "avg_perr": avg_perr, 231 | "avg_loss": avg_loss, 232 | "aps": aps, 233 | "gap": gap 234 | } 235 | return epoch_info_dict 236 | 237 | def clear(self): 238 | """Clear the evaluation metrics and reset the EvaluationMetrics object.""" 239 | self.sum_hit_at_one = 0.0 240 | self.sum_perr = 0.0 241 | self.sum_loss = 0.0 242 | self.map_calculator.clear() 243 | self.global_ap_calculator.clear() 244 | self.num_examples = 0 245 | -------------------------------------------------------------------------------- /average_precision_calculator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Calculate or keep track of the interpolated average precision. 15 | 16 | It provides an interface for calculating interpolated average precision for an 17 | entire list or the top-n ranked items. For the definition of the 18 | (non-)interpolated average precision: 19 | http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf 20 | 21 | Example usages: 22 | 1) Use it as a static function call to directly calculate average precision for 23 | a short ranked list in the memory. 24 | 25 | ``` 26 | import random 27 | 28 | p = np.array([random.random() for _ in xrange(10)]) 29 | a = np.array([random.choice([0, 1]) for _ in xrange(10)]) 30 | 31 | ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a) 32 | ``` 33 | 34 | 2) Use it as an object for long ranked list that cannot be stored in memory or 35 | the case where partial predictions can be observed at a time (Tensorflow 36 | predictions). In this case, we first call the function accumulate many times 37 | to process parts of the ranked list. After processing all the parts, we call 38 | peek_interpolated_ap_at_n. 39 | ``` 40 | p1 = np.array([random.random() for _ in xrange(5)]) 41 | a1 = np.array([random.choice([0, 1]) for _ in xrange(5)]) 42 | p2 = np.array([random.random() for _ in xrange(5)]) 43 | a2 = np.array([random.choice([0, 1]) for _ in xrange(5)]) 44 | 45 | # interpolated average precision at 10 using 1000 break points 46 | calculator = average_precision_calculator.AveragePrecisionCalculator(10) 47 | calculator.accumulate(p1, a1) 48 | calculator.accumulate(p2, a2) 49 | ap3 = calculator.peek_ap_at_n() 50 | ``` 51 | """ 52 | 53 | import heapq 54 | import random 55 | import numbers 56 | 57 | import numpy 58 | 59 | 60 | class AveragePrecisionCalculator(object): 61 | """Calculate the average precision and average precision at n.""" 62 | 63 | def __init__(self, top_n=None): 64 | """Construct an AveragePrecisionCalculator to calculate average precision. 65 | 66 | This class is used to calculate the average precision for a single label. 67 | 68 | Args: 69 | top_n: A positive Integer specifying the average precision at n, or None 70 | to use all provided data points. 71 | 72 | Raises: 73 | ValueError: An error occurred when the top_n is not a positive integer. 74 | """ 75 | if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None): 76 | raise ValueError("top_n must be a positive integer or None.") 77 | 78 | self._top_n = top_n # average precision at n 79 | self._total_positives = 0 # total number of positives have seen 80 | self._heap = [] # max heap of (prediction, actual) 81 | 82 | @property 83 | def heap_size(self): 84 | """Gets the heap size maintained in the class.""" 85 | return len(self._heap) 86 | 87 | @property 88 | def num_accumulated_positives(self): 89 | """Gets the number of positive samples that have been accumulated.""" 90 | return self._total_positives 91 | 92 | def accumulate(self, predictions, actuals, num_positives=None): 93 | """Accumulate the predictions and their ground truth labels. 94 | 95 | After the function call, we may call peek_ap_at_n to actually calculate 96 | the average precision. 97 | Note predictions and actuals must have the same shape. 98 | 99 | Args: 100 | predictions: a list storing the prediction scores. 101 | actuals: a list storing the ground truth labels. Any value larger than 0 102 | will be treated as positives, otherwise as negatives. num_positives = If 103 | the 'predictions' and 'actuals' inputs aren't complete, then it's 104 | possible some true positives were missed in them. In that case, you can 105 | provide 'num_positives' in order to accurately track recall. 106 | 107 | Raises: 108 | ValueError: An error occurred when the format of the input is not the 109 | numpy 1-D array or the shape of predictions and actuals does not match. 110 | """ 111 | if len(predictions) != len(actuals): 112 | raise ValueError("the shape of predictions and actuals does not match.") 113 | 114 | if num_positives is not None: 115 | if not isinstance(num_positives, numbers.Number) or num_positives < 0: 116 | raise ValueError( 117 | "'num_positives' was provided but it was a negative number.") 118 | 119 | if num_positives is not None: 120 | self._total_positives += num_positives 121 | else: 122 | self._total_positives += numpy.size( 123 | numpy.where(numpy.array(actuals) > 1e-5)) 124 | topk = self._top_n 125 | heap = self._heap 126 | 127 | for i in range(numpy.size(predictions)): 128 | if topk is None or len(heap) < topk: 129 | heapq.heappush(heap, (predictions[i], actuals[i])) 130 | else: 131 | if predictions[i] > heap[0][0]: # heap[0] is the smallest 132 | heapq.heappop(heap) 133 | heapq.heappush(heap, (predictions[i], actuals[i])) 134 | 135 | def clear(self): 136 | """Clear the accumulated predictions.""" 137 | self._heap = [] 138 | self._total_positives = 0 139 | 140 | def peek_ap_at_n(self): 141 | """Peek the non-interpolated average precision at n. 142 | 143 | Returns: 144 | The non-interpolated average precision at n (default 0). 145 | If n is larger than the length of the ranked list, 146 | the average precision will be returned. 147 | """ 148 | if self.heap_size <= 0: 149 | return 0 150 | predlists = numpy.array(list(zip(*self._heap))) 151 | 152 | ap = self.ap_at_n(predlists[0], 153 | predlists[1], 154 | n=self._top_n, 155 | total_num_positives=self._total_positives) 156 | return ap 157 | 158 | @staticmethod 159 | def ap(predictions, actuals): 160 | """Calculate the non-interpolated average precision. 161 | 162 | Args: 163 | predictions: a numpy 1-D array storing the sparse prediction scores. 164 | actuals: a numpy 1-D array storing the ground truth labels. Any value 165 | larger than 0 will be treated as positives, otherwise as negatives. 166 | 167 | Returns: 168 | The non-interpolated average precision at n. 169 | If n is larger than the length of the ranked list, 170 | the average precision will be returned. 171 | 172 | Raises: 173 | ValueError: An error occurred when the format of the input is not the 174 | numpy 1-D array or the shape of predictions and actuals does not match. 175 | """ 176 | return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None) 177 | 178 | @staticmethod 179 | def ap_at_n(predictions, actuals, n=20, total_num_positives=None): 180 | """Calculate the non-interpolated average precision. 181 | 182 | Args: 183 | predictions: a numpy 1-D array storing the sparse prediction scores. 184 | actuals: a numpy 1-D array storing the ground truth labels. Any value 185 | larger than 0 will be treated as positives, otherwise as negatives. 186 | n: the top n items to be considered in ap@n. 187 | total_num_positives : (optionally) you can specify the number of total 188 | positive in the list. If specified, it will be used in calculation. 189 | 190 | Returns: 191 | The non-interpolated average precision at n. 192 | If n is larger than the length of the ranked list, 193 | the average precision will be returned. 194 | 195 | Raises: 196 | ValueError: An error occurred when 197 | 1) the format of the input is not the numpy 1-D array; 198 | 2) the shape of predictions and actuals does not match; 199 | 3) the input n is not a positive integer. 200 | """ 201 | if len(predictions) != len(actuals): 202 | raise ValueError("the shape of predictions and actuals does not match.") 203 | 204 | if n is not None: 205 | if not isinstance(n, int) or n <= 0: 206 | raise ValueError("n must be 'None' or a positive integer." 207 | " It was '%s'." % n) 208 | 209 | ap = 0.0 210 | 211 | predictions = numpy.array(predictions) 212 | actuals = numpy.array(actuals) 213 | 214 | # add a shuffler to avoid overestimating the ap 215 | predictions, actuals = AveragePrecisionCalculator._shuffle( 216 | predictions, actuals) 217 | sortidx = sorted(range(len(predictions)), 218 | key=lambda k: predictions[k], 219 | reverse=True) 220 | 221 | if total_num_positives is None: 222 | numpos = numpy.size(numpy.where(actuals > 0)) 223 | else: 224 | numpos = total_num_positives 225 | 226 | if numpos == 0: 227 | return 0 228 | 229 | if n is not None: 230 | numpos = min(numpos, n) 231 | delta_recall = 1.0 / numpos 232 | poscount = 0.0 233 | 234 | # calculate the ap 235 | r = len(sortidx) 236 | if n is not None: 237 | r = min(r, n) 238 | for i in range(r): 239 | if actuals[sortidx[i]] > 0: 240 | poscount += 1 241 | ap += poscount / (i + 1) * delta_recall 242 | return ap 243 | 244 | @staticmethod 245 | def _shuffle(predictions, actuals): 246 | random.seed(0) 247 | suffidx = random.sample(range(len(predictions)), len(predictions)) 248 | predictions = predictions[suffidx] 249 | actuals = actuals[suffidx] 250 | return predictions, actuals 251 | 252 | @staticmethod 253 | def _zero_one_normalize(predictions, epsilon=1e-7): 254 | """Normalize the predictions to the range between 0.0 and 1.0. 255 | 256 | For some predictions like SVM predictions, we need to normalize them before 257 | calculate the interpolated average precision. The normalization will not 258 | change the rank in the original list and thus won't change the average 259 | precision. 260 | 261 | Args: 262 | predictions: a numpy 1-D array storing the sparse prediction scores. 263 | epsilon: a small constant to avoid denominator being zero. 264 | 265 | Returns: 266 | The normalized prediction. 267 | """ 268 | denominator = numpy.max(predictions) - numpy.min(predictions) 269 | ret = (predictions - numpy.min(predictions)) / numpy.max( 270 | denominator, epsilon) 271 | return ret 272 | -------------------------------------------------------------------------------- /segment_label_ids.csv: -------------------------------------------------------------------------------- 1 | Index 2 | 3 3 | 7 4 | 8 5 | 11 6 | 12 7 | 17 8 | 18 9 | 19 10 | 21 11 | 22 12 | 23 13 | 28 14 | 31 15 | 30 16 | 32 17 | 33 18 | 34 19 | 41 20 | 43 21 | 45 22 | 46 23 | 48 24 | 53 25 | 54 26 | 52 27 | 55 28 | 58 29 | 59 30 | 60 31 | 61 32 | 65 33 | 68 34 | 73 35 | 71 36 | 74 37 | 75 38 | 76 39 | 77 40 | 80 41 | 83 42 | 90 43 | 88 44 | 89 45 | 92 46 | 95 47 | 100 48 | 101 49 | 99 50 | 104 51 | 105 52 | 109 53 | 113 54 | 112 55 | 115 56 | 116 57 | 118 58 | 120 59 | 121 60 | 123 61 | 125 62 | 127 63 | 131 64 | 128 65 | 129 66 | 130 67 | 137 68 | 141 69 | 143 70 | 145 71 | 148 72 | 152 73 | 151 74 | 156 75 | 155 76 | 158 77 | 160 78 | 164 79 | 163 80 | 169 81 | 170 82 | 172 83 | 171 84 | 173 85 | 174 86 | 175 87 | 176 88 | 178 89 | 182 90 | 184 91 | 186 92 | 188 93 | 187 94 | 192 95 | 191 96 | 190 97 | 194 98 | 197 99 | 196 100 | 198 101 | 201 102 | 202 103 | 200 104 | 199 105 | 205 106 | 204 107 | 209 108 | 207 109 | 206 110 | 210 111 | 213 112 | 214 113 | 220 114 | 218 115 | 217 116 | 226 117 | 227 118 | 231 119 | 232 120 | 229 121 | 233 122 | 235 123 | 237 124 | 244 125 | 240 126 | 249 127 | 246 128 | 248 129 | 239 130 | 250 131 | 245 132 | 255 133 | 253 134 | 256 135 | 261 136 | 259 137 | 263 138 | 262 139 | 266 140 | 267 141 | 268 142 | 269 143 | 271 144 | 276 145 | 273 146 | 277 147 | 274 148 | 278 149 | 279 150 | 280 151 | 288 152 | 291 153 | 295 154 | 294 155 | 293 156 | 297 157 | 296 158 | 300 159 | 299 160 | 303 161 | 302 162 | 304 163 | 305 164 | 313 165 | 307 166 | 311 167 | 310 168 | 312 169 | 316 170 | 318 171 | 321 172 | 322 173 | 331 174 | 333 175 | 329 176 | 330 177 | 334 178 | 343 179 | 349 180 | 340 181 | 344 182 | 348 183 | 358 184 | 347 185 | 359 186 | 355 187 | 361 188 | 360 189 | 364 190 | 365 191 | 368 192 | 369 193 | 366 194 | 370 195 | 374 196 | 380 197 | 373 198 | 385 199 | 384 200 | 388 201 | 389 202 | 382 203 | 393 204 | 381 205 | 390 206 | 394 207 | 399 208 | 397 209 | 396 210 | 402 211 | 400 212 | 398 213 | 401 214 | 405 215 | 406 216 | 410 217 | 408 218 | 416 219 | 415 220 | 419 221 | 422 222 | 414 223 | 421 224 | 424 225 | 429 226 | 418 227 | 427 228 | 434 229 | 428 230 | 435 231 | 430 232 | 441 233 | 439 234 | 437 235 | 443 236 | 440 237 | 442 238 | 445 239 | 446 240 | 448 241 | 454 242 | 444 243 | 453 244 | 455 245 | 451 246 | 452 247 | 458 248 | 460 249 | 465 250 | 457 251 | 463 252 | 462 253 | 461 254 | 464 255 | 469 256 | 468 257 | 472 258 | 473 259 | 471 260 | 475 261 | 474 262 | 477 263 | 485 264 | 491 265 | 488 266 | 482 267 | 490 268 | 496 269 | 494 270 | 483 271 | 495 272 | 493 273 | 507 274 | 501 275 | 499 276 | 503 277 | 498 278 | 514 279 | 504 280 | 502 281 | 506 282 | 508 283 | 511 284 | 527 285 | 526 286 | 532 287 | 513 288 | 519 289 | 525 290 | 518 291 | 528 292 | 522 293 | 523 294 | 535 295 | 539 296 | 540 297 | 533 298 | 521 299 | 541 300 | 547 301 | 550 302 | 544 303 | 549 304 | 551 305 | 554 306 | 543 307 | 548 308 | 557 309 | 560 310 | 552 311 | 559 312 | 563 313 | 565 314 | 567 315 | 555 316 | 576 317 | 568 318 | 564 319 | 573 320 | 581 321 | 580 322 | 572 323 | 571 324 | 584 325 | 590 326 | 585 327 | 587 328 | 588 329 | 592 330 | 598 331 | 597 332 | 599 333 | 603 334 | 600 335 | 604 336 | 605 337 | 614 338 | 602 339 | 610 340 | 608 341 | 611 342 | 612 343 | 613 344 | 617 345 | 620 346 | 607 347 | 624 348 | 627 349 | 625 350 | 631 351 | 629 352 | 638 353 | 632 354 | 634 355 | 644 356 | 641 357 | 642 358 | 646 359 | 652 360 | 647 361 | 637 362 | 661 363 | 635 364 | 658 365 | 648 366 | 663 367 | 668 368 | 664 369 | 656 370 | 666 371 | 671 372 | 683 373 | 675 374 | 669 375 | 676 376 | 667 377 | 691 378 | 685 379 | 673 380 | 688 381 | 702 382 | 684 383 | 679 384 | 694 385 | 686 386 | 689 387 | 680 388 | 693 389 | 703 390 | 697 391 | 698 392 | 692 393 | 705 394 | 706 395 | 712 396 | 711 397 | 709 398 | 710 399 | 726 400 | 713 401 | 721 402 | 720 403 | 715 404 | 717 405 | 730 406 | 728 407 | 723 408 | 716 409 | 722 410 | 718 411 | 732 412 | 724 413 | 736 414 | 725 415 | 742 416 | 727 417 | 735 418 | 740 419 | 748 420 | 738 421 | 746 422 | 751 423 | 749 424 | 752 425 | 754 426 | 760 427 | 763 428 | 756 429 | 758 430 | 766 431 | 764 432 | 757 433 | 780 434 | 767 435 | 769 436 | 771 437 | 786 438 | 785 439 | 781 440 | 787 441 | 778 442 | 783 443 | 792 444 | 791 445 | 795 446 | 788 447 | 805 448 | 802 449 | 801 450 | 793 451 | 796 452 | 804 453 | 803 454 | 797 455 | 814 456 | 813 457 | 789 458 | 808 459 | 818 460 | 816 461 | 817 462 | 811 463 | 820 464 | 826 465 | 829 466 | 824 467 | 821 468 | 825 469 | 822 470 | 835 471 | 833 472 | 843 473 | 823 474 | 827 475 | 830 476 | 832 477 | 837 478 | 852 479 | 844 480 | 841 481 | 812 482 | 847 483 | 862 484 | 869 485 | 860 486 | 838 487 | 870 488 | 846 489 | 858 490 | 854 491 | 880 492 | 876 493 | 857 494 | 859 495 | 877 496 | 871 497 | 855 498 | 875 499 | 861 500 | 867 501 | 892 502 | 898 503 | 888 504 | 884 505 | 887 506 | 891 507 | 906 508 | 900 509 | 878 510 | 885 511 | 883 512 | 901 513 | 903 514 | 907 515 | 930 516 | 897 517 | 914 518 | 917 519 | 910 520 | 905 521 | 909 522 | 933 523 | 932 524 | 922 525 | 913 526 | 923 527 | 931 528 | 911 529 | 937 530 | 918 531 | 955 532 | 915 533 | 944 534 | 952 535 | 945 536 | 948 537 | 946 538 | 970 539 | 974 540 | 958 541 | 925 542 | 979 543 | 942 544 | 965 545 | 975 546 | 950 547 | 982 548 | 940 549 | 973 550 | 962 551 | 972 552 | 957 553 | 984 554 | 983 555 | 964 556 | 1007 557 | 971 558 | 981 559 | 954 560 | 993 561 | 991 562 | 996 563 | 1005 564 | 1015 565 | 1009 566 | 995 567 | 986 568 | 1000 569 | 985 570 | 980 571 | 1016 572 | 1011 573 | 999 574 | 1002 575 | 994 576 | 1013 577 | 1010 578 | 992 579 | 1008 580 | 1036 581 | 1025 582 | 1012 583 | 990 584 | 1037 585 | 1040 586 | 1031 587 | 1019 588 | 1052 589 | 1001 590 | 1055 591 | 1032 592 | 1069 593 | 1058 594 | 1014 595 | 1023 596 | 1030 597 | 1061 598 | 1035 599 | 1034 600 | 1053 601 | 1045 602 | 1046 603 | 1067 604 | 1060 605 | 1049 606 | 1056 607 | 1074 608 | 1066 609 | 1044 610 | 1038 611 | 1073 612 | 1077 613 | 1068 614 | 1057 615 | 1072 616 | 1104 617 | 1083 618 | 1089 619 | 1087 620 | 1099 621 | 1076 622 | 1086 623 | 1098 624 | 1094 625 | 1095 626 | 1096 627 | 1101 628 | 1107 629 | 1105 630 | 1117 631 | 1093 632 | 1106 633 | 1122 634 | 1119 635 | 1103 636 | 1128 637 | 1120 638 | 1126 639 | 1102 640 | 1115 641 | 1124 642 | 1123 643 | 1131 644 | 1136 645 | 1144 646 | 1121 647 | 1137 648 | 1132 649 | 1133 650 | 1157 651 | 1134 652 | 1143 653 | 1159 654 | 1164 655 | 1155 656 | 1142 657 | 1150 658 | 1148 659 | 1161 660 | 1165 661 | 1147 662 | 1162 663 | 1152 664 | 1174 665 | 1160 666 | 1166 667 | 1190 668 | 1175 669 | 1167 670 | 1156 671 | 1180 672 | 1171 673 | 1179 674 | 1172 675 | 1186 676 | 1188 677 | 1201 678 | 1177 679 | 1208 680 | 1183 681 | 1189 682 | 1192 683 | 1209 684 | 1214 685 | 1197 686 | 1168 687 | 1202 688 | 1205 689 | 1203 690 | 1199 691 | 1219 692 | 1217 693 | 1187 694 | 1206 695 | 1210 696 | 1241 697 | 1221 698 | 1218 699 | 1223 700 | 1236 701 | 1212 702 | 1237 703 | 1195 704 | 1216 705 | 1247 706 | 1234 707 | 1240 708 | 1257 709 | 1224 710 | 1243 711 | 1259 712 | 1242 713 | 1282 714 | 1222 715 | 1254 716 | 1227 717 | 1235 718 | 1269 719 | 1258 720 | 1290 721 | 1275 722 | 1262 723 | 1252 724 | 1248 725 | 1272 726 | 1246 727 | 1225 728 | 1245 729 | 1277 730 | 1298 731 | 1288 732 | 1271 733 | 1265 734 | 1286 735 | 1260 736 | 1266 737 | 1296 738 | 1280 739 | 1285 740 | 1293 741 | 1276 742 | 1287 743 | 1289 744 | 1261 745 | 1264 746 | 1295 747 | 1291 748 | 1283 749 | 1311 750 | 1303 751 | 1330 752 | 1315 753 | 1300 754 | 1333 755 | 1307 756 | 1325 757 | 1334 758 | 1316 759 | 1314 760 | 1317 761 | 1310 762 | 1329 763 | 1324 764 | 1339 765 | 1346 766 | 1342 767 | 1352 768 | 1321 769 | 1376 770 | 1366 771 | 1308 772 | 1345 773 | 1348 774 | 1386 775 | 1383 776 | 1372 777 | 1367 778 | 1400 779 | 1382 780 | 1375 781 | 1392 782 | 1380 783 | 1371 784 | 1393 785 | 1389 786 | 1353 787 | 1387 788 | 1374 789 | 1379 790 | 1381 791 | 1359 792 | 1360 793 | 1396 794 | 1399 795 | 1365 796 | 1424 797 | 1373 798 | 1411 799 | 1401 800 | 1397 801 | 1395 802 | 1412 803 | 1394 804 | 1368 805 | 1423 806 | 1391 807 | 1435 808 | 1409 809 | 1443 810 | 1402 811 | 1425 812 | 1415 813 | 1421 814 | 1426 815 | 1433 816 | 1420 817 | 1452 818 | 1436 819 | 1430 820 | 1408 821 | 1458 822 | 1429 823 | 1453 824 | 1454 825 | 1447 826 | 1472 827 | 1486 828 | 1468 829 | 1461 830 | 1467 831 | 1484 832 | 1457 833 | 1444 834 | 1450 835 | 1451 836 | 1459 837 | 1462 838 | 1449 839 | 1476 840 | 1470 841 | 1471 842 | 1498 843 | 1488 844 | 1442 845 | 1480 846 | 1456 847 | 1466 848 | 1505 849 | 1517 850 | 1464 851 | 1503 852 | 1490 853 | 1519 854 | 1481 855 | 1493 856 | 1463 857 | 1532 858 | 1487 859 | 1501 860 | 1500 861 | 1495 862 | 1509 863 | 1535 864 | 1506 865 | 1521 866 | 1580 867 | 1540 868 | 1502 869 | 1520 870 | 1496 871 | 1569 872 | 1515 873 | 1489 874 | 1507 875 | 1527 876 | 1545 877 | 1560 878 | 1510 879 | 1514 880 | 1526 881 | 1594 882 | 1511 883 | 1572 884 | 1548 885 | 1584 886 | 1556 887 | 1588 888 | 1628 889 | 1555 890 | 1568 891 | 1550 892 | 1622 893 | 1563 894 | 1603 895 | 1616 896 | 1576 897 | 1549 898 | 1537 899 | 1593 900 | 1618 901 | 1645 902 | 1624 903 | 1617 904 | 1634 905 | 1595 906 | 1597 907 | 1590 908 | 1632 909 | 1575 910 | 1559 911 | 1625 912 | 1615 913 | 1591 914 | 1630 915 | 1608 916 | 1621 917 | 1589 918 | 1646 919 | 1643 920 | 1652 921 | 1627 922 | 1611 923 | 1626 924 | 1613 925 | 1639 926 | 1655 927 | 1620 928 | 1602 929 | 1651 930 | 1653 931 | 1669 932 | 1638 933 | 1696 934 | 1649 935 | 1675 936 | 1660 937 | 1683 938 | 1666 939 | 1671 940 | 1703 941 | 1716 942 | 1637 943 | 1672 944 | 1676 945 | 1692 946 | 1711 947 | 1680 948 | 1641 949 | 1688 950 | 1708 951 | 1704 952 | 1690 953 | 1674 954 | 1718 955 | 1699 956 | 1723 957 | 1756 958 | 1700 959 | 1662 960 | 1715 961 | 1657 962 | 1733 963 | 1728 964 | 1670 965 | 1712 966 | 1685 967 | 1724 968 | 1735 969 | 1714 970 | 1730 971 | 1747 972 | 1656 973 | 1737 974 | 1705 975 | 1693 976 | 1713 977 | 1689 978 | 1753 979 | 1739 980 | 1721 981 | 1725 982 | 1749 983 | 1732 984 | 1743 985 | 1731 986 | 1767 987 | 1738 988 | 1831 989 | 1771 990 | 1726 991 | 1746 992 | 1776 993 | 1775 994 | 1799 995 | 1774 996 | 1780 997 | 1781 998 | 1769 999 | 1805 1000 | 1788 1001 | 1801 1002 | -------------------------------------------------------------------------------- /frame_level_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Contains a collection of models which operate on variable-length sequences.""" 15 | import math 16 | 17 | import model_utils as utils 18 | import models 19 | import tensorflow as tf 20 | from tensorflow import flags 21 | import tensorflow.contrib.slim as slim 22 | import video_level_models 23 | 24 | FLAGS = flags.FLAGS 25 | flags.DEFINE_integer("iterations", 30, "Number of frames per batch for DBoF.") 26 | flags.DEFINE_bool("dbof_add_batch_norm", True, 27 | "Adds batch normalization to the DBoF model.") 28 | flags.DEFINE_bool( 29 | "sample_random_frames", True, 30 | "If true samples random frames (for frame level models). If false, a random" 31 | "sequence of frames is sampled instead.") 32 | flags.DEFINE_integer("dbof_cluster_size", 8192, 33 | "Number of units in the DBoF cluster layer.") 34 | flags.DEFINE_integer("dbof_hidden_size", 1024, 35 | "Number of units in the DBoF hidden layer.") 36 | flags.DEFINE_string( 37 | "dbof_pooling_method", "max", 38 | "The pooling method used in the DBoF cluster layer. " 39 | "Choices are 'average' and 'max'.") 40 | flags.DEFINE_string( 41 | "dbof_activation", "sigmoid", 42 | "The nonlinear activation method for cluster and hidden dense layer, e.g., " 43 | "sigmoid, relu6, etc.") 44 | flags.DEFINE_string( 45 | "video_level_classifier_model", "MoeModel", 46 | "Some Frame-Level models can be decomposed into a " 47 | "generalized pooling operation followed by a " 48 | "classifier layer") 49 | flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.") 50 | flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.") 51 | 52 | 53 | class FrameLevelLogisticModel(models.BaseModel): 54 | """Creates a logistic classifier over the aggregated frame-level features.""" 55 | 56 | def create_model(self, model_input, vocab_size, num_frames, **unused_params): 57 | """See base class. 58 | 59 | This class is intended to be an example for implementors of frame level 60 | models. If you want to train a model over averaged features it is more 61 | efficient to average them beforehand rather than on the fly. 62 | 63 | Args: 64 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of 65 | input features. 66 | vocab_size: The number of classes in the dataset. 67 | num_frames: A vector of length 'batch' which indicates the number of 68 | frames for each video (before padding). 69 | 70 | Returns: 71 | A dictionary with a tensor containing the probability predictions of the 72 | model in the 'predictions' key. The dimensions of the tensor are 73 | 'batch_size' x 'num_classes'. 74 | """ 75 | num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) 76 | feature_size = model_input.get_shape().as_list()[2] 77 | 78 | denominators = tf.reshape(tf.tile(num_frames, [1, feature_size]), 79 | [-1, feature_size]) 80 | avg_pooled = tf.reduce_sum(model_input, axis=[1]) / denominators 81 | 82 | output = slim.fully_connected(avg_pooled, 83 | vocab_size, 84 | activation_fn=tf.nn.sigmoid, 85 | weights_regularizer=slim.l2_regularizer(1e-8)) 86 | return {"predictions": output} 87 | 88 | 89 | class DbofModel(models.BaseModel): 90 | """Creates a Deep Bag of Frames model. 91 | 92 | The model projects the features for each frame into a higher dimensional 93 | 'clustering' space, pools across frames in that space, and then 94 | uses a configurable video-level model to classify the now aggregated features. 95 | 96 | The model will randomly sample either frames or sequences of frames during 97 | training to speed up convergence. 98 | """ 99 | 100 | ACT_FN_MAP = { 101 | "sigmoid": tf.nn.sigmoid, 102 | "relu6": tf.nn.relu6, 103 | } 104 | 105 | def create_model(self, 106 | model_input, 107 | vocab_size, 108 | num_frames, 109 | iterations=None, 110 | add_batch_norm=None, 111 | sample_random_frames=None, 112 | cluster_size=None, 113 | hidden_size=None, 114 | is_training=True, 115 | **unused_params): 116 | """See base class. 117 | 118 | Args: 119 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of 120 | input features. 121 | vocab_size: The number of classes in the dataset. 122 | num_frames: A vector of length 'batch' which indicates the number of 123 | frames for each video (before padding). 124 | iterations: the number of frames to be sampled. 125 | add_batch_norm: whether to add batch norm during training. 126 | sample_random_frames: whether to sample random frames or random sequences. 127 | cluster_size: the output neuron number of the cluster layer. 128 | hidden_size: the output neuron number of the hidden layer. 129 | is_training: whether to build the graph in training mode. 130 | 131 | Returns: 132 | A dictionary with a tensor containing the probability predictions of the 133 | model in the 'predictions' key. The dimensions of the tensor are 134 | 'batch_size' x 'num_classes'. 135 | """ 136 | iterations = iterations or FLAGS.iterations 137 | add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm 138 | random_frames = sample_random_frames or FLAGS.sample_random_frames 139 | cluster_size = cluster_size or FLAGS.dbof_cluster_size 140 | hidden1_size = hidden_size or FLAGS.dbof_hidden_size 141 | act_fn = self.ACT_FN_MAP.get(FLAGS.dbof_activation) 142 | assert act_fn is not None, ("dbof_activation is not valid: %s." % 143 | FLAGS.dbof_activation) 144 | 145 | num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32) 146 | if random_frames: 147 | model_input = utils.SampleRandomFrames(model_input, num_frames, 148 | iterations) 149 | else: 150 | model_input = utils.SampleRandomSequence(model_input, num_frames, 151 | iterations) 152 | max_frames = model_input.get_shape().as_list()[1] 153 | feature_size = model_input.get_shape().as_list()[2] 154 | reshaped_input = tf.reshape(model_input, [-1, feature_size]) 155 | tf.compat.v1.summary.histogram("input_hist", reshaped_input) 156 | 157 | if add_batch_norm: 158 | reshaped_input = slim.batch_norm(reshaped_input, 159 | center=True, 160 | scale=True, 161 | is_training=is_training, 162 | scope="input_bn") 163 | 164 | cluster_weights = tf.compat.v1.get_variable( 165 | "cluster_weights", [feature_size, cluster_size], 166 | initializer=tf.random_normal_initializer(stddev=1 / 167 | math.sqrt(feature_size))) 168 | tf.compat.v1.summary.histogram("cluster_weights", cluster_weights) 169 | activation = tf.matmul(reshaped_input, cluster_weights) 170 | if add_batch_norm: 171 | activation = slim.batch_norm(activation, 172 | center=True, 173 | scale=True, 174 | is_training=is_training, 175 | scope="cluster_bn") 176 | else: 177 | cluster_biases = tf.compat.v1.get_variable( 178 | "cluster_biases", [cluster_size], 179 | initializer=tf.random_normal_initializer(stddev=1 / 180 | math.sqrt(feature_size))) 181 | tf.compat.v1.summary.histogram("cluster_biases", cluster_biases) 182 | activation += cluster_biases 183 | activation = act_fn(activation) 184 | tf.compat.v1.summary.histogram("cluster_output", activation) 185 | 186 | activation = tf.reshape(activation, [-1, max_frames, cluster_size]) 187 | activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method) 188 | 189 | hidden1_weights = tf.compat.v1.get_variable( 190 | "hidden1_weights", [cluster_size, hidden1_size], 191 | initializer=tf.random_normal_initializer(stddev=1 / 192 | math.sqrt(cluster_size))) 193 | tf.compat.v1.summary.histogram("hidden1_weights", hidden1_weights) 194 | activation = tf.matmul(activation, hidden1_weights) 195 | if add_batch_norm: 196 | activation = slim.batch_norm(activation, 197 | center=True, 198 | scale=True, 199 | is_training=is_training, 200 | scope="hidden1_bn") 201 | else: 202 | hidden1_biases = tf.compat.v1.get_variable( 203 | "hidden1_biases", [hidden1_size], 204 | initializer=tf.random_normal_initializer(stddev=0.01)) 205 | tf.compat.v1.summary.histogram("hidden1_biases", hidden1_biases) 206 | activation += hidden1_biases 207 | activation = act_fn(activation) 208 | tf.compat.v1.summary.histogram("hidden1_output", activation) 209 | 210 | aggregated_model = getattr(video_level_models, 211 | FLAGS.video_level_classifier_model) 212 | return aggregated_model().create_model(model_input=activation, 213 | vocab_size=vocab_size, 214 | **unused_params) 215 | 216 | 217 | class LstmModel(models.BaseModel): 218 | """Creates a model which uses a stack of LSTMs to represent the video.""" 219 | 220 | def create_model(self, model_input, vocab_size, num_frames, **unused_params): 221 | """See base class. 222 | 223 | Args: 224 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of 225 | input features. 226 | vocab_size: The number of classes in the dataset. 227 | num_frames: A vector of length 'batch' which indicates the number of 228 | frames for each video (before padding). 229 | 230 | Returns: 231 | A dictionary with a tensor containing the probability predictions of the 232 | model in the 'predictions' key. The dimensions of the tensor are 233 | 'batch_size' x 'num_classes'. 234 | """ 235 | lstm_size = FLAGS.lstm_cells 236 | number_of_layers = FLAGS.lstm_layers 237 | 238 | stacked_lstm = tf.contrib.rnn.MultiRNNCell([ 239 | tf.contrib.rnn.BasicLSTMCell(lstm_size, forget_bias=1.0) 240 | for _ in range(number_of_layers) 241 | ]) 242 | 243 | _, state = tf.nn.dynamic_rnn(stacked_lstm, 244 | model_input, 245 | sequence_length=num_frames, 246 | dtype=tf.float32) 247 | 248 | aggregated_model = getattr(video_level_models, 249 | FLAGS.video_level_classifier_model) 250 | 251 | return aggregated_model().create_model(model_input=state[-1].h, 252 | vocab_size=vocab_size, 253 | **unused_params) 254 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /readers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Provides readers configured for different datasets.""" 15 | 16 | import tensorflow as tf 17 | import utils 18 | 19 | 20 | def resize_axis(tensor, axis, new_size, fill_value=0): 21 | """Truncates or pads a tensor to new_size on on a given axis. 22 | 23 | Truncate or extend tensor such that tensor.shape[axis] == new_size. If the 24 | size increases, the padding will be performed at the end, using fill_value. 25 | 26 | Args: 27 | tensor: The tensor to be resized. 28 | axis: An integer representing the dimension to be sliced. 29 | new_size: An integer or 0d tensor representing the new value for 30 | tensor.shape[axis]. 31 | fill_value: Value to use to fill any new entries in the tensor. Will be cast 32 | to the type of tensor. 33 | 34 | Returns: 35 | The resized tensor. 36 | """ 37 | tensor = tf.convert_to_tensor(tensor) 38 | shape = tf.unstack(tf.shape(tensor)) 39 | 40 | pad_shape = shape[:] 41 | pad_shape[axis] = tf.maximum(0, new_size - shape[axis]) 42 | 43 | shape[axis] = tf.minimum(shape[axis], new_size) 44 | shape = tf.stack(shape) 45 | 46 | resized = tf.concat([ 47 | tf.slice(tensor, tf.zeros_like(shape), shape), 48 | tf.fill(tf.stack(pad_shape), tf.cast(fill_value, tensor.dtype)) 49 | ], axis) 50 | 51 | # Update shape. 52 | new_shape = tensor.get_shape().as_list() # A copy is being made. 53 | new_shape[axis] = new_size 54 | resized.set_shape(new_shape) 55 | return resized 56 | 57 | 58 | class BaseReader(object): 59 | """Inherit from this class when implementing new readers.""" 60 | 61 | def prepare_reader(self, unused_filename_queue): 62 | """Create a thread for generating prediction and label tensors.""" 63 | raise NotImplementedError() 64 | 65 | 66 | class YT8MAggregatedFeatureReader(BaseReader): 67 | """Reads TFRecords of pre-aggregated Examples. 68 | 69 | The TFRecords must contain Examples with a sparse int64 'labels' feature and 70 | a fixed length float32 feature, obtained from the features in 'feature_name'. 71 | The float features are assumed to be an average of dequantized values. 72 | """ 73 | 74 | def __init__( # pylint: disable=dangerous-default-value 75 | self, 76 | num_classes=3862, 77 | feature_sizes=[1024, 128], 78 | feature_names=["mean_rgb", "mean_audio"]): 79 | """Construct a YT8MAggregatedFeatureReader. 80 | 81 | Args: 82 | num_classes: a positive integer for the number of classes. 83 | feature_sizes: positive integer(s) for the feature dimensions as a list. 84 | feature_names: the feature name(s) in the tensorflow record as a list. 85 | """ 86 | 87 | assert len(feature_names) == len(feature_sizes), ( 88 | "length of feature_names (={}) != length of feature_sizes (={})".format( 89 | len(feature_names), len(feature_sizes))) 90 | 91 | self.num_classes = num_classes 92 | self.feature_sizes = feature_sizes 93 | self.feature_names = feature_names 94 | 95 | def prepare_reader(self, filename_queue, batch_size=1024): 96 | """Creates a single reader thread for pre-aggregated YouTube 8M Examples. 97 | 98 | Args: 99 | filename_queue: A tensorflow queue of filename locations. 100 | batch_size: batch size used for feature output. 101 | 102 | Returns: 103 | A dict of video indexes, features, labels, and frame counts. 104 | """ 105 | reader = tf.TFRecordReader() 106 | _, serialized_examples = reader.read_up_to(filename_queue, batch_size) 107 | 108 | tf.add_to_collection("serialized_examples", serialized_examples) 109 | return self.prepare_serialized_examples(serialized_examples) 110 | 111 | def prepare_serialized_examples(self, serialized_examples): 112 | """Parse a single video-level TF Example.""" 113 | # set the mapping from the fields to data types in the proto 114 | num_features = len(self.feature_names) 115 | assert num_features > 0, "self.feature_names is empty!" 116 | assert len(self.feature_names) == len(self.feature_sizes), \ 117 | "length of feature_names (={}) != length of feature_sizes (={})".format( 118 | len(self.feature_names), len(self.feature_sizes)) 119 | 120 | feature_map = { 121 | "id": tf.io.FixedLenFeature([], tf.string), 122 | "labels": tf.io.VarLenFeature(tf.int64) 123 | } 124 | for feature_index in range(num_features): 125 | feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature( 126 | [self.feature_sizes[feature_index]], tf.float32) 127 | 128 | features = tf.parse_example(serialized_examples, features=feature_map) 129 | labels = tf.sparse_to_indicator(features["labels"], self.num_classes) 130 | labels.set_shape([None, self.num_classes]) 131 | concatenated_features = tf.concat( 132 | [features[feature_name] for feature_name in self.feature_names], 1) 133 | 134 | output_dict = { 135 | "video_ids": features["id"], 136 | "video_matrix": concatenated_features, 137 | "labels": labels, 138 | "num_frames": tf.ones([tf.shape(serialized_examples)[0]]) 139 | } 140 | 141 | return output_dict 142 | 143 | 144 | class YT8MFrameFeatureReader(BaseReader): 145 | """Reads TFRecords of SequenceExamples. 146 | 147 | The TFRecords must contain SequenceExamples with the sparse in64 'labels' 148 | context feature and a fixed length byte-quantized feature vector, obtained 149 | from the features in 'feature_names'. The quantized features will be mapped 150 | back into a range between min_quantized_value and max_quantized_value. 151 | """ 152 | 153 | def __init__( # pylint: disable=dangerous-default-value 154 | self, 155 | num_classes=3862, 156 | feature_sizes=[1024, 128], 157 | feature_names=["rgb", "audio"], 158 | max_frames=300, 159 | segment_labels=False, 160 | segment_size=5): 161 | """Construct a YT8MFrameFeatureReader. 162 | 163 | Args: 164 | num_classes: a positive integer for the number of classes. 165 | feature_sizes: positive integer(s) for the feature dimensions as a list. 166 | feature_names: the feature name(s) in the tensorflow record as a list. 167 | max_frames: the maximum number of frames to process. 168 | segment_labels: if we read segment labels instead. 169 | segment_size: the segment_size used for reading segments. 170 | """ 171 | 172 | assert len(feature_names) == len(feature_sizes), ( 173 | "length of feature_names (={}) != length of feature_sizes (={})".format( 174 | len(feature_names), len(feature_sizes))) 175 | 176 | self.num_classes = num_classes 177 | self.feature_sizes = feature_sizes 178 | self.feature_names = feature_names 179 | self.max_frames = max_frames 180 | self.segment_labels = segment_labels 181 | self.segment_size = segment_size 182 | 183 | def get_video_matrix(self, features, feature_size, max_frames, 184 | max_quantized_value, min_quantized_value): 185 | """Decodes features from an input string and quantizes it. 186 | 187 | Args: 188 | features: raw feature values 189 | feature_size: length of each frame feature vector 190 | max_frames: number of frames (rows) in the output feature_matrix 191 | max_quantized_value: the maximum of the quantized value. 192 | min_quantized_value: the minimum of the quantized value. 193 | 194 | Returns: 195 | feature_matrix: matrix of all frame-features 196 | num_frames: number of frames in the sequence 197 | """ 198 | decoded_features = tf.reshape( 199 | tf.cast(tf.decode_raw(features, tf.uint8), tf.float32), 200 | [-1, feature_size]) 201 | 202 | num_frames = tf.minimum(tf.shape(decoded_features)[0], max_frames) 203 | feature_matrix = utils.Dequantize(decoded_features, max_quantized_value, 204 | min_quantized_value) 205 | feature_matrix = resize_axis(feature_matrix, 0, max_frames) 206 | return feature_matrix, num_frames 207 | 208 | def prepare_reader(self, 209 | filename_queue, 210 | max_quantized_value=2, 211 | min_quantized_value=-2): 212 | """Creates a single reader thread for YouTube8M SequenceExamples. 213 | 214 | Args: 215 | filename_queue: A tensorflow queue of filename locations. 216 | max_quantized_value: the maximum of the quantized value. 217 | min_quantized_value: the minimum of the quantized value. 218 | 219 | Returns: 220 | A dict of video indexes, video features, labels, and frame counts. 221 | """ 222 | reader = tf.TFRecordReader() 223 | _, serialized_example = reader.read(filename_queue) 224 | 225 | return self.prepare_serialized_examples(serialized_example, 226 | max_quantized_value, 227 | min_quantized_value) 228 | 229 | def prepare_serialized_examples(self, 230 | serialized_example, 231 | max_quantized_value=2, 232 | min_quantized_value=-2): 233 | """Parse single serialized SequenceExample from the TFRecords.""" 234 | 235 | # Read/parse frame/segment-level labels. 236 | context_features = { 237 | "id": tf.io.FixedLenFeature([], tf.string), 238 | } 239 | if self.segment_labels: 240 | context_features.update({ 241 | # There is no need to read end-time given we always assume the segment 242 | # has the same size. 243 | "segment_labels": tf.io.VarLenFeature(tf.int64), 244 | "segment_start_times": tf.io.VarLenFeature(tf.int64), 245 | "segment_scores": tf.io.VarLenFeature(tf.float32) 246 | }) 247 | else: 248 | context_features.update({"labels": tf.io.VarLenFeature(tf.int64)}) 249 | sequence_features = { 250 | feature_name: tf.io.FixedLenSequenceFeature([], dtype=tf.string) 251 | for feature_name in self.feature_names 252 | } 253 | contexts, features = tf.io.parse_single_sequence_example( 254 | serialized_example, 255 | context_features=context_features, 256 | sequence_features=sequence_features) 257 | 258 | # loads (potentially) different types of features and concatenates them 259 | num_features = len(self.feature_names) 260 | assert num_features > 0, "No feature selected: feature_names is empty!" 261 | 262 | assert len(self.feature_names) == len(self.feature_sizes), ( 263 | "length of feature_names (={}) != length of feature_sizes (={})".format( 264 | len(self.feature_names), len(self.feature_sizes))) 265 | 266 | num_frames = -1 # the number of frames in the video 267 | feature_matrices = [None] * num_features # an array of different features 268 | for feature_index in range(num_features): 269 | feature_matrix, num_frames_in_this_feature = self.get_video_matrix( 270 | features[self.feature_names[feature_index]], 271 | self.feature_sizes[feature_index], self.max_frames, 272 | max_quantized_value, min_quantized_value) 273 | if num_frames == -1: 274 | num_frames = num_frames_in_this_feature 275 | 276 | feature_matrices[feature_index] = feature_matrix 277 | 278 | # cap the number of frames at self.max_frames 279 | num_frames = tf.minimum(num_frames, self.max_frames) 280 | 281 | # concatenate different features 282 | video_matrix = tf.concat(feature_matrices, 1) 283 | 284 | # Partition frame-level feature matrix to segment-level feature matrix. 285 | if self.segment_labels: 286 | start_times = contexts["segment_start_times"].values 287 | # Here we assume all the segments that started at the same start time has 288 | # the same segment_size. 289 | uniq_start_times, seg_idxs = tf.unique(start_times, 290 | out_idx=tf.dtypes.int64) 291 | # TODO(zhengxu): Ensure the segment_sizes are all same. 292 | segment_size = self.segment_size 293 | # Range gather matrix, e.g., [[0,1,2],[1,2,3]] for segment_size == 3. 294 | range_mtx = tf.expand_dims(uniq_start_times, axis=-1) + tf.expand_dims( 295 | tf.range(0, segment_size, dtype=tf.int64), axis=0) 296 | # Shape: [num_segment, segment_size, feature_dim]. 297 | batch_video_matrix = tf.gather_nd(video_matrix, 298 | tf.expand_dims(range_mtx, axis=-1)) 299 | num_segment = tf.shape(batch_video_matrix)[0] 300 | batch_video_ids = tf.reshape(tf.tile([contexts["id"]], [num_segment]), 301 | (num_segment,)) 302 | batch_frames = tf.reshape(tf.tile([segment_size], [num_segment]), 303 | (num_segment,)) 304 | 305 | # For segment labels, all labels are not exhausively rated. So we only 306 | # evaluate the rated labels. 307 | 308 | # Label indices for each segment, shape: [num_segment, 2]. 309 | label_indices = tf.stack([seg_idxs, contexts["segment_labels"].values], 310 | axis=-1) 311 | label_values = contexts["segment_scores"].values 312 | sparse_labels = tf.sparse.SparseTensor(label_indices, label_values, 313 | (num_segment, self.num_classes)) 314 | batch_labels = tf.sparse.to_dense(sparse_labels, validate_indices=False) 315 | 316 | sparse_label_weights = tf.sparse.SparseTensor( 317 | label_indices, tf.ones_like(label_values, dtype=tf.float32), 318 | (num_segment, self.num_classes)) 319 | batch_label_weights = tf.sparse.to_dense(sparse_label_weights, 320 | validate_indices=False) 321 | else: 322 | # Process video-level labels. 323 | label_indices = contexts["labels"].values 324 | sparse_labels = tf.sparse.SparseTensor( 325 | tf.expand_dims(label_indices, axis=-1), 326 | tf.ones_like(contexts["labels"].values, dtype=tf.bool), 327 | (self.num_classes,)) 328 | labels = tf.sparse.to_dense(sparse_labels, 329 | default_value=False, 330 | validate_indices=False) 331 | # convert to batch format. 332 | batch_video_ids = tf.expand_dims(contexts["id"], 0) 333 | batch_video_matrix = tf.expand_dims(video_matrix, 0) 334 | batch_labels = tf.expand_dims(labels, 0) 335 | batch_frames = tf.expand_dims(num_frames, 0) 336 | batch_label_weights = None 337 | 338 | output_dict = { 339 | "video_ids": batch_video_ids, 340 | "video_matrix": batch_video_matrix, 341 | "labels": batch_labels, 342 | "num_frames": batch_frames, 343 | } 344 | if batch_label_weights is not None: 345 | output_dict["label_weights"] = batch_label_weights 346 | 347 | return output_dict 348 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YouTube-8M Tensorflow Starter Code 2 | 3 | This repo contains starter code for training and evaluating machine learning 4 | models over the [YouTube-8M](https://research.google.com/youtube8m/) dataset. 5 | This is the starter code for our 6 | [3rd Youtube8M Video Understanding Challenge on Kaggle](https://www.kaggle.com/c/youtube8m-2019) 7 | and part of the International Conference on Computer Vision (ICCV) 2019 selected 8 | workshop session. The code gives an end-to-end working example for reading the 9 | dataset, training a TensorFlow model, and evaluating the performance of the 10 | model. 11 | 12 | ## Table of Contents 13 | 14 | * [Running on Your Own Machine](#running-on-your-own-machine) 15 | * [Requirements](#requirements) 16 | * [Download Dataset Locally](#download-dataset-locally) 17 | * [Try the starter code](#try-the-starter-code) 18 | * [Train video-level model on frame-level features and inference at 19 | segment-level.](#train-video-level-model-on-frame-level-features-and-inference-at-segment-level) 20 | * [Tensorboard](#tensorboard) 21 | * [Using GPUs](#using-gpus) 22 | * [Running on Google's Cloud Machine Learning Platform](#running-on-googles-cloud-machine-learning-platform) 23 | * [Requirements](#requirements-1) 24 | * [Accessing Files on Google Cloud](#accessing-files-on-google-cloud) 25 | * [Testing Locally](#testing-locally) 26 | * [Training on the Cloud over Frame-Level Features](#training-on-the-cloud-over-frame-level-features) 27 | * [Evaluation and Inference](#evaluation-and-inference) 28 | * [Create Your Own Dataset Files](#create-your-own-dataset-files) 29 | * [Training without this Starter Code](#training-without-this-starter-code) 30 | * [Export Your Model for MediaPipe Inference](#export-your-model-for-mediapipe-inference) 31 | * [More Documents](#more-documents) 32 | * [About This Project](#about-this-project) 33 | 34 | ## Running on Your Own Machine 35 | 36 | ### Requirements 37 | 38 | The starter code requires Tensorflow. If you haven't installed it yet, follow 39 | the instructions on [tensorflow.org](https://www.tensorflow.org/install/). This 40 | code has been tested with Tensorflow 1.14. Going forward, we will continue to 41 | target the latest released version of Tensorflow. 42 | 43 | Please verify that you have Python 3.6+ and Tensorflow 1.14 or higher installed 44 | by running the following commands: 45 | 46 | ```sh 47 | python --version 48 | python -c 'import tensorflow as tf; print(tf.__version__)' 49 | ``` 50 | 51 | ### Download Dataset Locally 52 | 53 | Please see our 54 | [dataset website](https://research.google.com/youtube8m/download.html) for 55 | up-to-date download instructions. 56 | 57 | In this document, we assume you download all the frame-level feature dataset to 58 | `~/yt8m/2/frame` and segment-level validation/test dataset to `~/yt8m/3/frame`. 59 | So the structure should look like 60 | 61 | ``` 62 | ~/yt8m/ 63 | - ~/yt8m/2/frame/ 64 | - ~/yt8m/2/frame/train 65 | - ~/yt8m/3/frame/ 66 | - ~/yt8m/3/frame/test 67 | - ~/yt8m/3/frame/validate 68 | ``` 69 | 70 | ### Try the starter code 71 | 72 | Clone this git repo: `mkdir -p ~/yt8m/code cd ~/yt8m/code git clone 73 | https://github.com/google/youtube-8m.git` 74 | 75 | #### Train video-level model on frame-level features and inference at segment-level. 76 | 77 | Train using `train.py`, selecting a frame-level model (e.g. 78 | `FrameLevelLogisticModel`), and instructing the trainer to use 79 | `--frame_features`. TLDR - frame-level features are compressed, and this flag 80 | uncompresses them. 81 | 82 | ```bash 83 | python train.py --frame_features --model=FrameLevelLogisticModel \ 84 | --feature_names='rgb,audio' --feature_sizes='1024,128' \ 85 | --train_data_pattern=${HOME}/yt8m/2/frame/train/train*.tfrecord 86 | --train_dir ~/yt8m/models/frame/sample_model --start_new_model 87 | ``` 88 | 89 | Evaluate the model by 90 | 91 | ```bash 92 | python eval.py \ 93 | --eval_data_pattern=${HOME}/yt8m/3/frame/validate/validate*.tfrecord \ 94 | --train_dir ~/yt8m/models/frame/sample_model --segment_labels 95 | ``` 96 | 97 | This will provide some comprehensive metrics, e.g., gAP, mAP, etc., for your 98 | models. 99 | 100 | Produce CSV (`kaggle_solution.csv`) by doing inference: 101 | 102 | ```bash 103 | python \ 104 | inference.py --train_dir ~/yt8m/models/frame/sample_model \ 105 | --output_file=$HOME/tmp/kaggle_solution.csv \ 106 | --input_data_pattern=${HOME}/yt8m/3/frame/test/test*.tfrecord --segment_labels 107 | ``` 108 | 109 | (Optional) If you wish to see how the models are evaluated in Kaggle system, you 110 | can do so by 111 | 112 | ```bash 113 | python inference.py --train_dir ~/yt8m/models/frame/sample_model \ 114 | --output_file=$HOME/tmp/kaggle_solution_validation.csv \ 115 | --input_data_pattern=${HOME}/yt8m/3/frame/validate/validate*.tfrecord \ 116 | --segment_labels 117 | ``` 118 | 119 | ```bash 120 | python segment_eval_inference.py \ 121 | --eval_data_pattern=${HOME}/yt8m/3/frame/validate/validate*.tfrecord \ 122 | --label_cache=$HOME/tmp/validate.label_cache \ 123 | --submission_file=$HOME/tmp/kaggle_solution_validation.csv --top_n=100000 124 | ``` 125 | 126 | **NOTE**: This script can be slow for the first time running. It will read 127 | TFRecord data and build label cache. Once label cache is built, the evaluation 128 | will be much faster later on. 129 | 130 | #### Tensorboard 131 | 132 | You can use Tensorboard to compare your frame-level or video-level models, like: 133 | 134 | ```sh 135 | MODELS_DIR=~/yt8m/models 136 | tensorboard --logdir frame:${MODELS_DIR}/frame 137 | ``` 138 | 139 | We find it useful to keep the tensorboard instance always running, as we train 140 | and evaluate different models. 141 | 142 | #### Using GPUs 143 | 144 | If your Tensorflow installation has GPU support, e.g., installed with `pip 145 | install tensorflow-gpu`, this code will make use of all of your compatible GPUs. 146 | You can verify your installation by running 147 | 148 | ``` 149 | python -c 'import tensorflow as tf; tf.Session()' 150 | ``` 151 | 152 | This will print out something like the following for each of your compatible 153 | GPUs. 154 | 155 | ``` 156 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:102] Found device 0 with properties: 157 | name: Tesla M40 158 | major: 5 minor: 2 memoryClockRate (GHz) 1.112 159 | pciBusID 0000:04:00.0 160 | Total memory: 11.25GiB 161 | Free memory: 11.09GiB 162 | ... 163 | ``` 164 | 165 | If at least one GPU was found, the forward and backward passes will be computed 166 | with the GPUs, whereas the CPU will be used primarily for the input and output 167 | pipelines. If you have multiple GPUs, the current default behavior is to use 168 | only one of them. 169 | 170 | 171 | ## Running on Google's Cloud Machine Learning Platform 172 | 173 | ### Requirements 174 | 175 | This option requires you to have an appropriately configured Google Cloud 176 | Platform account. To create and configure your account, please make sure you 177 | follow the instructions 178 | [here](https://cloud.google.com/ml/docs/how-tos/getting-set-up). 179 | 180 | Please also verify that you have Python 3.6+ and Tensorflow 1.14 or higher 181 | installed by running the following commands: 182 | 183 | ```sh 184 | python --version 185 | python -c 'import tensorflow as tf; print(tf.__version__)' 186 | ``` 187 | 188 | ### Accessing Files on Google Cloud 189 | 190 | You can browse the storage buckets you created on Google Cloud, for example, to 191 | access the trained models, prediction CSV files, etc. by visiting the 192 | [Google Cloud storage browser](https://console.cloud.google.com/storage/browser). 193 | 194 | Alternatively, you can use the 'gsutil' command to download the files directly. 195 | For example, to download the output of the inference code from the previous 196 | section to your local machine, run: 197 | 198 | ``` 199 | gsutil cp $BUCKET_NAME/${JOB_TO_EVAL}/predictions.csv . 200 | ``` 201 | 202 | ### Testing Locally 203 | 204 | All gcloud commands should be done from the directory *immediately above* the 205 | source code. You should be able to see the source code directory if you run 206 | 'ls'. 207 | 208 | As you are developing your own models, you will want to test them quickly to 209 | flush out simple problems without having to submit them to the cloud. 210 | 211 | Here is an example command line for frame-level training: 212 | 213 | ```sh 214 | gcloud ai-platform local train \ 215 | --package-path=youtube-8m --module-name=youtube-8m.train -- \ 216 | --train_data_pattern='gs://youtube8m-ml/2/frame/train/train*.tfrecord' \ 217 | --train_dir=/tmp/yt8m_train --frame_features --model=FrameLevelLogisticModel \ 218 | --feature_names='rgb,audio' --feature_sizes='1024,128' --start_new_model 219 | ``` 220 | 221 | ### Training on the Cloud over Frame-Level Features 222 | 223 | The following commands will train a model on Google Cloud over frame-level 224 | features. 225 | 226 | ```bash 227 | BUCKET_NAME=gs://${USER}_yt8m_train_bucket 228 | # (One Time) Create a storage bucket to store training logs and checkpoints. 229 | gsutil mb -l us-east1 $BUCKET_NAME 230 | # Submit the training job. 231 | JOB_NAME=yt8m_train_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ai-platform jobs \ 232 | submit training $JOB_NAME \ 233 | --package-path=youtube-8m --module-name=youtube-8m.train \ 234 | --staging-bucket=$BUCKET_NAME --region=us-east1 \ 235 | --config=youtube-8m/cloudml-gpu.yaml \ 236 | -- --train_data_pattern='gs://youtube8m-ml/2/frame/train/train*.tfrecord' \ 237 | --frame_features --model=FrameLevelLogisticModel \ 238 | --feature_names='rgb,audio' --feature_sizes='1024,128' \ 239 | --train_dir=$BUCKET_NAME/yt8m_train_frame_level_logistic_model --start_new_model 240 | ``` 241 | 242 | In the 'gsutil' command above, the 'package-path' flag refers to the directory 243 | containing the 'train.py' script and more generally the python package which 244 | should be deployed to the cloud worker. The module-name refers to the specific 245 | python script which should be executed (in this case the train module). 246 | 247 | It may take several minutes before the job starts running on Google Cloud. When 248 | it starts you will see outputs like the following: 249 | 250 | ``` 251 | training step 270| Hit@1: 0.68 PERR: 0.52 Loss: 638.453 252 | training step 271| Hit@1: 0.66 PERR: 0.49 Loss: 635.537 253 | training step 272| Hit@1: 0.70 PERR: 0.52 Loss: 637.564 254 | ``` 255 | 256 | At this point you can disconnect your console by pressing "ctrl-c". The model 257 | will continue to train indefinitely in the Cloud. Later, you can check on its 258 | progress or halt the job by visiting the 259 | [Google Cloud ML Jobs console](https://console.cloud.google.com/ml/jobs). 260 | 261 | You can train many jobs at once and use tensorboard to compare their performance 262 | visually. 263 | 264 | ```sh 265 | tensorboard --logdir=$BUCKET_NAME --port=8080 266 | ``` 267 | 268 | Once tensorboard is running, you can access it at the following url: 269 | [http://localhost:8080](http://localhost:8080). If you are using Google Cloud 270 | Shell, you can instead click the Web Preview button on the upper left corner of 271 | the Cloud Shell window and select "Preview on port 8080". This will bring up a 272 | new browser tab with the Tensorboard view. 273 | 274 | ### Evaluation and Inference 275 | 276 | Here's how to evaluate a model on the validation dataset: 277 | 278 | ```sh 279 | JOB_TO_EVAL=yt8m_train_frame_level_logistic_model 280 | JOB_NAME=yt8m_eval_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ai-platform jobs \ 281 | submit training $JOB_NAME \ 282 | --package-path=youtube-8m --module-name=youtube-8m.eval \ 283 | --staging-bucket=$BUCKET_NAME --region=us-east1 \ 284 | --config=youtube-8m/cloudml-gpu.yaml \ 285 | -- --eval_data_pattern='gs://youtube8m-ml/3/frame/validate/validate*.tfrecord' \ 286 | --frame_features --model=FrameLevelLogisticModel --feature_names='rgb,audio' \ 287 | --feature_sizes='1024,128' --train_dir=$BUCKET_NAME/${JOB_TO_EVAL} \ 288 | --segment_labels --run_once=True 289 | ``` 290 | 291 | And here's how to perform inference with a model on the test set: 292 | 293 | ```sh 294 | JOB_TO_EVAL=yt8m_train_frame_level_logistic_model 295 | JOB_NAME=yt8m_inference_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ai-platform jobs \ 296 | submit training $JOB_NAME \ 297 | --package-path=youtube-8m --module-name=youtube-8m.inference \ 298 | --staging-bucket=$BUCKET_NAME --region=us-east1 \ 299 | --config=youtube-8m/cloudml-gpu.yaml \ 300 | -- --input_data_pattern='gs://youtube8m-ml/3/frame/test/test*.tfrecord' \ 301 | --train_dir=$BUCKET_NAME/${JOB_TO_EVAL} --segment_labels \ 302 | --output_file=$BUCKET_NAME/${JOB_TO_EVAL}/predictions.csv 303 | ``` 304 | 305 | Note the confusing use of 'training' in the above gcloud commands. Despite the 306 | name, the 'training' argument really just offers a cloud hosted 307 | python/tensorflow service. From the point of view of the Cloud Platform, there 308 | is no distinction between our training and inference jobs. The Cloud ML platform 309 | also offers specialized functionality for prediction with Tensorflow models, but 310 | discussing that is beyond the scope of this readme. 311 | 312 | Once these job starts executing you will see outputs similar to the following 313 | for the evaluation code: 314 | 315 | ``` 316 | examples_processed: 1024 | global_step 447044 | Batch Hit@1: 0.782 | Batch PERR: 0.637 | Batch Loss: 7.821 | Examples_per_sec: 834.658 317 | ``` 318 | 319 | and the following for the inference code: 320 | 321 | ``` 322 | num examples processed: 8192 elapsed seconds: 14.85 323 | ``` 324 | 325 | ## Export Your Model for MediaPipe Inference 326 | To run inference with your model in [MediaPipe inference 327 | demo](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m#steps-to-run-the-youtube-8m-inference-graph-with-the-yt8m-dataset), you need to export your checkpoint to a SavedModel. 328 | 329 | Example command: 330 | ```sh 331 | python export_model_mediapipe.py --checkpoint_file ~/yt8m/models/frame/sample_model/inference_model/segment_inference_model --output_dir /tmp/mediapipe/saved_model/ 332 | ``` 333 | 334 | 335 | ## Create Your Own Dataset Files 336 | 337 | You can create your dataset files from your own videos. Our 338 | [feature extractor](./feature_extractor) code creates `tfrecord` files, 339 | identical to our dataset files. You can use our starter code to train on the 340 | `tfrecord` files output by the feature extractor. In addition, you can fine-tune 341 | your YouTube-8M models on your new dataset. 342 | 343 | ## Training without this Starter Code 344 | 345 | You are welcome to use our dataset without using our starter code. However, if 346 | you'd like to compete on Kaggle, then you must make sure that you are able to 347 | produce a prediction CSV file produced by our `inference.py`. In particular, the 348 | [predictions CSV file](https://www.kaggle.com/c/youtube8m-2018#evaluation) must 349 | have two fields: `Class Id,Segment Ids` where `Class Id` must be class ids 350 | listed in `segment_label_ids.csv` and `Segment Ids` is a space-delimited list of 351 | `