├── .gitignore ├── DeepFM.py ├── LICENSE ├── README.md └── example ├── DataReader.py ├── README.md ├── __init__.py ├── config.py ├── data └── README.md ├── fig ├── DNN.png ├── DeepFM.png └── FM.png ├── main.py ├── metrics.py └── output └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *.py[cod] 3 | *$py.class 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | build/ 11 | develop-eggs/ 12 | dist/ 13 | downloads/ 14 | eggs/ 15 | .eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | MANIFEST 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | .static_storage/ 55 | .media/ 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | env.bak/ 90 | venv.bak/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | 105 | # 106 | __pycache__/ 107 | example/__pycache__/ 108 | example/data/*.csv 109 | example/output/*.csv 110 | yellowfin.py 111 | -------------------------------------------------------------------------------- /DeepFM.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tensorflow implementation of DeepFM [1] 3 | 4 | Reference: 5 | [1] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, 6 | Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He. 7 | """ 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from sklearn.base import BaseEstimator, TransformerMixin 12 | from sklearn.metrics import roc_auc_score 13 | from time import time 14 | from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm 15 | from yellowfin import YFOptimizer 16 | 17 | 18 | class DeepFM(BaseEstimator, TransformerMixin): 19 | def __init__(self, feature_size, field_size, 20 | embedding_size=8, dropout_fm=[1.0, 1.0], 21 | deep_layers=[32, 32], dropout_deep=[0.5, 0.5, 0.5], 22 | deep_layers_activation=tf.nn.relu, 23 | epoch=10, batch_size=256, 24 | learning_rate=0.001, optimizer_type="adam", 25 | batch_norm=0, batch_norm_decay=0.995, 26 | verbose=False, random_seed=2016, 27 | use_fm=True, use_deep=True, 28 | loss_type="logloss", eval_metric=roc_auc_score, 29 | l2_reg=0.0, greater_is_better=True): 30 | assert (use_fm or use_deep) 31 | assert loss_type in ["logloss", "mse"], \ 32 | "loss_type can be either 'logloss' for classification task or 'mse' for regression task" 33 | 34 | self.feature_size = feature_size # denote as M, size of the feature dictionary 35 | self.field_size = field_size # denote as F, size of the feature fields 36 | self.embedding_size = embedding_size # denote as K, size of the feature embedding 37 | 38 | self.dropout_fm = dropout_fm 39 | self.deep_layers = deep_layers 40 | self.dropout_deep = dropout_deep 41 | self.deep_layers_activation = deep_layers_activation 42 | self.use_fm = use_fm 43 | self.use_deep = use_deep 44 | self.l2_reg = l2_reg 45 | 46 | self.epoch = epoch 47 | self.batch_size = batch_size 48 | self.learning_rate = learning_rate 49 | self.optimizer_type = optimizer_type 50 | 51 | self.batch_norm = batch_norm 52 | self.batch_norm_decay = batch_norm_decay 53 | 54 | self.verbose = verbose 55 | self.random_seed = random_seed 56 | self.loss_type = loss_type 57 | self.eval_metric = eval_metric 58 | self.greater_is_better = greater_is_better 59 | self.train_result, self.valid_result = [], [] 60 | 61 | self._init_graph() 62 | 63 | 64 | def _init_graph(self): 65 | self.graph = tf.Graph() 66 | with self.graph.as_default(): 67 | 68 | tf.set_random_seed(self.random_seed) 69 | 70 | self.feat_index = tf.placeholder(tf.int32, shape=[None, None], 71 | name="feat_index") # None * F 72 | self.feat_value = tf.placeholder(tf.float32, shape=[None, None], 73 | name="feat_value") # None * F 74 | self.label = tf.placeholder(tf.float32, shape=[None, 1], name="label") # None * 1 75 | self.dropout_keep_fm = tf.placeholder(tf.float32, shape=[None], name="dropout_keep_fm") 76 | self.dropout_keep_deep = tf.placeholder(tf.float32, shape=[None], name="dropout_keep_deep") 77 | self.train_phase = tf.placeholder(tf.bool, name="train_phase") 78 | 79 | self.weights = self._initialize_weights() 80 | 81 | # model 82 | self.embeddings = tf.nn.embedding_lookup(self.weights["feature_embeddings"], 83 | self.feat_index) # None * F * K 84 | feat_value = tf.reshape(self.feat_value, shape=[-1, self.field_size, 1]) 85 | self.embeddings = tf.multiply(self.embeddings, feat_value) 86 | 87 | # ---------- first order term ---------- 88 | self.y_first_order = tf.nn.embedding_lookup(self.weights["feature_bias"], self.feat_index) # None * F * 1 89 | self.y_first_order = tf.reduce_sum(tf.multiply(self.y_first_order, feat_value), 2) # None * F 90 | self.y_first_order = tf.nn.dropout(self.y_first_order, self.dropout_keep_fm[0]) # None * F 91 | 92 | # ---------- second order term --------------- 93 | # sum_square part 94 | self.summed_features_emb = tf.reduce_sum(self.embeddings, 1) # None * K 95 | self.summed_features_emb_square = tf.square(self.summed_features_emb) # None * K 96 | 97 | # square_sum part 98 | self.squared_features_emb = tf.square(self.embeddings) 99 | self.squared_sum_features_emb = tf.reduce_sum(self.squared_features_emb, 1) # None * K 100 | 101 | # second order 102 | self.y_second_order = 0.5 * tf.subtract(self.summed_features_emb_square, self.squared_sum_features_emb) # None * K 103 | self.y_second_order = tf.nn.dropout(self.y_second_order, self.dropout_keep_fm[1]) # None * K 104 | 105 | # ---------- Deep component ---------- 106 | self.y_deep = tf.reshape(self.embeddings, shape=[-1, self.field_size * self.embedding_size]) # None * (F*K) 107 | self.y_deep = tf.nn.dropout(self.y_deep, self.dropout_keep_deep[0]) 108 | for i in range(0, len(self.deep_layers)): 109 | self.y_deep = tf.add(tf.matmul(self.y_deep, self.weights["layer_%d" %i]), self.weights["bias_%d"%i]) # None * layer[i] * 1 110 | if self.batch_norm: 111 | self.y_deep = self.batch_norm_layer(self.y_deep, train_phase=self.train_phase, scope_bn="bn_%d" %i) # None * layer[i] * 1 112 | self.y_deep = self.deep_layers_activation(self.y_deep) 113 | self.y_deep = tf.nn.dropout(self.y_deep, self.dropout_keep_deep[1+i]) # dropout at each Deep layer 114 | 115 | # ---------- DeepFM ---------- 116 | if self.use_fm and self.use_deep: 117 | concat_input = tf.concat([self.y_first_order, self.y_second_order, self.y_deep], axis=1) 118 | elif self.use_fm: 119 | concat_input = tf.concat([self.y_first_order, self.y_second_order], axis=1) 120 | elif self.use_deep: 121 | concat_input = self.y_deep 122 | self.out = tf.add(tf.matmul(concat_input, self.weights["concat_projection"]), self.weights["concat_bias"]) 123 | 124 | # loss 125 | if self.loss_type == "logloss": 126 | self.out = tf.nn.sigmoid(self.out) 127 | self.loss = tf.losses.log_loss(self.label, self.out) 128 | elif self.loss_type == "mse": 129 | self.loss = tf.nn.l2_loss(tf.subtract(self.label, self.out)) 130 | # l2 regularization on weights 131 | if self.l2_reg > 0: 132 | self.loss += tf.contrib.layers.l2_regularizer( 133 | self.l2_reg)(self.weights["concat_projection"]) 134 | if self.use_deep: 135 | for i in range(len(self.deep_layers)): 136 | self.loss += tf.contrib.layers.l2_regularizer( 137 | self.l2_reg)(self.weights["layer_%d"%i]) 138 | 139 | # optimizer 140 | if self.optimizer_type == "adam": 141 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.9, beta2=0.999, 142 | epsilon=1e-8).minimize(self.loss) 143 | elif self.optimizer_type == "adagrad": 144 | self.optimizer = tf.train.AdagradOptimizer(learning_rate=self.learning_rate, 145 | initial_accumulator_value=1e-8).minimize(self.loss) 146 | elif self.optimizer_type == "gd": 147 | self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate).minimize(self.loss) 148 | elif self.optimizer_type == "momentum": 149 | self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.95).minimize( 150 | self.loss) 151 | elif self.optimizer_type == "yellowfin": 152 | self.optimizer = YFOptimizer(learning_rate=self.learning_rate, momentum=0.0).minimize( 153 | self.loss) 154 | 155 | # init 156 | self.saver = tf.train.Saver() 157 | init = tf.global_variables_initializer() 158 | self.sess = self._init_session() 159 | self.sess.run(init) 160 | 161 | # number of params 162 | total_parameters = 0 163 | for variable in self.weights.values(): 164 | shape = variable.get_shape() 165 | variable_parameters = 1 166 | for dim in shape: 167 | variable_parameters *= dim.value 168 | total_parameters += variable_parameters 169 | if self.verbose > 0: 170 | print("#params: %d" % total_parameters) 171 | 172 | 173 | def _init_session(self): 174 | config = tf.ConfigProto(device_count={"gpu": 0}) 175 | config.gpu_options.allow_growth = True 176 | return tf.Session(config=config) 177 | 178 | 179 | def _initialize_weights(self): 180 | weights = dict() 181 | 182 | # embeddings 183 | weights["feature_embeddings"] = tf.Variable( 184 | tf.random_normal([self.feature_size, self.embedding_size], 0.0, 0.01), 185 | name="feature_embeddings") # feature_size * K 186 | weights["feature_bias"] = tf.Variable( 187 | tf.random_uniform([self.feature_size, 1], 0.0, 1.0), name="feature_bias") # feature_size * 1 188 | 189 | # deep layers 190 | num_layer = len(self.deep_layers) 191 | input_size = self.field_size * self.embedding_size 192 | glorot = np.sqrt(2.0 / (input_size + self.deep_layers[0])) 193 | weights["layer_0"] = tf.Variable( 194 | np.random.normal(loc=0, scale=glorot, size=(input_size, self.deep_layers[0])), dtype=np.float32) 195 | weights["bias_0"] = tf.Variable(np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[0])), 196 | dtype=np.float32) # 1 * layers[0] 197 | for i in range(1, num_layer): 198 | glorot = np.sqrt(2.0 / (self.deep_layers[i-1] + self.deep_layers[i])) 199 | weights["layer_%d" % i] = tf.Variable( 200 | np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[i-1], self.deep_layers[i])), 201 | dtype=np.float32) # layers[i-1] * layers[i] 202 | weights["bias_%d" % i] = tf.Variable( 203 | np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[i])), 204 | dtype=np.float32) # 1 * layer[i] 205 | 206 | # final concat projection layer 207 | if self.use_fm and self.use_deep: 208 | input_size = self.field_size + self.embedding_size + self.deep_layers[-1] 209 | elif self.use_fm: 210 | input_size = self.field_size + self.embedding_size 211 | elif self.use_deep: 212 | input_size = self.deep_layers[-1] 213 | glorot = np.sqrt(2.0 / (input_size + 1)) 214 | weights["concat_projection"] = tf.Variable( 215 | np.random.normal(loc=0, scale=glorot, size=(input_size, 1)), 216 | dtype=np.float32) # layers[i-1]*layers[i] 217 | weights["concat_bias"] = tf.Variable(tf.constant(0.01), dtype=np.float32) 218 | 219 | return weights 220 | 221 | 222 | def batch_norm_layer(self, x, train_phase, scope_bn): 223 | bn_train = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None, 224 | is_training=True, reuse=None, trainable=True, scope=scope_bn) 225 | bn_inference = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None, 226 | is_training=False, reuse=True, trainable=True, scope=scope_bn) 227 | z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference) 228 | return z 229 | 230 | 231 | def get_batch(self, Xi, Xv, y, batch_size, index): 232 | start = index * batch_size 233 | end = (index+1) * batch_size 234 | end = end if end < len(y) else len(y) 235 | return Xi[start:end], Xv[start:end], [[y_] for y_ in y[start:end]] 236 | 237 | 238 | # shuffle three lists simutaneously 239 | def shuffle_in_unison_scary(self, a, b, c): 240 | rng_state = np.random.get_state() 241 | np.random.shuffle(a) 242 | np.random.set_state(rng_state) 243 | np.random.shuffle(b) 244 | np.random.set_state(rng_state) 245 | np.random.shuffle(c) 246 | 247 | 248 | def fit_on_batch(self, Xi, Xv, y): 249 | feed_dict = {self.feat_index: Xi, 250 | self.feat_value: Xv, 251 | self.label: y, 252 | self.dropout_keep_fm: self.dropout_fm, 253 | self.dropout_keep_deep: self.dropout_deep, 254 | self.train_phase: True} 255 | loss, opt = self.sess.run((self.loss, self.optimizer), feed_dict=feed_dict) 256 | return loss 257 | 258 | 259 | def fit(self, Xi_train, Xv_train, y_train, 260 | Xi_valid=None, Xv_valid=None, y_valid=None, 261 | early_stopping=False, refit=False): 262 | """ 263 | :param Xi_train: [[ind1_1, ind1_2, ...], [ind2_1, ind2_2, ...], ..., [indi_1, indi_2, ..., indi_j, ...], ...] 264 | indi_j is the feature index of feature field j of sample i in the training set 265 | :param Xv_train: [[val1_1, val1_2, ...], [val2_1, val2_2, ...], ..., [vali_1, vali_2, ..., vali_j, ...], ...] 266 | vali_j is the feature value of feature field j of sample i in the training set 267 | vali_j can be either binary (1/0, for binary/categorical features) or float (e.g., 10.24, for numerical features) 268 | :param y_train: label of each sample in the training set 269 | :param Xi_valid: list of list of feature indices of each sample in the validation set 270 | :param Xv_valid: list of list of feature values of each sample in the validation set 271 | :param y_valid: label of each sample in the validation set 272 | :param early_stopping: perform early stopping or not 273 | :param refit: refit the model on the train+valid dataset or not 274 | :return: None 275 | """ 276 | has_valid = Xv_valid is not None 277 | for epoch in range(self.epoch): 278 | t1 = time() 279 | self.shuffle_in_unison_scary(Xi_train, Xv_train, y_train) 280 | total_batch = int(len(y_train) / self.batch_size) 281 | for i in range(total_batch): 282 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi_train, Xv_train, y_train, self.batch_size, i) 283 | self.fit_on_batch(Xi_batch, Xv_batch, y_batch) 284 | 285 | # evaluate training and validation datasets 286 | train_result = self.evaluate(Xi_train, Xv_train, y_train) 287 | self.train_result.append(train_result) 288 | if has_valid: 289 | valid_result = self.evaluate(Xi_valid, Xv_valid, y_valid) 290 | self.valid_result.append(valid_result) 291 | if self.verbose > 0 and epoch % self.verbose == 0: 292 | if has_valid: 293 | print("[%d] train-result=%.4f, valid-result=%.4f [%.1f s]" 294 | % (epoch + 1, train_result, valid_result, time() - t1)) 295 | else: 296 | print("[%d] train-result=%.4f [%.1f s]" 297 | % (epoch + 1, train_result, time() - t1)) 298 | if has_valid and early_stopping and self.training_termination(self.valid_result): 299 | break 300 | 301 | # fit a few more epoch on train+valid until result reaches the best_train_score 302 | if has_valid and refit: 303 | if self.greater_is_better: 304 | best_valid_score = max(self.valid_result) 305 | else: 306 | best_valid_score = min(self.valid_result) 307 | best_epoch = self.valid_result.index(best_valid_score) 308 | best_train_score = self.train_result[best_epoch] 309 | Xi_train = Xi_train + Xi_valid 310 | Xv_train = Xv_train + Xv_valid 311 | y_train = y_train + y_valid 312 | for epoch in range(100): 313 | self.shuffle_in_unison_scary(Xi_train, Xv_train, y_train) 314 | total_batch = int(len(y_train) / self.batch_size) 315 | for i in range(total_batch): 316 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi_train, Xv_train, y_train, 317 | self.batch_size, i) 318 | self.fit_on_batch(Xi_batch, Xv_batch, y_batch) 319 | # check 320 | train_result = self.evaluate(Xi_train, Xv_train, y_train) 321 | if abs(train_result - best_train_score) < 0.001 or \ 322 | (self.greater_is_better and train_result > best_train_score) or \ 323 | ((not self.greater_is_better) and train_result < best_train_score): 324 | break 325 | 326 | 327 | def training_termination(self, valid_result): 328 | if len(valid_result) > 5: 329 | if self.greater_is_better: 330 | if valid_result[-1] < valid_result[-2] and \ 331 | valid_result[-2] < valid_result[-3] and \ 332 | valid_result[-3] < valid_result[-4] and \ 333 | valid_result[-4] < valid_result[-5]: 334 | return True 335 | else: 336 | if valid_result[-1] > valid_result[-2] and \ 337 | valid_result[-2] > valid_result[-3] and \ 338 | valid_result[-3] > valid_result[-4] and \ 339 | valid_result[-4] > valid_result[-5]: 340 | return True 341 | return False 342 | 343 | 344 | def predict(self, Xi, Xv): 345 | """ 346 | :param Xi: list of list of feature indices of each sample in the dataset 347 | :param Xv: list of list of feature values of each sample in the dataset 348 | :return: predicted probability of each sample 349 | """ 350 | # dummy y 351 | dummy_y = [1] * len(Xi) 352 | batch_index = 0 353 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index) 354 | y_pred = None 355 | while len(Xi_batch) > 0: 356 | num_batch = len(y_batch) 357 | feed_dict = {self.feat_index: Xi_batch, 358 | self.feat_value: Xv_batch, 359 | self.label: y_batch, 360 | self.dropout_keep_fm: [1.0] * len(self.dropout_fm), 361 | self.dropout_keep_deep: [1.0] * len(self.dropout_deep), 362 | self.train_phase: False} 363 | batch_out = self.sess.run(self.out, feed_dict=feed_dict) 364 | 365 | if batch_index == 0: 366 | y_pred = np.reshape(batch_out, (num_batch,)) 367 | else: 368 | y_pred = np.concatenate((y_pred, np.reshape(batch_out, (num_batch,)))) 369 | 370 | batch_index += 1 371 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index) 372 | 373 | return y_pred 374 | 375 | 376 | def evaluate(self, Xi, Xv, y): 377 | """ 378 | :param Xi: list of list of feature indices of each sample in the dataset 379 | :param Xv: list of list of feature values of each sample in the dataset 380 | :param y: label of each sample in the dataset 381 | :return: metric of the evaluation 382 | """ 383 | y_pred = self.predict(Xi, Xv) 384 | return self.eval_metric(y, y_pred) 385 | 386 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Chenglong Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow-DeepFM 2 | 3 | This project includes a Tensorflow implementation of DeepFM [1]. 4 | 5 | # NEWS 6 | - A modified version of DeepFM is used to win the 4th Place for [Mercari Price Suggestion Challenge on Kaggle](https://www.kaggle.com/c/mercari-price-suggestion-challenge). See the slide [here](https://github.com/ChenglongChen/tensorflow-XNN/blob/master/doc/Mercari_Price_Suggesion_Competition_ChenglongChen_4th_Place.pdf) how we deal with fields containing sequences, how we incoporate various FM components into deep model. 7 | 8 | # Usage 9 | ## Input Format 10 | This implementation requires the input data in the following format: 11 | - [ ] **Xi**: *[[ind1_1, ind1_2, ...], [ind2_1, ind2_2, ...], ..., [indi_1, indi_2, ..., indi_j, ...], ...]* 12 | - *indi_j* is the feature index of feature field *j* of sample *i* in the dataset 13 | - [ ] **Xv**: *[[val1_1, val1_2, ...], [val2_1, val2_2, ...], ..., [vali_1, vali_2, ..., vali_j, ...], ...]* 14 | - *vali_j* is the feature value of feature field *j* of sample *i* in the dataset 15 | - *vali_j* can be either binary (1/0, for binary/categorical features) or float (e.g., 10.24, for numerical features) 16 | - [ ] **y**: target of each sample in the dataset (1/0 for classification, numeric number for regression) 17 | 18 | Please see `example/DataReader.py` an example how to prepare the data in required format for DeepFM. 19 | 20 | ## Init and train a model 21 | ``` 22 | import tensorflow as tf 23 | from sklearn.metrics import roc_auc_score 24 | 25 | # params 26 | dfm_params = { 27 | "use_fm": True, 28 | "use_deep": True, 29 | "embedding_size": 8, 30 | "dropout_fm": [1.0, 1.0], 31 | "deep_layers": [32, 32], 32 | "dropout_deep": [0.5, 0.5, 0.5], 33 | "deep_layers_activation": tf.nn.relu, 34 | "epoch": 30, 35 | "batch_size": 1024, 36 | "learning_rate": 0.001, 37 | "optimizer_type": "adam", 38 | "batch_norm": 1, 39 | "batch_norm_decay": 0.995, 40 | "l2_reg": 0.01, 41 | "verbose": True, 42 | "eval_metric": roc_auc_score, 43 | "random_seed": 2017 44 | } 45 | 46 | # prepare training and validation data in the required format 47 | Xi_train, Xv_train, y_train = prepare(...) 48 | Xi_valid, Xv_valid, y_valid = prepare(...) 49 | 50 | # init a DeepFM model 51 | dfm = DeepFM(**dfm_params) 52 | 53 | # fit a DeepFM model 54 | dfm.fit(Xi_train, Xv_train, y_train) 55 | 56 | # make prediction 57 | dfm.predict(Xi_valid, Xv_valid) 58 | 59 | # evaluate a trained model 60 | dfm.evaluate(Xi_valid, Xv_valid, y_valid) 61 | ``` 62 | 63 | You can use early_stopping in the training as follow 64 | ``` 65 | dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True) 66 | ``` 67 | 68 | You can refit the model on the whole training and validation set as follow 69 | ``` 70 | dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True, refit=True) 71 | ``` 72 | 73 | You can use the FM or DNN part only by setting the parameter `use_fm` or `use_dnn` to `False`. 74 | 75 | ## Regression 76 | This implementation also supports regression task. To use DeepFM for regression, you can set `loss_type` as `mse`. Accordingly, you should use eval_metric for regression, e.g., mse or mae. 77 | 78 | # Example 79 | Folder `example` includes an example usage of DeepFM/FM/DNN models for [Porto Seguro's Safe Driver Prediction competition on Kaggle](https://www.kaggle.com/c/porto-seguro-safe-driver-prediction). 80 | 81 | Please download the data from the competition website and put them into the `example/data` folder. 82 | 83 | To train DeepFM model for this dataset, run 84 | 85 | ``` 86 | $ cd example 87 | $ python main.py 88 | ``` 89 | Please see `example/DataReader.py` how to parse the raw dataset into the required format for DeepFM. 90 | 91 | ## Performance 92 | 93 | ### DeepFM 94 | 95 | ![dfm](example/fig/DeepFM.png) 96 | 97 | ### FM 98 | 99 | ![fm](example/fig/FM.png) 100 | 101 | ### DNN 102 | 103 | ![dnn](example/fig/DNN.png) 104 | 105 | ## Some tips 106 | - [ ] You should tune the parameters for each model in order to get reasonable performance. 107 | - [ ] You can also try to ensemble these models or ensemble them with other models (e.g., XGBoost or LightGBM). 108 | 109 | # Reference 110 | [1] *DeepFM: A Factorization-Machine based Neural Network for CTR Prediction*, Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He. 111 | 112 | # Acknowledgments 113 | This project gets inspirations from the following projects: 114 | - [ ] He Xiangnan's [neural_factorization_machine](https://github.com/hexiangnan/neural_factorization_machine) 115 | - [ ] Jian Zhang's [YellowFin](https://github.com/JianGoForIt/YellowFin) (yellowfin optimizer is taken from here) 116 | 117 | # License 118 | MIT -------------------------------------------------------------------------------- /example/DataReader.py: -------------------------------------------------------------------------------- 1 | """ 2 | A data parser for Porto Seguro's Safe Driver Prediction competition's dataset. 3 | URL: https://www.kaggle.com/c/porto-seguro-safe-driver-prediction 4 | """ 5 | import pandas as pd 6 | 7 | 8 | class FeatureDictionary(object): 9 | def __init__(self, trainfile=None, testfile=None, 10 | dfTrain=None, dfTest=None, numeric_cols=[], ignore_cols=[]): 11 | assert not ((trainfile is None) and (dfTrain is None)), "trainfile or dfTrain at least one is set" 12 | assert not ((trainfile is not None) and (dfTrain is not None)), "only one can be set" 13 | assert not ((testfile is None) and (dfTest is None)), "testfile or dfTest at least one is set" 14 | assert not ((testfile is not None) and (dfTest is not None)), "only one can be set" 15 | self.trainfile = trainfile 16 | self.testfile = testfile 17 | self.dfTrain = dfTrain 18 | self.dfTest = dfTest 19 | self.numeric_cols = numeric_cols 20 | self.ignore_cols = ignore_cols 21 | self.gen_feat_dict() 22 | 23 | def gen_feat_dict(self): 24 | if self.dfTrain is None: 25 | dfTrain = pd.read_csv(self.trainfile) 26 | else: 27 | dfTrain = self.dfTrain 28 | if self.dfTest is None: 29 | dfTest = pd.read_csv(self.testfile) 30 | else: 31 | dfTest = self.dfTest 32 | df = pd.concat([dfTrain, dfTest]) 33 | self.feat_dict = {} 34 | tc = 0 35 | for col in df.columns: 36 | if col in self.ignore_cols: 37 | continue 38 | if col in self.numeric_cols: 39 | # map to a single index 40 | self.feat_dict[col] = tc 41 | tc += 1 42 | else: 43 | us = df[col].unique() 44 | self.feat_dict[col] = dict(zip(us, range(tc, len(us)+tc))) 45 | tc += len(us) 46 | self.feat_dim = tc 47 | 48 | 49 | class DataParser(object): 50 | def __init__(self, feat_dict): 51 | self.feat_dict = feat_dict 52 | 53 | def parse(self, infile=None, df=None, has_label=False): 54 | assert not ((infile is None) and (df is None)), "infile or df at least one is set" 55 | assert not ((infile is not None) and (df is not None)), "only one can be set" 56 | if infile is None: 57 | dfi = df.copy() 58 | else: 59 | dfi = pd.read_csv(infile) 60 | if has_label: 61 | y = dfi["target"].values.tolist() 62 | dfi.drop(["id", "target"], axis=1, inplace=True) 63 | else: 64 | ids = dfi["id"].values.tolist() 65 | dfi.drop(["id"], axis=1, inplace=True) 66 | # dfi for feature index 67 | # dfv for feature value which can be either binary (1/0) or float (e.g., 10.24) 68 | dfv = dfi.copy() 69 | for col in dfi.columns: 70 | if col in self.feat_dict.ignore_cols: 71 | dfi.drop(col, axis=1, inplace=True) 72 | dfv.drop(col, axis=1, inplace=True) 73 | continue 74 | if col in self.feat_dict.numeric_cols: 75 | dfi[col] = self.feat_dict.feat_dict[col] 76 | else: 77 | dfi[col] = dfi[col].map(self.feat_dict.feat_dict[col]) 78 | dfv[col] = 1. 79 | 80 | # list of list of feature indices of each sample in the dataset 81 | Xi = dfi.values.tolist() 82 | # list of list of feature values of each sample in the dataset 83 | Xv = dfv.values.tolist() 84 | if has_label: 85 | return Xi, Xv, y 86 | else: 87 | return Xi, Xv, ids 88 | 89 | -------------------------------------------------------------------------------- /example/README.md: -------------------------------------------------------------------------------- 1 | 2 | An example usage of DeepFM/FM/DNN models for [Porto Seguro's Safe Driver Prediction competition on Kaggle](https://www.kaggle.com/c/porto-seguro-safe-driver-prediction). 3 | -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-DeepFM/a43dd5ff1f61a275c8a6e0fd659e200af64093cc/example/__init__.py -------------------------------------------------------------------------------- /example/config.py: -------------------------------------------------------------------------------- 1 | 2 | # set the path-to-files 3 | TRAIN_FILE = "./data/train.csv" 4 | TEST_FILE = "./data/test.csv" 5 | 6 | SUB_DIR = "./output" 7 | 8 | 9 | NUM_SPLITS = 3 10 | RANDOM_SEED = 2017 11 | 12 | # types of columns of the dataset dataframe 13 | CATEGORICAL_COLS = [ 14 | # 'ps_ind_02_cat', 'ps_ind_04_cat', 'ps_ind_05_cat', 15 | # 'ps_car_01_cat', 'ps_car_02_cat', 'ps_car_03_cat', 16 | # 'ps_car_04_cat', 'ps_car_05_cat', 'ps_car_06_cat', 17 | # 'ps_car_07_cat', 'ps_car_08_cat', 'ps_car_09_cat', 18 | # 'ps_car_10_cat', 'ps_car_11_cat', 19 | ] 20 | 21 | NUMERIC_COLS = [ 22 | # # binary 23 | # "ps_ind_06_bin", "ps_ind_07_bin", "ps_ind_08_bin", 24 | # "ps_ind_09_bin", "ps_ind_10_bin", "ps_ind_11_bin", 25 | # "ps_ind_12_bin", "ps_ind_13_bin", "ps_ind_16_bin", 26 | # "ps_ind_17_bin", "ps_ind_18_bin", 27 | # "ps_calc_15_bin", "ps_calc_16_bin", "ps_calc_17_bin", 28 | # "ps_calc_18_bin", "ps_calc_19_bin", "ps_calc_20_bin", 29 | # numeric 30 | "ps_reg_01", "ps_reg_02", "ps_reg_03", 31 | "ps_car_12", "ps_car_13", "ps_car_14", "ps_car_15", 32 | 33 | # feature engineering 34 | "missing_feat", "ps_car_13_x_ps_reg_03", 35 | ] 36 | 37 | IGNORE_COLS = [ 38 | "id", "target", 39 | "ps_calc_01", "ps_calc_02", "ps_calc_03", "ps_calc_04", 40 | "ps_calc_05", "ps_calc_06", "ps_calc_07", "ps_calc_08", 41 | "ps_calc_09", "ps_calc_10", "ps_calc_11", "ps_calc_12", 42 | "ps_calc_13", "ps_calc_14", 43 | "ps_calc_15_bin", "ps_calc_16_bin", "ps_calc_17_bin", 44 | "ps_calc_18_bin", "ps_calc_19_bin", "ps_calc_20_bin" 45 | ] 46 | -------------------------------------------------------------------------------- /example/data/README.md: -------------------------------------------------------------------------------- 1 | 2 | Please download the data from the [competition website](https://www.kaggle.com/c/porto-seguro-safe-driver-prediction) and put them here. 3 | -------------------------------------------------------------------------------- /example/fig/DNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-DeepFM/a43dd5ff1f61a275c8a6e0fd659e200af64093cc/example/fig/DNN.png -------------------------------------------------------------------------------- /example/fig/DeepFM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-DeepFM/a43dd5ff1f61a275c8a6e0fd659e200af64093cc/example/fig/DeepFM.png -------------------------------------------------------------------------------- /example/fig/FM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenglongChen/tensorflow-DeepFM/a43dd5ff1f61a275c8a6e0fd659e200af64093cc/example/fig/FM.png -------------------------------------------------------------------------------- /example/main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import tensorflow as tf 8 | from matplotlib import pyplot as plt 9 | from sklearn.metrics import make_scorer 10 | from sklearn.model_selection import StratifiedKFold 11 | 12 | import config 13 | from metrics import gini_norm 14 | from DataReader import FeatureDictionary, DataParser 15 | sys.path.append("..") 16 | from DeepFM import DeepFM 17 | 18 | gini_scorer = make_scorer(gini_norm, greater_is_better=True, needs_proba=True) 19 | 20 | 21 | def _load_data(): 22 | 23 | dfTrain = pd.read_csv(config.TRAIN_FILE) 24 | dfTest = pd.read_csv(config.TEST_FILE) 25 | 26 | def preprocess(df): 27 | cols = [c for c in df.columns if c not in ["id", "target"]] 28 | df["missing_feat"] = np.sum((df[cols] == -1).values, axis=1) 29 | df["ps_car_13_x_ps_reg_03"] = df["ps_car_13"] * df["ps_reg_03"] 30 | return df 31 | 32 | dfTrain = preprocess(dfTrain) 33 | dfTest = preprocess(dfTest) 34 | 35 | cols = [c for c in dfTrain.columns if c not in ["id", "target"]] 36 | cols = [c for c in cols if (not c in config.IGNORE_COLS)] 37 | 38 | X_train = dfTrain[cols].values 39 | y_train = dfTrain["target"].values 40 | X_test = dfTest[cols].values 41 | ids_test = dfTest["id"].values 42 | cat_features_indices = [i for i,c in enumerate(cols) if c in config.CATEGORICAL_COLS] 43 | 44 | return dfTrain, dfTest, X_train, y_train, X_test, ids_test, cat_features_indices 45 | 46 | 47 | def _run_base_model_dfm(dfTrain, dfTest, folds, dfm_params): 48 | fd = FeatureDictionary(dfTrain=dfTrain, dfTest=dfTest, 49 | numeric_cols=config.NUMERIC_COLS, 50 | ignore_cols=config.IGNORE_COLS) 51 | data_parser = DataParser(feat_dict=fd) 52 | Xi_train, Xv_train, y_train = data_parser.parse(df=dfTrain, has_label=True) 53 | Xi_test, Xv_test, ids_test = data_parser.parse(df=dfTest) 54 | 55 | dfm_params["feature_size"] = fd.feat_dim 56 | dfm_params["field_size"] = len(Xi_train[0]) 57 | 58 | y_train_meta = np.zeros((dfTrain.shape[0], 1), dtype=float) 59 | y_test_meta = np.zeros((dfTest.shape[0], 1), dtype=float) 60 | _get = lambda x, l: [x[i] for i in l] 61 | gini_results_cv = np.zeros(len(folds), dtype=float) 62 | gini_results_epoch_train = np.zeros((len(folds), dfm_params["epoch"]), dtype=float) 63 | gini_results_epoch_valid = np.zeros((len(folds), dfm_params["epoch"]), dtype=float) 64 | for i, (train_idx, valid_idx) in enumerate(folds): 65 | Xi_train_, Xv_train_, y_train_ = _get(Xi_train, train_idx), _get(Xv_train, train_idx), _get(y_train, train_idx) 66 | Xi_valid_, Xv_valid_, y_valid_ = _get(Xi_train, valid_idx), _get(Xv_train, valid_idx), _get(y_train, valid_idx) 67 | 68 | dfm = DeepFM(**dfm_params) 69 | dfm.fit(Xi_train_, Xv_train_, y_train_, Xi_valid_, Xv_valid_, y_valid_) 70 | 71 | y_train_meta[valid_idx,0] = dfm.predict(Xi_valid_, Xv_valid_) 72 | y_test_meta[:,0] += dfm.predict(Xi_test, Xv_test) 73 | 74 | gini_results_cv[i] = gini_norm(y_valid_, y_train_meta[valid_idx]) 75 | gini_results_epoch_train[i] = dfm.train_result 76 | gini_results_epoch_valid[i] = dfm.valid_result 77 | 78 | y_test_meta /= float(len(folds)) 79 | 80 | # save result 81 | if dfm_params["use_fm"] and dfm_params["use_deep"]: 82 | clf_str = "DeepFM" 83 | elif dfm_params["use_fm"]: 84 | clf_str = "FM" 85 | elif dfm_params["use_deep"]: 86 | clf_str = "DNN" 87 | print("%s: %.5f (%.5f)"%(clf_str, gini_results_cv.mean(), gini_results_cv.std())) 88 | filename = "%s_Mean%.5f_Std%.5f.csv"%(clf_str, gini_results_cv.mean(), gini_results_cv.std()) 89 | _make_submission(ids_test, y_test_meta, filename) 90 | 91 | _plot_fig(gini_results_epoch_train, gini_results_epoch_valid, clf_str) 92 | 93 | return y_train_meta, y_test_meta 94 | 95 | 96 | def _make_submission(ids, y_pred, filename="submission.csv"): 97 | pd.DataFrame({"id": ids, "target": y_pred.flatten()}).to_csv( 98 | os.path.join(config.SUB_DIR, filename), index=False, float_format="%.5f") 99 | 100 | 101 | def _plot_fig(train_results, valid_results, model_name): 102 | colors = ["red", "blue", "green"] 103 | xs = np.arange(1, train_results.shape[1]+1) 104 | plt.figure() 105 | legends = [] 106 | for i in range(train_results.shape[0]): 107 | plt.plot(xs, train_results[i], color=colors[i], linestyle="solid", marker="o") 108 | plt.plot(xs, valid_results[i], color=colors[i], linestyle="dashed", marker="o") 109 | legends.append("train-%d"%(i+1)) 110 | legends.append("valid-%d"%(i+1)) 111 | plt.xlabel("Epoch") 112 | plt.ylabel("Normalized Gini") 113 | plt.title("%s"%model_name) 114 | plt.legend(legends) 115 | plt.savefig("./fig/%s.png"%model_name) 116 | plt.close() 117 | 118 | 119 | # load data 120 | dfTrain, dfTest, X_train, y_train, X_test, ids_test, cat_features_indices = _load_data() 121 | 122 | # folds 123 | folds = list(StratifiedKFold(n_splits=config.NUM_SPLITS, shuffle=True, 124 | random_state=config.RANDOM_SEED).split(X_train, y_train)) 125 | 126 | 127 | # ------------------ DeepFM Model ------------------ 128 | # params 129 | dfm_params = { 130 | "use_fm": True, 131 | "use_deep": True, 132 | "embedding_size": 8, 133 | "dropout_fm": [1.0, 1.0], 134 | "deep_layers": [32, 32], 135 | "dropout_deep": [0.5, 0.5, 0.5], 136 | "deep_layers_activation": tf.nn.relu, 137 | "epoch": 30, 138 | "batch_size": 1024, 139 | "learning_rate": 0.001, 140 | "optimizer_type": "adam", 141 | "batch_norm": 1, 142 | "batch_norm_decay": 0.995, 143 | "l2_reg": 0.01, 144 | "verbose": True, 145 | "eval_metric": gini_norm, 146 | "random_seed": config.RANDOM_SEED 147 | } 148 | y_train_dfm, y_test_dfm = _run_base_model_dfm(dfTrain, dfTest, folds, dfm_params) 149 | 150 | # ------------------ FM Model ------------------ 151 | fm_params = dfm_params.copy() 152 | fm_params["use_deep"] = False 153 | y_train_fm, y_test_fm = _run_base_model_dfm(dfTrain, dfTest, folds, fm_params) 154 | 155 | 156 | # ------------------ DNN Model ------------------ 157 | dnn_params = dfm_params.copy() 158 | dnn_params["use_fm"] = False 159 | y_train_dnn, y_test_dnn = _run_base_model_dfm(dfTrain, dfTest, folds, dnn_params) 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /example/metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | def gini(actual, pred): 5 | assert (len(actual) == len(pred)) 6 | all = np.asarray(np.c_[actual, pred, np.arange(len(actual))], dtype=np.float) 7 | all = all[np.lexsort((all[:, 2], -1 * all[:, 1]))] 8 | totalLosses = all[:, 0].sum() 9 | giniSum = all[:, 0].cumsum().sum() / totalLosses 10 | 11 | giniSum -= (len(actual) + 1) / 2. 12 | return giniSum / len(actual) 13 | 14 | def gini_norm(actual, pred): 15 | return gini(actual, pred) / gini(actual, actual) 16 | -------------------------------------------------------------------------------- /example/output/README.md: -------------------------------------------------------------------------------- 1 | 2 | Submissions are saved here. 3 | --------------------------------------------------------------------------------