├── .gitignore ├── LICENSE ├── README.md ├── images ├── prediction_formula.png └── tensorboard_wide_only_cmat.png └── tflearn_wide_and_deep.py /.gitignore: -------------------------------------------------------------------------------- 1 | OLD 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 ichuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | tflearn_wide_and_deep 2 | ===================== 3 | 4 | Pedagogical example realization of wide & deep networks, using 5 | [TensorFlow](https://www.tensorflow.org/) and 6 | [TFLearn](http://tflearn.org/). 7 | 8 | (Also see: [Pedagogical example of seq2seq RNN](https://github.com/ichuang/tflearn_seq2seq)) 9 | 10 | This is a re-implementation of the google paper on [Wide & Deep 11 | Learning for Recommender Systems](http://arxiv.org/abs/1606.07792), 12 | using the combination of a wide linear model, and a deep feed-forward 13 | neural network, for binary classification (image from the Tensorflow 14 | Tutorial): 15 | 16 | ![wide_and_deep](https://www.tensorflow.org/versions/r0.10/images/wide_n_deep.svg) 17 | 18 | This example realization is based on Tensorflow's [Wide and Deep Learning Tutorial](https://www.tensorflow.org/versions/r0.10/tutorials/wide_and_deep/index.html), 19 | but implemented in [TFLearn](http://tflearn.org/). Note that despite 20 | the closeness of names, [TFLearn](http://tflearn.org/) is distinct 21 | from TF.Learn (previously known as scikit flow, sometimes referred to 22 | as 23 | [tf.contrib.learn](https://www.tensorflow.org/versions/r0.9/tutorials/tflearn/index.html)). 24 | 25 | This implementation explicitly presents the construction of layers in the deep part of the 26 | network, and allows direct access to changing the layer architecture, and customization 27 | of methods used for regression and optimization. 28 | 29 | In contrast, the TF.Learn tutorial offers more sophistication, but 30 | hides the layer architecture behind a black box function, 31 | [tf.contrib.learn.DNNLinearCombinedClassifier](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py#L41). 32 | 33 | 34 | Basic Usage 35 | =========== 36 | 37 | usage: tflearn_wide+deep.py [-h] [--model_type MODEL_TYPE] 38 | [--run_name RUN_NAME] 39 | [--load_weights LOAD_WEIGHTS] [--n_epoch N_EPOCH] 40 | [--snapshot_step SNAPSHOT_STEP] 41 | [--wide_learning_rate WIDE_LEARNING_RATE] 42 | [--deep_learning_rate DEEP_LEARNING_RATE] 43 | [--verbose [VERBOSE]] [--noverbose] 44 | 45 | optional arguments: 46 | -h, --help show this help message and exit 47 | --model_type MODEL_TYPE 48 | Valid model types: {'wide', 'deep', 'wide+deep'}. 49 | --run_name RUN_NAME name for this run (defaults to model type) 50 | --load_weights LOAD_WEIGHTS 51 | filename with initial weights to load 52 | --n_epoch N_EPOCH Number of training epoch steps 53 | --snapshot_step SNAPSHOT_STEP 54 | Step number when snapshot (and validation testing) is 55 | done 56 | --wide_learning_rate WIDE_LEARNING_RATE 57 | learning rate for the wide part of the model 58 | --deep_learning_rate DEEP_LEARNING_RATE 59 | learning rate for the deep part of the model 60 | --verbose [VERBOSE] Verbose output 61 | --noverbose 62 | 63 | Dataset 64 | ======= 65 | 66 | The dataset is the same [Census income 67 | data](https://archive.ics.uci.edu/ml/datasets/Census+Income) used in 68 | Tensorflow's [Wide and Deep Learning 69 | Tutorial](https://www.tensorflow.org/versions/r0.10/tutorials/wide_and_deep/index.html). 70 | The goal is to predict whether a given individual has an income of 71 | over 50,000 dollars or not, based on 5 continuous variables (`age`, 72 | `education_num`, `capital_gain`, `capital_loss`, `hours_per_week`) and 9 categorical variables. 73 | 74 | We simplify the approach used for categorical variables, and do not 75 | use sparse tensors or anything fancy; instead, for the sake of a 76 | simple demonstration, we map category strings to integers, using 77 | pandas, then use embedding layers (whose weights are learned by 78 | training). That part of the code is excerpted here: 79 | 80 | ```python 81 | cc_input_var = {} 82 | cc_embed_var = {} 83 | flat_vars = [] 84 | for cc, cc_size in self.categorical_columns.items(): 85 | cc_input_var[cc] = tflearn.input_data(shape=[None, 1], name="%s_in" % cc, dtype=tf.int32) 86 | # embedding layers only work on CPU! No GPU implementation in tensorflow, yet! 87 | cc_embed_var[cc] = tflearn.layers.embedding_ops.embedding(cc_input_var[cc], cc_size, 8, name="deep_%s_embed" % cc) 88 | flat_vars.append(tf.squeeze(cc_embed_var[cc], squeeze_dims=[1], name="%s_squeeze" % cc)) 89 | ``` 90 | 91 | Notice how TFLearn provides input layers, which automatically construct placeholders for input data feeds. 92 | 93 | Layer Architecture 94 | ================== 95 | 96 | The wide model is realized using a single fully-connected layer, with no bias, and width equal to the number of inputs: 97 | 98 | ```python 99 | network = tflearn.fully_connected(network, n_inputs, activation="linear", name="wide_linear", bias=False) # x*W (no bias) 100 | network = tf.reduce_sum(network, 1, name="reduce_sum") # batched sum, to produce logits 101 | network = tf.reshape(network, [-1, 1]) 102 | ``` 103 | 104 | The deep model is realized with two fully connected layers, with an 105 | input constructed by concatenating the wide inputs with the embedded 106 | categorical variables: 107 | 108 | ```python 109 | n_nodes=[100, 50] 110 | network = tf.concat(1, [wide_inputs] + flat_vars, name="deep_concat") 111 | for k in range(len(n_nodes)): 112 | network = tflearn.fully_connected(network, n_nodes[k], activation="relu", name="deep_fc%d" % (k+1)) 113 | network = tflearn.fully_connected(network, 1, activation="linear", name="deep_fc_output", bias=False) 114 | ``` 115 | 116 | For the combined wide+deep model, the probability that the outcome is 117 | "1" (versus "0"), for input "x" is given by Equation 3 of the [google research 118 | paper](http://arxiv.org/abs/1606.07792), as 119 | 120 | ![prediction_formula](https://github.com/ichuang/tflearn_wide_and_deep/raw/master/images/prediction_formula.png "") 121 | 122 | Note that the wide and deep models share a single central bias variable: 123 | 124 | ```python 125 | with tf.variable_op_scope([wide_inputs], None, "cb_unit", reuse=False) as scope: 126 | central_bias = tflearn.variables.variable('central_bias', shape=[1], 127 | initializer=tf.constant_initializer(np.random.randn()), 128 | trainable=True, restore=True) 129 | tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/cb_unit', central_bias) 130 | ``` 131 | 132 | The wide and deep networks are combined according to the formula: 133 | 134 | ```python 135 | wide_network = self.wide_model(wide_inputs, n_cc) 136 | deep_network = self.deep_model(wide_inputs, n_cc) 137 | network = tf.add(wide_network, deep_network) 138 | network = tf.add(network, central_bias, name="add_central_bias") 139 | ``` 140 | 141 | Regression is done separately for the wide and deep networks, and for the central bias: 142 | 143 | ```python 144 | trainable_vars = tf.trainable_variables() 145 | tv_deep = [v for v in trainable_vars if v.name.startswith('deep_')] 146 | tv_wide = [v for v in trainable_vars if v.name.startswith('wide_')] 147 | 148 | wide_network_with_bias = tf.add(wide_network, central_bias, name="wide_with_bias") 149 | tflearn.regression(wide_network_with_bias, 150 | placeholder=Y_in, 151 | optimizer='sgd', 152 | loss='binary_crossentropy', 153 | metric="binary_accuracy", 154 | learning_rate=learning_rate[0], 155 | validation_monitors=vmset, 156 | trainable_vars=tv_wide, 157 | op_name="wide_regression", 158 | name="Y") 159 | 160 | deep_network_with_bias = tf.add(deep_network, central_bias, name="deep_with_bias") 161 | tflearn.regression(deep_network_with_bias, 162 | placeholder=Y_in, 163 | optimizer='adam', 164 | loss='binary_crossentropy', 165 | metric="binary_accuracy", 166 | learning_rate=learning_rate[1], 167 | trainable_vars=tv_deep, 168 | op_name="deep_regression", 169 | name="Y") 170 | 171 | tflearn.regression(network, 172 | placeholder=Y_in, 173 | optimizer='adam', 174 | loss='binary_crossentropy', 175 | metric="binary_accuracy", 176 | learning_rate=learning_rate[0], # use wide learning rate 177 | trainable_vars=[central_bias], 178 | op_name="central_bias_regression", 179 | name="Y") 180 | ``` 181 | 182 | and the confusion matrix is computed at each valiation step, using a 183 | validation monitor which pushes the result as a summary to 184 | TensorBoard: 185 | 186 | ```python 187 | with tf.name_scope('Monitors'): 188 | predictions = tf.cast(tf.greater(network, 0), tf.int64) 189 | Ybool = tf.cast(Y_in, tf.bool) 190 | pos = tf.boolean_mask(predictions, Ybool) 191 | neg = tf.boolean_mask(predictions, ~Ybool) 192 | psize = tf.cast(tf.shape(pos)[0], tf.int64) 193 | nsize = tf.cast(tf.shape(neg)[0], tf.int64) 194 | true_positive = tf.reduce_sum(pos, name="true_positive") 195 | false_negative = tf.sub(psize, true_positive, name="false_negative") 196 | false_positive = tf.reduce_sum(neg, name="false_positive") 197 | true_negative = tf.sub(nsize, false_positive, name="true_negative") 198 | overall_accuracy = tf.truediv(tf.add(true_positive, true_negative), tf.add(nsize, psize), name="overall_accuracy") 199 | vmset = [true_positive, true_negative, false_positive, false_negative, overall_accuracy] 200 | ``` 201 | 202 | Performance Comparisons 203 | ======================= 204 | 205 | How does wide-only compare with wide+deep, or, for that matter, with deep only? 206 | 207 | Wide Model 208 | ---------- 209 | 210 | Run this for the wide model: 211 | 212 | python tflearn_wide_and_deep.py --verbose --n_epoch=2000 --model_type=wide --snapshot_step=500 --wide_learning_rate=0.0001 213 | 214 | The tensorboard plots should show the accuracy and loss, as well as the four confusion matrix entries, e.g.: 215 | 216 | ![tensorboard_confusion_matrix](https://github.com/ichuang/tflearn_wide_and_deep/raw/master/images/tensorboard_wide_only_cmat.png "") 217 | 218 | The tail end of the console output should look something like this: 219 | 220 | ``` 221 | Training Step: 2000 | total loss: 0.82368 222 | | wide_regression | epoch: 2000 | loss: 0.82368 - binary_acc: 0.7489 | val_loss: 0.58739 - val_acc: 0.7813 -- iter: 32561/32561 223 | -- 224 | ============================================================ Evaluation 225 | logits: (16281,), min=-2.59761142731, max=116.775054932 226 | Actual IDV 227 | 0 12435 228 | 1 3846 229 | 230 | Predicted IDV 231 | 0 14726 232 | 1 1555 233 | 234 | Confusion matrix: 235 | actual 0 1 236 | predictions 237 | 0 11800 2926 238 | 1 635 920 239 | ``` 240 | 241 | Note that the accuracy is (920+11800)/16281 = 78.1% 242 | 243 | 244 | Deep Model 245 | ---------- 246 | 247 | Run this: 248 | 249 | python tflearn_wide_and_deep.py --verbose --n_epoch=2000 --model_type=deep --snapshot_step=250 --run_name="deep_run" --deep_learning_rate=0.001 250 | 251 | And the result should look something like: 252 | 253 | ``` 254 | Training Step: 2000 | total loss: 0.31951 255 | | deep_regression | epoch: 2000 | loss: 0.31951 - binary_acc: 0.8515 | val_loss: 0.31093 - val_acc: 0.8553 -- iter: 32561/32561 256 | -- 257 | ============================================================ Evaluation 258 | logits: (16281,), min=-12.0320196152, max=4.89985847473 259 | Actual IDV 260 | 0 12435 261 | 1 3846 262 | 263 | Predicted IDV 264 | 0 12891 265 | 1 3390 266 | 267 | 268 | Confusion matrix: 269 | actual 0 1 270 | predictions 271 | 0 11485 1406 272 | 1 950 2440 273 | 274 | ``` 275 | 276 | Giving a final accuracy of (2440+11485)/16281 = 85.53% 277 | 278 | Wide+Deep Model 279 | --------------- 280 | 281 | Now how does the combined model perform? Run this: 282 | 283 | python tflearn_wide_and_deep.py --verbose --n_epoch=2000 --model_type=wide+deep --snapshot_step=250 \ 284 | --run_name="wide+deep_run" --wide_learning_rate=0.00001 --deep_learning_rate=0.0001 285 | 286 | And the output should give something like this: 287 | 288 | ``` 289 | Training Step: 2000 | total loss: 1.33436 290 | | wide_regression | epoch: 1250 | loss: 0.56108 - binary_acc: 0.7800 | val_loss: 0.55753 - val_acc: 0.7780 -- iter: 32561/32561 291 | | deep_regression | epoch: 1250 | loss: 0.30490 - binary_acc: 0.8576 | val_loss: 0.30492 - val_acc: 0.8576 -- iter: 32561/32561 292 | | central_bias_regression | epoch: 1250 | loss: 0.46839 - binary_acc: 0.8158 | val_loss: 0.46368 - val_acc: 0.8176 -- iter: 32561/32561 293 | -- 294 | ============================================================ Evaluation 295 | logits: (16281,), min=-14.6657066345, max=74.5122756958 296 | Actual IDV 297 | 0 12435 298 | 1 3846 299 | 300 | Predicted IDV 301 | 0 15127 302 | 1 1154 303 | 304 | Confusion matrix: 305 | actual 0 1 306 | predictions 307 | 0 12296 2831 308 | 1 139 1015 309 | ============================================================ 310 | ``` 311 | 312 | (Note how TFLearn shows losses and accuracy numbers for all three regressions). The final accuracy for the combined wide+deep model is 81.76% 313 | 314 | It is striking, though, that the deep model evidently gives 85.76% 315 | accuracy, whereas the wide model gives 77.8% accuracy, at least for 316 | the run recorded above. The combined model has performance inbetween. 317 | 318 | On more complicated datasets, perhaps the outcome would be different. 319 | 320 | Testing 321 | ======= 322 | 323 | Unit tests are provided, implemented using [pytest](http://doc.pytest.org/en/latest/). Run these using: 324 | 325 | py.test tflearn_wide_and_deep.py 326 | 327 | Installation 328 | ============ 329 | 330 | * Requires TF 0.10 or better 331 | * Requires TFLearn installed from github (with [PR#308](https://github.com/tflearn/tflearn/pull/308)) 332 | -------------------------------------------------------------------------------- /images/prediction_formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ichuang/tflearn_wide_and_deep/0a565548ee041391f70244d51a16f74de54ae3d5/images/prediction_formula.png -------------------------------------------------------------------------------- /images/tensorboard_wide_only_cmat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ichuang/tflearn_wide_and_deep/0a565548ee041391f70244d51a16f74de54ae3d5/images/tensorboard_wide_only_cmat.png -------------------------------------------------------------------------------- /tflearn_wide_and_deep.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Pedagogical example realization of wide & deep networks, using TensorFlow and TFLearn. 3 | 4 | This is a re-implementation of http://arxiv.org/abs/1606.07792, using the combination 5 | of a wide linear model, and a deep feed-forward neural network, for binary classification 6 | This example realization is based on Tensorflow's TF.Learn tutorial 7 | (https://www.tensorflow.org/versions/r0.10/tutorials/wide_and_deep/index.html), 8 | but implemented in TFLearn. Note that despite the closeness of names, TFLearn is distinct 9 | from TF.Learn (previously known as scikit flow). 10 | 11 | This implementation explicitly presents the construction of layers in the deep part of the 12 | network, and allows direct access to changing the layer architecture, and customization 13 | of methods used for regression and optimization. 14 | 15 | In contrast, the TF.Learn tutorial offers more sophistication, but hides the layer 16 | architecture behind a black box function, tf.contrib.learn.DNNLinearCombinedClassifier. 17 | 18 | See https://github.com/ichuang/tflearn_wide_and_deep for more about this example. 19 | ''' 20 | 21 | from __future__ import division, print_function 22 | 23 | import os 24 | import sys 25 | import argparse 26 | import tflearn 27 | import tempfile 28 | import urllib 29 | 30 | import numpy as np 31 | import pandas as pd 32 | import tensorflow as tf 33 | 34 | #----------------------------------------------------------------------------- 35 | 36 | COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num", 37 | "marital_status", "occupation", "relationship", "race", "gender", 38 | "capital_gain", "capital_loss", "hours_per_week", "native_country", 39 | "income_bracket"] 40 | LABEL_COLUMN = "label" 41 | CATEGORICAL_COLUMNS = {"workclass": 10, "education": 17, "marital_status":8, 42 | "occupation": 16, "relationship": 7, "race": 6, 43 | "gender": 3, "native_country": 43, "age_binned": 14} 44 | CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss", 45 | "hours_per_week"] 46 | 47 | #----------------------------------------------------------------------------- 48 | 49 | class TFLearnWideAndDeep(object): 50 | ''' 51 | Wide and deep model, implemented using TFLearn 52 | ''' 53 | AVAILABLE_MODELS = ["wide", "deep", "wide+deep"] 54 | def __init__(self, model_type="wide+deep", verbose=None, name=None, tensorboard_verbose=3, 55 | wide_learning_rate=0.001, deep_learning_rate=0.001, checkpoints_dir=None): 56 | ''' 57 | model_type = `str`: wide or deep or wide+deep 58 | verbose = `bool` 59 | name = `str` used for run_id (defaults to model_type) 60 | tensorboard_verbose = `int`: logging level for tensorboard (0, 1, 2, or 3) 61 | wide_learning_rate = `float`: defaults to 0.001 62 | deep_learning_rate = `float`: defaults to 0.001 63 | checkpoints_dir = `str`: where checkpoint files will be stored (defaults to "CHECKPOINTS") 64 | ''' 65 | self.model_type = model_type or "wide+deep" 66 | assert self.model_type in self.AVAILABLE_MODELS 67 | self.verbose = verbose or 0 68 | self.tensorboard_verbose = tensorboard_verbose 69 | self.name = name or self.model_type # name is used for the run_id 70 | self.data_columns = COLUMNS 71 | self.continuous_columns = CONTINUOUS_COLUMNS 72 | self.categorical_columns = CATEGORICAL_COLUMNS # dict with category_name: category_size 73 | self.label_column = LABEL_COLUMN 74 | self.checkpoints_dir = checkpoints_dir or "CHECKPOINTS" 75 | if not os.path.exists(self.checkpoints_dir): 76 | os.mkdir(self.checkpoints_dir) 77 | print("Created checkpoints directory %s" % self.checkpoints_dir) 78 | self.build_model([wide_learning_rate, deep_learning_rate]) 79 | 80 | def load_data(self, train_dfn="adult.data", test_dfn="adult.test"): 81 | ''' 82 | Load data (use files offered in the Tensorflow wide_n_deep_tutorial) 83 | ''' 84 | if not os.path.exists(train_dfn): 85 | urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data", train_dfn) 86 | print("Training data is downloaded to %s" % train_dfn) 87 | 88 | if not os.path.exists(test_dfn): 89 | urllib.urlretrieve("https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test", test_dfn) 90 | print("Test data is downloaded to %s" % test_dfn) 91 | 92 | self.train_data = pd.read_csv(train_dfn, names=COLUMNS, skipinitialspace=True) 93 | self.test_data = pd.read_csv(test_dfn, names=COLUMNS, skipinitialspace=True, skiprows=1) 94 | 95 | self.train_data[self.label_column] = (self.train_data["income_bracket"].apply(lambda x: ">50K" in x)).astype(int) 96 | self.test_data[self.label_column] = (self.test_data["income_bracket"].apply(lambda x: ">50K" in x)).astype(int) 97 | 98 | 99 | def build_model(self, learning_rate=[0.001, 0.01]): 100 | ''' 101 | Model - wide and deep - built using tflearn 102 | ''' 103 | n_cc = len(self.continuous_columns) 104 | n_categories = 1 # two categories: is_idv and is_not_idv 105 | input_shape = [None, n_cc] 106 | if self.verbose: 107 | print ("="*77 + " Model %s (type=%s)" % (self.name, self.model_type)) 108 | print (" Input placeholder shape=%s" % str(input_shape)) 109 | wide_inputs = tflearn.input_data(shape=input_shape, name="wide_X") 110 | if not isinstance(learning_rate, list): 111 | learning_rate = [learning_rate, learning_rate] # wide, deep 112 | if self.verbose: 113 | print (" Learning rates (wide, deep)=%s" % learning_rate) 114 | 115 | with tf.name_scope("Y"): # placeholder for target variable (i.e. trainY input) 116 | Y_in = tf.placeholder(shape=[None, 1], dtype=tf.float32, name="Y") 117 | 118 | with tf.variable_op_scope([wide_inputs], None, "cb_unit", reuse=False) as scope: 119 | central_bias = tflearn.variables.variable('central_bias', shape=[1], 120 | initializer=tf.constant_initializer(np.random.randn()), 121 | trainable=True, restore=True) 122 | tf.add_to_collection(tf.GraphKeys.LAYER_VARIABLES + '/cb_unit', central_bias) 123 | 124 | if 'wide' in self.model_type: 125 | wide_network = self.wide_model(wide_inputs, n_cc) 126 | network = wide_network 127 | wide_network_with_bias = tf.add(wide_network, central_bias, name="wide_with_bias") 128 | 129 | if 'deep' in self.model_type: 130 | deep_network = self.deep_model(wide_inputs, n_cc) 131 | deep_network_with_bias = tf.add(deep_network, central_bias, name="deep_with_bias") 132 | if 'wide' in self.model_type: 133 | network = tf.add(wide_network, deep_network) 134 | if self.verbose: 135 | print ("Wide + deep model network %s" % network) 136 | else: 137 | network = deep_network 138 | 139 | network = tf.add(network, central_bias, name="add_central_bias") 140 | 141 | # add validation monitor summaries giving confusion matrix entries 142 | with tf.name_scope('Monitors'): 143 | predictions = tf.cast(tf.greater(network, 0), tf.int64) 144 | print ("predictions=%s" % predictions) 145 | Ybool = tf.cast(Y_in, tf.bool) 146 | print ("Ybool=%s" % Ybool) 147 | pos = tf.boolean_mask(predictions, Ybool) 148 | neg = tf.boolean_mask(predictions, ~Ybool) 149 | psize = tf.cast(tf.shape(pos)[0], tf.int64) 150 | nsize = tf.cast(tf.shape(neg)[0], tf.int64) 151 | true_positive = tf.reduce_sum(pos, name="true_positive") 152 | false_negative = tf.sub(psize, true_positive, name="false_negative") 153 | false_positive = tf.reduce_sum(neg, name="false_positive") 154 | true_negative = tf.sub(nsize, false_positive, name="true_negative") 155 | overall_accuracy = tf.truediv(tf.add(true_positive, true_negative), tf.add(nsize, psize), name="overall_accuracy") 156 | vmset = [true_positive, true_negative, false_positive, false_negative, overall_accuracy] 157 | 158 | trainable_vars = tf.trainable_variables() 159 | tv_deep = [v for v in trainable_vars if v.name.startswith('deep_')] 160 | tv_wide = [v for v in trainable_vars if v.name.startswith('wide_')] 161 | 162 | if self.verbose: 163 | print ("DEEP trainable_vars") 164 | for v in tv_deep: 165 | print (" Variable %s: %s" % (v.name, v)) 166 | print ("WIDE trainable_vars") 167 | for v in tv_wide: 168 | print (" Variable %s: %s" % (v.name, v)) 169 | 170 | if 'wide' in self.model_type: 171 | if not 'deep' in self.model_type: 172 | tv_wide.append(central_bias) 173 | tflearn.regression(wide_network_with_bias, 174 | placeholder=Y_in, 175 | optimizer='sgd', 176 | #loss='roc_auc_score', 177 | loss='binary_crossentropy', 178 | metric="accuracy", 179 | learning_rate=learning_rate[0], 180 | validation_monitors=vmset, 181 | trainable_vars=tv_wide, 182 | op_name="wide_regression", 183 | name="Y") 184 | 185 | if 'deep' in self.model_type: 186 | if not 'wide' in self.model_type: 187 | tv_wide.append(central_bias) 188 | tflearn.regression(deep_network_with_bias, 189 | placeholder=Y_in, 190 | optimizer='adam', 191 | #loss='roc_auc_score', 192 | loss='binary_crossentropy', 193 | metric="accuracy", 194 | learning_rate=learning_rate[1], 195 | validation_monitors=vmset if not 'wide' in self.model_type else None, 196 | trainable_vars=tv_deep, 197 | op_name="deep_regression", 198 | name="Y") 199 | 200 | if self.model_type=='wide+deep': # learn central bias separately for wide+deep 201 | tflearn.regression(network, 202 | placeholder=Y_in, 203 | optimizer='adam', 204 | loss='binary_crossentropy', 205 | metric="accuracy", 206 | learning_rate=learning_rate[0], # use wide learning rate 207 | trainable_vars=[central_bias], 208 | op_name="central_bias_regression", 209 | name="Y") 210 | 211 | self.model = tflearn.DNN(network, 212 | tensorboard_verbose=self.tensorboard_verbose, 213 | max_checkpoints=5, 214 | checkpoint_path="%s/%s.tfl" % (self.checkpoints_dir, self.name), 215 | ) 216 | 217 | if self.verbose: 218 | print ("Target variables:") 219 | for v in tf.get_collection(tf.GraphKeys.TARGETS): 220 | print (" variable %s: %s" % (v.name, v)) 221 | 222 | print ("="*77) 223 | 224 | 225 | def deep_model(self, wide_inputs, n_inputs, n_nodes=[100, 50], use_dropout=False): 226 | ''' 227 | Model - deep, i.e. two-layer fully connected network model 228 | ''' 229 | cc_input_var = {} 230 | cc_embed_var = {} 231 | flat_vars = [] 232 | if self.verbose: 233 | print ("--> deep model: %s categories, %d continuous" % (len(self.categorical_columns), n_inputs)) 234 | for cc, cc_size in self.categorical_columns.items(): 235 | cc_input_var[cc] = tflearn.input_data(shape=[None, 1], name="%s_in" % cc, dtype=tf.int32) 236 | # embedding layers only work on CPU! No GPU implementation in tensorflow, yet! 237 | cc_embed_var[cc] = tflearn.layers.embedding_ops.embedding(cc_input_var[cc], cc_size, 8, name="deep_%s_embed" % cc) 238 | if self.verbose: 239 | print (" %s_embed = %s" % (cc, cc_embed_var[cc])) 240 | flat_vars.append(tf.squeeze(cc_embed_var[cc], squeeze_dims=[1], name="%s_squeeze" % cc)) 241 | 242 | network = tf.concat(1, [wide_inputs] + flat_vars, name="deep_concat") 243 | for k in range(len(n_nodes)): 244 | network = tflearn.fully_connected(network, n_nodes[k], activation="relu", name="deep_fc%d" % (k+1)) 245 | if use_dropout: 246 | network = tflearn.dropout(network, 0.5, name="deep_dropout%d" % (k+1)) 247 | if self.verbose: 248 | print ("Deep model network before output %s" % network) 249 | network = tflearn.fully_connected(network, 1, activation="linear", name="deep_fc_output", bias=False) 250 | network = tf.reshape(network, [-1, 1]) # so that accuracy is binary_accuracy 251 | if self.verbose: 252 | print ("Deep model network %s" % network) 253 | return network 254 | 255 | def wide_model(self, inputs, n_inputs): 256 | ''' 257 | Model - wide, i.e. normal linear model (for logistic regression) 258 | ''' 259 | network = inputs 260 | # use fully_connected (instad of single_unit) because fc works properly with batches, whereas single_unit is 1D only 261 | network = tflearn.fully_connected(network, n_inputs, activation="linear", name="wide_linear", bias=False) # x*W (no bias) 262 | network = tf.reduce_sum(network, 1, name="reduce_sum") # batched sum, to produce logits 263 | network = tf.reshape(network, [-1, 1]) # so that accuracy is binary_accuracy 264 | if self.verbose: 265 | print ("Wide model network %s" % network) 266 | return network 267 | 268 | def prepare_input_data(self, input_data, name="", category_map=None): 269 | ''' 270 | Prepare input data dicts 271 | ''' 272 | print ("-"*40 + " Preparing %s" % name) 273 | X = input_data[self.continuous_columns].values.astype(np.float32) 274 | Y = input_data[self.label_column].values.astype(np.float32) 275 | Y = Y.reshape([-1, 1]) 276 | if self.verbose: 277 | print (" Y shape=%s, X shape=%s" % (Y.shape, X.shape)) 278 | 279 | X_dict = {"wide_X": X} 280 | 281 | if 'deep' in self.model_type: 282 | # map categorical value strings to integers 283 | td = input_data 284 | if category_map is None: 285 | category_map = {} 286 | for cc in self.categorical_columns: 287 | if not cc in td.columns: 288 | continue 289 | cc_values = sorted(td[cc].unique()) 290 | cc_max = 1+len(cc_values) 291 | cc_map = dict(zip(cc_values, range(1, cc_max))) # start from 1 to avoid 0:0 mapping (save 0 for missing) 292 | if self.verbose: 293 | print (" category %s max=%s, map=%s" % (cc, cc_max, cc_map)) 294 | category_map[cc] = cc_map 295 | 296 | td = td.replace(category_map) 297 | 298 | # bin ages (cuts off extreme values) 299 | age_bins = [ 0, 12, 18, 25, 30, 35, 40, 45, 50, 55, 60, 65, 80, 65535 ] 300 | td['age_binned'] = pd.cut(td['age'], age_bins, labels=False) 301 | td = td.replace({'age_binned': {np.nan: 0}}) 302 | print (" %d age bins: age bins = %s" % (len(age_bins), age_bins)) 303 | 304 | X_dict.update({ ("%s_in" % cc): td[cc].values.astype(np.int32).reshape([-1, 1]) for cc in self.categorical_columns}) 305 | 306 | Y_dict = {"Y": Y} 307 | if self.verbose: 308 | print ("-"*40) 309 | return X_dict, Y_dict, category_map 310 | 311 | 312 | def train(self, n_epoch=1000, snapshot_step=10, batch_size=None): 313 | 314 | self.X_dict, self.Y_dict, category_map = self.prepare_input_data(self.train_data, "train data") 315 | self.testX_dict, self.testY_dict, _ = self.prepare_input_data(self.test_data, "test data", category_map) 316 | validation_batch_size = batch_size or self.testY_dict['Y'].shape[0] 317 | batch_size = batch_size or self.Y_dict['Y'].shape[0] 318 | 319 | print ("Input data shape = %s; output data shape=%s, batch_size=%s" % (str(self.X_dict['wide_X'].shape), 320 | str(self.Y_dict['Y'].shape), 321 | batch_size)) 322 | print ("Test data shape = %s; output data shape=%s, validation_batch_size=%s" % (str(self.testX_dict['wide_X'].shape), 323 | str(self.testY_dict['Y'].shape), 324 | validation_batch_size)) 325 | print ("="*60 + " Training") 326 | self.model.fit(self.X_dict, 327 | self.Y_dict, 328 | n_epoch=n_epoch, 329 | validation_set=(self.testX_dict, self.testY_dict), 330 | snapshot_step=snapshot_step, 331 | batch_size=batch_size, 332 | validation_batch_size=validation_batch_size, 333 | show_metric=True, 334 | snapshot_epoch=False, 335 | shuffle=True, 336 | run_id=self.name, 337 | ) 338 | 339 | def evaluate(self): 340 | logits = np.array(self.model.predict(self.testX_dict)).reshape([-1]) 341 | print ("="*60 + " Evaluation") 342 | print (" logits: %s, min=%s, max=%s" % (logits.shape, logits.min(), logits.max())) 343 | probs = 1.0 / (1.0 + np.exp(-logits)) 344 | y_pred = pd.Series((probs > 0.5).astype(np.int32)) 345 | Y = pd.Series(self.testY_dict['Y'].astype(np.int32).reshape([-1])) 346 | self.confusion_matrix = self.output_confusion_matrix(Y, y_pred) 347 | print ("="*60) 348 | 349 | def output_confusion_matrix(self, y, y_pred): 350 | assert y.size == y_pred.size 351 | print("Actual IDV") 352 | print(y.value_counts()) 353 | print("Predicted IDV") 354 | print(y_pred.value_counts()) 355 | print() 356 | print("Confusion matrix:") 357 | cmat = pd.crosstab(y_pred, y, rownames=['predictions'], colnames=['actual']) 358 | print(cmat) 359 | sys.stdout.flush() 360 | return cmat 361 | 362 | #----------------------------------------------------------------------------- 363 | 364 | def CommandLine(args=None): 365 | ''' 366 | Main command line. Accepts args, to allow for simple unit testing. 367 | ''' 368 | flags = tf.app.flags 369 | FLAGS = flags.FLAGS 370 | if args: 371 | FLAGS.__init__() 372 | FLAGS.__dict__.update(args) 373 | 374 | try: 375 | flags.DEFINE_string("model_type", "wide+deep","Valid model types: {'wide', 'deep', 'wide+deep'}.") 376 | flags.DEFINE_string("run_name", None, "name for this run (defaults to model type)") 377 | flags.DEFINE_string("load_weights", None, "filename with initial weights to load") 378 | flags.DEFINE_string("checkpoints_dir", None, "name of directory where checkpoints should be saved") 379 | flags.DEFINE_integer("n_epoch", 200, "Number of training epoch steps") 380 | flags.DEFINE_integer("snapshot_step", 100, "Step number when snapshot (and validation testing) is done") 381 | flags.DEFINE_float("wide_learning_rate", 0.001, "learning rate for the wide part of the model") 382 | flags.DEFINE_float("deep_learning_rate", 0.001, "learning rate for the deep part of the model") 383 | flags.DEFINE_boolean("verbose", False, "Verbose output") 384 | except argparse.ArgumentError: 385 | pass # so that CommandLine can be run more than once, for testing 386 | 387 | twad = TFLearnWideAndDeep(model_type=FLAGS.model_type, verbose=FLAGS.verbose, 388 | name=FLAGS.run_name, wide_learning_rate=FLAGS.wide_learning_rate, 389 | deep_learning_rate=FLAGS.deep_learning_rate, 390 | checkpoints_dir=FLAGS.checkpoints_dir) 391 | twad.load_data() 392 | if FLAGS.load_weights: 393 | print ("Loading initial weights from %s" % FLAGS.load_weights) 394 | twad.model.load(FLAGS.load_weights) 395 | twad.train(n_epoch=FLAGS.n_epoch, snapshot_step=FLAGS.snapshot_step) 396 | twad.evaluate() 397 | return twad 398 | 399 | #----------------------------------------------------------------------------- 400 | # unit tests 401 | 402 | def test_wide_and_deep(): 403 | import glob 404 | tf.reset_default_graph() 405 | cdir = "test_checkpoints" 406 | if os.path.exists(cdir): 407 | os.system("rm -rf %s" % cdir) 408 | twad = CommandLine(args=dict(verbose=True, n_epoch=5, model_type="wide+deep", snapshot_step=5, 409 | wide_learning_rate=0.0001, checkpoints_dir=cdir)) 410 | cfiles = glob.glob("%s/*.tfl-*" % cdir) 411 | print ("cfiles=%s" % cfiles) 412 | assert(len(cfiles)) 413 | cm = twad.confusion_matrix.values.astype(np.float32) 414 | assert(cm[1][1]) 415 | 416 | def test_deep(): 417 | import glob 418 | tf.reset_default_graph() 419 | cdir = "test_checkpoints" 420 | if os.path.exists(cdir): 421 | os.system("rm -rf %s" % cdir) 422 | twad = CommandLine(args=dict(verbose=True, n_epoch=5, model_type="deep", snapshot_step=5, 423 | wide_learning_rate=0.0001, checkpoints_dir=cdir)) 424 | cfiles = glob.glob("%s/*.tfl-*" % cdir) 425 | print ("cfiles=%s" % cfiles) 426 | assert(len(cfiles)) 427 | cm = twad.confusion_matrix.values.astype(np.float32) 428 | assert(cm[1][1]) 429 | 430 | def test_wide(): 431 | import glob 432 | tf.reset_default_graph() 433 | cdir = "test_checkpoints" 434 | if os.path.exists(cdir): 435 | os.system("rm -rf %s" % cdir) 436 | twad = CommandLine(args=dict(verbose=True, n_epoch=5, model_type="wide", snapshot_step=5, 437 | wide_learning_rate=0.0001, checkpoints_dir=cdir)) 438 | cfiles = glob.glob("%s/*.tfl-*" % cdir) 439 | print ("cfiles=%s" % cfiles) 440 | assert(len(cfiles)) 441 | cm = twad.confusion_matrix.values.astype(np.float32) 442 | assert(cm[1][1]) 443 | 444 | #----------------------------------------------------------------------------- 445 | 446 | if __name__=="__main__": 447 | CommandLine() 448 | None 449 | --------------------------------------------------------------------------------