├── AFM ├── afm.py ├── input_fn.py ├── metric.py ├── run.sh ├── train.py └── utils.py ├── DeepCross ├── dcn.py ├── input_fn.py ├── metric.py ├── run.sh └── train.py ├── DeepFM ├── deepfm.py ├── input_fn.py ├── metric.py ├── run.sh └── train.py ├── Din ├── din.py ├── din_feature_column.py ├── input_fn.py ├── metric.py ├── run.sh ├── train.py └── utils.py ├── ESMM ├── esmm.py ├── input_fn.py ├── metric.py ├── run.sh └── train.py ├── Fibinet ├── fibinet.py ├── input_fn.py ├── metric.py ├── run.sh ├── train.py └── utils.py ├── README.md ├── ResNet ├── input_fn.py ├── metric.py ├── resnet.py ├── run.sh ├── train.py └── utils.py ├── Transformer ├── input_fn.py ├── metric.py ├── run.sh ├── train.py ├── transformer.py └── utils.py ├── XDeepFM ├── input_fn.py ├── metric.py ├── run.sh ├── train.py ├── utils.py └── xdeepfm.py └── official ├── .DS_Store ├── __init__.py ├── datasets ├── __init__.py └── movielens.py ├── utils ├── .DS_Store ├── __init__.py ├── accelerator │ ├── __init__.py │ ├── tpu.py │ └── tpu_test.py ├── data │ ├── __init__.py │ ├── file_io.py │ └── file_io_test.py ├── export │ ├── __init__.py │ ├── export.py │ └── export_test.py ├── flags │ ├── README.md │ ├── __init__.py │ ├── _base.py │ ├── _benchmark.py │ ├── _conventions.py │ ├── _device.py │ ├── _misc.py │ ├── _performance.py │ ├── core.py │ ├── flags_test.py │ └── guidelines.md ├── logs │ ├── __init__.py │ ├── cloud_lib.py │ ├── cloud_lib_test.py │ ├── guidelines.md │ ├── hooks.py │ ├── hooks_helper.py │ ├── hooks_helper_test.py │ ├── hooks_test.py │ ├── logger.py │ ├── logger_test.py │ ├── metric_hook.py │ ├── metric_hook_test.py │ └── mlperf_helper.py ├── misc │ ├── __init__.py │ ├── distribution_utils.py │ ├── distribution_utils_test.py │ ├── model_helpers.py │ └── model_helpers_test.py └── testing │ ├── .DS_Store │ ├── __init__.py │ ├── integration.py │ ├── mock_lib.py │ ├── pylint.rcfile │ ├── reference_data.py │ ├── reference_data │ ├── .DS_Store │ ├── reference_data_test │ │ ├── .DS_Store │ │ ├── dense │ │ │ ├── expected_graph │ │ │ ├── model.ckpt.data-00000-of-00001 │ │ │ ├── model.ckpt.index │ │ │ ├── results.json │ │ │ └── tf_version.json │ │ └── uniform_random │ │ │ ├── expected_graph │ │ │ ├── model.ckpt.data-00000-of-00001 │ │ │ ├── model.ckpt.index │ │ │ ├── results.json │ │ │ └── tf_version.json │ └── resnet │ │ ├── .DS_Store │ │ ├── batch-size-32_bottleneck_projection_version-1_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ ├── batch-size-32_bottleneck_projection_version-2_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ ├── batch-size-32_bottleneck_version-1_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ ├── batch-size-32_bottleneck_version-2_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ ├── batch-size-32_building_projection_version-1_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ ├── batch-size-32_building_projection_version-2_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ ├── batch-size-32_building_version-1_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ ├── batch-size-32_building_version-2_width-8_channels-4 │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ │ └── batch_norm │ │ ├── expected_graph │ │ ├── model.ckpt.data-00000-of-00001 │ │ ├── model.ckpt.index │ │ ├── results.json │ │ └── tf_version.json │ ├── reference_data_test.py │ └── scripts │ └── presubmit.sh └── wide_deep ├── README.md ├── __init__.py ├── census_dataset.py ├── census_main.py ├── census_test.csv ├── census_test.py ├── movielens_dataset.py ├── movielens_main.py ├── movielens_test.py └── wide_deep_run_loop.py /AFM/input_fn.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import tensorflow as tf 6 | 7 | 8 | FixedLenFeatureColumns=["label", "user_id", "creative_id", "has_target", "terminal", 9 | "hour", "weekday","template_category", 10 | "day_user_show", "day_user_click", "city_code","network_type"] 11 | StringVarLenFeatureColumns = ["keyword"] #特征长度不固定 12 | FloatFixedLenFeatureColumns = ['creative_history_ctr'] 13 | StringFixedLenFeatureColumns = ["keyword_attention"] 14 | StringFeatureColumns = ["device_type", "device_model", "manufacturer"] 15 | 16 | DayShowSegs = [1, 5, 8, 12, 18, 26, 54, 120, 250, 432, 823] 17 | DayClickSegs = [1, 2, 3, 6, 23] 18 | 19 | 20 | def build_model_columns(): 21 | """Builds a set of wide and deep feature columns.""" 22 | # Continuous variable columns 23 | # hours_per_week = tf.feature_column.numeric_column('hours_per_week') 24 | 25 | creative_id = tf.feature_column.categorical_column_with_hash_bucket( 26 | 'creative_id', hash_bucket_size=200000, dtype=tf.int64) 27 | # To show an example of hashing: 28 | has_target = tf.feature_column.categorical_column_with_identity( 29 | 'has_target', num_buckets=3) 30 | terminal = tf.feature_column.categorical_column_with_identity( 31 | 'terminal', num_buckets=10) 32 | hour = tf.feature_column.categorical_column_with_identity( 33 | 'hour', num_buckets=25) 34 | weekday = tf.feature_column.categorical_column_with_identity( 35 | 'weekday', num_buckets=10) 36 | day_user_show = tf.feature_column.bucketized_column( 37 | tf.feature_column.numeric_column('day_user_show', dtype=tf.int32), boundaries=DayShowSegs) 38 | day_user_click = tf.feature_column.bucketized_column( 39 | tf.feature_column.numeric_column('day_user_click', dtype=tf.int32), boundaries=DayClickSegs) 40 | 41 | city_code = tf.feature_column.categorical_column_with_hash_bucket( 42 | 'city_code', hash_bucket_size=2000, dtype=tf.int64) 43 | 44 | network_type = tf.feature_column.categorical_column_with_identity( 45 | 'network_type', num_buckets=20, default_value=19) 46 | 47 | device_type = tf.feature_column.categorical_column_with_hash_bucket( #androidphone这些 48 | 'device_type', hash_bucket_size=500000, dtype=tf.string 49 | ) 50 | device_model = tf.feature_column.categorical_column_with_hash_bucket( #型号如iPhone10 vivo X9 51 | 'device_model', hash_bucket_size=200000, dtype=tf.string 52 | ) 53 | manufacturer = tf.feature_column.categorical_column_with_hash_bucket( #手机品牌 vivo iphone等 54 | 'manufacturer', hash_bucket_size=50000, dtype=tf.string 55 | ) 56 | 57 | 58 | deep_columns = [ 59 | tf.feature_column.embedding_column(creative_id, dimension=15,combiner='sum'), 60 | tf.feature_column.embedding_column(has_target, dimension=15,combiner='sum'), 61 | tf.feature_column.embedding_column(terminal, dimension=15, combiner='sum'), 62 | tf.feature_column.embedding_column(hour, dimension=15, combiner='sum'), 63 | tf.feature_column.embedding_column(weekday, dimension=15, combiner='sum'), 64 | tf.feature_column.embedding_column(day_user_show, dimension=15, combiner='sum'), 65 | tf.feature_column.embedding_column(day_user_click, dimension=15, combiner='sum'), 66 | tf.feature_column.embedding_column(city_code, dimension=15, combiner='sum'), 67 | tf.feature_column.embedding_column(network_type, dimension=15, combiner='sum'), 68 | tf.feature_column.embedding_column(device_type, dimension=15, combiner='sum'), 69 | tf.feature_column.embedding_column(device_model, dimension=15, combiner='sum'), 70 | tf.feature_column.embedding_column(manufacturer, dimension=15, combiner='sum'), 71 | 72 | ] 73 | # base_columns = [user_id, ad_id, creative_id, product_id, brush_num, terminal,terminal_brand] 74 | ''' 75 | crossed_columns = [tf.feature_column.crossed_column( 76 | ['userId', 'adId'], hash_bucket_size = 50000000), 77 | 、、、 78 | ] 79 | ''' 80 | return deep_columns 81 | 82 | def feature_input_fn(data_file, num_epochs, shuffle, batch_size, labels=True): 83 | """Generate an input function for the Estimator.""" 84 | 85 | def parse_tfrecord(value): 86 | tf.logging.info('Parsing {}'.format(data_file[:10])) 87 | FixedLenFeatures = { 88 | key: tf.FixedLenFeature(shape=[1], dtype=tf.int64) for key in FixedLenFeatureColumns 89 | } 90 | 91 | StringVarLenFeatures = { 92 | key: tf.VarLenFeature(dtype=tf.string) for key in StringVarLenFeatureColumns 93 | } 94 | FloatFixedLenFeatures = { 95 | key: tf.FixedLenFeature(shape=[1], dtype=tf.float32) for key in FloatFixedLenFeatureColumns 96 | } 97 | StringFixedLenFeatures = { 98 | key: tf.FixedLenFeature(shape=[20], dtype=tf.string) for key in StringFixedLenFeatureColumns 99 | } 100 | StringFeatures = { 101 | key: tf.FixedLenFeature(shape=[1], dtype=tf.string) for key in StringFeatureColumns 102 | } 103 | features={} 104 | features.update(FixedLenFeatures) 105 | features.update(StringVarLenFeatures) 106 | features.update(FloatFixedLenFeatures) 107 | features.update(StringFixedLenFeatures) 108 | features.update(StringFeatures) 109 | 110 | fea = tf.parse_example(value, features) 111 | feature = { 112 | key: fea[key] for key in features 113 | } 114 | classes = tf.to_float(feature['label']) 115 | return feature, classes 116 | 117 | # Extract lines from input files using the Dataset API. 118 | filenames = tf.data.Dataset.list_files(data_file) 119 | dataset = filenames.apply(tf.contrib.data.parallel_interleave( 120 | lambda filename: tf.data.TFRecordDataset(filename), 121 | cycle_length=32)) 122 | 123 | if shuffle: 124 | dataset = dataset.shuffle(buffer_size=batch_size*64) 125 | 126 | dataset = dataset.repeat(num_epochs).batch(batch_size).prefetch(buffer_size=batch_size*8) 127 | dataset = dataset.map(parse_tfrecord, num_parallel_calls=32) 128 | 129 | return dataset 130 | 131 | -------------------------------------------------------------------------------- /AFM/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /AFM/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 2 -------------------------------------------------------------------------------- /AFM/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def dice(_x, axis=-1, epsilon=0.0000001, name=''): 4 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 5 | alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], 6 | initializer=tf.constant_initializer(0.0), 7 | dtype=tf.float32) 8 | input_shape = list(_x.get_shape()) 9 | 10 | reduction_axes = list(range(len(input_shape))) 11 | del reduction_axes[axis] 12 | broadcast_shape = [1] * len(input_shape) 13 | broadcast_shape[axis] = input_shape[axis] 14 | 15 | # case: train mode (uses stats of the current batch) 16 | mean = tf.reduce_mean(_x, axis=reduction_axes) 17 | brodcast_mean = tf.reshape(mean, broadcast_shape) 18 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) 19 | std = tf.sqrt(std) 20 | brodcast_std = tf.reshape(std, broadcast_shape) 21 | x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) 22 | # x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) 23 | x_p = tf.sigmoid(x_normed) 24 | 25 | return alphas * (1.0 - x_p) * _x + x_p * _x 26 | 27 | def prelu(_x, scope=''): 28 | """parametric ReLU activation""" 29 | with tf.variable_scope(name_or_scope=scope, default_name="prelu"): 30 | _alpha = tf.get_variable("prelu_"+scope, shape=_x.get_shape()[-1], 31 | dtype=_x.dtype, initializer=tf.constant_initializer(0.1)) 32 | return tf.maximum(0.0, _x) + _alpha * tf.minimum(0.0, _x) 33 | 34 | -------------------------------------------------------------------------------- /DeepCross/dcn.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | from tensorflow.python.estimator.canned import head as head_lib 4 | from tensorflow.python.ops.losses import losses 5 | 6 | def build_deep_layers(net, params): 7 | # Build the hidden layers, sized according to the 'hidden_units' param. 8 | 9 | for num_hidden_units in params['hidden_units']: 10 | net = tf.layers.dense(net, units=num_hidden_units, activation=tf.nn.relu, 11 | kernel_initializer=tf.glorot_uniform_initializer()) 12 | return net 13 | 14 | 15 | def build_cross_layers(x0, params): 16 | num_layers = params['num_cross_layers'] 17 | x = x0 18 | for i in range(num_layers): 19 | x = cross_layer(x0, x, 'cross_{}'.format(i)) 20 | return x 21 | 22 | def cross_layer(x0, x, name): 23 | with tf.variable_scope(name): 24 | input_dim = x0.get_shape().as_list()[1] 25 | w = tf.get_variable("weight", [input_dim], initializer=tf.truncated_normal_initializer(stddev=0.01)) 26 | b = tf.get_variable("bias", [input_dim], initializer=tf.truncated_normal_initializer(stddev=0.01)) 27 | xb = tf.tensordot(tf.reshape(x, [-1, 1, input_dim]), w, 1) 28 | return x0 * xb + b + x 29 | 30 | def dcn_model_fn(features, labels, mode, params): 31 | net = tf.feature_column.input_layer(features, params['feature_columns']) 32 | last_deep_layer = build_deep_layers(net, params) 33 | last_cross_layer = build_cross_layers(net, params) 34 | 35 | if params['use_cross']: 36 | print('--use cross layer--') 37 | last_layer = tf.concat([last_deep_layer, last_cross_layer], 1) 38 | else: 39 | last_layer = last_deep_layer 40 | 41 | #head = tf.contrib.estimator.binary_classification_head(loss_reduction=losses.Reduction.SUM) 42 | head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access 43 | n_classes=2, weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM) 44 | logits = tf.layers.dense(last_layer, units=head.logits_dimension, kernel_initializer=tf.glorot_uniform_initializer()) 45 | optimizer = tf.train.AdagradOptimizer(learning_rate=params['learning_rate']) 46 | #optimizer = tf.contrib.opt.GGTOptimizer(learning_rate=params['learning_rate']) 47 | preds = tf.sigmoid(logits) 48 | user_id = features['user_id'] 49 | label = features['label'] 50 | 51 | if mode == tf.estimator.ModeKeys.PREDICT: 52 | predictions = { 53 | 'probabilities': preds, 54 | 'user_id': user_id, 55 | 'label': label 56 | } 57 | export_outputs = { 58 | 'regression': tf.estimator.export.RegressionOutput(predictions['probabilities']) 59 | } 60 | return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs) 61 | 62 | if mode == tf.estimator.ModeKeys.TRAIN: 63 | 64 | return head.create_estimator_spec( 65 | features=features, 66 | mode=mode, 67 | labels=labels, 68 | logits=logits, 69 | train_op_fn=lambda loss: optimizer.minimize(loss, global_step=tf.train.get_global_step()) 70 | ) 71 | 72 | -------------------------------------------------------------------------------- /DeepCross/input_fn.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import tensorflow as tf 6 | 7 | 8 | FixedLenFeatureColumns=["label", "user_id", "creative_id", "has_target", "terminal", 9 | "hour", "weekday","template_category", 10 | "day_user_show", "day_user_click", "city_code","network_type"] 11 | StringVarLenFeatureColumns = ["keyword"] #特征长度不固定 12 | FloatFixedLenFeatureColumns = ['creative_history_ctr'] 13 | StringFixedLenFeatureColumns = ["keyword_attention"] 14 | StringFeatureColumns = ["device_type", "device_model", "manufacturer"] 15 | 16 | DayShowSegs = [1, 5, 8, 12, 18, 26, 54, 120, 250, 432, 823] 17 | DayClickSegs = [1, 2, 3, 6, 23] 18 | 19 | 20 | def build_model_columns(): 21 | """Builds a set of wide and deep feature columns.""" 22 | # Continuous variable columns 23 | # hours_per_week = tf.feature_column.numeric_column('hours_per_week') 24 | 25 | creative_id = tf.feature_column.categorical_column_with_hash_bucket( 26 | 'creative_id', hash_bucket_size=200000, dtype=tf.int64) 27 | # To show an example of hashing: 28 | has_target = tf.feature_column.categorical_column_with_identity( 29 | 'has_target', num_buckets=3) 30 | terminal = tf.feature_column.categorical_column_with_identity( 31 | 'terminal', num_buckets=10) 32 | hour = tf.feature_column.categorical_column_with_identity( 33 | 'hour', num_buckets=25) 34 | weekday = tf.feature_column.categorical_column_with_identity( 35 | 'weekday', num_buckets=10) 36 | day_user_show = tf.feature_column.bucketized_column( 37 | tf.feature_column.numeric_column('day_user_show', dtype=tf.int32), boundaries=DayShowSegs) 38 | day_user_click = tf.feature_column.bucketized_column( 39 | tf.feature_column.numeric_column('day_user_click', dtype=tf.int32), boundaries=DayClickSegs) 40 | 41 | city_code = tf.feature_column.categorical_column_with_hash_bucket( 42 | 'city_code', hash_bucket_size=2000, dtype=tf.int64) 43 | 44 | network_type = tf.feature_column.categorical_column_with_identity( 45 | 'network_type', num_buckets=20, default_value=19) 46 | 47 | device_type = tf.feature_column.categorical_column_with_hash_bucket( #androidphone这些 48 | 'device_type', hash_bucket_size=500000, dtype=tf.string 49 | ) 50 | device_model = tf.feature_column.categorical_column_with_hash_bucket( #型号如iPhone10 vivo X9 51 | 'device_model', hash_bucket_size=200000, dtype=tf.string 52 | ) 53 | manufacturer = tf.feature_column.categorical_column_with_hash_bucket( #手机品牌 vivo iphone等 54 | 'manufacturer', hash_bucket_size=50000, dtype=tf.string 55 | ) 56 | 57 | 58 | deep_columns = [ 59 | tf.feature_column.embedding_column(creative_id, dimension=15,combiner='sum'), 60 | tf.feature_column.embedding_column(has_target, dimension=15,combiner='sum'), 61 | tf.feature_column.embedding_column(terminal, dimension=15, combiner='sum'), 62 | tf.feature_column.embedding_column(hour, dimension=15, combiner='sum'), 63 | tf.feature_column.embedding_column(weekday, dimension=15, combiner='sum'), 64 | tf.feature_column.embedding_column(day_user_show, dimension=15, combiner='sum'), 65 | tf.feature_column.embedding_column(day_user_click, dimension=15, combiner='sum'), 66 | tf.feature_column.embedding_column(city_code, dimension=15, combiner='sum'), 67 | tf.feature_column.embedding_column(network_type, dimension=15, combiner='sum'), 68 | tf.feature_column.embedding_column(device_type, dimension=15, combiner='sum'), 69 | tf.feature_column.embedding_column(device_model, dimension=15, combiner='sum'), 70 | tf.feature_column.embedding_column(manufacturer, dimension=15, combiner='sum'), 71 | 72 | ] 73 | # base_columns = [user_id, ad_id, creative_id, product_id, brush_num, terminal,terminal_brand] 74 | ''' 75 | crossed_columns = [tf.feature_column.crossed_column( 76 | ['userId', 'adId'], hash_bucket_size = 50000000), 77 | 、、、 78 | ] 79 | ''' 80 | return deep_columns 81 | 82 | def feature_input_fn(data_file, num_epochs, shuffle, batch_size, labels=True): 83 | """Generate an input function for the Estimator.""" 84 | 85 | def parse_tfrecord(value): 86 | tf.logging.info('Parsing {}'.format(data_file[:10])) 87 | FixedLenFeatures = { 88 | key: tf.FixedLenFeature(shape=[1], dtype=tf.int64) for key in FixedLenFeatureColumns 89 | } 90 | 91 | StringVarLenFeatures = { 92 | key: tf.VarLenFeature(dtype=tf.string) for key in StringVarLenFeatureColumns 93 | } 94 | FloatFixedLenFeatures = { 95 | key: tf.FixedLenFeature(shape=[1], dtype=tf.float32) for key in FloatFixedLenFeatureColumns 96 | } 97 | StringFixedLenFeatures = { 98 | key: tf.FixedLenFeature(shape=[20], dtype=tf.string) for key in StringFixedLenFeatureColumns 99 | } 100 | StringFeatures = { 101 | key: tf.FixedLenFeature(shape=[1], dtype=tf.string) for key in StringFeatureColumns 102 | } 103 | features={} 104 | features.update(FixedLenFeatures) 105 | features.update(StringVarLenFeatures) 106 | features.update(FloatFixedLenFeatures) 107 | features.update(StringFixedLenFeatures) 108 | features.update(StringFeatures) 109 | 110 | fea = tf.parse_example(value, features) 111 | feature = { 112 | key: fea[key] for key in features 113 | } 114 | classes = tf.to_float(feature['label']) 115 | return feature, classes 116 | 117 | # Extract lines from input files using the Dataset API. 118 | filenames = tf.data.Dataset.list_files(data_file) 119 | dataset = filenames.apply(tf.contrib.data.parallel_interleave( 120 | lambda filename: tf.data.TFRecordDataset(filename), 121 | cycle_length=32)) 122 | 123 | if shuffle: 124 | dataset = dataset.shuffle(buffer_size=batch_size*64) 125 | 126 | dataset = dataset.repeat(num_epochs).batch(batch_size).prefetch(buffer_size=batch_size*8) 127 | dataset = dataset.map(parse_tfrecord, num_parallel_calls=32) 128 | 129 | return dataset 130 | 131 | -------------------------------------------------------------------------------- /DeepCross/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /DeepCross/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 2 3 | -------------------------------------------------------------------------------- /DeepFM/deepfm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | from tensorflow.python.estimator.canned import head as head_lib 4 | from tensorflow.python.ops.losses import losses 5 | import collections 6 | from tensorflow.python.ops import metrics as metrics_lib 7 | from metric import cal_group_auc 8 | 9 | def build_deep_layers(net, params): 10 | # Build the hidden layers, sized according to the 'hidden_units' param. 11 | 12 | for num_hidden_units in params['hidden_units']: 13 | net = tf.layers.dense(net, units=num_hidden_units, activation=tf.nn.relu, 14 | kernel_initializer=tf.glorot_uniform_initializer()) 15 | return net 16 | 17 | def _check_fm_columns(feature_columns): 18 | if isinstance(feature_columns, collections.Iterator): 19 | feature_columns = list(feature_columns) 20 | column_num = len(feature_columns) 21 | if column_num < 2: 22 | raise ValueError('feature_columns must have as least two elements.') 23 | dimension = -1 24 | for column in feature_columns: 25 | if dimension != -1 and column.dimension != dimension: 26 | raise ValueError('fm_feature_columns must have the same dimension.') 27 | dimension = column.dimension 28 | return column_num, dimension 29 | 30 | def dfm_model_fn(features, labels, mode, params): 31 | net = tf.feature_column.input_layer(features, params['feature_columns']) # shape(batch_size, column_num * embedding_size) 32 | last_deep_layer = build_deep_layers(net, params) 33 | 34 | column_num, dimension = _check_fm_columns(params['feature_columns']) 35 | feature_embeddings = tf.reshape(net, (-1, column_num, dimension)) #(batch_size,column_num, embedding_size) 36 | 37 | # sum_square part 38 | summed_feature_embeddings = tf.reduce_sum(feature_embeddings, 1) # (batch_size,embedding_size) 39 | summed_square_feature_embeddings = tf.square(summed_feature_embeddings) 40 | 41 | # squre-sum part 42 | squared_feature_embeddings = tf.square(feature_embeddings) 43 | squared_sum_feature_embeddings = tf.reduce_sum(squared_feature_embeddings, 1) 44 | 45 | fm_second_order = 0.5 * tf.subtract(summed_square_feature_embeddings, squared_sum_feature_embeddings) 46 | #print(tf.shape(fm_second_order)) 47 | #print(fm_second_order.get_shape()) 48 | 49 | if params['use_fm']: 50 | print('--use fm--') 51 | last_layer = tf.concat([fm_second_order, last_deep_layer], 1) 52 | else: 53 | last_layer = last_deep_layer 54 | #head = tf.contrib.estimator.binary_classification_head(loss_reduction=losses.Reduction.SUM) 55 | head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access 56 | n_classes=2, weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM) 57 | logits = tf.layers.dense(last_layer, units=head.logits_dimension, 58 | kernel_initializer=tf.glorot_uniform_initializer()) 59 | optimizer = tf.train.AdagradOptimizer(learning_rate=params['learning_rate']) 60 | 61 | preds = tf.sigmoid(logits) 62 | #print(tf.shape(preds)) 63 | #print(preds.get_shape()) 64 | user_id = features['user_id'] 65 | label = features['label'] 66 | if mode == tf.estimator.ModeKeys.EVAL: 67 | accuracy = tf.metrics.accuracy(labels=labels['class'], 68 | predictions=tf.to_float(tf.greater_equal(preds, 0.5))) 69 | auc = tf.metrics.auc(labels['class'], preds) 70 | label_mean = metrics_lib.mean(labels['class']) 71 | prediction_mean = metrics_lib.mean(preds) 72 | 73 | prediction_squared_difference = tf.math.squared_difference(preds, prediction_mean[0]) 74 | prediction_squared_sum = tf.reduce_sum(prediction_squared_difference) 75 | num_predictions = tf.to_float(tf.size(preds)) 76 | s_deviation = tf.sqrt(prediction_squared_sum/num_predictions), accuracy[0] #标准差 77 | 78 | c_variation = tf.to_float(s_deviation[0]/prediction_mean[0]), accuracy[0] #变异系数 79 | 80 | #group_auc = tf.to_float(cal_group_auc(labels['class'], preds, labels['user_id'])), accuracy[0] # group auc 81 | 82 | 83 | metrics = {'accuracy': accuracy, 'auc': auc, 'label/mean': label_mean, 'prediction/mean': prediction_mean, 84 | 'standard deviation': s_deviation, 'coefficient of variation': c_variation} 85 | # 'group auc': group_auc} 86 | tf.summary.scalar('accuracy', accuracy[1]) 87 | tf.summary.scalar('auc', auc[1]) 88 | tf.summary.scalar('label/mean', label_mean[1]) 89 | tf.summary.scalar('prediction/mean', prediction_mean[1]) 90 | tf.summary.scalar('s_deviation', s_deviation[1]) 91 | tf.summary.scalar('c_variation', c_variation[1]) 92 | #tf.summary.scalar('group_auc', group_auc[1]) 93 | 94 | loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels['class'], logits=logits)) 95 | #print(tf.shape(loss)) 96 | #print(loss.get_shape()) 97 | return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics) 98 | 99 | if mode == tf.estimator.ModeKeys.PREDICT: 100 | predictions = { 101 | 'probabilities': preds, 102 | 'user_id': user_id, 103 | 'label': label 104 | } 105 | export_outputs = { 106 | 'prediction': tf.estimator.export.PredictOutput(predictions) 107 | } 108 | return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs) 109 | 110 | return head.create_estimator_spec( 111 | features=features, 112 | mode=mode, 113 | labels=labels, 114 | logits=logits, 115 | train_op_fn=lambda loss: optimizer.minimize(loss, global_step=tf.train.get_global_step()) 116 | ) 117 | 118 | -------------------------------------------------------------------------------- /DeepFM/input_fn.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import tensorflow as tf 6 | 7 | 8 | FixedLenFeatureColumns=["label", "user_id", "creative_id", "has_target", "terminal", 9 | "hour", "weekday","template_category", 10 | "day_user_show", "day_user_click", "city_code","network_type"] 11 | StringVarLenFeatureColumns = ["keyword"] #特征长度不固定 12 | FloatFixedLenFeatureColumns = ['creative_history_ctr'] 13 | StringFixedLenFeatureColumns = ["keyword_attention"] 14 | StringFeatureColumns = ["device_type", "device_model", "manufacturer"] 15 | 16 | DayShowSegs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 41, 42, 44, 46, 47, 49, 51, 54, 56, 59, 61, 65, 68, 72, 76, 81, 86, 92, 100, 109, 120, 134, 153, 184, 243, 1195] 17 | DayClickSegs = [1, 2, 3, 6, 23] 18 | 19 | 20 | def build_model_columns(): 21 | """Builds a set of wide and deep feature columns.""" 22 | # Continuous variable columns 23 | # hours_per_week = tf.feature_column.numeric_column('hours_per_week') 24 | 25 | creative_id = tf.feature_column.categorical_column_with_hash_bucket( 26 | 'creative_id', hash_bucket_size=200000, dtype=tf.int64) 27 | # To show an example of hashing: 28 | has_target = tf.feature_column.categorical_column_with_identity( 29 | 'has_target', num_buckets=3) 30 | terminal = tf.feature_column.categorical_column_with_identity( 31 | 'terminal', num_buckets=10) 32 | hour = tf.feature_column.categorical_column_with_identity( 33 | 'hour', num_buckets=25) 34 | weekday = tf.feature_column.categorical_column_with_identity( 35 | 'weekday', num_buckets=10) 36 | day_user_show = tf.feature_column.bucketized_column( 37 | tf.feature_column.numeric_column('day_user_show', dtype=tf.int32), boundaries=DayShowSegs) 38 | day_user_click = tf.feature_column.bucketized_column( 39 | tf.feature_column.numeric_column('day_user_click', dtype=tf.int32), boundaries=DayClickSegs) 40 | 41 | city_code = tf.feature_column.categorical_column_with_hash_bucket( 42 | 'city_code', hash_bucket_size=2000, dtype=tf.int64) 43 | 44 | network_type = tf.feature_column.categorical_column_with_identity( 45 | 'network_type', num_buckets=20, default_value=19) 46 | 47 | device_type = tf.feature_column.categorical_column_with_hash_bucket( #androidphone这些 48 | 'device_type', hash_bucket_size=500000, dtype=tf.string 49 | ) 50 | device_model = tf.feature_column.categorical_column_with_hash_bucket( #型号如iPhone10 vivo X9 51 | 'device_model', hash_bucket_size=200000, dtype=tf.string 52 | ) 53 | manufacturer = tf.feature_column.categorical_column_with_hash_bucket( #手机品牌 vivo iphone等 54 | 'manufacturer', hash_bucket_size=50000, dtype=tf.string 55 | ) 56 | 57 | 58 | deep_columns = [ 59 | tf.feature_column.embedding_column(creative_id, dimension=15,combiner='sum'), 60 | tf.feature_column.embedding_column(has_target, dimension=15,combiner='sum'), 61 | tf.feature_column.embedding_column(terminal, dimension=15, combiner='sum'), 62 | tf.feature_column.embedding_column(hour, dimension=15, combiner='sum'), 63 | tf.feature_column.embedding_column(weekday, dimension=15, combiner='sum'), 64 | tf.feature_column.embedding_column(day_user_show, dimension=15, combiner='sum'), 65 | tf.feature_column.embedding_column(day_user_click, dimension=15, combiner='sum'), 66 | tf.feature_column.embedding_column(city_code, dimension=15, combiner='sum'), 67 | tf.feature_column.embedding_column(network_type, dimension=15, combiner='sum'), 68 | tf.feature_column.embedding_column(device_type, dimension=15, combiner='sum'), 69 | tf.feature_column.embedding_column(device_model, dimension=15, combiner='sum'), 70 | tf.feature_column.embedding_column(manufacturer, dimension=15, combiner='sum'), 71 | 72 | ] 73 | # base_columns = [user_id, ad_id, creative_id, product_id, brush_num, terminal,terminal_brand] 74 | ''' 75 | crossed_columns = [tf.feature_column.crossed_column( 76 | ['userId', 'adId'], hash_bucket_size = 50000000), 77 | 、、、 78 | ] 79 | ''' 80 | return deep_columns 81 | 82 | def feature_input_fn(data_file, num_epochs, shuffle, batch_size, labels=True): 83 | """Generate an input function for the Estimator.""" 84 | 85 | def parse_tfrecord(value): 86 | tf.logging.info('Parsing {}'.format(data_file[:10])) 87 | FixedLenFeatures = { 88 | key: tf.FixedLenFeature(shape=[1], dtype=tf.int64) for key in FixedLenFeatureColumns 89 | } 90 | 91 | StringVarLenFeatures = { 92 | key: tf.VarLenFeature(dtype=tf.string) for key in StringVarLenFeatureColumns 93 | } 94 | FloatFixedLenFeatures = { 95 | key: tf.FixedLenFeature(shape=[1], dtype=tf.float32) for key in FloatFixedLenFeatureColumns 96 | } 97 | StringFixedLenFeatures = { 98 | key: tf.FixedLenFeature(shape=[20], dtype=tf.string) for key in StringFixedLenFeatureColumns 99 | } 100 | StringFeatures = { 101 | key: tf.FixedLenFeature(shape=[1], dtype=tf.string) for key in StringFeatureColumns 102 | } 103 | features={} 104 | features.update(FixedLenFeatures) 105 | features.update(StringVarLenFeatures) 106 | features.update(FloatFixedLenFeatures) 107 | features.update(StringFixedLenFeatures) 108 | features.update(StringFeatures) 109 | 110 | fea = tf.parse_example(value, features) 111 | feature = { 112 | key: fea[key] for key in features 113 | } 114 | classes = tf.to_float(feature['label']) 115 | return feature, classes 116 | 117 | # Extract lines from input files using the Dataset API. 118 | filenames = tf.data.Dataset.list_files(data_file) 119 | dataset = filenames.apply(tf.contrib.data.parallel_interleave( 120 | lambda filename: tf.data.TFRecordDataset(filename), 121 | cycle_length=32)) 122 | 123 | if shuffle: 124 | dataset = dataset.shuffle(buffer_size=batch_size*64) 125 | 126 | dataset = dataset.repeat(num_epochs).batch(batch_size).prefetch(buffer_size=batch_size*8) 127 | dataset = dataset.map(parse_tfrecord, num_parallel_calls=32) 128 | 129 | return dataset 130 | 131 | -------------------------------------------------------------------------------- /DeepFM/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | ''' 5 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 6 | 7 | @author: Qiao 8 | ''' 9 | def cal_group_auc(labels, preds, user_id_list): 10 | """Calculate group auc""" 11 | 12 | print('*'*50) 13 | if len(user_id_list) != len(labels): 14 | raise ValueError( 15 | "impression id num should equal to the sample num," \ 16 | "impression id num is {0}".format(len(user_id_list))) 17 | group_score = defaultdict(lambda: []) 18 | group_truth = defaultdict(lambda: []) 19 | for idx, truth in enumerate(labels): 20 | user_id = user_id_list[idx] 21 | score = preds[idx] 22 | truth = labels[idx] 23 | group_score[user_id].append(score) 24 | group_truth[user_id].append(truth) 25 | 26 | group_flag = defaultdict(lambda: False) 27 | for user_id in set(user_id_list): 28 | truths = group_truth[user_id] 29 | flag = False 30 | for i in range(len(truths) - 1): 31 | if truths[i] != truths[i + 1]: 32 | flag = True 33 | break 34 | group_flag[user_id] = flag 35 | 36 | impression_total = 0 37 | total_auc = 0 38 | # 39 | for user_id in group_flag: 40 | if group_flag[user_id]: 41 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 42 | total_auc += auc * len(group_truth[user_id]) 43 | impression_total += len(group_truth[user_id]) 44 | group_auc = float(total_auc) / impression_total 45 | group_auc = round(group_auc, 4) 46 | return group_auc 47 | 48 | def cross_entropy_loss(labels, preds): 49 | """calculate cross_entropy_loss 50 | 51 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 52 | 53 | Args: 54 | labels, preds 55 | 56 | Returns: 57 | log loss 58 | """ 59 | 60 | if len(labels) != len(preds): 61 | raise ValueError( 62 | "labels num should equal to the preds num,") 63 | 64 | z = np.array(labels) 65 | x = np.array(preds) 66 | res = -z * np.log(x) - (1-z)*np.log(1-x) 67 | return res.tolist() 68 | -------------------------------------------------------------------------------- /DeepFM/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 1 -eo 3 | -------------------------------------------------------------------------------- /Din/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /Din/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 2 -------------------------------------------------------------------------------- /Din/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def dice(_x, axis=-1, epsilon=0.0000001, name=''): 4 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 5 | alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], 6 | initializer=tf.constant_initializer(0.0), 7 | dtype=tf.float32) 8 | input_shape = list(_x.get_shape()) 9 | 10 | reduction_axes = list(range(len(input_shape))) 11 | del reduction_axes[axis] 12 | broadcast_shape = [1] * len(input_shape) 13 | broadcast_shape[axis] = input_shape[axis] 14 | 15 | # case: train mode (uses stats of the current batch) 16 | mean = tf.reduce_mean(_x, axis=reduction_axes) 17 | brodcast_mean = tf.reshape(mean, broadcast_shape) 18 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) 19 | std = tf.sqrt(std) 20 | brodcast_std = tf.reshape(std, broadcast_shape) 21 | x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) 22 | # x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) 23 | x_p = tf.sigmoid(x_normed) 24 | 25 | return alphas * (1.0 - x_p) * _x + x_p * _x 26 | 27 | def prelu(_x, scope=''): 28 | """parametric ReLU activation""" 29 | with tf.variable_scope(name_or_scope=scope, default_name="prelu"): 30 | _alpha = tf.get_variable("prelu_"+scope, shape=_x.get_shape()[-1], 31 | dtype=_x.dtype, initializer=tf.constant_initializer(0.1)) 32 | return tf.maximum(0.0, _x) + _alpha * tf.minimum(0.0, _x) 33 | 34 | -------------------------------------------------------------------------------- /ESMM/esmm.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | from tensorflow.python.estimator.canned import head as head_lib 4 | from tensorflow.python.ops.losses import losses 5 | 6 | def build_deep_layers(net, params): 7 | # Build the hidden layers, sized according to the 'hidden_units' param. 8 | 9 | for num_hidden_units in params['hidden_units']: 10 | net = tf.layers.dense(net, units=num_hidden_units, activation=tf.nn.relu, 11 | kernel_initializer=tf.glorot_uniform_initializer()) 12 | return net 13 | 14 | def esmm_model_fn(features, labels, mode, params): 15 | net = tf.feature_column.input_layer(features, params['feature_columns']) 16 | last_ctr_layer = build_deep_layers(net, params) 17 | last_cvr_layer = build_deep_layers(net, params) 18 | 19 | #head = tf.contrib.estimator.binary_classification_head(loss_reduction=losses.Reduction.SUM) 20 | head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access 21 | n_classes=2, weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM) 22 | ctr_logits = tf.layers.dense(last_ctr_layer, units=head.logits_dimension, 23 | kernel_initializer=tf.glorot_uniform_initializer()) 24 | cvr_logits = tf.layers.dense(last_cvr_layer, units=head.logits_dimension, 25 | kernel_initializer=tf.glorot_uniform_initializer()) 26 | ctr_preds = tf.sigmoid(ctr_logits) 27 | cvr_preds = tf.sigmoid(cvr_logits) 28 | ctcvr_preds = tf.multiply(ctr_preds, cvr_preds) 29 | 30 | optimizer = tf.train.AdagradOptimizer(learning_rate=params['learning_rate']) 31 | ctr_label = labels['ctr_label'] 32 | cvr_label = labels['cvr_label'] 33 | 34 | user_id = features['user_id'] 35 | click_label = features['label'] 36 | conversion_label = features['is_conversion'] 37 | 38 | 39 | if mode == tf.estimator.ModeKeys.PREDICT: 40 | predictions = { 41 | 'ctr_preds': ctr_preds, 42 | 'cvr_preds': cvr_preds, 43 | 'ctcvr_preds': ctcvr_preds, 44 | 'user_id': user_id, 45 | 'click_label': click_label, 46 | 'conversion_label': conversion_label 47 | } 48 | export_outputs = { 49 | 'regression': tf.estimator.export.RegressionOutput(predictions['cvr_preds']) #线上预测需要的 50 | } 51 | return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs) 52 | 53 | else: 54 | ctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=ctr_label, logits=ctr_logits)) 55 | ctcvr_loss = tf.reduce_sum(tf.losses.log_loss(labels=cvr_label, predictions=ctcvr_preds)) 56 | loss = ctr_loss + ctcvr_loss # loss这儿可以加一个参数,参考multi-task损失的方法 57 | 58 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 59 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op) 60 | """ 61 | return head.create_estimator_spec( 62 | features=features, 63 | mode=mode, 64 | labels=labels, 65 | logits=logits, 66 | train_op_fn=lambda loss: optimizer.minimize(loss, global_step=tf.train.get_global_step()) 67 | ) 68 | """ 69 | 70 | -------------------------------------------------------------------------------- /ESMM/input_fn.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import tensorflow as tf 6 | 7 | 8 | FixedLenFeatureColumns=["label", "user_id", "creative_id", "has_target", "terminal", 9 | "hour", "weekday","template_category", 10 | "day_user_show", "day_user_click", "city_code","network_type"] 11 | StringVarLenFeatureColumns = ["keyword"] #特征长度不固定 12 | FloatFixedLenFeatureColumns = ['creative_history_ctr'] 13 | StringFixedLenFeatureColumns = ["keyword_attention"] 14 | StringFeatureColumns = ["device_type", "device_model", "manufacturer"] 15 | 16 | DayShowSegs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 41, 42, 44, 46, 47, 49, 51, 54, 56, 59, 61, 65, 68, 72, 76, 81, 86, 92, 100, 109, 120, 134, 153, 184, 243, 1195] 17 | DayClickSegs = [1, 2, 3, 6, 23] 18 | 19 | 20 | def build_model_columns(): 21 | """Builds a set of wide and deep feature columns.""" 22 | # Continuous variable columns 23 | # hours_per_week = tf.feature_column.numeric_column('hours_per_week') 24 | 25 | creative_id = tf.feature_column.categorical_column_with_hash_bucket( 26 | 'creative_id', hash_bucket_size=200000, dtype=tf.int64) 27 | # To show an example of hashing: 28 | has_target = tf.feature_column.categorical_column_with_identity( 29 | 'has_target', num_buckets=3) 30 | terminal = tf.feature_column.categorical_column_with_identity( 31 | 'terminal', num_buckets=10) 32 | hour = tf.feature_column.categorical_column_with_identity( 33 | 'hour', num_buckets=25) 34 | weekday = tf.feature_column.categorical_column_with_identity( 35 | 'weekday', num_buckets=10) 36 | day_user_show = tf.feature_column.bucketized_column( 37 | tf.feature_column.numeric_column('day_user_show', dtype=tf.int32), boundaries=DayShowSegs) 38 | day_user_click = tf.feature_column.bucketized_column( 39 | tf.feature_column.numeric_column('day_user_click', dtype=tf.int32), boundaries=DayClickSegs) 40 | 41 | city_code = tf.feature_column.categorical_column_with_hash_bucket( 42 | 'city_code', hash_bucket_size=2000, dtype=tf.int64) 43 | 44 | network_type = tf.feature_column.categorical_column_with_identity( 45 | 'network_type', num_buckets=20, default_value=19) 46 | 47 | device_type = tf.feature_column.categorical_column_with_hash_bucket( #androidphone这些 48 | 'device_type', hash_bucket_size=500000, dtype=tf.string 49 | ) 50 | device_model = tf.feature_column.categorical_column_with_hash_bucket( #型号如iPhone10 vivo X9 51 | 'device_model', hash_bucket_size=200000, dtype=tf.string 52 | ) 53 | manufacturer = tf.feature_column.categorical_column_with_hash_bucket( #手机品牌 vivo iphone等 54 | 'manufacturer', hash_bucket_size=50000, dtype=tf.string 55 | ) 56 | 57 | 58 | deep_columns = [ 59 | tf.feature_column.embedding_column(creative_id, dimension=15,combiner='sum'), 60 | tf.feature_column.embedding_column(has_target, dimension=15,combiner='sum'), 61 | tf.feature_column.embedding_column(terminal, dimension=15, combiner='sum'), 62 | tf.feature_column.embedding_column(hour, dimension=15, combiner='sum'), 63 | tf.feature_column.embedding_column(weekday, dimension=15, combiner='sum'), 64 | tf.feature_column.embedding_column(day_user_show, dimension=15, combiner='sum'), 65 | tf.feature_column.embedding_column(day_user_click, dimension=15, combiner='sum'), 66 | tf.feature_column.embedding_column(city_code, dimension=15, combiner='sum'), 67 | tf.feature_column.embedding_column(network_type, dimension=15, combiner='sum'), 68 | tf.feature_column.embedding_column(device_type, dimension=15, combiner='sum'), 69 | tf.feature_column.embedding_column(device_model, dimension=15, combiner='sum'), 70 | tf.feature_column.embedding_column(manufacturer, dimension=15, combiner='sum'), 71 | 72 | ] 73 | # base_columns = [user_id, ad_id, creative_id, product_id, brush_num, terminal,terminal_brand] 74 | ''' 75 | crossed_columns = [tf.feature_column.crossed_column( 76 | ['userId', 'adId'], hash_bucket_size = 50000000), 77 | 、、、 78 | ] 79 | ''' 80 | return deep_columns 81 | 82 | def feature_input_fn(data_file, num_epochs, shuffle, batch_size, labels=True): 83 | """Generate an input function for the Estimator.""" 84 | 85 | def parse_tfrecord(value): 86 | tf.logging.info('Parsing {}'.format(data_file[:10])) 87 | FixedLenFeatures = { 88 | key: tf.FixedLenFeature(shape=[1], dtype=tf.int64) for key in FixedLenFeatureColumns 89 | } 90 | 91 | StringVarLenFeatures = { 92 | key: tf.VarLenFeature(dtype=tf.string) for key in StringVarLenFeatureColumns 93 | } 94 | FloatFixedLenFeatures = { 95 | key: tf.FixedLenFeature(shape=[1], dtype=tf.float32) for key in FloatFixedLenFeatureColumns 96 | } 97 | StringFixedLenFeatures = { 98 | key: tf.FixedLenFeature(shape=[20], dtype=tf.string) for key in StringFixedLenFeatureColumns 99 | } 100 | StringFeatures = { 101 | key: tf.FixedLenFeature(shape=[1], dtype=tf.string) for key in StringFeatureColumns 102 | } 103 | features={} 104 | features.update(FixedLenFeatures) 105 | features.update(StringVarLenFeatures) 106 | features.update(FloatFixedLenFeatures) 107 | features.update(StringFixedLenFeatures) 108 | features.update(StringFeatures) 109 | 110 | fea = tf.parse_example(value, features) 111 | feature = { 112 | key: fea[key] for key in features 113 | } 114 | classes = tf.to_float(feature['label']) 115 | return feature, classes 116 | 117 | # Extract lines from input files using the Dataset API. 118 | filenames = tf.data.Dataset.list_files(data_file) 119 | dataset = filenames.apply(tf.contrib.data.parallel_interleave( 120 | lambda filename: tf.data.TFRecordDataset(filename), 121 | cycle_length=32)) 122 | 123 | if shuffle: 124 | dataset = dataset.shuffle(buffer_size=batch_size*64) 125 | 126 | dataset = dataset.repeat(num_epochs).batch(batch_size).prefetch(buffer_size=batch_size*8) 127 | dataset = dataset.map(parse_tfrecord, num_parallel_calls=32) 128 | 129 | return dataset 130 | 131 | -------------------------------------------------------------------------------- /ESMM/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /ESMM/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 1 3 | -------------------------------------------------------------------------------- /Fibinet/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /Fibinet/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 2 -------------------------------------------------------------------------------- /Fibinet/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def dice(_x, axis=-1, epsilon=0.0000001, name=''): 4 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 5 | alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], 6 | initializer=tf.constant_initializer(0.0), 7 | dtype=tf.float32) 8 | input_shape = list(_x.get_shape()) 9 | 10 | reduction_axes = list(range(len(input_shape))) 11 | del reduction_axes[axis] 12 | broadcast_shape = [1] * len(input_shape) 13 | broadcast_shape[axis] = input_shape[axis] 14 | 15 | # case: train mode (uses stats of the current batch) 16 | mean = tf.reduce_mean(_x, axis=reduction_axes) 17 | brodcast_mean = tf.reshape(mean, broadcast_shape) 18 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) 19 | std = tf.sqrt(std) 20 | brodcast_std = tf.reshape(std, broadcast_shape) 21 | #x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) 22 | x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) 23 | x_p = tf.sigmoid(x_normed) 24 | 25 | return alphas * (1.0 - x_p) * _x + x_p * _x 26 | 27 | def prelu(_x, scope=''): 28 | """parametric ReLU activation""" 29 | with tf.variable_scope(name_or_scope=scope, default_name="prelu"): 30 | _alpha = tf.get_variable("prelu_"+scope, shape=_x.get_shape()[-1], 31 | dtype=_x.dtype, initializer=tf.constant_initializer(0.1)) 32 | return tf.maximum(0.0, _x) + _alpha * tf.minimum(0.0, _x) 33 | 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deep-ctr-prediction 2 | 3 | 一些广告算法(CTR预估)相关的DNN模型 4 | 5 | 6 | * wide&deep 可以参考official/wide_deep 7 | 8 | * deep&cross 9 | 10 | * deepfm 11 | 12 | * ESMM 13 | 14 | * Deep Interest Network 15 | 16 | * ResNet 17 | 18 | * xDeepFM 19 | 20 | * AFM(Attentional FM) 21 | 22 | * Transformer 23 | 24 | * FiBiNET 25 | 26 | 代码使用tf.estimator构建, 数据存储为tfrecord格式(字典,key:value), 采用tf.Dataset API, 加快IO速度,支持工业级的应用。特征工程定义在input_fn,模型定义在model_fn,实现特征和模型代码分离,特征工程代码只用修改input_fn,模型代码只用修改model_fn。数据默认都是存在hadoop,可以根据自己需求存在本地, 特征工程和数据的处理可以参考Google开源的wide&deep模型(不使用tfrecord格式, 代码在official/wide_deep) 27 | 28 | All codes are written based on tf.estimator API, the data is stored in tfrecord format(dictionary, key:value), and the tf.Dataset API is used to speed up IO speed, it support industrial applications. Feature engineering is defined in input_fn, model function is defined in model_fn, the related code of feature engineering and model function is completely separated, 29 | the data is stored in hadoop by default, and can be locally stored according to your 30 | own need. The feature engineering and data processing can refer to Google's open source wide&deep model(without tfrecord format, codes are available at official/wide_deep) 31 | 32 | # Requirements 33 | * Tensorflow 1.10 34 | 35 | # 参考文献 36 | 37 | 【1】Heng-Tze Cheng, Levent Koc et all. "Wide & Deep Learning for Recommender Systems," In 1st Workshop on Deep Learning for Recommender Systems,2016. 38 | 39 | 【2】Huifeng Guo et all. "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction," In IJCAI,2017. 40 | 41 | 【3】Ruoxi Wang et all. "Deep & Cross Network for Ad Click Predictions," In ADKDD,2017. 42 | 43 | 【4】Xiao Ma et all. "Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate," In SIGIR,2018. 44 | 45 | 【5】Guorui Zhou et all. "Deep Interest Network for Click-Through Rate Prediction," In KDD,2018. 46 | 47 | 【6】Kaiming He et all. "Deep Residual Learning for Image Recognition," In CVPR,2016. 48 | 49 | 【7】Jianxun Lian et all. "xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems," In KDD,2018. 50 | 51 | 【8】Jun Xiao et all. "Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks," In IJCAI, 2017. 52 | 53 | 【9】Ashish Vasmani et all. "Attention is All You Need," In NIPS, 2017. 54 | 55 | 【10】Tongwen et all. "FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction," In RecSys, 2019. -------------------------------------------------------------------------------- /ResNet/input_fn.py: -------------------------------------------------------------------------------- 1 | #-*- coding: UTF-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import tensorflow as tf 6 | 7 | 8 | FixedLenFeatureColumns=["label", "user_id", "creative_id", "has_target", "terminal", 9 | "hour", "weekday","template_category", 10 | "day_user_show", "day_user_click", "city_code","network_type"] 11 | StringVarLenFeatureColumns = ["keyword"] #特征长度不固定 12 | FloatFixedLenFeatureColumns = ['creative_history_ctr'] 13 | StringFixedLenFeatureColumns = ["keyword_attention"] 14 | StringFeatureColumns = ["device_type", "device_model", "manufacturer"] 15 | 16 | DayShowSegs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37, 38, 39, 41, 42, 44, 46, 47, 49, 51, 54, 56, 59, 61, 65, 68, 72, 76, 81, 86, 92, 100, 109, 120, 134, 153, 184, 243, 1195] 17 | DayClickSegs = [1, 2, 3, 6, 23] 18 | 19 | 20 | def build_model_columns(): 21 | """Builds a set of wide and deep feature columns.""" 22 | # Continuous variable columns 23 | # hours_per_week = tf.feature_column.numeric_column('hours_per_week') 24 | 25 | creative_id = tf.feature_column.categorical_column_with_hash_bucket( 26 | 'creative_id', hash_bucket_size=200000, dtype=tf.int64) 27 | # To show an example of hashing: 28 | has_target = tf.feature_column.categorical_column_with_identity( 29 | 'has_target', num_buckets=3) 30 | terminal = tf.feature_column.categorical_column_with_identity( 31 | 'terminal', num_buckets=10) 32 | hour = tf.feature_column.categorical_column_with_identity( 33 | 'hour', num_buckets=25) 34 | weekday = tf.feature_column.categorical_column_with_identity( 35 | 'weekday', num_buckets=10) 36 | day_user_show = tf.feature_column.bucketized_column( 37 | tf.feature_column.numeric_column('day_user_show', dtype=tf.int32), boundaries=DayShowSegs) 38 | day_user_click = tf.feature_column.bucketized_column( 39 | tf.feature_column.numeric_column('day_user_click', dtype=tf.int32), boundaries=DayClickSegs) 40 | 41 | city_code = tf.feature_column.categorical_column_with_hash_bucket( 42 | 'city_code', hash_bucket_size=2000, dtype=tf.int64) 43 | 44 | network_type = tf.feature_column.categorical_column_with_identity( 45 | 'network_type', num_buckets=20, default_value=19) 46 | 47 | device_type = tf.feature_column.categorical_column_with_hash_bucket( #androidphone这些 48 | 'device_type', hash_bucket_size=500000, dtype=tf.string 49 | ) 50 | device_model = tf.feature_column.categorical_column_with_hash_bucket( #型号如iPhone10 vivo X9 51 | 'device_model', hash_bucket_size=200000, dtype=tf.string 52 | ) 53 | manufacturer = tf.feature_column.categorical_column_with_hash_bucket( #手机品牌 vivo iphone等 54 | 'manufacturer', hash_bucket_size=50000, dtype=tf.string 55 | ) 56 | 57 | 58 | deep_columns = [ 59 | tf.feature_column.embedding_column(creative_id, dimension=15,combiner='sum'), 60 | tf.feature_column.embedding_column(has_target, dimension=15,combiner='sum'), 61 | tf.feature_column.embedding_column(terminal, dimension=15, combiner='sum'), 62 | tf.feature_column.embedding_column(hour, dimension=15, combiner='sum'), 63 | tf.feature_column.embedding_column(weekday, dimension=15, combiner='sum'), 64 | tf.feature_column.embedding_column(day_user_show, dimension=15, combiner='sum'), 65 | tf.feature_column.embedding_column(day_user_click, dimension=15, combiner='sum'), 66 | tf.feature_column.embedding_column(city_code, dimension=15, combiner='sum'), 67 | tf.feature_column.embedding_column(network_type, dimension=15, combiner='sum'), 68 | tf.feature_column.embedding_column(device_type, dimension=15, combiner='sum'), 69 | tf.feature_column.embedding_column(device_model, dimension=15, combiner='sum'), 70 | tf.feature_column.embedding_column(manufacturer, dimension=15, combiner='sum'), 71 | 72 | ] 73 | # base_columns = [user_id, ad_id, creative_id, product_id, brush_num, terminal,terminal_brand] 74 | ''' 75 | crossed_columns = [tf.feature_column.crossed_column( 76 | ['userId', 'adId'], hash_bucket_size = 50000000), 77 | 、、、 78 | ] 79 | ''' 80 | return deep_columns 81 | 82 | def feature_input_fn(data_file, num_epochs, shuffle, batch_size, labels=True): 83 | """Generate an input function for the Estimator.""" 84 | 85 | def parse_tfrecord(value): 86 | tf.logging.info('Parsing {}'.format(data_file[:10])) 87 | FixedLenFeatures = { 88 | key: tf.FixedLenFeature(shape=[1], dtype=tf.int64) for key in FixedLenFeatureColumns 89 | } 90 | 91 | StringVarLenFeatures = { 92 | key: tf.VarLenFeature(dtype=tf.string) for key in StringVarLenFeatureColumns 93 | } 94 | FloatFixedLenFeatures = { 95 | key: tf.FixedLenFeature(shape=[1], dtype=tf.float32) for key in FloatFixedLenFeatureColumns 96 | } 97 | StringFixedLenFeatures = { 98 | key: tf.FixedLenFeature(shape=[20], dtype=tf.string) for key in StringFixedLenFeatureColumns 99 | } 100 | StringFeatures = { 101 | key: tf.FixedLenFeature(shape=[1], dtype=tf.string) for key in StringFeatureColumns 102 | } 103 | features={} 104 | features.update(FixedLenFeatures) 105 | features.update(StringVarLenFeatures) 106 | features.update(FloatFixedLenFeatures) 107 | features.update(StringFixedLenFeatures) 108 | features.update(StringFeatures) 109 | 110 | fea = tf.parse_example(value, features) 111 | feature = { 112 | key: fea[key] for key in features 113 | } 114 | classes = tf.to_float(feature['label']) 115 | return feature, classes 116 | 117 | # Extract lines from input files using the Dataset API. 118 | filenames = tf.data.Dataset.list_files(data_file) 119 | dataset = filenames.apply(tf.contrib.data.parallel_interleave( 120 | lambda filename: tf.data.TFRecordDataset(filename), 121 | cycle_length=32)) 122 | 123 | if shuffle: 124 | dataset = dataset.shuffle(buffer_size=batch_size*64) 125 | 126 | dataset = dataset.repeat(num_epochs).batch(batch_size).prefetch(buffer_size=batch_size*8) 127 | dataset = dataset.map(parse_tfrecord, num_parallel_calls=32) 128 | 129 | return dataset 130 | 131 | -------------------------------------------------------------------------------- /ResNet/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /ResNet/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | from tensorflow.python.estimator.canned import head as head_lib 4 | from tensorflow.python.ops.losses import losses 5 | from utils import dice 6 | import collections 7 | 8 | 9 | def build_deep_layers(net, params): 10 | # Build the hidden layers, sized according to the 'hidden_units' param. 11 | 12 | for layer_id, num_hidden_units in enumerate(params['hidden_units']): 13 | net = tf.layers.dense(net, units=num_hidden_units, activation=tf.nn.relu, 14 | kernel_initializer=tf.glorot_uniform_initializer()) 15 | return net 16 | 17 | def build_residual_layers(net, params): 18 | # Build the hidden layers, sized according to the 'hidden_units' param. 19 | net = tf.layers.batch_normalization(net) 20 | shortcut = net 21 | residual = tf.layers.dense(net, units=256, activation=tf.nn.relu, 22 | kernel_initializer=tf.glorot_uniform_initializer()) 23 | net = tf.concat([shortcut, residual], 1) 24 | net = tf.layers.batch_normalization(net) 25 | net = tf.layers.dense(net, units=256, activation=tf.nn.relu, 26 | kernel_initializer=tf.glorot_uniform_initializer()) 27 | 28 | net = tf.layers.batch_normalization(net) 29 | shortcut = net 30 | residual = tf.layers.dense(net, units=128, activation=tf.nn.relu, 31 | kernel_initializer=tf.glorot_uniform_initializer()) 32 | net = tf.concat([shortcut, residual], 1) 33 | net = tf.layers.batch_normalization(net) 34 | net = tf.layers.dense(net, units=128, activation=tf.nn.relu, 35 | kernel_initializer=tf.glorot_uniform_initializer()) 36 | 37 | return net 38 | 39 | 40 | def resnet_model_fn(features, labels, mode, params): 41 | net = tf.feature_column.input_layer(features, params['feature_columns']) 42 | 43 | last_layer = build_residual_layers(net, params) 44 | # head = tf.contrib.estimator.binary_classification_head(loss_reduction=losses.Reduction.SUM) 45 | head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access 46 | n_classes=2, weight_column=None, label_vocabulary=None, loss_reduction=losses.Reduction.SUM) 47 | logits = tf.layers.dense(last_layer, units=head.logits_dimension, 48 | kernel_initializer=tf.glorot_uniform_initializer()) 49 | optimizer = tf.train.AdagradOptimizer(learning_rate=params['learning_rate']) 50 | preds = tf.sigmoid(logits) 51 | user_id = features['user_id'] 52 | label = features['label'] 53 | 54 | if mode == tf.estimator.ModeKeys.PREDICT: 55 | predictions = { 56 | 'probabilities': preds, 57 | 'user_id': user_id, 58 | 'label': label 59 | } 60 | export_outputs = { 61 | 'regression': tf.estimator.export.RegressionOutput(predictions['probabilities']) 62 | } 63 | return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs) 64 | 65 | return head.create_estimator_spec( 66 | features=features, 67 | mode=mode, 68 | labels=labels, 69 | logits=logits, 70 | train_op_fn=lambda loss: optimizer.minimize(loss, global_step=tf.train.get_global_step()) 71 | ) 72 | 73 | -------------------------------------------------------------------------------- /ResNet/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 2 -------------------------------------------------------------------------------- /ResNet/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def dice(_x, axis=-1, epsilon=0.0000001, name=''): 4 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 5 | alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], 6 | initializer=tf.constant_initializer(0.0), 7 | dtype=tf.float32) 8 | input_shape = list(_x.get_shape()) 9 | 10 | reduction_axes = list(range(len(input_shape))) 11 | del reduction_axes[axis] 12 | broadcast_shape = [1] * len(input_shape) 13 | broadcast_shape[axis] = input_shape[axis] 14 | 15 | # case: train mode (uses stats of the current batch) 16 | mean = tf.reduce_mean(_x, axis=reduction_axes) 17 | brodcast_mean = tf.reshape(mean, broadcast_shape) 18 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) 19 | std = tf.sqrt(std) 20 | brodcast_std = tf.reshape(std, broadcast_shape) 21 | x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) 22 | # x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) 23 | x_p = tf.sigmoid(x_normed) 24 | 25 | return alphas * (1.0 - x_p) * _x + x_p * _x 26 | 27 | def prelu(_x, scope=''): 28 | """parametric ReLU activation""" 29 | with tf.variable_scope(name_or_scope=scope, default_name="prelu"): 30 | _alpha = tf.get_variable("prelu_"+scope, shape=_x.get_shape()[-1], 31 | dtype=_x.dtype, initializer=tf.constant_initializer(0.1)) 32 | return tf.maximum(0.0, _x) + _alpha * tf.minimum(0.0, _x) 33 | 34 | -------------------------------------------------------------------------------- /Transformer/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /Transformer/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 2 -------------------------------------------------------------------------------- /XDeepFM/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from sklearn.metrics import roc_auc_score 3 | import numpy as np 4 | 5 | ''' 6 | calculate group_auc and cross_entropy_loss(log loss for binary classification) 7 | 8 | @author: Qiao 9 | ''' 10 | 11 | 12 | def cal_group_auc(labels, preds, user_id_list): 13 | """Calculate group auc""" 14 | 15 | print('*' * 50) 16 | if len(user_id_list) != len(labels): 17 | raise ValueError( 18 | "impression id num should equal to the sample num," \ 19 | "impression id num is {0}".format(len(user_id_list))) 20 | group_score = defaultdict(lambda: []) 21 | group_truth = defaultdict(lambda: []) 22 | for idx, truth in enumerate(labels): 23 | user_id = user_id_list[idx] 24 | score = preds[idx] 25 | truth = labels[idx] 26 | group_score[user_id].append(score) 27 | group_truth[user_id].append(truth) 28 | 29 | group_flag = defaultdict(lambda: False) 30 | for user_id in set(user_id_list): 31 | truths = group_truth[user_id] 32 | flag = False 33 | for i in range(len(truths) - 1): 34 | if truths[i] != truths[i + 1]: 35 | flag = True 36 | break 37 | group_flag[user_id] = flag 38 | 39 | impression_total = 0 40 | total_auc = 0 41 | # 42 | for user_id in group_flag: 43 | if group_flag[user_id]: 44 | auc = roc_auc_score(np.asarray(group_truth[user_id]), np.asarray(group_score[user_id])) 45 | total_auc += auc * len(group_truth[user_id]) 46 | impression_total += len(group_truth[user_id]) 47 | group_auc = float(total_auc) / impression_total 48 | group_auc = round(group_auc, 4) 49 | return group_auc 50 | 51 | 52 | def cross_entropy_loss(labels, preds): 53 | """calculate cross_entropy_loss 54 | 55 | loss = -labels*log(preds)-(1-labels)*log(1-preds) 56 | 57 | Args: 58 | labels, preds 59 | 60 | Returns: 61 | log loss 62 | """ 63 | 64 | if len(labels) != len(preds): 65 | raise ValueError( 66 | "labels num should equal to the preds num,") 67 | 68 | z = np.array(labels) 69 | x = np.array(preds) 70 | res = -z * np.log(x) - (1 - z) * np.log(1 - x) 71 | return res.tolist() 72 | -------------------------------------------------------------------------------- /XDeepFM/run.sh: -------------------------------------------------------------------------------- 1 | export HADOOP_HDFS_HOME=$HADOOP_HOME/../hadoop-hdfs 2 | CLASSPATH=$(${HADOOP_HOME}/bin/hadoop classpath --glob) python train.py -ne 2 -------------------------------------------------------------------------------- /XDeepFM/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def dice(_x, axis=-1, epsilon=0.0000001, name=''): 4 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 5 | alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1], 6 | initializer=tf.constant_initializer(0.0), 7 | dtype=tf.float32) 8 | input_shape = list(_x.get_shape()) 9 | 10 | reduction_axes = list(range(len(input_shape))) 11 | del reduction_axes[axis] 12 | broadcast_shape = [1] * len(input_shape) 13 | broadcast_shape[axis] = input_shape[axis] 14 | 15 | # case: train mode (uses stats of the current batch) 16 | mean = tf.reduce_mean(_x, axis=reduction_axes) 17 | brodcast_mean = tf.reshape(mean, broadcast_shape) 18 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) 19 | std = tf.sqrt(std) 20 | brodcast_std = tf.reshape(std, broadcast_shape) 21 | x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon) 22 | # x_normed = tf.layers.batch_normalization(_x, center=False, scale=False) 23 | x_p = tf.sigmoid(x_normed) 24 | 25 | return alphas * (1.0 - x_p) * _x + x_p * _x 26 | 27 | def prelu(_x, scope=''): 28 | """parametric ReLU activation""" 29 | with tf.variable_scope(name_or_scope=scope, default_name="prelu"): 30 | _alpha = tf.get_variable("prelu_"+scope, shape=_x.get_shape()[-1], 31 | dtype=_x.dtype, initializer=tf.constant_initializer(0.1)) 32 | return tf.maximum(0.0, _x) + _alpha * tf.minimum(0.0, _x) 33 | 34 | -------------------------------------------------------------------------------- /official/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/.DS_Store -------------------------------------------------------------------------------- /official/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/__init__.py -------------------------------------------------------------------------------- /official/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/datasets/__init__.py -------------------------------------------------------------------------------- /official/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/.DS_Store -------------------------------------------------------------------------------- /official/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/__init__.py -------------------------------------------------------------------------------- /official/utils/accelerator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/accelerator/__init__.py -------------------------------------------------------------------------------- /official/utils/accelerator/tpu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Functions specific to running TensorFlow on TPUs.""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | # "local" is a magic word in the TPU cluster resolver; it informs the resolver 21 | # to use the local CPU as the compute device. This is useful for testing and 22 | # debugging; the code flow is ostensibly identical, but without the need to 23 | # actually have a TPU on the other end. 24 | LOCAL = "local" 25 | 26 | 27 | def construct_scalar_host_call(metric_dict, model_dir, prefix=""): 28 | """Construct a host call to log scalars when training on TPU. 29 | 30 | Args: 31 | metric_dict: A dict of the tensors to be logged. 32 | model_dir: The location to write the summary. 33 | prefix: The prefix (if any) to prepend to the metric names. 34 | 35 | Returns: 36 | A tuple of (function, args_to_be_passed_to_said_function) 37 | """ 38 | # type: (dict, str) -> (function, list) 39 | metric_names = list(metric_dict.keys()) 40 | 41 | def host_call_fn(global_step, *args): 42 | """Training host call. Creates scalar summaries for training metrics. 43 | 44 | This function is executed on the CPU and should not directly reference 45 | any Tensors in the rest of the `model_fn`. To pass Tensors from the 46 | model to the `metric_fn`, provide as part of the `host_call`. See 47 | https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec 48 | for more information. 49 | 50 | Arguments should match the list of `Tensor` objects passed as the second 51 | element in the tuple passed to `host_call`. 52 | 53 | Args: 54 | global_step: `Tensor with shape `[batch]` for the global_step 55 | *args: Remaining tensors to log. 56 | 57 | Returns: 58 | List of summary ops to run on the CPU host. 59 | """ 60 | step = global_step[0] 61 | with tf.contrib.summary.create_file_writer( 62 | logdir=model_dir, filename_suffix=".host_call").as_default(): 63 | with tf.contrib.summary.always_record_summaries(): 64 | for i, name in enumerate(metric_names): 65 | tf.contrib.summary.scalar(prefix + name, args[i][0], step=step) 66 | 67 | return tf.contrib.summary.all_summary_ops() 68 | 69 | # To log the current learning rate, and gradient norm for Tensorboard, the 70 | # summary op needs to be run on the host CPU via host_call. host_call 71 | # expects [batch_size, ...] Tensors, thus reshape to introduce a batch 72 | # dimension. These Tensors are implicitly concatenated to 73 | # [params['batch_size']]. 74 | global_step_tensor = tf.reshape( 75 | tf.compat.v1.train.get_or_create_global_step(), [1]) 76 | other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names] 77 | 78 | return host_call_fn, [global_step_tensor] + other_tensors 79 | 80 | 81 | def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"): 82 | """Performs embedding lookup via a matmul. 83 | 84 | The matrix to be multiplied by the embedding table Tensor is constructed 85 | via an implementation of scatter based on broadcasting embedding indices 86 | and performing an equality comparison against a broadcasted 87 | range(num_embedding_table_rows). All masked positions will produce an 88 | embedding vector of zeros. 89 | 90 | Args: 91 | embedding_table: Tensor of embedding table. 92 | Rank 2 (table_size x embedding dim) 93 | values: Tensor of embedding indices. Rank 2 (batch x n_indices) 94 | mask: Tensor of mask / weights. Rank 2 (batch x n_indices) 95 | name: Optional name scope for created ops 96 | 97 | Returns: 98 | Rank 3 tensor of embedding vectors. 99 | """ 100 | 101 | with tf.name_scope(name): 102 | n_embeddings = embedding_table.get_shape().as_list()[0] 103 | batch_size, padded_size = values.shape.as_list() 104 | 105 | emb_idcs = tf.tile( 106 | tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings)) 107 | emb_weights = tf.tile( 108 | tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings)) 109 | col_idcs = tf.tile( 110 | tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)), 111 | (batch_size, padded_size, 1)) 112 | one_hot = tf.where( 113 | tf.equal(emb_idcs, col_idcs), emb_weights, 114 | tf.zeros((batch_size, padded_size, n_embeddings))) 115 | 116 | return tf.tensordot(one_hot, embedding_table, 1) 117 | -------------------------------------------------------------------------------- /official/utils/accelerator/tpu_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test TPU optimized matmul embedding.""" 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | from official.utils.accelerator import tpu as tpu_utils 21 | 22 | 23 | TEST_CASES = [ 24 | dict(embedding_dim=256, vocab_size=1000, sequence_length=64, 25 | batch_size=32, seed=54131), 26 | dict(embedding_dim=8, vocab_size=15, sequence_length=12, 27 | batch_size=256, seed=536413), 28 | dict(embedding_dim=2048, vocab_size=512, sequence_length=50, 29 | batch_size=8, seed=35124) 30 | ] 31 | 32 | 33 | class TPUBaseTester(tf.test.TestCase): 34 | def construct_embedding_and_values(self, embedding_dim, vocab_size, 35 | sequence_length, batch_size, seed): 36 | np.random.seed(seed) 37 | 38 | embeddings = np.random.random(size=(vocab_size, embedding_dim)) 39 | embedding_table = tf.convert_to_tensor(value=embeddings, dtype=tf.float32) 40 | 41 | tokens = np.random.randint(low=1, high=vocab_size-1, 42 | size=(batch_size, sequence_length)) 43 | for i in range(batch_size): 44 | tokens[i, np.random.randint(low=0, high=sequence_length-1):] = 0 45 | values = tf.convert_to_tensor(value=tokens, dtype=tf.int32) 46 | mask = tf.cast(tf.not_equal(values, 0), dtype=tf.float32) 47 | return embedding_table, values, mask 48 | 49 | def _test_embedding(self, embedding_dim, vocab_size, 50 | sequence_length, batch_size, seed): 51 | """Test that matmul embedding matches embedding lookup (gather).""" 52 | 53 | with self.test_session(): 54 | embedding_table, values, mask = self.construct_embedding_and_values( 55 | embedding_dim=embedding_dim, 56 | vocab_size=vocab_size, 57 | sequence_length=sequence_length, 58 | batch_size=batch_size, 59 | seed=seed 60 | ) 61 | 62 | embedding = (tf.nn.embedding_lookup(params=embedding_table, ids=values) * 63 | tf.expand_dims(mask, -1)) 64 | 65 | matmul_embedding = tpu_utils.embedding_matmul( 66 | embedding_table=embedding_table, values=values, mask=mask) 67 | 68 | self.assertAllClose(embedding, matmul_embedding) 69 | 70 | def _test_masking(self, embedding_dim, vocab_size, 71 | sequence_length, batch_size, seed): 72 | """Test that matmul embedding properly zeros masked positions.""" 73 | with self.test_session(): 74 | embedding_table, values, mask = self.construct_embedding_and_values( 75 | embedding_dim=embedding_dim, 76 | vocab_size=vocab_size, 77 | sequence_length=sequence_length, 78 | batch_size=batch_size, 79 | seed=seed 80 | ) 81 | 82 | matmul_embedding = tpu_utils.embedding_matmul( 83 | embedding_table=embedding_table, values=values, mask=mask) 84 | 85 | self.assertAllClose(matmul_embedding, 86 | matmul_embedding * tf.expand_dims(mask, -1)) 87 | 88 | def test_embedding_0(self): 89 | self._test_embedding(**TEST_CASES[0]) 90 | 91 | def test_embedding_1(self): 92 | self._test_embedding(**TEST_CASES[1]) 93 | 94 | def test_embedding_2(self): 95 | self._test_embedding(**TEST_CASES[2]) 96 | 97 | def test_masking_0(self): 98 | self._test_masking(**TEST_CASES[0]) 99 | 100 | def test_masking_1(self): 101 | self._test_masking(**TEST_CASES[1]) 102 | 103 | def test_masking_2(self): 104 | self._test_masking(**TEST_CASES[2]) 105 | 106 | 107 | if __name__ == "__main__": 108 | tf.test.main() 109 | -------------------------------------------------------------------------------- /official/utils/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/data/__init__.py -------------------------------------------------------------------------------- /official/utils/export/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/export/__init__.py -------------------------------------------------------------------------------- /official/utils/export/export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Convenience functions for exporting models as SavedModels or other types.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def build_tensor_serving_input_receiver_fn(shape, dtype=tf.float32, 25 | batch_size=1): 26 | """Returns a input_receiver_fn that can be used during serving. 27 | 28 | This expects examples to come through as float tensors, and simply 29 | wraps them as TensorServingInputReceivers. 30 | 31 | Arguably, this should live in tf.estimator.export. Testing here first. 32 | 33 | Args: 34 | shape: list representing target size of a single example. 35 | dtype: the expected datatype for the input example 36 | batch_size: number of input tensors that will be passed for prediction 37 | 38 | Returns: 39 | A function that itself returns a TensorServingInputReceiver. 40 | """ 41 | def serving_input_receiver_fn(): 42 | # Prep a placeholder where the input example will be fed in 43 | features = tf.compat.v1.placeholder( 44 | dtype=dtype, shape=[batch_size] + shape, name='input_tensor') 45 | 46 | return tf.estimator.export.TensorServingInputReceiver( 47 | features=features, receiver_tensors=features) 48 | 49 | return serving_input_receiver_fn 50 | -------------------------------------------------------------------------------- /official/utils/export/export_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for exporting utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf # pylint: disable=g-bad-import-order 22 | 23 | from official.utils.export import export 24 | 25 | 26 | class ExportUtilsTest(tf.test.TestCase): 27 | """Tests for the ExportUtils.""" 28 | 29 | def test_build_tensor_serving_input_receiver_fn(self): 30 | receiver_fn = export.build_tensor_serving_input_receiver_fn(shape=[4, 5]) 31 | with tf.Graph().as_default(): 32 | receiver = receiver_fn() 33 | self.assertIsInstance( 34 | receiver, tf.estimator.export.TensorServingInputReceiver) 35 | 36 | self.assertIsInstance(receiver.features, tf.Tensor) 37 | self.assertEqual(receiver.features.shape, tf.TensorShape([1, 4, 5])) 38 | self.assertEqual(receiver.features.dtype, tf.float32) 39 | self.assertIsInstance(receiver.receiver_tensors, dict) 40 | # Note that Python 3 can no longer index .values() directly; cast to list. 41 | self.assertEqual(list(receiver.receiver_tensors.values())[0].shape, 42 | tf.TensorShape([1, 4, 5])) 43 | 44 | def test_build_tensor_serving_input_receiver_fn_batch_dtype(self): 45 | receiver_fn = export.build_tensor_serving_input_receiver_fn( 46 | shape=[4, 5], dtype=tf.int8, batch_size=10) 47 | 48 | with tf.Graph().as_default(): 49 | receiver = receiver_fn() 50 | self.assertIsInstance( 51 | receiver, tf.estimator.export.TensorServingInputReceiver) 52 | 53 | self.assertIsInstance(receiver.features, tf.Tensor) 54 | self.assertEqual(receiver.features.shape, tf.TensorShape([10, 4, 5])) 55 | self.assertEqual(receiver.features.dtype, tf.int8) 56 | self.assertIsInstance(receiver.receiver_tensors, dict) 57 | # Note that Python 3 can no longer index .values() directly; cast to list. 58 | self.assertEqual(list(receiver.receiver_tensors.values())[0].shape, 59 | tf.TensorShape([10, 4, 5])) 60 | 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /official/utils/flags/README.md: -------------------------------------------------------------------------------- 1 | # Adding Abseil (absl) flags quickstart 2 | ## Defining a flag 3 | absl flag definitions are similar to argparse, although they are defined on a global namespace. 4 | 5 | For instance defining a string flag looks like: 6 | ```$xslt 7 | from absl import flags 8 | flags.DEFINE_string( 9 | name="my_flag", 10 | default="a_sensible_default", 11 | help="Here is what this flag does." 12 | ) 13 | ``` 14 | 15 | All three arguments are required, but default may be `None`. A common optional argument is 16 | short_name for defining abreviations. Certain `DEFINE_*` methods will have other required arguments. 17 | For instance `DEFINE_enum` requires the `enum_values` argument to be specified. 18 | 19 | ## Key Flags 20 | absl has the concept of a key flag. Any flag defined in `__main__` is considered a key flag by 21 | default. Key flags are displayed in `--help`, others only appear in `--helpfull`. In order to 22 | handle key flags that are defined outside the module in question, absl provides the 23 | `flags.adopt_module_key_flags()` method. This adds the key flags of a different module to one's own 24 | key flags. For example: 25 | ```$xslt 26 | File: flag_source.py 27 | --------------------------------------- 28 | 29 | from absl import flags 30 | flags.DEFINE_string(name="my_flag", default="abc", help="a flag.") 31 | ``` 32 | 33 | ```$xslt 34 | File: my_module.py 35 | --------------------------------------- 36 | 37 | from absl import app as absl_app 38 | from absl import flags 39 | 40 | import flag_source 41 | 42 | flags.adopt_module_key_flags(flag_source) 43 | 44 | def main(_): 45 | pass 46 | 47 | absl_app.run(main, [__file__, "-h"] 48 | ``` 49 | 50 | when `my_module.py` is run it will show the help text for `my_flag`. Because not all flags defined 51 | in a file are equally important, `official/utils/flags/core.py` (generally imported as flags_core) 52 | provides an abstraction for handling key flag declaration in an easy way through the 53 | `register_key_flags_in_core()` function, which allows a module to make a single 54 | `adopt_key_flags(flags_core)` call when using the util flag declaration functions. 55 | 56 | ## Validators 57 | Often the constraints on a flag are complicated. absl provides the validator decorator to allow 58 | one to mark a function as a flag validation function. Suppose we want users to provide a flag 59 | which is a palindrome. 60 | 61 | ```$xslt 62 | from absl import flags 63 | 64 | flags.DEFINE_string(name="pal_flag", short_name="pf", default="", help="Give me a palindrome") 65 | 66 | @flags.validator("pal_flag") 67 | def _check_pal(provided_pal_flag): 68 | return provided_pal_flag == provided_pal_flag[::-1] 69 | 70 | ``` 71 | 72 | Validators take the form that returning True (truthy) passes, and all others 73 | (False, None, exception) fail. 74 | 75 | ## Testing 76 | To test using absl, simply declare flags in the setupClass method of TensorFlow's TestCase. 77 | 78 | ```$xslt 79 | from absl import flags 80 | import tensorflow as tf 81 | 82 | def define_flags(): 83 | flags.DEFINE_string(name="test_flag", default="abc", help="an example flag") 84 | 85 | 86 | class BaseTester(unittest.TestCase): 87 | 88 | @classmethod 89 | def setUpClass(cls): 90 | super(BaseTester, cls).setUpClass() 91 | define_flags() 92 | 93 | def test_trivial(self): 94 | flags_core.parse_flags([__file__, "test_flag", "def"]) 95 | self.AssertEqual(flags.FLAGS.test_flag, "def") 96 | 97 | ``` 98 | -------------------------------------------------------------------------------- /official/utils/flags/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/flags/__init__.py -------------------------------------------------------------------------------- /official/utils/flags/_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Flags for benchmarking models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | 23 | from official.utils.flags._conventions import help_wrap 24 | 25 | 26 | def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): 27 | """Register benchmarking flags. 28 | 29 | Args: 30 | benchmark_log_dir: Create a flag to specify location for benchmark logging. 31 | bigquery_uploader: Create flags for uploading results to BigQuery. 32 | 33 | Returns: 34 | A list of flags for core.py to marks as key flags. 35 | """ 36 | 37 | key_flags = [] 38 | 39 | flags.DEFINE_enum( 40 | name="benchmark_logger_type", default="BaseBenchmarkLogger", 41 | enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger", 42 | "BenchmarkBigQueryLogger"], 43 | help=help_wrap("The type of benchmark logger to use. Defaults to using " 44 | "BaseBenchmarkLogger which logs to STDOUT. Different " 45 | "loggers will require other flags to be able to work.")) 46 | flags.DEFINE_string( 47 | name="benchmark_test_id", short_name="bti", default=None, 48 | help=help_wrap("The unique test ID of the benchmark run. It could be the " 49 | "combination of key parameters. It is hardware " 50 | "independent and could be used compare the performance " 51 | "between different test runs. This flag is designed for " 52 | "human consumption, and does not have any impact within " 53 | "the system.")) 54 | 55 | if benchmark_log_dir: 56 | flags.DEFINE_string( 57 | name="benchmark_log_dir", short_name="bld", default=None, 58 | help=help_wrap("The location of the benchmark logging.") 59 | ) 60 | 61 | if bigquery_uploader: 62 | flags.DEFINE_string( 63 | name="gcp_project", short_name="gp", default=None, 64 | help=help_wrap( 65 | "The GCP project name where the benchmark will be uploaded.")) 66 | 67 | flags.DEFINE_string( 68 | name="bigquery_data_set", short_name="bds", default="test_benchmark", 69 | help=help_wrap( 70 | "The Bigquery dataset name where the benchmark will be uploaded.")) 71 | 72 | flags.DEFINE_string( 73 | name="bigquery_run_table", short_name="brt", default="benchmark_run", 74 | help=help_wrap("The Bigquery table name where the benchmark run " 75 | "information will be uploaded.")) 76 | 77 | flags.DEFINE_string( 78 | name="bigquery_run_status_table", short_name="brst", 79 | default="benchmark_run_status", 80 | help=help_wrap("The Bigquery table name where the benchmark run " 81 | "status information will be uploaded.")) 82 | 83 | flags.DEFINE_string( 84 | name="bigquery_metric_table", short_name="bmt", 85 | default="benchmark_metric", 86 | help=help_wrap("The Bigquery table name where the benchmark metric " 87 | "information will be uploaded.")) 88 | 89 | @flags.multi_flags_validator( 90 | ["benchmark_logger_type", "benchmark_log_dir"], 91 | message="--benchmark_logger_type=BenchmarkFileLogger will require " 92 | "--benchmark_log_dir being set") 93 | def _check_benchmark_log_dir(flags_dict): 94 | benchmark_logger_type = flags_dict["benchmark_logger_type"] 95 | if benchmark_logger_type == "BenchmarkFileLogger": 96 | return flags_dict["benchmark_log_dir"] 97 | return True 98 | 99 | return key_flags 100 | -------------------------------------------------------------------------------- /official/utils/flags/_conventions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Central location for shared arparse convention definitions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import functools 23 | 24 | from absl import app as absl_app 25 | from absl import flags 26 | 27 | 28 | # This codifies help string conventions and makes it easy to update them if 29 | # necessary. Currently the only major effect is that help bodies start on the 30 | # line after flags are listed. All flag definitions should wrap the text bodies 31 | # with help wrap when calling DEFINE_*. 32 | _help_wrap = functools.partial(flags.text_wrap, length=80, indent="", 33 | firstline_indent="\n") 34 | 35 | 36 | # Pretty formatting causes issues when utf-8 is not installed on a system. 37 | try: 38 | codecs.lookup("utf-8") 39 | help_wrap = _help_wrap 40 | except LookupError: 41 | def help_wrap(text, *args, **kwargs): 42 | return _help_wrap(text, *args, **kwargs).replace("\ufeff", "") 43 | 44 | 45 | # Replace None with h to also allow -h 46 | absl_app.HelpshortFlag.SHORT_NAME = "h" 47 | -------------------------------------------------------------------------------- /official/utils/flags/_device.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Flags for managing compute devices. Currently only contains TPU flags.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | import tensorflow as tf 23 | 24 | from official.utils.flags._conventions import help_wrap 25 | 26 | 27 | def require_cloud_storage(flag_names): 28 | """Register a validator to check directory flags. 29 | Args: 30 | flag_names: An iterable of strings containing the names of flags to be 31 | checked. 32 | """ 33 | msg = "TPU requires GCS path for {}".format(", ".join(flag_names)) 34 | @flags.multi_flags_validator(["tpu"] + flag_names, message=msg) 35 | def _path_check(flag_values): # pylint: disable=missing-docstring 36 | if flag_values["tpu"] is None: 37 | return True 38 | 39 | valid_flags = True 40 | for key in flag_names: 41 | if not flag_values[key].startswith("gs://"): 42 | tf.compat.v1.logging.error("{} must be a GCS path.".format(key)) 43 | valid_flags = False 44 | 45 | return valid_flags 46 | 47 | 48 | def define_device(tpu=True): 49 | """Register device specific flags. 50 | Args: 51 | tpu: Create flags to specify TPU operation. 52 | Returns: 53 | A list of flags for core.py to marks as key flags. 54 | """ 55 | 56 | key_flags = [] 57 | 58 | if tpu: 59 | flags.DEFINE_string( 60 | name="tpu", default=None, 61 | help=help_wrap( 62 | "The Cloud TPU to use for training. This should be either the name " 63 | "used when creating the Cloud TPU, or a " 64 | "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the" 65 | "CPU of the local instance instead. (Good for debugging.)")) 66 | key_flags.append("tpu") 67 | 68 | flags.DEFINE_string( 69 | name="tpu_zone", default=None, 70 | help=help_wrap( 71 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 72 | "specified, we will attempt to automatically detect the GCE " 73 | "project from metadata.")) 74 | 75 | flags.DEFINE_string( 76 | name="tpu_gcp_project", default=None, 77 | help=help_wrap( 78 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 79 | "specified, we will attempt to automatically detect the GCE " 80 | "project from metadata.")) 81 | 82 | flags.DEFINE_integer(name="num_tpu_shards", default=8, 83 | help=help_wrap("Number of shards (TPU chips).")) 84 | 85 | return key_flags 86 | -------------------------------------------------------------------------------- /official/utils/flags/_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Misc flags.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | 23 | from official.utils.flags._conventions import help_wrap 24 | 25 | 26 | def define_image(data_format=True): 27 | """Register image specific flags. 28 | 29 | Args: 30 | data_format: Create a flag to specify image axis convention. 31 | 32 | Returns: 33 | A list of flags for core.py to marks as key flags. 34 | """ 35 | 36 | key_flags = [] 37 | 38 | if data_format: 39 | flags.DEFINE_enum( 40 | name="data_format", short_name="df", default=None, 41 | enum_values=["channels_first", "channels_last"], 42 | help=help_wrap( 43 | "A flag to override the data format used in the model. " 44 | "channels_first provides a performance boost on GPU but is not " 45 | "always compatible with CPU. If left unspecified, the data format " 46 | "will be chosen automatically based on whether TensorFlow was " 47 | "built for CPU or GPU.")) 48 | key_flags.append("data_format") 49 | 50 | return key_flags 51 | -------------------------------------------------------------------------------- /official/utils/flags/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Public interface for flag definition. 16 | 17 | See _example.py for detailed instructions on defining flags. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import functools 25 | import sys 26 | 27 | from absl import app as absl_app 28 | from absl import flags 29 | 30 | from official.utils.flags import _base 31 | from official.utils.flags import _benchmark 32 | from official.utils.flags import _conventions 33 | from official.utils.flags import _device 34 | from official.utils.flags import _misc 35 | from official.utils.flags import _performance 36 | 37 | 38 | def set_defaults(**kwargs): 39 | for key, value in kwargs.items(): 40 | flags.FLAGS.set_default(name=key, value=value) 41 | 42 | 43 | def parse_flags(argv=None): 44 | """Reset flags and reparse. Currently only used in testing.""" 45 | flags.FLAGS.unparse_flags() 46 | absl_app.parse_flags_with_usage(argv or sys.argv) 47 | 48 | 49 | def register_key_flags_in_core(f): 50 | """Defines a function in core.py, and registers its key flags. 51 | 52 | absl uses the location of a flags.declare_key_flag() to determine the context 53 | in which a flag is key. By making all declares in core, this allows model 54 | main functions to call flags.adopt_module_key_flags() on core and correctly 55 | chain key flags. 56 | 57 | Args: 58 | f: The function to be wrapped 59 | 60 | Returns: 61 | The "core-defined" version of the input function. 62 | """ 63 | 64 | def core_fn(*args, **kwargs): 65 | key_flags = f(*args, **kwargs) 66 | [flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned 67 | return core_fn 68 | 69 | 70 | define_base = register_key_flags_in_core(_base.define_base) 71 | # Remove options not relevant for Eager from define_base(). 72 | define_base_eager = register_key_flags_in_core(functools.partial( 73 | _base.define_base, epochs_between_evals=False, stop_threshold=False, 74 | hooks=False)) 75 | define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark) 76 | define_device = register_key_flags_in_core(_device.define_device) 77 | define_image = register_key_flags_in_core(_misc.define_image) 78 | define_performance = register_key_flags_in_core(_performance.define_performance) 79 | 80 | 81 | help_wrap = _conventions.help_wrap 82 | 83 | 84 | get_num_gpus = _base.get_num_gpus 85 | get_tf_dtype = _performance.get_tf_dtype 86 | get_loss_scale = _performance.get_loss_scale 87 | DTYPE_MAP = _performance.DTYPE_MAP 88 | require_cloud_storage = _device.require_cloud_storage 89 | -------------------------------------------------------------------------------- /official/utils/flags/flags_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import unittest 17 | 18 | from absl import flags 19 | import tensorflow as tf 20 | 21 | from official.utils.flags import core as flags_core # pylint: disable=g-bad-import-order 22 | 23 | 24 | def define_flags(): 25 | flags_core.define_base(num_gpu=False) 26 | flags_core.define_performance() 27 | flags_core.define_image() 28 | flags_core.define_benchmark() 29 | 30 | 31 | class BaseTester(unittest.TestCase): 32 | 33 | @classmethod 34 | def setUpClass(cls): 35 | super(BaseTester, cls).setUpClass() 36 | define_flags() 37 | 38 | def test_default_setting(self): 39 | """Test to ensure fields exist and defaults can be set. 40 | """ 41 | 42 | defaults = dict( 43 | data_dir="dfgasf", 44 | model_dir="dfsdkjgbs", 45 | train_epochs=534, 46 | epochs_between_evals=15, 47 | batch_size=256, 48 | hooks=["LoggingTensorHook"], 49 | num_parallel_calls=18, 50 | inter_op_parallelism_threads=5, 51 | intra_op_parallelism_threads=10, 52 | data_format="channels_first" 53 | ) 54 | 55 | flags_core.set_defaults(**defaults) 56 | flags_core.parse_flags() 57 | 58 | for key, value in defaults.items(): 59 | assert flags.FLAGS.get_flag_value(name=key, default=None) == value 60 | 61 | def test_benchmark_setting(self): 62 | defaults = dict( 63 | hooks=["LoggingMetricHook"], 64 | benchmark_log_dir="/tmp/12345", 65 | gcp_project="project_abc", 66 | ) 67 | 68 | flags_core.set_defaults(**defaults) 69 | flags_core.parse_flags() 70 | 71 | for key, value in defaults.items(): 72 | assert flags.FLAGS.get_flag_value(name=key, default=None) == value 73 | 74 | def test_booleans(self): 75 | """Test to ensure boolean flags trigger as expected. 76 | """ 77 | 78 | flags_core.parse_flags([__file__, "--use_synthetic_data"]) 79 | 80 | assert flags.FLAGS.use_synthetic_data 81 | 82 | def test_parse_dtype_info(self): 83 | for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128], 84 | ["fp32", tf.float32, 1]]: 85 | flags_core.parse_flags([__file__, "--dtype", dtype_str]) 86 | 87 | self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype) 88 | self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale) 89 | 90 | flags_core.parse_flags( 91 | [__file__, "--dtype", dtype_str, "--loss_scale", "5"]) 92 | 93 | self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5) 94 | 95 | with self.assertRaises(SystemExit): 96 | flags_core.parse_flags([__file__, "--dtype", "int8"]) 97 | 98 | 99 | if __name__ == "__main__": 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /official/utils/flags/guidelines.md: -------------------------------------------------------------------------------- 1 | # Using flags in official models 2 | 3 | 1. **All common flags must be incorporated in the models.** 4 | 5 | Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions, 6 | and channeled through `official.utils.flags.core`. For instance to define common supervised 7 | learning parameters one could use the following code: 8 | 9 | ```$xslt 10 | from absl import app as absl_app 11 | from absl import flags 12 | 13 | from official.utils.flags import core as flags_core 14 | 15 | 16 | def define_flags(): 17 | flags_core.define_base() 18 | flags.adopt_key_flags(flags_core) 19 | 20 | 21 | def main(_): 22 | flags_obj = flags.FLAGS 23 | print(flags_obj) 24 | 25 | 26 | if __name__ == "__main__" 27 | absl_app.run(main) 28 | ``` 29 | 2. **Validate flag values.** 30 | 31 | See the [Validators](#validators) section for implementation details. 32 | 33 | Validators in the official model repo should not access the file system, such as verifying 34 | that files exist, due to the strict ordering requirements. 35 | 36 | 3. **Flag values should not be mutated.** 37 | 38 | Instead of mutating flag values, use getter functions to return the desired values. An example 39 | getter function is `get_loss_scale` function below: 40 | 41 | ``` 42 | # Map string to (TensorFlow dtype, default loss scale) 43 | DTYPE_MAP = { 44 | "fp16": (tf.float16, 128), 45 | "fp32": (tf.float32, 1), 46 | } 47 | 48 | 49 | def get_loss_scale(flags_obj): 50 | if flags_obj.loss_scale is not None: 51 | return flags_obj.loss_scale 52 | return DTYPE_MAP[flags_obj.dtype][1] 53 | 54 | 55 | def main(_): 56 | flags_obj = flags.FLAGS() 57 | 58 | # Do not mutate flags_obj 59 | # if flags_obj.loss_scale is None: 60 | # flags_obj.loss_scale = DTYPE_MAP[flags_obj.dtype][1] # Don't do this 61 | 62 | print(get_loss_scale(flags_obj)) 63 | ... 64 | ``` -------------------------------------------------------------------------------- /official/utils/logs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/logs/__init__.py -------------------------------------------------------------------------------- /official/utils/logs/cloud_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities that interact with cloud service. 17 | """ 18 | 19 | import requests 20 | 21 | GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname" 22 | GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"} 23 | 24 | 25 | def on_gcp(): 26 | """Detect whether the current running environment is on GCP.""" 27 | try: 28 | # Timeout in 5 seconds, in case the test environment has connectivity issue. 29 | # There is not default timeout, which means it might block forever. 30 | response = requests.get( 31 | GCP_METADATA_URL, headers=GCP_METADATA_HEADER, timeout=5) 32 | return response.status_code == 200 33 | except requests.exceptions.RequestException: 34 | return False 35 | -------------------------------------------------------------------------------- /official/utils/logs/cloud_lib_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for cloud_lib.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unittest 23 | 24 | import mock 25 | import requests 26 | 27 | from official.utils.logs import cloud_lib 28 | 29 | 30 | class CloudLibTest(unittest.TestCase): 31 | 32 | @mock.patch("requests.get") 33 | def test_on_gcp(self, mock_requests_get): 34 | mock_response = mock.MagicMock() 35 | mock_requests_get.return_value = mock_response 36 | mock_response.status_code = 200 37 | 38 | self.assertEqual(cloud_lib.on_gcp(), True) 39 | 40 | @mock.patch("requests.get") 41 | def test_not_on_gcp(self, mock_requests_get): 42 | mock_requests_get.side_effect = requests.exceptions.ConnectionError() 43 | 44 | self.assertEqual(cloud_lib.on_gcp(), False) 45 | 46 | 47 | if __name__ == "__main__": 48 | unittest.main() 49 | -------------------------------------------------------------------------------- /official/utils/logs/guidelines.md: -------------------------------------------------------------------------------- 1 | # Logging in official models 2 | 3 | This library adds logging functions that print or save tensor values. Official models should define all common hooks 4 | (using hooks helper) and a benchmark logger. 5 | 6 | 1. **Training Hooks** 7 | 8 | Hooks are a TensorFlow concept that define specific actions at certain points of the execution. We use them to obtain and log 9 | tensor values during training. 10 | 11 | hooks_helper.py provides an easy way to create common hooks. The following hooks are currently defined: 12 | * LoggingTensorHook: Logs tensor values 13 | * ProfilerHook: Writes a timeline json that can be loaded into chrome://tracing. 14 | * ExamplesPerSecondHook: Logs the number of examples processed per second. 15 | * LoggingMetricHook: Similar to LoggingTensorHook, except that the tensors are logged in a format defined by our data 16 | anaylsis pipeline. 17 | 18 | 19 | 2. **Benchmarks** 20 | 21 | The benchmark logger provides useful functions for logging environment information, and evaluation results. 22 | The module also contains a context which is used to update the status of the run. 23 | 24 | Example usage: 25 | 26 | ``` 27 | from absl import app as absl_app 28 | 29 | from official.utils.logs import hooks_helper 30 | from official.utils.logs import logger 31 | 32 | def model_main(flags_obj): 33 | estimator = ... 34 | 35 | benchmark_logger = logger.get_benchmark_logger() 36 | benchmark_logger.log_run_info(...) 37 | 38 | train_hooks = hooks_helper.get_train_hooks(...) 39 | 40 | for epoch in range(10): 41 | estimator.train(..., hooks=train_hooks) 42 | eval_results = estimator.evaluate(...) 43 | 44 | # Log a dictionary of metrics 45 | benchmark_logger.log_evaluation_result(eval_results) 46 | 47 | # Log an individual metric 48 | benchmark_logger.log_metric(...) 49 | 50 | 51 | def main(_): 52 | with logger.benchmark_context(flags.FLAGS): 53 | model_main(flags.FLAGS) 54 | 55 | if __name__ == "__main__": 56 | # define flags 57 | absl_app.run(main) 58 | ``` 59 | -------------------------------------------------------------------------------- /official/utils/logs/hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Hook that counts examples per second every N steps or seconds.""" 17 | 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf # pylint: disable=g-bad-import-order 24 | 25 | from official.utils.logs import logger 26 | 27 | 28 | class ExamplesPerSecondHook(tf.estimator.SessionRunHook): 29 | """Hook to print out examples per second. 30 | 31 | Total time is tracked and then divided by the total number of steps 32 | to get the average step time and then batch_size is used to determine 33 | the running average of examples per second. The examples per second for the 34 | most recent interval is also logged. 35 | """ 36 | 37 | def __init__(self, 38 | batch_size, 39 | every_n_steps=None, 40 | every_n_secs=None, 41 | warm_steps=0, 42 | metric_logger=None): 43 | """Initializer for ExamplesPerSecondHook. 44 | 45 | Args: 46 | batch_size: Total batch size across all workers used to calculate 47 | examples/second from global time. 48 | every_n_steps: Log stats every n steps. 49 | every_n_secs: Log stats every n seconds. Exactly one of the 50 | `every_n_steps` or `every_n_secs` should be set. 51 | warm_steps: The number of steps to be skipped before logging and running 52 | average calculation. warm_steps steps refers to global steps across all 53 | workers, not on each worker 54 | metric_logger: instance of `BenchmarkLogger`, the benchmark logger that 55 | hook should use to write the log. If None, BaseBenchmarkLogger will 56 | be used. 57 | 58 | Raises: 59 | ValueError: if neither `every_n_steps` or `every_n_secs` is set, or 60 | both are set. 61 | """ 62 | 63 | if (every_n_steps is None) == (every_n_secs is None): 64 | raise ValueError("exactly one of every_n_steps" 65 | " and every_n_secs should be provided.") 66 | 67 | self._logger = metric_logger or logger.BaseBenchmarkLogger() 68 | 69 | self._timer = tf.estimator.SecondOrStepTimer( 70 | every_steps=every_n_steps, every_secs=every_n_secs) 71 | 72 | self._step_train_time = 0 73 | self._total_steps = 0 74 | self._batch_size = batch_size 75 | self._warm_steps = warm_steps 76 | # List of examples per second logged every_n_steps. 77 | self.current_examples_per_sec_list = [] 78 | 79 | def begin(self): 80 | """Called once before using the session to check global step.""" 81 | self._global_step_tensor = tf.compat.v1.train.get_global_step() 82 | if self._global_step_tensor is None: 83 | raise RuntimeError( 84 | "Global step should be created to use StepCounterHook.") 85 | 86 | def before_run(self, run_context): # pylint: disable=unused-argument 87 | """Called before each call to run(). 88 | 89 | Args: 90 | run_context: A SessionRunContext object. 91 | 92 | Returns: 93 | A SessionRunArgs object or None if never triggered. 94 | """ 95 | return tf.estimator.SessionRunArgs(self._global_step_tensor) 96 | 97 | def after_run(self, run_context, run_values): # pylint: disable=unused-argument 98 | """Called after each call to run(). 99 | 100 | Args: 101 | run_context: A SessionRunContext object. 102 | run_values: A SessionRunValues object. 103 | """ 104 | global_step = run_values.results 105 | 106 | if self._timer.should_trigger_for_step( 107 | global_step) and global_step > self._warm_steps: 108 | elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 109 | global_step) 110 | if elapsed_time is not None: 111 | self._step_train_time += elapsed_time 112 | self._total_steps += elapsed_steps 113 | 114 | # average examples per second is based on the total (accumulative) 115 | # training steps and training time so far 116 | average_examples_per_sec = self._batch_size * ( 117 | self._total_steps / self._step_train_time) 118 | # current examples per second is based on the elapsed training steps 119 | # and training time per batch 120 | current_examples_per_sec = self._batch_size * ( 121 | elapsed_steps / elapsed_time) 122 | # Logs entries to be read from hook during or after run. 123 | self.current_examples_per_sec_list.append(current_examples_per_sec) 124 | self._logger.log_metric( 125 | "average_examples_per_sec", average_examples_per_sec, 126 | global_step=global_step) 127 | 128 | self._logger.log_metric( 129 | "current_examples_per_sec", current_examples_per_sec, 130 | global_step=global_step) 131 | -------------------------------------------------------------------------------- /official/utils/logs/hooks_helper_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for hooks_helper.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import unittest 23 | 24 | import tensorflow as tf # pylint: disable=g-bad-import-order 25 | 26 | from official.utils.logs import hooks_helper 27 | 28 | 29 | class BaseTest(unittest.TestCase): 30 | 31 | def test_raise_in_non_list_names(self): 32 | with self.assertRaises(ValueError): 33 | hooks_helper.get_train_hooks( 34 | 'LoggingTensorHook, ProfilerHook', model_dir="", batch_size=256) 35 | 36 | def test_raise_in_invalid_names(self): 37 | invalid_names = ['StepCounterHook', 'StopAtStepHook'] 38 | with self.assertRaises(ValueError): 39 | hooks_helper.get_train_hooks(invalid_names, model_dir="", batch_size=256) 40 | 41 | def validate_train_hook_name(self, 42 | test_hook_name, 43 | expected_hook_name, 44 | **kwargs): 45 | returned_hook = hooks_helper.get_train_hooks( 46 | [test_hook_name], model_dir="", **kwargs) 47 | self.assertEqual(len(returned_hook), 1) 48 | self.assertIsInstance(returned_hook[0], tf.estimator.SessionRunHook) 49 | self.assertEqual(returned_hook[0].__class__.__name__.lower(), 50 | expected_hook_name) 51 | 52 | def test_get_train_hooks_logging_tensor_hook(self): 53 | self.validate_train_hook_name('LoggingTensorHook', 'loggingtensorhook') 54 | 55 | def test_get_train_hooks_profiler_hook(self): 56 | self.validate_train_hook_name('ProfilerHook', 'profilerhook') 57 | 58 | def test_get_train_hooks_examples_per_second_hook(self): 59 | self.validate_train_hook_name('ExamplesPerSecondHook', 60 | 'examplespersecondhook') 61 | 62 | def test_get_logging_metric_hook(self): 63 | test_hook_name = 'LoggingMetricHook' 64 | self.validate_train_hook_name(test_hook_name, 'loggingmetrichook') 65 | 66 | if __name__ == '__main__': 67 | tf.test.main() 68 | -------------------------------------------------------------------------------- /official/utils/logs/hooks_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for hooks.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import time 23 | 24 | import tensorflow as tf # pylint: disable=g-bad-import-order 25 | 26 | from official.utils.logs import hooks 27 | from official.utils.testing import mock_lib 28 | 29 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.DEBUG) 30 | 31 | 32 | class ExamplesPerSecondHookTest(tf.test.TestCase): 33 | """Tests for the ExamplesPerSecondHook. 34 | 35 | In the test, we explicitly run global_step tensor after train_op in order to 36 | keep the global_step value and the train_op (which increase the glboal_step 37 | by 1) consistent. This is to correct the discrepancies in reported global_step 38 | value when running on GPUs. 39 | """ 40 | 41 | def setUp(self): 42 | """Mock out logging calls to verify if correct info is being monitored.""" 43 | self._logger = mock_lib.MockBenchmarkLogger() 44 | 45 | self.graph = tf.Graph() 46 | with self.graph.as_default(): 47 | tf.compat.v1.train.create_global_step() 48 | self.train_op = tf.compat.v1.assign_add( 49 | tf.compat.v1.train.get_global_step(), 1) 50 | self.global_step = tf.compat.v1.train.get_global_step() 51 | 52 | def test_raise_in_both_secs_and_steps(self): 53 | with self.assertRaises(ValueError): 54 | hooks.ExamplesPerSecondHook( 55 | batch_size=256, 56 | every_n_steps=10, 57 | every_n_secs=20, 58 | metric_logger=self._logger) 59 | 60 | def test_raise_in_none_secs_and_steps(self): 61 | with self.assertRaises(ValueError): 62 | hooks.ExamplesPerSecondHook( 63 | batch_size=256, 64 | every_n_steps=None, 65 | every_n_secs=None, 66 | metric_logger=self._logger) 67 | 68 | def _validate_log_every_n_steps(self, every_n_steps, warm_steps): 69 | hook = hooks.ExamplesPerSecondHook( 70 | batch_size=256, 71 | every_n_steps=every_n_steps, 72 | warm_steps=warm_steps, 73 | metric_logger=self._logger) 74 | 75 | with tf.compat.v1.train.MonitoredSession( 76 | tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess: 77 | for _ in range(every_n_steps): 78 | # Explicitly run global_step after train_op to get the accurate 79 | # global_step value 80 | mon_sess.run(self.train_op) 81 | mon_sess.run(self.global_step) 82 | # Nothing should be in the list yet 83 | self.assertFalse(self._logger.logged_metric) 84 | 85 | mon_sess.run(self.train_op) 86 | global_step_val = mon_sess.run(self.global_step) 87 | 88 | if global_step_val > warm_steps: 89 | self._assert_metrics() 90 | else: 91 | # Nothing should be in the list yet 92 | self.assertFalse(self._logger.logged_metric) 93 | 94 | # Add additional run to verify proper reset when called multiple times. 95 | prev_log_len = len(self._logger.logged_metric) 96 | mon_sess.run(self.train_op) 97 | global_step_val = mon_sess.run(self.global_step) 98 | 99 | if every_n_steps == 1 and global_step_val > warm_steps: 100 | # Each time, we log two additional metrics. Did exactly 2 get added? 101 | self.assertEqual(len(self._logger.logged_metric), prev_log_len + 2) 102 | else: 103 | # No change in the size of the metric list. 104 | self.assertEqual(len(self._logger.logged_metric), prev_log_len) 105 | 106 | def test_examples_per_sec_every_1_steps(self): 107 | with self.graph.as_default(): 108 | self._validate_log_every_n_steps(1, 0) 109 | 110 | def test_examples_per_sec_every_5_steps(self): 111 | with self.graph.as_default(): 112 | self._validate_log_every_n_steps(5, 0) 113 | 114 | def test_examples_per_sec_every_1_steps_with_warm_steps(self): 115 | with self.graph.as_default(): 116 | self._validate_log_every_n_steps(1, 10) 117 | 118 | def test_examples_per_sec_every_5_steps_with_warm_steps(self): 119 | with self.graph.as_default(): 120 | self._validate_log_every_n_steps(5, 10) 121 | 122 | def _validate_log_every_n_secs(self, every_n_secs): 123 | hook = hooks.ExamplesPerSecondHook( 124 | batch_size=256, 125 | every_n_steps=None, 126 | every_n_secs=every_n_secs, 127 | metric_logger=self._logger) 128 | 129 | with tf.compat.v1.train.MonitoredSession( 130 | tf.compat.v1.train.ChiefSessionCreator(), [hook]) as mon_sess: 131 | # Explicitly run global_step after train_op to get the accurate 132 | # global_step value 133 | mon_sess.run(self.train_op) 134 | mon_sess.run(self.global_step) 135 | # Nothing should be in the list yet 136 | self.assertFalse(self._logger.logged_metric) 137 | time.sleep(every_n_secs) 138 | 139 | mon_sess.run(self.train_op) 140 | mon_sess.run(self.global_step) 141 | self._assert_metrics() 142 | 143 | def test_examples_per_sec_every_1_secs(self): 144 | with self.graph.as_default(): 145 | self._validate_log_every_n_secs(1) 146 | 147 | def test_examples_per_sec_every_5_secs(self): 148 | with self.graph.as_default(): 149 | self._validate_log_every_n_secs(5) 150 | 151 | def _assert_metrics(self): 152 | metrics = self._logger.logged_metric 153 | self.assertEqual(metrics[-2]["name"], "average_examples_per_sec") 154 | self.assertEqual(metrics[-1]["name"], "current_examples_per_sec") 155 | 156 | 157 | if __name__ == "__main__": 158 | tf.test.main() 159 | -------------------------------------------------------------------------------- /official/utils/logs/metric_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Session hook for logging benchmark metric.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf # pylint: disable=g-bad-import-order 22 | 23 | 24 | class LoggingMetricHook(tf.estimator.LoggingTensorHook): 25 | """Hook to log benchmark metric information. 26 | 27 | This hook is very similar as tf.train.LoggingTensorHook, which logs given 28 | tensors every N local steps, every N seconds, or at the end. The metric 29 | information will be logged to given log_dir or via metric_logger in JSON 30 | format, which can be consumed by data analysis pipeline later. 31 | 32 | Note that if `at_end` is True, `tensors` should not include any tensor 33 | whose evaluation produces a side effect such as consuming additional inputs. 34 | """ 35 | 36 | def __init__(self, tensors, metric_logger=None, 37 | every_n_iter=None, every_n_secs=None, at_end=False): 38 | """Initializer for LoggingMetricHook. 39 | 40 | Args: 41 | tensors: `dict` that maps string-valued tags to tensors/tensor names, 42 | or `iterable` of tensors/tensor names. 43 | metric_logger: instance of `BenchmarkLogger`, the benchmark logger that 44 | hook should use to write the log. 45 | every_n_iter: `int`, print the values of `tensors` once every N local 46 | steps taken on the current worker. 47 | every_n_secs: `int` or `float`, print the values of `tensors` once every N 48 | seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 49 | provided. 50 | at_end: `bool` specifying whether to print the values of `tensors` at the 51 | end of the run. 52 | 53 | Raises: 54 | ValueError: 55 | 1. `every_n_iter` is non-positive, or 56 | 2. Exactly one of every_n_iter and every_n_secs should be provided. 57 | 3. Exactly one of log_dir and metric_logger should be provided. 58 | """ 59 | super(LoggingMetricHook, self).__init__( 60 | tensors=tensors, 61 | every_n_iter=every_n_iter, 62 | every_n_secs=every_n_secs, 63 | at_end=at_end) 64 | 65 | if metric_logger is None: 66 | raise ValueError("metric_logger should be provided.") 67 | self._logger = metric_logger 68 | 69 | def begin(self): 70 | super(LoggingMetricHook, self).begin() 71 | self._global_step_tensor = tf.compat.v1.train.get_global_step() 72 | if self._global_step_tensor is None: 73 | raise RuntimeError( 74 | "Global step should be created to use LoggingMetricHook.") 75 | if self._global_step_tensor.name not in self._current_tensors: 76 | self._current_tensors[self._global_step_tensor.name] = ( 77 | self._global_step_tensor) 78 | 79 | def after_run(self, unused_run_context, run_values): 80 | # should_trigger is a internal state that populated at before_run, and it is 81 | # using self_timer to determine whether it should trigger. 82 | if self._should_trigger: 83 | self._log_metric(run_values.results) 84 | 85 | self._iter_count += 1 86 | 87 | def end(self, session): 88 | if self._log_at_end: 89 | values = session.run(self._current_tensors) 90 | self._log_metric(values) 91 | 92 | def _log_metric(self, tensor_values): 93 | self._timer.update_last_triggered_step(self._iter_count) 94 | global_step = tensor_values[self._global_step_tensor.name] 95 | # self._tag_order is populated during the init of LoggingTensorHook 96 | for tag in self._tag_order: 97 | self._logger.log_metric(tag, tensor_values[tag], global_step=global_step) 98 | -------------------------------------------------------------------------------- /official/utils/misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/misc/__init__.py -------------------------------------------------------------------------------- /official/utils/misc/distribution_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ Tests for distribution util functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf # pylint: disable=g-bad-import-order 22 | 23 | from official.utils.misc import distribution_utils 24 | 25 | 26 | class GetDistributionStrategyTest(tf.test.TestCase): 27 | """Tests for get_distribution_strategy.""" 28 | def test_one_device_strategy_cpu(self): 29 | ds = distribution_utils.get_distribution_strategy(num_gpus=0) 30 | self.assertEquals(ds.num_replicas_in_sync, 1) 31 | self.assertEquals(len(ds.extended.worker_devices), 1) 32 | self.assertIn('CPU', ds.extended.worker_devices[0]) 33 | 34 | def test_one_device_strategy_gpu(self): 35 | ds = distribution_utils.get_distribution_strategy(num_gpus=1) 36 | self.assertEquals(ds.num_replicas_in_sync, 1) 37 | self.assertEquals(len(ds.extended.worker_devices), 1) 38 | self.assertIn('GPU', ds.extended.worker_devices[0]) 39 | 40 | def test_mirrored_strategy(self): 41 | ds = distribution_utils.get_distribution_strategy(num_gpus=5) 42 | self.assertEquals(ds.num_replicas_in_sync, 5) 43 | self.assertEquals(len(ds.extended.worker_devices), 5) 44 | for device in ds.extended.worker_devices: 45 | self.assertIn('GPU', device) 46 | 47 | 48 | class PerDeviceBatchSizeTest(tf.test.TestCase): 49 | """Tests for per_device_batch_size.""" 50 | 51 | def test_batch_size(self): 52 | self.assertEquals( 53 | distribution_utils.per_device_batch_size(147, num_gpus=0), 147) 54 | self.assertEquals( 55 | distribution_utils.per_device_batch_size(147, num_gpus=1), 147) 56 | self.assertEquals( 57 | distribution_utils.per_device_batch_size(147, num_gpus=7), 21) 58 | 59 | def test_batch_size_with_remainder(self): 60 | with self.assertRaises(ValueError): 61 | distribution_utils.per_device_batch_size(147, num_gpus=5) 62 | 63 | 64 | if __name__ == "__main__": 65 | tf.test.main() 66 | -------------------------------------------------------------------------------- /official/utils/misc/model_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Miscellaneous functions that can be called by models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numbers 22 | 23 | import tensorflow as tf 24 | from tensorflow.python.util import nest 25 | 26 | 27 | def past_stop_threshold(stop_threshold, eval_metric): 28 | """Return a boolean representing whether a model should be stopped. 29 | 30 | Args: 31 | stop_threshold: float, the threshold above which a model should stop 32 | training. 33 | eval_metric: float, the current value of the relevant metric to check. 34 | 35 | Returns: 36 | True if training should stop, False otherwise. 37 | 38 | Raises: 39 | ValueError: if either stop_threshold or eval_metric is not a number 40 | """ 41 | if stop_threshold is None: 42 | return False 43 | 44 | if not isinstance(stop_threshold, numbers.Number): 45 | raise ValueError("Threshold for checking stop conditions must be a number.") 46 | if not isinstance(eval_metric, numbers.Number): 47 | raise ValueError("Eval metric being checked against stop conditions " 48 | "must be a number.") 49 | 50 | if eval_metric >= stop_threshold: 51 | tf.compat.v1.logging.info( 52 | "Stop threshold of {} was passed with metric value {}.".format( 53 | stop_threshold, eval_metric)) 54 | return True 55 | 56 | return False 57 | 58 | 59 | def generate_synthetic_data( 60 | input_shape, input_value=0, input_dtype=None, label_shape=None, 61 | label_value=0, label_dtype=None): 62 | """Create a repeating dataset with constant values. 63 | 64 | Args: 65 | input_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of 66 | the input data. 67 | input_value: Value of each input element. 68 | input_dtype: Input dtype. If None, will be inferred by the input value. 69 | label_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of 70 | the label data. 71 | label_value: Value of each input element. 72 | label_dtype: Input dtype. If None, will be inferred by the target value. 73 | 74 | Returns: 75 | Dataset of tensors or tuples of tensors (if label_shape is set). 76 | """ 77 | # TODO(kathywu): Replace with SyntheticDataset once it is in contrib. 78 | element = input_element = nest.map_structure( 79 | lambda s: tf.constant(input_value, input_dtype, s), input_shape) 80 | 81 | if label_shape: 82 | label_element = nest.map_structure( 83 | lambda s: tf.constant(label_value, label_dtype, s), label_shape) 84 | element = (input_element, label_element) 85 | 86 | return tf.data.Dataset.from_tensors(element).repeat() 87 | 88 | 89 | def apply_clean(flags_obj): 90 | if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir): 91 | tf.compat.v1.logging.info("--clean flag set. Removing existing model dir:" 92 | " {}".format(flags_obj.model_dir)) 93 | tf.io.gfile.rmtree(flags_obj.model_dir) 94 | -------------------------------------------------------------------------------- /official/utils/misc/model_helpers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ Tests for Model Helper functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf # pylint: disable=g-bad-import-order 22 | 23 | from official.utils.misc import model_helpers 24 | 25 | 26 | class PastStopThresholdTest(tf.test.TestCase): 27 | """Tests for past_stop_threshold.""" 28 | 29 | def test_past_stop_threshold(self): 30 | """Tests for normal operating conditions.""" 31 | self.assertTrue(model_helpers.past_stop_threshold(0.54, 1)) 32 | self.assertTrue(model_helpers.past_stop_threshold(54, 100)) 33 | self.assertFalse(model_helpers.past_stop_threshold(0.54, 0.1)) 34 | self.assertFalse(model_helpers.past_stop_threshold(-0.54, -1.5)) 35 | self.assertTrue(model_helpers.past_stop_threshold(-0.54, 0)) 36 | self.assertTrue(model_helpers.past_stop_threshold(0, 0)) 37 | self.assertTrue(model_helpers.past_stop_threshold(0.54, 0.54)) 38 | 39 | def test_past_stop_threshold_none_false(self): 40 | """Tests that check None returns false.""" 41 | self.assertFalse(model_helpers.past_stop_threshold(None, -1.5)) 42 | self.assertFalse(model_helpers.past_stop_threshold(None, None)) 43 | self.assertFalse(model_helpers.past_stop_threshold(None, 1.5)) 44 | # Zero should be okay, though. 45 | self.assertTrue(model_helpers.past_stop_threshold(0, 1.5)) 46 | 47 | def test_past_stop_threshold_not_number(self): 48 | """Tests for error conditions.""" 49 | with self.assertRaises(ValueError): 50 | model_helpers.past_stop_threshold("str", 1) 51 | 52 | with self.assertRaises(ValueError): 53 | model_helpers.past_stop_threshold("str", tf.constant(5)) 54 | 55 | with self.assertRaises(ValueError): 56 | model_helpers.past_stop_threshold("str", "another") 57 | 58 | with self.assertRaises(ValueError): 59 | model_helpers.past_stop_threshold(0, None) 60 | 61 | with self.assertRaises(ValueError): 62 | model_helpers.past_stop_threshold(0.7, "str") 63 | 64 | with self.assertRaises(ValueError): 65 | model_helpers.past_stop_threshold(tf.constant(4), None) 66 | 67 | 68 | class SyntheticDataTest(tf.test.TestCase): 69 | """Tests for generate_synthetic_data.""" 70 | 71 | def test_generate_synethetic_data(self): 72 | input_element, label_element = tf.compat.v1.data.make_one_shot_iterator( 73 | model_helpers.generate_synthetic_data(input_shape=tf.TensorShape([5]), 74 | input_value=123, 75 | input_dtype=tf.float32, 76 | label_shape=tf.TensorShape([]), 77 | label_value=456, 78 | label_dtype=tf.int32)).get_next() 79 | 80 | with self.test_session() as sess: 81 | for n in range(5): 82 | inp, lab = sess.run((input_element, label_element)) 83 | self.assertAllClose(inp, [123., 123., 123., 123., 123.]) 84 | self.assertEquals(lab, 456) 85 | 86 | def test_generate_only_input_data(self): 87 | d = model_helpers.generate_synthetic_data( 88 | input_shape=tf.TensorShape([4]), 89 | input_value=43.5, 90 | input_dtype=tf.float32) 91 | 92 | element = tf.compat.v1.data.make_one_shot_iterator(d).get_next() 93 | self.assertFalse(isinstance(element, tuple)) 94 | 95 | with self.test_session() as sess: 96 | inp = sess.run(element) 97 | self.assertAllClose(inp, [43.5, 43.5, 43.5, 43.5]) 98 | 99 | def test_generate_nested_data(self): 100 | d = model_helpers.generate_synthetic_data( 101 | input_shape={'a': tf.TensorShape([2]), 102 | 'b': {'c': tf.TensorShape([3]), 'd': tf.TensorShape([])}}, 103 | input_value=1.1) 104 | 105 | element = tf.compat.v1.data.make_one_shot_iterator(d).get_next() 106 | self.assertIn('a', element) 107 | self.assertIn('b', element) 108 | self.assertEquals(len(element['b']), 2) 109 | self.assertIn('c', element['b']) 110 | self.assertIn('d', element['b']) 111 | self.assertNotIn('c', element) 112 | 113 | with self.test_session() as sess: 114 | inp = sess.run(element) 115 | self.assertAllClose(inp['a'], [1.1, 1.1]) 116 | self.assertAllClose(inp['b']['c'], [1.1, 1.1, 1.1]) 117 | self.assertAllClose(inp['b']['d'], 1.1) 118 | 119 | 120 | if __name__ == "__main__": 121 | tf.test.main() 122 | -------------------------------------------------------------------------------- /official/utils/testing/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/.DS_Store -------------------------------------------------------------------------------- /official/utils/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/__init__.py -------------------------------------------------------------------------------- /official/utils/testing/integration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helper code to run complete models from within python. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import shutil 24 | import sys 25 | import tempfile 26 | 27 | from absl import flags 28 | 29 | from official.utils.flags import core as flags_core 30 | 31 | 32 | def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): 33 | """Performs a minimal run of a model. 34 | 35 | This function is intended to test for syntax errors throughout a model. A 36 | very limited run is performed using synthetic data. 37 | 38 | Args: 39 | main: The primary function used to exercise a code path. Generally this 40 | function is ".main(argv)". 41 | tmp_root: Root path for the temp directory created by the test class. 42 | extra_flags: Additional flags passed by the caller of this function. 43 | synth: Use synthetic data. 44 | max_train: Maximum number of allowed training steps. 45 | """ 46 | 47 | extra_flags = [] if extra_flags is None else extra_flags 48 | 49 | model_dir = tempfile.mkdtemp(dir=tmp_root) 50 | 51 | args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1", 52 | "--epochs_between_evals", "1"] + extra_flags 53 | 54 | if synth: 55 | args.append("--use_synthetic_data") 56 | 57 | if max_train is not None: 58 | args.extend(["--max_train_steps", str(max_train)]) 59 | 60 | try: 61 | flags_core.parse_flags(argv=args) 62 | main(flags.FLAGS) 63 | finally: 64 | if os.path.exists(model_dir): 65 | shutil.rmtree(model_dir) 66 | -------------------------------------------------------------------------------- /official/utils/testing/mock_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Mock objects and related functions for testing.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | class MockBenchmarkLogger(object): 24 | """This is a mock logger that can be used in dependent tests.""" 25 | 26 | def __init__(self): 27 | self.logged_metric = [] 28 | 29 | def log_metric(self, name, value, unit=None, global_step=None, 30 | extras=None): 31 | self.logged_metric.append({ 32 | "name": name, 33 | "value": float(value), 34 | "unit": unit, 35 | "global_step": global_step, 36 | "extras": extras}) 37 | -------------------------------------------------------------------------------- /official/utils/testing/pylint.rcfile: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable=R,W, 3 | bad-option-value 4 | 5 | [REPORTS] 6 | # Tells whether to display a full report or only the messages 7 | reports=no 8 | 9 | # Activate the evaluation score. 10 | score=no 11 | 12 | [BASIC] 13 | 14 | # Regular expression matching correct argument names 15 | argument-rgx=^[a-z][a-z0-9_]*$ 16 | 17 | # Regular expression matching correct attribute names 18 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 19 | 20 | # Regular expression matching correct class attribute names 21 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 22 | 23 | # Regular expression matching correct class names 24 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 25 | 26 | # Regular expression matching correct constant names 27 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 28 | 29 | # Minimum line length for functions/classes that require docstrings, shorter 30 | # ones are exempt. 31 | docstring-min-length=10 32 | 33 | # Regular expression matching correct function names 34 | function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 35 | 36 | # Good variable names which should always be accepted, separated by a comma 37 | good-names=main,_ 38 | 39 | # Regular expression matching correct inline iteration names 40 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 41 | 42 | # Regular expression matching correct method names 43 | method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*)|(setUp|tearDown))$ 44 | 45 | # Regular expression matching correct module names 46 | module-rgx=^(_?[a-z][a-z0-9_]*)|__init__|PRESUBMIT|PRESUBMIT_unittest$ 47 | 48 | # Regular expression which should only match function or class names that do 49 | # not require a docstring. 50 | no-docstring-rgx=(__.*__|main|.*ArgParser) 51 | 52 | # Naming hint for variable names 53 | variable-name-hint=[a-z_][a-z0-9_]{2,30}$ 54 | 55 | # Regular expression matching correct variable names 56 | variable-rgx=^[a-z][a-z0-9_]*$ 57 | 58 | [TYPECHECK] 59 | 60 | # List of module names for which member attributes should not be checked 61 | # (useful for modules/projects where namespaces are manipulated during runtime 62 | # and thus existing member attributes cannot be deduced by static analysis. It 63 | # supports qualified module names, as well as Unix pattern matching. 64 | ignored-modules=absl, absl.*, official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.* 65 | 66 | 67 | [CLASSES] 68 | 69 | # List of method names used to declare (i.e. assign) instance attributes. 70 | defining-attr-methods=__init__,__new__,setUp 71 | 72 | # List of member names, which should be excluded from the protected access 73 | # warning. 74 | exclude-protected=_asdict,_fields,_replace,_source,_make 75 | 76 | # This is deprecated, because it is not used anymore. 77 | #ignore-iface-methods= 78 | 79 | # List of valid names for the first argument in a class method. 80 | valid-classmethod-first-arg=cls,class_ 81 | 82 | # List of valid names for the first argument in a metaclass class method. 83 | valid-metaclass-classmethod-first-arg=mcs 84 | 85 | 86 | [DESIGN] 87 | 88 | # Argument names that match this expression will be ignored. Default to name 89 | # with leading underscore 90 | ignored-argument-names=_.* 91 | 92 | # Maximum number of arguments for function / method 93 | max-args=5 94 | 95 | # Maximum number of attributes for a class (see R0902). 96 | max-attributes=7 97 | 98 | # Maximum number of branch for function / method body 99 | max-branches=12 100 | 101 | # Maximum number of locals for function / method body 102 | max-locals=15 103 | 104 | # Maximum number of parents for a class (see R0901). 105 | max-parents=7 106 | 107 | # Maximum number of public methods for a class (see R0904). 108 | max-public-methods=20 109 | 110 | # Maximum number of return / yield for function / method body 111 | max-returns=6 112 | 113 | # Maximum number of statements in function / method body 114 | max-statements=50 115 | 116 | # Minimum number of public methods for a class (see R0903). 117 | min-public-methods=2 118 | 119 | 120 | [EXCEPTIONS] 121 | 122 | # Exceptions that will emit a warning when being caught. Defaults to 123 | # "Exception" 124 | overgeneral-exceptions=StandardError,Exception,BaseException 125 | 126 | 127 | [FORMAT] 128 | 129 | # Number of spaces of indent required inside a hanging or continued line. 130 | indent-after-paren=4 131 | 132 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 133 | # tab). 134 | indent-string=' ' 135 | 136 | # Maximum number of characters on a single line. 137 | max-line-length=80 138 | 139 | # Maximum number of lines in a module 140 | max-module-lines=99999 141 | 142 | # List of optional constructs for which whitespace checking is disabled 143 | no-space-check= 144 | 145 | # Allow the body of an if to be on the same line as the test if there is no 146 | # else. 147 | single-line-if-stmt=yes 148 | 149 | # Allow URLs and comment type annotations to exceed the max line length as neither can be easily 150 | # split across lines. 151 | ignore-long-lines=^\s*(?:(# )??$|# type:) 152 | 153 | 154 | [VARIABLES] 155 | 156 | # List of additional names supposed to be defined in builtins. Remember that 157 | # you should avoid to define new builtins when possible. 158 | additional-builtins= 159 | 160 | # List of strings which can identify a callback function by name. A callback 161 | # name must start or end with one of those strings. 162 | callbacks=cb_,_cb 163 | 164 | # A regular expression matching the name of dummy variables (i.e. expectedly 165 | # not used). 166 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 167 | 168 | # Tells whether we should check for unused import in __init__ files. 169 | init-import=no 170 | -------------------------------------------------------------------------------- /official/utils/testing/reference_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/.DS_Store -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/reference_data_test/.DS_Store -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/dense/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/reference_data_test/dense/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/dense/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/reference_data_test/dense/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/dense/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/reference_data_test/dense/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/dense/results.json: -------------------------------------------------------------------------------- 1 | [1, 1, 0.4701630473136902, 0.4701630473136902, 0.4701630473136902] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/dense/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180325", "v1.7.0-rc1-750-g6c1737e6c8"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/uniform_random/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/reference_data_test/uniform_random/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/uniform_random/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- 1 | ʼ|? -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/uniform_random/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/reference_data_test/uniform_random/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/uniform_random/results.json: -------------------------------------------------------------------------------- 1 | [0.9872556924819946] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/reference_data_test/uniform_random/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180325", "v1.7.0-rc1-750-g6c1737e6c8"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/.DS_Store -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.08920872211456299, 0.8918969631195068, 4064.7060546875, 32, 4, 4, 8, 0.0, 0.10715862363576889, 2344.4775390625] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-1_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.918815016746521, 0.1826801300048828, 4064.4677734375, 32, 4, 4, 8, -1.3153012990951538, 0.011247094720602036, 261.84716796875] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_projection_version-2_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.1677999496459961, 0.7767924070358276, 4089.44189453125, 32, 8, 8, 4, 0.8615571856498718, 1.1359407901763916, 5806.876953125] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-1_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.8239736557006836, 0.3485994338989258, 4108.87548828125, 32, 8, 8, 4, 0.16798323392868042, -0.2975311279296875, 2860.068359375] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_bottleneck_version-2_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.5349493026733398, 0.5126370191574097, 4070.01220703125, 32, 4, 4, 8, 0.0, 2.7680201530456543, 2341.23486328125] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-1_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.7820245027542114, 0.8173515796661377, 4095.256591796875, 32, 4, 4, 8, 0.0679062008857727, 0.009305447340011597, -137.36178588867188] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_projection_version-2_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.23128163814544678, 0.22117376327514648, 4100.51806640625, 32, 8, 8, 4, 1.1768392324447632, 0.2728465795516968, 5832.6416015625] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-1_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/results.json: -------------------------------------------------------------------------------- 1 | [32, 8, 8, 4, 0.7616699934005737, 0.5485763549804688, 4106.8720703125, 32, 8, 8, 4, -0.056346118450164795, 0.5792689919471741, 2972.37255859375] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch-size-32_building_version-2_width-8_channels-4/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch_norm/expected_graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch_norm/expected_graph -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch_norm/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch_norm/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch_norm/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/utils/testing/reference_data/resnet/batch_norm/model.ckpt.index -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch_norm/results.json: -------------------------------------------------------------------------------- 1 | [32, 16, 16, 3, 0.9722558259963989, 0.18413543701171875, 12374.20703125, 32, 16, 16, 3, 1.6126631498336792, -1.096894383430481, -0.041595458984375] -------------------------------------------------------------------------------- /official/utils/testing/reference_data/resnet/batch_norm/tf_version.json: -------------------------------------------------------------------------------- 1 | ["1.8.0-dev20180408", "v1.7.0-1345-gb874783ccd"] -------------------------------------------------------------------------------- /official/utils/testing/reference_data_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """This module tests generic behavior of reference data tests. 16 | 17 | This test is not intended to test every layer of interest, and models should 18 | test the layers that affect them. This test is primarily focused on ensuring 19 | that reference_data.BaseTest functions as intended. If there is a legitimate 20 | change such as a change to TensorFlow which changes graph construction, tests 21 | can be regenerated with the following command: 22 | 23 | $ python3 reference_data_test.py -regen 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import sys 31 | import unittest 32 | import warnings 33 | 34 | import tensorflow as tf # pylint: disable=g-bad-import-order 35 | from official.utils.testing import reference_data 36 | 37 | 38 | class GoldenBaseTest(reference_data.BaseTest): 39 | """Class to ensure that reference data testing runs properly.""" 40 | 41 | @property 42 | def test_name(self): 43 | return "reference_data_test" 44 | 45 | def _uniform_random_ops(self, test=False, wrong_name=False, wrong_shape=False, 46 | bad_seed=False, bad_function=False): 47 | """Tests number generation and failure modes. 48 | 49 | This test is of a very simple graph: the generation of a 1x1 random tensor. 50 | However, it is also used to confirm that the tests are actually checking 51 | properly by failing in predefined ways. 52 | 53 | Args: 54 | test: Whether or not to run as a test case. 55 | wrong_name: Whether to assign the wrong name to the tensor. 56 | wrong_shape: Whether to create a tensor with the wrong shape. 57 | bad_seed: Whether or not to perturb the random seed. 58 | bad_function: Whether to perturb the correctness function. 59 | """ 60 | name = "uniform_random" 61 | 62 | g = tf.Graph() 63 | with g.as_default(): 64 | seed = self.name_to_seed(name) 65 | seed = seed + 1 if bad_seed else seed 66 | tf.compat.v1.set_random_seed(seed) 67 | tensor_name = "wrong_tensor" if wrong_name else "input_tensor" 68 | tensor_shape = (1, 2) if wrong_shape else (1, 1) 69 | input_tensor = tf.compat.v1.get_variable( 70 | tensor_name, dtype=tf.float32, 71 | initializer=tf.random.uniform(tensor_shape, maxval=1) 72 | ) 73 | 74 | def correctness_function(tensor_result): 75 | result = float(tensor_result[0, 0]) 76 | result = result + 0.1 if bad_function else result 77 | return [result] 78 | 79 | self._save_or_test_ops( 80 | name=name, graph=g, ops_to_eval=[input_tensor], test=test, 81 | correctness_function=correctness_function 82 | ) 83 | 84 | def _dense_ops(self, test=False): 85 | name = "dense" 86 | 87 | g = tf.Graph() 88 | with g.as_default(): 89 | tf.compat.v1.set_random_seed(self.name_to_seed(name)) 90 | input_tensor = tf.compat.v1.get_variable( 91 | "input_tensor", dtype=tf.float32, 92 | initializer=tf.random.uniform((1, 2), maxval=1) 93 | ) 94 | layer = tf.compat.v1.layers.dense(inputs=input_tensor, units=4) 95 | layer = tf.compat.v1.layers.dense(inputs=layer, units=1) 96 | 97 | self._save_or_test_ops( 98 | name=name, graph=g, ops_to_eval=[layer], test=test, 99 | correctness_function=self.default_correctness_function 100 | ) 101 | 102 | def test_uniform_random(self): 103 | self._uniform_random_ops(test=True) 104 | 105 | def test_tensor_name_error(self): 106 | with self.assertRaises(AssertionError): 107 | self._uniform_random_ops(test=True, wrong_name=True) 108 | 109 | def test_tensor_shape_error(self): 110 | with self.assertRaises(AssertionError): 111 | self._uniform_random_ops(test=True, wrong_shape=True) 112 | 113 | @unittest.skipIf(sys.version_info[0] == 2, 114 | "catch_warning doesn't catch tf.logging.warn in py 2.") 115 | def test_bad_seed(self): 116 | with warnings.catch_warnings(record=True) as warn_catch: 117 | self._uniform_random_ops(test=True, bad_seed=True) 118 | assert len(warn_catch) == 1, "Test did not warn of minor graph change." 119 | 120 | def test_incorrectness_function(self): 121 | with self.assertRaises(AssertionError): 122 | self._uniform_random_ops(test=True, bad_function=True) 123 | 124 | def test_dense(self): 125 | self._dense_ops(test=True) 126 | 127 | def regenerate(self): 128 | self._uniform_random_ops(test=False) 129 | self._dense_ops(test=False) 130 | 131 | 132 | if __name__ == "__main__": 133 | reference_data.main(argv=sys.argv, test_class=GoldenBaseTest) 134 | -------------------------------------------------------------------------------- /official/utils/testing/scripts/presubmit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Presubmit script that run tests and lint under local environment. 18 | # Make sure that tensorflow and pylint is installed. 19 | # usage: models >: ./official/utils/testing/scripts/presubmit.sh 20 | # usage: models >: ./official/utils/testing/scripts/presubmit.sh lint py2_test py3_test 21 | set +x 22 | 23 | SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 24 | cd "$SCRIPT_DIR/../../../.." 25 | MODEL_ROOT="$(pwd)" 26 | 27 | export PYTHONPATH="$PYTHONPATH:${MODEL_ROOT}" 28 | 29 | cd official 30 | 31 | lint() { 32 | local exit_code=0 33 | 34 | RC_FILE="utils/testing/pylint.rcfile" 35 | PROTO_SKIP="DO\sNOT\sEDIT!" 36 | 37 | echo "===========Running lint test============" 38 | for file in `find . -name '*.py' ! -name '*test.py' -print` 39 | do 40 | if grep ${PROTO_SKIP} ${file}; then 41 | echo "Linting ${file} (Skipped: Machine generated file)" 42 | else 43 | echo "Linting ${file}" 44 | pylint --rcfile="${RC_FILE}" "${file}" || exit_code=$? 45 | fi 46 | done 47 | 48 | # More lenient for test files. 49 | for file in `find . -name '*test.py' -print` 50 | do 51 | echo "Linting ${file}" 52 | pylint --rcfile="${RC_FILE}" --disable=missing-docstring,protected-access "${file}" || exit_code=$? 53 | done 54 | 55 | return "${exit_code}" 56 | } 57 | 58 | py_test() { 59 | local PY_BINARY="$1" 60 | local exit_code=0 61 | 62 | echo "===========Running Python test============" 63 | 64 | for test_file in `find . -name '*test.py' -print` 65 | do 66 | echo "Testing ${test_file}" 67 | ${PY_BINARY} "${test_file}" || exit_code=$? 68 | done 69 | 70 | return "${exit_code}" 71 | } 72 | 73 | py2_test() { 74 | local PY_BINARY=$(which python2) 75 | py_test "$PY_BINARY" 76 | return $? 77 | } 78 | 79 | py3_test() { 80 | local PY_BINARY=$(which python3) 81 | py_test "$PY_BINARY" 82 | return $? 83 | } 84 | 85 | test_result=0 86 | 87 | if [ "$#" -eq 0 ]; then 88 | TESTS="lint py2_test py3_test" 89 | else 90 | TESTS="$@" 91 | fi 92 | 93 | for t in "${TESTS}"; do 94 | ${t} || test_result=$? 95 | done 96 | 97 | exit "${test_result}" 98 | -------------------------------------------------------------------------------- /official/wide_deep/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaoguan/deep-ctr-prediction/f8d83d6da2ee07158922474d11f444533ec6a7a3/official/wide_deep/__init__.py -------------------------------------------------------------------------------- /official/wide_deep/census_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Train DNN on census income dataset.""" 16 | 17 | import os 18 | 19 | from absl import app as absl_app 20 | from absl import flags 21 | import tensorflow as tf 22 | 23 | from official.utils.flags import core as flags_core 24 | from official.utils.logs import logger 25 | from official.wide_deep import census_dataset 26 | from official.wide_deep import wide_deep_run_loop 27 | 28 | 29 | def define_census_flags(): 30 | wide_deep_run_loop.define_wide_deep_flags() 31 | flags.adopt_module_key_flags(wide_deep_run_loop) 32 | flags_core.set_defaults(data_dir='/tmp/census_data', 33 | model_dir='/tmp/census_model', 34 | train_epochs=40, 35 | epochs_between_evals=2, 36 | inter_op_parallelism_threads=0, 37 | intra_op_parallelism_threads=0, 38 | batch_size=40) 39 | 40 | 41 | def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op): 42 | """Build an estimator appropriate for the given model type.""" 43 | wide_columns, deep_columns = model_column_fn() 44 | hidden_units = [100, 75, 50, 25] 45 | 46 | # Create a tf.estimator.RunConfig to ensure the model is run on CPU, which 47 | # trains faster than GPU for this model. 48 | run_config = tf.estimator.RunConfig().replace( 49 | session_config=tf.ConfigProto(device_count={'GPU': 0}, 50 | inter_op_parallelism_threads=inter_op, 51 | intra_op_parallelism_threads=intra_op)) 52 | 53 | if model_type == 'wide': 54 | return tf.estimator.LinearClassifier( 55 | model_dir=model_dir, 56 | feature_columns=wide_columns, 57 | config=run_config) 58 | elif model_type == 'deep': 59 | return tf.estimator.DNNClassifier( 60 | model_dir=model_dir, 61 | feature_columns=deep_columns, 62 | hidden_units=hidden_units, 63 | config=run_config) 64 | else: 65 | return tf.estimator.DNNLinearCombinedClassifier( 66 | model_dir=model_dir, 67 | linear_feature_columns=wide_columns, 68 | dnn_feature_columns=deep_columns, 69 | dnn_hidden_units=hidden_units, 70 | config=run_config) 71 | 72 | 73 | def run_census(flags_obj): 74 | """Construct all necessary functions and call run_loop. 75 | 76 | Args: 77 | flags_obj: Object containing user specified flags. 78 | """ 79 | if flags_obj.download_if_missing: 80 | census_dataset.download(flags_obj.data_dir) 81 | 82 | train_file = os.path.join(flags_obj.data_dir, census_dataset.TRAINING_FILE) 83 | test_file = os.path.join(flags_obj.data_dir, census_dataset.EVAL_FILE) 84 | 85 | # Train and evaluate the model every `flags.epochs_between_evals` epochs. 86 | def train_input_fn(): 87 | return census_dataset.input_fn( 88 | train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size) 89 | 90 | def eval_input_fn(): 91 | return census_dataset.input_fn(test_file, 1, False, flags_obj.batch_size) 92 | 93 | tensors_to_log = { 94 | 'average_loss': '{loss_prefix}head/truediv', 95 | 'loss': '{loss_prefix}head/weighted_loss/Sum' 96 | } 97 | 98 | wide_deep_run_loop.run_loop( 99 | name="Census Income", train_input_fn=train_input_fn, 100 | eval_input_fn=eval_input_fn, 101 | model_column_fn=census_dataset.build_model_columns, 102 | build_estimator_fn=build_estimator, 103 | flags_obj=flags_obj, 104 | tensors_to_log=tensors_to_log, 105 | early_stop=True) 106 | 107 | 108 | def main(_): 109 | with logger.benchmark_context(flags.FLAGS): 110 | run_census(flags.FLAGS) 111 | 112 | 113 | if __name__ == '__main__': 114 | tf.logging.set_verbosity(tf.logging.INFO) 115 | define_census_flags() 116 | absl_app.run(main) 117 | -------------------------------------------------------------------------------- /official/wide_deep/census_test.csv: -------------------------------------------------------------------------------- 1 | 39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,,,2174,0,40,,<=50K 2 | 50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,,,0,0,13,,<=50K 3 | 38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,,,0,0,40,,<=50K 4 | 53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,,,0,0,40,,<=50K 5 | 28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,,,0,0,40,,<=50K 6 | 37,Private,284582,Masters,14,Married-civ-spouse,Exec-managerial,Wife,,,0,0,40,,<=50K 7 | 49,Private,160187,9th,5,Married-spouse-absent,Other-service,Not-in-family,,,0,0,16,,<=50K 8 | 52,Self-emp-not-inc,209642,HS-grad,9,Married-civ-spouse,Exec-managerial,Husband,,,0,0,45,,>50K 9 | 31,Private,45781,Masters,14,Never-married,Prof-specialty,Not-in-family,,,14084,0,50,,>50K 10 | 42,Private,159449,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,,,5178,0,40,,>50K 11 | 37,Private,280464,Some-college,10,Married-civ-spouse,Exec-managerial,Husband,,,0,0,80,,>50K 12 | 30,State-gov,141297,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,,,0,0,40,,>50K 13 | 23,Private,122272,Bachelors,13,Never-married,Adm-clerical,Own-child,,,0,0,30,,<=50K 14 | 32,Private,205019,Assoc-acdm,12,Never-married,Sales,Not-in-family,,,0,0,50,,<=50K 15 | 40,Private,121772,Assoc-voc,11,Married-civ-spouse,Craft-repair,Husband,,,0,0,40,,>50K 16 | 34,Private,245487,7th-8th,4,Married-civ-spouse,Transport-moving,Husband,,,0,0,45,,<=50K 17 | 25,Self-emp-not-inc,176756,HS-grad,9,Never-married,Farming-fishing,Own-child,,,0,0,35,,<=50K 18 | 32,Private,186824,HS-grad,9,Never-married,Machine-op-inspct,Unmarried,,,0,0,40,,<=50K 19 | 38,Private,28887,11th,7,Married-civ-spouse,Sales,Husband,,,0,0,50,,<=50K 20 | 43,Self-emp-not-inc,292175,Masters,14,Divorced,Exec-managerial,Unmarried,,,0,0,45,,>50K 21 | 40,Private,193524,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,60,,>50K 22 | 56,Local-gov,216851,Bachelors,13,Married-civ-spouse,Tech-support,Husband,,,0,0,40,,>50K 23 | 54,?,180211,Some-college,10,Married-civ-spouse,?,Husband,,,0,0,60,,>50K 24 | 22,State-gov,311512,Some-college,10,Married-civ-spouse,Other-service,Husband,,,0,0,15,,<=50K 25 | 31,Private,84154,Some-college,10,Married-civ-spouse,Sales,Husband,,,0,0,38,,>50K 26 | 57,Federal-gov,337895,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,,,0,0,40,,>50K 27 | 47,Private,51835,Prof-school,15,Married-civ-spouse,Prof-specialty,Wife,,,0,1902,60,,>50K 28 | 50,Federal-gov,251585,Bachelors,13,Divorced,Exec-managerial,Not-in-family,,,0,0,55,,>50K 29 | 25,Private,289980,HS-grad,9,Never-married,Handlers-cleaners,Not-in-family,,,0,0,35,,<=50K 30 | 42,Private,116632,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,,,0,0,45,,>50K 31 | -------------------------------------------------------------------------------- /official/wide_deep/census_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import tensorflow as tf # pylint: disable=g-bad-import-order 23 | 24 | from official.utils.testing import integration 25 | from official.wide_deep import census_dataset 26 | from official.wide_deep import census_main 27 | from official.wide_deep import wide_deep_run_loop 28 | 29 | tf.logging.set_verbosity(tf.logging.ERROR) 30 | 31 | TEST_INPUT = ('18,Self-emp-not-inc,987,Bachelors,12,Married-civ-spouse,abc,' 32 | 'Husband,zyx,wvu,34,56,78,tsr,<=50K') 33 | 34 | TEST_INPUT_VALUES = { 35 | 'age': 18, 36 | 'education_num': 12, 37 | 'capital_gain': 34, 38 | 'capital_loss': 56, 39 | 'hours_per_week': 78, 40 | 'education': 'Bachelors', 41 | 'marital_status': 'Married-civ-spouse', 42 | 'relationship': 'Husband', 43 | 'workclass': 'Self-emp-not-inc', 44 | 'occupation': 'abc', 45 | } 46 | 47 | TEST_CSV = os.path.join(os.path.dirname(__file__), 'census_test.csv') 48 | 49 | 50 | class BaseTest(tf.test.TestCase): 51 | """Tests for Wide Deep model.""" 52 | 53 | @classmethod 54 | def setUpClass(cls): # pylint: disable=invalid-name 55 | super(BaseTest, cls).setUpClass() 56 | census_main.define_census_flags() 57 | 58 | def setUp(self): 59 | # Create temporary CSV file 60 | self.temp_dir = self.get_temp_dir() 61 | self.input_csv = os.path.join(self.temp_dir, 'test.csv') 62 | with tf.gfile.Open(self.input_csv, 'w') as temp_csv: 63 | temp_csv.write(TEST_INPUT) 64 | 65 | with tf.gfile.Open(TEST_CSV, "r") as temp_csv: 66 | test_csv_contents = temp_csv.read() 67 | 68 | # Used for end-to-end tests. 69 | for fname in [census_dataset.TRAINING_FILE, census_dataset.EVAL_FILE]: 70 | with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv: 71 | test_csv.write(test_csv_contents) 72 | 73 | def test_input_fn(self): 74 | dataset = census_dataset.input_fn(self.input_csv, 1, False, 1) 75 | features, labels = dataset.make_one_shot_iterator().get_next() 76 | 77 | with self.test_session() as sess: 78 | features, labels = sess.run((features, labels)) 79 | 80 | # Compare the two features dictionaries. 81 | for key in TEST_INPUT_VALUES: 82 | self.assertTrue(key in features) 83 | self.assertEqual(len(features[key]), 1) 84 | feature_value = features[key][0] 85 | 86 | # Convert from bytes to string for Python 3. 87 | if isinstance(feature_value, bytes): 88 | feature_value = feature_value.decode() 89 | 90 | self.assertEqual(TEST_INPUT_VALUES[key], feature_value) 91 | 92 | self.assertFalse(labels) 93 | 94 | def build_and_test_estimator(self, model_type): 95 | """Ensure that model trains and minimizes loss.""" 96 | model = census_main.build_estimator( 97 | self.temp_dir, model_type, 98 | model_column_fn=census_dataset.build_model_columns, 99 | inter_op=0, intra_op=0) 100 | 101 | # Train for 1 step to initialize model and evaluate initial loss 102 | def get_input_fn(num_epochs, shuffle, batch_size): 103 | def input_fn(): 104 | return census_dataset.input_fn( 105 | TEST_CSV, num_epochs=num_epochs, shuffle=shuffle, 106 | batch_size=batch_size) 107 | return input_fn 108 | 109 | model.train(input_fn=get_input_fn(1, True, 1), steps=1) 110 | initial_results = model.evaluate(input_fn=get_input_fn(1, False, 1)) 111 | 112 | # Train for 100 epochs at batch size 3 and evaluate final loss 113 | model.train(input_fn=get_input_fn(100, True, 3)) 114 | final_results = model.evaluate(input_fn=get_input_fn(1, False, 1)) 115 | 116 | print('%s initial results:' % model_type, initial_results) 117 | print('%s final results:' % model_type, final_results) 118 | 119 | # Ensure loss has decreased, while accuracy and both AUCs have increased. 120 | self.assertLess(final_results['loss'], initial_results['loss']) 121 | self.assertGreater(final_results['auc'], initial_results['auc']) 122 | self.assertGreater(final_results['auc_precision_recall'], 123 | initial_results['auc_precision_recall']) 124 | self.assertGreater(final_results['accuracy'], initial_results['accuracy']) 125 | 126 | def test_wide_deep_estimator_training(self): 127 | self.build_and_test_estimator('wide_deep') 128 | 129 | def test_end_to_end_wide(self): 130 | integration.run_synthetic( 131 | main=census_main.main, tmp_root=self.get_temp_dir(), 132 | extra_flags=[ 133 | '--data_dir', self.get_temp_dir(), 134 | '--model_type', 'wide', 135 | '--download_if_missing=false' 136 | ], 137 | synth=False, max_train=None) 138 | 139 | def test_end_to_end_deep(self): 140 | integration.run_synthetic( 141 | main=census_main.main, tmp_root=self.get_temp_dir(), 142 | extra_flags=[ 143 | '--data_dir', self.get_temp_dir(), 144 | '--model_type', 'deep', 145 | '--download_if_missing=false' 146 | ], 147 | synth=False, max_train=None) 148 | 149 | def test_end_to_end_wide_deep(self): 150 | integration.run_synthetic( 151 | main=census_main.main, tmp_root=self.get_temp_dir(), 152 | extra_flags=[ 153 | '--data_dir', self.get_temp_dir(), 154 | '--model_type', 'wide_deep', 155 | '--download_if_missing=false' 156 | ], 157 | synth=False, max_train=None) 158 | 159 | 160 | if __name__ == '__main__': 161 | tf.test.main() 162 | -------------------------------------------------------------------------------- /official/wide_deep/movielens_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Train DNN on Kaggle movie dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl import app as absl_app 24 | from absl import flags 25 | import tensorflow as tf 26 | 27 | from official.datasets import movielens 28 | from official.utils.flags import core as flags_core 29 | from official.utils.logs import logger 30 | from official.wide_deep import movielens_dataset 31 | from official.wide_deep import wide_deep_run_loop 32 | 33 | 34 | def define_movie_flags(): 35 | """Define flags for movie dataset training.""" 36 | wide_deep_run_loop.define_wide_deep_flags() 37 | flags.DEFINE_enum( 38 | name="dataset", default=movielens.ML_1M, 39 | enum_values=movielens.DATASETS, case_sensitive=False, 40 | help=flags_core.help_wrap("Dataset to be trained and evaluated.")) 41 | flags.adopt_module_key_flags(wide_deep_run_loop) 42 | flags_core.set_defaults(data_dir="/tmp/movielens-data/", 43 | model_dir='/tmp/movie_model', 44 | model_type="deep", 45 | train_epochs=50, 46 | epochs_between_evals=5, 47 | inter_op_parallelism_threads=0, 48 | intra_op_parallelism_threads=0, 49 | batch_size=256) 50 | 51 | @flags.validator("stop_threshold", 52 | message="stop_threshold not supported for movielens model") 53 | def _no_stop(stop_threshold): 54 | return stop_threshold is None 55 | 56 | 57 | def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op): 58 | """Build an estimator appropriate for the given model type.""" 59 | if model_type != "deep": 60 | raise NotImplementedError("movie dataset only supports `deep` model_type") 61 | _, deep_columns = model_column_fn() 62 | hidden_units = [256, 256, 256, 128] 63 | 64 | run_config = tf.estimator.RunConfig().replace( 65 | session_config=tf.ConfigProto(device_count={'GPU': 0}, 66 | inter_op_parallelism_threads=inter_op, 67 | intra_op_parallelism_threads=intra_op)) 68 | return tf.estimator.DNNRegressor( 69 | model_dir=model_dir, 70 | feature_columns=deep_columns, 71 | hidden_units=hidden_units, 72 | optimizer=tf.train.AdamOptimizer(), 73 | activation_fn=tf.nn.sigmoid, 74 | dropout=0.3, 75 | loss_reduction=tf.losses.Reduction.MEAN) 76 | 77 | 78 | def run_movie(flags_obj): 79 | """Construct all necessary functions and call run_loop. 80 | 81 | Args: 82 | flags_obj: Object containing user specified flags. 83 | """ 84 | 85 | if flags_obj.download_if_missing: 86 | movielens.download(dataset=flags_obj.dataset, data_dir=flags_obj.data_dir) 87 | 88 | train_input_fn, eval_input_fn, model_column_fn = \ 89 | movielens_dataset.construct_input_fns( 90 | dataset=flags_obj.dataset, data_dir=flags_obj.data_dir, 91 | batch_size=flags_obj.batch_size, repeat=flags_obj.epochs_between_evals) 92 | 93 | tensors_to_log = { 94 | 'loss': '{loss_prefix}head/weighted_loss/value' 95 | } 96 | 97 | wide_deep_run_loop.run_loop( 98 | name="MovieLens", train_input_fn=train_input_fn, 99 | eval_input_fn=eval_input_fn, 100 | model_column_fn=model_column_fn, 101 | build_estimator_fn=build_estimator, 102 | flags_obj=flags_obj, 103 | tensors_to_log=tensors_to_log, 104 | early_stop=False) 105 | 106 | 107 | def main(_): 108 | with logger.benchmark_context(flags.FLAGS): 109 | run_movie(flags.FLAGS) 110 | 111 | 112 | if __name__ == '__main__': 113 | tf.logging.set_verbosity(tf.logging.INFO) 114 | define_movie_flags() 115 | absl_app.run(main) 116 | -------------------------------------------------------------------------------- /official/wide_deep/movielens_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | 22 | import numpy as np 23 | import tensorflow as tf # pylint: disable=g-bad-import-order 24 | 25 | from official.datasets import movielens 26 | from official.utils.testing import integration 27 | from official.wide_deep import movielens_dataset 28 | from official.wide_deep import movielens_main 29 | from official.wide_deep import wide_deep_run_loop 30 | 31 | tf.logging.set_verbosity(tf.logging.ERROR) 32 | 33 | 34 | TEST_INPUT_VALUES = { 35 | "genres": np.array( 36 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 37 | "user_id": [3], 38 | "item_id": [4], 39 | } 40 | 41 | TEST_ITEM_DATA = """item_id,titles,genres 42 | 1,Movie_1,Comedy|Romance 43 | 2,Movie_2,Adventure|Children's 44 | 3,Movie_3,Comedy|Drama 45 | 4,Movie_4,Comedy 46 | 5,Movie_5,Action|Crime|Thriller 47 | 6,Movie_6,Action 48 | 7,Movie_7,Action|Adventure|Thriller""" 49 | 50 | TEST_RATING_DATA = """user_id,item_id,rating,timestamp 51 | 1,2,5,978300760 52 | 1,3,3,978302109 53 | 1,6,3,978301968 54 | 2,1,4,978300275 55 | 2,7,5,978824291 56 | 3,1,3,978302268 57 | 3,4,5,978302039 58 | 3,5,5,978300719 59 | """ 60 | 61 | 62 | class BaseTest(tf.test.TestCase): 63 | """Tests for Wide Deep model.""" 64 | 65 | @classmethod 66 | def setUpClass(cls): # pylint: disable=invalid-name 67 | super(BaseTest, cls).setUpClass() 68 | movielens_main.define_movie_flags() 69 | 70 | def setUp(self): 71 | # Create temporary CSV file 72 | self.temp_dir = self.get_temp_dir() 73 | tf.gfile.MakeDirs(os.path.join(self.temp_dir, movielens.ML_1M)) 74 | 75 | self.ratings_csv = os.path.join( 76 | self.temp_dir, movielens.ML_1M, movielens.RATINGS_FILE) 77 | self.item_csv = os.path.join( 78 | self.temp_dir, movielens.ML_1M, movielens.MOVIES_FILE) 79 | 80 | with tf.gfile.Open(self.ratings_csv, "w") as f: 81 | f.write(TEST_RATING_DATA) 82 | 83 | with tf.gfile.Open(self.item_csv, "w") as f: 84 | f.write(TEST_ITEM_DATA) 85 | 86 | 87 | def test_input_fn(self): 88 | train_input_fn, _, _ = movielens_dataset.construct_input_fns( 89 | dataset=movielens.ML_1M, data_dir=self.temp_dir, batch_size=8, repeat=1) 90 | 91 | dataset = train_input_fn() 92 | features, labels = dataset.make_one_shot_iterator().get_next() 93 | 94 | with self.test_session() as sess: 95 | features, labels = sess.run((features, labels)) 96 | 97 | # Compare the two features dictionaries. 98 | for key in TEST_INPUT_VALUES: 99 | self.assertTrue(key in features) 100 | self.assertAllClose(TEST_INPUT_VALUES[key], features[key][0]) 101 | 102 | self.assertAllClose(labels[0], [1.0]) 103 | 104 | def test_end_to_end_deep(self): 105 | integration.run_synthetic( 106 | main=movielens_main.main, tmp_root=self.temp_dir, 107 | extra_flags=[ 108 | "--data_dir", self.temp_dir, 109 | "--download_if_missing=false", 110 | "--train_epochs", "1", 111 | "--epochs_between_evals", "1" 112 | ], 113 | synth=False, max_train=None) 114 | 115 | 116 | if __name__ == "__main__": 117 | tf.test.main() 118 | -------------------------------------------------------------------------------- /official/wide_deep/wide_deep_run_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Core run logic for TensorFlow Wide & Deep Tutorial using tf.estimator API.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import shutil 23 | 24 | from absl import app as absl_app 25 | from absl import flags 26 | import tensorflow as tf # pylint: disable=g-bad-import-order 27 | 28 | from official.utils.flags import core as flags_core 29 | from official.utils.logs import hooks_helper 30 | from official.utils.logs import logger 31 | from official.utils.misc import model_helpers 32 | 33 | 34 | LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'} 35 | 36 | 37 | def define_wide_deep_flags(): 38 | """Add supervised learning flags, as well as wide-deep model type.""" 39 | flags_core.define_base() 40 | flags_core.define_benchmark() 41 | flags_core.define_performance( 42 | num_parallel_calls=False, inter_op=True, intra_op=True, 43 | synthetic_data=False, max_train_steps=False, dtype=False, 44 | all_reduce_alg=False) 45 | 46 | flags.adopt_module_key_flags(flags_core) 47 | 48 | flags.DEFINE_enum( 49 | name="model_type", short_name="mt", default="wide_deep", 50 | enum_values=['wide', 'deep', 'wide_deep'], 51 | help="Select model topology.") 52 | flags.DEFINE_boolean( 53 | name="download_if_missing", default=True, help=flags_core.help_wrap( 54 | "Download data to data_dir if it is not already present.")) 55 | 56 | 57 | def export_model(model, model_type, export_dir, model_column_fn): 58 | """Export to SavedModel format. 59 | 60 | Args: 61 | model: Estimator object 62 | model_type: string indicating model type. "wide", "deep" or "wide_deep" 63 | export_dir: directory to export the model. 64 | model_column_fn: Function to generate model feature columns. 65 | """ 66 | wide_columns, deep_columns = model_column_fn() 67 | if model_type == 'wide': 68 | columns = wide_columns 69 | elif model_type == 'deep': 70 | columns = deep_columns 71 | else: 72 | columns = wide_columns + deep_columns 73 | feature_spec = tf.feature_column.make_parse_example_spec(columns) 74 | example_input_fn = ( 75 | tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)) 76 | model.export_savedmodel(export_dir, example_input_fn, 77 | strip_default_attrs=True) 78 | 79 | 80 | def run_loop(name, train_input_fn, eval_input_fn, model_column_fn, 81 | build_estimator_fn, flags_obj, tensors_to_log, early_stop=False): 82 | """Define training loop.""" 83 | model_helpers.apply_clean(flags.FLAGS) 84 | model = build_estimator_fn( 85 | model_dir=flags_obj.model_dir, model_type=flags_obj.model_type, 86 | model_column_fn=model_column_fn, 87 | inter_op=flags_obj.inter_op_parallelism_threads, 88 | intra_op=flags_obj.intra_op_parallelism_threads) 89 | 90 | run_params = { 91 | 'batch_size': flags_obj.batch_size, 92 | 'train_epochs': flags_obj.train_epochs, 93 | 'model_type': flags_obj.model_type, 94 | } 95 | 96 | benchmark_logger = logger.get_benchmark_logger() 97 | benchmark_logger.log_run_info('wide_deep', name, run_params, 98 | test_id=flags_obj.benchmark_test_id) 99 | 100 | loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '') 101 | tensors_to_log = {k: v.format(loss_prefix=loss_prefix) 102 | for k, v in tensors_to_log.items()} 103 | train_hooks = hooks_helper.get_train_hooks( 104 | flags_obj.hooks, model_dir=flags_obj.model_dir, 105 | batch_size=flags_obj.batch_size, tensors_to_log=tensors_to_log) 106 | 107 | # Train and evaluate the model every `flags.epochs_between_evals` epochs. 108 | for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals): 109 | model.train(input_fn=train_input_fn, hooks=train_hooks) 110 | 111 | results = model.evaluate(input_fn=eval_input_fn) 112 | 113 | # Display evaluation metrics 114 | tf.logging.info('Results at epoch %d / %d', 115 | (n + 1) * flags_obj.epochs_between_evals, 116 | flags_obj.train_epochs) 117 | tf.logging.info('-' * 60) 118 | 119 | for key in sorted(results): 120 | tf.logging.info('%s: %s' % (key, results[key])) 121 | 122 | benchmark_logger.log_evaluation_result(results) 123 | 124 | if early_stop and model_helpers.past_stop_threshold( 125 | flags_obj.stop_threshold, results['accuracy']): 126 | break 127 | 128 | # Export the model 129 | if flags_obj.export_dir is not None: 130 | export_model(model, flags_obj.model_type, flags_obj.export_dir, 131 | model_column_fn) 132 | --------------------------------------------------------------------------------