├── cloudml-gpu.yaml
├── __init__.py
├── .github
└── workflows
│ └── pythonpackage.yml
├── models.py
├── docs
├── model_overview.md
└── files_overview.md
├── CONTRIBUTING.md
├── feature_extractor
├── feature_extractor_test.py
├── README.md
├── feature_extractor.py
└── extract_tfrecords_main.py
├── export_model_mediapipe.py
├── convert_prediction_from_json_to_csv.py
├── model_utils.py
├── losses.py
├── video_level_models.py
├── mean_average_precision_calculator.py
├── export_model.py
├── segment_eval_inference.py
├── utils.py
├── eval_util.py
├── average_precision_calculator.py
├── segment_label_ids.csv
├── frame_level_models.py
├── LICENSE
├── readers.py
├── README.md
├── eval.py
├── inference.py
└── train.py
/cloudml-gpu.yaml:
--------------------------------------------------------------------------------
1 | trainingInput:
2 | scaleTier: CUSTOM
3 | # https://cloud.google.com/ml-engine/docs/machine-types#machine_type_table
4 | masterType: standard_gpu
5 | runtimeVersion: "1.14"
6 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/.github/workflows/pythonpackage.yml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | max-parallel: 4
11 | matrix:
12 | python-version: [3.6]
13 |
14 | steps:
15 | - uses: actions/checkout@v1
16 | - name: Set up Python ${{ matrix.python-version }}
17 | uses: actions/setup-python@v1
18 | with:
19 | python-version: ${{ matrix.python-version }}
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install tensorflow==1.14.0 six Pillow
24 | - name: Yapf Check
25 | run: |
26 | pip install yapf
27 | yapf --diff --style="{based_on_style: google, indent_width:2}" *.py
28 | - name: Test with nosetests
29 | run: |
30 | pip install -U nose
31 | nosetests
32 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Contains the base class for models."""
15 |
16 |
17 | class BaseModel(object):
18 | """Inherit from this class when implementing new models."""
19 |
20 | def create_model(self, unused_model_input, **unused_params):
21 | raise NotImplementedError()
22 |
--------------------------------------------------------------------------------
/docs/model_overview.md:
--------------------------------------------------------------------------------
1 | # Overview of Models
2 |
3 | This sample code contains implementations of the models given in the
4 | [YouTube-8M technical report](https://arxiv.org/abs/1609.08675).
5 |
6 | ## Video-Level Models
7 |
8 | * `LogisticModel`: Linear projection of the output features into the label
9 | space, followed by a sigmoid function to convert logit values to
10 | probabilities.
11 | * `MoeModel`: A per-class softmax distribution over a configurable number of
12 | logistic classifiers. One of the classifiers in the mixture is not trained,
13 | and always predicts 0.
14 |
15 | ## Frame-Level Models
16 |
17 | * `LstmModel`: Processes the features for each frame using a multi-layered
18 | LSTM neural net. The final internal state of the LSTM is input to a
19 | video-level model for classification. Note that you will need to change the
20 | learning rate to 0.001 when using this model.
21 | * `DbofModel`: Projects the features for each frame into a higher dimensional
22 | 'clustering' space, pools across frames in that space, and then uses a
23 | video-level model to classify the now aggregated features.
24 | * `FrameLevelLogisticModel`: Equivalent to 'LogisticModel', but performs
25 | average-pooling on the fly over frame-level features rather than using
26 | pre-aggregated features.
27 |
--------------------------------------------------------------------------------
/docs/files_overview.md:
--------------------------------------------------------------------------------
1 | # Overview of Files
2 |
3 | ## Training
4 |
5 | * `train.py`: The primary script for training models.
6 | * `losses.py`: Contains definitions for loss functions.
7 | * `models.py`: Contains the base class for defining a model.
8 | * `video_level_models.py`: Contains definitions for models that take
9 | aggregated features as input.
10 | * `frame_level_models.py`: Contains definitions for models that take frame-
11 | level features as input.
12 | * `model_util.py`: Contains functions that are of general utility for
13 | implementing models.
14 | * `export_model.py`: Provides a class to export a model during training for
15 | later use in batch prediction.
16 | * `readers.py`: Contains definitions for the Video dataset and Frame dataset
17 | readers.
18 |
19 | ## Evaluation
20 |
21 | * `eval.py`: The primary script for evaluating models.
22 | * `eval_util.py`: Provides a class that calculates all evaluation metrics.
23 | * `average_precision_calculator.py`: Functions for calculating average
24 | precision.
25 | * `mean_average_precision_calculator.py`: Functions for calculating mean
26 | average precision.
27 | * `segment_eval_inference.py`: The primary script to evaluate segment models
28 | with Kaggle metrics.
29 |
30 | ## Inference
31 |
32 | * `inference.py`: Generates an output CSV file containing predictions of the
33 | model over a set of videos. It optionally generates a tarred file of the
34 | model.
35 |
36 | ## Misc
37 |
38 | * `README.md`: This documentation.
39 | * `utils.py`: Common functions.
40 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | We are accepting patches and contributions to this project. To set expectations,
4 | this project is primarily intended to be a flexible starting point for
5 | researchers working with the YouTube-8M dataset. As such, we would like to keep
6 | it simple. We are more likely to accept small bug fixes and optimizations, and
7 | less likely to accept patches which add significant complexity. For the latter
8 | type of contribution, we recommend creating a Github fork of the project
9 | instead.
10 |
11 | If you would like to contribute, there are a few small guidelines you need to
12 | follow.
13 |
14 | ## Contributor License Agreement
15 |
16 | Contributions to any Google project must be accompanied by a Contributor License
17 | Agreement. This is necessary because you own the copyright to your changes, even
18 | after your contribution becomes part of this project. So this agreement simply
19 | gives us permission to use and redistribute your contributions as part of the
20 | project. Head over to to see your current
21 | agreements on file or to sign a new one.
22 |
23 | You generally only need to submit a CLA once, so if you've already submitted one
24 | (even if it was for a different project), you probably don't need to do it
25 | again.
26 |
27 | ## Code reviews
28 |
29 | All submissions, including submissions by project members, require review. We
30 | use GitHub pull requests for this purpose. Consult [GitHub Help] for more
31 | information on using pull requests.
32 |
33 | [GitHub Help]: https://help.github.com/articles/about-pull-requests/
34 |
--------------------------------------------------------------------------------
/feature_extractor/feature_extractor_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for feature_extractor."""
15 |
16 | import json
17 | import os
18 | import feature_extractor
19 | import numpy
20 | from PIL import Image
21 | from six.moves import cPickle
22 | from tensorflow.python.platform import googletest
23 |
24 |
25 | def _FilePath(filename):
26 | return os.path.join('testdata', filename)
27 |
28 |
29 | def _MeanElementWiseDifference(a, b):
30 | """Calculates element-wise percent difference between two numpy matrices."""
31 | difference = numpy.abs(a - b)
32 | denominator = numpy.maximum(numpy.abs(a), numpy.abs(b))
33 |
34 | # We dont care if one is 0 and another is 0.01
35 | return (difference / (0.01 + denominator)).mean()
36 |
37 |
38 | class FeatureExtractorTest(googletest.TestCase):
39 |
40 | def setUp(self):
41 | self._extractor = feature_extractor.YouTube8MFeatureExtractor()
42 |
43 | def testPCAOnFeatureVector(self):
44 | sports_1m_test_data = cPickle.load(open(_FilePath('sports1m_frame.pkl')))
45 | actual_pca = self._extractor.apply_pca(sports_1m_test_data['original'])
46 | expected_pca = sports_1m_test_data['pca']
47 | self.assertLess(_MeanElementWiseDifference(actual_pca, expected_pca), 1e-5)
48 |
49 |
50 | if __name__ == '__main__':
51 | googletest.main()
52 |
--------------------------------------------------------------------------------
/export_model_mediapipe.py:
--------------------------------------------------------------------------------
1 | # Lint as: python3
2 | import numpy as np
3 | import tensorflow as tf
4 | from tensorflow import app
5 | from tensorflow import flags
6 |
7 | FLAGS = flags.FLAGS
8 |
9 |
10 | def main(unused_argv):
11 | # Get the input tensor names to be replaced.
12 | tf.reset_default_graph()
13 | meta_graph_location = FLAGS.checkpoint_file + ".meta"
14 | tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
15 |
16 | input_tensor_name = tf.get_collection("input_batch_raw")[0].name
17 | num_frames_tensor_name = tf.get_collection("num_frames")[0].name
18 |
19 | # Create output graph.
20 | saver = tf.train.Saver()
21 | tf.reset_default_graph()
22 |
23 | input_feature_placeholder = tf.placeholder(
24 | tf.float32, shape=(None, None, 1152))
25 | num_frames_placeholder = tf.placeholder(tf.int32, shape=(None, 1))
26 |
27 | saver = tf.train.import_meta_graph(
28 | meta_graph_location,
29 | input_map={
30 | input_tensor_name: input_feature_placeholder,
31 | num_frames_tensor_name: tf.squeeze(num_frames_placeholder, axis=1)
32 | },
33 | clear_devices=True)
34 | predictions_tensor = tf.get_collection("predictions")[0]
35 |
36 | with tf.Session() as sess:
37 | print("restoring variables from " + FLAGS.checkpoint_file)
38 | saver.restore(sess, FLAGS.checkpoint_file)
39 | tf.saved_model.simple_save(
40 | sess,
41 | FLAGS.output_dir,
42 | inputs={'rgb_and_audio': input_feature_placeholder,
43 | 'num_frames': num_frames_placeholder},
44 | outputs={'predictions': predictions_tensor})
45 |
46 | # Try running inference.
47 | predictions = sess.run(
48 | [predictions_tensor],
49 | feed_dict={
50 | input_feature_placeholder: np.zeros((3, 7, 1152), dtype=np.float32),
51 | num_frames_placeholder: np.array([[7]], dtype=np.int32)})
52 | print('Test inference:', predictions)
53 |
54 | print('Model saved to ', FLAGS.output_dir)
55 |
56 |
57 | if __name__ == '__main__':
58 | flags.DEFINE_string('checkpoint_file', None, 'Path to the checkpoint file.')
59 | flags.DEFINE_string('output_dir', None, 'SavedModel output directory.')
60 | app.run(main)
61 |
--------------------------------------------------------------------------------
/convert_prediction_from_json_to_csv.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utility to convert the output of batch prediction into a CSV submission.
15 |
16 | It converts the JSON files created by the command
17 | 'gcloud beta ml jobs submit prediction' into a CSV file ready for submission.
18 | """
19 |
20 | import json
21 | import tensorflow as tf
22 |
23 | from builtins import range
24 | from tensorflow import app
25 | from tensorflow import flags
26 | from tensorflow import gfile
27 | from tensorflow import logging
28 |
29 | FLAGS = flags.FLAGS
30 |
31 | if __name__ == "__main__":
32 |
33 | flags.DEFINE_string(
34 | "json_prediction_files_pattern", None,
35 | "Pattern specifying the list of JSON files that the command "
36 | "'gcloud beta ml jobs submit prediction' outputs. These files are "
37 | "located in the output path of the prediction command and are prefixed "
38 | "with 'prediction.results'.")
39 | flags.DEFINE_string(
40 | "csv_output_file", None,
41 | "The file to save the predictions converted to the CSV format.")
42 |
43 |
44 | def get_csv_header():
45 | return "VideoId,LabelConfidencePairs\n"
46 |
47 |
48 | def to_csv_row(json_data):
49 |
50 | video_id = json_data["video_id"]
51 |
52 | class_indexes = json_data["class_indexes"]
53 | predictions = json_data["predictions"]
54 |
55 | if isinstance(video_id, list):
56 | video_id = video_id[0]
57 | class_indexes = class_indexes[0]
58 | predictions = predictions[0]
59 |
60 | if len(class_indexes) != len(predictions):
61 | raise ValueError(
62 | "The number of indexes (%s) and predictions (%s) must be equal." %
63 | (len(class_indexes), len(predictions)))
64 |
65 | return (video_id.decode("utf-8") + "," +
66 | " ".join("%i %f" % (class_indexes[i], predictions[i])
67 | for i in range(len(class_indexes))) + "\n")
68 |
69 |
70 | def main(unused_argv):
71 | logging.set_verbosity(tf.logging.INFO)
72 |
73 | if not FLAGS.json_prediction_files_pattern:
74 | raise ValueError(
75 | "The flag --json_prediction_files_pattern must be specified.")
76 |
77 | if not FLAGS.csv_output_file:
78 | raise ValueError("The flag --csv_output_file must be specified.")
79 |
80 | logging.info("Looking for prediction files with pattern: %s",
81 | FLAGS.json_prediction_files_pattern)
82 |
83 | file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern)
84 | logging.info("Found files: %s", file_paths)
85 |
86 | logging.info("Writing submission file to: %s", FLAGS.csv_output_file)
87 | with gfile.Open(FLAGS.csv_output_file, "w+") as output_file:
88 | output_file.write(get_csv_header())
89 |
90 | for file_path in file_paths:
91 | logging.info("processing file: %s", file_path)
92 |
93 | with gfile.Open(file_path) as input_file:
94 |
95 | for line in input_file:
96 | json_data = json.loads(line)
97 | output_file.write(to_csv_row(json_data))
98 |
99 | output_file.flush()
100 | logging.info("done")
101 |
102 |
103 | if __name__ == "__main__":
104 | app.run()
105 |
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Contains a collection of util functions for model construction."""
15 | import numpy
16 | import tensorflow as tf
17 | from tensorflow import logging
18 | from tensorflow import flags
19 | import tensorflow.contrib.slim as slim
20 |
21 |
22 | def SampleRandomSequence(model_input, num_frames, num_samples):
23 | """Samples a random sequence of frames of size num_samples.
24 |
25 | Args:
26 | model_input: A tensor of size batch_size x max_frames x feature_size
27 | num_frames: A tensor of size batch_size x 1
28 | num_samples: A scalar
29 |
30 | Returns:
31 | `model_input`: A tensor of size batch_size x num_samples x feature_size
32 | """
33 |
34 | batch_size = tf.shape(model_input)[0]
35 | frame_index_offset = tf.tile(tf.expand_dims(tf.range(num_samples), 0),
36 | [batch_size, 1])
37 | max_start_frame_index = tf.maximum(num_frames - num_samples, 0)
38 | start_frame_index = tf.cast(
39 | tf.multiply(tf.random_uniform([batch_size, 1]),
40 | tf.cast(max_start_frame_index + 1, tf.float32)), tf.int32)
41 | frame_index = tf.minimum(start_frame_index + frame_index_offset,
42 | tf.cast(num_frames - 1, tf.int32))
43 | batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
44 | [1, num_samples])
45 | index = tf.stack([batch_index, frame_index], 2)
46 | return tf.gather_nd(model_input, index)
47 |
48 |
49 | def SampleRandomFrames(model_input, num_frames, num_samples):
50 | """Samples a random set of frames of size num_samples.
51 |
52 | Args:
53 | model_input: A tensor of size batch_size x max_frames x feature_size
54 | num_frames: A tensor of size batch_size x 1
55 | num_samples: A scalar
56 |
57 | Returns:
58 | `model_input`: A tensor of size batch_size x num_samples x feature_size
59 | """
60 | batch_size = tf.shape(model_input)[0]
61 | frame_index = tf.cast(
62 | tf.multiply(tf.random_uniform([batch_size, num_samples]),
63 | tf.tile(tf.cast(num_frames, tf.float32), [1, num_samples])),
64 | tf.int32)
65 | batch_index = tf.tile(tf.expand_dims(tf.range(batch_size), 1),
66 | [1, num_samples])
67 | index = tf.stack([batch_index, frame_index], 2)
68 | return tf.gather_nd(model_input, index)
69 |
70 |
71 | def FramePooling(frames, method, **unused_params):
72 | """Pools over the frames of a video.
73 |
74 | Args:
75 | frames: A tensor with shape [batch_size, num_frames, feature_size].
76 | method: "average", "max", "attention", or "none".
77 |
78 | Returns:
79 | A tensor with shape [batch_size, feature_size] for average, max, or
80 | attention pooling. A tensor with shape [batch_size*num_frames, feature_size]
81 | for none pooling.
82 |
83 | Raises:
84 | ValueError: if method is other than "average", "max", "attention", or
85 | "none".
86 | """
87 | if method == "average":
88 | return tf.reduce_mean(frames, 1)
89 | elif method == "max":
90 | return tf.reduce_max(frames, 1)
91 | elif method == "none":
92 | feature_size = frames.shape_as_list()[2]
93 | return tf.reshape(frames, [-1, feature_size])
94 | else:
95 | raise ValueError("Unrecognized pooling method: %s" % method)
96 |
--------------------------------------------------------------------------------
/feature_extractor/README.md:
--------------------------------------------------------------------------------
1 | # Please use the [MediaPipe YouTube8M feature extractor](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m) which extracts both RGB and audio features instead.
2 |
3 | ---
4 |
5 | # YouTube8M Feature Extractor (DEPRECATED)
6 |
7 | This directory contains binary and library code that can extract YouTube8M
8 | features from images and videos. The code requires the Inception TensorFlow
9 | model ([tutorial](https://www.tensorflow.org/tutorials/image_recognition)) and
10 | our PCA matrix, as outlined in Section 3.3 of our
11 | [paper](https://arxiv.org/abs/1609.08675). The first time you use our code, it
12 | will **automatically** download the inception model (75 Megabytes, tensorflow
13 | [GraphDef proto](https://www.tensorflow.org/api_docs/python/tf/GraphDef),
14 | [download link](http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz))
15 | and the PCA matrix (25 Megabytes, Numpy arrays,
16 | [download link](http://data.yt8m.org/yt8m_pca.tgz)).
17 |
18 | ## Usage
19 |
20 | There are two ways to use this code:
21 |
22 | 1. Binary `extract_tfrecords_main.py` processes a CSV file of videos (and their
23 | labels) and outputs `tfrecord` file. Files created with this binary match
24 | the schema of YouTube-8M dataset files, and are therefore are compatible
25 | with our training starter code. You can also use the file for inference
26 | using your models that are pre-trained on YouTube-8M.
27 | 1. Library `feature_extractor.py` which can extract features from images.
28 |
29 | ### Using the Binary to create `tfrecords` from videos
30 |
31 | You can use binary `extract_tfrecords_main.py` to create `tfrecord` files.
32 | However, this binary assumes that you have OpenCV properly installed (see end of
33 | subsection). Assume that you have two videos `/path/to/vid1` and
34 | `/path/to/vid2`, respectively, with multi-integer labels of `(52, 3, 10)` and
35 | `(7, 67)`. To create `tfrecord` containing features and labels for those videos,
36 | you must first create a CSV file (e.g. on `/path/to/vid_dataset.csv`) with
37 | contents:
38 |
39 | /path/to/vid1,52;3;10
40 | /path/to/vid2,7;67
41 |
42 | Note that the CSV is comma-separated but the label-field is semi-colon separated
43 | to allow for multiple labels per video.
44 |
45 | Then, you can create the `tfrecord` by calling the binary:
46 |
47 | python extract_tfrecords_main.py --input_videos_csv /path/to/vid_dataset.csv \
48 | --output_tfrecords_file /path/to/output.tfrecord
49 |
50 | Now, you can use the output file for training and/or inference using our starter
51 | code.
52 |
53 | `extract_tfrecords_main.py` requires OpenCV python bindings to be installed and
54 | linked with ffmpeg. In other words, running this command should print `True`:
55 |
56 | python -c 'import cv2; print cv2.VideoCapture().open("/path/to/some/video.mp4")'
57 |
58 | ### Using the library to extract features from images
59 |
60 | To extract our features from an image file `cropped_panda.jpg`, you can use this
61 | python code:
62 |
63 | ```python
64 | from PIL import Image
65 | import numpy
66 |
67 | # Instantiate extractor. Slow if called first time on your machine, as it
68 | # needs to download 100 MB.
69 | extractor = YouTube8MFeatureExtractor()
70 |
71 | image_file = os.path.join(extractor._model_dir, 'cropped_panda.jpg')
72 |
73 | im = numpy.array(Image.open(image_file))
74 | features = extractor.extract_rgb_frame_features(im)
75 | ```
76 |
77 | The constructor `extractor = YouTube8MFeatureExtractor()` will create a
78 | directory `~/yt8m/`, if it does not exist, and will download and untar the two
79 | model files (inception and PCA matrix). If you prefer, you can point our
80 | extractor to another directory as:
81 |
82 | ```python
83 | extractor = YouTube8MFeatureExtractor(model_dir="/path/to/yt8m_files")
84 | ```
85 |
86 | You can also pre-populate your custom `"/path/to/yt8m_files"` by manually
87 | downloading (e.g. using `wget`) the URLs and un-tarring them, for example:
88 |
89 | ```bash
90 | mkdir -p /path/to/yt8m_files
91 | cd /path/to/yt8m_files
92 |
93 | wget http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
94 | wget http://data.yt8m.org/yt8m_pca.tgz
95 |
96 | tar zxvf inception-2015-12-05.tgz
97 | tar zxvf yt8m_pca.tgz
98 | ```
99 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Provides definitions for non-regularized training or test losses."""
15 |
16 | import tensorflow as tf
17 |
18 |
19 | class BaseLoss(object):
20 | """Inherit from this class when implementing new losses."""
21 |
22 | def calculate_loss(self, unused_predictions, unused_labels, **unused_params):
23 | """Calculates the average loss of the examples in a mini-batch.
24 |
25 | Args:
26 | unused_predictions: a 2-d tensor storing the prediction scores, in which
27 | each row represents a sample in the mini-batch and each column
28 | represents a class.
29 | unused_labels: a 2-d tensor storing the labels, which has the same shape
30 | as the unused_predictions. The labels must be in the range of 0 and 1.
31 | unused_params: loss specific parameters.
32 |
33 | Returns:
34 | A scalar loss tensor.
35 | """
36 | raise NotImplementedError()
37 |
38 |
39 | class CrossEntropyLoss(BaseLoss):
40 | """Calculate the cross entropy loss between the predictions and labels."""
41 |
42 | def calculate_loss(self,
43 | predictions,
44 | labels,
45 | label_weights=None,
46 | **unused_params):
47 | with tf.name_scope("loss_xent"):
48 | epsilon = 1e-5
49 | float_labels = tf.cast(labels, tf.float32)
50 | cross_entropy_loss = float_labels * tf.math.log(predictions + epsilon) + (
51 | 1 - float_labels) * tf.math.log(1 - predictions + epsilon)
52 | cross_entropy_loss = tf.negative(cross_entropy_loss)
53 | if label_weights is not None:
54 | cross_entropy_loss *= label_weights
55 | return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
56 |
57 |
58 | class HingeLoss(BaseLoss):
59 | """Calculate the hinge loss between the predictions and labels.
60 |
61 | Note the subgradient is used in the backpropagation, and thus the optimization
62 | may converge slower. The predictions trained by the hinge loss are between -1
63 | and +1.
64 | """
65 |
66 | def calculate_loss(self, predictions, labels, b=1.0, **unused_params):
67 | with tf.name_scope("loss_hinge"):
68 | float_labels = tf.cast(labels, tf.float32)
69 | all_zeros = tf.zeros(tf.shape(float_labels), dtype=tf.float32)
70 | all_ones = tf.ones(tf.shape(float_labels), dtype=tf.float32)
71 | sign_labels = tf.subtract(tf.scalar_mul(2, float_labels), all_ones)
72 | hinge_loss = tf.maximum(
73 | all_zeros,
74 | tf.scalar_mul(b, all_ones) - sign_labels * predictions)
75 | return tf.reduce_mean(tf.reduce_sum(hinge_loss, 1))
76 |
77 |
78 | class SoftmaxLoss(BaseLoss):
79 | """Calculate the softmax loss between the predictions and labels.
80 |
81 | The function calculates the loss in the following way: first we feed the
82 | predictions to the softmax activation function and then we calculate
83 | the minus linear dot product between the logged softmax activations and the
84 | normalized ground truth label.
85 |
86 | It is an extension to the one-hot label. It allows for more than one positive
87 | labels for each sample.
88 | """
89 |
90 | def calculate_loss(self, predictions, labels, **unused_params):
91 | with tf.name_scope("loss_softmax"):
92 | epsilon = 10e-8
93 | float_labels = tf.cast(labels, tf.float32)
94 | # l1 normalization (labels are no less than 0)
95 | label_rowsum = tf.maximum(tf.reduce_sum(float_labels, 1, keep_dims=True),
96 | epsilon)
97 | norm_float_labels = tf.div(float_labels, label_rowsum)
98 | softmax_outputs = tf.nn.softmax(predictions)
99 | softmax_loss = tf.negative(
100 | tf.reduce_sum(tf.multiply(norm_float_labels, tf.log(softmax_outputs)),
101 | 1))
102 | return tf.reduce_mean(softmax_loss)
103 |
--------------------------------------------------------------------------------
/video_level_models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Contains model definitions."""
15 | import math
16 |
17 | import models
18 | import tensorflow as tf
19 | import utils
20 |
21 | from tensorflow import flags
22 | import tensorflow.contrib.slim as slim
23 |
24 | FLAGS = flags.FLAGS
25 | flags.DEFINE_integer(
26 | "moe_num_mixtures", 2,
27 | "The number of mixtures (excluding the dummy 'expert') used for MoeModel.")
28 |
29 |
30 | class LogisticModel(models.BaseModel):
31 | """Logistic model with L2 regularization."""
32 |
33 | def create_model(self,
34 | model_input,
35 | vocab_size,
36 | l2_penalty=1e-8,
37 | **unused_params):
38 | """Creates a logistic model.
39 |
40 | Args:
41 | model_input: 'batch' x 'num_features' matrix of input features.
42 | vocab_size: The number of classes in the dataset.
43 |
44 | Returns:
45 | A dictionary with a tensor containing the probability predictions of the
46 | model in the 'predictions' key. The dimensions of the tensor are
47 | batch_size x num_classes.
48 | """
49 | output = slim.fully_connected(
50 | model_input,
51 | vocab_size,
52 | activation_fn=tf.nn.sigmoid,
53 | weights_regularizer=slim.l2_regularizer(l2_penalty))
54 | return {"predictions": output}
55 |
56 |
57 | class MoeModel(models.BaseModel):
58 | """A softmax over a mixture of logistic models (with L2 regularization)."""
59 |
60 | def create_model(self,
61 | model_input,
62 | vocab_size,
63 | num_mixtures=None,
64 | l2_penalty=1e-8,
65 | **unused_params):
66 | """Creates a Mixture of (Logistic) Experts model.
67 |
68 | The model consists of a per-class softmax distribution over a
69 | configurable number of logistic classifiers. One of the classifiers in the
70 | mixture is not trained, and always predicts 0.
71 |
72 | Args:
73 | model_input: 'batch_size' x 'num_features' matrix of input features.
74 | vocab_size: The number of classes in the dataset.
75 | num_mixtures: The number of mixtures (excluding a dummy 'expert' that
76 | always predicts the non-existence of an entity).
77 | l2_penalty: How much to penalize the squared magnitudes of parameter
78 | values.
79 |
80 | Returns:
81 | A dictionary with a tensor containing the probability predictions of the
82 | model in the 'predictions' key. The dimensions of the tensor are
83 | batch_size x num_classes.
84 | """
85 | num_mixtures = num_mixtures or FLAGS.moe_num_mixtures
86 |
87 | gate_activations = slim.fully_connected(
88 | model_input,
89 | vocab_size * (num_mixtures + 1),
90 | activation_fn=None,
91 | biases_initializer=None,
92 | weights_regularizer=slim.l2_regularizer(l2_penalty),
93 | scope="gates")
94 | expert_activations = slim.fully_connected(
95 | model_input,
96 | vocab_size * num_mixtures,
97 | activation_fn=None,
98 | weights_regularizer=slim.l2_regularizer(l2_penalty),
99 | scope="experts")
100 |
101 | gating_distribution = tf.nn.softmax(
102 | tf.reshape(
103 | gate_activations,
104 | [-1, num_mixtures + 1])) # (Batch * #Labels) x (num_mixtures + 1)
105 | expert_distribution = tf.nn.sigmoid(
106 | tf.reshape(expert_activations,
107 | [-1, num_mixtures])) # (Batch * #Labels) x num_mixtures
108 |
109 | final_probabilities_by_class_and_batch = tf.reduce_sum(
110 | gating_distribution[:, :num_mixtures] * expert_distribution, 1)
111 | final_probabilities = tf.reshape(final_probabilities_by_class_and_batch,
112 | [-1, vocab_size])
113 | return {"predictions": final_probabilities}
114 |
--------------------------------------------------------------------------------
/mean_average_precision_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Calculate the mean average precision.
15 |
16 | It provides an interface for calculating mean average precision
17 | for an entire list or the top-n ranked items.
18 |
19 | Example usages:
20 | We first call the function accumulate many times to process parts of the ranked
21 | list. After processing all the parts, we call peek_map_at_n
22 | to calculate the mean average precision.
23 |
24 | ```
25 | import random
26 |
27 | p = np.array([[random.random() for _ in xrange(50)] for _ in xrange(1000)])
28 | a = np.array([[random.choice([0, 1]) for _ in xrange(50)]
29 | for _ in xrange(1000)])
30 |
31 | # mean average precision for 50 classes.
32 | calculator = mean_average_precision_calculator.MeanAveragePrecisionCalculator(
33 | num_class=50)
34 | calculator.accumulate(p, a)
35 | aps = calculator.peek_map_at_n()
36 | ```
37 | """
38 |
39 | import average_precision_calculator
40 |
41 |
42 | class MeanAveragePrecisionCalculator(object):
43 | """This class is to calculate mean average precision."""
44 |
45 | def __init__(self, num_class, filter_empty_classes=True, top_n=None):
46 | """Construct a calculator to calculate the (macro) average precision.
47 |
48 | Args:
49 | num_class: A positive Integer specifying the number of classes.
50 | filter_empty_classes: whether to filter classes without any positives.
51 | top_n: A positive Integer specifying the average precision at n, or None
52 | to use all provided data points.
53 |
54 | Raises:
55 | ValueError: An error occurred when num_class is not a positive integer;
56 | or the top_n_array is not a list of positive integers.
57 | """
58 | if not isinstance(num_class, int) or num_class <= 1:
59 | raise ValueError("num_class must be a positive integer.")
60 |
61 | self._ap_calculators = [] # member of AveragePrecisionCalculator
62 | self._num_class = num_class # total number of classes
63 | self._filter_empty_classes = filter_empty_classes
64 | for _ in range(num_class):
65 | self._ap_calculators.append(
66 | average_precision_calculator.AveragePrecisionCalculator(top_n=top_n))
67 |
68 | def accumulate(self, predictions, actuals, num_positives=None):
69 | """Accumulate the predictions and their ground truth labels.
70 |
71 | Args:
72 | predictions: A list of lists storing the prediction scores. The outer
73 | dimension corresponds to classes.
74 | actuals: A list of lists storing the ground truth labels. The dimensions
75 | should correspond to the predictions input. Any value larger than 0 will
76 | be treated as positives, otherwise as negatives.
77 | num_positives: If provided, it is a list of numbers representing the
78 | number of true positives for each class. If not provided, the number of
79 | true positives will be inferred from the 'actuals' array.
80 |
81 | Raises:
82 | ValueError: An error occurred when the shape of predictions and actuals
83 | does not match.
84 | """
85 | if not num_positives:
86 | num_positives = [None for i in range(self._num_class)]
87 |
88 | calculators = self._ap_calculators
89 | for i in range(self._num_class):
90 | calculators[i].accumulate(predictions[i], actuals[i], num_positives[i])
91 |
92 | def clear(self):
93 | for calculator in self._ap_calculators:
94 | calculator.clear()
95 |
96 | def is_empty(self):
97 | return ([calculator.heap_size for calculator in self._ap_calculators
98 | ] == [0 for _ in range(self._num_class)])
99 |
100 | def peek_map_at_n(self):
101 | """Peek the non-interpolated mean average precision at n.
102 |
103 | Returns:
104 | An array of non-interpolated average precision at n (default 0) for each
105 | class.
106 | """
107 | aps = []
108 | for i in range(self._num_class):
109 | if (not self._filter_empty_classes or
110 | self._ap_calculators[i].num_accumulated_positives > 0):
111 | ap = self._ap_calculators[i].peek_ap_at_n()
112 | aps.append(ap)
113 | return aps
114 |
--------------------------------------------------------------------------------
/export_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities to export a model for batch prediction."""
15 |
16 | import tensorflow as tf
17 | import tensorflow.contrib.slim as slim
18 |
19 | from tensorflow.python.saved_model import builder as saved_model_builder
20 | from tensorflow.python.saved_model import signature_constants
21 | from tensorflow.python.saved_model import signature_def_utils
22 | from tensorflow.python.saved_model import tag_constants
23 | from tensorflow.python.saved_model import utils as saved_model_utils
24 |
25 | _TOP_PREDICTIONS_IN_OUTPUT = 20
26 |
27 |
28 | class ModelExporter(object):
29 |
30 | def __init__(self, frame_features, model, reader):
31 | self.frame_features = frame_features
32 | self.model = model
33 | self.reader = reader
34 |
35 | with tf.Graph().as_default() as graph:
36 | self.inputs, self.outputs = self.build_inputs_and_outputs()
37 | self.graph = graph
38 | self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True)
39 |
40 | def export_model(self, model_dir, global_step_val, last_checkpoint):
41 | """Exports the model so that it can used for batch predictions."""
42 |
43 | with self.graph.as_default():
44 | with tf.Session() as session:
45 | session.run(tf.global_variables_initializer())
46 | self.saver.restore(session, last_checkpoint)
47 |
48 | signature = signature_def_utils.build_signature_def(
49 | inputs=self.inputs,
50 | outputs=self.outputs,
51 | method_name=signature_constants.PREDICT_METHOD_NAME)
52 |
53 | signature_map = {
54 | signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
55 | }
56 |
57 | model_builder = saved_model_builder.SavedModelBuilder(model_dir)
58 | model_builder.add_meta_graph_and_variables(
59 | session,
60 | tags=[tag_constants.SERVING],
61 | signature_def_map=signature_map,
62 | clear_devices=True)
63 | model_builder.save()
64 |
65 | def build_inputs_and_outputs(self):
66 | if self.frame_features:
67 | serialized_examples = tf.placeholder(tf.string, shape=(None,))
68 |
69 | fn = lambda x: self.build_prediction_graph(x)
70 | video_id_output, top_indices_output, top_predictions_output = (tf.map_fn(
71 | fn, serialized_examples, dtype=(tf.string, tf.int32, tf.float32)))
72 |
73 | else:
74 | serialized_examples = tf.placeholder(tf.string, shape=(None,))
75 |
76 | video_id_output, top_indices_output, top_predictions_output = (
77 | self.build_prediction_graph(serialized_examples))
78 |
79 | inputs = {
80 | "example_bytes":
81 | saved_model_utils.build_tensor_info(serialized_examples)
82 | }
83 |
84 | outputs = {
85 | "video_id":
86 | saved_model_utils.build_tensor_info(video_id_output),
87 | "class_indexes":
88 | saved_model_utils.build_tensor_info(top_indices_output),
89 | "predictions":
90 | saved_model_utils.build_tensor_info(top_predictions_output)
91 | }
92 |
93 | return inputs, outputs
94 |
95 | def build_prediction_graph(self, serialized_examples):
96 | input_data_dict = (
97 | self.reader.prepare_serialized_examples(serialized_examples))
98 | video_id = input_data_dict["video_ids"]
99 | model_input_raw = input_data_dict["video_matrix"]
100 | labels_batch = input_data_dict["labels"]
101 | num_frames = input_data_dict["num_frames"]
102 |
103 | feature_dim = len(model_input_raw.get_shape()) - 1
104 | model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
105 |
106 | with tf.variable_scope("tower"):
107 | result = self.model.create_model(model_input,
108 | num_frames=num_frames,
109 | vocab_size=self.reader.num_classes,
110 | labels=labels_batch,
111 | is_training=False)
112 |
113 | for variable in slim.get_model_variables():
114 | tf.summary.histogram(variable.op.name, variable)
115 |
116 | predictions = result["predictions"]
117 |
118 | top_predictions, top_indices = tf.nn.top_k(predictions,
119 | _TOP_PREDICTIONS_IN_OUTPUT)
120 | return video_id, top_indices, top_predictions
121 |
--------------------------------------------------------------------------------
/feature_extractor/feature_extractor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Facilitates extracting YouTube8M features from RGB images."""
15 |
16 | import os
17 | import sys
18 | import tarfile
19 | import numpy
20 | from six.moves import urllib
21 | import tensorflow as tf
22 |
23 | INCEPTION_TF_GRAPH = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
24 | YT8M_PCA_MAT = 'http://data.yt8m.org/yt8m_pca.tgz'
25 | MODEL_DIR = os.path.join(os.getenv('HOME'), 'yt8m')
26 |
27 |
28 | class YouTube8MFeatureExtractor(object):
29 | """Extracts YouTube8M features for RGB frames.
30 |
31 | First time constructing this class will create directory `yt8m` inside your
32 | home directory, and will download inception model (85 MB) and YouTube8M PCA
33 | matrix (15 MB). If you want to use another directory, then pass it to argument
34 | `model_dir` of constructor.
35 |
36 | If the model_dir exist and contains the necessary files, then files will be
37 | re-used without download.
38 |
39 | Usage Example:
40 |
41 | from PIL import Image
42 | import numpy
43 |
44 | # Instantiate extractor. Slow if called first time on your machine, as it
45 | # needs to download 100 MB.
46 | extractor = YouTube8MFeatureExtractor()
47 |
48 | image_file = os.path.join(extractor._model_dir, 'cropped_panda.jpg')
49 |
50 | im = numpy.array(Image.open(image_file))
51 | features = extractor.extract_rgb_frame_features(im)
52 |
53 | ** Note: OpenCV reverses the order of channels (i.e. orders channels as BGR
54 | instead of RGB). If you are using OpenCV, then you must do:
55 |
56 | im = im[:, :, ::-1] # Reverses order on last (i.e. channel) dimension.
57 |
58 | then call `extractor.extract_rgb_frame_features(im)`
59 | """
60 |
61 | def __init__(self, model_dir=MODEL_DIR):
62 | # Create MODEL_DIR if not created.
63 | self._model_dir = model_dir
64 | if not os.path.exists(model_dir):
65 | os.makedirs(model_dir)
66 |
67 | # Load PCA Matrix.
68 | download_path = self._maybe_download(YT8M_PCA_MAT)
69 | pca_mean = os.path.join(self._model_dir, 'mean.npy')
70 | if not os.path.exists(pca_mean):
71 | tarfile.open(download_path, 'r:gz').extractall(model_dir)
72 | self._load_pca()
73 |
74 | # Load Inception Network
75 | download_path = self._maybe_download(INCEPTION_TF_GRAPH)
76 | inception_proto_file = os.path.join(self._model_dir,
77 | 'classify_image_graph_def.pb')
78 | if not os.path.exists(inception_proto_file):
79 | tarfile.open(download_path, 'r:gz').extractall(model_dir)
80 | self._load_inception(inception_proto_file)
81 |
82 | def extract_rgb_frame_features(self, frame_rgb, apply_pca=True):
83 | """Applies the YouTube8M feature extraction over an RGB frame.
84 |
85 | This passes `frame_rgb` to inception3 model, extracting hidden layer
86 | activations and passing it to the YouTube8M PCA transformation.
87 |
88 | Args:
89 | frame_rgb: numpy array of uint8 with shape (height, width, channels) where
90 | channels must be 3 (RGB), and height and weight can be anything, as the
91 | inception model will resize.
92 | apply_pca: If not set, PCA transformation will be skipped.
93 |
94 | Returns:
95 | Output of inception from `frame_rgb` (2048-D) and optionally passed into
96 | YouTube8M PCA transformation (1024-D).
97 | """
98 | assert len(frame_rgb.shape) == 3
99 | assert frame_rgb.shape[2] == 3 # 3 channels (R, G, B)
100 | with self._inception_graph.as_default():
101 | if apply_pca:
102 | frame_features = self.session.run(
103 | 'pca_final_feature:0', feed_dict={'DecodeJpeg:0': frame_rgb})
104 | else:
105 | frame_features = self.session.run(
106 | 'pool_3/_reshape:0', feed_dict={'DecodeJpeg:0': frame_rgb})
107 | frame_features = frame_features[0]
108 | return frame_features
109 |
110 | def apply_pca(self, frame_features):
111 | """Applies the YouTube8M PCA Transformation over `frame_features`.
112 |
113 | Args:
114 | frame_features: numpy array of floats, 2048 dimensional vector.
115 |
116 | Returns:
117 | 1024 dimensional vector as a numpy array.
118 | """
119 | # Subtract mean
120 | feats = frame_features - self.pca_mean
121 |
122 | # Multiply by eigenvectors.
123 | feats = feats.reshape((1, 2048)).dot(self.pca_eigenvecs).reshape((1024,))
124 |
125 | # Whiten
126 | feats /= numpy.sqrt(self.pca_eigenvals + 1e-4)
127 | return feats
128 |
129 | def _maybe_download(self, url):
130 | """Downloads `url` if not in `_model_dir`."""
131 | filename = os.path.basename(url)
132 | download_path = os.path.join(self._model_dir, filename)
133 | if os.path.exists(download_path):
134 | return download_path
135 |
136 | def _progress(count, block_size, total_size):
137 | sys.stdout.write(
138 | '\r>> Downloading %s %.1f%%' %
139 | (filename, float(count * block_size) / float(total_size) * 100.0))
140 | sys.stdout.flush()
141 |
142 | urllib.request.urlretrieve(url, download_path, _progress)
143 | statinfo = os.stat(download_path)
144 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
145 | return download_path
146 |
147 | def _load_inception(self, proto_file):
148 | graph_def = tf.GraphDef.FromString(open(proto_file, 'rb').read())
149 | self._inception_graph = tf.Graph()
150 | with self._inception_graph.as_default():
151 | _ = tf.import_graph_def(graph_def, name='')
152 | self.session = tf.Session()
153 | Frame_Features = self.session.graph.get_tensor_by_name(
154 | 'pool_3/_reshape:0')
155 | Pca_Mean = tf.constant(value=self.pca_mean, dtype=tf.float32)
156 | Pca_Eigenvecs = tf.constant(value=self.pca_eigenvecs, dtype=tf.float32)
157 | Pca_Eigenvals = tf.constant(value=self.pca_eigenvals, dtype=tf.float32)
158 | Feats = Frame_Features[0] - Pca_Mean
159 | Feats = tf.reshape(
160 | tf.matmul(tf.reshape(Feats, [1, 2048]), Pca_Eigenvecs), [
161 | 1024,
162 | ])
163 | tf.divide(Feats, tf.sqrt(Pca_Eigenvals + 1e-4), name='pca_final_feature')
164 |
165 | def _load_pca(self):
166 | self.pca_mean = numpy.load(os.path.join(self._model_dir, 'mean.npy'))[:, 0]
167 | self.pca_eigenvals = numpy.load(
168 | os.path.join(self._model_dir, 'eigenvals.npy'))[:1024, 0]
169 | self.pca_eigenvecs = numpy.load(
170 | os.path.join(self._model_dir, 'eigenvecs.npy')).T[:, :1024]
171 |
--------------------------------------------------------------------------------
/segment_eval_inference.py:
--------------------------------------------------------------------------------
1 | """Eval mAP@N metric from inference file."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | from absl import app
8 | from absl import flags
9 |
10 | import mean_average_precision_calculator as map_calculator
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | flags.DEFINE_string(
15 | "eval_data_pattern", "",
16 | "File glob defining the evaluation dataset in tensorflow.SequenceExample "
17 | "format. The SequenceExamples are expected to have an 'rgb' byte array "
18 | "sequence feature as well as a 'labels' int64 context feature.")
19 | flags.DEFINE_string(
20 | "label_cache", "",
21 | "The path for the label cache file. Leave blank for not to cache.")
22 | flags.DEFINE_string("submission_file", "",
23 | "The segment submission file generated by inference.py.")
24 | flags.DEFINE_integer(
25 | "top_n", 0,
26 | "The cap per-class predictions by a maximum of N. Use 0 for not capping.")
27 |
28 | FLAGS = flags.FLAGS
29 |
30 |
31 | class Labels(object):
32 | """Contains the class to hold label objects.
33 |
34 | This class can serialize and de-serialize the groundtruths.
35 | The ground truth is in a mapping from (segment_id, class_id) -> label_score.
36 | """
37 |
38 | def __init__(self, labels):
39 | """__init__ method."""
40 | self._labels = labels
41 |
42 | @property
43 | def labels(self):
44 | """Return the ground truth mapping. See class docstring for details."""
45 | return self._labels
46 |
47 | def to_file(self, file_name):
48 | """Materialize the GT mapping to file."""
49 | with tf.gfile.Open(file_name, "w") as fobj:
50 | for k, v in self._labels.items():
51 | seg_id, label = k
52 | line = "%s,%s,%s\n" % (seg_id, label, v)
53 | fobj.write(line)
54 |
55 | @classmethod
56 | def from_file(cls, file_name):
57 | """Read the GT mapping from cached file."""
58 | labels = {}
59 | with tf.gfile.Open(file_name) as fobj:
60 | for line in fobj:
61 | line = line.strip().strip("\n")
62 | seg_id, label, score = line.split(",")
63 | labels[(seg_id, int(label))] = float(score)
64 | return cls(labels)
65 |
66 |
67 | def read_labels(data_pattern, cache_path=""):
68 | """Read labels from TFRecords.
69 |
70 | Args:
71 | data_pattern: the data pattern to the TFRecords.
72 | cache_path: the cache path for the label file.
73 |
74 | Returns:
75 | a Labels object.
76 | """
77 | if cache_path:
78 | if tf.gfile.Exists(cache_path):
79 | tf.logging.info("Reading cached labels from %s..." % cache_path)
80 | return Labels.from_file(cache_path)
81 | tf.enable_eager_execution()
82 | data_paths = tf.gfile.Glob(data_pattern)
83 | ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50)
84 | context_features = {
85 | "id": tf.FixedLenFeature([], tf.string),
86 | "segment_labels": tf.VarLenFeature(tf.int64),
87 | "segment_start_times": tf.VarLenFeature(tf.int64),
88 | "segment_scores": tf.VarLenFeature(tf.float32)
89 | }
90 |
91 | def _parse_se_func(sequence_example):
92 | return tf.parse_single_sequence_example(sequence_example,
93 | context_features=context_features)
94 |
95 | ds = ds.map(_parse_se_func)
96 | rated_labels = {}
97 | tf.logging.info("Reading labels from TFRecords...")
98 | last_batch = 0
99 | batch_size = 5000
100 | for cxt_feature_val, _ in ds:
101 | video_id = cxt_feature_val["id"].numpy()
102 | segment_labels = cxt_feature_val["segment_labels"].values.numpy()
103 | segment_start_times = cxt_feature_val["segment_start_times"].values.numpy()
104 | segment_scores = cxt_feature_val["segment_scores"].values.numpy()
105 | for label, start_time, score in zip(segment_labels, segment_start_times,
106 | segment_scores):
107 | rated_labels[("%s:%d" % (video_id, start_time), label)] = score
108 | batch_id = len(rated_labels) // batch_size
109 | if batch_id != last_batch:
110 | tf.logging.info("%d examples processed.", len(rated_labels))
111 | last_batch = batch_id
112 | tf.logging.info("Finish reading labels from TFRecords...")
113 | labels_obj = Labels(rated_labels)
114 | if cache_path:
115 | tf.logging.info("Caching labels to %s..." % cache_path)
116 | labels_obj.to_file(cache_path)
117 | return labels_obj
118 |
119 |
120 | def read_segment_predictions(file_path, labels, top_n=None):
121 | """Read segement predictions.
122 |
123 | Args:
124 | file_path: the submission file path.
125 | labels: a Labels object containing the eval labels.
126 | top_n: the per-class class capping.
127 |
128 | Returns:
129 | a segment prediction list for each classes.
130 | """
131 | cls_preds = {} # A label_id to pred list mapping.
132 | with tf.gfile.Open(file_path) as fobj:
133 | tf.logging.info("Reading predictions from %s..." % file_path)
134 | for line in fobj:
135 | label_id, pred_ids_val = line.split(",")
136 | pred_ids = pred_ids_val.split(" ")
137 | if top_n:
138 | pred_ids = pred_ids[:top_n]
139 | pred_ids = [
140 | pred_id for pred_id in pred_ids
141 | if (pred_id, int(label_id)) in labels.labels
142 | ]
143 | cls_preds[int(label_id)] = pred_ids
144 | if len(cls_preds) % 50 == 0:
145 | tf.logging.info("Processed %d classes..." % len(cls_preds))
146 | tf.logging.info("Finish reading predictions.")
147 | return cls_preds
148 |
149 |
150 | def main(unused_argv):
151 | """Entry function of the script."""
152 | if not FLAGS.submission_file:
153 | raise ValueError("You must input submission file.")
154 | eval_labels = read_labels(FLAGS.eval_data_pattern,
155 | cache_path=FLAGS.label_cache)
156 | tf.logging.info("Total rated segments: %d." % len(eval_labels.labels))
157 | positive_counter = {}
158 | for k, v in eval_labels.labels.items():
159 | _, label_id = k
160 | if v > 0:
161 | positive_counter[label_id] = positive_counter.get(label_id, 0) + 1
162 |
163 | seg_preds = read_segment_predictions(FLAGS.submission_file,
164 | eval_labels,
165 | top_n=FLAGS.top_n)
166 | map_cal = map_calculator.MeanAveragePrecisionCalculator(len(seg_preds))
167 | seg_labels = []
168 | seg_scored_preds = []
169 | num_positives = []
170 | for label_id in sorted(seg_preds):
171 | class_preds = seg_preds[label_id]
172 | seg_label = [eval_labels.labels[(pred, label_id)] for pred in class_preds]
173 | seg_labels.append(seg_label)
174 | seg_scored_pred = []
175 | if class_preds:
176 | seg_scored_pred = [
177 | float(x) / len(class_preds) for x in range(len(class_preds), 0, -1)
178 | ]
179 | seg_scored_preds.append(seg_scored_pred)
180 | num_positives.append(positive_counter[label_id])
181 | map_cal.accumulate(seg_scored_preds, seg_labels, num_positives)
182 | map_at_n = np.mean(map_cal.peek_map_at_n())
183 | tf.logging.info("Num classes: %d | mAP@%d: %.6f" %
184 | (len(seg_preds), FLAGS.top_n, map_at_n))
185 |
186 |
187 | if __name__ == "__main__":
188 | app.run(main)
189 |
--------------------------------------------------------------------------------
/feature_extractor/extract_tfrecords_main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Produces tfrecord files similar to the YouTube-8M dataset.
15 |
16 | It processes a CSV file containing lines like ",", where
17 | must be a path of a video, and must be an integer list
18 | joined with semi-colon ";". It processes all videos and outputs tfrecord file
19 | onto --output_tfrecords_file.
20 |
21 | It assumes that you have OpenCV installed and properly linked with ffmpeg (i.e.
22 | function `cv2.VideoCapture().open('/path/to/some/video')` should return True).
23 |
24 | The binary only processes the video stream (images) and not the audio stream.
25 | """
26 |
27 | import csv
28 | import os
29 | import sys
30 |
31 | import cv2
32 | import feature_extractor
33 | import numpy
34 | import tensorflow as tf
35 | from tensorflow import app
36 | from tensorflow import flags
37 |
38 | FLAGS = flags.FLAGS
39 |
40 | # In OpenCV3.X, this is available as cv2.CAP_PROP_POS_MSEC
41 | # In OpenCV2.X, this is available as cv2.cv.CV_CAP_PROP_POS_MSEC
42 | CAP_PROP_POS_MSEC = 0
43 |
44 | if __name__ == '__main__':
45 | # Required flags for input and output.
46 | flags.DEFINE_string(
47 | 'output_tfrecords_file', None,
48 | 'File containing tfrecords will be written at this path.')
49 | flags.DEFINE_string(
50 | 'input_videos_csv', None,
51 | 'CSV file with lines ",", where '
52 | ' must be a path of a video and '
53 | 'must be an integer list joined with semi-colon ";"')
54 | # Optional flags.
55 | flags.DEFINE_string('model_dir', os.path.join(os.getenv('HOME'), 'yt8m'),
56 | 'Directory to store model files. It defaults to ~/yt8m')
57 |
58 | # The following flags are set to match the YouTube-8M dataset format.
59 | flags.DEFINE_integer('frames_per_second', 1,
60 | 'This many frames per second will be processed')
61 | flags.DEFINE_boolean(
62 | 'skip_frame_level_features', False,
63 | 'If set, frame-level features will not be written: only '
64 | 'video-level features will be written with feature '
65 | 'names mean_*')
66 | flags.DEFINE_string(
67 | 'labels_feature_key', 'labels',
68 | 'Labels will be written to context feature with this '
69 | 'key, as int64 list feature.')
70 | flags.DEFINE_string(
71 | 'image_feature_key', 'rgb',
72 | 'Image features will be written to sequence feature with '
73 | 'this key, as bytes list feature, with only one entry, '
74 | 'containing quantized feature string.')
75 | flags.DEFINE_string(
76 | 'video_file_feature_key', 'id',
77 | 'Input will be written to context feature '
78 | 'with this key, as bytes list feature, with only one '
79 | 'entry, containing the file path of the video. This '
80 | 'can be used for debugging but not for training or eval.')
81 | flags.DEFINE_boolean(
82 | 'insert_zero_audio_features', True,
83 | 'If set, inserts features with name "audio" to be 128-D '
84 | 'zero vectors. This allows you to use YouTube-8M '
85 | 'pre-trained model.')
86 |
87 |
88 | def frame_iterator(filename, every_ms=1000, max_num_frames=300):
89 | """Uses OpenCV to iterate over all frames of filename at a given frequency.
90 |
91 | Args:
92 | filename: Path to video file (e.g. mp4)
93 | every_ms: The duration (in milliseconds) to skip between frames.
94 | max_num_frames: Maximum number of frames to process, taken from the
95 | beginning of the video.
96 |
97 | Yields:
98 | RGB frame with shape (image height, image width, channels)
99 | """
100 | video_capture = cv2.VideoCapture()
101 | if not video_capture.open(filename):
102 | print >> sys.stderr, 'Error: Cannot open video file ' + filename
103 | return
104 | last_ts = -99999 # The timestamp of last retrieved frame.
105 | num_retrieved = 0
106 |
107 | while num_retrieved < max_num_frames:
108 | # Skip frames
109 | while video_capture.get(CAP_PROP_POS_MSEC) < every_ms + last_ts:
110 | if not video_capture.read()[0]:
111 | return
112 |
113 | last_ts = video_capture.get(CAP_PROP_POS_MSEC)
114 | has_frames, frame = video_capture.read()
115 | if not has_frames:
116 | break
117 | yield frame
118 | num_retrieved += 1
119 |
120 |
121 | def _int64_list_feature(int64_list):
122 | return tf.train.Feature(int64_list=tf.train.Int64List(value=int64_list))
123 |
124 |
125 | def _bytes_feature(value):
126 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
127 |
128 |
129 | def _make_bytes(int_array):
130 | if bytes == str: # Python2
131 | return ''.join(map(chr, int_array))
132 | else:
133 | return bytes(int_array)
134 |
135 |
136 | def quantize(features, min_quantized_value=-2.0, max_quantized_value=2.0):
137 | """Quantizes float32 `features` into string."""
138 | assert features.dtype == 'float32'
139 | assert len(features.shape) == 1 # 1-D array
140 | features = numpy.clip(features, min_quantized_value, max_quantized_value)
141 | quantize_range = max_quantized_value - min_quantized_value
142 | features = (features - min_quantized_value) * (255.0 / quantize_range)
143 | features = [int(round(f)) for f in features]
144 |
145 | return _make_bytes(features)
146 |
147 |
148 | def main(unused_argv):
149 | extractor = feature_extractor.YouTube8MFeatureExtractor(FLAGS.model_dir)
150 | writer = tf.python_io.TFRecordWriter(FLAGS.output_tfrecords_file)
151 | total_written = 0
152 | total_error = 0
153 | for video_file, labels in csv.reader(open(FLAGS.input_videos_csv)):
154 | rgb_features = []
155 | sum_rgb_features = None
156 | for rgb in frame_iterator(
157 | video_file, every_ms=1000.0 / FLAGS.frames_per_second):
158 | features = extractor.extract_rgb_frame_features(rgb[:, :, ::-1])
159 | if sum_rgb_features is None:
160 | sum_rgb_features = features
161 | else:
162 | sum_rgb_features += features
163 | rgb_features.append(_bytes_feature(quantize(features)))
164 |
165 | if not rgb_features:
166 | print >> sys.stderr, 'Could not get features for ' + video_file
167 | total_error += 1
168 | continue
169 |
170 | mean_rgb_features = sum_rgb_features / len(rgb_features)
171 |
172 | # Create SequenceExample proto and write to output.
173 | feature_list = {
174 | FLAGS.image_feature_key: tf.train.FeatureList(feature=rgb_features),
175 | }
176 | context_features = {
177 | FLAGS.labels_feature_key:
178 | _int64_list_feature(sorted(map(int, labels.split(';')))),
179 | FLAGS.video_file_feature_key:
180 | _bytes_feature(_make_bytes(map(ord, video_file))),
181 | 'mean_' + FLAGS.image_feature_key:
182 | tf.train.Feature(
183 | float_list=tf.train.FloatList(value=mean_rgb_features)),
184 | }
185 |
186 | if FLAGS.insert_zero_audio_features:
187 | zero_vec = [0] * 128
188 | feature_list['audio'] = tf.train.FeatureList(
189 | feature=[_bytes_feature(_make_bytes(zero_vec))] * len(rgb_features))
190 | context_features['mean_audio'] = tf.train.Feature(
191 | float_list=tf.train.FloatList(value=zero_vec))
192 |
193 | if FLAGS.skip_frame_level_features:
194 | example = tf.train.SequenceExample(
195 | context=tf.train.Features(feature=context_features))
196 | else:
197 | example = tf.train.SequenceExample(
198 | context=tf.train.Features(feature=context_features),
199 | feature_lists=tf.train.FeatureLists(feature_list=feature_list))
200 | writer.write(example.SerializeToString())
201 | total_written += 1
202 |
203 | writer.close()
204 | print('Successfully encoded %i out of %i videos' %
205 | (total_written, total_written + total_error))
206 |
207 |
208 | if __name__ == '__main__':
209 | app.run(main)
210 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Contains a collection of util functions for training and evaluating."""
15 |
16 | import numpy
17 | import tensorflow as tf
18 | from tensorflow import logging
19 |
20 | try:
21 | xrange # Python 2
22 | except NameError:
23 | xrange = range # Python 3
24 |
25 |
26 | def Dequantize(feat_vector, max_quantized_value=2, min_quantized_value=-2):
27 | """Dequantize the feature from the byte format to the float format.
28 |
29 | Args:
30 | feat_vector: the input 1-d vector.
31 | max_quantized_value: the maximum of the quantized value.
32 | min_quantized_value: the minimum of the quantized value.
33 |
34 | Returns:
35 | A float vector which has the same shape as feat_vector.
36 | """
37 | assert max_quantized_value > min_quantized_value
38 | quantized_range = max_quantized_value - min_quantized_value
39 | scalar = quantized_range / 255.0
40 | bias = (quantized_range / 512.0) + min_quantized_value
41 | return feat_vector * scalar + bias
42 |
43 |
44 | def MakeSummary(name, value):
45 | """Creates a tf.Summary proto with the given name and value."""
46 | summary = tf.Summary()
47 | val = summary.value.add()
48 | val.tag = str(name)
49 | val.simple_value = float(value)
50 | return summary
51 |
52 |
53 | def AddGlobalStepSummary(summary_writer,
54 | global_step_val,
55 | global_step_info_dict,
56 | summary_scope="Eval"):
57 | """Add the global_step summary to the Tensorboard.
58 |
59 | Args:
60 | summary_writer: Tensorflow summary_writer.
61 | global_step_val: a int value of the global step.
62 | global_step_info_dict: a dictionary of the evaluation metrics calculated for
63 | a mini-batch.
64 | summary_scope: Train or Eval.
65 |
66 | Returns:
67 | A string of this global_step summary
68 | """
69 | this_hit_at_one = global_step_info_dict["hit_at_one"]
70 | this_perr = global_step_info_dict["perr"]
71 | this_loss = global_step_info_dict["loss"]
72 | examples_per_second = global_step_info_dict.get("examples_per_second", -1)
73 |
74 | summary_writer.add_summary(
75 | MakeSummary("GlobalStep/" + summary_scope + "_Hit@1", this_hit_at_one),
76 | global_step_val)
77 | summary_writer.add_summary(
78 | MakeSummary("GlobalStep/" + summary_scope + "_Perr", this_perr),
79 | global_step_val)
80 | summary_writer.add_summary(
81 | MakeSummary("GlobalStep/" + summary_scope + "_Loss", this_loss),
82 | global_step_val)
83 |
84 | if examples_per_second != -1:
85 | summary_writer.add_summary(
86 | MakeSummary("GlobalStep/" + summary_scope + "_Example_Second",
87 | examples_per_second), global_step_val)
88 |
89 | summary_writer.flush()
90 | info = (
91 | "global_step {0} | Batch Hit@1: {1:.3f} | Batch PERR: {2:.3f} | Batch "
92 | "Loss: {3:.3f} | Examples_per_sec: {4:.3f}").format(
93 | global_step_val, this_hit_at_one, this_perr, this_loss,
94 | examples_per_second)
95 | return info
96 |
97 |
98 | def AddEpochSummary(summary_writer,
99 | global_step_val,
100 | epoch_info_dict,
101 | summary_scope="Eval"):
102 | """Add the epoch summary to the Tensorboard.
103 |
104 | Args:
105 | summary_writer: Tensorflow summary_writer.
106 | global_step_val: a int value of the global step.
107 | epoch_info_dict: a dictionary of the evaluation metrics calculated for the
108 | whole epoch.
109 | summary_scope: Train or Eval.
110 |
111 | Returns:
112 | A string of this global_step summary
113 | """
114 | epoch_id = epoch_info_dict["epoch_id"]
115 | avg_hit_at_one = epoch_info_dict["avg_hit_at_one"]
116 | avg_perr = epoch_info_dict["avg_perr"]
117 | avg_loss = epoch_info_dict["avg_loss"]
118 | aps = epoch_info_dict["aps"]
119 | gap = epoch_info_dict["gap"]
120 | mean_ap = numpy.mean(aps)
121 |
122 | summary_writer.add_summary(
123 | MakeSummary("Epoch/" + summary_scope + "_Avg_Hit@1", avg_hit_at_one),
124 | global_step_val)
125 | summary_writer.add_summary(
126 | MakeSummary("Epoch/" + summary_scope + "_Avg_Perr", avg_perr),
127 | global_step_val)
128 | summary_writer.add_summary(
129 | MakeSummary("Epoch/" + summary_scope + "_Avg_Loss", avg_loss),
130 | global_step_val)
131 | summary_writer.add_summary(
132 | MakeSummary("Epoch/" + summary_scope + "_MAP", mean_ap), global_step_val)
133 | summary_writer.add_summary(
134 | MakeSummary("Epoch/" + summary_scope + "_GAP", gap), global_step_val)
135 | summary_writer.flush()
136 |
137 | info = ("epoch/eval number {0} | Avg_Hit@1: {1:.3f} | Avg_PERR: {2:.3f} "
138 | "| MAP: {3:.3f} | GAP: {4:.3f} | Avg_Loss: {5:3f} | num_classes: {6}"
139 | ).format(epoch_id, avg_hit_at_one, avg_perr, mean_ap, gap, avg_loss,
140 | len(aps))
141 | return info
142 |
143 |
144 | def GetListOfFeatureNamesAndSizes(feature_names, feature_sizes):
145 | """Extract the list of feature names and the dimensionality of each feature
146 |
147 | from string of comma separated values.
148 |
149 | Args:
150 | feature_names: string containing comma separated list of feature names
151 | feature_sizes: string containing comma separated list of feature sizes
152 |
153 | Returns:
154 | List of the feature names and list of the dimensionality of each feature.
155 | Elements in the first/second list are strings/integers.
156 | """
157 | list_of_feature_names = [
158 | feature_names.strip() for feature_names in feature_names.split(",")
159 | ]
160 | list_of_feature_sizes = [
161 | int(feature_sizes) for feature_sizes in feature_sizes.split(",")
162 | ]
163 | if len(list_of_feature_names) != len(list_of_feature_sizes):
164 | logging.error("length of the feature names (=" +
165 | str(len(list_of_feature_names)) + ") != length of feature "
166 | "sizes (=" + str(len(list_of_feature_sizes)) + ")")
167 |
168 | return list_of_feature_names, list_of_feature_sizes
169 |
170 |
171 | def clip_gradient_norms(gradients_to_variables, max_norm):
172 | """Clips the gradients by the given value.
173 |
174 | Args:
175 | gradients_to_variables: A list of gradient to variable pairs (tuples).
176 | max_norm: the maximum norm value.
177 |
178 | Returns:
179 | A list of clipped gradient to variable pairs.
180 | """
181 | clipped_grads_and_vars = []
182 | for grad, var in gradients_to_variables:
183 | if grad is not None:
184 | if isinstance(grad, tf.IndexedSlices):
185 | tmp = tf.clip_by_norm(grad.values, max_norm)
186 | grad = tf.IndexedSlices(tmp, grad.indices, grad.dense_shape)
187 | else:
188 | grad = tf.clip_by_norm(grad, max_norm)
189 | clipped_grads_and_vars.append((grad, var))
190 | return clipped_grads_and_vars
191 |
192 |
193 | def combine_gradients(tower_grads):
194 | """Calculate the combined gradient for each shared variable across all towers.
195 |
196 | Note that this function provides a synchronization point across all towers.
197 |
198 | Args:
199 | tower_grads: List of lists of (gradient, variable) tuples. The outer list is
200 | over individual gradients. The inner list is over the gradient calculation
201 | for each tower.
202 |
203 | Returns:
204 | List of pairs of (gradient, variable) where the gradient has been summed
205 | across all towers.
206 | """
207 | filtered_grads = [
208 | [x for x in grad_list if x[0] is not None] for grad_list in tower_grads
209 | ]
210 | final_grads = []
211 | for i in xrange(len(filtered_grads[0])):
212 | grads = [filtered_grads[t][i] for t in xrange(len(filtered_grads))]
213 | grad = tf.stack([x[0] for x in grads], 0)
214 | grad = tf.reduce_sum(grad, 0)
215 | final_grads.append((
216 | grad,
217 | filtered_grads[0][i][1],
218 | ))
219 |
220 | return final_grads
221 |
--------------------------------------------------------------------------------
/eval_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Provides functions to help with evaluating models."""
15 | import average_precision_calculator as ap_calculator
16 | import mean_average_precision_calculator as map_calculator
17 | import numpy
18 | from tensorflow.python.platform import gfile
19 |
20 |
21 | def flatten(l):
22 | """Merges a list of lists into a single list. """
23 | return [item for sublist in l for item in sublist]
24 |
25 |
26 | def calculate_hit_at_one(predictions, actuals):
27 | """Performs a local (numpy) calculation of the hit at one.
28 |
29 | Args:
30 | predictions: Matrix containing the outputs of the model. Dimensions are
31 | 'batch' x 'num_classes'.
32 | actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
33 | 'num_classes'.
34 |
35 | Returns:
36 | float: The average hit at one across the entire batch.
37 | """
38 | top_prediction = numpy.argmax(predictions, 1)
39 | hits = actuals[numpy.arange(actuals.shape[0]), top_prediction]
40 | return numpy.average(hits)
41 |
42 |
43 | def calculate_precision_at_equal_recall_rate(predictions, actuals):
44 | """Performs a local (numpy) calculation of the PERR.
45 |
46 | Args:
47 | predictions: Matrix containing the outputs of the model. Dimensions are
48 | 'batch' x 'num_classes'.
49 | actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
50 | 'num_classes'.
51 |
52 | Returns:
53 | float: The average precision at equal recall rate across the entire batch.
54 | """
55 | aggregated_precision = 0.0
56 | num_videos = actuals.shape[0]
57 | for row in numpy.arange(num_videos):
58 | num_labels = int(numpy.sum(actuals[row]))
59 | top_indices = numpy.argpartition(predictions[row],
60 | -num_labels)[-num_labels:]
61 | item_precision = 0.0
62 | for label_index in top_indices:
63 | if predictions[row][label_index] > 0:
64 | item_precision += actuals[row][label_index]
65 | item_precision /= top_indices.size
66 | aggregated_precision += item_precision
67 | aggregated_precision /= num_videos
68 | return aggregated_precision
69 |
70 |
71 | def calculate_gap(predictions, actuals, top_k=20):
72 | """Performs a local (numpy) calculation of the global average precision.
73 |
74 | Only the top_k predictions are taken for each of the videos.
75 |
76 | Args:
77 | predictions: Matrix containing the outputs of the model. Dimensions are
78 | 'batch' x 'num_classes'.
79 | actuals: Matrix containing the ground truth labels. Dimensions are 'batch' x
80 | 'num_classes'.
81 | top_k: How many predictions to use per video.
82 |
83 | Returns:
84 | float: The global average precision.
85 | """
86 | gap_calculator = ap_calculator.AveragePrecisionCalculator()
87 | sparse_predictions, sparse_labels, num_positives = top_k_by_class(
88 | predictions, actuals, top_k)
89 | gap_calculator.accumulate(flatten(sparse_predictions), flatten(sparse_labels),
90 | sum(num_positives))
91 | return gap_calculator.peek_ap_at_n()
92 |
93 |
94 | def top_k_by_class(predictions, labels, k=20):
95 | """Extracts the top k predictions for each video, sorted by class.
96 |
97 | Args:
98 | predictions: A numpy matrix containing the outputs of the model. Dimensions
99 | are 'batch' x 'num_classes'.
100 | k: the top k non-zero entries to preserve in each prediction.
101 |
102 | Returns:
103 | A tuple (predictions,labels, true_positives). 'predictions' and 'labels'
104 | are lists of lists of floats. 'true_positives' is a list of scalars. The
105 | length of the lists are equal to the number of classes. The entries in the
106 | predictions variable are probability predictions, and
107 | the corresponding entries in the labels variable are the ground truth for
108 | those predictions. The entries in 'true_positives' are the number of true
109 | positives for each class in the ground truth.
110 |
111 | Raises:
112 | ValueError: An error occurred when the k is not a positive integer.
113 | """
114 | if k <= 0:
115 | raise ValueError("k must be a positive integer.")
116 | k = min(k, predictions.shape[1])
117 | num_classes = predictions.shape[1]
118 | prediction_triplets = []
119 | for video_index in range(predictions.shape[0]):
120 | prediction_triplets.extend(
121 | top_k_triplets(predictions[video_index], labels[video_index], k))
122 | out_predictions = [[] for _ in range(num_classes)]
123 | out_labels = [[] for _ in range(num_classes)]
124 | for triplet in prediction_triplets:
125 | out_predictions[triplet[0]].append(triplet[1])
126 | out_labels[triplet[0]].append(triplet[2])
127 | out_true_positives = [numpy.sum(labels[:, i]) for i in range(num_classes)]
128 |
129 | return out_predictions, out_labels, out_true_positives
130 |
131 |
132 | def top_k_triplets(predictions, labels, k=20):
133 | """Get the top_k for a 1-d numpy array.
134 |
135 | Returns a sparse list of tuples in
136 | (prediction, class) format
137 | """
138 | m = len(predictions)
139 | k = min(k, m)
140 | indices = numpy.argpartition(predictions, -k)[-k:]
141 | return [(index, predictions[index], labels[index]) for index in indices]
142 |
143 |
144 | class EvaluationMetrics(object):
145 | """A class to store the evaluation metrics."""
146 |
147 | def __init__(self, num_class, top_k, top_n):
148 | """Construct an EvaluationMetrics object to store the evaluation metrics.
149 |
150 | Args:
151 | num_class: A positive integer specifying the number of classes.
152 | top_k: A positive integer specifying how many predictions are considered
153 | per video.
154 | top_n: A positive Integer specifying the average precision at n, or None
155 | to use all provided data points.
156 |
157 | Raises:
158 | ValueError: An error occurred when MeanAveragePrecisionCalculator cannot
159 | not be constructed.
160 | """
161 | self.sum_hit_at_one = 0.0
162 | self.sum_perr = 0.0
163 | self.sum_loss = 0.0
164 | self.map_calculator = map_calculator.MeanAveragePrecisionCalculator(
165 | num_class, top_n=top_n)
166 | self.global_ap_calculator = ap_calculator.AveragePrecisionCalculator()
167 | self.top_k = top_k
168 | self.num_examples = 0
169 |
170 | def accumulate(self, predictions, labels, loss):
171 | """Accumulate the metrics calculated locally for this mini-batch.
172 |
173 | Args:
174 | predictions: A numpy matrix containing the outputs of the model.
175 | Dimensions are 'batch' x 'num_classes'.
176 | labels: A numpy matrix containing the ground truth labels. Dimensions are
177 | 'batch' x 'num_classes'.
178 | loss: A numpy array containing the loss for each sample.
179 |
180 | Returns:
181 | dictionary: A dictionary storing the metrics for the mini-batch.
182 |
183 | Raises:
184 | ValueError: An error occurred when the shape of predictions and actuals
185 | does not match.
186 | """
187 | batch_size = labels.shape[0]
188 | mean_hit_at_one = calculate_hit_at_one(predictions, labels)
189 | mean_perr = calculate_precision_at_equal_recall_rate(predictions, labels)
190 | mean_loss = numpy.mean(loss)
191 |
192 | # Take the top 20 predictions.
193 | sparse_predictions, sparse_labels, num_positives = top_k_by_class(
194 | predictions, labels, self.top_k)
195 | self.map_calculator.accumulate(sparse_predictions, sparse_labels,
196 | num_positives)
197 | self.global_ap_calculator.accumulate(flatten(sparse_predictions),
198 | flatten(sparse_labels),
199 | sum(num_positives))
200 |
201 | self.num_examples += batch_size
202 | self.sum_hit_at_one += mean_hit_at_one * batch_size
203 | self.sum_perr += mean_perr * batch_size
204 | self.sum_loss += mean_loss * batch_size
205 |
206 | return {"hit_at_one": mean_hit_at_one, "perr": mean_perr, "loss": mean_loss}
207 |
208 | def get(self):
209 | """Calculate the evaluation metrics for the whole epoch.
210 |
211 | Raises:
212 | ValueError: If no examples were accumulated.
213 |
214 | Returns:
215 | dictionary: a dictionary storing the evaluation metrics for the epoch. The
216 | dictionary has the fields: avg_hit_at_one, avg_perr, avg_loss, and
217 | aps (default nan).
218 | """
219 | if self.num_examples <= 0:
220 | raise ValueError("total_sample must be positive.")
221 | avg_hit_at_one = self.sum_hit_at_one / self.num_examples
222 | avg_perr = self.sum_perr / self.num_examples
223 | avg_loss = self.sum_loss / self.num_examples
224 |
225 | aps = self.map_calculator.peek_map_at_n()
226 | gap = self.global_ap_calculator.peek_ap_at_n()
227 |
228 | epoch_info_dict = {
229 | "avg_hit_at_one": avg_hit_at_one,
230 | "avg_perr": avg_perr,
231 | "avg_loss": avg_loss,
232 | "aps": aps,
233 | "gap": gap
234 | }
235 | return epoch_info_dict
236 |
237 | def clear(self):
238 | """Clear the evaluation metrics and reset the EvaluationMetrics object."""
239 | self.sum_hit_at_one = 0.0
240 | self.sum_perr = 0.0
241 | self.sum_loss = 0.0
242 | self.map_calculator.clear()
243 | self.global_ap_calculator.clear()
244 | self.num_examples = 0
245 |
--------------------------------------------------------------------------------
/average_precision_calculator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Calculate or keep track of the interpolated average precision.
15 |
16 | It provides an interface for calculating interpolated average precision for an
17 | entire list or the top-n ranked items. For the definition of the
18 | (non-)interpolated average precision:
19 | http://trec.nist.gov/pubs/trec15/appendices/CE.MEASURES06.pdf
20 |
21 | Example usages:
22 | 1) Use it as a static function call to directly calculate average precision for
23 | a short ranked list in the memory.
24 |
25 | ```
26 | import random
27 |
28 | p = np.array([random.random() for _ in xrange(10)])
29 | a = np.array([random.choice([0, 1]) for _ in xrange(10)])
30 |
31 | ap = average_precision_calculator.AveragePrecisionCalculator.ap(p, a)
32 | ```
33 |
34 | 2) Use it as an object for long ranked list that cannot be stored in memory or
35 | the case where partial predictions can be observed at a time (Tensorflow
36 | predictions). In this case, we first call the function accumulate many times
37 | to process parts of the ranked list. After processing all the parts, we call
38 | peek_interpolated_ap_at_n.
39 | ```
40 | p1 = np.array([random.random() for _ in xrange(5)])
41 | a1 = np.array([random.choice([0, 1]) for _ in xrange(5)])
42 | p2 = np.array([random.random() for _ in xrange(5)])
43 | a2 = np.array([random.choice([0, 1]) for _ in xrange(5)])
44 |
45 | # interpolated average precision at 10 using 1000 break points
46 | calculator = average_precision_calculator.AveragePrecisionCalculator(10)
47 | calculator.accumulate(p1, a1)
48 | calculator.accumulate(p2, a2)
49 | ap3 = calculator.peek_ap_at_n()
50 | ```
51 | """
52 |
53 | import heapq
54 | import random
55 | import numbers
56 |
57 | import numpy
58 |
59 |
60 | class AveragePrecisionCalculator(object):
61 | """Calculate the average precision and average precision at n."""
62 |
63 | def __init__(self, top_n=None):
64 | """Construct an AveragePrecisionCalculator to calculate average precision.
65 |
66 | This class is used to calculate the average precision for a single label.
67 |
68 | Args:
69 | top_n: A positive Integer specifying the average precision at n, or None
70 | to use all provided data points.
71 |
72 | Raises:
73 | ValueError: An error occurred when the top_n is not a positive integer.
74 | """
75 | if not ((isinstance(top_n, int) and top_n >= 0) or top_n is None):
76 | raise ValueError("top_n must be a positive integer or None.")
77 |
78 | self._top_n = top_n # average precision at n
79 | self._total_positives = 0 # total number of positives have seen
80 | self._heap = [] # max heap of (prediction, actual)
81 |
82 | @property
83 | def heap_size(self):
84 | """Gets the heap size maintained in the class."""
85 | return len(self._heap)
86 |
87 | @property
88 | def num_accumulated_positives(self):
89 | """Gets the number of positive samples that have been accumulated."""
90 | return self._total_positives
91 |
92 | def accumulate(self, predictions, actuals, num_positives=None):
93 | """Accumulate the predictions and their ground truth labels.
94 |
95 | After the function call, we may call peek_ap_at_n to actually calculate
96 | the average precision.
97 | Note predictions and actuals must have the same shape.
98 |
99 | Args:
100 | predictions: a list storing the prediction scores.
101 | actuals: a list storing the ground truth labels. Any value larger than 0
102 | will be treated as positives, otherwise as negatives. num_positives = If
103 | the 'predictions' and 'actuals' inputs aren't complete, then it's
104 | possible some true positives were missed in them. In that case, you can
105 | provide 'num_positives' in order to accurately track recall.
106 |
107 | Raises:
108 | ValueError: An error occurred when the format of the input is not the
109 | numpy 1-D array or the shape of predictions and actuals does not match.
110 | """
111 | if len(predictions) != len(actuals):
112 | raise ValueError("the shape of predictions and actuals does not match.")
113 |
114 | if num_positives is not None:
115 | if not isinstance(num_positives, numbers.Number) or num_positives < 0:
116 | raise ValueError(
117 | "'num_positives' was provided but it was a negative number.")
118 |
119 | if num_positives is not None:
120 | self._total_positives += num_positives
121 | else:
122 | self._total_positives += numpy.size(
123 | numpy.where(numpy.array(actuals) > 1e-5))
124 | topk = self._top_n
125 | heap = self._heap
126 |
127 | for i in range(numpy.size(predictions)):
128 | if topk is None or len(heap) < topk:
129 | heapq.heappush(heap, (predictions[i], actuals[i]))
130 | else:
131 | if predictions[i] > heap[0][0]: # heap[0] is the smallest
132 | heapq.heappop(heap)
133 | heapq.heappush(heap, (predictions[i], actuals[i]))
134 |
135 | def clear(self):
136 | """Clear the accumulated predictions."""
137 | self._heap = []
138 | self._total_positives = 0
139 |
140 | def peek_ap_at_n(self):
141 | """Peek the non-interpolated average precision at n.
142 |
143 | Returns:
144 | The non-interpolated average precision at n (default 0).
145 | If n is larger than the length of the ranked list,
146 | the average precision will be returned.
147 | """
148 | if self.heap_size <= 0:
149 | return 0
150 | predlists = numpy.array(list(zip(*self._heap)))
151 |
152 | ap = self.ap_at_n(predlists[0],
153 | predlists[1],
154 | n=self._top_n,
155 | total_num_positives=self._total_positives)
156 | return ap
157 |
158 | @staticmethod
159 | def ap(predictions, actuals):
160 | """Calculate the non-interpolated average precision.
161 |
162 | Args:
163 | predictions: a numpy 1-D array storing the sparse prediction scores.
164 | actuals: a numpy 1-D array storing the ground truth labels. Any value
165 | larger than 0 will be treated as positives, otherwise as negatives.
166 |
167 | Returns:
168 | The non-interpolated average precision at n.
169 | If n is larger than the length of the ranked list,
170 | the average precision will be returned.
171 |
172 | Raises:
173 | ValueError: An error occurred when the format of the input is not the
174 | numpy 1-D array or the shape of predictions and actuals does not match.
175 | """
176 | return AveragePrecisionCalculator.ap_at_n(predictions, actuals, n=None)
177 |
178 | @staticmethod
179 | def ap_at_n(predictions, actuals, n=20, total_num_positives=None):
180 | """Calculate the non-interpolated average precision.
181 |
182 | Args:
183 | predictions: a numpy 1-D array storing the sparse prediction scores.
184 | actuals: a numpy 1-D array storing the ground truth labels. Any value
185 | larger than 0 will be treated as positives, otherwise as negatives.
186 | n: the top n items to be considered in ap@n.
187 | total_num_positives : (optionally) you can specify the number of total
188 | positive in the list. If specified, it will be used in calculation.
189 |
190 | Returns:
191 | The non-interpolated average precision at n.
192 | If n is larger than the length of the ranked list,
193 | the average precision will be returned.
194 |
195 | Raises:
196 | ValueError: An error occurred when
197 | 1) the format of the input is not the numpy 1-D array;
198 | 2) the shape of predictions and actuals does not match;
199 | 3) the input n is not a positive integer.
200 | """
201 | if len(predictions) != len(actuals):
202 | raise ValueError("the shape of predictions and actuals does not match.")
203 |
204 | if n is not None:
205 | if not isinstance(n, int) or n <= 0:
206 | raise ValueError("n must be 'None' or a positive integer."
207 | " It was '%s'." % n)
208 |
209 | ap = 0.0
210 |
211 | predictions = numpy.array(predictions)
212 | actuals = numpy.array(actuals)
213 |
214 | # add a shuffler to avoid overestimating the ap
215 | predictions, actuals = AveragePrecisionCalculator._shuffle(
216 | predictions, actuals)
217 | sortidx = sorted(range(len(predictions)),
218 | key=lambda k: predictions[k],
219 | reverse=True)
220 |
221 | if total_num_positives is None:
222 | numpos = numpy.size(numpy.where(actuals > 0))
223 | else:
224 | numpos = total_num_positives
225 |
226 | if numpos == 0:
227 | return 0
228 |
229 | if n is not None:
230 | numpos = min(numpos, n)
231 | delta_recall = 1.0 / numpos
232 | poscount = 0.0
233 |
234 | # calculate the ap
235 | r = len(sortidx)
236 | if n is not None:
237 | r = min(r, n)
238 | for i in range(r):
239 | if actuals[sortidx[i]] > 0:
240 | poscount += 1
241 | ap += poscount / (i + 1) * delta_recall
242 | return ap
243 |
244 | @staticmethod
245 | def _shuffle(predictions, actuals):
246 | random.seed(0)
247 | suffidx = random.sample(range(len(predictions)), len(predictions))
248 | predictions = predictions[suffidx]
249 | actuals = actuals[suffidx]
250 | return predictions, actuals
251 |
252 | @staticmethod
253 | def _zero_one_normalize(predictions, epsilon=1e-7):
254 | """Normalize the predictions to the range between 0.0 and 1.0.
255 |
256 | For some predictions like SVM predictions, we need to normalize them before
257 | calculate the interpolated average precision. The normalization will not
258 | change the rank in the original list and thus won't change the average
259 | precision.
260 |
261 | Args:
262 | predictions: a numpy 1-D array storing the sparse prediction scores.
263 | epsilon: a small constant to avoid denominator being zero.
264 |
265 | Returns:
266 | The normalized prediction.
267 | """
268 | denominator = numpy.max(predictions) - numpy.min(predictions)
269 | ret = (predictions - numpy.min(predictions)) / numpy.max(
270 | denominator, epsilon)
271 | return ret
272 |
--------------------------------------------------------------------------------
/segment_label_ids.csv:
--------------------------------------------------------------------------------
1 | Index
2 | 3
3 | 7
4 | 8
5 | 11
6 | 12
7 | 17
8 | 18
9 | 19
10 | 21
11 | 22
12 | 23
13 | 28
14 | 31
15 | 30
16 | 32
17 | 33
18 | 34
19 | 41
20 | 43
21 | 45
22 | 46
23 | 48
24 | 53
25 | 54
26 | 52
27 | 55
28 | 58
29 | 59
30 | 60
31 | 61
32 | 65
33 | 68
34 | 73
35 | 71
36 | 74
37 | 75
38 | 76
39 | 77
40 | 80
41 | 83
42 | 90
43 | 88
44 | 89
45 | 92
46 | 95
47 | 100
48 | 101
49 | 99
50 | 104
51 | 105
52 | 109
53 | 113
54 | 112
55 | 115
56 | 116
57 | 118
58 | 120
59 | 121
60 | 123
61 | 125
62 | 127
63 | 131
64 | 128
65 | 129
66 | 130
67 | 137
68 | 141
69 | 143
70 | 145
71 | 148
72 | 152
73 | 151
74 | 156
75 | 155
76 | 158
77 | 160
78 | 164
79 | 163
80 | 169
81 | 170
82 | 172
83 | 171
84 | 173
85 | 174
86 | 175
87 | 176
88 | 178
89 | 182
90 | 184
91 | 186
92 | 188
93 | 187
94 | 192
95 | 191
96 | 190
97 | 194
98 | 197
99 | 196
100 | 198
101 | 201
102 | 202
103 | 200
104 | 199
105 | 205
106 | 204
107 | 209
108 | 207
109 | 206
110 | 210
111 | 213
112 | 214
113 | 220
114 | 218
115 | 217
116 | 226
117 | 227
118 | 231
119 | 232
120 | 229
121 | 233
122 | 235
123 | 237
124 | 244
125 | 240
126 | 249
127 | 246
128 | 248
129 | 239
130 | 250
131 | 245
132 | 255
133 | 253
134 | 256
135 | 261
136 | 259
137 | 263
138 | 262
139 | 266
140 | 267
141 | 268
142 | 269
143 | 271
144 | 276
145 | 273
146 | 277
147 | 274
148 | 278
149 | 279
150 | 280
151 | 288
152 | 291
153 | 295
154 | 294
155 | 293
156 | 297
157 | 296
158 | 300
159 | 299
160 | 303
161 | 302
162 | 304
163 | 305
164 | 313
165 | 307
166 | 311
167 | 310
168 | 312
169 | 316
170 | 318
171 | 321
172 | 322
173 | 331
174 | 333
175 | 329
176 | 330
177 | 334
178 | 343
179 | 349
180 | 340
181 | 344
182 | 348
183 | 358
184 | 347
185 | 359
186 | 355
187 | 361
188 | 360
189 | 364
190 | 365
191 | 368
192 | 369
193 | 366
194 | 370
195 | 374
196 | 380
197 | 373
198 | 385
199 | 384
200 | 388
201 | 389
202 | 382
203 | 393
204 | 381
205 | 390
206 | 394
207 | 399
208 | 397
209 | 396
210 | 402
211 | 400
212 | 398
213 | 401
214 | 405
215 | 406
216 | 410
217 | 408
218 | 416
219 | 415
220 | 419
221 | 422
222 | 414
223 | 421
224 | 424
225 | 429
226 | 418
227 | 427
228 | 434
229 | 428
230 | 435
231 | 430
232 | 441
233 | 439
234 | 437
235 | 443
236 | 440
237 | 442
238 | 445
239 | 446
240 | 448
241 | 454
242 | 444
243 | 453
244 | 455
245 | 451
246 | 452
247 | 458
248 | 460
249 | 465
250 | 457
251 | 463
252 | 462
253 | 461
254 | 464
255 | 469
256 | 468
257 | 472
258 | 473
259 | 471
260 | 475
261 | 474
262 | 477
263 | 485
264 | 491
265 | 488
266 | 482
267 | 490
268 | 496
269 | 494
270 | 483
271 | 495
272 | 493
273 | 507
274 | 501
275 | 499
276 | 503
277 | 498
278 | 514
279 | 504
280 | 502
281 | 506
282 | 508
283 | 511
284 | 527
285 | 526
286 | 532
287 | 513
288 | 519
289 | 525
290 | 518
291 | 528
292 | 522
293 | 523
294 | 535
295 | 539
296 | 540
297 | 533
298 | 521
299 | 541
300 | 547
301 | 550
302 | 544
303 | 549
304 | 551
305 | 554
306 | 543
307 | 548
308 | 557
309 | 560
310 | 552
311 | 559
312 | 563
313 | 565
314 | 567
315 | 555
316 | 576
317 | 568
318 | 564
319 | 573
320 | 581
321 | 580
322 | 572
323 | 571
324 | 584
325 | 590
326 | 585
327 | 587
328 | 588
329 | 592
330 | 598
331 | 597
332 | 599
333 | 603
334 | 600
335 | 604
336 | 605
337 | 614
338 | 602
339 | 610
340 | 608
341 | 611
342 | 612
343 | 613
344 | 617
345 | 620
346 | 607
347 | 624
348 | 627
349 | 625
350 | 631
351 | 629
352 | 638
353 | 632
354 | 634
355 | 644
356 | 641
357 | 642
358 | 646
359 | 652
360 | 647
361 | 637
362 | 661
363 | 635
364 | 658
365 | 648
366 | 663
367 | 668
368 | 664
369 | 656
370 | 666
371 | 671
372 | 683
373 | 675
374 | 669
375 | 676
376 | 667
377 | 691
378 | 685
379 | 673
380 | 688
381 | 702
382 | 684
383 | 679
384 | 694
385 | 686
386 | 689
387 | 680
388 | 693
389 | 703
390 | 697
391 | 698
392 | 692
393 | 705
394 | 706
395 | 712
396 | 711
397 | 709
398 | 710
399 | 726
400 | 713
401 | 721
402 | 720
403 | 715
404 | 717
405 | 730
406 | 728
407 | 723
408 | 716
409 | 722
410 | 718
411 | 732
412 | 724
413 | 736
414 | 725
415 | 742
416 | 727
417 | 735
418 | 740
419 | 748
420 | 738
421 | 746
422 | 751
423 | 749
424 | 752
425 | 754
426 | 760
427 | 763
428 | 756
429 | 758
430 | 766
431 | 764
432 | 757
433 | 780
434 | 767
435 | 769
436 | 771
437 | 786
438 | 785
439 | 781
440 | 787
441 | 778
442 | 783
443 | 792
444 | 791
445 | 795
446 | 788
447 | 805
448 | 802
449 | 801
450 | 793
451 | 796
452 | 804
453 | 803
454 | 797
455 | 814
456 | 813
457 | 789
458 | 808
459 | 818
460 | 816
461 | 817
462 | 811
463 | 820
464 | 826
465 | 829
466 | 824
467 | 821
468 | 825
469 | 822
470 | 835
471 | 833
472 | 843
473 | 823
474 | 827
475 | 830
476 | 832
477 | 837
478 | 852
479 | 844
480 | 841
481 | 812
482 | 847
483 | 862
484 | 869
485 | 860
486 | 838
487 | 870
488 | 846
489 | 858
490 | 854
491 | 880
492 | 876
493 | 857
494 | 859
495 | 877
496 | 871
497 | 855
498 | 875
499 | 861
500 | 867
501 | 892
502 | 898
503 | 888
504 | 884
505 | 887
506 | 891
507 | 906
508 | 900
509 | 878
510 | 885
511 | 883
512 | 901
513 | 903
514 | 907
515 | 930
516 | 897
517 | 914
518 | 917
519 | 910
520 | 905
521 | 909
522 | 933
523 | 932
524 | 922
525 | 913
526 | 923
527 | 931
528 | 911
529 | 937
530 | 918
531 | 955
532 | 915
533 | 944
534 | 952
535 | 945
536 | 948
537 | 946
538 | 970
539 | 974
540 | 958
541 | 925
542 | 979
543 | 942
544 | 965
545 | 975
546 | 950
547 | 982
548 | 940
549 | 973
550 | 962
551 | 972
552 | 957
553 | 984
554 | 983
555 | 964
556 | 1007
557 | 971
558 | 981
559 | 954
560 | 993
561 | 991
562 | 996
563 | 1005
564 | 1015
565 | 1009
566 | 995
567 | 986
568 | 1000
569 | 985
570 | 980
571 | 1016
572 | 1011
573 | 999
574 | 1002
575 | 994
576 | 1013
577 | 1010
578 | 992
579 | 1008
580 | 1036
581 | 1025
582 | 1012
583 | 990
584 | 1037
585 | 1040
586 | 1031
587 | 1019
588 | 1052
589 | 1001
590 | 1055
591 | 1032
592 | 1069
593 | 1058
594 | 1014
595 | 1023
596 | 1030
597 | 1061
598 | 1035
599 | 1034
600 | 1053
601 | 1045
602 | 1046
603 | 1067
604 | 1060
605 | 1049
606 | 1056
607 | 1074
608 | 1066
609 | 1044
610 | 1038
611 | 1073
612 | 1077
613 | 1068
614 | 1057
615 | 1072
616 | 1104
617 | 1083
618 | 1089
619 | 1087
620 | 1099
621 | 1076
622 | 1086
623 | 1098
624 | 1094
625 | 1095
626 | 1096
627 | 1101
628 | 1107
629 | 1105
630 | 1117
631 | 1093
632 | 1106
633 | 1122
634 | 1119
635 | 1103
636 | 1128
637 | 1120
638 | 1126
639 | 1102
640 | 1115
641 | 1124
642 | 1123
643 | 1131
644 | 1136
645 | 1144
646 | 1121
647 | 1137
648 | 1132
649 | 1133
650 | 1157
651 | 1134
652 | 1143
653 | 1159
654 | 1164
655 | 1155
656 | 1142
657 | 1150
658 | 1148
659 | 1161
660 | 1165
661 | 1147
662 | 1162
663 | 1152
664 | 1174
665 | 1160
666 | 1166
667 | 1190
668 | 1175
669 | 1167
670 | 1156
671 | 1180
672 | 1171
673 | 1179
674 | 1172
675 | 1186
676 | 1188
677 | 1201
678 | 1177
679 | 1208
680 | 1183
681 | 1189
682 | 1192
683 | 1209
684 | 1214
685 | 1197
686 | 1168
687 | 1202
688 | 1205
689 | 1203
690 | 1199
691 | 1219
692 | 1217
693 | 1187
694 | 1206
695 | 1210
696 | 1241
697 | 1221
698 | 1218
699 | 1223
700 | 1236
701 | 1212
702 | 1237
703 | 1195
704 | 1216
705 | 1247
706 | 1234
707 | 1240
708 | 1257
709 | 1224
710 | 1243
711 | 1259
712 | 1242
713 | 1282
714 | 1222
715 | 1254
716 | 1227
717 | 1235
718 | 1269
719 | 1258
720 | 1290
721 | 1275
722 | 1262
723 | 1252
724 | 1248
725 | 1272
726 | 1246
727 | 1225
728 | 1245
729 | 1277
730 | 1298
731 | 1288
732 | 1271
733 | 1265
734 | 1286
735 | 1260
736 | 1266
737 | 1296
738 | 1280
739 | 1285
740 | 1293
741 | 1276
742 | 1287
743 | 1289
744 | 1261
745 | 1264
746 | 1295
747 | 1291
748 | 1283
749 | 1311
750 | 1303
751 | 1330
752 | 1315
753 | 1300
754 | 1333
755 | 1307
756 | 1325
757 | 1334
758 | 1316
759 | 1314
760 | 1317
761 | 1310
762 | 1329
763 | 1324
764 | 1339
765 | 1346
766 | 1342
767 | 1352
768 | 1321
769 | 1376
770 | 1366
771 | 1308
772 | 1345
773 | 1348
774 | 1386
775 | 1383
776 | 1372
777 | 1367
778 | 1400
779 | 1382
780 | 1375
781 | 1392
782 | 1380
783 | 1371
784 | 1393
785 | 1389
786 | 1353
787 | 1387
788 | 1374
789 | 1379
790 | 1381
791 | 1359
792 | 1360
793 | 1396
794 | 1399
795 | 1365
796 | 1424
797 | 1373
798 | 1411
799 | 1401
800 | 1397
801 | 1395
802 | 1412
803 | 1394
804 | 1368
805 | 1423
806 | 1391
807 | 1435
808 | 1409
809 | 1443
810 | 1402
811 | 1425
812 | 1415
813 | 1421
814 | 1426
815 | 1433
816 | 1420
817 | 1452
818 | 1436
819 | 1430
820 | 1408
821 | 1458
822 | 1429
823 | 1453
824 | 1454
825 | 1447
826 | 1472
827 | 1486
828 | 1468
829 | 1461
830 | 1467
831 | 1484
832 | 1457
833 | 1444
834 | 1450
835 | 1451
836 | 1459
837 | 1462
838 | 1449
839 | 1476
840 | 1470
841 | 1471
842 | 1498
843 | 1488
844 | 1442
845 | 1480
846 | 1456
847 | 1466
848 | 1505
849 | 1517
850 | 1464
851 | 1503
852 | 1490
853 | 1519
854 | 1481
855 | 1493
856 | 1463
857 | 1532
858 | 1487
859 | 1501
860 | 1500
861 | 1495
862 | 1509
863 | 1535
864 | 1506
865 | 1521
866 | 1580
867 | 1540
868 | 1502
869 | 1520
870 | 1496
871 | 1569
872 | 1515
873 | 1489
874 | 1507
875 | 1527
876 | 1545
877 | 1560
878 | 1510
879 | 1514
880 | 1526
881 | 1594
882 | 1511
883 | 1572
884 | 1548
885 | 1584
886 | 1556
887 | 1588
888 | 1628
889 | 1555
890 | 1568
891 | 1550
892 | 1622
893 | 1563
894 | 1603
895 | 1616
896 | 1576
897 | 1549
898 | 1537
899 | 1593
900 | 1618
901 | 1645
902 | 1624
903 | 1617
904 | 1634
905 | 1595
906 | 1597
907 | 1590
908 | 1632
909 | 1575
910 | 1559
911 | 1625
912 | 1615
913 | 1591
914 | 1630
915 | 1608
916 | 1621
917 | 1589
918 | 1646
919 | 1643
920 | 1652
921 | 1627
922 | 1611
923 | 1626
924 | 1613
925 | 1639
926 | 1655
927 | 1620
928 | 1602
929 | 1651
930 | 1653
931 | 1669
932 | 1638
933 | 1696
934 | 1649
935 | 1675
936 | 1660
937 | 1683
938 | 1666
939 | 1671
940 | 1703
941 | 1716
942 | 1637
943 | 1672
944 | 1676
945 | 1692
946 | 1711
947 | 1680
948 | 1641
949 | 1688
950 | 1708
951 | 1704
952 | 1690
953 | 1674
954 | 1718
955 | 1699
956 | 1723
957 | 1756
958 | 1700
959 | 1662
960 | 1715
961 | 1657
962 | 1733
963 | 1728
964 | 1670
965 | 1712
966 | 1685
967 | 1724
968 | 1735
969 | 1714
970 | 1730
971 | 1747
972 | 1656
973 | 1737
974 | 1705
975 | 1693
976 | 1713
977 | 1689
978 | 1753
979 | 1739
980 | 1721
981 | 1725
982 | 1749
983 | 1732
984 | 1743
985 | 1731
986 | 1767
987 | 1738
988 | 1831
989 | 1771
990 | 1726
991 | 1746
992 | 1776
993 | 1775
994 | 1799
995 | 1774
996 | 1780
997 | 1781
998 | 1769
999 | 1805
1000 | 1788
1001 | 1801
1002 |
--------------------------------------------------------------------------------
/frame_level_models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Contains a collection of models which operate on variable-length sequences."""
15 | import math
16 |
17 | import model_utils as utils
18 | import models
19 | import tensorflow as tf
20 | from tensorflow import flags
21 | import tensorflow.contrib.slim as slim
22 | import video_level_models
23 |
24 | FLAGS = flags.FLAGS
25 | flags.DEFINE_integer("iterations", 30, "Number of frames per batch for DBoF.")
26 | flags.DEFINE_bool("dbof_add_batch_norm", True,
27 | "Adds batch normalization to the DBoF model.")
28 | flags.DEFINE_bool(
29 | "sample_random_frames", True,
30 | "If true samples random frames (for frame level models). If false, a random"
31 | "sequence of frames is sampled instead.")
32 | flags.DEFINE_integer("dbof_cluster_size", 8192,
33 | "Number of units in the DBoF cluster layer.")
34 | flags.DEFINE_integer("dbof_hidden_size", 1024,
35 | "Number of units in the DBoF hidden layer.")
36 | flags.DEFINE_string(
37 | "dbof_pooling_method", "max",
38 | "The pooling method used in the DBoF cluster layer. "
39 | "Choices are 'average' and 'max'.")
40 | flags.DEFINE_string(
41 | "dbof_activation", "sigmoid",
42 | "The nonlinear activation method for cluster and hidden dense layer, e.g., "
43 | "sigmoid, relu6, etc.")
44 | flags.DEFINE_string(
45 | "video_level_classifier_model", "MoeModel",
46 | "Some Frame-Level models can be decomposed into a "
47 | "generalized pooling operation followed by a "
48 | "classifier layer")
49 | flags.DEFINE_integer("lstm_cells", 1024, "Number of LSTM cells.")
50 | flags.DEFINE_integer("lstm_layers", 2, "Number of LSTM layers.")
51 |
52 |
53 | class FrameLevelLogisticModel(models.BaseModel):
54 | """Creates a logistic classifier over the aggregated frame-level features."""
55 |
56 | def create_model(self, model_input, vocab_size, num_frames, **unused_params):
57 | """See base class.
58 |
59 | This class is intended to be an example for implementors of frame level
60 | models. If you want to train a model over averaged features it is more
61 | efficient to average them beforehand rather than on the fly.
62 |
63 | Args:
64 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
65 | input features.
66 | vocab_size: The number of classes in the dataset.
67 | num_frames: A vector of length 'batch' which indicates the number of
68 | frames for each video (before padding).
69 |
70 | Returns:
71 | A dictionary with a tensor containing the probability predictions of the
72 | model in the 'predictions' key. The dimensions of the tensor are
73 | 'batch_size' x 'num_classes'.
74 | """
75 | num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
76 | feature_size = model_input.get_shape().as_list()[2]
77 |
78 | denominators = tf.reshape(tf.tile(num_frames, [1, feature_size]),
79 | [-1, feature_size])
80 | avg_pooled = tf.reduce_sum(model_input, axis=[1]) / denominators
81 |
82 | output = slim.fully_connected(avg_pooled,
83 | vocab_size,
84 | activation_fn=tf.nn.sigmoid,
85 | weights_regularizer=slim.l2_regularizer(1e-8))
86 | return {"predictions": output}
87 |
88 |
89 | class DbofModel(models.BaseModel):
90 | """Creates a Deep Bag of Frames model.
91 |
92 | The model projects the features for each frame into a higher dimensional
93 | 'clustering' space, pools across frames in that space, and then
94 | uses a configurable video-level model to classify the now aggregated features.
95 |
96 | The model will randomly sample either frames or sequences of frames during
97 | training to speed up convergence.
98 | """
99 |
100 | ACT_FN_MAP = {
101 | "sigmoid": tf.nn.sigmoid,
102 | "relu6": tf.nn.relu6,
103 | }
104 |
105 | def create_model(self,
106 | model_input,
107 | vocab_size,
108 | num_frames,
109 | iterations=None,
110 | add_batch_norm=None,
111 | sample_random_frames=None,
112 | cluster_size=None,
113 | hidden_size=None,
114 | is_training=True,
115 | **unused_params):
116 | """See base class.
117 |
118 | Args:
119 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
120 | input features.
121 | vocab_size: The number of classes in the dataset.
122 | num_frames: A vector of length 'batch' which indicates the number of
123 | frames for each video (before padding).
124 | iterations: the number of frames to be sampled.
125 | add_batch_norm: whether to add batch norm during training.
126 | sample_random_frames: whether to sample random frames or random sequences.
127 | cluster_size: the output neuron number of the cluster layer.
128 | hidden_size: the output neuron number of the hidden layer.
129 | is_training: whether to build the graph in training mode.
130 |
131 | Returns:
132 | A dictionary with a tensor containing the probability predictions of the
133 | model in the 'predictions' key. The dimensions of the tensor are
134 | 'batch_size' x 'num_classes'.
135 | """
136 | iterations = iterations or FLAGS.iterations
137 | add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm
138 | random_frames = sample_random_frames or FLAGS.sample_random_frames
139 | cluster_size = cluster_size or FLAGS.dbof_cluster_size
140 | hidden1_size = hidden_size or FLAGS.dbof_hidden_size
141 | act_fn = self.ACT_FN_MAP.get(FLAGS.dbof_activation)
142 | assert act_fn is not None, ("dbof_activation is not valid: %s." %
143 | FLAGS.dbof_activation)
144 |
145 | num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
146 | if random_frames:
147 | model_input = utils.SampleRandomFrames(model_input, num_frames,
148 | iterations)
149 | else:
150 | model_input = utils.SampleRandomSequence(model_input, num_frames,
151 | iterations)
152 | max_frames = model_input.get_shape().as_list()[1]
153 | feature_size = model_input.get_shape().as_list()[2]
154 | reshaped_input = tf.reshape(model_input, [-1, feature_size])
155 | tf.compat.v1.summary.histogram("input_hist", reshaped_input)
156 |
157 | if add_batch_norm:
158 | reshaped_input = slim.batch_norm(reshaped_input,
159 | center=True,
160 | scale=True,
161 | is_training=is_training,
162 | scope="input_bn")
163 |
164 | cluster_weights = tf.compat.v1.get_variable(
165 | "cluster_weights", [feature_size, cluster_size],
166 | initializer=tf.random_normal_initializer(stddev=1 /
167 | math.sqrt(feature_size)))
168 | tf.compat.v1.summary.histogram("cluster_weights", cluster_weights)
169 | activation = tf.matmul(reshaped_input, cluster_weights)
170 | if add_batch_norm:
171 | activation = slim.batch_norm(activation,
172 | center=True,
173 | scale=True,
174 | is_training=is_training,
175 | scope="cluster_bn")
176 | else:
177 | cluster_biases = tf.compat.v1.get_variable(
178 | "cluster_biases", [cluster_size],
179 | initializer=tf.random_normal_initializer(stddev=1 /
180 | math.sqrt(feature_size)))
181 | tf.compat.v1.summary.histogram("cluster_biases", cluster_biases)
182 | activation += cluster_biases
183 | activation = act_fn(activation)
184 | tf.compat.v1.summary.histogram("cluster_output", activation)
185 |
186 | activation = tf.reshape(activation, [-1, max_frames, cluster_size])
187 | activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method)
188 |
189 | hidden1_weights = tf.compat.v1.get_variable(
190 | "hidden1_weights", [cluster_size, hidden1_size],
191 | initializer=tf.random_normal_initializer(stddev=1 /
192 | math.sqrt(cluster_size)))
193 | tf.compat.v1.summary.histogram("hidden1_weights", hidden1_weights)
194 | activation = tf.matmul(activation, hidden1_weights)
195 | if add_batch_norm:
196 | activation = slim.batch_norm(activation,
197 | center=True,
198 | scale=True,
199 | is_training=is_training,
200 | scope="hidden1_bn")
201 | else:
202 | hidden1_biases = tf.compat.v1.get_variable(
203 | "hidden1_biases", [hidden1_size],
204 | initializer=tf.random_normal_initializer(stddev=0.01))
205 | tf.compat.v1.summary.histogram("hidden1_biases", hidden1_biases)
206 | activation += hidden1_biases
207 | activation = act_fn(activation)
208 | tf.compat.v1.summary.histogram("hidden1_output", activation)
209 |
210 | aggregated_model = getattr(video_level_models,
211 | FLAGS.video_level_classifier_model)
212 | return aggregated_model().create_model(model_input=activation,
213 | vocab_size=vocab_size,
214 | **unused_params)
215 |
216 |
217 | class LstmModel(models.BaseModel):
218 | """Creates a model which uses a stack of LSTMs to represent the video."""
219 |
220 | def create_model(self, model_input, vocab_size, num_frames, **unused_params):
221 | """See base class.
222 |
223 | Args:
224 | model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
225 | input features.
226 | vocab_size: The number of classes in the dataset.
227 | num_frames: A vector of length 'batch' which indicates the number of
228 | frames for each video (before padding).
229 |
230 | Returns:
231 | A dictionary with a tensor containing the probability predictions of the
232 | model in the 'predictions' key. The dimensions of the tensor are
233 | 'batch_size' x 'num_classes'.
234 | """
235 | lstm_size = FLAGS.lstm_cells
236 | number_of_layers = FLAGS.lstm_layers
237 |
238 | stacked_lstm = tf.contrib.rnn.MultiRNNCell([
239 | tf.contrib.rnn.BasicLSTMCell(lstm_size, forget_bias=1.0)
240 | for _ in range(number_of_layers)
241 | ])
242 |
243 | _, state = tf.nn.dynamic_rnn(stacked_lstm,
244 | model_input,
245 | sequence_length=num_frames,
246 | dtype=tf.float32)
247 |
248 | aggregated_model = getattr(video_level_models,
249 | FLAGS.video_level_classifier_model)
250 |
251 | return aggregated_model().create_model(model_input=state[-1].h,
252 | vocab_size=vocab_size,
253 | **unused_params)
254 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/readers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Google Inc. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS-IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Provides readers configured for different datasets."""
15 |
16 | import tensorflow as tf
17 | import utils
18 |
19 |
20 | def resize_axis(tensor, axis, new_size, fill_value=0):
21 | """Truncates or pads a tensor to new_size on on a given axis.
22 |
23 | Truncate or extend tensor such that tensor.shape[axis] == new_size. If the
24 | size increases, the padding will be performed at the end, using fill_value.
25 |
26 | Args:
27 | tensor: The tensor to be resized.
28 | axis: An integer representing the dimension to be sliced.
29 | new_size: An integer or 0d tensor representing the new value for
30 | tensor.shape[axis].
31 | fill_value: Value to use to fill any new entries in the tensor. Will be cast
32 | to the type of tensor.
33 |
34 | Returns:
35 | The resized tensor.
36 | """
37 | tensor = tf.convert_to_tensor(tensor)
38 | shape = tf.unstack(tf.shape(tensor))
39 |
40 | pad_shape = shape[:]
41 | pad_shape[axis] = tf.maximum(0, new_size - shape[axis])
42 |
43 | shape[axis] = tf.minimum(shape[axis], new_size)
44 | shape = tf.stack(shape)
45 |
46 | resized = tf.concat([
47 | tf.slice(tensor, tf.zeros_like(shape), shape),
48 | tf.fill(tf.stack(pad_shape), tf.cast(fill_value, tensor.dtype))
49 | ], axis)
50 |
51 | # Update shape.
52 | new_shape = tensor.get_shape().as_list() # A copy is being made.
53 | new_shape[axis] = new_size
54 | resized.set_shape(new_shape)
55 | return resized
56 |
57 |
58 | class BaseReader(object):
59 | """Inherit from this class when implementing new readers."""
60 |
61 | def prepare_reader(self, unused_filename_queue):
62 | """Create a thread for generating prediction and label tensors."""
63 | raise NotImplementedError()
64 |
65 |
66 | class YT8MAggregatedFeatureReader(BaseReader):
67 | """Reads TFRecords of pre-aggregated Examples.
68 |
69 | The TFRecords must contain Examples with a sparse int64 'labels' feature and
70 | a fixed length float32 feature, obtained from the features in 'feature_name'.
71 | The float features are assumed to be an average of dequantized values.
72 | """
73 |
74 | def __init__( # pylint: disable=dangerous-default-value
75 | self,
76 | num_classes=3862,
77 | feature_sizes=[1024, 128],
78 | feature_names=["mean_rgb", "mean_audio"]):
79 | """Construct a YT8MAggregatedFeatureReader.
80 |
81 | Args:
82 | num_classes: a positive integer for the number of classes.
83 | feature_sizes: positive integer(s) for the feature dimensions as a list.
84 | feature_names: the feature name(s) in the tensorflow record as a list.
85 | """
86 |
87 | assert len(feature_names) == len(feature_sizes), (
88 | "length of feature_names (={}) != length of feature_sizes (={})".format(
89 | len(feature_names), len(feature_sizes)))
90 |
91 | self.num_classes = num_classes
92 | self.feature_sizes = feature_sizes
93 | self.feature_names = feature_names
94 |
95 | def prepare_reader(self, filename_queue, batch_size=1024):
96 | """Creates a single reader thread for pre-aggregated YouTube 8M Examples.
97 |
98 | Args:
99 | filename_queue: A tensorflow queue of filename locations.
100 | batch_size: batch size used for feature output.
101 |
102 | Returns:
103 | A dict of video indexes, features, labels, and frame counts.
104 | """
105 | reader = tf.TFRecordReader()
106 | _, serialized_examples = reader.read_up_to(filename_queue, batch_size)
107 |
108 | tf.add_to_collection("serialized_examples", serialized_examples)
109 | return self.prepare_serialized_examples(serialized_examples)
110 |
111 | def prepare_serialized_examples(self, serialized_examples):
112 | """Parse a single video-level TF Example."""
113 | # set the mapping from the fields to data types in the proto
114 | num_features = len(self.feature_names)
115 | assert num_features > 0, "self.feature_names is empty!"
116 | assert len(self.feature_names) == len(self.feature_sizes), \
117 | "length of feature_names (={}) != length of feature_sizes (={})".format(
118 | len(self.feature_names), len(self.feature_sizes))
119 |
120 | feature_map = {
121 | "id": tf.io.FixedLenFeature([], tf.string),
122 | "labels": tf.io.VarLenFeature(tf.int64)
123 | }
124 | for feature_index in range(num_features):
125 | feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature(
126 | [self.feature_sizes[feature_index]], tf.float32)
127 |
128 | features = tf.parse_example(serialized_examples, features=feature_map)
129 | labels = tf.sparse_to_indicator(features["labels"], self.num_classes)
130 | labels.set_shape([None, self.num_classes])
131 | concatenated_features = tf.concat(
132 | [features[feature_name] for feature_name in self.feature_names], 1)
133 |
134 | output_dict = {
135 | "video_ids": features["id"],
136 | "video_matrix": concatenated_features,
137 | "labels": labels,
138 | "num_frames": tf.ones([tf.shape(serialized_examples)[0]])
139 | }
140 |
141 | return output_dict
142 |
143 |
144 | class YT8MFrameFeatureReader(BaseReader):
145 | """Reads TFRecords of SequenceExamples.
146 |
147 | The TFRecords must contain SequenceExamples with the sparse in64 'labels'
148 | context feature and a fixed length byte-quantized feature vector, obtained
149 | from the features in 'feature_names'. The quantized features will be mapped
150 | back into a range between min_quantized_value and max_quantized_value.
151 | """
152 |
153 | def __init__( # pylint: disable=dangerous-default-value
154 | self,
155 | num_classes=3862,
156 | feature_sizes=[1024, 128],
157 | feature_names=["rgb", "audio"],
158 | max_frames=300,
159 | segment_labels=False,
160 | segment_size=5):
161 | """Construct a YT8MFrameFeatureReader.
162 |
163 | Args:
164 | num_classes: a positive integer for the number of classes.
165 | feature_sizes: positive integer(s) for the feature dimensions as a list.
166 | feature_names: the feature name(s) in the tensorflow record as a list.
167 | max_frames: the maximum number of frames to process.
168 | segment_labels: if we read segment labels instead.
169 | segment_size: the segment_size used for reading segments.
170 | """
171 |
172 | assert len(feature_names) == len(feature_sizes), (
173 | "length of feature_names (={}) != length of feature_sizes (={})".format(
174 | len(feature_names), len(feature_sizes)))
175 |
176 | self.num_classes = num_classes
177 | self.feature_sizes = feature_sizes
178 | self.feature_names = feature_names
179 | self.max_frames = max_frames
180 | self.segment_labels = segment_labels
181 | self.segment_size = segment_size
182 |
183 | def get_video_matrix(self, features, feature_size, max_frames,
184 | max_quantized_value, min_quantized_value):
185 | """Decodes features from an input string and quantizes it.
186 |
187 | Args:
188 | features: raw feature values
189 | feature_size: length of each frame feature vector
190 | max_frames: number of frames (rows) in the output feature_matrix
191 | max_quantized_value: the maximum of the quantized value.
192 | min_quantized_value: the minimum of the quantized value.
193 |
194 | Returns:
195 | feature_matrix: matrix of all frame-features
196 | num_frames: number of frames in the sequence
197 | """
198 | decoded_features = tf.reshape(
199 | tf.cast(tf.decode_raw(features, tf.uint8), tf.float32),
200 | [-1, feature_size])
201 |
202 | num_frames = tf.minimum(tf.shape(decoded_features)[0], max_frames)
203 | feature_matrix = utils.Dequantize(decoded_features, max_quantized_value,
204 | min_quantized_value)
205 | feature_matrix = resize_axis(feature_matrix, 0, max_frames)
206 | return feature_matrix, num_frames
207 |
208 | def prepare_reader(self,
209 | filename_queue,
210 | max_quantized_value=2,
211 | min_quantized_value=-2):
212 | """Creates a single reader thread for YouTube8M SequenceExamples.
213 |
214 | Args:
215 | filename_queue: A tensorflow queue of filename locations.
216 | max_quantized_value: the maximum of the quantized value.
217 | min_quantized_value: the minimum of the quantized value.
218 |
219 | Returns:
220 | A dict of video indexes, video features, labels, and frame counts.
221 | """
222 | reader = tf.TFRecordReader()
223 | _, serialized_example = reader.read(filename_queue)
224 |
225 | return self.prepare_serialized_examples(serialized_example,
226 | max_quantized_value,
227 | min_quantized_value)
228 |
229 | def prepare_serialized_examples(self,
230 | serialized_example,
231 | max_quantized_value=2,
232 | min_quantized_value=-2):
233 | """Parse single serialized SequenceExample from the TFRecords."""
234 |
235 | # Read/parse frame/segment-level labels.
236 | context_features = {
237 | "id": tf.io.FixedLenFeature([], tf.string),
238 | }
239 | if self.segment_labels:
240 | context_features.update({
241 | # There is no need to read end-time given we always assume the segment
242 | # has the same size.
243 | "segment_labels": tf.io.VarLenFeature(tf.int64),
244 | "segment_start_times": tf.io.VarLenFeature(tf.int64),
245 | "segment_scores": tf.io.VarLenFeature(tf.float32)
246 | })
247 | else:
248 | context_features.update({"labels": tf.io.VarLenFeature(tf.int64)})
249 | sequence_features = {
250 | feature_name: tf.io.FixedLenSequenceFeature([], dtype=tf.string)
251 | for feature_name in self.feature_names
252 | }
253 | contexts, features = tf.io.parse_single_sequence_example(
254 | serialized_example,
255 | context_features=context_features,
256 | sequence_features=sequence_features)
257 |
258 | # loads (potentially) different types of features and concatenates them
259 | num_features = len(self.feature_names)
260 | assert num_features > 0, "No feature selected: feature_names is empty!"
261 |
262 | assert len(self.feature_names) == len(self.feature_sizes), (
263 | "length of feature_names (={}) != length of feature_sizes (={})".format(
264 | len(self.feature_names), len(self.feature_sizes)))
265 |
266 | num_frames = -1 # the number of frames in the video
267 | feature_matrices = [None] * num_features # an array of different features
268 | for feature_index in range(num_features):
269 | feature_matrix, num_frames_in_this_feature = self.get_video_matrix(
270 | features[self.feature_names[feature_index]],
271 | self.feature_sizes[feature_index], self.max_frames,
272 | max_quantized_value, min_quantized_value)
273 | if num_frames == -1:
274 | num_frames = num_frames_in_this_feature
275 |
276 | feature_matrices[feature_index] = feature_matrix
277 |
278 | # cap the number of frames at self.max_frames
279 | num_frames = tf.minimum(num_frames, self.max_frames)
280 |
281 | # concatenate different features
282 | video_matrix = tf.concat(feature_matrices, 1)
283 |
284 | # Partition frame-level feature matrix to segment-level feature matrix.
285 | if self.segment_labels:
286 | start_times = contexts["segment_start_times"].values
287 | # Here we assume all the segments that started at the same start time has
288 | # the same segment_size.
289 | uniq_start_times, seg_idxs = tf.unique(start_times,
290 | out_idx=tf.dtypes.int64)
291 | # TODO(zhengxu): Ensure the segment_sizes are all same.
292 | segment_size = self.segment_size
293 | # Range gather matrix, e.g., [[0,1,2],[1,2,3]] for segment_size == 3.
294 | range_mtx = tf.expand_dims(uniq_start_times, axis=-1) + tf.expand_dims(
295 | tf.range(0, segment_size, dtype=tf.int64), axis=0)
296 | # Shape: [num_segment, segment_size, feature_dim].
297 | batch_video_matrix = tf.gather_nd(video_matrix,
298 | tf.expand_dims(range_mtx, axis=-1))
299 | num_segment = tf.shape(batch_video_matrix)[0]
300 | batch_video_ids = tf.reshape(tf.tile([contexts["id"]], [num_segment]),
301 | (num_segment,))
302 | batch_frames = tf.reshape(tf.tile([segment_size], [num_segment]),
303 | (num_segment,))
304 |
305 | # For segment labels, all labels are not exhausively rated. So we only
306 | # evaluate the rated labels.
307 |
308 | # Label indices for each segment, shape: [num_segment, 2].
309 | label_indices = tf.stack([seg_idxs, contexts["segment_labels"].values],
310 | axis=-1)
311 | label_values = contexts["segment_scores"].values
312 | sparse_labels = tf.sparse.SparseTensor(label_indices, label_values,
313 | (num_segment, self.num_classes))
314 | batch_labels = tf.sparse.to_dense(sparse_labels, validate_indices=False)
315 |
316 | sparse_label_weights = tf.sparse.SparseTensor(
317 | label_indices, tf.ones_like(label_values, dtype=tf.float32),
318 | (num_segment, self.num_classes))
319 | batch_label_weights = tf.sparse.to_dense(sparse_label_weights,
320 | validate_indices=False)
321 | else:
322 | # Process video-level labels.
323 | label_indices = contexts["labels"].values
324 | sparse_labels = tf.sparse.SparseTensor(
325 | tf.expand_dims(label_indices, axis=-1),
326 | tf.ones_like(contexts["labels"].values, dtype=tf.bool),
327 | (self.num_classes,))
328 | labels = tf.sparse.to_dense(sparse_labels,
329 | default_value=False,
330 | validate_indices=False)
331 | # convert to batch format.
332 | batch_video_ids = tf.expand_dims(contexts["id"], 0)
333 | batch_video_matrix = tf.expand_dims(video_matrix, 0)
334 | batch_labels = tf.expand_dims(labels, 0)
335 | batch_frames = tf.expand_dims(num_frames, 0)
336 | batch_label_weights = None
337 |
338 | output_dict = {
339 | "video_ids": batch_video_ids,
340 | "video_matrix": batch_video_matrix,
341 | "labels": batch_labels,
342 | "num_frames": batch_frames,
343 | }
344 | if batch_label_weights is not None:
345 | output_dict["label_weights"] = batch_label_weights
346 |
347 | return output_dict
348 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YouTube-8M Tensorflow Starter Code
2 |
3 | This repo contains starter code for training and evaluating machine learning
4 | models over the [YouTube-8M](https://research.google.com/youtube8m/) dataset.
5 | This is the starter code for our
6 | [3rd Youtube8M Video Understanding Challenge on Kaggle](https://www.kaggle.com/c/youtube8m-2019)
7 | and part of the International Conference on Computer Vision (ICCV) 2019 selected
8 | workshop session. The code gives an end-to-end working example for reading the
9 | dataset, training a TensorFlow model, and evaluating the performance of the
10 | model.
11 |
12 | ## Table of Contents
13 |
14 | * [Running on Your Own Machine](#running-on-your-own-machine)
15 | * [Requirements](#requirements)
16 | * [Download Dataset Locally](#download-dataset-locally)
17 | * [Try the starter code](#try-the-starter-code)
18 | * [Train video-level model on frame-level features and inference at
19 | segment-level.](#train-video-level-model-on-frame-level-features-and-inference-at-segment-level)
20 | * [Tensorboard](#tensorboard)
21 | * [Using GPUs](#using-gpus)
22 | * [Running on Google's Cloud Machine Learning Platform](#running-on-googles-cloud-machine-learning-platform)
23 | * [Requirements](#requirements-1)
24 | * [Accessing Files on Google Cloud](#accessing-files-on-google-cloud)
25 | * [Testing Locally](#testing-locally)
26 | * [Training on the Cloud over Frame-Level Features](#training-on-the-cloud-over-frame-level-features)
27 | * [Evaluation and Inference](#evaluation-and-inference)
28 | * [Create Your Own Dataset Files](#create-your-own-dataset-files)
29 | * [Training without this Starter Code](#training-without-this-starter-code)
30 | * [Export Your Model for MediaPipe Inference](#export-your-model-for-mediapipe-inference)
31 | * [More Documents](#more-documents)
32 | * [About This Project](#about-this-project)
33 |
34 | ## Running on Your Own Machine
35 |
36 | ### Requirements
37 |
38 | The starter code requires Tensorflow. If you haven't installed it yet, follow
39 | the instructions on [tensorflow.org](https://www.tensorflow.org/install/). This
40 | code has been tested with Tensorflow 1.14. Going forward, we will continue to
41 | target the latest released version of Tensorflow.
42 |
43 | Please verify that you have Python 3.6+ and Tensorflow 1.14 or higher installed
44 | by running the following commands:
45 |
46 | ```sh
47 | python --version
48 | python -c 'import tensorflow as tf; print(tf.__version__)'
49 | ```
50 |
51 | ### Download Dataset Locally
52 |
53 | Please see our
54 | [dataset website](https://research.google.com/youtube8m/download.html) for
55 | up-to-date download instructions.
56 |
57 | In this document, we assume you download all the frame-level feature dataset to
58 | `~/yt8m/2/frame` and segment-level validation/test dataset to `~/yt8m/3/frame`.
59 | So the structure should look like
60 |
61 | ```
62 | ~/yt8m/
63 | - ~/yt8m/2/frame/
64 | - ~/yt8m/2/frame/train
65 | - ~/yt8m/3/frame/
66 | - ~/yt8m/3/frame/test
67 | - ~/yt8m/3/frame/validate
68 | ```
69 |
70 | ### Try the starter code
71 |
72 | Clone this git repo: `mkdir -p ~/yt8m/code cd ~/yt8m/code git clone
73 | https://github.com/google/youtube-8m.git`
74 |
75 | #### Train video-level model on frame-level features and inference at segment-level.
76 |
77 | Train using `train.py`, selecting a frame-level model (e.g.
78 | `FrameLevelLogisticModel`), and instructing the trainer to use
79 | `--frame_features`. TLDR - frame-level features are compressed, and this flag
80 | uncompresses them.
81 |
82 | ```bash
83 | python train.py --frame_features --model=FrameLevelLogisticModel \
84 | --feature_names='rgb,audio' --feature_sizes='1024,128' \
85 | --train_data_pattern=${HOME}/yt8m/2/frame/train/train*.tfrecord
86 | --train_dir ~/yt8m/models/frame/sample_model --start_new_model
87 | ```
88 |
89 | Evaluate the model by
90 |
91 | ```bash
92 | python eval.py \
93 | --eval_data_pattern=${HOME}/yt8m/3/frame/validate/validate*.tfrecord \
94 | --train_dir ~/yt8m/models/frame/sample_model --segment_labels
95 | ```
96 |
97 | This will provide some comprehensive metrics, e.g., gAP, mAP, etc., for your
98 | models.
99 |
100 | Produce CSV (`kaggle_solution.csv`) by doing inference:
101 |
102 | ```bash
103 | python \
104 | inference.py --train_dir ~/yt8m/models/frame/sample_model \
105 | --output_file=$HOME/tmp/kaggle_solution.csv \
106 | --input_data_pattern=${HOME}/yt8m/3/frame/test/test*.tfrecord --segment_labels
107 | ```
108 |
109 | (Optional) If you wish to see how the models are evaluated in Kaggle system, you
110 | can do so by
111 |
112 | ```bash
113 | python inference.py --train_dir ~/yt8m/models/frame/sample_model \
114 | --output_file=$HOME/tmp/kaggle_solution_validation.csv \
115 | --input_data_pattern=${HOME}/yt8m/3/frame/validate/validate*.tfrecord \
116 | --segment_labels
117 | ```
118 |
119 | ```bash
120 | python segment_eval_inference.py \
121 | --eval_data_pattern=${HOME}/yt8m/3/frame/validate/validate*.tfrecord \
122 | --label_cache=$HOME/tmp/validate.label_cache \
123 | --submission_file=$HOME/tmp/kaggle_solution_validation.csv --top_n=100000
124 | ```
125 |
126 | **NOTE**: This script can be slow for the first time running. It will read
127 | TFRecord data and build label cache. Once label cache is built, the evaluation
128 | will be much faster later on.
129 |
130 | #### Tensorboard
131 |
132 | You can use Tensorboard to compare your frame-level or video-level models, like:
133 |
134 | ```sh
135 | MODELS_DIR=~/yt8m/models
136 | tensorboard --logdir frame:${MODELS_DIR}/frame
137 | ```
138 |
139 | We find it useful to keep the tensorboard instance always running, as we train
140 | and evaluate different models.
141 |
142 | #### Using GPUs
143 |
144 | If your Tensorflow installation has GPU support, e.g., installed with `pip
145 | install tensorflow-gpu`, this code will make use of all of your compatible GPUs.
146 | You can verify your installation by running
147 |
148 | ```
149 | python -c 'import tensorflow as tf; tf.Session()'
150 | ```
151 |
152 | This will print out something like the following for each of your compatible
153 | GPUs.
154 |
155 | ```
156 | I tensorflow/core/common_runtime/gpu/gpu_init.cc:102] Found device 0 with properties:
157 | name: Tesla M40
158 | major: 5 minor: 2 memoryClockRate (GHz) 1.112
159 | pciBusID 0000:04:00.0
160 | Total memory: 11.25GiB
161 | Free memory: 11.09GiB
162 | ...
163 | ```
164 |
165 | If at least one GPU was found, the forward and backward passes will be computed
166 | with the GPUs, whereas the CPU will be used primarily for the input and output
167 | pipelines. If you have multiple GPUs, the current default behavior is to use
168 | only one of them.
169 |
170 |
171 | ## Running on Google's Cloud Machine Learning Platform
172 |
173 | ### Requirements
174 |
175 | This option requires you to have an appropriately configured Google Cloud
176 | Platform account. To create and configure your account, please make sure you
177 | follow the instructions
178 | [here](https://cloud.google.com/ml/docs/how-tos/getting-set-up).
179 |
180 | Please also verify that you have Python 3.6+ and Tensorflow 1.14 or higher
181 | installed by running the following commands:
182 |
183 | ```sh
184 | python --version
185 | python -c 'import tensorflow as tf; print(tf.__version__)'
186 | ```
187 |
188 | ### Accessing Files on Google Cloud
189 |
190 | You can browse the storage buckets you created on Google Cloud, for example, to
191 | access the trained models, prediction CSV files, etc. by visiting the
192 | [Google Cloud storage browser](https://console.cloud.google.com/storage/browser).
193 |
194 | Alternatively, you can use the 'gsutil' command to download the files directly.
195 | For example, to download the output of the inference code from the previous
196 | section to your local machine, run:
197 |
198 | ```
199 | gsutil cp $BUCKET_NAME/${JOB_TO_EVAL}/predictions.csv .
200 | ```
201 |
202 | ### Testing Locally
203 |
204 | All gcloud commands should be done from the directory *immediately above* the
205 | source code. You should be able to see the source code directory if you run
206 | 'ls'.
207 |
208 | As you are developing your own models, you will want to test them quickly to
209 | flush out simple problems without having to submit them to the cloud.
210 |
211 | Here is an example command line for frame-level training:
212 |
213 | ```sh
214 | gcloud ai-platform local train \
215 | --package-path=youtube-8m --module-name=youtube-8m.train -- \
216 | --train_data_pattern='gs://youtube8m-ml/2/frame/train/train*.tfrecord' \
217 | --train_dir=/tmp/yt8m_train --frame_features --model=FrameLevelLogisticModel \
218 | --feature_names='rgb,audio' --feature_sizes='1024,128' --start_new_model
219 | ```
220 |
221 | ### Training on the Cloud over Frame-Level Features
222 |
223 | The following commands will train a model on Google Cloud over frame-level
224 | features.
225 |
226 | ```bash
227 | BUCKET_NAME=gs://${USER}_yt8m_train_bucket
228 | # (One Time) Create a storage bucket to store training logs and checkpoints.
229 | gsutil mb -l us-east1 $BUCKET_NAME
230 | # Submit the training job.
231 | JOB_NAME=yt8m_train_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ai-platform jobs \
232 | submit training $JOB_NAME \
233 | --package-path=youtube-8m --module-name=youtube-8m.train \
234 | --staging-bucket=$BUCKET_NAME --region=us-east1 \
235 | --config=youtube-8m/cloudml-gpu.yaml \
236 | -- --train_data_pattern='gs://youtube8m-ml/2/frame/train/train*.tfrecord' \
237 | --frame_features --model=FrameLevelLogisticModel \
238 | --feature_names='rgb,audio' --feature_sizes='1024,128' \
239 | --train_dir=$BUCKET_NAME/yt8m_train_frame_level_logistic_model --start_new_model
240 | ```
241 |
242 | In the 'gsutil' command above, the 'package-path' flag refers to the directory
243 | containing the 'train.py' script and more generally the python package which
244 | should be deployed to the cloud worker. The module-name refers to the specific
245 | python script which should be executed (in this case the train module).
246 |
247 | It may take several minutes before the job starts running on Google Cloud. When
248 | it starts you will see outputs like the following:
249 |
250 | ```
251 | training step 270| Hit@1: 0.68 PERR: 0.52 Loss: 638.453
252 | training step 271| Hit@1: 0.66 PERR: 0.49 Loss: 635.537
253 | training step 272| Hit@1: 0.70 PERR: 0.52 Loss: 637.564
254 | ```
255 |
256 | At this point you can disconnect your console by pressing "ctrl-c". The model
257 | will continue to train indefinitely in the Cloud. Later, you can check on its
258 | progress or halt the job by visiting the
259 | [Google Cloud ML Jobs console](https://console.cloud.google.com/ml/jobs).
260 |
261 | You can train many jobs at once and use tensorboard to compare their performance
262 | visually.
263 |
264 | ```sh
265 | tensorboard --logdir=$BUCKET_NAME --port=8080
266 | ```
267 |
268 | Once tensorboard is running, you can access it at the following url:
269 | [http://localhost:8080](http://localhost:8080). If you are using Google Cloud
270 | Shell, you can instead click the Web Preview button on the upper left corner of
271 | the Cloud Shell window and select "Preview on port 8080". This will bring up a
272 | new browser tab with the Tensorboard view.
273 |
274 | ### Evaluation and Inference
275 |
276 | Here's how to evaluate a model on the validation dataset:
277 |
278 | ```sh
279 | JOB_TO_EVAL=yt8m_train_frame_level_logistic_model
280 | JOB_NAME=yt8m_eval_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ai-platform jobs \
281 | submit training $JOB_NAME \
282 | --package-path=youtube-8m --module-name=youtube-8m.eval \
283 | --staging-bucket=$BUCKET_NAME --region=us-east1 \
284 | --config=youtube-8m/cloudml-gpu.yaml \
285 | -- --eval_data_pattern='gs://youtube8m-ml/3/frame/validate/validate*.tfrecord' \
286 | --frame_features --model=FrameLevelLogisticModel --feature_names='rgb,audio' \
287 | --feature_sizes='1024,128' --train_dir=$BUCKET_NAME/${JOB_TO_EVAL} \
288 | --segment_labels --run_once=True
289 | ```
290 |
291 | And here's how to perform inference with a model on the test set:
292 |
293 | ```sh
294 | JOB_TO_EVAL=yt8m_train_frame_level_logistic_model
295 | JOB_NAME=yt8m_inference_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug ai-platform jobs \
296 | submit training $JOB_NAME \
297 | --package-path=youtube-8m --module-name=youtube-8m.inference \
298 | --staging-bucket=$BUCKET_NAME --region=us-east1 \
299 | --config=youtube-8m/cloudml-gpu.yaml \
300 | -- --input_data_pattern='gs://youtube8m-ml/3/frame/test/test*.tfrecord' \
301 | --train_dir=$BUCKET_NAME/${JOB_TO_EVAL} --segment_labels \
302 | --output_file=$BUCKET_NAME/${JOB_TO_EVAL}/predictions.csv
303 | ```
304 |
305 | Note the confusing use of 'training' in the above gcloud commands. Despite the
306 | name, the 'training' argument really just offers a cloud hosted
307 | python/tensorflow service. From the point of view of the Cloud Platform, there
308 | is no distinction between our training and inference jobs. The Cloud ML platform
309 | also offers specialized functionality for prediction with Tensorflow models, but
310 | discussing that is beyond the scope of this readme.
311 |
312 | Once these job starts executing you will see outputs similar to the following
313 | for the evaluation code:
314 |
315 | ```
316 | examples_processed: 1024 | global_step 447044 | Batch Hit@1: 0.782 | Batch PERR: 0.637 | Batch Loss: 7.821 | Examples_per_sec: 834.658
317 | ```
318 |
319 | and the following for the inference code:
320 |
321 | ```
322 | num examples processed: 8192 elapsed seconds: 14.85
323 | ```
324 |
325 | ## Export Your Model for MediaPipe Inference
326 | To run inference with your model in [MediaPipe inference
327 | demo](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/youtube8m#steps-to-run-the-youtube-8m-inference-graph-with-the-yt8m-dataset), you need to export your checkpoint to a SavedModel.
328 |
329 | Example command:
330 | ```sh
331 | python export_model_mediapipe.py --checkpoint_file ~/yt8m/models/frame/sample_model/inference_model/segment_inference_model --output_dir /tmp/mediapipe/saved_model/
332 | ```
333 |
334 |
335 | ## Create Your Own Dataset Files
336 |
337 | You can create your dataset files from your own videos. Our
338 | [feature extractor](./feature_extractor) code creates `tfrecord` files,
339 | identical to our dataset files. You can use our starter code to train on the
340 | `tfrecord` files output by the feature extractor. In addition, you can fine-tune
341 | your YouTube-8M models on your new dataset.
342 |
343 | ## Training without this Starter Code
344 |
345 | You are welcome to use our dataset without using our starter code. However, if
346 | you'd like to compete on Kaggle, then you must make sure that you are able to
347 | produce a prediction CSV file produced by our `inference.py`. In particular, the
348 | [predictions CSV file](https://www.kaggle.com/c/youtube8m-2018#evaluation) must
349 | have two fields: `Class Id,Segment Ids` where `Class Id` must be class ids
350 | listed in `segment_label_ids.csv` and `Segment Ids` is a space-delimited list of
351 | `