├── Dataprocess
├── Avazu
│ └── preprocess.py
├── Criteo
│ ├── config.py
│ ├── preprocess.py
│ └── scale.py
├── KDD2012
│ ├── preprocess.py
│ └── scale.py
└── Kfold_split
│ ├── config.py
│ └── stratifiedKfold.py
├── LICENSE
├── README.md
├── figures
└── model.png
├── model.py
├── sample_preprocess.sh
├── test_code.sh
├── train.py
└── train_examples.txt
/Dataprocess/Avazu/preprocess.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | #Email of the author: zjduan@pku.edu.cn
3 | '''
4 | 0.id: ad identifier
5 | 1.click: 0/1 for non-click/click
6 | 2.hour: format is YYMMDDHH, so 14091123 means 23:00 on Sept. 11, 2014 UTC.
7 | 3.C1 -- anonymized categorical variable
8 | 4.banner_pos
9 | 5.site_id
10 | 6.site_domain
11 | 7.site_category
12 | 8.app_id
13 | 9.app_domain
14 | 10.app_category
15 | 11.device_id
16 | 12.device_ip
17 | 13.device_model
18 | 14.device_type
19 | 15.device_conn_type
20 | 16.C14
21 | 17.C15
22 | 18.C16
23 | 19.C17
24 | 20.C18
25 | 21.C19
26 | 22.C20
27 | 23.C21
28 | '''
29 | import pandas as pd
30 | import math
31 | train_path = './train.csv'
32 | f1 = open(train_path, 'r')
33 | dic = {}
34 | f_train_value = open('./train_x.txt', 'w')
35 | f_train_index = open('./train_i.txt', 'w')
36 | f_train_label = open('./train_y.txt', 'w')
37 | debug = False
38 | tune = False
39 | Bound = [5] * 24
40 |
41 | label_index = 1
42 | Column = 24
43 |
44 | numr_feat = []
45 | numerical = [0] * Column
46 | numerical[label_index] = -1
47 |
48 | cate_feat = []
49 | for i in range(Column):
50 | if (numerical[i] == 0):
51 | cate_feat.extend([i])
52 |
53 | index_cnt = 0
54 | index_others = [0] * Column
55 | Max = [0] * Column
56 |
57 |
58 | for i in numr_feat:
59 | index_others[i] = index_cnt
60 | index_cnt += 1
61 | numerical[i] = 1
62 | for i in cate_feat:
63 | index_others[i] = index_cnt
64 | index_cnt += 1
65 |
66 | for i in range(Column):
67 | dic[i] = dict()
68 |
69 | cnt_line = 0
70 | for line in f1:
71 | cnt_line += 1
72 | if (cnt_line == 1): continue # header
73 | if (cnt_line % 1000000 == 0):
74 | print ("cnt_line = %d, index_cnt = %d" % (cnt_line, index_cnt))
75 | if (debug == True):
76 | if (cnt_line >= 10000):
77 | break
78 | split = line.strip('\n').split(',')
79 | for i in cate_feat:
80 | if (split[i] != ''):
81 | if split[i] not in dic[i]:
82 | dic[i][split[i]] = [index_others[i], 0]
83 | dic[i][split[i]][1] += 1
84 | if (dic[i][split[i]][0] == index_others[i] and dic[i][split[i]][1] == Bound[i]):
85 | dic[i][split[i]][0] = index_cnt
86 | index_cnt += 1
87 |
88 | if (tune == False):
89 | label = split[label_index]
90 | if (label != '0'): label = '1'
91 | index = [0] * (Column - 1)
92 | value = ['0'] * (Column - 1)
93 | for i in range(Column):
94 | cur = i
95 | if (i == label_index): continue
96 | if (i > label_index): cur = i - 1
97 | if (numerical[i] == 1):
98 | index[cur] = index_others[i]
99 | if (split[i] != ''):
100 | value[cur] = split[i]
101 | # Max[i] = max(int(split[i]), Max[i])
102 | else:
103 | if (split[i] != ''):
104 | index[cur] = dic[i][split[i]][0]
105 | value[cur] = '1'
106 |
107 | if (split[i] == ''):
108 | value[cur] = '0'
109 |
110 | f_train_index.write(' '.join(str(i) for i in index) + '\n')
111 | f_train_value.write(' '.join(value) + '\n')
112 | f_train_label.write(label + '\n')
113 |
114 | f1.close()
115 | f_train_index.close()
116 | f_train_value.close()
117 | f_train_label.close()
118 | print ("Finished!")
119 | print ("index_cnt = %d" % index_cnt)
120 | # print ("max number for numerical features:")
121 | # for i in numr_feat:
122 | # print ("no.:%d max: %d" % (i, Max[i]))
123 |
124 |
125 |
--------------------------------------------------------------------------------
/Dataprocess/Criteo/config.py:
--------------------------------------------------------------------------------
1 | DATA_PATH = './Criteo/'
2 | SOURCE_DATA = './train_examples.txt'
--------------------------------------------------------------------------------
/Dataprocess/Criteo/preprocess.py:
--------------------------------------------------------------------------------
1 | import config
2 |
3 | train_path = config.SOURCE_DATA
4 | f1 = open(train_path,'r')
5 | dic= {}
6 | # generate three fold.
7 | # train_x: value
8 | # train_i: index
9 | # train_y: label
10 | f_train_value = open(config.DATA_PATH + 'train_x.txt','w')
11 | f_train_index = open(config.DATA_PATH + 'train_i.txt','w')
12 | f_train_label = open(config.DATA_PATH + 'train_y.txt','w')
13 |
14 | for i in range(39):
15 | dic[i] = {}
16 |
17 | cnt_train = 0
18 |
19 | #for debug
20 | #limits = 10000
21 | index = [1] * 26
22 | for line in f1:
23 | cnt_train +=1
24 | if cnt_train % 100000 ==0:
25 | print('now train cnt : %d\n' % cnt_train)
26 | #if cnt_train > limits:
27 | # break
28 | split = line.strip('\n').split('\t')
29 | # 0-label, 1-13 numerical, 14-39 category
30 | for i in range(13,39):
31 | #dic_len = len(dic[i])
32 | if split[i+1] not in dic[i]:
33 | # [1, 0] 1 is the index for those whose appear times <= 10 0 indicates the appear times
34 | dic[i][split[i+1]] = [1,0]
35 | dic[i][split[i+1]][1] += 1
36 | if dic[i][split[i+1]][0] == 1 and dic[i][split[i+1]][1] > 10:
37 | index[i-13] += 1
38 | dic[i][split[i+1]][0] = index[i-13]
39 | f1.close()
40 | print('total entries :%d\n' % (cnt_train - 1))
41 |
42 | # calculate number of category features of every dimension
43 | kinds = [13]
44 | for i in range(13,39):
45 | kinds.append(index[i-13])
46 | print('number of dimensions : %d' % (len(kinds)-1))
47 | print(kinds)
48 |
49 | for i in range(1,len(kinds)):
50 | kinds[i] += kinds[i-1]
51 | print(kinds)
52 |
53 | # make new data
54 |
55 | f1 = open(train_path,'r')
56 | cnt_train = 0
57 | print('remake training data...\n')
58 | for line in f1:
59 | cnt_train +=1
60 | if cnt_train % 100000 ==0:
61 | print('now train cnt : %d\n' % cnt_train)
62 | #if cnt_train > limits:
63 | # break
64 | entry = ['0'] * 39
65 | index = [None] * 39
66 | split = line.strip('\n').split('\t')
67 | label = str(split[0])
68 | for i in range(13):
69 | if split[i+1] != '':
70 | entry[i] = (split[i+1])
71 | index[i] = (i+1)
72 | for i in range(13,39):
73 | if split[i+1] != '':
74 | entry[i] = '1'
75 | index[i] = (dic[i][split[i+1]][0])
76 | for j in range(26):
77 | index[13+j] += kinds[j]
78 | index = [str(item) for item in index]
79 | f_train_value.write(' '.join(entry)+'\n')
80 | f_train_index.write(' '.join(index)+'\n')
81 | f_train_label.write(label+'\n')
82 | f1.close()
83 |
84 |
85 | f_train_value.close()
86 | f_train_index.close()
87 | f_train_label.close()
88 |
89 |
90 |
--------------------------------------------------------------------------------
/Dataprocess/Criteo/scale.py:
--------------------------------------------------------------------------------
1 | import math
2 | import config
3 | import numpy as np
4 | def scale(x):
5 | if x > 2:
6 | x = int(math.log(float(x))**2)
7 | return x
8 |
9 |
10 |
11 | def scale_each_fold():
12 | for i in range(1,11):
13 | print('now part %d' % i)
14 | data = np.load(config.DATA_PATH + 'part'+str(i)+'/train_x.npy')
15 | part = data[:,0:13]
16 | for j in range(part.shape[0]):
17 | if j % 100000 ==0:
18 | print(j)
19 | part[j] = list(map(scale, part[j]))
20 | np.save(config.DATA_PATH + 'part' + str(i) + '/train_x2.npy', data)
21 |
22 |
23 | if __name__ == '__main__':
24 | scale_each_fold()
--------------------------------------------------------------------------------
/Dataprocess/KDD2012/preprocess.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | #Email of the author: zjduan@pku.edu.cn
3 | '''
4 | 0. Click:
5 | 1. Impression(numerical)
6 | 2. DisplayURL: (categorical)
7 | 3. AdID:(categorical)
8 | 4. AdvertiserID:(categorical)
9 | 5. Depth:(numerical)
10 | 6. Position:(numerical)
11 | 7. QueryID: (categorical) the key of the data file 'queryid_tokensid.txt'.
12 | 8. KeywordID: (categorical)the key of 'purchasedkeyword_tokensid.txt'.
13 | 9. TitleID: (categorical)the key of 'titleid_tokensid.txt'.
14 | 10. DescriptionID: (categorical)the key of 'descriptionid_tokensid.txt'.
15 | 11. UserID: (categorical)the key of 'userid_profile.txt'
16 | 12. User's Gender: (categorical)
17 | 13. User's Age: (categorical)
18 | '''
19 | import math
20 | train_path = './training.txt'
21 | f1 = open(train_path, 'r')
22 | f2 = open('./userid_profile.txt', 'r')
23 | dic = {}
24 | f_train_value = open('./train_x.txt', 'w')
25 | f_train_index = open('./train_i.txt', 'w')
26 | f_train_label = open('./train_y.txt', 'w')
27 | debug = False
28 | tune = False
29 | Column = 12
30 | Field = 13
31 |
32 | numr_feat = [1,5,6]
33 | numerical = [0] * Column
34 | cate_feat = [2,3,4,7,8,9,10,11]
35 | index_cnt = 0
36 | index_others = [0] * (Field + 1)
37 | Max = [0] * 12
38 | numerical[0] = -1
39 | for i in numr_feat:
40 | index_others[i] = index_cnt
41 | index_cnt += 1
42 | numerical[i] = 1
43 | for i in cate_feat:
44 | index_others[i] = index_cnt
45 | index_cnt += 1
46 |
47 | for i in range(Field + 1):
48 | dic[i] = dict()
49 |
50 | ###init user_dic
51 | user_dic = dict()
52 |
53 | cnt_line = 0
54 | for line in f2:
55 | cnt_line += 1
56 | if (cnt_line % 1000000 == 0):
57 | print ("cnt_line = %d, index_cnt = %d" % (cnt_line, index_cnt))
58 | # if (debug == True):
59 | # if (cnt_line >= 10000):
60 | # break
61 | split = line.strip('\n').split('\t')
62 | user_dic[split[0]] = [split[1], split[2]]
63 | if (split[1] not in dic[12]):
64 | dic[12][split[1]] = [index_cnt, 0]
65 | index_cnt += 1
66 | if (split[2] not in dic[13]):
67 | dic[13][split[2]] = [index_cnt, 0]
68 | index_cnt += 1
69 |
70 | cnt_line = 0
71 | for line in f1:
72 | cnt_line += 1
73 | if (cnt_line % 1000000 == 0):
74 | print ("cnt_line = %d, index_cnt = %d" % (cnt_line, index_cnt))
75 | if (debug == True):
76 | if (cnt_line >= 10000):
77 | break
78 | split = line.strip('\n').split('\t')
79 | for i in cate_feat:
80 | if (split[i] != ''):
81 | if split[i] not in dic[i]:
82 | dic[i][split[i]] = [index_others[i], 0]
83 | dic[i][split[i]][1] += 1
84 | if (dic[i][split[i]][0] == index_others[i] and dic[i][split[i]][1] == 10):
85 | dic[i][split[i]][0] = index_cnt
86 | index_cnt += 1
87 |
88 | if (tune == False):
89 | label = split[0]
90 | if (label != '0'): label = '1'
91 | index = [0] * Field
92 | value = ['0'] * Field
93 | for i in range(1, 12):
94 | if (numerical[i] == 1):
95 | index[i - 1] = index_others[i]
96 | if (split[i] != ''):
97 | value[i - 1] = split[i]
98 | Max[i] = max(int(split[i]), Max[i])
99 | else:
100 | if (split[i] != ''):
101 | index[i - 1] = dic[i][split[i]][0]
102 | value[i - 1] = '1'
103 |
104 | if (split[i] == ''):
105 | value[i - 1] = '0'
106 | if (i == 11 and split[i] == '0'):
107 | value[i - 1] = '0'
108 | ### gender and age
109 | if (split[11] == '' or (split[11] not in user_dic)):
110 | index[12 - 1] = index_others[12]
111 | value[12 - 1] = '0'
112 | index[13 - 1] = index_others[13]
113 | value[13 - 1] = '0'
114 | else:
115 | index[12 - 1] = dic[12][user_dic[split[11]][0]][0]
116 | value[12 - 1] = '1'
117 | index[13 - 1] = dic[13][user_dic[split[11]][1]][0]
118 | value[13 - 1] = '1'
119 |
120 | f_train_index.write(' '.join(str(i) for i in index) + '\n')
121 | f_train_value.write(' '.join(value) + '\n')
122 | f_train_label.write(label + '\n')
123 |
124 | f1.close()
125 | f_train_index.close()
126 | f_train_value.close()
127 | f_train_label.close()
128 | print ("Finished!")
129 | print ("index_cnt = %d" % index_cnt)
130 | print ("max number for numerical features:")
131 | for i in numr_feat:
132 | print ("no.:%d max: %d" % (i, Max[i]))
133 |
--------------------------------------------------------------------------------
/Dataprocess/KDD2012/scale.py:
--------------------------------------------------------------------------------
1 | import math
2 | import config
3 | import numpy as np
4 | def scale(x):
5 | if x > 2:
6 | x = int(math.log(float(x))**2)
7 | return x
8 |
9 |
10 |
11 | def scale_each_fold():
12 | for i in range(1,11):
13 | print('now part %d' % i)
14 | data = np.load(config.DATA_PATH + 'part'+str(i)+'/train_x.npy')
15 | part = data[:,0:13]
16 | for j in range(part.shape[0]):
17 | if j % 100000 ==0:
18 | print(j)
19 | part[j] = list(map(scale, part[j]))
20 | np.save(config.DATA_PATH + 'part' + str(i) + '/train_x2.npy', data)
21 |
22 |
23 | if __name__ == '__main__':
24 | scale_each_fold()
--------------------------------------------------------------------------------
/Dataprocess/Kfold_split/config.py:
--------------------------------------------------------------------------------
1 | DATA_PATH = './Criteo/'
2 | TRAIN_I = DATA_PATH + 'train_i.txt'
3 | TRAIN_X = DATA_PATH + 'train_x.txt'
4 | TRAIN_Y = DATA_PATH + 'train_y.txt'
5 |
6 | NUM_SPLITS = 10
7 | RANDOM_SEED = 2018
8 |
9 |
--------------------------------------------------------------------------------
/Dataprocess/Kfold_split/stratifiedKfold.py:
--------------------------------------------------------------------------------
1 | #Email of the author: zjduan@pku.edu.cn
2 | import numpy as np
3 | import config
4 | import os
5 | import pandas as pd
6 | from sklearn.model_selection import StratifiedKFold
7 | from sklearn import preprocessing
8 |
9 | scale = ""
10 | train_x_name = "train_x.npy"
11 | train_y_name = "train_y.npy"
12 |
13 | # numr_feat = []
14 | Column = 13
15 |
16 | def _load_data(_nrows=None, debug = False):
17 |
18 | train_x = pd.read_csv(config.TRAIN_X,header=None,sep=' ',nrows=_nrows, dtype=np.float)
19 | train_y = pd.read_csv(config.TRAIN_Y,header=None,sep=' ',nrows=_nrows, dtype=np.int32)
20 |
21 | # for i in range(11):
22 | # print ("argmax feat %d = %d, max = %d" % (i, train_x[i].argmax(), train_x[i].max()))
23 |
24 | train_x = train_x.values
25 | train_y = train_y.values.reshape([-1])
26 |
27 | # print ("begin to scale")
28 | # if (scale == "minmax"):
29 | # train_x = preprocessing.MinMaxScaler().fit_transform(train_x)
30 |
31 | # if (scale == "std"):
32 | # train_x[:,0:12] = preprocessing.scale(train_x[:,0:12])
33 | # train_x[:,0:12] += 1
34 |
35 | print('data loading done!')
36 | print('training data : %d' % train_y.shape[0])
37 |
38 |
39 | assert train_x.shape[0]==train_y.shape[0]
40 |
41 | return train_x, train_y
42 |
43 |
44 | def save_x_y(fold_index, train_x, train_y):
45 | _get = lambda x, l: [x[i] for i in l]
46 | for i in range(len(fold_index)):
47 | print("now part %d" % (i+1))
48 | part_index = fold_index[i]
49 | Xv_train_, y_train_ = _get(train_x, part_index), _get(train_y, part_index)
50 | save_dir_Xv = config.DATA_PATH + "part" + str(i+1) + "/"
51 | save_dir_y = config.DATA_PATH + "part" + str(i+1) + "/"
52 | if (os.path.exists(save_dir_Xv) == False):
53 | os.makedirs(save_dir_Xv)
54 | if (os.path.exists(save_dir_y) == False):
55 | os.makedirs(save_dir_y)
56 | save_path_Xv = save_dir_Xv + train_x_name
57 | save_path_y = save_dir_y + train_y_name
58 | np.save(save_path_Xv, Xv_train_)
59 | np.save(save_path_y, y_train_)
60 |
61 |
62 | # def save_test(test_x, test_y):
63 | # np.save("../data/test/test_x.npy", test_x)
64 | # np.save("../data/test/test_y.npy", test_y)
65 |
66 |
67 | def save_i(fold_index):
68 | _get = lambda x, l: [x[i] for i in l]
69 | train_i = pd.read_csv(config.TRAIN_I,header=None,sep=' ',nrows=None, dtype=np.int32)
70 | train_i = train_i.values
71 | feature_size = train_i.max() + 1
72 | print ("feature_size = %d" % feature_size)
73 | feature_size = [feature_size]
74 | feature_size = np.array(feature_size)
75 | np.save(config.DATA_PATH + "feature_size.npy", feature_size)
76 |
77 | # pivot = 40000000
78 |
79 | # test_i = train_i[pivot:]
80 | # train_i = train_i[:pivot]
81 |
82 | # print("test_i size: %d" % len(test_i))
83 | print("train_i size: %d" % len(train_i))
84 |
85 | # np.save("../data/test/test_i.npy", test_i)
86 |
87 | for i in range(len(fold_index)):
88 | print("now part %d" % (i+1))
89 | part_index = fold_index[i]
90 | Xi_train_ = _get(train_i, part_index)
91 | save_path_Xi = config.DATA_PATH + "part" + str(i+1)+ '/train_i.npy'
92 | np.save(save_path_Xi, Xi_train_)
93 |
94 |
95 | def main():
96 |
97 | train_x, train_y = _load_data()
98 | print('loading data done!')
99 |
100 | folds = list(StratifiedKFold(n_splits=10, shuffle=True,
101 | random_state=config.RANDOM_SEED).split(train_x, train_y))
102 |
103 | fold_index = []
104 | for i,(train_id, valid_id) in enumerate(folds):
105 | fold_index.append(valid_id)
106 |
107 | print("fold num: %d" % (len(fold_index)))
108 |
109 | fold_index = np.array(fold_index)
110 | np.save(config.DATA_PATH + "fold_index.npy", fold_index)
111 |
112 | save_x_y(fold_index, train_x, train_y)
113 | print("save train_x_y done!")
114 |
115 | fold_index = np.load(config.DATA_PATH + "fold_index.npy")
116 | save_i(fold_index)
117 | print("save index done!")
118 |
119 | if __name__ == "__main__":
120 | main()
121 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Chence Shi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Note
2 | We have moved the repo to [https://github.com/DeepGraphLearning/RecommenderSystems/tree/master/featureRec](https://github.com/DeepGraphLearning/RecommenderSystems/tree/master/featureRec). Please check out the latest version there.
3 |
4 | # AutoInt
5 |
6 | This is a TenforFlow implementation of ***AutoInt*** for CTR prediction task, as described in our paper:
7 |
8 | Weiping Song, Chence Shi, Zhiping Xiao, Zhijian Duan, Yewen Xu, Ming Zhang and Jian Tang. [AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/pdf/1810.11921.pdf). arXiv preprint arXiv:1810.11921, 2018.
9 |
10 | ## Requirements:
11 | * **Tensorflow 1.4.0-rc1**
12 | * Python 3
13 | * CUDA 8.0+ (For GPU)
14 |
15 | ## Introduction
16 |
17 | AutoInt:An effective and efficient algorithm to
18 | automatically learn high-order feature interactions for (sparse) categorical and numerical features.
19 |
20 |
21 |

22 |
23 | The illustration of AutoInt. We first project all sparse features
24 | (both categorical and numerical features) into the low-dimensional space. Next, we feed embeddings of all fields into stacked multiple interacting layers implemented by self-attentive neural network. The output of the final interacting layer is the low-dimensional representation of learnt combinatorial features, which is further used for estimating the CTR via sigmoid function.
25 |
26 | ## Usage
27 | ### Input Format
28 | AutoInt requires the input data in the following format:
29 | * train_x: matrix with shape *(num_sample, num_field)*. train_x[s][t] is the feature value of feature field t of sample s in the dataset. The default value for categorical feature is 1.
30 | * train_i: matrix with shape *(num_sample, num_field)*. train_i[s][t] is the feature index of feature field t of sample s in the dataset. The maximal value of train_i is the feature size.
31 | * train_y: label of each sample in the dataset.
32 |
33 | If you want to know how to preprocess the data, please refer to `./Dataprocess/Criteo/preprocess.py`
34 |
35 | ### Example
36 | We use four public real-world datasets(Avazu, Criteo, KDD12, MovieLens-1M) in our experiments. Since the first three datasets are super huge, they can not be fit into the memory as a whole. In our implementation, we split the whole dataset into 10 parts and we use the first file as test set and the second file as valid set. We provide the codes for preprocessing these three datasets in `./Dataprocess`. If you want to reuse these codes, you should first run `preprocess.py` to generate `train_x.txt, train_i.txt, train_y.txt` as described in `Input Format`. Then you should run `./Dataprocesss/Kfold_split/StratifiedKfold.py` to split the whole dataset into ten folds. Finally you can run `scale.py` to scale the numerical value(optional).
37 |
38 | To help test the correctness of the code and familarize yourself with the code, we upload the first `10000` samples of `Criteo` dataset in `train_examples.txt`. And we provide the scripts for preprocessing and training.(Please refer to ` sample_preprocess.sh` and `test_code.sh`, you may need to modify the path in `config.py` and `test_code.sh`).
39 |
40 | After you run the `test_code.sh`, you should get a folder named `Criteo` which contains `part*, feature_size.npy, fold_index.npy, train_*.txt`. `feature_size.npy` contains the number of total features which will be used to initialize the model. `train_*.txt` is the whole dataset. If you use other small dataset, say `MovieLens-1M`, you only need to modify the function `_run_` in `train.py`.
41 |
42 | Here's how to run the preprocessing.
43 | ```
44 | mkdir Criteo
45 | python ./Dataprocess/Criteo/preprocess.py
46 | python ./Dataprocess/Kfold_split/stratifiedKfold.py
47 | python ./Dataprocess/Criteo/scale.py
48 | ```
49 |
50 | Here's how to run the training.
51 | ```
52 | python -u train.py \
53 | --data "Criteo" --blocks 3 --heads 2 --block_shape "[64, 64, 64]" \
54 | --is_save "True" --save_path "./test_code/Criteo/b3h2_64x64x64/" \
55 | --field_size 39 --run_times 1 --data_path "./" \
56 | --epoch 3 --has_residual "True" --has_wide "False" \
57 | --batch_size 1024 \
58 | > test_code_single.out &
59 | ```
60 |
61 | You should see output like this:
62 |
63 | ```
64 | ...
65 | train logs
66 | ...
67 | start testing!...
68 | restored from ./test_code/Criteo/b3h2_dnn_dropkeep1_400x2/1/
69 | test-result = 0.8088, test-logloss = 0.4430
70 | test_auc [0.8088305055534442]
71 | test_log_loss [0.44297631300399626]
72 | avg_auc 0.8088305055534442
73 | avg_log_loss 0.44297631300399626
74 | ```
75 |
76 | ## Citation
77 | If you find AutoInt useful for your research, please consider citing the following paper:
78 | ```
79 | @article{weiping2018autoint,
80 | title={AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks},
81 | author={Weiping, Song and Chence, Shi and Zhiping, Xiao and Zhijian, Duan and Yewen, Xu and Ming, Zhang and Jian, Tang},
82 | journal={arXiv preprint arXiv:1810.11921},
83 | year={2018}
84 | }
85 | ```
86 |
87 |
88 | ## Contact information
89 | If you have questions related to the code, feel free to contact Weiping Song (`songweiping@pku.edu.cn`), Chence Shi (`chenceshi@pku.edu.cn`) and Zhijian Duan (`zjduan@pku.edu.cn`).
90 |
91 | ## License
92 | MIT
93 |
94 | ## Acknowledgement
95 | This implementation gets inspirations from Kyubyong Park's [transformer](https://github.com/Kyubyong/transformer) and Chenglong Chen' [DeepFM](https://github.com/ChenglongChen/tensorflow-DeepFM).
96 |
--------------------------------------------------------------------------------
/figures/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shichence/AutoInt/d88536fbe9733e45d182fb600657d0bdf82b640c/figures/model.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | '''
2 | Tensorflow implementation of AutoInt described in:
3 | AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks.
4 | author: Chence Shi
5 | email: chenceshi@pku.edu.cn
6 | '''
7 |
8 | import os
9 | import numpy as np
10 | import tensorflow as tf
11 | from time import time
12 | from sklearn.base import BaseEstimator, TransformerMixin
13 | from sklearn.metrics import roc_auc_score, log_loss
14 | from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
15 |
16 |
17 |
18 | '''
19 | The following two functions are adapted from kyubyong park's implementation of transformer
20 | We slightly modify the code to make it suitable for our work.(add relu, delete key masking and causality mask)
21 | June 2017 by kyubyong park.
22 | kbpark.linguist@gmail.com.
23 | https://www.github.com/kyubyong/transformer
24 | '''
25 |
26 |
27 | def normalize(inputs, epsilon=1e-8):
28 | '''
29 | Applies layer normalization
30 | Args:
31 | inputs: A tensor with 2 or more dimensions
32 | epsilon: A floating number to prevent Zero Division
33 | Returns:
34 | A tensor with the same shape and data dtype
35 | '''
36 | inputs_shape = inputs.get_shape()
37 | params_shape = inputs_shape[-1:]
38 |
39 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
40 | beta = tf.Variable(tf.zeros(params_shape))
41 | gamma = tf.Variable(tf.ones(params_shape))
42 | normalized = (inputs - mean) / ((variance + epsilon) ** (.5))
43 | outputs = gamma * normalized + beta
44 |
45 | return outputs
46 |
47 |
48 | def multihead_attention(queries,
49 | keys,
50 | values,
51 | num_units=None,
52 | num_heads=1,
53 | dropout_keep_prob=1,
54 | is_training=True,
55 | has_residual=True):
56 |
57 | if num_units is None:
58 | num_units = queries.get_shape().as_list[-1]
59 |
60 | # Linear projections
61 | Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu)
62 | K = tf.layers.dense(keys, num_units, activation=tf.nn.relu)
63 | V = tf.layers.dense(values, num_units, activation=tf.nn.relu)
64 | if has_residual:
65 | V_res = tf.layers.dense(values, num_units, activation=tf.nn.relu)
66 |
67 | # Split and concat
68 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0)
69 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0)
70 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0)
71 |
72 | # Multiplication
73 | weights = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))
74 |
75 | # Scale
76 | weights = weights / (K_.get_shape().as_list()[-1] ** 0.5)
77 |
78 | # Activation
79 | weights = tf.nn.softmax(weights)
80 |
81 |
82 | # Dropouts
83 | weights = tf.layers.dropout(weights, rate=1-dropout_keep_prob,
84 | training=tf.convert_to_tensor(is_training))
85 |
86 | # Weighted sum
87 | outputs = tf.matmul(weights, V_)
88 |
89 | # Restore shape
90 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2)
91 |
92 | # Residual connection
93 | if has_residual:
94 | outputs += V_res
95 |
96 | outputs = tf.nn.relu(outputs)
97 | # Normalize
98 | outputs = normalize(outputs)
99 |
100 | return outputs
101 |
102 |
103 | class AutoInt():
104 | def __init__(self, args, feature_size, run_cnt):
105 | #print(args.block_shape)
106 | #print(type(args.block_shape))
107 |
108 | self.feature_size = feature_size # denote as n, dimension of concatenated features
109 | self.field_size = args.field_size # denote as M, number of total feature fields
110 | self.embedding_size = args.embedding_size # denote as d, size of the feature embedding
111 | self.blocks = args.blocks # number of the blocks
112 | self.heads = args.heads # number of the heads
113 | self.block_shape = args.block_shape
114 | self.output_size = args.block_shape[-1]
115 | self.has_residual = args.has_residual
116 | self.has_wide = args.has_wide # whether to add wide part
117 | self.deep_layers = args.deep_layers # whether to joint train with deep networks as described in paper
118 |
119 |
120 | self.batch_norm = args.batch_norm
121 | self.batch_norm_decay = args.batch_norm_decay
122 | self.drop_keep_prob = args.dropout_keep_prob
123 | self.l2_reg = args.l2_reg
124 | self.epoch = args.epoch
125 | self.batch_size = args.batch_size
126 | self.learning_rate = args.learning_rate
127 | self.learning_rate_wide = args.learning_rate_wide
128 | self.optimizer_type = args.optimizer_type
129 |
130 | self.save_path = args.save_path + str(run_cnt) + '/'
131 | self.is_save = args.is_save
132 | if (args.is_save == True and os.path.exists(self.save_path) == False):
133 | os.makedirs(self.save_path)
134 |
135 | self.verbose = args.verbose
136 | self.random_seed = args.random_seed
137 | self.loss_type = args.loss_type
138 | self.eval_metric = roc_auc_score
139 | self.best_loss = 1.0
140 | self.greater_is_better = args.greater_is_better
141 | self.train_result, self.valid_result = [], []
142 | self.train_loss, self.valid_loss = [], []
143 |
144 | self._init_graph()
145 |
146 |
147 | def _init_graph(self):
148 | self.graph = tf.Graph()
149 | with self.graph.as_default():
150 |
151 | tf.set_random_seed(self.random_seed)
152 |
153 | self.feat_index = tf.placeholder(tf.int32, shape=[None, None],
154 | name="feat_index") # None * M
155 | self.feat_value = tf.placeholder(tf.float32, shape=[None, None],
156 | name="feat_value") # None * M
157 | self.label = tf.placeholder(tf.float32, shape=[None, 1], name="label") # None * 1
158 | # In our implementation, the shape of dropout_keep_prob is [3], used in 3 different parts.
159 | self.dropout_keep_prob = tf.placeholder(tf.float32, shape=[None], name="dropout_keep_prob")
160 | self.train_phase = tf.placeholder(tf.bool, name="train_phase")
161 |
162 | self.weights = self._initialize_weights()
163 |
164 | # model
165 | self.embeddings = tf.nn.embedding_lookup(self.weights["feature_embeddings"],
166 | self.feat_index) # None * M * d
167 | feat_value = tf.reshape(self.feat_value, shape=[-1, self.field_size, 1])
168 | self.embeddings = tf.multiply(self.embeddings, feat_value) # None * M * d
169 | self.embeddings = tf.nn.dropout(self.embeddings, self.dropout_keep_prob[1]) # None * M * d
170 | if self.has_wide:
171 | self.y_first_order = tf.nn.embedding_lookup(self.weights["feature_bias"], self.feat_index) # None * M * 1
172 | self.y_first_order = tf.reduce_sum(tf.multiply(self.y_first_order, feat_value), 1) # None * 1
173 |
174 | # joint training with feedforward nn
175 | if self.deep_layers != None:
176 | self.y_dense = tf.reshape(self.embeddings, shape=[-1, self.field_size * self.embedding_size])
177 | for i in range(0, len(self.deep_layers)):
178 | self.y_dense = tf.add(tf.matmul(self.y_dense, self.weights["layer_%d" %i]), self.weights["bias_%d"%i]) # None * layer[i]
179 | if self.batch_norm:
180 | self.y_dense = self.batch_norm_layer(self.y_dense, train_phase=self.train_phase, scope_bn="bn_%d" %i)
181 | self.y_dense = tf.nn.relu(self.y_dense)
182 | self.y_dense = tf.nn.dropout(self.y_dense, self.dropout_keep_prob[2])
183 | self.y_dense = tf.add(tf.matmul(self.y_dense, self.weights["prediction_dense"]),
184 | self.weights["prediction_bias_dense"], name='logits_dense') # None * 1
185 |
186 |
187 | # ---------- main part of AutoInt-------------------
188 | self.y_deep = self.embeddings # None * M * d
189 | for i in range(self.blocks):
190 | self.y_deep = multihead_attention(queries=self.y_deep,
191 | keys=self.y_deep,
192 | values=self.y_deep,
193 | num_units=self.block_shape[i],
194 | num_heads=self.heads,
195 | dropout_keep_prob=self.dropout_keep_prob[0],
196 | is_training=self.train_phase,
197 | has_residual=self.has_residual)
198 |
199 | self.flat = tf.reshape(self.y_deep,
200 | shape=[-1, self.output_size * self.field_size])
201 | #if self.has_wide:
202 | # self.flat = tf.concat([self.flat, self.y_first_order], axis=1)
203 | #if self.deep_layers != None:
204 | # self.flat = tf.concat([self.flat, self.y_dense], axis=1)
205 | self.out = tf.add(tf.matmul(self.flat, self.weights["prediction"]),
206 | self.weights["prediction_bias"], name='logits') # None * 1
207 |
208 | if self.has_wide:
209 | self.out += self.y_first_order
210 |
211 | if self.deep_layers != None:
212 | self.out += self.y_dense
213 |
214 | # ---------- Compute the loss ----------
215 | # loss
216 | if self.loss_type == "logloss":
217 | self.out = tf.nn.sigmoid(self.out, name='pred')
218 | self.loss = tf.losses.log_loss(self.label, self.out)
219 | elif self.loss_type == "mse":
220 | self.loss = tf.nn.l2_loss(tf.subtract(self.label, self.out))
221 |
222 | # l2 regularization on weights
223 | if self.l2_reg > 0:
224 | if self.deep_layers != None:
225 | for i in range(len(self.deep_layers)):
226 | self.loss += tf.contrib.layers.l2_regularizer(
227 | self.l2_reg)(self.weights["layer_%d"%i])
228 | #self.loss += tf.contrib.layers.l2_regularizer(self.l2_reg)(self.embeddings)
229 | #all_vars = tf.trainable_variables()
230 | #lossL2 = tf.add_n([ tf.nn.l2_loss(v) for v in all_vars
231 | # if 'bias' not in v.name and 'embeddings' not in v.name]) * self.l2_reg
232 | #self.loss += lossL2
233 |
234 | self.global_step = tf.Variable(0, name="global_step", trainable=False)
235 | self.var1 = [v for v in tf.trainable_variables() if v.name != 'feature_bias:0']
236 | self.var2 = [tf.trainable_variables()[1]] # self.var2 = [feature_bias]
237 | # optimizer
238 | # here we should use two different optimizer for wide and deep model(if we add wide part).
239 | if self.optimizer_type == "adam":
240 | if self.has_wide:
241 | optimizer1 = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
242 | beta1=0.9, beta2=0.999, epsilon=1e-8)
243 | optimizer2 = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate_wide)
244 | #minimize(self.loss, global_step=self.global_step)
245 | var_list1 = self.var1
246 | var_list2 = self.var2
247 | grads = tf.gradients(self.loss, var_list1 + var_list2)
248 | grads1 = grads[:len(var_list1)]
249 | grads2 = grads[len(var_list1):]
250 | train_op1 = optimizer1.apply_gradients(zip(grads1, var_list1), global_step=self.global_step)
251 | train_op2 = optimizer2.apply_gradients(zip(grads2, var_list2))
252 | self.optimizer = tf.group(train_op1, train_op2)
253 | else:
254 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate,
255 | beta1=0.9, beta2=0.999, epsilon=1e-8).\
256 | minimize(self.loss, global_step=self.global_step)
257 | elif self.optimizer_type == "adagrad":
258 | self.optimizer = tf.train.AdagradOptimizer(learning_rate=self.learning_rate,
259 | initial_accumulator_value=1e-8).\
260 | minimize(self.loss)
261 | elif self.optimizer_type == "gd":
262 | self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate).\
263 | minimize(self.loss)
264 | elif self.optimizer_type == "momentum":
265 | self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.95).\
266 | minimize(self.loss)
267 |
268 | # init
269 | self.saver = tf.train.Saver(max_to_keep=5)
270 | init = tf.global_variables_initializer()
271 | self.sess = self._init_session()
272 | self.sess.run(init)
273 | self.count_param()
274 |
275 |
276 | def count_param(self):
277 | k = (np.sum([np.prod(v.get_shape().as_list())
278 | for v in tf.trainable_variables()]))
279 |
280 | #print(tf.trainable_variables())
281 | print("total parameters :%d" % k)
282 | print("extra parameters : %d" % (k - self.feature_size * self.embedding_size))
283 |
284 |
285 | def _init_session(self):
286 | config = tf.ConfigProto(allow_soft_placement=True)
287 | config.gpu_options.allow_growth = True
288 | return tf.Session(config=config)
289 |
290 |
291 | def _initialize_weights(self):
292 | weights = dict()
293 |
294 | # embeddings
295 | weights["feature_embeddings"] = tf.Variable(
296 | tf.random_normal([self.feature_size, self.embedding_size], 0.0, 0.01),
297 | name="feature_embeddings") # feature_size(n) * d
298 |
299 | if self.has_wide:
300 | weights["feature_bias"] = tf.Variable(
301 | tf.random_normal([self.feature_size, 1], 0.0, 0.001),
302 | name="feature_bias") # feature_size(n) * 1
303 | input_size = self.output_size * self.field_size
304 | #if self.deep_layers != None:
305 | # input_size += self.deep_layers[-1]
306 | #if self.has_wide:
307 | # input_size += self.field_size
308 |
309 | # dense layers
310 | if self.deep_layers != None:
311 | num_layer = len(self.deep_layers)
312 | layer0_size = self.field_size * self.embedding_size
313 | glorot = np.sqrt(2.0 / (layer0_size + self.deep_layers[0]))
314 | weights["layer_0"] = tf.Variable(
315 | np.random.normal(loc=0, scale=glorot, size=(layer0_size, self.deep_layers[0])), dtype=np.float32)
316 | weights["bias_0"] = tf.Variable(np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[0])),
317 | dtype=np.float32) # 1 * layers[0]
318 | for i in range(1, num_layer):
319 | glorot = np.sqrt(2.0 / (self.deep_layers[i-1] + self.deep_layers[i]))
320 | weights["layer_%d" % i] = tf.Variable(
321 | np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[i-1], self.deep_layers[i])),
322 | dtype=np.float32) # layers[i-1] * layers[i]
323 | weights["bias_%d" % i] = tf.Variable(
324 | np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[i])),
325 | dtype=np.float32) # 1 * layer[i]
326 | glorot = np.sqrt(2.0 / (self.deep_layers[-1] + 1))
327 | weights["prediction_dense"] = tf.Variable(
328 | np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[-1], 1)),
329 | dtype=np.float32, name="prediction_dense")
330 | weights["prediction_bias_dense"] = tf.Variable(
331 | np.random.normal(), dtype=np.float32, name="prediction_bias_dense")
332 |
333 |
334 | #---------- prediciton weight ------------------#
335 | glorot = np.sqrt(2.0 / (input_size + 1))
336 | weights["prediction"] = tf.Variable(
337 | np.random.normal(loc=0, scale=glorot, size=(input_size, 1)),
338 | dtype=np.float32, name="prediction")
339 | weights["prediction_bias"] = tf.Variable(
340 | np.random.normal(), dtype=np.float32, name="prediction_bias")
341 |
342 | return weights
343 |
344 | def batch_norm_layer(self, x, train_phase, scope_bn):
345 | bn_train = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None,
346 | is_training=True, reuse=None, trainable=True, scope=scope_bn)
347 | bn_inference = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None,
348 | is_training=False, reuse=True, trainable=True, scope=scope_bn)
349 | z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
350 | return z
351 |
352 |
353 | def get_batch(self, Xi, Xv, y, batch_size, index):
354 | start = index * batch_size
355 | end = (index+1) * batch_size
356 | end = end if end < len(y) else len(y)
357 | return Xi[start:end], Xv[start:end], [[y_] for y_ in y[start:end]]
358 |
359 |
360 | # shuffle three lists simutaneously
361 | def shuffle_in_unison_scary(self, a, b, c):
362 | rng_state = np.random.get_state()
363 | np.random.shuffle(a)
364 | np.random.set_state(rng_state)
365 | np.random.shuffle(b)
366 | np.random.set_state(rng_state)
367 | np.random.shuffle(c)
368 |
369 |
370 | def fit_on_batch(self, Xi, Xv, y):
371 | feed_dict = {self.feat_index: Xi,
372 | self.feat_value: Xv,
373 | self.label: y,
374 | self.dropout_keep_prob: self.drop_keep_prob,
375 | self.train_phase: True}
376 | step, loss, opt = self.sess.run((self.global_step, self.loss, self.optimizer), feed_dict=feed_dict)
377 | return step, loss
378 |
379 | # Since the train data is very large, they can not be fit into the memory at the same time.
380 | # We separate the whole train data into several files and call "fit_once" for each file.
381 | def fit_once(self, Xi_train, Xv_train, y_train,
382 | epoch, file_count, Xi_valid=None,
383 | Xv_valid=None, y_valid=None,
384 | early_stopping=False):
385 |
386 | has_valid = Xv_valid is not None
387 | last_step = 0
388 | t1 = time()
389 | self.shuffle_in_unison_scary(Xi_train, Xv_train, y_train)
390 | total_batch = int(len(y_train) / self.batch_size)
391 | for i in range(total_batch):
392 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi_train, Xv_train, y_train, self.batch_size, i)
393 | step, loss = self.fit_on_batch(Xi_batch, Xv_batch, y_batch)
394 | last_step = step
395 |
396 | # evaluate training and validation datasets
397 | train_result, train_loss = self.evaluate(Xi_train, Xv_train, y_train)
398 | self.train_result.append(train_result)
399 | self.train_loss.append(train_loss)
400 | if has_valid:
401 | valid_result, valid_loss = self.evaluate(Xi_valid, Xv_valid, y_valid)
402 | self.valid_result.append(valid_result)
403 | self.valid_loss.append(valid_loss)
404 | if valid_loss < self.best_loss and self.is_save == True:
405 | old_loss = self.best_loss
406 | self.best_loss = valid_loss
407 | self.saver.save(self.sess, self.save_path + 'model.ckpt',global_step=last_step)
408 | print("[%d-%d] model saved!. Valid loss is improved from %.4f to %.4f"
409 | % (epoch, file_count, old_loss, self.best_loss))
410 |
411 | if self.verbose > 0 and ((epoch-1)*9 + file_count) % self.verbose == 0:
412 | if has_valid:
413 | print("[%d-%d] train-result=%.4f, train-logloss=%.4f, valid-result=%.4f, valid-logloss=%.4f [%.1f s]" % (epoch, file_count, train_result, train_loss, valid_result, valid_loss, time() - t1))
414 | else:
415 | print("[%d-%d] train-result=%.4f [%.1f s]" \
416 | % (epoch, file_count, train_result, time() - t1))
417 | if has_valid and early_stopping and self.training_termination(self.valid_loss):
418 | return False
419 | else:
420 | return True
421 |
422 |
423 |
424 | def training_termination(self, valid_result):
425 | if len(valid_result) > 5:
426 | if self.greater_is_better:
427 | if valid_result[-1] < valid_result[-2] and \
428 | valid_result[-2] < valid_result[-3] and \
429 | valid_result[-3] < valid_result[-4] and \
430 | valid_result[-4] < valid_result[-5]:
431 | return True
432 | else:
433 | if valid_result[-1] > valid_result[-2] and \
434 | valid_result[-2] > valid_result[-3] and \
435 | valid_result[-3] > valid_result[-4] and \
436 | valid_result[-4] > valid_result[-5]:
437 | return True
438 | return False
439 |
440 |
441 | def predict(self, Xi, Xv):
442 | """
443 | :param Xi: list of list of feature indices of each sample in the dataset
444 | :param Xv: list of list of feature values of each sample in the dataset
445 | :return: predicted probability of each sample
446 | """
447 | # dummy y
448 | dummy_y = [1] * len(Xi)
449 | batch_index = 0
450 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index)
451 | y_pred = None
452 | #y_loss = None
453 | while len(Xi_batch) > 0:
454 | num_batch = len(y_batch)
455 | feed_dict = {self.feat_index: Xi_batch,
456 | self.feat_value: Xv_batch,
457 | self.label: y_batch,
458 | self.dropout_keep_prob: [1.0] * len(self.drop_keep_prob),
459 | self.train_phase: False}
460 | batch_out = self.sess.run(self.out, feed_dict=feed_dict)
461 |
462 | if batch_index == 0:
463 | y_pred = np.reshape(batch_out, (num_batch,))
464 | #y_loss = np.reshape(batch_loss, (num_batch,))
465 | else:
466 | y_pred = np.concatenate((y_pred, np.reshape(batch_out, (num_batch,))))
467 | #y_loss = np.concatenate((y_loss, np.reshape(batch_loss, (num_batch,))))
468 |
469 | batch_index += 1
470 | Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index)
471 |
472 | return y_pred
473 |
474 |
475 | def evaluate(self, Xi, Xv, y):
476 | """
477 | :param Xi: list of list of feature indices of each sample in the dataset
478 | :param Xv: list of list of feature values of each sample in the dataset
479 | :param y: label of each sample in the dataset
480 | :return: metric of the evaluation
481 | """
482 | y_pred = self.predict(Xi, Xv)
483 | y_pred = np.clip(y_pred,1e-6,1-1e-6)
484 | return self.eval_metric(y, y_pred), log_loss(y, y_pred)
485 |
486 | def restore(self, save_path=None):
487 | if (save_path == None):
488 | save_path = self.save_path
489 | ckpt = tf.train.get_checkpoint_state(save_path)
490 | if ckpt and ckpt.model_checkpoint_path:
491 | self.saver.restore(self.sess, ckpt.model_checkpoint_path)
492 | if self.verbose > 0:
493 | print ("restored from %s" % (save_path))
494 |
--------------------------------------------------------------------------------
/sample_preprocess.sh:
--------------------------------------------------------------------------------
1 | mkdir Criteo
2 | python ./Dataprocess/Criteo/preprocess.py
3 | python ./Dataprocess/Kfold_split/stratifiedKfold.py
4 | python ./Dataprocess/Criteo/scale.py
5 |
--------------------------------------------------------------------------------
/test_code.sh:
--------------------------------------------------------------------------------
1 | echo "start training ..."
2 |
3 | python -u train.py \
4 | --data "Criteo" --blocks 3 --heads 2 --block_shape "[64, 64, 64]" \
5 | --is_save "True" --save_path "./test_code/Criteo/b3h2_64x64x64/" \
6 | --field_size 39 --run_times 1 --data_path "./" \
7 | --epoch 3 --has_residual "True" --has_wide "False" \
8 | --batch_size 1024 \
9 | > test_code_single.out &
10 |
11 |
12 | python -u train.py \
13 | --data "Criteo" --blocks 3 --heads 2 --block_shape "[64, 64, 64]" \
14 | --is_save "True" --save_path "./test_code/Criteo/b3h2_dnn_dropkeep1_400x2/" \
15 | --field_size 39 --run_times 1 --dropout_keep_prob "[1, 1, 1]" --data_path "./" \
16 | --epoch 3 --has_residual "True" --has_wide "False" --deep_layers "[400, 400]"\
17 | --batch_size 1024 \
18 | > ./test_code_dnn.out &
19 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import pandas as pd
4 | import tensorflow as tf
5 | from sklearn.metrics import make_scorer
6 | from sklearn.model_selection import StratifiedKFold
7 | from time import time
8 | from model import AutoInt
9 | import argparse
10 | import os
11 |
12 |
13 | def str2list(v):
14 | v=v.split(',')
15 | v=[int(_.strip('[]')) for _ in v]
16 |
17 | return v
18 |
19 |
20 | def str2list2(v):
21 | v=v.split(',')
22 | v=[float(_.strip('[]')) for _ in v]
23 |
24 | return v
25 |
26 |
27 | def str2bool(v):
28 | if v.lower() in ['yes', 'true', 't', 'y', '1']:
29 | return True
30 | elif v.lower() in ['no', 'false', 'f', 'n', '0']:
31 | return False
32 | else:
33 | raise argparse.ArgumentTypeError('Unsupported value encountered.')
34 |
35 | def parse_args():
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--blocks', type=int, default=2, help='#blocks')
38 | parser.add_argument('--block_shape', type=str2list, default=[16,16], help='output shape of each block')
39 | parser.add_argument('--heads', type=int, default=2, help='#heads')
40 | parser.add_argument('--embedding_size', type=int, default=16)
41 | parser.add_argument('--dropout_keep_prob', type=str2list2, default=[1, 1, 0.5])
42 | parser.add_argument('--epoch', type=int, default=2)
43 | parser.add_argument('--batch_size', type=int, default=1024)
44 | parser.add_argument('--learning_rate', type=float, default=0.001)
45 | parser.add_argument('--learning_rate_wide', type=float, default=0.001)
46 | parser.add_argument('--optimizer_type', type=str, default='adam')
47 | parser.add_argument('--l2_reg', type=float, default=0.0)
48 | parser.add_argument('--random_seed', type=int, default=2018)
49 | parser.add_argument('--save_path', type=str, default='./model/')
50 | parser.add_argument('--field_size', type=int, default=23, help='#fields')
51 | parser.add_argument('--loss_type', type=str, default='logloss')
52 | parser.add_argument('--verbose', type=int, default=1)
53 | parser.add_argument('--run_times', type=int, default=3,help='run multiple times to eliminate error')
54 | parser.add_argument('--is_save', type=str2bool, default=False)
55 | parser.add_argument('--greater_is_better', type=str2bool, default=False, help='early stop criterion')
56 | parser.add_argument('--has_residual', type=str2bool, default=True, help='add residual or not')
57 | parser.add_argument('--has_wide', type=str2bool, default=False)
58 | parser.add_argument('--deep_layers', type=str2list, default=None, help='config for dnn in joint train')
59 | parser.add_argument('--batch_norm', type=int, default=0)
60 | parser.add_argument('--batch_norm_decay', type=float, default=0.995)
61 | parser.add_argument('--data', type=str, help='data name')
62 | parser.add_argument('--data_path', type=str, help='root path for all the data')
63 | return parser.parse_args()
64 |
65 |
66 |
67 | def _run_(args, file_name, run_cnt):
68 | #path_prefix = '../Dataprocess/' + args.data
69 | path_prefix = os.path.join(args.data_path, args.data)
70 | feature_size = np.load(path_prefix + '/feature_size.npy')[0]
71 |
72 | # test: file1, valid: file2, train: file3-10
73 | model = AutoInt(args=args, feature_size=feature_size, run_cnt=run_cnt)
74 |
75 | #variables = tf.contrib.framework.get_variables_to_restore()
76 | #print(variables)
77 | #return
78 | Xi_valid = np.load(path_prefix + '/part2/' + file_name[0])
79 | Xv_valid = np.load(path_prefix + '/part2/' + file_name[1])
80 | y_valid = np.load(path_prefix + '/part2/' + file_name[2])
81 |
82 | is_continue = True
83 | for k in range(model.epoch):
84 | if not is_continue:
85 | print('early stopping at epoch %d' % (k+1))
86 | break
87 | file_count = 0
88 | time_epoch = 0
89 | for j in range(3, 11):
90 | if not is_continue:
91 | print('early stopping at epoch %d file %d' % (k+1, j))
92 | break
93 | file_count += 1
94 | Xi_train = np.load(path_prefix + '/part' + str(j) + '/' + file_name[0])
95 | Xv_train = np.load(path_prefix + '/part' + str(j) + '/' + file_name[1])
96 | y_train = np.load(path_prefix + '/part' + str(j) + '/' + file_name[2])
97 |
98 | print("epoch %d, file %d" %(k+1, j))
99 | t1 = time()
100 | is_continue = model.fit_once(Xi_train, Xv_train, y_train, k+1, file_count,
101 | Xi_valid, Xv_valid, y_valid, early_stopping=True)
102 | time_epoch += time() - t1
103 |
104 | print("epoch %d, time %d" % (k+1, time_epoch))
105 |
106 |
107 | print('start testing!...')
108 | Xi_test = np.load(path_prefix + '/part1/' + file_name[0])
109 | Xv_test = np.load(path_prefix + '/part1/' + file_name[1])
110 | y_test = np.load(path_prefix + '/part1/' + file_name[2])
111 |
112 | model.restore()
113 |
114 | test_result, test_loss = model.evaluate(Xi_test, Xv_test, y_test)
115 | print("test-result = %.4lf, test-logloss = %.4lf" % (test_result, test_loss))
116 | return test_result, test_loss
117 |
118 | if __name__ == "__main__":
119 | args = parse_args()
120 | print(args.__dict__)
121 | print('**************')
122 | #file_name = []
123 | if args.data in ['Avazu']:
124 | # Avazu does not have numerical features so we didn't scale the data.
125 | file_name = ['train_i.npy', 'train_x.npy', 'train_y.npy']
126 | elif args.data in ['Criteo', 'KDD2012']:
127 | file_name = ['train_i.npy', 'train_x2.npy', 'train_y.npy']
128 | test_auc = []
129 | test_log = []
130 |
131 | print('run time : %d' % args.run_times)
132 | for i in range(1, args.run_times + 1):
133 | test_result, test_loss = _run_(args, file_name, i)
134 | test_auc.append(test_result)
135 | test_log.append(test_loss)
136 | print('test_auc', test_auc)
137 | print('test_log_loss', test_log)
138 | print('avg_auc', sum(test_auc)/len(test_auc))
139 | print('avg_log_loss', sum(test_log)/len(test_log))
140 |
141 |
--------------------------------------------------------------------------------