├── Dataprocess ├── Avazu │ └── preprocess.py ├── Criteo │ ├── config.py │ ├── preprocess.py │ └── scale.py ├── KDD2012 │ ├── preprocess.py │ └── scale.py └── Kfold_split │ ├── config.py │ └── stratifiedKfold.py ├── LICENSE ├── README.md ├── figures └── model.png ├── model.py ├── sample_preprocess.sh ├── test_code.sh ├── train.py └── train_examples.txt /Dataprocess/Avazu/preprocess.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | #Email of the author: zjduan@pku.edu.cn 3 | ''' 4 | 0.id: ad identifier 5 | 1.click: 0/1 for non-click/click 6 | 2.hour: format is YYMMDDHH, so 14091123 means 23:00 on Sept. 11, 2014 UTC. 7 | 3.C1 -- anonymized categorical variable 8 | 4.banner_pos 9 | 5.site_id 10 | 6.site_domain 11 | 7.site_category 12 | 8.app_id 13 | 9.app_domain 14 | 10.app_category 15 | 11.device_id 16 | 12.device_ip 17 | 13.device_model 18 | 14.device_type 19 | 15.device_conn_type 20 | 16.C14 21 | 17.C15 22 | 18.C16 23 | 19.C17 24 | 20.C18 25 | 21.C19 26 | 22.C20 27 | 23.C21 28 | ''' 29 | import pandas as pd 30 | import math 31 | train_path = './train.csv' 32 | f1 = open(train_path, 'r') 33 | dic = {} 34 | f_train_value = open('./train_x.txt', 'w') 35 | f_train_index = open('./train_i.txt', 'w') 36 | f_train_label = open('./train_y.txt', 'w') 37 | debug = False 38 | tune = False 39 | Bound = [5] * 24 40 | 41 | label_index = 1 42 | Column = 24 43 | 44 | numr_feat = [] 45 | numerical = [0] * Column 46 | numerical[label_index] = -1 47 | 48 | cate_feat = [] 49 | for i in range(Column): 50 | if (numerical[i] == 0): 51 | cate_feat.extend([i]) 52 | 53 | index_cnt = 0 54 | index_others = [0] * Column 55 | Max = [0] * Column 56 | 57 | 58 | for i in numr_feat: 59 | index_others[i] = index_cnt 60 | index_cnt += 1 61 | numerical[i] = 1 62 | for i in cate_feat: 63 | index_others[i] = index_cnt 64 | index_cnt += 1 65 | 66 | for i in range(Column): 67 | dic[i] = dict() 68 | 69 | cnt_line = 0 70 | for line in f1: 71 | cnt_line += 1 72 | if (cnt_line == 1): continue # header 73 | if (cnt_line % 1000000 == 0): 74 | print ("cnt_line = %d, index_cnt = %d" % (cnt_line, index_cnt)) 75 | if (debug == True): 76 | if (cnt_line >= 10000): 77 | break 78 | split = line.strip('\n').split(',') 79 | for i in cate_feat: 80 | if (split[i] != ''): 81 | if split[i] not in dic[i]: 82 | dic[i][split[i]] = [index_others[i], 0] 83 | dic[i][split[i]][1] += 1 84 | if (dic[i][split[i]][0] == index_others[i] and dic[i][split[i]][1] == Bound[i]): 85 | dic[i][split[i]][0] = index_cnt 86 | index_cnt += 1 87 | 88 | if (tune == False): 89 | label = split[label_index] 90 | if (label != '0'): label = '1' 91 | index = [0] * (Column - 1) 92 | value = ['0'] * (Column - 1) 93 | for i in range(Column): 94 | cur = i 95 | if (i == label_index): continue 96 | if (i > label_index): cur = i - 1 97 | if (numerical[i] == 1): 98 | index[cur] = index_others[i] 99 | if (split[i] != ''): 100 | value[cur] = split[i] 101 | # Max[i] = max(int(split[i]), Max[i]) 102 | else: 103 | if (split[i] != ''): 104 | index[cur] = dic[i][split[i]][0] 105 | value[cur] = '1' 106 | 107 | if (split[i] == ''): 108 | value[cur] = '0' 109 | 110 | f_train_index.write(' '.join(str(i) for i in index) + '\n') 111 | f_train_value.write(' '.join(value) + '\n') 112 | f_train_label.write(label + '\n') 113 | 114 | f1.close() 115 | f_train_index.close() 116 | f_train_value.close() 117 | f_train_label.close() 118 | print ("Finished!") 119 | print ("index_cnt = %d" % index_cnt) 120 | # print ("max number for numerical features:") 121 | # for i in numr_feat: 122 | # print ("no.:%d max: %d" % (i, Max[i])) 123 | 124 | 125 | -------------------------------------------------------------------------------- /Dataprocess/Criteo/config.py: -------------------------------------------------------------------------------- 1 | DATA_PATH = './Criteo/' 2 | SOURCE_DATA = './train_examples.txt' -------------------------------------------------------------------------------- /Dataprocess/Criteo/preprocess.py: -------------------------------------------------------------------------------- 1 | import config 2 | 3 | train_path = config.SOURCE_DATA 4 | f1 = open(train_path,'r') 5 | dic= {} 6 | # generate three fold. 7 | # train_x: value 8 | # train_i: index 9 | # train_y: label 10 | f_train_value = open(config.DATA_PATH + 'train_x.txt','w') 11 | f_train_index = open(config.DATA_PATH + 'train_i.txt','w') 12 | f_train_label = open(config.DATA_PATH + 'train_y.txt','w') 13 | 14 | for i in range(39): 15 | dic[i] = {} 16 | 17 | cnt_train = 0 18 | 19 | #for debug 20 | #limits = 10000 21 | index = [1] * 26 22 | for line in f1: 23 | cnt_train +=1 24 | if cnt_train % 100000 ==0: 25 | print('now train cnt : %d\n' % cnt_train) 26 | #if cnt_train > limits: 27 | # break 28 | split = line.strip('\n').split('\t') 29 | # 0-label, 1-13 numerical, 14-39 category 30 | for i in range(13,39): 31 | #dic_len = len(dic[i]) 32 | if split[i+1] not in dic[i]: 33 | # [1, 0] 1 is the index for those whose appear times <= 10 0 indicates the appear times 34 | dic[i][split[i+1]] = [1,0] 35 | dic[i][split[i+1]][1] += 1 36 | if dic[i][split[i+1]][0] == 1 and dic[i][split[i+1]][1] > 10: 37 | index[i-13] += 1 38 | dic[i][split[i+1]][0] = index[i-13] 39 | f1.close() 40 | print('total entries :%d\n' % (cnt_train - 1)) 41 | 42 | # calculate number of category features of every dimension 43 | kinds = [13] 44 | for i in range(13,39): 45 | kinds.append(index[i-13]) 46 | print('number of dimensions : %d' % (len(kinds)-1)) 47 | print(kinds) 48 | 49 | for i in range(1,len(kinds)): 50 | kinds[i] += kinds[i-1] 51 | print(kinds) 52 | 53 | # make new data 54 | 55 | f1 = open(train_path,'r') 56 | cnt_train = 0 57 | print('remake training data...\n') 58 | for line in f1: 59 | cnt_train +=1 60 | if cnt_train % 100000 ==0: 61 | print('now train cnt : %d\n' % cnt_train) 62 | #if cnt_train > limits: 63 | # break 64 | entry = ['0'] * 39 65 | index = [None] * 39 66 | split = line.strip('\n').split('\t') 67 | label = str(split[0]) 68 | for i in range(13): 69 | if split[i+1] != '': 70 | entry[i] = (split[i+1]) 71 | index[i] = (i+1) 72 | for i in range(13,39): 73 | if split[i+1] != '': 74 | entry[i] = '1' 75 | index[i] = (dic[i][split[i+1]][0]) 76 | for j in range(26): 77 | index[13+j] += kinds[j] 78 | index = [str(item) for item in index] 79 | f_train_value.write(' '.join(entry)+'\n') 80 | f_train_index.write(' '.join(index)+'\n') 81 | f_train_label.write(label+'\n') 82 | f1.close() 83 | 84 | 85 | f_train_value.close() 86 | f_train_index.close() 87 | f_train_label.close() 88 | 89 | 90 | -------------------------------------------------------------------------------- /Dataprocess/Criteo/scale.py: -------------------------------------------------------------------------------- 1 | import math 2 | import config 3 | import numpy as np 4 | def scale(x): 5 | if x > 2: 6 | x = int(math.log(float(x))**2) 7 | return x 8 | 9 | 10 | 11 | def scale_each_fold(): 12 | for i in range(1,11): 13 | print('now part %d' % i) 14 | data = np.load(config.DATA_PATH + 'part'+str(i)+'/train_x.npy') 15 | part = data[:,0:13] 16 | for j in range(part.shape[0]): 17 | if j % 100000 ==0: 18 | print(j) 19 | part[j] = list(map(scale, part[j])) 20 | np.save(config.DATA_PATH + 'part' + str(i) + '/train_x2.npy', data) 21 | 22 | 23 | if __name__ == '__main__': 24 | scale_each_fold() -------------------------------------------------------------------------------- /Dataprocess/KDD2012/preprocess.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | #Email of the author: zjduan@pku.edu.cn 3 | ''' 4 | 0. Click: 5 | 1. Impression(numerical) 6 | 2. DisplayURL: (categorical) 7 | 3. AdID:(categorical) 8 | 4. AdvertiserID:(categorical) 9 | 5. Depth:(numerical) 10 | 6. Position:(numerical) 11 | 7. QueryID: (categorical) the key of the data file 'queryid_tokensid.txt'. 12 | 8. KeywordID: (categorical)the key of 'purchasedkeyword_tokensid.txt'. 13 | 9. TitleID: (categorical)the key of 'titleid_tokensid.txt'. 14 | 10. DescriptionID: (categorical)the key of 'descriptionid_tokensid.txt'. 15 | 11. UserID: (categorical)the key of 'userid_profile.txt' 16 | 12. User's Gender: (categorical) 17 | 13. User's Age: (categorical) 18 | ''' 19 | import math 20 | train_path = './training.txt' 21 | f1 = open(train_path, 'r') 22 | f2 = open('./userid_profile.txt', 'r') 23 | dic = {} 24 | f_train_value = open('./train_x.txt', 'w') 25 | f_train_index = open('./train_i.txt', 'w') 26 | f_train_label = open('./train_y.txt', 'w') 27 | debug = False 28 | tune = False 29 | Column = 12 30 | Field = 13 31 | 32 | numr_feat = [1,5,6] 33 | numerical = [0] * Column 34 | cate_feat = [2,3,4,7,8,9,10,11] 35 | index_cnt = 0 36 | index_others = [0] * (Field + 1) 37 | Max = [0] * 12 38 | numerical[0] = -1 39 | for i in numr_feat: 40 | index_others[i] = index_cnt 41 | index_cnt += 1 42 | numerical[i] = 1 43 | for i in cate_feat: 44 | index_others[i] = index_cnt 45 | index_cnt += 1 46 | 47 | for i in range(Field + 1): 48 | dic[i] = dict() 49 | 50 | ###init user_dic 51 | user_dic = dict() 52 | 53 | cnt_line = 0 54 | for line in f2: 55 | cnt_line += 1 56 | if (cnt_line % 1000000 == 0): 57 | print ("cnt_line = %d, index_cnt = %d" % (cnt_line, index_cnt)) 58 | # if (debug == True): 59 | # if (cnt_line >= 10000): 60 | # break 61 | split = line.strip('\n').split('\t') 62 | user_dic[split[0]] = [split[1], split[2]] 63 | if (split[1] not in dic[12]): 64 | dic[12][split[1]] = [index_cnt, 0] 65 | index_cnt += 1 66 | if (split[2] not in dic[13]): 67 | dic[13][split[2]] = [index_cnt, 0] 68 | index_cnt += 1 69 | 70 | cnt_line = 0 71 | for line in f1: 72 | cnt_line += 1 73 | if (cnt_line % 1000000 == 0): 74 | print ("cnt_line = %d, index_cnt = %d" % (cnt_line, index_cnt)) 75 | if (debug == True): 76 | if (cnt_line >= 10000): 77 | break 78 | split = line.strip('\n').split('\t') 79 | for i in cate_feat: 80 | if (split[i] != ''): 81 | if split[i] not in dic[i]: 82 | dic[i][split[i]] = [index_others[i], 0] 83 | dic[i][split[i]][1] += 1 84 | if (dic[i][split[i]][0] == index_others[i] and dic[i][split[i]][1] == 10): 85 | dic[i][split[i]][0] = index_cnt 86 | index_cnt += 1 87 | 88 | if (tune == False): 89 | label = split[0] 90 | if (label != '0'): label = '1' 91 | index = [0] * Field 92 | value = ['0'] * Field 93 | for i in range(1, 12): 94 | if (numerical[i] == 1): 95 | index[i - 1] = index_others[i] 96 | if (split[i] != ''): 97 | value[i - 1] = split[i] 98 | Max[i] = max(int(split[i]), Max[i]) 99 | else: 100 | if (split[i] != ''): 101 | index[i - 1] = dic[i][split[i]][0] 102 | value[i - 1] = '1' 103 | 104 | if (split[i] == ''): 105 | value[i - 1] = '0' 106 | if (i == 11 and split[i] == '0'): 107 | value[i - 1] = '0' 108 | ### gender and age 109 | if (split[11] == '' or (split[11] not in user_dic)): 110 | index[12 - 1] = index_others[12] 111 | value[12 - 1] = '0' 112 | index[13 - 1] = index_others[13] 113 | value[13 - 1] = '0' 114 | else: 115 | index[12 - 1] = dic[12][user_dic[split[11]][0]][0] 116 | value[12 - 1] = '1' 117 | index[13 - 1] = dic[13][user_dic[split[11]][1]][0] 118 | value[13 - 1] = '1' 119 | 120 | f_train_index.write(' '.join(str(i) for i in index) + '\n') 121 | f_train_value.write(' '.join(value) + '\n') 122 | f_train_label.write(label + '\n') 123 | 124 | f1.close() 125 | f_train_index.close() 126 | f_train_value.close() 127 | f_train_label.close() 128 | print ("Finished!") 129 | print ("index_cnt = %d" % index_cnt) 130 | print ("max number for numerical features:") 131 | for i in numr_feat: 132 | print ("no.:%d max: %d" % (i, Max[i])) 133 | -------------------------------------------------------------------------------- /Dataprocess/KDD2012/scale.py: -------------------------------------------------------------------------------- 1 | import math 2 | import config 3 | import numpy as np 4 | def scale(x): 5 | if x > 2: 6 | x = int(math.log(float(x))**2) 7 | return x 8 | 9 | 10 | 11 | def scale_each_fold(): 12 | for i in range(1,11): 13 | print('now part %d' % i) 14 | data = np.load(config.DATA_PATH + 'part'+str(i)+'/train_x.npy') 15 | part = data[:,0:13] 16 | for j in range(part.shape[0]): 17 | if j % 100000 ==0: 18 | print(j) 19 | part[j] = list(map(scale, part[j])) 20 | np.save(config.DATA_PATH + 'part' + str(i) + '/train_x2.npy', data) 21 | 22 | 23 | if __name__ == '__main__': 24 | scale_each_fold() -------------------------------------------------------------------------------- /Dataprocess/Kfold_split/config.py: -------------------------------------------------------------------------------- 1 | DATA_PATH = './Criteo/' 2 | TRAIN_I = DATA_PATH + 'train_i.txt' 3 | TRAIN_X = DATA_PATH + 'train_x.txt' 4 | TRAIN_Y = DATA_PATH + 'train_y.txt' 5 | 6 | NUM_SPLITS = 10 7 | RANDOM_SEED = 2018 8 | 9 | -------------------------------------------------------------------------------- /Dataprocess/Kfold_split/stratifiedKfold.py: -------------------------------------------------------------------------------- 1 | #Email of the author: zjduan@pku.edu.cn 2 | import numpy as np 3 | import config 4 | import os 5 | import pandas as pd 6 | from sklearn.model_selection import StratifiedKFold 7 | from sklearn import preprocessing 8 | 9 | scale = "" 10 | train_x_name = "train_x.npy" 11 | train_y_name = "train_y.npy" 12 | 13 | # numr_feat = [] 14 | Column = 13 15 | 16 | def _load_data(_nrows=None, debug = False): 17 | 18 | train_x = pd.read_csv(config.TRAIN_X,header=None,sep=' ',nrows=_nrows, dtype=np.float) 19 | train_y = pd.read_csv(config.TRAIN_Y,header=None,sep=' ',nrows=_nrows, dtype=np.int32) 20 | 21 | # for i in range(11): 22 | # print ("argmax feat %d = %d, max = %d" % (i, train_x[i].argmax(), train_x[i].max())) 23 | 24 | train_x = train_x.values 25 | train_y = train_y.values.reshape([-1]) 26 | 27 | # print ("begin to scale") 28 | # if (scale == "minmax"): 29 | # train_x = preprocessing.MinMaxScaler().fit_transform(train_x) 30 | 31 | # if (scale == "std"): 32 | # train_x[:,0:12] = preprocessing.scale(train_x[:,0:12]) 33 | # train_x[:,0:12] += 1 34 | 35 | print('data loading done!') 36 | print('training data : %d' % train_y.shape[0]) 37 | 38 | 39 | assert train_x.shape[0]==train_y.shape[0] 40 | 41 | return train_x, train_y 42 | 43 | 44 | def save_x_y(fold_index, train_x, train_y): 45 | _get = lambda x, l: [x[i] for i in l] 46 | for i in range(len(fold_index)): 47 | print("now part %d" % (i+1)) 48 | part_index = fold_index[i] 49 | Xv_train_, y_train_ = _get(train_x, part_index), _get(train_y, part_index) 50 | save_dir_Xv = config.DATA_PATH + "part" + str(i+1) + "/" 51 | save_dir_y = config.DATA_PATH + "part" + str(i+1) + "/" 52 | if (os.path.exists(save_dir_Xv) == False): 53 | os.makedirs(save_dir_Xv) 54 | if (os.path.exists(save_dir_y) == False): 55 | os.makedirs(save_dir_y) 56 | save_path_Xv = save_dir_Xv + train_x_name 57 | save_path_y = save_dir_y + train_y_name 58 | np.save(save_path_Xv, Xv_train_) 59 | np.save(save_path_y, y_train_) 60 | 61 | 62 | # def save_test(test_x, test_y): 63 | # np.save("../data/test/test_x.npy", test_x) 64 | # np.save("../data/test/test_y.npy", test_y) 65 | 66 | 67 | def save_i(fold_index): 68 | _get = lambda x, l: [x[i] for i in l] 69 | train_i = pd.read_csv(config.TRAIN_I,header=None,sep=' ',nrows=None, dtype=np.int32) 70 | train_i = train_i.values 71 | feature_size = train_i.max() + 1 72 | print ("feature_size = %d" % feature_size) 73 | feature_size = [feature_size] 74 | feature_size = np.array(feature_size) 75 | np.save(config.DATA_PATH + "feature_size.npy", feature_size) 76 | 77 | # pivot = 40000000 78 | 79 | # test_i = train_i[pivot:] 80 | # train_i = train_i[:pivot] 81 | 82 | # print("test_i size: %d" % len(test_i)) 83 | print("train_i size: %d" % len(train_i)) 84 | 85 | # np.save("../data/test/test_i.npy", test_i) 86 | 87 | for i in range(len(fold_index)): 88 | print("now part %d" % (i+1)) 89 | part_index = fold_index[i] 90 | Xi_train_ = _get(train_i, part_index) 91 | save_path_Xi = config.DATA_PATH + "part" + str(i+1)+ '/train_i.npy' 92 | np.save(save_path_Xi, Xi_train_) 93 | 94 | 95 | def main(): 96 | 97 | train_x, train_y = _load_data() 98 | print('loading data done!') 99 | 100 | folds = list(StratifiedKFold(n_splits=10, shuffle=True, 101 | random_state=config.RANDOM_SEED).split(train_x, train_y)) 102 | 103 | fold_index = [] 104 | for i,(train_id, valid_id) in enumerate(folds): 105 | fold_index.append(valid_id) 106 | 107 | print("fold num: %d" % (len(fold_index))) 108 | 109 | fold_index = np.array(fold_index) 110 | np.save(config.DATA_PATH + "fold_index.npy", fold_index) 111 | 112 | save_x_y(fold_index, train_x, train_y) 113 | print("save train_x_y done!") 114 | 115 | fold_index = np.load(config.DATA_PATH + "fold_index.npy") 116 | save_i(fold_index) 117 | print("save index done!") 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Chence Shi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Note 2 | We have moved the repo to [https://github.com/DeepGraphLearning/RecommenderSystems/tree/master/featureRec](https://github.com/DeepGraphLearning/RecommenderSystems/tree/master/featureRec). Please check out the latest version there. 3 | 4 | # AutoInt 5 | 6 | This is a TenforFlow implementation of ***AutoInt*** for CTR prediction task, as described in our paper: 7 | 8 | Weiping Song, Chence Shi, Zhiping Xiao, Zhijian Duan, Yewen Xu, Ming Zhang and Jian Tang. [AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/pdf/1810.11921.pdf). arXiv preprint arXiv:1810.11921, 2018. 9 | 10 | ## Requirements: 11 | * **Tensorflow 1.4.0-rc1** 12 | * Python 3 13 | * CUDA 8.0+ (For GPU) 14 | 15 | ## Introduction 16 | 17 | AutoInt:An effective and efficient algorithm to 18 | automatically learn high-order feature interactions for (sparse) categorical and numerical features. 19 | 20 |
21 | 22 |
23 | The illustration of AutoInt. We first project all sparse features 24 | (both categorical and numerical features) into the low-dimensional space. Next, we feed embeddings of all fields into stacked multiple interacting layers implemented by self-attentive neural network. The output of the final interacting layer is the low-dimensional representation of learnt combinatorial features, which is further used for estimating the CTR via sigmoid function. 25 | 26 | ## Usage 27 | ### Input Format 28 | AutoInt requires the input data in the following format: 29 | * train_x: matrix with shape *(num_sample, num_field)*. train_x[s][t] is the feature value of feature field t of sample s in the dataset. The default value for categorical feature is 1. 30 | * train_i: matrix with shape *(num_sample, num_field)*. train_i[s][t] is the feature index of feature field t of sample s in the dataset. The maximal value of train_i is the feature size. 31 | * train_y: label of each sample in the dataset. 32 | 33 | If you want to know how to preprocess the data, please refer to `./Dataprocess/Criteo/preprocess.py` 34 | 35 | ### Example 36 | We use four public real-world datasets(Avazu, Criteo, KDD12, MovieLens-1M) in our experiments. Since the first three datasets are super huge, they can not be fit into the memory as a whole. In our implementation, we split the whole dataset into 10 parts and we use the first file as test set and the second file as valid set. We provide the codes for preprocessing these three datasets in `./Dataprocess`. If you want to reuse these codes, you should first run `preprocess.py` to generate `train_x.txt, train_i.txt, train_y.txt` as described in `Input Format`. Then you should run `./Dataprocesss/Kfold_split/StratifiedKfold.py` to split the whole dataset into ten folds. Finally you can run `scale.py` to scale the numerical value(optional). 37 | 38 | To help test the correctness of the code and familarize yourself with the code, we upload the first `10000` samples of `Criteo` dataset in `train_examples.txt`. And we provide the scripts for preprocessing and training.(Please refer to ` sample_preprocess.sh` and `test_code.sh`, you may need to modify the path in `config.py` and `test_code.sh`). 39 | 40 | After you run the `test_code.sh`, you should get a folder named `Criteo` which contains `part*, feature_size.npy, fold_index.npy, train_*.txt`. `feature_size.npy` contains the number of total features which will be used to initialize the model. `train_*.txt` is the whole dataset. If you use other small dataset, say `MovieLens-1M`, you only need to modify the function `_run_` in `train.py`. 41 | 42 | Here's how to run the preprocessing. 43 | ``` 44 | mkdir Criteo 45 | python ./Dataprocess/Criteo/preprocess.py 46 | python ./Dataprocess/Kfold_split/stratifiedKfold.py 47 | python ./Dataprocess/Criteo/scale.py 48 | ``` 49 | 50 | Here's how to run the training. 51 | ``` 52 | python -u train.py \ 53 | --data "Criteo" --blocks 3 --heads 2 --block_shape "[64, 64, 64]" \ 54 | --is_save "True" --save_path "./test_code/Criteo/b3h2_64x64x64/" \ 55 | --field_size 39 --run_times 1 --data_path "./" \ 56 | --epoch 3 --has_residual "True" --has_wide "False" \ 57 | --batch_size 1024 \ 58 | > test_code_single.out & 59 | ``` 60 | 61 | You should see output like this: 62 | 63 | ``` 64 | ... 65 | train logs 66 | ... 67 | start testing!... 68 | restored from ./test_code/Criteo/b3h2_dnn_dropkeep1_400x2/1/ 69 | test-result = 0.8088, test-logloss = 0.4430 70 | test_auc [0.8088305055534442] 71 | test_log_loss [0.44297631300399626] 72 | avg_auc 0.8088305055534442 73 | avg_log_loss 0.44297631300399626 74 | ``` 75 | 76 | ## Citation 77 | If you find AutoInt useful for your research, please consider citing the following paper: 78 | ``` 79 | @article{weiping2018autoint, 80 | title={AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks}, 81 | author={Weiping, Song and Chence, Shi and Zhiping, Xiao and Zhijian, Duan and Yewen, Xu and Ming, Zhang and Jian, Tang}, 82 | journal={arXiv preprint arXiv:1810.11921}, 83 | year={2018} 84 | } 85 | ``` 86 | 87 | 88 | ## Contact information 89 | If you have questions related to the code, feel free to contact Weiping Song (`songweiping@pku.edu.cn`), Chence Shi (`chenceshi@pku.edu.cn`) and Zhijian Duan (`zjduan@pku.edu.cn`). 90 | 91 | ## License 92 | MIT 93 | 94 | ## Acknowledgement 95 | This implementation gets inspirations from Kyubyong Park's [transformer](https://github.com/Kyubyong/transformer) and Chenglong Chen' [DeepFM](https://github.com/ChenglongChen/tensorflow-DeepFM). 96 | -------------------------------------------------------------------------------- /figures/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shichence/AutoInt/d88536fbe9733e45d182fb600657d0bdf82b640c/figures/model.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Tensorflow implementation of AutoInt described in: 3 | AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks. 4 | author: Chence Shi 5 | email: chenceshi@pku.edu.cn 6 | ''' 7 | 8 | import os 9 | import numpy as np 10 | import tensorflow as tf 11 | from time import time 12 | from sklearn.base import BaseEstimator, TransformerMixin 13 | from sklearn.metrics import roc_auc_score, log_loss 14 | from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm 15 | 16 | 17 | 18 | ''' 19 | The following two functions are adapted from kyubyong park's implementation of transformer 20 | We slightly modify the code to make it suitable for our work.(add relu, delete key masking and causality mask) 21 | June 2017 by kyubyong park. 22 | kbpark.linguist@gmail.com. 23 | https://www.github.com/kyubyong/transformer 24 | ''' 25 | 26 | 27 | def normalize(inputs, epsilon=1e-8): 28 | ''' 29 | Applies layer normalization 30 | Args: 31 | inputs: A tensor with 2 or more dimensions 32 | epsilon: A floating number to prevent Zero Division 33 | Returns: 34 | A tensor with the same shape and data dtype 35 | ''' 36 | inputs_shape = inputs.get_shape() 37 | params_shape = inputs_shape[-1:] 38 | 39 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 40 | beta = tf.Variable(tf.zeros(params_shape)) 41 | gamma = tf.Variable(tf.ones(params_shape)) 42 | normalized = (inputs - mean) / ((variance + epsilon) ** (.5)) 43 | outputs = gamma * normalized + beta 44 | 45 | return outputs 46 | 47 | 48 | def multihead_attention(queries, 49 | keys, 50 | values, 51 | num_units=None, 52 | num_heads=1, 53 | dropout_keep_prob=1, 54 | is_training=True, 55 | has_residual=True): 56 | 57 | if num_units is None: 58 | num_units = queries.get_shape().as_list[-1] 59 | 60 | # Linear projections 61 | Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) 62 | K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) 63 | V = tf.layers.dense(values, num_units, activation=tf.nn.relu) 64 | if has_residual: 65 | V_res = tf.layers.dense(values, num_units, activation=tf.nn.relu) 66 | 67 | # Split and concat 68 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) 69 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) 70 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) 71 | 72 | # Multiplication 73 | weights = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) 74 | 75 | # Scale 76 | weights = weights / (K_.get_shape().as_list()[-1] ** 0.5) 77 | 78 | # Activation 79 | weights = tf.nn.softmax(weights) 80 | 81 | 82 | # Dropouts 83 | weights = tf.layers.dropout(weights, rate=1-dropout_keep_prob, 84 | training=tf.convert_to_tensor(is_training)) 85 | 86 | # Weighted sum 87 | outputs = tf.matmul(weights, V_) 88 | 89 | # Restore shape 90 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) 91 | 92 | # Residual connection 93 | if has_residual: 94 | outputs += V_res 95 | 96 | outputs = tf.nn.relu(outputs) 97 | # Normalize 98 | outputs = normalize(outputs) 99 | 100 | return outputs 101 | 102 | 103 | class AutoInt(): 104 | def __init__(self, args, feature_size, run_cnt): 105 | #print(args.block_shape) 106 | #print(type(args.block_shape)) 107 | 108 | self.feature_size = feature_size # denote as n, dimension of concatenated features 109 | self.field_size = args.field_size # denote as M, number of total feature fields 110 | self.embedding_size = args.embedding_size # denote as d, size of the feature embedding 111 | self.blocks = args.blocks # number of the blocks 112 | self.heads = args.heads # number of the heads 113 | self.block_shape = args.block_shape 114 | self.output_size = args.block_shape[-1] 115 | self.has_residual = args.has_residual 116 | self.has_wide = args.has_wide # whether to add wide part 117 | self.deep_layers = args.deep_layers # whether to joint train with deep networks as described in paper 118 | 119 | 120 | self.batch_norm = args.batch_norm 121 | self.batch_norm_decay = args.batch_norm_decay 122 | self.drop_keep_prob = args.dropout_keep_prob 123 | self.l2_reg = args.l2_reg 124 | self.epoch = args.epoch 125 | self.batch_size = args.batch_size 126 | self.learning_rate = args.learning_rate 127 | self.learning_rate_wide = args.learning_rate_wide 128 | self.optimizer_type = args.optimizer_type 129 | 130 | self.save_path = args.save_path + str(run_cnt) + '/' 131 | self.is_save = args.is_save 132 | if (args.is_save == True and os.path.exists(self.save_path) == False): 133 | os.makedirs(self.save_path) 134 | 135 | self.verbose = args.verbose 136 | self.random_seed = args.random_seed 137 | self.loss_type = args.loss_type 138 | self.eval_metric = roc_auc_score 139 | self.best_loss = 1.0 140 | self.greater_is_better = args.greater_is_better 141 | self.train_result, self.valid_result = [], [] 142 | self.train_loss, self.valid_loss = [], [] 143 | 144 | self._init_graph() 145 | 146 | 147 | def _init_graph(self): 148 | self.graph = tf.Graph() 149 | with self.graph.as_default(): 150 | 151 | tf.set_random_seed(self.random_seed) 152 | 153 | self.feat_index = tf.placeholder(tf.int32, shape=[None, None], 154 | name="feat_index") # None * M 155 | self.feat_value = tf.placeholder(tf.float32, shape=[None, None], 156 | name="feat_value") # None * M 157 | self.label = tf.placeholder(tf.float32, shape=[None, 1], name="label") # None * 1 158 | # In our implementation, the shape of dropout_keep_prob is [3], used in 3 different parts. 159 | self.dropout_keep_prob = tf.placeholder(tf.float32, shape=[None], name="dropout_keep_prob") 160 | self.train_phase = tf.placeholder(tf.bool, name="train_phase") 161 | 162 | self.weights = self._initialize_weights() 163 | 164 | # model 165 | self.embeddings = tf.nn.embedding_lookup(self.weights["feature_embeddings"], 166 | self.feat_index) # None * M * d 167 | feat_value = tf.reshape(self.feat_value, shape=[-1, self.field_size, 1]) 168 | self.embeddings = tf.multiply(self.embeddings, feat_value) # None * M * d 169 | self.embeddings = tf.nn.dropout(self.embeddings, self.dropout_keep_prob[1]) # None * M * d 170 | if self.has_wide: 171 | self.y_first_order = tf.nn.embedding_lookup(self.weights["feature_bias"], self.feat_index) # None * M * 1 172 | self.y_first_order = tf.reduce_sum(tf.multiply(self.y_first_order, feat_value), 1) # None * 1 173 | 174 | # joint training with feedforward nn 175 | if self.deep_layers != None: 176 | self.y_dense = tf.reshape(self.embeddings, shape=[-1, self.field_size * self.embedding_size]) 177 | for i in range(0, len(self.deep_layers)): 178 | self.y_dense = tf.add(tf.matmul(self.y_dense, self.weights["layer_%d" %i]), self.weights["bias_%d"%i]) # None * layer[i] 179 | if self.batch_norm: 180 | self.y_dense = self.batch_norm_layer(self.y_dense, train_phase=self.train_phase, scope_bn="bn_%d" %i) 181 | self.y_dense = tf.nn.relu(self.y_dense) 182 | self.y_dense = tf.nn.dropout(self.y_dense, self.dropout_keep_prob[2]) 183 | self.y_dense = tf.add(tf.matmul(self.y_dense, self.weights["prediction_dense"]), 184 | self.weights["prediction_bias_dense"], name='logits_dense') # None * 1 185 | 186 | 187 | # ---------- main part of AutoInt------------------- 188 | self.y_deep = self.embeddings # None * M * d 189 | for i in range(self.blocks): 190 | self.y_deep = multihead_attention(queries=self.y_deep, 191 | keys=self.y_deep, 192 | values=self.y_deep, 193 | num_units=self.block_shape[i], 194 | num_heads=self.heads, 195 | dropout_keep_prob=self.dropout_keep_prob[0], 196 | is_training=self.train_phase, 197 | has_residual=self.has_residual) 198 | 199 | self.flat = tf.reshape(self.y_deep, 200 | shape=[-1, self.output_size * self.field_size]) 201 | #if self.has_wide: 202 | # self.flat = tf.concat([self.flat, self.y_first_order], axis=1) 203 | #if self.deep_layers != None: 204 | # self.flat = tf.concat([self.flat, self.y_dense], axis=1) 205 | self.out = tf.add(tf.matmul(self.flat, self.weights["prediction"]), 206 | self.weights["prediction_bias"], name='logits') # None * 1 207 | 208 | if self.has_wide: 209 | self.out += self.y_first_order 210 | 211 | if self.deep_layers != None: 212 | self.out += self.y_dense 213 | 214 | # ---------- Compute the loss ---------- 215 | # loss 216 | if self.loss_type == "logloss": 217 | self.out = tf.nn.sigmoid(self.out, name='pred') 218 | self.loss = tf.losses.log_loss(self.label, self.out) 219 | elif self.loss_type == "mse": 220 | self.loss = tf.nn.l2_loss(tf.subtract(self.label, self.out)) 221 | 222 | # l2 regularization on weights 223 | if self.l2_reg > 0: 224 | if self.deep_layers != None: 225 | for i in range(len(self.deep_layers)): 226 | self.loss += tf.contrib.layers.l2_regularizer( 227 | self.l2_reg)(self.weights["layer_%d"%i]) 228 | #self.loss += tf.contrib.layers.l2_regularizer(self.l2_reg)(self.embeddings) 229 | #all_vars = tf.trainable_variables() 230 | #lossL2 = tf.add_n([ tf.nn.l2_loss(v) for v in all_vars 231 | # if 'bias' not in v.name and 'embeddings' not in v.name]) * self.l2_reg 232 | #self.loss += lossL2 233 | 234 | self.global_step = tf.Variable(0, name="global_step", trainable=False) 235 | self.var1 = [v for v in tf.trainable_variables() if v.name != 'feature_bias:0'] 236 | self.var2 = [tf.trainable_variables()[1]] # self.var2 = [feature_bias] 237 | # optimizer 238 | # here we should use two different optimizer for wide and deep model(if we add wide part). 239 | if self.optimizer_type == "adam": 240 | if self.has_wide: 241 | optimizer1 = tf.train.AdamOptimizer(learning_rate=self.learning_rate, 242 | beta1=0.9, beta2=0.999, epsilon=1e-8) 243 | optimizer2 = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate_wide) 244 | #minimize(self.loss, global_step=self.global_step) 245 | var_list1 = self.var1 246 | var_list2 = self.var2 247 | grads = tf.gradients(self.loss, var_list1 + var_list2) 248 | grads1 = grads[:len(var_list1)] 249 | grads2 = grads[len(var_list1):] 250 | train_op1 = optimizer1.apply_gradients(zip(grads1, var_list1), global_step=self.global_step) 251 | train_op2 = optimizer2.apply_gradients(zip(grads2, var_list2)) 252 | self.optimizer = tf.group(train_op1, train_op2) 253 | else: 254 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, 255 | beta1=0.9, beta2=0.999, epsilon=1e-8).\ 256 | minimize(self.loss, global_step=self.global_step) 257 | elif self.optimizer_type == "adagrad": 258 | self.optimizer = tf.train.AdagradOptimizer(learning_rate=self.learning_rate, 259 | initial_accumulator_value=1e-8).\ 260 | minimize(self.loss) 261 | elif self.optimizer_type == "gd": 262 | self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate).\ 263 | minimize(self.loss) 264 | elif self.optimizer_type == "momentum": 265 | self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.95).\ 266 | minimize(self.loss) 267 | 268 | # init 269 | self.saver = tf.train.Saver(max_to_keep=5) 270 | init = tf.global_variables_initializer() 271 | self.sess = self._init_session() 272 | self.sess.run(init) 273 | self.count_param() 274 | 275 | 276 | def count_param(self): 277 | k = (np.sum([np.prod(v.get_shape().as_list()) 278 | for v in tf.trainable_variables()])) 279 | 280 | #print(tf.trainable_variables()) 281 | print("total parameters :%d" % k) 282 | print("extra parameters : %d" % (k - self.feature_size * self.embedding_size)) 283 | 284 | 285 | def _init_session(self): 286 | config = tf.ConfigProto(allow_soft_placement=True) 287 | config.gpu_options.allow_growth = True 288 | return tf.Session(config=config) 289 | 290 | 291 | def _initialize_weights(self): 292 | weights = dict() 293 | 294 | # embeddings 295 | weights["feature_embeddings"] = tf.Variable( 296 | tf.random_normal([self.feature_size, self.embedding_size], 0.0, 0.01), 297 | name="feature_embeddings") # feature_size(n) * d 298 | 299 | if self.has_wide: 300 | weights["feature_bias"] = tf.Variable( 301 | tf.random_normal([self.feature_size, 1], 0.0, 0.001), 302 | name="feature_bias") # feature_size(n) * 1 303 | input_size = self.output_size * self.field_size 304 | #if self.deep_layers != None: 305 | # input_size += self.deep_layers[-1] 306 | #if self.has_wide: 307 | # input_size += self.field_size 308 | 309 | # dense layers 310 | if self.deep_layers != None: 311 | num_layer = len(self.deep_layers) 312 | layer0_size = self.field_size * self.embedding_size 313 | glorot = np.sqrt(2.0 / (layer0_size + self.deep_layers[0])) 314 | weights["layer_0"] = tf.Variable( 315 | np.random.normal(loc=0, scale=glorot, size=(layer0_size, self.deep_layers[0])), dtype=np.float32) 316 | weights["bias_0"] = tf.Variable(np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[0])), 317 | dtype=np.float32) # 1 * layers[0] 318 | for i in range(1, num_layer): 319 | glorot = np.sqrt(2.0 / (self.deep_layers[i-1] + self.deep_layers[i])) 320 | weights["layer_%d" % i] = tf.Variable( 321 | np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[i-1], self.deep_layers[i])), 322 | dtype=np.float32) # layers[i-1] * layers[i] 323 | weights["bias_%d" % i] = tf.Variable( 324 | np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[i])), 325 | dtype=np.float32) # 1 * layer[i] 326 | glorot = np.sqrt(2.0 / (self.deep_layers[-1] + 1)) 327 | weights["prediction_dense"] = tf.Variable( 328 | np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[-1], 1)), 329 | dtype=np.float32, name="prediction_dense") 330 | weights["prediction_bias_dense"] = tf.Variable( 331 | np.random.normal(), dtype=np.float32, name="prediction_bias_dense") 332 | 333 | 334 | #---------- prediciton weight ------------------# 335 | glorot = np.sqrt(2.0 / (input_size + 1)) 336 | weights["prediction"] = tf.Variable( 337 | np.random.normal(loc=0, scale=glorot, size=(input_size, 1)), 338 | dtype=np.float32, name="prediction") 339 | weights["prediction_bias"] = tf.Variable( 340 | np.random.normal(), dtype=np.float32, name="prediction_bias") 341 | 342 | return weights 343 | 344 | def batch_norm_layer(self, x, train_phase, scope_bn): 345 | bn_train = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None, 346 | is_training=True, reuse=None, trainable=True, scope=scope_bn) 347 | bn_inference = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None, 348 | is_training=False, reuse=True, trainable=True, scope=scope_bn) 349 | z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference) 350 | return z 351 | 352 | 353 | def get_batch(self, Xi, Xv, y, batch_size, index): 354 | start = index * batch_size 355 | end = (index+1) * batch_size 356 | end = end if end < len(y) else len(y) 357 | return Xi[start:end], Xv[start:end], [[y_] for y_ in y[start:end]] 358 | 359 | 360 | # shuffle three lists simutaneously 361 | def shuffle_in_unison_scary(self, a, b, c): 362 | rng_state = np.random.get_state() 363 | np.random.shuffle(a) 364 | np.random.set_state(rng_state) 365 | np.random.shuffle(b) 366 | np.random.set_state(rng_state) 367 | np.random.shuffle(c) 368 | 369 | 370 | def fit_on_batch(self, Xi, Xv, y): 371 | feed_dict = {self.feat_index: Xi, 372 | self.feat_value: Xv, 373 | self.label: y, 374 | self.dropout_keep_prob: self.drop_keep_prob, 375 | self.train_phase: True} 376 | step, loss, opt = self.sess.run((self.global_step, self.loss, self.optimizer), feed_dict=feed_dict) 377 | return step, loss 378 | 379 | # Since the train data is very large, they can not be fit into the memory at the same time. 380 | # We separate the whole train data into several files and call "fit_once" for each file. 381 | def fit_once(self, Xi_train, Xv_train, y_train, 382 | epoch, file_count, Xi_valid=None, 383 | Xv_valid=None, y_valid=None, 384 | early_stopping=False): 385 | 386 | has_valid = Xv_valid is not None 387 | last_step = 0 388 | t1 = time() 389 | self.shuffle_in_unison_scary(Xi_train, Xv_train, y_train) 390 | total_batch = int(len(y_train) / self.batch_size) 391 | for i in range(total_batch): 392 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi_train, Xv_train, y_train, self.batch_size, i) 393 | step, loss = self.fit_on_batch(Xi_batch, Xv_batch, y_batch) 394 | last_step = step 395 | 396 | # evaluate training and validation datasets 397 | train_result, train_loss = self.evaluate(Xi_train, Xv_train, y_train) 398 | self.train_result.append(train_result) 399 | self.train_loss.append(train_loss) 400 | if has_valid: 401 | valid_result, valid_loss = self.evaluate(Xi_valid, Xv_valid, y_valid) 402 | self.valid_result.append(valid_result) 403 | self.valid_loss.append(valid_loss) 404 | if valid_loss < self.best_loss and self.is_save == True: 405 | old_loss = self.best_loss 406 | self.best_loss = valid_loss 407 | self.saver.save(self.sess, self.save_path + 'model.ckpt',global_step=last_step) 408 | print("[%d-%d] model saved!. Valid loss is improved from %.4f to %.4f" 409 | % (epoch, file_count, old_loss, self.best_loss)) 410 | 411 | if self.verbose > 0 and ((epoch-1)*9 + file_count) % self.verbose == 0: 412 | if has_valid: 413 | print("[%d-%d] train-result=%.4f, train-logloss=%.4f, valid-result=%.4f, valid-logloss=%.4f [%.1f s]" % (epoch, file_count, train_result, train_loss, valid_result, valid_loss, time() - t1)) 414 | else: 415 | print("[%d-%d] train-result=%.4f [%.1f s]" \ 416 | % (epoch, file_count, train_result, time() - t1)) 417 | if has_valid and early_stopping and self.training_termination(self.valid_loss): 418 | return False 419 | else: 420 | return True 421 | 422 | 423 | 424 | def training_termination(self, valid_result): 425 | if len(valid_result) > 5: 426 | if self.greater_is_better: 427 | if valid_result[-1] < valid_result[-2] and \ 428 | valid_result[-2] < valid_result[-3] and \ 429 | valid_result[-3] < valid_result[-4] and \ 430 | valid_result[-4] < valid_result[-5]: 431 | return True 432 | else: 433 | if valid_result[-1] > valid_result[-2] and \ 434 | valid_result[-2] > valid_result[-3] and \ 435 | valid_result[-3] > valid_result[-4] and \ 436 | valid_result[-4] > valid_result[-5]: 437 | return True 438 | return False 439 | 440 | 441 | def predict(self, Xi, Xv): 442 | """ 443 | :param Xi: list of list of feature indices of each sample in the dataset 444 | :param Xv: list of list of feature values of each sample in the dataset 445 | :return: predicted probability of each sample 446 | """ 447 | # dummy y 448 | dummy_y = [1] * len(Xi) 449 | batch_index = 0 450 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index) 451 | y_pred = None 452 | #y_loss = None 453 | while len(Xi_batch) > 0: 454 | num_batch = len(y_batch) 455 | feed_dict = {self.feat_index: Xi_batch, 456 | self.feat_value: Xv_batch, 457 | self.label: y_batch, 458 | self.dropout_keep_prob: [1.0] * len(self.drop_keep_prob), 459 | self.train_phase: False} 460 | batch_out = self.sess.run(self.out, feed_dict=feed_dict) 461 | 462 | if batch_index == 0: 463 | y_pred = np.reshape(batch_out, (num_batch,)) 464 | #y_loss = np.reshape(batch_loss, (num_batch,)) 465 | else: 466 | y_pred = np.concatenate((y_pred, np.reshape(batch_out, (num_batch,)))) 467 | #y_loss = np.concatenate((y_loss, np.reshape(batch_loss, (num_batch,)))) 468 | 469 | batch_index += 1 470 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index) 471 | 472 | return y_pred 473 | 474 | 475 | def evaluate(self, Xi, Xv, y): 476 | """ 477 | :param Xi: list of list of feature indices of each sample in the dataset 478 | :param Xv: list of list of feature values of each sample in the dataset 479 | :param y: label of each sample in the dataset 480 | :return: metric of the evaluation 481 | """ 482 | y_pred = self.predict(Xi, Xv) 483 | y_pred = np.clip(y_pred,1e-6,1-1e-6) 484 | return self.eval_metric(y, y_pred), log_loss(y, y_pred) 485 | 486 | def restore(self, save_path=None): 487 | if (save_path == None): 488 | save_path = self.save_path 489 | ckpt = tf.train.get_checkpoint_state(save_path) 490 | if ckpt and ckpt.model_checkpoint_path: 491 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 492 | if self.verbose > 0: 493 | print ("restored from %s" % (save_path)) 494 | -------------------------------------------------------------------------------- /sample_preprocess.sh: -------------------------------------------------------------------------------- 1 | mkdir Criteo 2 | python ./Dataprocess/Criteo/preprocess.py 3 | python ./Dataprocess/Kfold_split/stratifiedKfold.py 4 | python ./Dataprocess/Criteo/scale.py 5 | -------------------------------------------------------------------------------- /test_code.sh: -------------------------------------------------------------------------------- 1 | echo "start training ..." 2 | 3 | python -u train.py \ 4 | --data "Criteo" --blocks 3 --heads 2 --block_shape "[64, 64, 64]" \ 5 | --is_save "True" --save_path "./test_code/Criteo/b3h2_64x64x64/" \ 6 | --field_size 39 --run_times 1 --data_path "./" \ 7 | --epoch 3 --has_residual "True" --has_wide "False" \ 8 | --batch_size 1024 \ 9 | > test_code_single.out & 10 | 11 | 12 | python -u train.py \ 13 | --data "Criteo" --blocks 3 --heads 2 --block_shape "[64, 64, 64]" \ 14 | --is_save "True" --save_path "./test_code/Criteo/b3h2_dnn_dropkeep1_400x2/" \ 15 | --field_size 39 --run_times 1 --dropout_keep_prob "[1, 1, 1]" --data_path "./" \ 16 | --epoch 3 --has_residual "True" --has_wide "False" --deep_layers "[400, 400]"\ 17 | --batch_size 1024 \ 18 | > ./test_code_dnn.out & 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import pandas as pd 4 | import tensorflow as tf 5 | from sklearn.metrics import make_scorer 6 | from sklearn.model_selection import StratifiedKFold 7 | from time import time 8 | from model import AutoInt 9 | import argparse 10 | import os 11 | 12 | 13 | def str2list(v): 14 | v=v.split(',') 15 | v=[int(_.strip('[]')) for _ in v] 16 | 17 | return v 18 | 19 | 20 | def str2list2(v): 21 | v=v.split(',') 22 | v=[float(_.strip('[]')) for _ in v] 23 | 24 | return v 25 | 26 | 27 | def str2bool(v): 28 | if v.lower() in ['yes', 'true', 't', 'y', '1']: 29 | return True 30 | elif v.lower() in ['no', 'false', 'f', 'n', '0']: 31 | return False 32 | else: 33 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--blocks', type=int, default=2, help='#blocks') 38 | parser.add_argument('--block_shape', type=str2list, default=[16,16], help='output shape of each block') 39 | parser.add_argument('--heads', type=int, default=2, help='#heads') 40 | parser.add_argument('--embedding_size', type=int, default=16) 41 | parser.add_argument('--dropout_keep_prob', type=str2list2, default=[1, 1, 0.5]) 42 | parser.add_argument('--epoch', type=int, default=2) 43 | parser.add_argument('--batch_size', type=int, default=1024) 44 | parser.add_argument('--learning_rate', type=float, default=0.001) 45 | parser.add_argument('--learning_rate_wide', type=float, default=0.001) 46 | parser.add_argument('--optimizer_type', type=str, default='adam') 47 | parser.add_argument('--l2_reg', type=float, default=0.0) 48 | parser.add_argument('--random_seed', type=int, default=2018) 49 | parser.add_argument('--save_path', type=str, default='./model/') 50 | parser.add_argument('--field_size', type=int, default=23, help='#fields') 51 | parser.add_argument('--loss_type', type=str, default='logloss') 52 | parser.add_argument('--verbose', type=int, default=1) 53 | parser.add_argument('--run_times', type=int, default=3,help='run multiple times to eliminate error') 54 | parser.add_argument('--is_save', type=str2bool, default=False) 55 | parser.add_argument('--greater_is_better', type=str2bool, default=False, help='early stop criterion') 56 | parser.add_argument('--has_residual', type=str2bool, default=True, help='add residual or not') 57 | parser.add_argument('--has_wide', type=str2bool, default=False) 58 | parser.add_argument('--deep_layers', type=str2list, default=None, help='config for dnn in joint train') 59 | parser.add_argument('--batch_norm', type=int, default=0) 60 | parser.add_argument('--batch_norm_decay', type=float, default=0.995) 61 | parser.add_argument('--data', type=str, help='data name') 62 | parser.add_argument('--data_path', type=str, help='root path for all the data') 63 | return parser.parse_args() 64 | 65 | 66 | 67 | def _run_(args, file_name, run_cnt): 68 | #path_prefix = '../Dataprocess/' + args.data 69 | path_prefix = os.path.join(args.data_path, args.data) 70 | feature_size = np.load(path_prefix + '/feature_size.npy')[0] 71 | 72 | # test: file1, valid: file2, train: file3-10 73 | model = AutoInt(args=args, feature_size=feature_size, run_cnt=run_cnt) 74 | 75 | #variables = tf.contrib.framework.get_variables_to_restore() 76 | #print(variables) 77 | #return 78 | Xi_valid = np.load(path_prefix + '/part2/' + file_name[0]) 79 | Xv_valid = np.load(path_prefix + '/part2/' + file_name[1]) 80 | y_valid = np.load(path_prefix + '/part2/' + file_name[2]) 81 | 82 | is_continue = True 83 | for k in range(model.epoch): 84 | if not is_continue: 85 | print('early stopping at epoch %d' % (k+1)) 86 | break 87 | file_count = 0 88 | time_epoch = 0 89 | for j in range(3, 11): 90 | if not is_continue: 91 | print('early stopping at epoch %d file %d' % (k+1, j)) 92 | break 93 | file_count += 1 94 | Xi_train = np.load(path_prefix + '/part' + str(j) + '/' + file_name[0]) 95 | Xv_train = np.load(path_prefix + '/part' + str(j) + '/' + file_name[1]) 96 | y_train = np.load(path_prefix + '/part' + str(j) + '/' + file_name[2]) 97 | 98 | print("epoch %d, file %d" %(k+1, j)) 99 | t1 = time() 100 | is_continue = model.fit_once(Xi_train, Xv_train, y_train, k+1, file_count, 101 | Xi_valid, Xv_valid, y_valid, early_stopping=True) 102 | time_epoch += time() - t1 103 | 104 | print("epoch %d, time %d" % (k+1, time_epoch)) 105 | 106 | 107 | print('start testing!...') 108 | Xi_test = np.load(path_prefix + '/part1/' + file_name[0]) 109 | Xv_test = np.load(path_prefix + '/part1/' + file_name[1]) 110 | y_test = np.load(path_prefix + '/part1/' + file_name[2]) 111 | 112 | model.restore() 113 | 114 | test_result, test_loss = model.evaluate(Xi_test, Xv_test, y_test) 115 | print("test-result = %.4lf, test-logloss = %.4lf" % (test_result, test_loss)) 116 | return test_result, test_loss 117 | 118 | if __name__ == "__main__": 119 | args = parse_args() 120 | print(args.__dict__) 121 | print('**************') 122 | #file_name = [] 123 | if args.data in ['Avazu']: 124 | # Avazu does not have numerical features so we didn't scale the data. 125 | file_name = ['train_i.npy', 'train_x.npy', 'train_y.npy'] 126 | elif args.data in ['Criteo', 'KDD2012']: 127 | file_name = ['train_i.npy', 'train_x2.npy', 'train_y.npy'] 128 | test_auc = [] 129 | test_log = [] 130 | 131 | print('run time : %d' % args.run_times) 132 | for i in range(1, args.run_times + 1): 133 | test_result, test_loss = _run_(args, file_name, i) 134 | test_auc.append(test_result) 135 | test_log.append(test_loss) 136 | print('test_auc', test_auc) 137 | print('test_log_loss', test_log) 138 | print('avg_auc', sum(test_auc)/len(test_auc)) 139 | print('avg_log_loss', sum(test_log)/len(test_log)) 140 | 141 | --------------------------------------------------------------------------------