├── auc_flop.png ├── start_train_avazu.sh ├── start_train_criteo.sh ├── .github └── pull_request_template.md ├── bash ├── FFM.sh ├── FM.sh ├── LR.sh ├── FmFM.sh ├── DCN.sh ├── FvFM.sh └── FwFM.sh ├── LICENSE-MIT ├── Contributing.md ├── README.md ├── features.py ├── data ├── avazu │ └── trans_avazu_dataset.py └── criteo │ └── trans_criteo_dataset.py ├── fmfm_optimizer.py ├── train.py ├── Code-of-Conduct.md └── models.py /auc_flop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yahoo/FmFM/main/auc_flop.png -------------------------------------------------------------------------------- /start_train_avazu.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export data_folder="avazu" 3 | 4 | sh bash/LR.sh 5 | sh bash/FM.sh 6 | sh bash/FFM.sh 7 | sh bash/FwFM.sh 8 | sh bash/FvFM.sh 9 | sh bash/FmFM.sh 10 | sh bash/DCN.sh 11 | -------------------------------------------------------------------------------- /start_train_criteo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export data_folder="criteo" 3 | 4 | sh bash/LR.sh 5 | sh bash/FM.sh 6 | sh bash/FFM.sh 7 | sh bash/FwFM.sh 8 | sh bash/FvFM.sh 9 | sh bash/FmFM.sh 10 | sh bash/DCN.sh 11 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | I confirm that this contribution is made under the terms of the license found in the root directory of this repository's source tree and that I have the authority necessary to make this contribution on behalf of its copyright owner. 2 | -------------------------------------------------------------------------------- /bash/FFM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model_type="FFM" 3 | model_folder=${data_folder}_${model_type} 4 | 5 | python3.6 train.py \ 6 | --model_dir models/${model_folder} \ 7 | --train_data_file data/${data_folder}/train.csv \ 8 | --val_data_file data/${data_folder}/val.csv \ 9 | --test_data_file data/${data_folder}/test.csv \ 10 | --batch_size 1024 \ 11 | --train_epoch 20 \ 12 | --max_steps 50000 \ 13 | --l2_linear 1e-5 \ 14 | --l2_latent 1e-5 \ 15 | --l2_r 1e-5 \ 16 | --learning_rate 1e-4 \ 17 | --default_feat_dim 16 \ 18 | --feature_meta data/${data_folder}/features.json \ 19 | --feature_dict data/${data_folder}/feature_index \ 20 | --model_type $model_type -------------------------------------------------------------------------------- /bash/FM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model_type="FM" 3 | model_folder=${data_folder}_${model_type} 4 | 5 | python3.6 train.py \ 6 | --model_dir models/${model_folder} \ 7 | --train_data_file data/${data_folder}/train.csv \ 8 | --val_data_file data/${data_folder}/val.csv \ 9 | --test_data_file data/${data_folder}/test.csv \ 10 | --batch_size 1024 \ 11 | --train_epoch 20 \ 12 | --max_steps 50000 \ 13 | --l2_linear 1e-5 \ 14 | --l2_latent 1e-5 \ 15 | --l2_r 1e-5 \ 16 | --learning_rate 1e-4 \ 17 | --default_feat_dim 16 \ 18 | --feature_meta data/${data_folder}/features.json \ 19 | --feature_dict data/${data_folder}/feature_index \ 20 | --model_type $model_type -------------------------------------------------------------------------------- /bash/LR.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model_type="LR" 3 | model_folder=${data_folder}_${model_type} 4 | 5 | python3.6 train.py \ 6 | --model_dir models/${model_folder} \ 7 | --train_data_file data/${data_folder}/train.csv \ 8 | --val_data_file data/${data_folder}/val.csv \ 9 | --test_data_file data/${data_folder}/test.csv \ 10 | --batch_size 1024 \ 11 | --train_epoch 20 \ 12 | --max_steps 50000 \ 13 | --l2_linear 1e-5 \ 14 | --l2_latent 1e-5 \ 15 | --l2_r 1e-5 \ 16 | --learning_rate 1e-4 \ 17 | --default_feat_dim 16 \ 18 | --feature_meta data/${data_folder}/features.json \ 19 | --feature_dict data/${data_folder}/feature_index \ 20 | --model_type $model_type -------------------------------------------------------------------------------- /bash/FmFM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model_type="FmFM" 3 | model_folder=${data_folder}_${model_type} 4 | 5 | python3.6 train.py \ 6 | --model_dir models/${model_folder} \ 7 | --train_data_file data/${data_folder}/train.csv \ 8 | --val_data_file data/${data_folder}/val.csv \ 9 | --test_data_file data/${data_folder}/test.csv \ 10 | --batch_size 1024 \ 11 | --train_epoch 20 \ 12 | --max_steps 50000 \ 13 | --l2_linear 1e-5 \ 14 | --l2_latent 1e-5 \ 15 | --l2_r 1e-5 \ 16 | --learning_rate 1e-4 \ 17 | --default_feat_dim 16 \ 18 | --feature_meta data/${data_folder}/features.json \ 19 | --feature_dict data/${data_folder}/feature_index \ 20 | --model_type $model_type -------------------------------------------------------------------------------- /bash/DCN.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model_type="DCN" 3 | model_folder=${data_folder}_${model_type} 4 | 5 | python3.6 train.py \ 6 | --model_dir models/${model_folder} \ 7 | --train_data_file data/${data_folder}/train.csv \ 8 | --val_data_file data/${data_folder}/val.csv \ 9 | --test_data_file data/${data_folder}/test.csv \ 10 | --batch_size 1024 \ 11 | --train_epoch 20 \ 12 | --max_steps 50000 \ 13 | --l2_linear 1e-5 \ 14 | --l2_latent 1e-5 \ 15 | --l2_r 1e-5 \ 16 | --learning_rate 1e-4 \ 17 | --default_feat_dim 16 \ 18 | --feature_meta data/${data_folder}/features.json \ 19 | --feature_dict data/${data_folder}/feature_index \ 20 | --model_type $model_type 21 | 22 | -------------------------------------------------------------------------------- /bash/FvFM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model_type="FvFM" 3 | model_folder=${data_folder}_${model_type} 4 | 5 | python3.6 train.py \ 6 | --model_dir models/${model_folder} \ 7 | --train_data_file data/${data_folder}/train.csv \ 8 | --val_data_file data/${data_folder}/val.csv \ 9 | --test_data_file data/${data_folder}/test.csv \ 10 | --batch_size 1024 \ 11 | --train_epoch 20 \ 12 | --max_steps 50000 \ 13 | --l2_linear 1e-5 \ 14 | --l2_latent 1e-5 \ 15 | --l2_r 1e-5 \ 16 | --learning_rate 1e-4 \ 17 | --default_feat_dim 16 \ 18 | --feature_meta data/${data_folder}/features.json \ 19 | --feature_dict data/${data_folder}/feature_index \ 20 | --model_type $model_type 21 | 22 | -------------------------------------------------------------------------------- /bash/FwFM.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | model_type="FwFM" 3 | model_folder=${data_folder}_${model_type} 4 | 5 | python3.6 train.py \ 6 | --model_dir models/${model_folder} \ 7 | --train_data_file data/${data_folder}/train.csv \ 8 | --val_data_file data/${data_folder}/val.csv \ 9 | --test_data_file data/${data_folder}/test.csv \ 10 | --batch_size 1024 \ 11 | --train_epoch 20 \ 12 | --max_steps 50000 \ 13 | --l2_linear 1e-5 \ 14 | --l2_latent 1e-5 \ 15 | --l2_r 1e-5 \ 16 | --learning_rate 1e-4 \ 17 | --default_feat_dim 16 \ 18 | --feature_meta data/${data_folder}/features.json \ 19 | --feature_dict data/${data_folder}/feature_index \ 20 | --model_type $model_type 21 | 22 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright Verizon Media 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /Contributing.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | First, thanks for taking the time to contribute to our project! There are many ways you can help out. 3 | 4 | ### Questions 5 | 6 | If you have a question that needs an answer, [create an issue](https://help.github.com/articles/creating-an-issue/), and label it as a question. 7 | 8 | ### Issues for bugs or feature requests 9 | 10 | If you encounter any bugs in the code, or want to request a new feature or enhancement, please [create an issue](https://help.github.com/articles/creating-an-issue/) to report it. Kindly add a label to indicate what type of issue it is. 11 | 12 | ### Contribute Code 13 | We welcome your pull requests for bug fixes. To implement something new, please create an issue first so we can discuss it together. 14 | 15 | ***Creating a Pull Request*** 16 | Please follow [best practices](https://github.com/trein/dev-best-practices/wiki/Git-Commit-Best-Practices) for creating git commits. 17 | 18 | When your code is ready to be submitted, [submit a pull request](https://help.github.com/articles/creating-a-pull-request/) to begin the code review process. 19 | 20 | We only seek to accept code that you are authorized to contribute to the project. We have added a pull request template on our projects so that your contributions are made with the following confirmation: 21 | 22 | > I confirm that this contribution is made under the terms of the license found in the root directory of this repository's source tree and that I have the authority necessary to make this contribution on behalf of its copyright owner. 23 | 24 | ## Code of Conduct 25 | 26 | We encourage inclusive and professional interactions on our project. We welcome everyone to open an issue, improve the documentation, report bug or ssubmit a pull request. By participating in this project, you agree to abide by the [Verizon Media Code of Conduct](Code-of-Conduct.md). If you feel there is a conduct issue related to this project, please raise it per the Code of Conduct process and we will address it. 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FM^2: Field-matrixed Factorization Machines for Recommender Systems 2 | 3 | ## Table of Contents 4 | 5 | - [Background](#background) 6 | - [Install](#install) 7 | - [Usage](#usage) 8 | - [Contribute](#contribute) 9 | - [License](#license) 10 | 11 | ## Background 12 | This is the code to implement the algorithm of FM^2 (Field-matrixed Factorization Machines), it can run a quick benchmark among the LR, FM, FFM, FwFM, FvFM, FmFM and DCN, 13 | it also support data process and feature extraction from public data set Criteo and Avazu. 14 | 15 | 16 | ## Install 17 | First you will need to have [TensorFlow](https://github.com/tensorflow) (v1.15 with a GPU is preferred) and numpy, pandas, pickle and tqdm installed. 18 | 19 | You may need to login and download the [Criteo](http://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/) and [Avazu](https://www.kaggle.com/c/avazu-ctr-prediction/data) from their websites respectively. 20 | The unzipped raw data files should be placed at folder `data/criteo/` and `data/avazu/` respectively. 21 | 22 | ## Usage 23 | 24 | This project has the following content 25 | 1. **train.py** The main function to train the model 26 | 2. **features.py** Functions to process the data file and generate features 27 | 3. **models.py** The core functions to describe those models, include the new proposed FmFM and FvFM, as well as other baseline models like LR, FM, FFM, FwFM 28 | 29 | The folder **bash** contains individual training task with hyper-parameters, and the **start_train.sh** can schedule multiple task in one bash file. 30 | 31 | ![AUC vs FLOP comparison](/auc_flop.png) 32 | 33 | ## Contribute 34 | 35 | Please refer to [the contributing.md file](Contributing.md) for information about how to get involved. We welcome issues, questions, and pull requests. 36 | 37 | ## Maintainers 38 | Yang Sun, yang.sun@verizonmedia.com 39 | 40 | ## License 41 | This project is licensed under the terms of the MIT open source license. Please refer to LICENSE for the full terms. 42 | -------------------------------------------------------------------------------- /features.py: -------------------------------------------------------------------------------- 1 | # Copyright Verizon Media 2 | # This project is licensed under the MIT. See license in project root for terms. 3 | 4 | import tensorflow as tf 5 | import logging 6 | import json 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | def build_feature_meta(features_meta_file = 'features.json', features_dict_file = 'dict.txt'): 11 | ft_cat_counts = {} 12 | ft_types = {} 13 | ft_dims = {} 14 | ft_names = ['label'] 15 | ft_defaults = [[0]] 16 | with tf.gfile.Open(features_meta_file, 'r') as ftfile: 17 | feature_meta = json.load(ftfile) 18 | for ft_name, ft_type, ft_dim in feature_meta: 19 | ft_names.append(ft_name) 20 | ft_types[ft_name] = ft_type 21 | ft_dims[ft_name] = ft_dim 22 | if ft_type == 'CATEGORICAL': 23 | ft_defaults.append([0]) 24 | elif ft_type == 'NUMERIC': 25 | ft_defaults.append([0.0]) 26 | 27 | with tf.gfile.Open(features_dict_file) as fm_file: 28 | logger.info('Reading features from file %s', features_dict_file) 29 | for feature in fm_file: 30 | ft = feature.strip().split('\1') 31 | feature_name = ft[0].strip() 32 | if ft_cat_counts.get(feature_name) is None: 33 | ft_cat_counts[feature_name] = 1 34 | else: 35 | ft_cat_counts[feature_name] += 1 36 | 37 | return ft_names, ft_defaults, ft_cat_counts, ft_types, ft_dims 38 | 39 | 40 | def load_cross_fields(cross_fields_file): 41 | if cross_fields_file is None: 42 | return None 43 | else: 44 | return set(json.load(open(cross_fields_file))) 45 | 46 | 47 | # Create a feature 48 | def parse_record(record, feature_names, feature_defaults): 49 | feature_array = tf.decode_csv(record, feature_defaults) 50 | features = dict(zip(feature_names, feature_array)) 51 | label = features.pop('label') 52 | #features.pop('tag') # unused 53 | # if features['ads_category'] < 0: 54 | # features['ads_category'] = 0 55 | return features, label 56 | 57 | 58 | def input_fn(train_files, shuffle, batch_size, epoch, feature_names, feature_defaults): 59 | dataset = tf.data.TextLineDataset(train_files) 60 | if epoch: 61 | dataset = dataset.repeat(epoch) 62 | 63 | if shuffle: 64 | dataset = dataset.shuffle(shuffle) 65 | 66 | dataset = dataset.map(lambda x: parse_record(x, feature_names, feature_defaults)) 67 | dataset = dataset.batch(batch_size) 68 | 69 | return dataset 70 | 71 | 72 | if __name__ == "__main__": 73 | feature_names, feature_defaults, categorical_feature_counts, feature_types, feature_dim = \ 74 | build_feature_meta('data/criteo/feature.json', 'data/criteo/feature_index') 75 | 76 | print(feature_dim) -------------------------------------------------------------------------------- /data/avazu/trans_avazu_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Verizon Media 2 | # This project is licensed under the MIT. See license in project root for terms. 3 | 4 | from tqdm import tqdm 5 | from datetime import datetime 6 | import json 7 | 8 | def trans_date_feat(datetime_str): 9 | date_hr = datetime.strptime(datetime_str, '%y%m%d%H') 10 | return date_hr.weekday(), date_hr.hour 11 | 12 | def pre_parse_line(line, feat_list, fields_cnt = 23): 13 | splits = line.rstrip('\r\n').split(',', fields_cnt+2) 14 | 15 | datatime_str = splits[2] 16 | weekday, hour = trans_date_feat(datatime_str) 17 | features = [weekday, hour] + [int(val) for val in splits[3:5]] +\ 18 | [int(val, 16) for val in splits[5:14]] + [int(val) for val in splits[14:]] 19 | 20 | for idx in range(0, fields_cnt): 21 | val = features[idx] 22 | if val not in feat_list[idx]: 23 | feat_list[idx][val] = 1 24 | else: 25 | feat_list[idx][val] += 1 26 | 27 | return 28 | 29 | def parse_line(line, feat_list, fields_cnt): 30 | splits = line.rstrip('\r\n').split(',', fields_cnt+2) 31 | 32 | label = int(splits[1]) 33 | vals = [] 34 | 35 | datatime_str = splits[2] 36 | weekday, hour = trans_date_feat(datatime_str) 37 | features = [weekday, hour] + [int(val) for val in splits[3:5]] + \ 38 | [int(val, 16) for val in splits[5:14]] + [int(val) for val in splits[14:]] 39 | 40 | for idx in range(0, fields_cnt): 41 | val = features[idx] 42 | if val not in feat_list[idx]: 43 | vals.append(0) 44 | else: 45 | vals.append(feat_list[idx][val]) 46 | return label, vals 47 | 48 | if __name__ == "__main__": 49 | thres = 5 50 | fields_cnt = 23 51 | 52 | data_file = 'train' 53 | out_file = open('all_data.csv', 'w') 54 | feature_index = open('feature_index', 'w') 55 | feature_json = open('features.json', 'w') 56 | 57 | dataset_ptr = open(data_file, 'r') 58 | titles = dataset_ptr.readline().rstrip('\r\n').split(',') 59 | titles[1] = 'weekday' 60 | del titles[0] 61 | 62 | dataset = dataset_ptr.readlines() 63 | 64 | feat_list = [] 65 | for i in range(fields_cnt): 66 | feat_list.append({}) 67 | 68 | for line in tqdm(dataset): 69 | pre_parse_line(line, feat_list, fields_cnt) 70 | 71 | for lst in tqdm(feat_list): 72 | idx = 1 73 | for key, val in lst.items(): 74 | if val < thres: 75 | del lst[key] 76 | else: 77 | lst[key] = idx 78 | idx += 1 79 | 80 | for idx, field in tqdm(enumerate(feat_list)): 81 | for feat, id in field.items(): 82 | feature_index.write('%s\1|raw_feat_%s|\1%d\n' % (titles[idx], str(feat), id)) 83 | feature_index.close() 84 | 85 | for line in tqdm(dataset): 86 | key, vals = parse_line(line, feat_list, fields_cnt) 87 | out_file.write('%s,%s\n' % (key, ','.join([str(s) for s in vals]))) 88 | 89 | feature_meta = [] 90 | for idx in range(0, fields_cnt): 91 | feature_meta.append(('%s' % titles[idx], 'CATEGORICAL', 20)) 92 | json.dump(feature_meta, feature_json, indent=2) 93 | 94 | out_file.close() 95 | del dataset 96 | 97 | -------------------------------------------------------------------------------- /fmfm_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright Verizon Media 2 | # This project is licensed under the MIT. See license in project root for terms. 3 | 4 | import numpy as np 5 | from os.path import join 6 | from sklearn.decomposition import PCA 7 | import json 8 | import models 9 | import features 10 | 11 | def emb_dim_opt(model, feature_types, p_thres): 12 | # Embedding Dim Optimazer 13 | for feature_column, feature_type in feature_types.items(): 14 | if feature_column in {'label', 'tag'} or feature_type == 'NUMERIC': 15 | continue 16 | else: 17 | mats = model.get_variable_value('v_%s' % feature_column)[1:] 18 | n_example, dim = mats.shape 19 | pca = PCA(n_components=min(dim, n_example), svd_solver='full') 20 | pca.fit(mats) 21 | ratio = pca.explained_variance_ratio_ 22 | 23 | p = 0. 24 | for idx, r in enumerate(ratio): 25 | p += r 26 | if p > p_thres/100.: 27 | dim_opt = max(2, idx+1) 28 | print('%s\t%d -> %d' % (feature_column, dim, dim_opt)) 29 | feat_dim[feature_column] = dim_opt 30 | break 31 | 32 | return feat_dim 33 | 34 | 35 | def cross_terms_pruning(model, feat_dim, p_thres): 36 | # Mat Eigvals 37 | mat_eigvals = {} 38 | field_order = sorted(feat_dim.items(), key=lambda x: x[0]) 39 | for idx_l, (feat_l, _) in enumerate(field_order): 40 | for idx_r, (feat_r, _) in enumerate(field_order[idx_l + 1:]): 41 | idx_r += (idx_l + 1) 42 | mat_name = '%s_%s' % (feat_l, feat_r) 43 | w = model.get_variable_value(mat_name) 44 | _, sigma, _ = np.linalg.svd(w) 45 | mat_eigvals[mat_name] = sigma.tolist() 46 | sorted_terms = sorted(mat_eigvals.items(), key=lambda x:np.var(x[1]), reverse=True) 47 | n_terms = len(sorted_terms) * p_thres // 100 48 | return [x[0] for x in sorted_terms[:n_terms]] 49 | 50 | 51 | if __name__ == "__main__": 52 | # mode 1: optimize the embedding dimensions 53 | # mode 2: pruning the cross terms 54 | mode = 2 55 | model_dir = 'models/criteo_FmFM' 56 | data_dir = 'data/criteo' 57 | feature_meta = join(data_dir, 'features.json') 58 | feature_dict = join(data_dir, 'feature_index') 59 | 60 | feature_names, feature_defaults, categorical_feature_counts, feature_types, feat_dim = \ 61 | features.build_feature_meta(feature_meta, feature_dict) 62 | 63 | model = models.build_custom_linear_classifier( 64 | model_dir, feature_names, feature_types, categorical_feature_counts, 65 | None, None, None, None, None, 'FmFM', None) 66 | 67 | if mode == 1: 68 | p_thres=95 69 | feature_meta_opt = join(data_dir, 'features_opt_p%d.json') 70 | feat_dim_opt = emb_dim_opt(model, feature_types, p_thres) 71 | feat_meta_list = [] 72 | for feat_name, feat_type, _ in json.load(open(feature_meta)): 73 | feat_meta_list.append((feat_name, feat_type, feat_dim_opt[feat_name])) 74 | json.dump(feat_meta_list, open(feature_meta_opt % p_thres, 'w'), indent=2) 75 | 76 | elif mode == 2: 77 | p_thres=20 78 | cross_terms_file = join(data_dir, 'cross_fields_p%d.json') 79 | cross_terms = cross_terms_pruning(model, feat_dim, p_thres) 80 | json.dump(cross_terms, open(cross_terms_file % p_thres, 'w'), indent=2) 81 | 82 | -------------------------------------------------------------------------------- /data/criteo/trans_criteo_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Verizon Media 2 | # This project is licensed under the MIT. See license in project root for terms. 3 | 4 | from tqdm import tqdm 5 | import json 6 | import math 7 | 8 | def trans_int_feat(val): 9 | return int(math.ceil(math.log(val)**2+1)) 10 | 11 | def pre_parse_line(line, feat_list, int_feat_cnt=13, cate_feat_cnt=26, with_int_feat=False): 12 | fields_cnt = int_feat_cnt + cate_feat_cnt 13 | splits = line.rstrip('\n').split('\t', fields_cnt+1) 14 | 15 | start_index = 0 if with_int_feat else int_feat_cnt 16 | for idx in range(start_index, fields_cnt): 17 | val = splits[idx+1] 18 | if val == '': 19 | continue 20 | elif idx < int_feat_cnt: 21 | val = int(val) 22 | if val > 2: 23 | val = trans_int_feat(val) 24 | else: 25 | val = int(val, 16) 26 | 27 | if val not in feat_list[idx]: 28 | feat_list[idx][val] = 1 29 | else: 30 | feat_list[idx][val] += 1 31 | 32 | return 33 | 34 | def parse_line(line, feat_list, int_feat_cnt=13, cate_feat_cnt=26, with_int_feat = False): 35 | fields_cnt = int_feat_cnt + cate_feat_cnt 36 | splits = line.rstrip('\n').split('\t', fields_cnt+1) 37 | 38 | label = int(splits[0]) 39 | vals = [] 40 | 41 | start_index = 0 if with_int_feat else int_feat_cnt 42 | for idx in range(start_index, fields_cnt): 43 | val = splits[idx+1] 44 | if val == '': 45 | vals.append(0) 46 | continue 47 | elif idx < int_feat_cnt: 48 | val = int(val) 49 | if val > 2: 50 | val = trans_int_feat(val) 51 | else: 52 | val = int(val, 16) 53 | 54 | if val not in feat_list[idx]: 55 | vals.append(0) 56 | else: 57 | vals.append(feat_list[idx][val]) 58 | return label, vals 59 | 60 | if __name__ == "__main__": 61 | thres = 8 62 | int_feat_cnt = 13 63 | cate_feat_cnt = 26 64 | with_int_feat = True 65 | 66 | data_file = 'train.txt' 67 | out_file = open('all_data.csv', 'w') 68 | feature_index = open('feature_index', 'w') 69 | feature_json = open('features.json', 'w') 70 | 71 | dataset = open(data_file, 'r').readlines() 72 | 73 | feat_list = [] 74 | for i in range(40): 75 | feat_list.append({}) 76 | 77 | for line in tqdm(dataset): 78 | pre_parse_line(line, feat_list, int_feat_cnt, cate_feat_cnt, with_int_feat) 79 | 80 | for lst in tqdm(feat_list[:int_feat_cnt]): 81 | idx = 1 82 | for key, val in lst.items(): 83 | lst[key] = idx 84 | idx += 1 85 | 86 | for lst in tqdm(feat_list[int_feat_cnt:]): 87 | idx = 1 88 | for key, val in lst.items(): 89 | if val < thres: 90 | del lst[key] 91 | else: 92 | lst[key] = idx 93 | idx += 1 94 | 95 | for idx, field in tqdm(enumerate(feat_list)): 96 | # feat_id = sorted(field.items(), key=lambda x:x[1]) 97 | for feat, id in field.items(): 98 | feature_index.write('field_%02d\1|raw_feat_%s|\1%d\n' % (idx+1, str(feat), id)) 99 | feature_index.close() 100 | 101 | for line in tqdm(dataset): 102 | key, vals = parse_line(line, feat_list, int_feat_cnt, cate_feat_cnt, with_int_feat) 103 | if vals is None: 104 | continue 105 | out_file.write('%s,%s\n' % (key, ','.join([str(s) for s in vals]))) 106 | 107 | feature_meta = [] 108 | for idx in range(1, int_feat_cnt + cate_feat_cnt + 1): 109 | feature_meta.append(('field_%02d' % idx, 'CATEGORICAL', 20)) 110 | json.dump(feature_meta, feature_json, indent=2) 111 | 112 | out_file.close() 113 | del dataset 114 | 115 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright Verizon Media 2 | # This project is licensed under the MIT. See license in project root for terms. 3 | 4 | from __future__ import (absolute_import, division, print_function) 5 | import tensorflow as tf 6 | from os.path import join 7 | from os import getenv 8 | import logging 9 | import features 10 | import models 11 | import os 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | def report_update(file_ptr, mode, itr, metrics): 16 | info = 'Iter:%d\t%s Set metrics - %s\n' % (itr+1, mode, format(metrics)) 17 | file_ptr.write(info) 18 | file_ptr.flush() 19 | return info 20 | 21 | def main(args): 22 | if not os.path.exists(args.model_dir): 23 | os.makedirs(args.model_dir) 24 | report_file = open(os.path.join(args.model_dir, 'report.txt'), 'w+') 25 | tf.logging.set_verbosity(tf.logging.INFO) 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | logger.info('Training start with arguments... {}'.format(args)) 29 | logger.info('tf version {}'.format(tf.__version__)) 30 | logger.info('TF_CONFIG : {}'.format(getenv('TF_CONFIG'))) 31 | 32 | logger.info('Starting to load feature info from {},{}'.format(args.feature_meta, args.feature_dict)) 33 | feature_names, feature_defaults, categorical_feature_counts, feature_types, feature_dims = \ 34 | features.build_feature_meta(args.feature_meta, args.feature_dict) 35 | 36 | cross_fields = features.load_cross_fields(args.cross_fields) 37 | 38 | print('{}'.format(feature_names)) 39 | print('{}'.format(feature_defaults)) 40 | print('{}'.format(categorical_feature_counts)) 41 | print('{}'.format(sorted(feature_types.items()))) 42 | 43 | logger.info('train file: {}'.format(args.train_data_file)) 44 | logger.info('val file: {}'.format(args.val_data_file)) 45 | logger.info('test file: {}'.format(args.test_data_file)) 46 | 47 | def train_input_fn(): 48 | return features.input_fn(args.train_data_file, True, args.batch_size, None, 49 | feature_names, feature_defaults) 50 | def val_input_fn(): 51 | return features.input_fn(args.val_data_file, True, args.batch_size, None, 52 | feature_names, feature_defaults) 53 | def test_input_fn(): 54 | return features.input_fn(args.test_data_file, True, args.batch_size, None, 55 | feature_names, feature_defaults) 56 | 57 | model = models.build_custom_linear_classifier(args.model_dir, feature_names, feature_types, 58 | categorical_feature_counts, args.l2_linear, args.l2_latent, args.l2_r, args.learning_rate, 59 | args.default_feat_dim, args.model_type, feature_dims, cross_fields) 60 | 61 | tf.reset_default_graph() 62 | best_auc = 0 63 | 64 | 65 | for n in range(args.train_epoch): 66 | model.train(input_fn=train_input_fn) 67 | 68 | metrics = model.evaluate(input_fn=val_input_fn, name='Val') 69 | logger.info(report_update(report_file, 'Val', n, metrics)) 70 | if metrics['auc'] > best_auc: 71 | best_auc = metrics['auc'] 72 | else: 73 | break 74 | 75 | metrics = model.evaluate(input_fn=test_input_fn, name='Test') 76 | logger.info(report_update(report_file, 'Test', n, metrics)) 77 | 78 | metrics = model.evaluate(input_fn=train_input_fn, name='train') 79 | logger.info(report_update(report_file, 'Train', n, metrics)) 80 | 81 | report_file.close() 82 | 83 | 84 | if __name__ == "__main__": 85 | import argparse 86 | 87 | parser = argparse.ArgumentParser() 88 | 89 | parser.add_argument("--model_dir", help="path to save model/checkpoint", default="dsp_cpc_tf_model") 90 | parser.add_argument("--train_data_file", help="path to training input data", default="dsp_cpc_train") 91 | parser.add_argument("--val_data_file", help="path to validate input data", default="dsp_cpc_val") 92 | parser.add_argument("--test_data_file", help="path to test input data", default="dsp_cpc_test") 93 | parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100) 94 | parser.add_argument("--eval_batch_size", help="number of records per batch for eval", type=int, default=100) 95 | parser.add_argument("--train_epoch", help="Training epoch count", type=int, default=20) 96 | parser.add_argument("--max_steps", help="maximum number of steps", type=int, default=100000) 97 | parser.add_argument("--l2_linear", help="l2 regularization scale for linear term", type=float, default=0.0) 98 | parser.add_argument("--l2_latent", help="l2 regularization scale for latent factor", type=float, default=0.0) 99 | parser.add_argument("--l2_r", help="l2 regularization scale for fields", type=float, default=0.0) 100 | parser.add_argument("--learning_rate", help="learning rate in Adam Optimizer", type=float, default=1e-4) 101 | parser.add_argument("--default_feat_dim", help="dimension of latent vector for each feature", type=int, default=10) 102 | parser.add_argument("--cross_fields", help="json file, list of cross fields", default=None) 103 | parser.add_argument("--model_type", help="the model used in training (lr, FM, FwFM)", default="LR") 104 | parser.add_argument("--feature_meta", help="path to feature meta data", default="features.json") 105 | parser.add_argument("--feature_dict", help="path to feature dictionary", default="feature_dict") 106 | 107 | 108 | args = parser.parse_args() 109 | print("args:", args) 110 | 111 | print('Executing in local mode') 112 | main(args) -------------------------------------------------------------------------------- /Code-of-Conduct.md: -------------------------------------------------------------------------------- 1 | # Verizon Media Open Source Code of Conduct 2 | 3 | ## Summary 4 | This Code of Conduct is our way to encourage good behavior and discourage bad behavior in our open source projects. We invite participation from many people to bring different perspectives to our projects. We will do our part to foster a welcoming and professional environment free of harassment. We expect participants to communicate professionally and thoughtfully during their involvement with this project. 5 | 6 | Participants may lose their good standing by engaging in misconduct. For example: insulting, threatening, or conveying unwelcome sexual content. We ask participants who observe conduct issues to report the incident directly to the project's Response Team at opensource-conduct@verizonmedia.com. Verizon Media will assign a respondent to address the issue. We may remove harassers from this project. 7 | 8 | This code does not replace the terms of service or acceptable use policies of the websites used to support this project. We acknowledge that participants may be subject to additional conduct terms based on their employment which may govern their online expressions. 9 | 10 | ## Details 11 | This Code of Conduct makes our expectations of participants in this community explicit. 12 | * We forbid harassment and abusive speech within this community. 13 | * We request participants to report misconduct to the project’s Response Team. 14 | * We urge participants to refrain from using discussion forums to play out a fight. 15 | 16 | ### Expected Behaviors 17 | We expect participants in this community to conduct themselves professionally. Since our primary mode of communication is text on an online forum (e.g. issues, pull requests, comments, emails, or chats) devoid of vocal tone, gestures, or other context that is often vital to understanding, it is important that participants are attentive to their interaction style. 18 | 19 | * **Assume positive intent.** We ask community members to assume positive intent on the part of other people’s communications. We may disagree on details, but we expect all suggestions to be supportive of the community goals. 20 | * **Respect participants.** We expect occasional disagreements. Open Source projects are learning experiences. Ask, explore, challenge, and then _respectfully_ state if you agree or disagree. If your idea is rejected, be more persuasive not bitter. 21 | * **Welcoming to new members.** New members bring new perspectives. Some ask questions that have been addressed before. _Kindly_ point to existing discussions. Everyone is new to every project once. 22 | * **Be kind to beginners.** Beginners use open source projects to get experience. They might not be talented coders yet, and projects should not accept poor quality code. But we were all beginners once, and we need to engage kindly. 23 | * **Consider your impact on others.** Your work will be used by others, and you depend on the work of others. We expect community members to be considerate and establish a balance their self-interest with communal interest. 24 | * **Use words carefully.** We may not understand intent when you say something ironic. Often, people will misinterpret sarcasm in online communications. We ask community members to communicate plainly. 25 | * **Leave with class.** When you wish to resign from participating in this project for any reason, you are free to fork the code and create a competitive project. Open Source explicitly allows this. Your exit should not be dramatic or bitter. 26 | 27 | ### Unacceptable Behaviors 28 | Participants remain in good standing when they do not engage in misconduct or harassment (some examples follow). We do not list all forms of harassment, nor imply some forms of harassment are not worthy of action. Any participant who *feels* harassed or *observes* harassment, should report the incident to the Response Team. 29 | * **Don't be a bigot.** Calling out project members by their identity or background in a negative or insulting manner. This includes, but is not limited to, slurs or insinuations related to protected or suspect classes e.g. race, color, citizenship, national origin, political belief, religion, sexual orientation, gender identity and expression, age, size, culture, ethnicity, genetic features, language, profession, national minority status, mental or physical ability. 30 | * **Don't insult.** Insulting remarks about a person’s lifestyle practices. 31 | * **Don't dox.** Revealing private information about other participants without explicit permission. 32 | * **Don't intimidate.** Threats of violence or intimidation of any project member. 33 | * **Don't creep.** Unwanted sexual attention or content unsuited for the subject of this project. 34 | * **Don't inflame.** We ask that victim of harassment not address their grievances in the public forum, as this often intensifies the problem. Report it, and let us address it off-line. 35 | * **Don't disrupt.** Sustained disruptions in a discussion. 36 | 37 | ### Reporting Issues 38 | If you experience or witness misconduct, or have any other concerns about the conduct of members of this project, please report it by contacting our Response Team at opensource-conduct@verizonmedia.com who will handle your report with discretion. Your report should include: 39 | * Your preferred contact information. We cannot process anonymous reports. 40 | * Names (real or usernames) of those involved in the incident. 41 | * Your account of what occurred, and if the incident is ongoing. Please provide links to or transcripts of the publicly available records (e.g. a mailing list archive or a public IRC logger), so that we can review it. 42 | * Any additional information that may be helpful to achieve resolution. 43 | 44 | After filing a report, a representative will contact you directly to review the incident and ask additional questions. If a member of the Verizon Media Response Team is named in an incident report, that member will be recused from handling your incident. If the complaint originates from a member of the Response Team, it will be addressed by a different member of the Response Team. We will consider reports to be confidential for the purpose of protecting victims of abuse. 45 | 46 | ### Scope 47 | Verizon Media will assign a Response Team member with admin rights on the project and legal rights on the project copyright. The Response Team is empowered to restrict some privileges to the project as needed. Since this project is governed by an open source license, any participant may fork the code under the terms of the project license. The Response Team’s goal is to preserve the project if possible, and will restrict or remove participation from those who disrupt the project. 48 | 49 | This code does not replace the terms of service or acceptable use policies that are provided by the websites used to support this community. Nor does this code apply to communications or actions that take place outside of the context of this community. Many participants in this project are also subject to codes of conduct based on their employment. This code is a social-contract that informs participants of our social expectations. It is not a terms of service or legal contract. 50 | 51 | ## License and Acknowledgment. 52 | This text is shared under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0/). This code is based on a study conducted by the [TODO Group](https://todogroup.org/) of many codes used in the open source community. If you have feedback about this code, contact our Response Team at the address listed above. 53 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright Verizon Media 2 | # This project is licensed under the MIT. See license in project root for terms. 3 | 4 | import tensorflow as tf 5 | import numpy as np 6 | import pandas as pd 7 | from os.path import join 8 | import logging 9 | import pickle 10 | 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | def write_weight(outfile, weight, rec): 15 | if weight != 0.0: 16 | outfile.write(rec) 17 | 18 | 19 | def build_feature_columns(feature_names, feature_types, categorical_feature_counts): 20 | features = [] 21 | for feature_name in feature_names: 22 | if feature_name == 'label' or feature_name == 'tag': 23 | pass 24 | elif feature_types[feature_name] == 'CATEGORICAL': 25 | f1 = tf.feature_column.categorical_column_with_identity(feature_name, categorical_feature_counts[feature_name], default_value=0) 26 | features.append(f1) 27 | elif feature_types[feature_name] == 'NUMERIC': 28 | features.append(tf.feature_column.numeric_column(feature_name, default_value=0.0)) 29 | return features 30 | 31 | 32 | def metric_auc(labels, predictions): 33 | return { 34 | 'auc_precision_recall': tf.metrics.auc( 35 | labels=labels, predictions=predictions['logistic'], num_thresholds=200, 36 | curve='PR', summation_method='careful_interpolation') 37 | } 38 | 39 | 40 | def export_model(model, model_export_path, feature_names, feature_types, categorical_feature_counts): 41 | return model.export_savedmodel(model_export_path, 42 | tf.estimator.export.build_parsing_serving_input_receiver_fn( 43 | tf.feature_column.make_parse_example_spec(build_feature_columns( 44 | feature_names, feature_types, categorical_feature_counts 45 | )), 46 | ), as_text=True) 47 | 48 | 49 | def save_model_weights(model, model_weights_dir, feature_types): 50 | if not tf.gfile.Exists(model_weights_dir): 51 | tf.gfile.MkDir(model_weights_dir) 52 | 53 | print('Saving model weights to {}'.format(model_weights_dir)) 54 | with tf.gfile.Open(join(model_weights_dir, 'model.txt'), 'w+') as modelfile: 55 | dense_weight = model.get_variable_value('dense/kernel')[0][1] 56 | bias = model.get_variable_value('linear_model/bias_weights') * dense_weight + model.get_variable_value('dense/bias')[1] 57 | # format is feature_column_name, feature_idx, weights 58 | write_weight(modelfile, bias, '{},{},{}\n'.format('BiasWeight', 0, bias)) 59 | for f_order, (feature_column, feature_type) in enumerate(feature_types.items()): 60 | if feature_type == 'CATEGORICAL': 61 | t = model.get_variable_value('linear_model/' + feature_column + '/weights') 62 | for ft_idx, wt in enumerate(t): 63 | write_weight(modelfile, wt[0] * dense_weight, '{},{},{}\n'.format(feature_column, ft_idx, wt[0] * dense_weight)) 64 | else: 65 | write_weight(modelfile, dense_weight, '{},{},{}\n'.format(feature_column, 0, dense_weight)) 66 | 67 | 68 | def generate_field_dic(path, field): 69 | '''read and create the field feature index - value dictionary''' 70 | df = pd.read_csv(path, sep=chr(1), header=None) 71 | df.columns = ['field', 'value', 'index'] 72 | df_line = df[df['field'] == field] 73 | field_dic = {} 74 | for index, row in df_line.iterrows(): 75 | # access data using column names 76 | field_dic[row['index']] = row['value'] 77 | return field_dic 78 | 79 | 80 | def save_FwFM_model_weights(model, model_weights_dir, categorical_feature_counts, feature_types, feature_dir, embed_dim): 81 | if not tf.gfile.Exists(model_weights_dir): 82 | tf.gfile.MkDir(model_weights_dir) 83 | num_cat_fields = 0 84 | num_features = [] 85 | field_order = [] 86 | for feature_name, feature_type in feature_types.items(): 87 | if feature_name in {'label', 'tag'} or feature_type == 'NUMERIC': 88 | continue 89 | else: 90 | num_cat_fields += 1 91 | num_features.append(categorical_feature_counts[feature_name]) 92 | field_order.append(feature_name) 93 | print('Number of categorical fields are: {}'.format(num_cat_fields)) 94 | print('Number of features are: {}'.format(sum(num_features))) 95 | 96 | print('Saving model weights to {}'.format(model_weights_dir)) 97 | 98 | with tf.gfile.Open(join(model_weights_dir, 'model.txt'), 'w+') as modelfile: 99 | #Write the header to the file 100 | modelfile.write('==== Header ====\n') 101 | modelfile.write('model_version_id: 100\n') 102 | modelfile.write('n_field: {}\n'.format(num_cat_fields)) 103 | modelfile.write('has_bias: 0\n') 104 | modelfile.write('embedding_dim: {}\n'.format(embed_dim)) 105 | modelfile.write('total_number_of_features: {}\n'.format(sum(num_features))) 106 | 107 | # Check number of interaction pairs (used for model_pruned version) 108 | wp = model.get_variable_value('w_p') 109 | num_pair = 0 110 | for i in range(len(wp)): 111 | if wp[i][0] != 0.0: 112 | num_pair += 1 113 | modelfile.write('total_number_of_interaction_pairs: {}\n'.format(num_pair)) 114 | 115 | # Write the default weight of each field 116 | for feature_column, feature_type in feature_types.items(): 117 | if feature_column in {'label', 'tag'} or feature_type == 'NUMERIC': 118 | continue 119 | else: 120 | default_vector = model.get_variable_value('v_%s' % feature_column)[0] 121 | modelfile.write('default_weights: {}\u0001{}\u0001'.format(feature_column, categorical_feature_counts[feature_column])) 122 | N = len(default_vector) 123 | for i in range(N - 1): 124 | modelfile.write('{},'.format(default_vector[i])) 125 | modelfile.write('{}\n'.format(default_vector[-1])) 126 | 127 | # Write the model weights to the file 128 | modelfile.write('==== Model ====\n') 129 | global_interception = model.get_variable_value('b') 130 | modelfile.write('b1: {}\n'.format(global_interception[0])) 131 | 132 | linear_term = model.get_variable_value('w_l') 133 | field_scalar = model.get_variable_value('w_p') 134 | assert(len(field_scalar) == num_cat_fields * (num_cat_fields - 1) / 2) 135 | 136 | # First write the feature embedding 137 | for feature_column, feature_type in feature_types.items(): 138 | if feature_column in {'label', 'tag'} or feature_type == 'NUMERIC': 139 | continue 140 | else: 141 | f_embedding = model.get_variable_value('v_%s' % feature_column)[1:] 142 | field_dic = generate_field_dic(feature_dir, feature_column) 143 | for ft_idx, wt in enumerate(f_embedding): 144 | ft_idx += 1 145 | f_value = field_dic[ft_idx] 146 | #remove the prefix of field_name 147 | m = len(feature_column) 148 | if feature_column == 'subdomain' or feature_column == 'page_tld' or feature_column == 'app_name': 149 | modelfile.write('{}\u0001{}: '.format(feature_column, f_value)) 150 | else: 151 | modelfile.write('{}\u0001{}: '.format(feature_column, f_value[m+1:])) 152 | N = len(wt) 153 | for i in range(N-1): 154 | modelfile.write('{},'.format(wt[i])) 155 | modelfile.write('{}\n'.format(wt[-1])) 156 | 157 | #Then for linear embedding 158 | for i in range(len(field_order)): 159 | modelfile.write('w_l\u0001{}: '.format(field_order[i])) 160 | for j in range(embed_dim - 1): 161 | modelfile.write('{},'.format(linear_term[i * embed_dim + j][0])) 162 | modelfile.write('{}\n'.format(linear_term[i * embed_dim + embed_dim - 1][0])) 163 | 164 | #Finally for field interaction: 165 | num_pair = 0 166 | for i in range(len(field_order)): 167 | for j in range(i+1, len(field_order)): 168 | if field_scalar[num_pair][0] != 0.0: 169 | modelfile.write('r\u0001{}\u0001{}: {}\n'.format(field_order[i], field_order[j], field_scalar[num_pair][0])) 170 | num_pair += 1 171 | 172 | 173 | def save_FmFM_model_weights(model, model_weights_dir, categorical_feature_counts, feature_types, feature_dir, 174 | embed_dim): 175 | if not tf.gfile.Exists(model_weights_dir): 176 | tf.gfile.MkDir(model_weights_dir) 177 | num_cat_fields = 0 178 | num_features = [] 179 | field_order = [] 180 | for feature_name, feature_type in feature_types.items(): 181 | if feature_name in {'label', 'tag'} or feature_type == 'NUMERIC': 182 | continue 183 | else: 184 | num_cat_fields += 1 185 | num_features.append(categorical_feature_counts[feature_name]) 186 | field_order.append(feature_name) 187 | field_order = sorted(field_order) 188 | print('Number of categorical fields are: {}'.format(num_cat_fields)) 189 | print('Number of features are: {}'.format(sum(num_features))) 190 | 191 | print('Saving model weights to {}'.format(model_weights_dir)) 192 | 193 | with tf.gfile.Open(join(model_weights_dir, 'model.txt'), 'w') as modelfile: 194 | # Write the header to the file 195 | modelfile.write('==== Header ====\n') 196 | modelfile.write('model_version_id: 100\n') 197 | modelfile.write('n_field: {}\n'.format(num_cat_fields)) 198 | modelfile.write('has_bias: 0\n') 199 | modelfile.write('embedding_dim: {}\n'.format(embed_dim)) 200 | modelfile.write('total_number_of_features: {}\n'.format(sum(num_features))) 201 | 202 | # Check number of interaction pairs (used for model_pruned version) 203 | # wp = model.get_variable_value('w_p') 204 | # num_pair = 0 205 | # for i in range(len(wp)): 206 | # if wp[i][0] != 0.0: 207 | # num_pair += 1 208 | # modelfile.write('total_number_of_interaction_pairs: {}\n'.format(num_pair)) 209 | 210 | # Write the default weight of each field 211 | for feature_column, feature_type in feature_types.items(): 212 | if feature_column in {'label', 'tag'} or feature_type == 'NUMERIC': 213 | continue 214 | else: 215 | default_vector = model.get_variable_value('v_%s' % feature_column)[0] 216 | modelfile.write('default_weights: {}\u0001{}\u0001'.format(feature_column, 217 | categorical_feature_counts[feature_column])) 218 | N = len(default_vector) 219 | for i in range(N - 1): 220 | modelfile.write('{},'.format(default_vector[i])) 221 | modelfile.write('{}\n'.format(default_vector[-1])) 222 | 223 | # Write the model weights to the file 224 | modelfile.write('==== Model ====\n') 225 | global_interception = model.get_variable_value('b') 226 | modelfile.write('b1: {}\n'.format(global_interception[0])) 227 | 228 | # linear_term = model.get_variable_value('w_l') 229 | # field_matrices = model.get_variable_value('w_p') 230 | field_matrices = {} 231 | for idx_l, feat_l in enumerate(field_order): 232 | for idx_r, feat_r in enumerate(field_order): 233 | if idx_r <= idx_l: 234 | continue 235 | w = model.get_variable_value('%s_%s' % (feat_l, feat_r)) 236 | modelfile.write( 237 | 'r\u0001%s\u0001%s:\t%s\n' % (feat_l, feat_r, pickle.dumps(w))) 238 | field_matrices['%s X %s' % (feat_l, feat_r)] = w 239 | pickle.dump(field_matrices, open(join(model_weights_dir, 'field_matrices.pickle'), 'wb')) 240 | 241 | # Then for linear embedding 242 | # for i in range(len(field_order)): 243 | # modelfile.write('w_l\u0001{}: '.format(field_order[i])) 244 | # for j in range(embed_dim - 1): 245 | # modelfile.write('{},'.format(linear_term[i * embed_dim + j][0])) 246 | # modelfile.write('{}\n'.format(linear_term[i * embed_dim + embed_dim - 1][0])) 247 | 248 | #Then for linear embedding 249 | for i in range(len(field_order)): 250 | modelfile.write('w_\u0001{}:\t'.format(field_order[i])) 251 | modelfile.write('{}\n'.format(model.get_variable_value('w_%s' % field_order[i]).flatten() ) ) 252 | 253 | # First write the feature embedding 254 | for feature_column, feature_type in feature_types.items(): 255 | if feature_column in {'label', 'tag'} or feature_type == 'NUMERIC': 256 | continue 257 | else: 258 | f_embedding = model.get_variable_value('v_%s' % feature_column)[1:] 259 | field_dic = generate_field_dic(feature_dir, feature_column) 260 | for ft_idx, wt in enumerate(f_embedding): 261 | ft_idx += 1 262 | f_value = field_dic[ft_idx] 263 | # remove the prefix of field_name 264 | m = len(feature_column) 265 | if feature_column == 'subdomain' or feature_column == 'page_tld' or feature_column == 'app_name': 266 | modelfile.write('{}\u0001{}: '.format(feature_column, f_value)) 267 | else: 268 | modelfile.write('{}\u0001{}: '.format(feature_column, f_value[m + 1:])) 269 | N = len(wt) 270 | for i in range(N - 1): 271 | modelfile.write('{},'.format(wt[i])) 272 | modelfile.write('{}\n'.format(wt[-1])) 273 | 274 | def dense_to_sparse(dense_tensor, n_dim): 275 | zero_t = tf.zeros_like(dense_tensor) 276 | dense_final = tf.where(tf.greater_equal(dense_tensor, tf.ones_like(dense_tensor) * n_dim), zero_t, dense_tensor) 277 | indices = tf.to_int64( 278 | tf.transpose([tf.range(tf.shape(dense_tensor)[0]), tf.reshape(dense_final, [-1])])) 279 | values = tf.ones_like(dense_tensor, dtype=tf.float32) 280 | shape = [tf.shape(dense_tensor)[0], n_dim] 281 | return tf.SparseTensor( 282 | indices=indices, 283 | values=values, 284 | dense_shape=shape 285 | ) 286 | 287 | 288 | def LR(features, labels, mode, params): 289 | 290 | X = [] 291 | W = [] 292 | X_numeric = [] 293 | W_numeric = [] 294 | feat_cnt = params["categorical_feature_counts"] 295 | 296 | for f_name in features.keys(): 297 | if f_name in {'label', 'tag'}: #Not used in model training 298 | continue 299 | elif params["feature_types"][f_name] == 'NUMERIC': 300 | X_numeric.append(features[f_name]) 301 | w = tf.get_variable('w0_%s' % f_name, shape=[1, 2]) 302 | W_numeric.append(w) 303 | else: 304 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 305 | X.append(sparse_t) 306 | w = tf.get_variable('w0_%s' % f_name, shape=[feat_cnt[f_name] + 1, 2]) 307 | W.append(w) 308 | 309 | b = tf.get_variable('b', shape=[2]) 310 | logits = b 311 | for i in range(len(X)): 312 | logits = logits + tf.sparse_tensor_dense_matmul(X[i], W[i]) 313 | for i in range(len(X_numeric)): 314 | logits = logits + tf.matmul(tf.reshape(X_numeric[i], [-1, 1]), W_numeric[i]) 315 | 316 | predicted_classes = tf.argmax(logits, 1) 317 | if mode == tf.estimator.ModeKeys.PREDICT: 318 | predictions = { 319 | 'class_ids': predicted_classes[:, tf.newaxis], 320 | 'probabilities': tf.nn.softmax(logits)[:,1], 321 | 'logits': logits, 322 | } 323 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 324 | 325 | # Compute loss. 326 | labels = tf.cast(labels, tf.int64) 327 | one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.int64) 328 | 329 | class_weights = tf.constant(([1, 1]), dtype=tf.int64) 330 | sample_weights = tf.reduce_sum(tf.multiply(one_hot_labels, class_weights), 1) 331 | 332 | #loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 333 | loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits, weights=sample_weights) 334 | 335 | l2_loss = tf.nn.l2_loss(b) * params['l2_linear'] 336 | for w in W: 337 | l2_loss += params['l2_linear'] * tf.nn.l2_loss(w) 338 | for w_numeric in W_numeric: 339 | l2_loss += params['l2_linear'] * tf.nn.l2_loss(w_numeric) 340 | 341 | # Compute evaluation metrics. 342 | accuracy = tf.metrics.accuracy(labels=labels, 343 | predictions=predicted_classes, 344 | name='acc_op') 345 | auc = tf.metrics.auc(labels=labels, 346 | predictions=tf.nn.softmax(logits)[:, 1], 347 | name='auc_op') 348 | metric_orig_loss = tf.metrics.mean_tensor(loss, 349 | name='orig_loss_op') 350 | metric_l2_loss = tf.metrics.mean_tensor(l2_loss, 351 | name='l2_loss_op') 352 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 353 | 354 | loss += l2_loss 355 | 356 | if mode == tf.estimator.ModeKeys.EVAL: 357 | return tf.estimator.EstimatorSpec( 358 | mode, loss=loss, eval_metric_ops=metrics) 359 | 360 | # Create training op. 361 | assert mode == tf.estimator.ModeKeys.TRAIN 362 | 363 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 364 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 365 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 366 | 367 | 368 | def FM(features, labels, mode, params): 369 | # Create three fully connected layers. 370 | 371 | X = [] 372 | W = [] 373 | V = [] 374 | X_square = [] 375 | XV = [] 376 | XV_square = [] 377 | feat_cnt = params["categorical_feature_counts"] 378 | 379 | for f_name in features.keys(): 380 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 381 | continue 382 | else: 383 | # X.append(one_hot_dense) 384 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 385 | X.append(sparse_t) 386 | w = tf.get_variable('w_%s' % f_name, shape=[feat_cnt[f_name] + 1, 1]) 387 | W.append(w) 388 | v = tf.get_variable('v_%s' % f_name, shape=[feat_cnt[f_name] + 1, params['latent_factor']]) 389 | V.append(v) 390 | 391 | b = tf.get_variable('b', shape=[1]) 392 | logits = b 393 | for i in range(len(X)): 394 | logits = logits + tf.sparse_tensor_dense_matmul(X[i], W[i]) 395 | x_square = tf.SparseTensor(X[i].indices, tf.square(X[i].values), tf.to_int64(tf.shape(X[i]))) 396 | X_square.append(x_square) 397 | xv = tf.sparse_tensor_dense_matmul(X[i], V[i]) 398 | XV.append(xv) 399 | xv_square = tf.sparse_tensor_dense_matmul(x_square, tf.square(V[i])) 400 | XV_square.append(xv_square) 401 | p1 = XV[0] 402 | p2 = XV_square[0] 403 | for i in range(1, len(XV)): 404 | p1 = p1 + XV[i] 405 | p2 = p2 + XV_square[i] 406 | p_final = tf.reshape(0.5 * tf.reduce_sum(tf.square(p1) - p2, 1), [-1, 1]) 407 | 408 | 409 | logits = logits + p_final 410 | 411 | y_prob = tf.sigmoid(logits) 412 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 413 | if mode == tf.estimator.ModeKeys.PREDICT: 414 | predictions = { 415 | 'class_ids': pred_class, 416 | 'probabilities': y_prob, 417 | 'logits': logits, 418 | } 419 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 420 | 421 | # Compute loss. 422 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 423 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 424 | loss = tf.math.reduce_mean(loss) 425 | 426 | l2_loss = tf.nn.l2_loss(b) * params['l2_linear'] 427 | for w in W: 428 | l2_loss += params['l2_linear'] * tf.nn.l2_loss(w) 429 | for v in V: 430 | l2_loss += params['l2_latent'] * tf.nn.l2_loss(v) 431 | 432 | 433 | # Compute evaluation metrics. 434 | accuracy = tf.metrics.accuracy(labels=labels, 435 | predictions=pred_class, 436 | name='acc_op1') 437 | auc = tf.metrics.auc(labels=labels, 438 | predictions=y_prob, 439 | name='auc_op1') 440 | 441 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 442 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 443 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 444 | 445 | loss += l2_loss 446 | 447 | if mode == tf.estimator.ModeKeys.EVAL: 448 | return tf.estimator.EstimatorSpec( 449 | mode, loss=loss, eval_metric_ops=metrics) 450 | 451 | # Create training op. 452 | assert mode == tf.estimator.ModeKeys.TRAIN 453 | 454 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 455 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 456 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 457 | 458 | 459 | def FFM(features, labels, mode, params): 460 | X = [] 461 | V = [] 462 | W = [] 463 | F = params['latent_factor'] 464 | feat_cnt = params["categorical_feature_counts"] 465 | 466 | M = 0 #Number of categorical features 467 | for f_name in features.keys(): 468 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 469 | continue 470 | else: 471 | M += 1 472 | # X.append(one_hot_dense) 473 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 474 | X.append(sparse_t) 475 | w = tf.get_variable('w_%s' % f_name, shape=[feat_cnt[f_name] + 1, 1]) #[Mi, 1] 476 | W.append(w) #linear term 477 | 478 | for f_name in features.keys(): 479 | if params["feature_types"][f_name] == 'NUMERIC': 480 | continue 481 | else: 482 | v = tf.get_variable('v_%s' % f_name, shape=[feat_cnt[f_name] + 1, M * F]) #[Mi, M*F] 483 | V.append(v) 484 | 485 | # global interception 486 | b = tf.get_variable('b', shape=[1]) 487 | 488 | #linear term 489 | xw = [tf.sparse_tensor_dense_matmul(X[i], W[i]) for i in range(M)] 490 | 491 | # field tensor list with M of [N, M * F] 492 | xv = [tf.sparse_tensor_dense_matmul(X[i], V[i]) for i in range(M)] 493 | 494 | # concact to matrix [N, M^2 * F] 495 | l = tf.concat([xv[i] for i in range(M)], 1) 496 | xw1 = tf.concat([xw[i] for i in range(M)], 1) 497 | 498 | #Create field index 499 | index_left = [] 500 | index_right = [] 501 | 502 | for i in range(M): 503 | for j in range(M): 504 | if i != j: 505 | index_left.append(i * M + j) 506 | index_right.append(j * M + i) 507 | 508 | # reshape l_ to [N, M^2, F] 509 | l_ = tf.reshape(l, [-1, M*M, F]) 510 | l_left = tf.gather(l_, index_left, axis=1) 511 | l_right = tf.gather(l_, index_right, axis=1) 512 | # element-wise multiplication of [N, M(M-1), F] 513 | p_full = tf.multiply(l_left, l_right) 514 | # Reduce to [N, M(M-1)] 515 | p = tf.reduce_sum(p_full, 2) 516 | 517 | #Reduce to [N, 1] 518 | p = tf.reduce_sum(p, 1, keep_dims=True) 519 | 520 | logits = tf.reduce_sum(xw1, 1, keep_dims=True) + b + p 521 | 522 | y_prob = tf.sigmoid(logits) 523 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 524 | if mode == tf.estimator.ModeKeys.PREDICT: 525 | predictions = { 526 | 'class_ids': pred_class, 527 | 'probabilities': y_prob, 528 | 'logits': logits, 529 | } 530 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 531 | 532 | # Compute loss. 533 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 534 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 535 | loss = tf.math.reduce_mean(loss) 536 | 537 | l2_loss = 0 538 | for w in W: 539 | l2_loss += tf.nn.l2_loss(w) * params['l2_linear'] 540 | for v in V: 541 | l2_loss += tf.nn.l2_loss(v) * params['l2_latent'] 542 | 543 | # Compute evaluation metrics. 544 | accuracy = tf.metrics.accuracy(labels=labels, 545 | predictions=pred_class, 546 | name='acc_op1') 547 | auc = tf.metrics.auc(labels=labels, 548 | predictions=y_prob, 549 | name='auc_op1') 550 | 551 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 552 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 553 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 554 | 555 | loss += l2_loss 556 | 557 | if mode == tf.estimator.ModeKeys.EVAL: 558 | return tf.estimator.EstimatorSpec( 559 | mode, loss=loss, eval_metric_ops=metrics) 560 | 561 | # Create training op. 562 | assert mode == tf.estimator.ModeKeys.TRAIN 563 | 564 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 565 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 566 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 567 | 568 | 569 | def FwFM(features, labels, mode, params): 570 | V = [] 571 | xv_cat = [] 572 | F = params['latent_factor'] 573 | feat_cnt = params["categorical_feature_counts"] 574 | 575 | M = 0 #Number of categorical features 576 | for f_name in features.keys(): 577 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 578 | continue 579 | else: 580 | M += 1 581 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 582 | v = tf.get_variable('v_%s' % f_name, dtype=tf.float32, 583 | initializer=tf.random_normal(shape=[feat_cnt[f_name] + 1, F], mean=0.0, stddev=0.01)) 584 | V.append(v) 585 | xv_cat.append(tf.sparse_tensor_dense_matmul(sparse_t, v)) # tf.nn.l2_normalize(v, axis=-1))) 586 | 587 | # linear term 588 | w_l = tf.get_variable('w_l', dtype=tf.float32, 589 | initializer=tf.random_normal(shape=[M * F, 1], mean=0.0, stddev=1.0)) 590 | # field scalar 591 | w_p = tf.get_variable('w_p', dtype=tf.float32, 592 | initializer=tf.random_normal(shape=[int(M * (M-1)/2), 1], mean=0.0, stddev=1.0)) 593 | # global interception 594 | b = tf.get_variable('b', shape=[1], dtype=tf.float32) 595 | 596 | # concact to matrix [N, M * F] 597 | l_cat = tf.concat([xv_cat[i] for i in range(M)], 1) 598 | 599 | l = l_cat 600 | index_left = [] 601 | index_right = [] 602 | 603 | for i in range(M): 604 | for j in range(i + 1, M): 605 | index_left.append(i) 606 | index_right.append(j) 607 | 608 | # reshape l_ to [N, M, F] 609 | l_ = tf.reshape(l, [-1, M, F]) 610 | l_left = tf.gather(l_, index_left, axis=1) 611 | l_right = tf.gather(l_, index_right, axis=1) 612 | 613 | # element-wise multiplication of [N, M(M-1)/2, F] 614 | p_full = tf.multiply(l_left, l_right) 615 | # Reduce to [N, M(M-1)/2] 616 | p = tf.reduce_sum(p_full, 2) 617 | 618 | #Reduce to [N, 1] 619 | p = tf.matmul(p, w_p) 620 | 621 | logits = tf.matmul(l_cat, w_l) + b + p 622 | 623 | y_prob = tf.sigmoid(logits) 624 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 625 | if mode == tf.estimator.ModeKeys.PREDICT: 626 | predictions = { 627 | 'class_ids': pred_class, 628 | 'probabilities': y_prob, 629 | 'logits': logits, 630 | } 631 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 632 | 633 | # Compute loss. 634 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 635 | 636 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 637 | loss = tf.math.reduce_mean(loss) 638 | 639 | # L2_loss 640 | l2_loss = tf.nn.l2_loss(w_p) * params['l2_r'] \ 641 | + tf.nn.l2_loss(w_l) * params['l2_linear'] \ 642 | + sum([tf.nn.l2_loss(v) * params['l2_latent'] for v in V]) 643 | 644 | # Compute evaluation metrics. 645 | accuracy = tf.metrics.accuracy(labels=labels, 646 | predictions=pred_class, 647 | name='acc_op1') 648 | auc = tf.metrics.auc(labels=labels, 649 | predictions=y_prob, 650 | name='auc_op1') 651 | 652 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 653 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 654 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 655 | 656 | loss += l2_loss 657 | 658 | if mode == tf.estimator.ModeKeys.EVAL: 659 | return tf.estimator.EstimatorSpec( 660 | mode, loss=loss, eval_metric_ops=metrics) 661 | 662 | # Create training op. 663 | assert mode == tf.estimator.ModeKeys.TRAIN 664 | 665 | # variables_to_restore = tf.contrib.get_variables_to_restore() 666 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 667 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 668 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 669 | 670 | 671 | def FvFM(features, labels, mode, params): 672 | V = [] 673 | xv_cat = [] 674 | F = params['latent_factor'] 675 | feat_cnt = params["categorical_feature_counts"] 676 | 677 | M = 0 # Number of categorical features 678 | for f_name in features.keys(): 679 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 680 | continue 681 | else: 682 | M += 1 683 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 684 | v = tf.get_variable('v_%s' % f_name, dtype=tf.float32, 685 | initializer=tf.random_normal(shape=[feat_cnt[f_name] + 1, F], mean=0.0, stddev=0.01)) 686 | V.append(v) 687 | xv_cat.append(tf.sparse_tensor_dense_matmul(sparse_t, v)) 688 | 689 | # linear term 690 | w_l = tf.get_variable('w_l', dtype=tf.float32, 691 | initializer=tf.random_normal(shape=[M * F, 1], mean=0.0, stddev=1.0)) 692 | # field scalar 693 | w_p = tf.get_variable('w_p', dtype=tf.float32, 694 | initializer=tf.random_normal(shape=[int(M * (M - 1) / 2), 1], mean=0.0, stddev=1.0)) 695 | # global interception 696 | b = tf.get_variable('b', shape=[1], dtype=tf.float32) 697 | 698 | # concact to matrix [N, M * F] 699 | l_cat = tf.concat([xv_cat[i] for i in range(M)], 1) 700 | 701 | l = l_cat 702 | index_left = [] 703 | index_right = [] 704 | 705 | for i in range(M): 706 | for j in range(i + 1, M): 707 | index_left.append(i) 708 | index_right.append(j) 709 | 710 | # reshape l_ to [N, M, F] 711 | l_ = tf.reshape(l, [-1, M, F]) 712 | l_left = tf.gather(l_, index_left, axis=1) 713 | l_right = tf.gather(l_, index_right, axis=1) 714 | 715 | # element-wise multiplication of [N, M(M-1)/2, F] 716 | p_full = tf.multiply(l_left, l_right) 717 | p_full = tf.multiply(p_full, w_p) 718 | 719 | # Reduce to [N, M(M-1)/2] 720 | p = tf.reduce_sum(p_full, 2) 721 | 722 | # Reduce to [N, 1] 723 | p = tf.reduce_sum(p, 1, keepdims=True) 724 | 725 | logits = tf.matmul(l_cat, w_l) + b + p 726 | 727 | y_prob = tf.sigmoid(logits) 728 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 729 | if mode == tf.estimator.ModeKeys.PREDICT: 730 | predictions = { 731 | 'class_ids': pred_class, 732 | 'probabilities': y_prob, 733 | 'logits': logits, 734 | } 735 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 736 | 737 | # Compute loss. 738 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 739 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 740 | loss = tf.math.reduce_mean(loss) 741 | 742 | # L2_loss 743 | l2_loss = tf.nn.l2_loss(w_l) * params['l2_linear'] \ 744 | + tf.nn.l2_loss(w_p) * params['l2_r'] \ 745 | + sum([tf.nn.l2_loss(v) * params['l2_latent'] for v in V]) 746 | 747 | # Compute evaluation metrics. 748 | accuracy = tf.metrics.accuracy(labels=labels, 749 | predictions=pred_class, 750 | name='acc_op1') 751 | auc = tf.metrics.auc(labels=labels, 752 | predictions=y_prob, 753 | name='auc_op1') 754 | 755 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 756 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 757 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 758 | 759 | loss += l2_loss 760 | 761 | if mode == tf.estimator.ModeKeys.EVAL: 762 | return tf.estimator.EstimatorSpec( 763 | mode, loss=loss, eval_metric_ops=metrics) 764 | 765 | # Create training op. 766 | assert mode == tf.estimator.ModeKeys.TRAIN 767 | 768 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 769 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 770 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 771 | 772 | 773 | def FmFM(features, labels, mode, params): 774 | # Create three fully connected layers. 775 | V = [] 776 | W = [] 777 | xv = {} 778 | 779 | feat_cnt = params["categorical_feature_counts"] 780 | feat_dims = params['feat_dims'] 781 | print('Field Dim - {}'.format(feat_dims)) 782 | feat_dim_sum = sum(feat_dims.values()) 783 | cross_fields = params['cross_fields'] 784 | 785 | for f_name in features.keys(): 786 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 787 | continue 788 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 789 | v = tf.get_variable('v_%s' % f_name, dtype=tf.float32, initializer=tf.random_normal( 790 | shape=[feat_cnt[f_name] + 1, feat_dims[f_name]], mean=0.0, stddev=0.01)) 791 | V.append(v) 792 | xv[f_name] = tf.sparse_tensor_dense_matmul(sparse_t, v) 793 | 794 | # linear term 795 | w_l = tf.get_variable('w_l', dtype=tf.float32, 796 | initializer=tf.random_normal(shape=[feat_dim_sum, 1], mean=0.0, stddev=1.0)) 797 | 798 | b = tf.get_variable('b', shape=[1], dtype=tf.float32) 799 | 800 | # concact to matrix [N, M * F] 801 | l_cat = tf.concat(list(xv.values()), 1) 802 | 803 | l_left, l_right = [], [] 804 | cross_fields_selected = [] 805 | 806 | field_order = sorted(feat_dims.items(), key=lambda x: (-x[1], x[0])) 807 | for idx_l, (feat_l, dim_l) in enumerate(field_order): 808 | for idx_r, (feat_r, dim_r) in enumerate(field_order[idx_l+1:]): 809 | idx_r += (idx_l + 1) 810 | name = '%s_%s' % (feat_l, feat_r) 811 | name_alt = '%s_%s' % (feat_r, feat_l) 812 | if cross_fields and name not in cross_fields and name_alt not in cross_fields: 813 | continue 814 | cross_fields_selected.append(name) 815 | w = tf.get_variable(name, dtype=tf.float32, initializer=tf.random_normal( 816 | shape=[dim_l, dim_r], mean=0.0, stddev=0.01)) 817 | W.append(w) 818 | l_left.append(tf.matmul(xv[feat_l], w)) 819 | l_right.append(xv[feat_r]) 820 | print(cross_fields_selected) 821 | l_left = tf.concat(l_left, 1) 822 | l_right = tf.concat(l_right, 1) 823 | 824 | p = tf.multiply(l_left, l_right) 825 | 826 | # Reduce to [N, 1] 827 | p = tf.reduce_sum(p, 1, keepdims=True) 828 | 829 | # Add the linear part and bias 830 | logits = tf.matmul(l_cat, w_l) + b + p 831 | 832 | y_prob = tf.sigmoid(logits) 833 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 834 | if mode == tf.estimator.ModeKeys.PREDICT: 835 | predictions = { 836 | 'class_ids': pred_class, 837 | 'probabilities': y_prob, 838 | 'logits': logits, 839 | } 840 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 841 | 842 | # Compute loss. 843 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 844 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 845 | loss = tf.math.reduce_mean(loss) 846 | 847 | # L2_loss 848 | l2_loss = tf.nn.l2_loss(w_l) * params['l2_linear'] \ 849 | + sum([tf.nn.l2_loss(v) * params['l2_latent'] for v in V]) \ 850 | + sum([tf.nn.l2_loss(w) * params['l2_r'] for w in W]) 851 | 852 | # Compute evaluation metrics. 853 | accuracy = tf.metrics.accuracy(labels=labels, 854 | predictions=pred_class, 855 | name='acc_op1') 856 | auc = tf.metrics.auc(labels=labels, 857 | predictions=y_prob, 858 | name='auc_op1') 859 | 860 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 861 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 862 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 863 | 864 | loss += l2_loss 865 | 866 | if mode == tf.estimator.ModeKeys.EVAL: 867 | return tf.estimator.EstimatorSpec( 868 | mode, loss=loss, eval_metric_ops=metrics) 869 | 870 | # Create training op. 871 | assert mode == tf.estimator.ModeKeys.TRAIN 872 | 873 | # variables_to_restore = tf.contrib.get_variables_to_restore() 874 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 875 | # optimizer = tf.train.AdagradOptimizer(0.2) 876 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 877 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 878 | 879 | 880 | def deepFwFM(features, labels, mode, params): 881 | # Create three fully connected layers. 882 | 883 | X = [] 884 | V = [] 885 | F = params['latent_factor'] 886 | feat_cnt = params["categorical_feature_counts"] 887 | 888 | M_cat = 0 #Number of categorical features 889 | for f_name in features.keys(): 890 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 891 | continue 892 | else: 893 | M_cat += 1 894 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 895 | X.append(sparse_t) 896 | v = tf.get_variable('v_%s' % f_name, initializer=tf.random_normal(shape=[feat_cnt[f_name] + 1, F], mean=0.0, stddev=0.01)) 897 | # w = tf.get_variable('w_%s' % f_name, initializer=tf.random_normal(shape=[F, 1], mean=0.0, stddev=0.01)) 898 | V.append(v) 899 | M = M_cat 900 | # linear term 901 | w_l = tf.get_variable('w_l', initializer=tf.random_normal(shape=[M_cat * F, 1], mean=0.0, stddev=1.0)) 902 | 903 | # final activation layer 904 | w_f = tf.get_variable('w_f', initializer=tf.random_normal(shape=[M + int(M * (M-1)/2) + params['deep_dimension'], 1], mean=0.0, stddev=1.0)) 905 | # global interception 906 | b = tf.get_variable('b', shape=[1]) 907 | 908 | xv_cat = [tf.sparse_tensor_dense_matmul(X[i], V[i]) for i in range(M_cat)] 909 | 910 | # concact to matrix [N, M * F] 911 | l_cat = tf.concat([xv_cat[i] for i in range(M_cat)], 1) 912 | 913 | # linear part 914 | l_linear = tf.matmul(l_cat, w_l) #[N, 1] 915 | 916 | # second order interaction part 917 | l = l_cat 918 | index_left = [] 919 | index_right = [] 920 | 921 | for i in range(M): 922 | for j in range(i + 1, M): 923 | index_left.append(i) 924 | index_right.append(j) 925 | 926 | # reshape l_ to [N, M, F] 927 | l_ = tf.reshape(l, [-1, M, F]) 928 | l_left = tf.gather(l_, index_left, axis=1) 929 | l_right = tf.gather(l_, index_right, axis=1) 930 | # element-wise multiplication of [N, M(M-1)/2, F] 931 | p_full = tf.multiply(l_left, l_right) 932 | # Reduce to [N, M(M-1)/2] 933 | p = tf.reduce_sum(p_full, 2) 934 | 935 | # Deep part 936 | deepmatrix1 = tf.get_variable('deepmatrix1', initializer=tf.random_normal(shape=[M_cat * F, params['deep_dimension']], mean=0.0, stddev=0.01)) 937 | deepmatrix2 = tf.get_variable('deepmatrix2', initializer=tf.random_normal(shape=[params['deep_dimension'], params['deep_dimension']], mean=0.0, stddev=0.01)) 938 | d1 = tf.matmul(l, deepmatrix1) # [N, deep_dim] 939 | d2 = tf.matmul(d1, deepmatrix2) # [N, deep_dim] 940 | 941 | # Combine 942 | l_final = tf.concat([l_linear, p, d2], 1) #[N, 1 + (M-1)*M/2 + deep_dim] 943 | 944 | logits = b + tf.matmul(l_final, w_f) 945 | 946 | 947 | 948 | y_prob = tf.sigmoid(logits) 949 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 950 | if mode == tf.estimator.ModeKeys.PREDICT: 951 | predictions = { 952 | 'class_ids': pred_class, 953 | 'probabilities': y_prob, 954 | 'logits': logits, 955 | } 956 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 957 | 958 | # Compute loss. 959 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 960 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 961 | loss = tf.math.reduce_mean(loss) 962 | 963 | #l2_loss = 0 964 | l2_loss = (tf.nn.l2_loss(w_f) + tf.nn.l2_loss(w_l)) * params['l2_linear'] 965 | for v in V: 966 | l2_loss += tf.nn.l2_loss(v) * params['l2_latent'] 967 | l2_loss += (tf.nn.l2_loss(deepmatrix1) + tf.nn.l2_loss(deepmatrix2)) * params['l2_latent'] 968 | 969 | # Compute evaluation metrics. 970 | accuracy = tf.metrics.accuracy(labels=labels, 971 | predictions=pred_class, 972 | name='acc_op1') 973 | auc = tf.metrics.auc(labels=labels, 974 | predictions=y_prob, 975 | name='auc_op1') 976 | 977 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 978 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 979 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 980 | 981 | loss += l2_loss 982 | 983 | 984 | if mode == tf.estimator.ModeKeys.EVAL: 985 | return tf.estimator.EstimatorSpec( 986 | mode, loss=loss, eval_metric_ops=metrics) 987 | 988 | # Create training op. 989 | assert mode == tf.estimator.ModeKeys.TRAIN 990 | 991 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 992 | #optimizer = tf.train.AdagradOptimizer(0.2) 993 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 994 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 995 | 996 | 997 | def deepFmFM(features, labels, mode, params): 998 | # Create three fully connected layers. 999 | V = [] 1000 | W = [] 1001 | xv = {} 1002 | # xw = [] 1003 | # F = params['latent_factor'] 1004 | 1005 | feat_cnt = params["categorical_feature_counts"] 1006 | 1007 | feat_dim = {} 1008 | if params['feat_dim']: 1009 | feat_dim = params['feat_dim'] 1010 | else: 1011 | for k, v in feat_cnt.items(): 1012 | feat_dim[k] = int(np.ceil(np.log2(v)/2) + 1) 1013 | feat_dim_sum = sum(feat_dim.values()) 1014 | print('Field Dim - {}'.format(feat_dim)) 1015 | 1016 | M_cat = 0 # Number of categorical features 1017 | # M_num = 0 1018 | for f_name in features.keys(): 1019 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 1020 | continue 1021 | M_cat += 1 1022 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 1023 | v = tf.get_variable('v_%s' % f_name, dtype=tf.float32, shape=[feat_cnt[f_name] + 1, feat_dim[f_name]]) 1024 | # w = tf.get_variable('w_%s' % f_name, dtype=tf.float32, initializer=tf.random_normal( 1025 | # shape=[feat_cnt[f_name] + 1, 1], mean=0.0, stddev=0.01)) 1026 | V.append(v) 1027 | # W.append(w) 1028 | xv[f_name] = tf.sparse_tensor_dense_matmul(sparse_t, v) 1029 | # xw.append(tf.sparse_tensor_dense_matmul(sparse_t, w)) 1030 | 1031 | print(features.keys()) 1032 | # M = M_num + M_cat 1033 | # linear term 1034 | # w_l = tf.get_variable('w_l', dtype=tf.float32, 1035 | # initializer=tf.random_normal(shape=[feat_dim_sum, 1], mean=0.0, stddev=1.0)) 1036 | # field scalar 1037 | # w_p = tf.get_variable('w_p', dtype=tf.float32, 1038 | # initializer=tf.random_normal(shape=[int(M * (M - 1) / 2), F], mean=0.0, stddev=1.0)) 1039 | # global interception 1040 | b = tf.get_variable('b', shape=[1], dtype=tf.float32) 1041 | 1042 | # concact to matrix [N, M * F] 1043 | l_cat = tf.concat(list(xv.values()), 1) 1044 | # w_cat = tf.concat(xw, 1) 1045 | # w_cat = tf.concat(xw, 1) 1046 | # l = l_cat 1047 | 1048 | l_left, l_right = [], [] 1049 | # interactions = [] 1050 | 1051 | cross_dim_sum = 0 1052 | field_order = sorted(feat_dim.items(), key=lambda x: x[1], reverse=True) 1053 | for idx_l, (feat_l, dim_l) in enumerate(field_order): 1054 | for idx_r, (feat_r, dim_r) in enumerate(field_order[idx_l+1:]): 1055 | idx_r += (idx_l + 1) 1056 | w = tf.get_variable('%s_%s' % (feat_l, feat_r), dtype=tf.float32, shape=[dim_l, dim_r]) 1057 | W.append(w) 1058 | l_left.append(tf.matmul(xv[feat_l], w)) 1059 | l_right.append(xv[feat_r]) 1060 | cross_dim_sum += dim_r 1061 | # int_vec = tf.matmul(xv[feat_l], w) 1062 | # interactions.append(tf.reduce_sum(tf.multiply(int_vec, xv[feat_r]), 1, keepdims=True)) 1063 | 1064 | # interactions = tf.concat(interactions, 1) 1065 | l_left = tf.concat(l_left, 1) 1066 | l_right = tf.concat(l_right, 1) 1067 | 1068 | p = tf.multiply(l_left, l_right) 1069 | 1070 | 1071 | 1072 | # Deep part 1073 | deepmatrix1 = tf.get_variable('deepmatrix1', dtype=tf.float32, shape=[feat_dim_sum,feat_dim_sum]) 1074 | deep_b1 = tf.get_variable('deep_b1', dtype=tf.float32, shape=[feat_dim_sum]) 1075 | deepmatrix2 = tf.get_variable('deepmatrix2', dtype=tf.float32, shape=[feat_dim_sum, feat_dim_sum]) 1076 | deep_b2 = tf.get_variable('deep_b2', dtype=tf.float32, shape=[feat_dim_sum]) 1077 | d1 = tf.matmul(l_cat, deepmatrix1) + deep_b1 # [N, deep_dim] 1078 | d1 = tf.nn.relu(d1) 1079 | 1080 | d2 = tf.matmul(d1, deepmatrix2) + deep_b2# [N, deep_dim] 1081 | d2 = tf.nn.relu(d2) 1082 | 1083 | # final activation layer 1084 | w_f = tf.get_variable('w_f', shape=[feat_dim_sum + cross_dim_sum, 1]) 1085 | 1086 | # Reduce to [N, 1] 1087 | # p = tf.reduce_sum(p, 1, keepdims=True) 1088 | l_f = tf.concat([p, d2], 1) 1089 | logits = tf.matmul(l_f, w_f) + b 1090 | y_prob = tf.sigmoid(logits) 1091 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 1092 | if mode == tf.estimator.ModeKeys.PREDICT: 1093 | predictions = { 1094 | 'class_ids': pred_class, 1095 | 'probabilities': y_prob, 1096 | 'logits': logits, 1097 | } 1098 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 1099 | 1100 | # Compute loss. 1101 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 1102 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 1103 | loss = tf.math.reduce_mean(loss) 1104 | 1105 | # L2_loss 1106 | l2_loss = (tf.nn.l2_loss(w_f) + tf.nn.l2_loss(deepmatrix1) + tf.nn.l2_loss(deepmatrix2)) * params['l2_linear'] \ 1107 | + sum([tf.nn.l2_loss(v) * params['l2_latent'] for v in V]) 1108 | # + sum([tf.nn.l2_loss(w) * params['l2_latent'] for w in W]) 1109 | 1110 | # Compute evaluation metrics. 1111 | accuracy = tf.metrics.accuracy(labels=labels, 1112 | predictions=pred_class, 1113 | name='acc_op1') 1114 | auc = tf.metrics.auc(labels=labels, 1115 | predictions=y_prob, 1116 | name='auc_op1') 1117 | 1118 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 1119 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 1120 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 1121 | 1122 | loss += l2_loss 1123 | 1124 | if mode == tf.estimator.ModeKeys.EVAL: 1125 | return tf.estimator.EstimatorSpec( 1126 | mode, loss=loss, eval_metric_ops=metrics) 1127 | 1128 | # Create training op. 1129 | assert mode == tf.estimator.ModeKeys.TRAIN 1130 | 1131 | # variables_to_restore = tf.contrib.get_variables_to_restore() 1132 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 1133 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 1134 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 1135 | 1136 | 1137 | def DCN(features, labels, mode, params): 1138 | V = [] 1139 | xv = {} 1140 | 1141 | feat_cnt = params["categorical_feature_counts"] 1142 | feat_dims = params['feat_dims'] 1143 | feat_dim_sum = sum(feat_dims.values()) 1144 | print('Field Dim - {}'.format(feat_dims)) 1145 | 1146 | M_cat = 0 # Number of categorical features 1147 | for f_name in features.keys(): 1148 | if f_name in {'label', 'tag'} or params["feature_types"][f_name] == 'NUMERIC': 1149 | continue 1150 | M_cat += 1 1151 | sparse_t = dense_to_sparse(features[f_name], feat_cnt[f_name] + 1) 1152 | v = tf.get_variable('v_%s' % f_name, dtype=tf.float32, shape=[feat_cnt[f_name] + 1, feat_dims[f_name]]) 1153 | V.append(v) 1154 | xv[f_name] = tf.sparse_tensor_dense_matmul(sparse_t, v) 1155 | 1156 | print(features.keys()) 1157 | 1158 | # global interception 1159 | b = tf.get_variable('b', shape=[1], dtype=tf.float32) 1160 | 1161 | # concact to matrix [N, M * F] 1162 | l_cat = tf.concat(list(xv.values()), 1) 1163 | 1164 | # mat_list, for L2 reg 1165 | mat_list = [] 1166 | 1167 | # DNN part 1168 | h_layer = l_cat 1169 | for i in range(4): 1170 | deep_mat = tf.get_variable('deep_mat_%d' % i, dtype=tf.float32, shape=[feat_dim_sum, feat_dim_sum]) 1171 | deep_bias = tf.get_variable('deep_b_%d' % i, dtype=tf.float32, shape=[feat_dim_sum]) 1172 | mat_list.extend([deep_mat, deep_bias]) 1173 | h_layer = tf.matmul(h_layer, deep_mat) + deep_bias # [N, deep_dim] 1174 | h_layer = tf.nn.relu(h_layer) 1175 | 1176 | # cross network 1177 | y_cross_0 = tf.reshape(l_cat, shape=[-1, feat_dim_sum, 1]) 1178 | y_cross_i = tf.reshape(l_cat, shape=[-1, 1, feat_dim_sum]) 1179 | y_cross = l_cat 1180 | 1181 | for i in range(4): 1182 | x0T_x_x1 = tf.matmul(y_cross_0, y_cross_i) 1183 | cross_layer_w = tf.get_variable('cross_layer_w_%d' % i, dtype=tf.float32, shape=[feat_dim_sum, 1]) 1184 | cross_layer_b = tf.get_variable('cross_layer_b_%d' % i, dtype=tf.float32, shape=[1, 1, feat_dim_sum]) 1185 | mat_list.extend([cross_layer_b, cross_layer_w]) 1186 | y_cross_i = tf.add(tf.reshape(tf.matmul(x0T_x_x1, cross_layer_w), shape=[-1, 1, feat_dim_sum]), cross_layer_b) 1187 | y_cross = tf.add(y_cross, tf.reshape(y_cross_i, shape=[-1, feat_dim_sum])) 1188 | 1189 | # final activation layer 1190 | w_f = tf.get_variable('w_f', shape=[2 * feat_dim_sum, 1]) 1191 | mat_list.append(w_f) 1192 | 1193 | # Reduce to [N, 1] 1194 | l_f = tf.concat([y_cross, h_layer], 1) 1195 | logits = tf.matmul(l_f, w_f) + b 1196 | y_prob = tf.sigmoid(logits) 1197 | pred_class = tf.cast((y_prob >= 0.5), tf.bool) 1198 | if mode == tf.estimator.ModeKeys.PREDICT: 1199 | predictions = { 1200 | 'class_ids': pred_class, 1201 | 'probabilities': y_prob, 1202 | 'logits': logits, 1203 | } 1204 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 1205 | 1206 | # Compute loss. 1207 | labels = tf.reshape(tf.cast(labels, tf.float32), [-1, 1]) 1208 | loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) 1209 | loss = tf.math.reduce_mean(loss) 1210 | 1211 | # L2_loss 1212 | l2_loss = sum([tf.nn.l2_loss(v) * params['l2_linear'] for v in mat_list])\ 1213 | + sum([tf.nn.l2_loss(v) * params['l2_latent'] for v in V]) 1214 | 1215 | # Compute evaluation metrics. 1216 | accuracy = tf.metrics.accuracy(labels=labels, 1217 | predictions=pred_class, 1218 | name='acc_op1') 1219 | auc = tf.metrics.auc(labels=labels, 1220 | predictions=y_prob, 1221 | name='auc_op1') 1222 | 1223 | metric_orig_loss = tf.metrics.mean(loss, name='orig_loss_op') 1224 | metric_l2_loss = tf.metrics.mean(l2_loss, name='l2_loss_op') 1225 | metrics = {'accuracy': accuracy, 'auc': auc, 'orig_loss': metric_orig_loss, 'l2_loss': metric_l2_loss} 1226 | 1227 | loss += l2_loss 1228 | 1229 | if mode == tf.estimator.ModeKeys.EVAL: 1230 | return tf.estimator.EstimatorSpec( 1231 | mode, loss=loss, eval_metric_ops=metrics) 1232 | 1233 | # Create training op. 1234 | assert mode == tf.estimator.ModeKeys.TRAIN 1235 | 1236 | # variables_to_restore = tf.contrib.get_variables_to_restore() 1237 | optimizer = tf.train.AdamOptimizer(params['learning_rate']) 1238 | train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) 1239 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) 1240 | 1241 | 1242 | def build_custom_linear_classifier(model_dir, feature_names, feature_types, categorical_feature_counts, l2_linear, 1243 | l2_latent, l2_r, learning_rate, latent_dimension, model=None, feat_dims=None, cross_fields=None): 1244 | Call_functions = { 1245 | 'LR': LR, 1246 | 'FM': FM, 1247 | 'FFM': FFM, 1248 | 'FwFM': FwFM, 1249 | 'FvFM': FvFM, 1250 | 'FmFM': FmFM, 1251 | 'deepFwFM': deepFwFM, 1252 | 'deepFmFM': deepFmFM, 1253 | 'DCN': DCN 1254 | } 1255 | config = tf.estimator.RunConfig(keep_checkpoint_max=1, save_checkpoints_steps=10000) 1256 | 1257 | estimator = tf.estimator.Estimator( 1258 | model_fn=Call_functions[model], 1259 | model_dir=model_dir, 1260 | params={ 1261 | 'feature_names': feature_names, 1262 | 'feature_types': feature_types, 1263 | 'categorical_feature_counts': categorical_feature_counts, 1264 | 'n_classes': 2, 1265 | 'latent_factor': latent_dimension, 1266 | 'l2_linear': l2_linear, 1267 | 'l2_latent': l2_latent, 1268 | 'l2_r': l2_r, 1269 | 'learning_rate': learning_rate, 1270 | 'deep_dimension': 200, 1271 | 'feat_dims': feat_dims, 1272 | 'cross_fields': cross_fields, 1273 | 'emb_dim': 128}, 1274 | config=config) 1275 | 1276 | return estimator 1277 | 1278 | --------------------------------------------------------------------------------