├── LICENSE ├── README.md ├── data_generator.py ├── data_provider.py ├── emotion_eval.py ├── emotion_train.py ├── inception_processing.py ├── losses.py ├── metrics.py └── models.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, tzirakis 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Multimodal Emotion Recognition using Deep Neural Networks 2 | 3 | This package provides training and evaluation code for the end-to-end multimodal emotion recognition paper. If you use this codebase in your experiments please cite: 4 | 5 | `P. Tzirakis, G. Trigeorgis, M. A. Nicolaou, B. Schuller and S. Zafeiriou, "End-to-End Multimodal Emotion Recognition using Deep Neural Networks," in IEEE Journal of Selected Topics in Signal Processing, vol. PP, no. 99, pp. 1-1.` (http://ieeexplore.ieee.org/document/8070966/) 6 | 7 | ## UPDATE 8 | ### Implementation of this method in PyTorch (along with pretrain models) can be found in our [End2You toolkit](https://github.com/end2you/end2you) 9 | 10 | ## Requirements 11 | Below are listed the required modules to run the code. 12 | 13 | * Python <= 2.7 14 | * NumPy >= 1.11.1 15 | * TensorFlow <= 0.12 16 | * Menpo >= 0.6.2 17 | * MoviePy >= 0.2.2.11 18 | 19 | ## Content 20 | This repository contains the files: 21 | * model.py: contains the audio and video networks. 22 | * emotion_train.py: is in charge of training. 23 | * emotion_eval.py: is in charge of evaluating. 24 | * data_provider.py: provides the data. 25 | * data_generator.py: creates the tfrecords from '.wav' files 26 | * metrics.py: contains the concordance metric used for evaluation. 27 | * losses.py: contains the loss function of the training. 28 | * inception_processing.py: provides functions for visual regularization. 29 | 30 | The multimodal model can be downloaded from here : https://www.doc.ic.ac.uk/~pt511/emotion_recognition_model.zip 31 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import menpo 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | 6 | from io import BytesIO 7 | from pathlib import Path 8 | from moviepy.editor import VideoFileClip 9 | from menpo.visualize import progress_bar_str, print_progress 10 | from moviepy.audio.AudioClip import AudioArrayClip 11 | 12 | 13 | root_dir = Path('path_of_RECOLA') 14 | 15 | # Inser the numbers for each video for the RECOLA Dataset 16 | portion_to_id = dict( 17 | train = [], 18 | valid = [], 19 | test = [] 20 | ) 21 | 22 | def get_samples(subject_id): 23 | arousal_label_path = root_dir / 'Ratings_affective_behaviour_CCC_centred/arousal/{}.csv'.format(subject_id) 24 | valence_label_path = root_dir / 'Ratings_affective_behaviour_CCC_centred/valence/{}.csv'.format(subject_id) 25 | 26 | clip = VideoFileClip(str(root_dir / "Video_recordings_MP4/{}.mp4".format(subject_id))) 27 | 28 | subsampled_audio = clip.audio.set_fps(16000) 29 | 30 | audio_frames = [] 31 | for i in range(1, 7501): 32 | time = 0.04 * i 33 | 34 | audio = np.array(list(subsampled_audio.subclip(time - 0.04, time).iter_frames())) 35 | audio = audio.mean(1)[:640] 36 | 37 | audio_frames.append(audio.astype(np.float32)) 38 | 39 | arousal = np.loadtxt(str(arousal_label_path), delimiter=',')[:, 1][1:] 40 | valence = np.loadtxt(str(valence_label_path), delimiter=',')[:, 1][1:] 41 | 42 | return audio_frames, np.dstack([arousal, valence])[0].astype(np.float32) 43 | 44 | def get_jpg_string(im): 45 | # Gets the serialized jpg from a menpo `Image`. 46 | fp = BytesIO() 47 | menpo.io.export_image(im, fp, extension='jpg') 48 | fp.seek(0) 49 | return fp.read() 50 | 51 | def _int_feauture(value): 52 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 53 | 54 | def _bytes_feauture(value): 55 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 56 | 57 | def serialize_sample(writer, subject_id): 58 | subject_name = 'P{}'.format(subject_id) 59 | 60 | for i, (audio, label) in enumerate(zip(*get_samples(subject_name))): 61 | 62 | example = tf.train.Example(features=tf.train.Features(feature={ 63 | 'sample_id': _int_feauture(i), 64 | 'subject_id': _int_feauture(subject_id), 65 | 'label': _bytes_feauture(label.tobytes()), 66 | 'raw_audio': _bytes_feauture(audio.tobytes()), 67 | })) 68 | 69 | writer.write(example.SerializeToString()) 70 | del audio, label 71 | 72 | def main(directory): 73 | for portion in portion_to_id.keys(): 74 | print(portion) 75 | 76 | for subj_id in print_progress(portion_to_id[portion]): 77 | 78 | writer = tf.python_io.TFRecordWriter( 79 | (directory / 'tf_records' / portion / '{}.tfrecords'.format(subj_id) 80 | ).as_posix()) 81 | serialize_sample(writer, subj_id) 82 | 83 | if __name__ == "__main__": 84 | main(Path('path_to_save_tfrecords')) 85 | 86 | -------------------------------------------------------------------------------- /data_provider.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from pathlib import Path 7 | from inception_processing import distort_color 8 | 9 | 10 | slim = tf.contrib.slim 11 | 12 | def get_split(dataset_dir, is_training=True, split_name='train', batch_size=32, 13 | seq_length=100, debugging=False): 14 | """Returns a data split of the RECOLA dataset, which was saved in tfrecords format. 15 | 16 | Args: 17 | split_name: A train/test/valid split name. 18 | Returns: 19 | The raw audio examples and the corresponding arousal/valence 20 | labels. 21 | """ 22 | 23 | root_path = Path(dataset_dir) / split_name 24 | paths = [str(x) for x in root_path.glob('*.tfrecords')] 25 | 26 | filename_queue = tf.train.string_input_producer(paths, shuffle=is_training) 27 | 28 | reader = tf.TFRecordReader() 29 | 30 | _, serialized_example = reader.read(filename_queue) 31 | 32 | features = tf.parse_single_example( 33 | serialized_example, 34 | features={ 35 | 'raw_audio': tf.FixedLenFeature([], tf.string), 36 | 'label': tf.FixedLenFeature([], tf.string), 37 | 'subject_id': tf.FixedLenFeature([], tf.int64), 38 | 'frame': tf.FixedLenFeature([], tf.string), 39 | } 40 | ) 41 | 42 | raw_audio = tf.decode_raw(features['raw_audio'], tf.float32) 43 | frame = tf.image.decode_jpeg(features['frame']) 44 | label = tf.decode_raw(features['label'], tf.float32) 45 | subject_id = features['subject_id'] 46 | 47 | # 640 samples at 16KhZ corresponds to 40ms which is the frequency of 48 | # annotations. 49 | raw_audio.set_shape([640]) 50 | label.set_shape([2]) 51 | frame.set_shape([96, 96, 3]) 52 | frame = tf.cast(frame, tf.float32) / 255. 53 | 54 | if is_training: 55 | resized_image = tf.image.resize_images(frame, [110, 110]) 56 | frame = tf.random_crop(resized_image, [96, 96, 3]) 57 | frame = distort_color(frame, 1) 58 | 59 | # Number of threads should always be one, in order to load samples 60 | # sequentially. 61 | frames, audio_samples, labels, subject_ids = tf.train.batch( 62 | [frame, raw_audio, label, subject_id], seq_length, num_threads=1, capacity=1000) 63 | 64 | 65 | # Assert is an expensive op so we only want to use it when it's a must. 66 | if debugging: 67 | # Asserts that a sequence contains samples from a single subject. 68 | assert_op = tf.Assert( 69 | tf.reduce_all(tf.equal(subject_ids[0], subject_ids)), 70 | [subject_ids]) 71 | 72 | with tf.control_dependencies([assert_op]): 73 | audio_samples = tf.identity(audio_samples) 74 | 75 | audio_samples = tf.expand_dims(audio_samples, 0) 76 | labels = tf.expand_dims(labels, 0) 77 | frames = tf.expand_dims(frames, 0) 78 | 79 | if is_training: 80 | frames, audio_samples, labels, subject_ids = tf.train.shuffle_batch( 81 | [frames, audio_samples, labels, subject_ids], batch_size, 1000, 50, num_threads=1) 82 | else: 83 | frames, audio_samples, labels, subject_ids = tf.train.batch( 84 | [frames, audio_samples, labels, subject_ids], batch_size, num_threads=1, capacity=1000) 85 | 86 | return frames[:, 0, :, :], audio_samples[:, 0, :, :], labels[:, 0, :, :], subject_ids 87 | -------------------------------------------------------------------------------- /emotion_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import data_provider 7 | import models 8 | import losses 9 | import math 10 | import metrics 11 | 12 | from menpo.visualize import print_progress 13 | from pathlib import Path 14 | from tensorflow.python.platform import tf_logging as logging 15 | 16 | 17 | slim = tf.contrib.slim 18 | 19 | # Create FLAGS 20 | FLAGS = tf.app.flags.FLAGS 21 | tf.app.flags.DEFINE_integer('batch_size', 1, 'The batch size to use.') 22 | tf.app.flags.DEFINE_string('model', 'both','Which model is going to be used: audio, video, or both ') 23 | tf.app.flags.DEFINE_string('dataset_dir', 'path_to_tfrecords', 'The tfrecords directory.') 24 | tf.app.flags.DEFINE_string('checkpoint_dir', 'ckpt/train', 'The directory that contains the saved model.') 25 | tf.app.flags.DEFINE_string('log_dir', 'ckpt/log', 'The directory to save log files.') 26 | tf.app.flags.DEFINE_integer('num_examples', 1000, 'The number of examples in the data set') 27 | tf.app.flags.DEFINE_integer('hidden_units', 256, 'The number of hidden units in the recurrent model') 28 | tf.app.flags.DEFINE_integer('seq_length', 150, 29 | 'The number of consecutive examples to be used in the recurrent model') 30 | tf.app.flags.DEFINE_string('eval_interval_secs', 300, 'How often to run the evaluation (in sec).') 31 | tf.app.flags.DEFINE_string('portion', 'valid', 'The portion (train, valid, test) to use for evaluation') 32 | 33 | def evaluate(data_folder): 34 | """Evaluates the model (audio/video/both). 35 | 36 | Args: 37 | data_folder: The folder that contains the data to evaluate the model. 38 | """ 39 | 40 | g = tf.Graph() 41 | with g.as_default(): 42 | 43 | # Load dataset. 44 | frames, audio, ground_truth,_ = data_provider.get_split(data_folder, False, 45 | FLAGS.portion, FLAGS.batch_size, 46 | FLAGS.seq_length) 47 | 48 | # Define model graph. 49 | with slim.arg_scope([slim.batch_norm, slim.layers.dropout], 50 | is_training=False): 51 | with slim.arg_scope(slim.nets.resnet_utils.resnet_arg_scope(is_training=False)): 52 | prediction = models.get_model(FLAGS.model)(frames, audio, 53 | hidden_units=FLAGS.hidden_units) 54 | 55 | # Computing MSE and Concordance values, and adding them to summary 56 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 57 | 'eval/mse_arousal':slim.metrics.streaming_mean_squared_error(prediction[:,:,0], ground_truth[:,:,0]), 58 | 'eval/mse_valence':slim.metrics.streaming_mean_squared_error(prediction[:,:,1], ground_truth[:,:,1]), 59 | }) 60 | 61 | summary_ops = [] 62 | 63 | conc_total = 0 64 | mse_total = 0 65 | for i, name in enumerate(['arousal', 'valence']): 66 | with tf.name_scope(name) as scope: 67 | concordance_cc2, values, updates = metrics.concordance_cc2( 68 | tf.reshape(prediction[:,:,i], [-1]), 69 | tf.reshape(ground_truth[:,:,i], [-1])) 70 | 71 | for n, v in updates.items(): 72 | names_to_updates[n + '/' + name] = v 73 | 74 | op = tf.summary.scalar('eval/concordance_' + name, concordance_cc2) 75 | op = tf.Print(op, [concordance_cc2], 'eval/concordance_'+name) 76 | summary_ops.append(op) 77 | 78 | mse_eval = 'eval/mse_' + name 79 | op = tf.summary.scalar(mse_eval, names_to_values[mse_eval]) 80 | op = tf.Print(op, [names_to_values[mse_eval]], mse_eval) 81 | summary_ops.append(op) 82 | 83 | mse_total += names_to_values[mse_eval] 84 | conc_total += concordance_cc2 85 | 86 | conc_total = conc_total / 2 87 | mse_total = mse_total / 2 88 | 89 | op = tf.summary.scalar('eval/concordance_total', conc_total) 90 | op = tf.Print(op, [conc_total], 'eval/concordance_total') 91 | summary_ops.append(op) 92 | 93 | op = tf.summary.scalar('eval/mse_total', mse_total) 94 | op = tf.Print(op, [mse_total], 'eval/mse_total') 95 | summary_ops.append(op) 96 | 97 | num_batches = int(FLAGS.num_examples / (FLAGS.batch_size * FLAGS.seq_length)) 98 | logging.set_verbosity(1) 99 | 100 | slim.evaluation.evaluation_loop( 101 | '', 102 | FLAGS.checkpoint_dir, 103 | FLAGS.log_dir, 104 | num_evals=num_batches, 105 | eval_op=list(names_to_updates.values()), 106 | summary_op=tf.summary.merge(summary_ops), 107 | eval_interval_secs=FLAGS.eval_interval_secs) 108 | 109 | def main(_): 110 | evaluate(FLAGS.dataset_dir) 111 | 112 | if __name__ == '__main__': 113 | tf.app.run() 114 | -------------------------------------------------------------------------------- /emotion_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import absolute_import 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import data_provider 7 | import losses 8 | import models 9 | 10 | from tensorflow.python.platform import tf_logging as logging 11 | 12 | 13 | slim = tf.contrib.slim 14 | 15 | # Create FLAGS 16 | FLAGS = tf.app.flags.FLAGS 17 | tf.app.flags.DEFINE_float('initial_learning_rate', 0.0001, 'Initial learning rate.') 18 | tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.97, 'Learning rate decay factor.') 19 | tf.app.flags.DEFINE_integer('batch_size', 2, 'The batch size to use.') 20 | tf.app.flags.DEFINE_string('train_dir', 'ckpt/train', 21 | 'Directory where to write event logs ' 22 | 'and checkpoint.') 23 | tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '', 24 | 'If specified, restore this pretrained model ' 25 | 'before beginning any training.') 26 | tf.app.flags.DEFINE_integer('hidden_units', 256, 27 | 'The number of hidden units in the recurrent model') 28 | tf.app.flags.DEFINE_integer('seq_length', 2, 29 | 'The number of consecutive examples to be used' 30 | 'in the recurrent model') 31 | tf.app.flags.DEFINE_string('model', 'both', 32 | 'Which model is going to be used: audio, video, or both ') 33 | tf.app.flags.DEFINE_string('dataset_dir', 'path_to_tfrecords', 34 | 'The tfrecords directory.') 35 | 36 | def train(data_folder): 37 | 38 | g = tf.Graph() 39 | with g.as_default(): 40 | # Load dataset. 41 | frames, audio, ground_truth, _ = data_provider.get_split(data_folder, True, 42 | 'train', FLAGS.batch_size, 43 | seq_length=FLAGS.seq_length) 44 | 45 | # Define model graph. 46 | with slim.arg_scope([slim.batch_norm, slim.layers.dropout], 47 | is_training=True): 48 | with slim.arg_scope(slim.nets.resnet_utils.resnet_arg_scope(is_training=True)): 49 | prediction = models.get_model(FLAGS.model)(frames, audio, 50 | hidden_units=FLAGS.hidden_units) 51 | 52 | for i, name in enumerate(['arousal', 'valence']): 53 | pred_single = tf.reshape(prediction[:, :, i], (-1,)) 54 | gt_single = tf.reshape(ground_truth[:, :, i], (-1,)) 55 | 56 | loss = losses.concordance_cc(pred_single, gt_single) 57 | tf.summary.scalar('losses/{} loss'.format(name), loss) 58 | 59 | mse = tf.reduce_mean(tf.square(pred_single - gt_single)) 60 | tf.summary.scalar('losses/mse {} loss'.format(name), mse) 61 | 62 | slim.losses.add_loss(loss / 2.) 63 | 64 | total_loss = slim.losses.get_total_loss() 65 | tf.summary.scalar('losses/total loss', total_loss) 66 | 67 | optimizer = tf.train.AdamOptimizer(FLAGS.initial_learning_rate) 68 | 69 | init_fn = None 70 | with tf.Session(graph=g) as sess: 71 | if FLAGS.pretrained_model_checkpoint_path: 72 | # Need to specify which variables to restore (use scope of models) 73 | variables_to_restore = slim.get_variables() 74 | init_fn = slim.assign_from_checkpoint_fn( 75 | FLAGS.pretrained_model_checkpoint_path, variables_to_restore) 76 | 77 | train_op = slim.learning.create_train_op(total_loss, 78 | optimizer, 79 | summarize_gradients=True) 80 | 81 | 82 | logging.set_verbosity(1) 83 | slim.learning.train(train_op, 84 | FLAGS.train_dir, 85 | init_fn=init_fn, 86 | save_summaries_secs=60, 87 | save_interval_secs=300) 88 | 89 | 90 | if __name__ == '__main__': 91 | train(FLAGS.dataset_dir) 92 | -------------------------------------------------------------------------------- /inception_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images for the Inception networks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | 25 | 26 | def apply_with_random_selector(x, func, num_cases): 27 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 28 | 29 | Args: 30 | x: input Tensor. 31 | func: Python function to apply. 32 | num_cases: Python int32, number of cases to sample sel from. 33 | 34 | Returns: 35 | The result of func(x, sel), where func receives the value of the 36 | selector as a python integer, but sel is sampled dynamically. 37 | """ 38 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 39 | # Pass the real x only to one of the func calls. 40 | return control_flow_ops.merge([ 41 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 42 | for case in range(num_cases)])[0] 43 | 44 | 45 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 46 | """Distort the color of a Tensor image. 47 | 48 | Each color distortion is non-commutative and thus ordering of the color ops 49 | matters. Ideally we would randomly permute the ordering of the color ops. 50 | Rather then adding that level of complication, we select a distinct ordering 51 | of color ops for each preprocessing thread. 52 | 53 | Args: 54 | image: 3-D Tensor containing single image in [0, 1]. 55 | color_ordering: Python int, a type of distortion (valid values: 0-3). 56 | fast_mode: Avoids slower ops (random_hue and random_contrast) 57 | scope: Optional scope for name_scope. 58 | Returns: 59 | 3-D Tensor color-distorted image on range [0, 1] 60 | Raises: 61 | ValueError: if color_ordering not in [0, 3] 62 | """ 63 | with tf.name_scope(scope, 'distort_color', [image]): 64 | if fast_mode: 65 | if color_ordering == 0: 66 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 67 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 68 | else: 69 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 70 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 71 | else: 72 | if color_ordering == 0: 73 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 74 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 75 | image = tf.image.random_hue(image, max_delta=0.2) 76 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 77 | elif color_ordering == 1: 78 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 79 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 80 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 81 | image = tf.image.random_hue(image, max_delta=0.2) 82 | elif color_ordering == 2: 83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 84 | image = tf.image.random_hue(image, max_delta=0.2) 85 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 86 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 87 | elif color_ordering == 3: 88 | image = tf.image.random_hue(image, max_delta=0.2) 89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 90 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 91 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 92 | else: 93 | raise ValueError('color_ordering must be in [0, 3]') 94 | 95 | # The random_* ops do not necessarily clamp. 96 | return tf.clip_by_value(image, 0.0, 1.0) 97 | 98 | 99 | def distorted_bounding_box_crop(image, 100 | bbox, 101 | min_object_covered=0.1, 102 | aspect_ratio_range=(0.75, 1.33), 103 | area_range=(0.05, 1.0), 104 | max_attempts=100, 105 | scope=None): 106 | """Generates cropped_image using a one of the bboxes randomly distorted. 107 | 108 | See `tf.image.sample_distorted_bounding_box` for more documentation. 109 | 110 | Args: 111 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 112 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 113 | where each coordinate is [0, 1) and the coordinates are arranged 114 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 115 | image. 116 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 117 | area of the image must contain at least this fraction of any bounding box 118 | supplied. 119 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 120 | image must have an aspect ratio = width / height within this range. 121 | area_range: An optional list of `floats`. The cropped area of the image 122 | must contain a fraction of the supplied image within in this range. 123 | max_attempts: An optional `int`. Number of attempts at generating a cropped 124 | region of the image of the specified constraints. After `max_attempts` 125 | failures, return the entire image. 126 | scope: Optional scope for name_scope. 127 | Returns: 128 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 129 | """ 130 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 131 | # Each bounding box has shape [1, num_boxes, box coords] and 132 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 133 | 134 | # A large fraction of image datasets contain a human-annotated bounding 135 | # box delineating the region of the image containing the object of interest. 136 | # We choose to create a new bounding box for the object which is a randomly 137 | # distorted version of the human-annotated bounding box that obeys an 138 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 139 | # bounding box. If no box is supplied, then we assume the bounding box is 140 | # the entire image. 141 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 142 | tf.shape(image), 143 | bounding_boxes=bbox, 144 | min_object_covered=min_object_covered, 145 | aspect_ratio_range=aspect_ratio_range, 146 | area_range=area_range, 147 | max_attempts=max_attempts, 148 | use_image_if_no_bounding_boxes=True) 149 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 150 | 151 | # Crop the image to the specified bounding box. 152 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 153 | return cropped_image, distort_bbox 154 | 155 | 156 | def preprocess_for_train(image, height, width, bbox, 157 | fast_mode=True, 158 | scope=None): 159 | """Distort one image for training a network. 160 | 161 | Distorting images provides a useful technique for augmenting the data 162 | set during training in order to make the network invariant to aspects 163 | of the image that do not effect the label. 164 | 165 | Additionally it would create image_summaries to display the different 166 | transformations applied to the image. 167 | 168 | Args: 169 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 170 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 171 | is [0, MAX], where MAX is largest positive representable number for 172 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 173 | height: integer 174 | width: integer 175 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 176 | where each coordinate is [0, 1) and the coordinates are arranged 177 | as [ymin, xmin, ymax, xmax]. 178 | fast_mode: Optional boolean, if True avoids slower transformations (i.e. 179 | bi-cubic resizing, random_hue or random_contrast). 180 | scope: Optional scope for name_scope. 181 | Returns: 182 | 3-D float Tensor of distorted image used for training with range [-1, 1]. 183 | """ 184 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): 185 | if bbox is None: 186 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], 187 | dtype=tf.float32, 188 | shape=[1, 1, 4]) 189 | if image.dtype != tf.float32: 190 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 191 | # Each bounding box has shape [1, num_boxes, box coords] and 192 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 193 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 194 | bbox) 195 | tf.image_summary('image_with_bounding_boxes', image_with_box) 196 | 197 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) 198 | # Restore the shape since the dynamic slice based upon the bbox_size loses 199 | # the third dimension. 200 | distorted_image.set_shape([None, None, 3]) 201 | image_with_distorted_box = tf.image.draw_bounding_boxes( 202 | tf.expand_dims(image, 0), distorted_bbox) 203 | tf.image_summary('images_with_distorted_bounding_box', 204 | image_with_distorted_box) 205 | 206 | # This resizing operation may distort the images because the aspect 207 | # ratio is not respected. We select a resize method in a round robin 208 | # fashion based on the thread number. 209 | # Note that ResizeMethod contains 4 enumerated resizing methods. 210 | 211 | # We select only 1 case for fast_mode bilinear. 212 | num_resize_cases = 1 if fast_mode else 4 213 | distorted_image = apply_with_random_selector( 214 | distorted_image, 215 | lambda x, method: tf.image.resize_images(x, [height, width], method=method), 216 | num_cases=num_resize_cases) 217 | 218 | tf.image_summary('cropped_resized_image', 219 | tf.expand_dims(distorted_image, 0)) 220 | 221 | # Randomly flip the image horizontally. 222 | distorted_image = tf.image.random_flip_left_right(distorted_image) 223 | 224 | # Randomly distort the colors. There are 4 ways to do it. 225 | distorted_image = apply_with_random_selector( 226 | distorted_image, 227 | lambda x, ordering: distort_color(x, ordering, fast_mode), 228 | num_cases=4) 229 | 230 | tf.image_summary('final_distorted_image', 231 | tf.expand_dims(distorted_image, 0)) 232 | distorted_image = tf.sub(distorted_image, 0.5) 233 | distorted_image = tf.mul(distorted_image, 2.0) 234 | return distorted_image 235 | 236 | 237 | def preprocess_for_eval(image, height, width, 238 | central_fraction=0.875, scope=None): 239 | """Prepare one image for evaluation. 240 | 241 | If height and width are specified it would output an image with that size by 242 | applying resize_bilinear. 243 | 244 | If central_fraction is specified it would cropt the central fraction of the 245 | input image. 246 | 247 | Args: 248 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 249 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 250 | is [0, MAX], where MAX is largest positive representable number for 251 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details) 252 | height: integer 253 | width: integer 254 | central_fraction: Optional Float, fraction of the image to crop. 255 | scope: Optional scope for name_scope. 256 | Returns: 257 | 3-D float Tensor of prepared image. 258 | """ 259 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 260 | if image.dtype != tf.float32: 261 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 262 | # Crop the central region of the image with an area containing 87.5% of 263 | # the original image. 264 | if central_fraction: 265 | image = tf.image.central_crop(image, central_fraction=central_fraction) 266 | 267 | if height and width: 268 | # Resize the image to the specified height and width. 269 | image = tf.expand_dims(image, 0) 270 | image = tf.image.resize_bilinear(image, [height, width], 271 | align_corners=False) 272 | image = tf.squeeze(image, [0]) 273 | image = tf.sub(image, 0.5) 274 | image = tf.mul(image, 2.0) 275 | return image 276 | 277 | 278 | def preprocess_image(image, height, width, 279 | is_training=False, 280 | bbox=None, 281 | fast_mode=True): 282 | """Pre-process one image for training or evaluation. 283 | 284 | Args: 285 | image: 3-D Tensor [height, width, channels] with the image. 286 | height: integer, image expected height. 287 | width: integer, image expected width. 288 | is_training: Boolean. If true it would transform an image for train, 289 | otherwise it would transform it for evaluation. 290 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 291 | where each coordinate is [0, 1) and the coordinates are arranged as 292 | [ymin, xmin, ymax, xmax]. 293 | fast_mode: Optional boolean, if True avoids slower transformations. 294 | 295 | Returns: 296 | 3-D float Tensor containing an appropriately scaled image 297 | 298 | Raises: 299 | ValueError: if user does not provide bounding box 300 | """ 301 | if is_training: 302 | return preprocess_for_train(image, height, width, bbox, fast_mode) 303 | else: 304 | return preprocess_for_eval(image, height, width) 305 | 306 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | 8 | slim = tf.contrib.slim 9 | 10 | def concordance_cc(prediction, ground_truth): 11 | """Defines concordance loss for training the model. 12 | 13 | Args: 14 | prediction: prediction of the model. 15 | ground_truth: ground truth values. 16 | Returns: 17 | The concordance value. 18 | """ 19 | 20 | pred_mean, pred_var = tf.nn.moments(prediction, (0,)) 21 | gt_mean, gt_var = tf.nn.moments(ground_truth, (0,)) 22 | 23 | mean_cent_prod = tf.reduce_mean((prediction - pred_mean) * (ground_truth - gt_mean)) 24 | 25 | return 1 - (2 * mean_cent_prod) / (pred_var + gt_var + tf.square(pred_mean - gt_mean)) 26 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | slim = tf.contrib.slim 5 | 6 | def concordance_cc2(prediction, ground_truth): 7 | """Defines concordance metric for model evaluation. 8 | 9 | Args: 10 | prediction: prediction of the model. 11 | ground_truth: ground truth values. 12 | Returns: 13 | The concordance value. 14 | """ 15 | 16 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 17 | 'eval/mean_pred':slim.metrics.streaming_mean(prediction), 18 | 'eval/mean_lab':slim.metrics.streaming_mean(ground_truth), 19 | 'eval/cov_pred':slim.metrics.streaming_covariance(prediction, prediction), 20 | 'eval/cov_lab':slim.metrics.streaming_covariance(ground_truth, ground_truth), 21 | 'eval/cov_lab_pred':slim.metrics.streaming_covariance(prediction, ground_truth) 22 | }) 23 | 24 | metrics = dict() 25 | for name, value in names_to_values.items(): 26 | metrics[name] = value 27 | 28 | mean_pred = metrics['eval/mean_pred'] 29 | var_pred = metrics['eval/cov_pred'] 30 | mean_lab = metrics['eval/mean_lab'] 31 | var_lab = metrics['eval/cov_lab'] 32 | var_lab_pred = metrics['eval/cov_lab_pred'] 33 | 34 | denominator = (var_pred + var_lab + (mean_pred - mean_lab) ** 2) 35 | 36 | concordance_cc2 = (2 * var_lab_pred) / denominator 37 | 38 | return concordance_cc2, names_to_values, names_to_updates 39 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib.slim.nets import resnet_v1 8 | 9 | 10 | slim = tf.contrib.slim 11 | 12 | def recurrent_model(net, hidden_units=256, number_of_outputs=2): 13 | """Adds the recurrent network on top of the spatial 14 | audio / video / audio-visual model. 15 | 16 | Args: 17 | net: A `Tensor` of dimensions [batch_size, seq_length, num_features]. 18 | hidden_units: The number of hidden units of the LSTM cell. 19 | num_classes: The number of classes. 20 | Returns: 21 | The prediction of the network. 22 | """ 23 | 24 | batch_size, seq_length, num_features = net.get_shape().as_list() 25 | 26 | lstm = tf.nn.rnn_cell.LSTMCell(hidden_units, 27 | use_peepholes=True, 28 | cell_clip=100, 29 | state_is_tuple=True) 30 | 31 | stacked_lstm = tf.nn.rnn_cell.MultiRNNCell([lstm] * 2, state_is_tuple=True) 32 | 33 | # We have to specify the dimensionality of the Tensor so we can allocate 34 | # weights for the fully connected layers. 35 | outputs, _ = tf.nn.dynamic_rnn(stacked_lstm, net, dtype=tf.float32) 36 | 37 | net = tf.reshape(outputs, (batch_size * seq_length, hidden_units)) 38 | 39 | prediction = slim.layers.linear(net, number_of_outputs) 40 | 41 | return tf.reshape(prediction, (batch_size, seq_length, number_of_outputs)) 42 | 43 | def video_model(video_frames=None, audio_frames=None): 44 | """Creates the video model. 45 | 46 | Args: 47 | video_frames: A tensor that contains the video input. 48 | audio_frames: not needed (leave None). 49 | Returns: 50 | The video model. 51 | """ 52 | 53 | with tf.variable_scope("video_model"): 54 | batch_size, seq_length, height, width, channels = video_frames.get_shape().as_list() 55 | 56 | video_input = tf.reshape(video_frames, (batch_size * seq_length, height, width, channels)) 57 | video_input = tf.cast(video_input, tf.float32) 58 | 59 | features, end_points = resnet_v1.resnet_v1_50(video_input, None) 60 | features = tf.reshape(features, (batch_size, seq_length, int(features.get_shape()[3]))) 61 | 62 | return features 63 | 64 | def audio_model(video_frames=None, audio_frames=None, conv_filters=40): 65 | """Creates the audio model. 66 | 67 | Args: 68 | video_frames: not needed (leave None). 69 | audio_frames: A tensor that contains the audio input. 70 | conv_filters: The number of convolutional filters to use. 71 | Returns: 72 | The audio model. 73 | """ 74 | 75 | with tf.variable_scope("audio_model"): 76 | batch_size, seq_length, num_features = audio_frames.get_shape().as_list() 77 | audio_input = tf.reshape(audio_frames, [batch_size * seq_length, 1, num_features, 1]) 78 | 79 | with slim.arg_scope([slim.layers.conv2d], padding='SAME'): 80 | net = slim.dropout(audio_input) 81 | net = slim.layers.conv2d(net, conv_filters, (1, 20)) 82 | 83 | # Subsampling of the signal to 8KhZ. 84 | net = tf.nn.max_pool( 85 | net, 86 | ksize=[1, 1, 2, 1], 87 | strides=[1, 1, 2, 1], 88 | padding='SAME', 89 | name='pool1') 90 | 91 | # Original model had 400 output filters for the second conv layer 92 | # but this trains much faster and achieves comparable accuracy. 93 | net = slim.layers.conv2d(net, conv_filters, (1, 40)) 94 | 95 | net = tf.reshape(net, (batch_size * seq_length, num_features // 2, conv_filters, 1)) 96 | 97 | # Pooling over the feature maps. 98 | net = tf.nn.max_pool( 99 | net, 100 | ksize=[1, 1, 10, 1], 101 | strides=[1, 1, 10, 1], 102 | padding='SAME', 103 | name='pool2') 104 | 105 | net = tf.reshape(net, (batch_size, seq_length, num_features //2 * 4 )) 106 | 107 | return net 108 | 109 | 110 | def combined_model(video_frames, audio_frames): 111 | """Creates the audio-visual model. 112 | 113 | Args: 114 | video_frames: A tensor that contains the video input. 115 | audio_frames: A tensor that contains the audio input. 116 | Returns: 117 | The audio-visual model. 118 | """ 119 | 120 | audio_features = audio_model([], audio_frames) 121 | visual_features = video_model(video_frames,[]) 122 | 123 | return tf.concat(2, (audio_features, visual_features), name='concat') 124 | 125 | 126 | def get_model(name): 127 | """Returns the recurrent model. 128 | 129 | Args: 130 | name: one of the 'audio', 'video', or 'both' 131 | Returns: 132 | The recurrent model. 133 | """ 134 | 135 | name_to_fun = {'audio': audio_model, 'video': video_model, 'both': combined_model} 136 | 137 | if name in name_to_fun: 138 | model = name_to_fun[name] 139 | else: 140 | raise ValueError('Requested name [{}] not a valid model'.format(name)) 141 | 142 | def wrapper(*args, **kwargs): 143 | return recurrent_model(model(*args), **kwargs) 144 | 145 | return wrapper 146 | --------------------------------------------------------------------------------