├── data
├── .idea
├── dictionaries
│ └── kevin.xml
└── vcs.xml
├── .gitignore
├── python
├── main.py
├── utils.py
└── models.py
└── README.md
/data:
--------------------------------------------------------------------------------
1 | ../make-ipinyou-data/3358
--------------------------------------------------------------------------------
/.idea/dictionaries/kevin.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .coverage
43 | .coverage.*
44 | .cache
45 | nosetests.xml
46 | coverage.xml
47 | *,cover
48 | .hypothesis/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 |
58 | # Flask stuff:
59 | instance/
60 | .webassets-cache
61 |
62 | # Scrapy stuff:
63 | .scrapy
64 |
65 | # Sphinx documentation
66 | docs/_build/
67 |
68 | # PyBuilder
69 | target/
70 |
71 | # IPython Notebook
72 | .ipynb_checkpoints
73 |
74 | # pyenv
75 | .python-version
76 |
77 | # celery beat schedule file
78 | celerybeat-schedule
79 |
80 | # dotenv
81 | .env
82 |
83 | # virtualenv
84 | venv/
85 | ENV/
86 |
87 | # Spyder project settings
88 | .spyderproject
89 |
90 | # Rope project settings
91 | .ropeproject
92 |
93 | log
94 | .idea
95 | *.ipynb
--------------------------------------------------------------------------------
/python/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import absolute_import
3 | from __future__ import division
4 |
5 | import sys
6 | if sys.version[0] == '2':
7 | import cPickle as pkl
8 | else:
9 | import pickle as pkl
10 |
11 | import numpy as np
12 | from sklearn.metrics import roc_auc_score
13 |
14 | import progressbar
15 |
16 | import os
17 | p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18 | if not p in sys.path:
19 | sys.path.append(p)
20 |
21 | from python import utils
22 | from python.models import LR, FM, PNN1, PNN2, FNN, CCPM, DeepFM
23 |
24 | train_file = '../data/train.txt'
25 | test_file = '../data/test.txt'
26 |
27 | input_dim = utils.INPUT_DIM
28 |
29 | train_data = utils.read_data(train_file)
30 | # train_data = pkl.load(open('../data/train.yx.pkl', 'rb'))
31 | train_data = utils.shuffle(train_data)
32 | test_data = utils.read_data(test_file)
33 | # test_data = pkl.load(open('../data/test.yx.pkl', 'rb'))
34 | # pkl.dump(train_data, open('../data/train.yx.pkl', 'wb'))
35 | # pkl.dump(test_data, open('../data/test.yx.pkl', 'wb'))
36 |
37 | if train_data[1].ndim > 1:
38 | print('label must be 1-dim')
39 | exit(0)
40 | print('read finish')
41 | print('train data size:', train_data[0].shape)
42 | print('test data size:', test_data[0].shape)
43 |
44 | train_size = train_data[0].shape[0]
45 | test_size = test_data[0].shape[0]
46 | num_feas = len(utils.FIELD_SIZES)
47 |
48 | min_round = 1
49 | num_round = 200
50 | early_stop_round = 5
51 | batch_size = 1024
52 |
53 | field_sizes = utils.FIELD_SIZES
54 | field_offsets = utils.FIELD_OFFSETS
55 |
56 | algo = 'pnn2'
57 |
58 | if algo in {'fnn', 'ccpm', 'pnn1', 'pnn2', 'deepfm'}:
59 | train_data = utils.split_data(train_data)
60 | test_data = utils.split_data(test_data)
61 | tmp = []
62 | for x in field_sizes:
63 | if x > 0:
64 | tmp.append(x)
65 | field_sizes = tmp
66 | print('remove empty fields', field_sizes)
67 |
68 | if algo == 'lr':
69 | lr_params = {
70 | 'input_dim': input_dim,
71 | 'opt_algo': 'gd',
72 | 'learning_rate': 0.1,
73 | 'l2_weight': 0,
74 | 'random_seed': 0
75 | }
76 | print(lr_params)
77 | model = LR(**lr_params)
78 | elif algo == 'fm':
79 | fm_params = {
80 | 'input_dim': input_dim,
81 | 'factor_order': 10,
82 | 'opt_algo': 'gd',
83 | 'learning_rate': 0.1,
84 | 'l2_w': 0,
85 | 'l2_v': 0,
86 | }
87 | print(fm_params)
88 | model = FM(**fm_params)
89 | elif algo == 'fnn':
90 | fnn_params = {
91 | 'field_sizes': field_sizes,
92 | 'embed_size': 10,
93 | 'layer_sizes': [500, 1],
94 | 'layer_acts': ['relu', None],
95 | 'drop_out': [0, 0],
96 | 'opt_algo': 'gd',
97 | 'learning_rate': 0.1,
98 | 'embed_l2': 0,
99 | 'layer_l2': [0, 0],
100 | 'random_seed': 0
101 | }
102 | print(fnn_params)
103 | model = FNN(**fnn_params)
104 | elif algo == 'deepfm':
105 | deepfm_params = {
106 | 'field_sizes': field_sizes,
107 | 'embed_size': 10,
108 | 'layer_sizes': [500, 1],
109 | 'layer_acts': ['relu', None],
110 | 'drop_out': [0, 0],
111 | 'opt_algo': 'gd',
112 | 'learning_rate': 0.1,
113 | 'embed_l2': 0,
114 | 'layer_l2': [0, 0],
115 | 'random_seed': 0
116 | }
117 | print(deepfm_params)
118 | model = DeepFM(**deepfm_params)
119 | elif algo == 'ccpm':
120 | ccpm_params = {
121 | 'field_sizes': field_sizes,
122 | 'embed_size': 10,
123 | 'filter_sizes': [5, 3],
124 | 'layer_acts': ['relu'],
125 | 'drop_out': [0],
126 | 'opt_algo': 'gd',
127 | 'learning_rate': 0.1,
128 | 'random_seed': 0
129 | }
130 | print(ccpm_params)
131 | model = CCPM(**ccpm_params)
132 | elif algo == 'pnn1':
133 | pnn1_params = {
134 | 'field_sizes': field_sizes,
135 | 'embed_size': 10,
136 | 'layer_sizes': [500, 1],
137 | 'layer_acts': ['relu', None],
138 | 'drop_out': [0, 0],
139 | 'opt_algo': 'gd',
140 | 'learning_rate': 0.1,
141 | 'embed_l2': 0,
142 | 'layer_l2': [0, 0],
143 | 'random_seed': 0
144 | }
145 | print(pnn1_params)
146 | model = PNN1(**pnn1_params)
147 | elif algo == 'pnn2':
148 | pnn2_params = {
149 | 'field_sizes': field_sizes,
150 | 'embed_size': 10,
151 | 'layer_sizes': [500, 1],
152 | 'layer_acts': ['relu', None],
153 | 'drop_out': [0, 0],
154 | 'opt_algo': 'gd',
155 | 'learning_rate': 0.1,
156 | 'embed_l2': 0,
157 | 'layer_l2': [0., 0.],
158 | 'random_seed': 0,
159 | 'layer_norm': True,
160 | }
161 | print(pnn2_params)
162 | model = PNN2(**pnn2_params)
163 |
164 |
165 | def train(model):
166 | history_score = []
167 | for i in range(num_round):
168 | fetches = [model.optimizer, model.loss]
169 | if batch_size > 0:
170 | ls = []
171 | bar = progressbar.ProgressBar()
172 | print('[%d]\ttraining...' % i)
173 | for j in bar(range(int(train_size / batch_size + 1))):
174 | X_i, y_i = utils.slice(train_data, j * batch_size, batch_size)
175 | _, l = model.run(fetches, X_i, y_i)
176 | ls.append(l)
177 | elif batch_size == -1:
178 | X_i, y_i = utils.slice(train_data)
179 | _, l = model.run(fetches, X_i, y_i)
180 | ls = [l]
181 | train_preds = []
182 | print('[%d]\tevaluating...' % i)
183 | bar = progressbar.ProgressBar()
184 | for j in bar(range(int(train_size / 10000 + 1))):
185 | X_i, _ = utils.slice(train_data, j * 10000, 10000)
186 | preds = model.run(model.y_prob, X_i, mode='test')
187 | train_preds.extend(preds)
188 | test_preds = []
189 | bar = progressbar.ProgressBar()
190 | for j in bar(range(int(test_size / 10000 + 1))):
191 | X_i, _ = utils.slice(test_data, j * 10000, 10000)
192 | preds = model.run(model.y_prob, X_i, mode='test')
193 | test_preds.extend(preds)
194 | train_score = roc_auc_score(train_data[1], train_preds)
195 | test_score = roc_auc_score(test_data[1], test_preds)
196 | print('[%d]\tloss (with l2 norm):%f\ttrain-auc: %f\teval-auc: %f' % (i, np.mean(ls), train_score, test_score))
197 | history_score.append(test_score)
198 | if i > min_round and i > early_stop_round:
199 | if np.argmax(history_score) == i - early_stop_round and history_score[-1] - history_score[
200 | -1 * early_stop_round] < 1e-5:
201 | print('early stop\nbest iteration:\n[%d]\teval-auc: %f' % (
202 | np.argmax(history_score), np.max(history_score)))
203 | break
204 |
205 | train(model)
206 |
--------------------------------------------------------------------------------
/python/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import absolute_import
3 | from __future__ import division
4 |
5 | import sys
6 | if sys.version[0] == '2':
7 | import cPickle as pkl
8 | else:
9 | import pickle as pkl
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 | from scipy.sparse import coo_matrix
14 |
15 | DTYPE = tf.float32
16 |
17 | FIELD_SIZES = [0] * 26
18 | with open('../data/featindex.txt') as fin:
19 | for line in fin:
20 | line = line.strip().split(':')
21 | if len(line) > 1:
22 | f = int(line[0]) - 1
23 | FIELD_SIZES[f] += 1
24 | print('field sizes:', FIELD_SIZES)
25 | FIELD_OFFSETS = [sum(FIELD_SIZES[:i]) for i in range(len(FIELD_SIZES))]
26 | INPUT_DIM = sum(FIELD_SIZES)
27 | OUTPUT_DIM = 1
28 | STDDEV = 1e-3
29 | MINVAL = -1e-3
30 | MAXVAL = 1e-3
31 |
32 |
33 | def read_data(file_name):
34 | X = []
35 | D = []
36 | y = []
37 | with open(file_name) as fin:
38 | for line in fin:
39 | fields = line.strip().split()
40 | y_i = int(fields[0])
41 | X_i = [int(x.split(':')[0]) for x in fields[1:]]
42 | D_i = [int(x.split(':')[1]) for x in fields[1:]]
43 | y.append(y_i)
44 | X.append(X_i)
45 | D.append(D_i)
46 | y = np.reshape(np.array(y), [-1])
47 | X = libsvm_2_coo(zip(X, D), (len(X), INPUT_DIM)).tocsr()
48 | return X, y
49 |
50 |
51 | def shuffle(data):
52 | X, y = data
53 | ind = np.arange(X.shape[0])
54 | for i in range(7):
55 | np.random.shuffle(ind)
56 | return X[ind], y[ind]
57 |
58 |
59 | def libsvm_2_coo(libsvm_data, shape):
60 | coo_rows = []
61 | coo_cols = []
62 | coo_data = []
63 | n = 0
64 | for x, d in libsvm_data:
65 | coo_rows.extend([n] * len(x))
66 | coo_cols.extend(x)
67 | coo_data.extend(d)
68 | n += 1
69 | coo_rows = np.array(coo_rows)
70 | coo_cols = np.array(coo_cols)
71 | coo_data = np.array(coo_data)
72 | return coo_matrix((coo_data, (coo_rows, coo_cols)), shape=shape)
73 |
74 |
75 | def csr_2_input(csr_mat):
76 | if not isinstance(csr_mat, list):
77 | coo_mat = csr_mat.tocoo()
78 | indices = np.vstack((coo_mat.row, coo_mat.col)).transpose()
79 | values = csr_mat.data
80 | shape = csr_mat.shape
81 | return indices, values, shape
82 | else:
83 | inputs = []
84 | for csr_i in csr_mat:
85 | inputs.append(csr_2_input(csr_i))
86 | return inputs
87 |
88 |
89 | def slice(csr_data, start=0, size=-1):
90 | if not isinstance(csr_data[0], list):
91 | if size == -1 or start + size >= csr_data[0].shape[0]:
92 | slc_data = csr_data[0][start:]
93 | slc_labels = csr_data[1][start:]
94 | else:
95 | slc_data = csr_data[0][start:start + size]
96 | slc_labels = csr_data[1][start:start + size]
97 | else:
98 | if size == -1 or start + size >= csr_data[0][0].shape[0]:
99 | slc_data = []
100 | for d_i in csr_data[0]:
101 | slc_data.append(d_i[start:])
102 | slc_labels = csr_data[1][start:]
103 | else:
104 | slc_data = []
105 | for d_i in csr_data[0]:
106 | slc_data.append(d_i[start:start + size])
107 | slc_labels = csr_data[1][start:start + size]
108 | return csr_2_input(slc_data), slc_labels
109 |
110 |
111 | def split_data(data, skip_empty=True):
112 | fields = []
113 | for i in range(len(FIELD_OFFSETS) - 1):
114 | start_ind = FIELD_OFFSETS[i]
115 | end_ind = FIELD_OFFSETS[i + 1]
116 | if skip_empty and start_ind == end_ind:
117 | continue
118 | field_i = data[0][:, start_ind:end_ind]
119 | fields.append(field_i)
120 | fields.append(data[0][:, FIELD_OFFSETS[-1]:])
121 | return fields, data[1]
122 |
123 |
124 | def init_var_map(init_vars, init_path=None):
125 | if init_path is not None:
126 | load_var_map = pkl.load(open(init_path, 'rb'))
127 | print('load variable map from', init_path, load_var_map.keys())
128 | var_map = {}
129 | for var_name, var_shape, init_method, dtype in init_vars:
130 | if init_method == 'zero':
131 | var_map[var_name] = tf.Variable(tf.zeros(var_shape, dtype=dtype), name=var_name, dtype=dtype)
132 | elif init_method == 'one':
133 | var_map[var_name] = tf.Variable(tf.ones(var_shape, dtype=dtype), name=var_name, dtype=dtype)
134 | elif init_method == 'normal':
135 | var_map[var_name] = tf.Variable(tf.random_normal(var_shape, mean=0.0, stddev=STDDEV, dtype=dtype),
136 | name=var_name, dtype=dtype)
137 | elif init_method == 'tnormal':
138 | var_map[var_name] = tf.Variable(tf.truncated_normal(var_shape, mean=0.0, stddev=STDDEV, dtype=dtype),
139 | name=var_name, dtype=dtype)
140 | elif init_method == 'uniform':
141 | var_map[var_name] = tf.Variable(tf.random_uniform(var_shape, minval=MINVAL, maxval=MAXVAL, dtype=dtype),
142 | name=var_name, dtype=dtype)
143 | elif init_method == 'xavier':
144 | maxval = np.sqrt(6. / np.sum(var_shape))
145 | minval = -maxval
146 | var_map[var_name] = tf.Variable(tf.random_uniform(var_shape, minval=minval, maxval=maxval, dtype=dtype),
147 | name=var_name, dtype=dtype)
148 | elif isinstance(init_method, int) or isinstance(init_method, float):
149 | var_map[var_name] = tf.Variable(tf.ones(var_shape, dtype=dtype) * init_method, name=var_name, dtype=dtype)
150 | elif init_method in load_var_map:
151 | if load_var_map[init_method].shape == tuple(var_shape):
152 | var_map[var_name] = tf.Variable(load_var_map[init_method], name=var_name, dtype=dtype)
153 | else:
154 | print('BadParam: init method', init_method, 'shape', var_shape, load_var_map[init_method].shape)
155 | else:
156 | print('BadParam: init method', init_method)
157 | return var_map
158 |
159 |
160 | def activate(weights, activation_function):
161 | if activation_function == 'sigmoid':
162 | return tf.nn.sigmoid(weights)
163 | elif activation_function == 'softmax':
164 | return tf.nn.softmax(weights)
165 | elif activation_function == 'relu':
166 | return tf.nn.relu(weights)
167 | elif activation_function == 'tanh':
168 | return tf.nn.tanh(weights)
169 | elif activation_function == 'elu':
170 | return tf.nn.elu(weights)
171 | elif activation_function == 'none':
172 | return weights
173 | else:
174 | return weights
175 |
176 |
177 | def get_optimizer(opt_algo, learning_rate, loss):
178 | if opt_algo == 'adaldeta':
179 | return tf.train.AdadeltaOptimizer(learning_rate).minimize(loss)
180 | elif opt_algo == 'adagrad':
181 | return tf.train.AdagradOptimizer(learning_rate).minimize(loss)
182 | elif opt_algo == 'adam':
183 | return tf.train.AdamOptimizer(learning_rate).minimize(loss)
184 | elif opt_algo == 'ftrl':
185 | return tf.train.FtrlOptimizer(learning_rate).minimize(loss)
186 | elif opt_algo == 'gd':
187 | return tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
188 | elif opt_algo == 'padagrad':
189 | return tf.train.ProximalAdagradOptimizer(learning_rate).minimize(loss)
190 | elif opt_algo == 'pgd':
191 | return tf.train.ProximalGradientDescentOptimizer(learning_rate).minimize(loss)
192 | elif opt_algo == 'rmsprop':
193 | return tf.train.RMSPropOptimizer(learning_rate).minimize(loss)
194 | else:
195 | return tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
196 |
197 |
198 | def gather_2d(params, indices):
199 | shape = tf.shape(params)
200 | flat = tf.reshape(params, [-1])
201 | flat_idx = indices[:, 0] * shape[1] + indices[:, 1]
202 | flat_idx = tf.reshape(flat_idx, [-1])
203 | return tf.gather(flat, flat_idx)
204 |
205 |
206 | def gather_3d(params, indices):
207 | shape = tf.shape(params)
208 | flat = tf.reshape(params, [-1])
209 | flat_idx = indices[:, 0] * shape[1] * shape[2] + indices[:, 1] * shape[2] + indices[:, 2]
210 | flat_idx = tf.reshape(flat_idx, [-1])
211 | return tf.gather(flat, flat_idx)
212 |
213 |
214 | def gather_4d(params, indices):
215 | shape = tf.shape(params)
216 | flat = tf.reshape(params, [-1])
217 | flat_idx = indices[:, 0] * shape[1] * shape[2] * shape[3] + \
218 | indices[:, 1] * shape[2] * shape[3] + indices[:, 2] * shape[3] + indices[:, 3]
219 | flat_idx = tf.reshape(flat_idx, [-1])
220 | return tf.gather(flat, flat_idx)
221 |
222 |
223 | def max_pool_2d(params, k):
224 | _, indices = tf.nn.top_k(params, k, sorted=False)
225 | shape = tf.shape(indices)
226 | r1 = tf.reshape(tf.range(shape[0]), [-1, 1])
227 | r1 = tf.tile(r1, [1, k])
228 | r1 = tf.reshape(r1, [-1, 1])
229 | indices = tf.concat([r1, tf.reshape(indices, [-1, 1])], 1)
230 | return tf.reshape(gather_2d(params, indices), [-1, k])
231 |
232 |
233 | def max_pool_3d(params, k):
234 | _, indices = tf.nn.top_k(params, k, sorted=False)
235 | shape = tf.shape(indices)
236 | r1 = tf.reshape(tf.range(shape[0]), [-1, 1])
237 | r2 = tf.reshape(tf.range(shape[1]), [-1, 1])
238 | r1 = tf.tile(r1, [1, k * shape[1]])
239 | r2 = tf.tile(r2, [1, k])
240 | r1 = tf.reshape(r1, [-1, 1])
241 | r2 = tf.tile(tf.reshape(r2, [-1, 1]), [shape[0], 1])
242 | indices = tf.concat([r1, r2, tf.reshape(indices, [-1, 1])], 1)
243 | return tf.reshape(gather_3d(params, indices), [-1, shape[1], k])
244 |
245 |
246 | def max_pool_4d(params, k):
247 | _, indices = tf.nn.top_k(params, k, sorted=False)
248 | shape = tf.shape(indices)
249 | r1 = tf.reshape(tf.range(shape[0]), [-1, 1])
250 | r2 = tf.reshape(tf.range(shape[1]), [-1, 1])
251 | r3 = tf.reshape(tf.range(shape[2]), [-1, 1])
252 | r1 = tf.tile(r1, [1, shape[1] * shape[2] * k])
253 | r2 = tf.tile(r2, [1, shape[2] * k])
254 | r3 = tf.tile(r3, [1, k])
255 | r1 = tf.reshape(r1, [-1, 1])
256 | r2 = tf.tile(tf.reshape(r2, [-1, 1]), [shape[0], 1])
257 | r3 = tf.tile(tf.reshape(r3, [-1, 1]), [shape[0] * shape[1], 1])
258 | indices = tf.concat([r1, r2, r3, tf.reshape(indices, [-1, 1])], 1)
259 | return tf.reshape(gather_4d(params, indices), [-1, shape[1], shape[2], k])
260 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Product-based Neural Networks for User Response Prediction
2 |
3 | ``Note``: An extended version of the conference paper is https://arxiv.org/abs/1807.00311 , which is accepted by TOIS.
4 | Compared with this simple demo, a more detailed implementation of the journal paper is at https://github.com/Atomu2014/product-nets-distributed , which has large-scale data access, multi-gpu support, and distributed training support.
5 |
6 | ``Note``: I would like to share some intersting and advanced discussions in the [extended version](https://github.com/Atomu2014/product-nets-distributed).
7 |
8 | ``Note``: Any problems, you can contact me at kevinqu16@gmail.com. Through email, you will get my rapid response.
9 |
10 | This repository maintains the demo code of the paper
11 | [Product-based Neural Network for User Response Prediction](https://arxiv.org/abs/1611.00144)
12 | and other baseline models, implemented with ``tensorflow``.
13 | And this paper has been published on ICDM2016.
14 |
15 | ## Introduction to User Response Prediction
16 |
17 | User response prediction takes a fundamental and crucial role in today's business, especially personalized recommender system and online display advertising.
18 | Different from traditional machine learning tasks,
19 | user response prediction always has ``categorical features`` grouped by different ``fields``,
20 | which we call ``multi-field categorical data``, e.g.:
21 |
22 | ad. request={
23 | 'weekday': 3,
24 | 'hour': 18,
25 | 'IP': 255.255.255.255,
26 | 'domain': xxx.com,
27 | 'advertiser': 2997,
28 | 'click': 1
29 | }
30 |
31 | In practice, these categorical features are usually one-hot encoded for training.
32 | However, this representation results in sparsity.
33 | Challenged by data sparsity, linear models (e.g., ``LR``), latent factor-based models (e.g., ``FM``, ``FFM``), tree models (e.g., ``GBDT``), and DNN models (e.g., ``FNN``, ``DeepFM``) are proposed.
34 |
35 | A core problem in user response prediction is how to represent the complex feature interactions. Industrial applications prefer feature engineering and simple models. With GPU servers becoming more and more popular, it is promising to design complex models to explore feature interactions automatically. Through our analysis and experiments, we find a ``coupled gradient`` issue of latent factor-based models, and an ``insensitive gradient`` issue of DNN models.
36 |
37 | Take FM as an example, the gradient of each feature vector is the sum over other feature vectors. Suppose two features are independent, FM can hardly learn two orthogonal feature vectors. The gradient issue of DNNs is discussed in the paper ``Failures of Gradient-based Deep Learning``.
38 |
39 |
42 |
43 | In order to solve these issues, we propose to use product operators in DNN to help explore feature interactions. We discuss these issues in an extended paper, which is submitted to TOIS at Seq. 2017 and will be released later.
44 | Any discussion is welcomed, please contact kevinqu16@gmail.com.
45 |
46 | ## Product-based Neural Networks
47 |
48 | Through discussion of previous works, we think a good predictor should have a good feature extractor (to convert sparse features into dense representations) as well as a powerful classifier (e.g., DNN as universal approximator). Since FM is good at represent feature interactions, we introduce product operators in DNN. The proposed PNN models follow this architecture: an embedding layer to represent sparse features, a product layer to explore feature interactions, and a DNN classifier.
49 |
50 | For product layer, we propose 2 types of product operators in the paper: inner product and outer product. These operators output $n(n-1)/2$ feature interactions, which are concatenated with embeddings and fed to the following fully conncted layers.
51 |
52 | The inner product is easy to understand, the outer product is actually equivalent to projecting embeddings into a hidden space and computing the inner product of projected embeddings:
53 |
54 | $uv^T\odot w = u^Twv$
55 |
56 | Since there are $n(n-1)/2$ feature interactions, we propose some tricks to reduce complexity.
57 | However, we find these tricks restrict model capacity and are unecessary.
58 | In recent update of the code, we remove the tricks for better performance.
59 |
60 | In our implementation, we add the parameter ``kernel_type: {mat, vec, num}`` for outer product.
61 | The default type is mat, and you can switch to other types to save time and memory.
62 |
63 | A potential risk may happen in training the first hidden layer. Feature embeddings and interactions are concatenated and fed to the first hidden layer, but the embeddings and interactions have different distribution. A simple method is adding linear transformation to the embeddings to balance the distributions. ``Layer norm`` is also worth to try.
64 |
65 | ## How to Use
66 |
67 | For simplicity, we provide iPinYou dataset at [make-ipinyou-data](https://github.com/Atomu2014/make-ipinyou-data).
68 | Follow the instructions and update the soft link `data`:
69 |
70 | ```
71 | XXX/product-nets$ ln -sfn XXX/make-ipinyou-data/2997 data
72 | ```
73 |
74 | run ``main.py``:
75 |
76 | cd python
77 | python main.py
78 |
79 | As for dataset, we build a repository on github serving as a benchmark in our Lab
80 | [APEX-Datasets](https://github.com/Atomu2014/Ads-RecSys-Datasets).
81 | This repository contains detailed data processing, feature engineering,
82 | data storage/buffering/access and other implementations.
83 | For better I/O performance, this benchmark provides hdf5 APIs.
84 | Currently we provide download links of 4 large scale ad-click datasets (already processed),
85 | Criteo-8day, Avazu, iPinYou-all, and Criteo Challenge. More datasets will be updated later.
86 |
87 | This code is originally written in python 2.7, numpy, scipy and tensorflow are required.
88 | In recent update, we make it consistent with python 3.x.
89 | Thus you can use it as a start-up with any python version you like.
90 | LR, FM, FNN, CCPM, DeepFM and PNN are all implemented in `models.py`, based on TensorFlow.
91 | You can train any of the models in `main.py` and configure parameters via a dict.
92 |
93 | More models and mxnet implementation will be released in the extended version.
94 |
95 | ## Practical Issues
96 |
97 | In this section we select some discussions from my emails and issues to share.
98 |
99 | ``Note``: 2 advanced discussions about overfitting of adam and performance gain of DNNs are presented in the [extended version](https://github.com/Atomu2014/product-nets-distributed). You are welcomed to discuss relavant problems through issues or emails.
100 |
101 | ### 1. Sparse Regularization (L2)
102 |
103 | L2 is fundamental in controlling over-fitting.
104 | For sparse input, we suggest sparse regularization,
105 | i.e. we only regularize on activated weights/neurons.
106 | Traditional L2 regularization penalizes all parameters $\Vert w\Vert$, $w = [w_1, \dots, w_n]$ even though some inputs are zero $x_i = 0$,
107 | which means every parameter $w_i$ will have a non-zero gradient for every training example $x$.
108 | Sparse regularization instead penalizes on non-zero terms, $\Vert xw \Vert$.
109 |
110 | ### 2. Initialization
111 |
112 | Initializing weights with small random numbers is always promising in Deep Learning.
113 | Usually we use ``uniform`` or ``normal`` distribution around 0.
114 | An empirical choice is to set the distribution variance near $\sqrt{(1/n)}$ where n is the input dimension.
115 | Another choice is ``xavier``, for uniform distribution,
116 | ``xavier`` uses $\sqrt{(3/node_i)}$, $\sqrt{(3/node_o)}$,
117 | or $\sqrt{(6/(node_i+node_o))}$ as the upper/lower bound.
118 | This is to keep unit variance among different layers.
119 |
120 | ### 3. Learning Rate
121 |
122 | For deep neural networks with a lot of parameters,
123 | large learning rate always causes divergence.
124 | Usually sgd with small learning rate has promising performance, however converges slow.
125 | For extremely sparse input, adaptive learning rate converges much faster,
126 | e.g. AdaGrad, Adam, FTRL, etc.
127 | [This blog](http://sebastianruder.com/optimizing-gradient-descent/)
128 | compares most of adaptive algorithms.
129 | Even though adaptive algorithms speed up and sometimes jump out of local minimum,
130 | there is no guarantee for better generalization performance.
131 | To sum up, ``Adam`` and ``AdaGrad`` are good choices. ``Adam`` converges faster than ``AdaGrad``, but is also easier to overfit.
132 |
133 | ### 4. Data Processing
134 |
135 | Usually you need to build a feature map to convert categorical data into one-hot representation.
136 | These features usually follow a long-tailed distribution,
137 | resulting in extremely large feature space, e.g. IP address.
138 | A simple way is to remove those low frequency features by a threshold,
139 | which will dramatically reduce the input dimension without much decrease of performance.
140 |
141 | For unbalance dataset, a typical positive/negative ratio is 0.1% - 1%,
142 | and Facebook has published a paper discussing negative down sampling.
143 | Negative down-sampling can speed up training, as well as reduce dimension, but requires calibration in some cases.
144 |
145 | ### 5. Normalization
146 |
147 | There are two kinds of normalization, feature level and instance level.
148 | Feature level is within one field,
149 | e.g. set the mean of one field to 0 and the variance to 1.
150 | Instance level is to keep consistent between difference records,
151 | e.g. you have a multi-value field, which has 5-100 values and the length varies.
152 | You can set the magnitude to 1 by shifting and scaling.
153 | Besides, ``batch/weight/layer normalization`` are worth to try when network grows deeper.
154 |
155 | ### 6. Continuous/Discrete/Multi-value Feature
156 |
157 | Most features in User Response Prediction have discrete values (categorical features). The key difference between continuous and discrete features is, only continuous features are comparable in values. For example, {``male``: 0, ``female``: 1} and {``male``: 1, ``female``: 0} are equivalent.
158 |
159 | When the data contains both continuous and discrete values, one solution is to discretize those continuous values using bucketing. Taking 'age' as an example, you can set [0, 12] as ``children``, [13, 18] as ``teenagers``, [19, ~] as ``adults`` and so on.
160 |
161 | Multi-value features are special cases of discrete features.
162 | e.g. recently reviewed items = [``item2``, ``item7``, ``item11``], [``item1``, ``item4``, ``item9``, ``item13``].
163 | This type of data is also called set data, with one key property ``permutation invariance``, which is discussed in the paper ``DeepSet``.
164 |
165 | ### 7. Activation Function
166 |
167 | Do not use ``sigmoid`` in hidden layers, use ``tanh`` or ``relu`` instead.
168 | And recently ``selu`` is proposed to maintain fixed point in training.
169 |
170 | ### 8. Numerical Stable Parameters
171 |
172 | Adaptive optimizers usually requires hyperparameters for numerical stability, e.g., $\epsilon$ in ``Adam``, ``initial value`` of ``AdaGrad``. Sometimes, these parameters have large impacts on model convergence and performance.
173 |
--------------------------------------------------------------------------------
/python/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import absolute_import
3 | from __future__ import division
4 |
5 | import sys
6 | if sys.version[0] == '2':
7 | import cPickle as pkl
8 | else:
9 | import pickle as pkl
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from python import utils
15 |
16 | dtype = utils.DTYPE
17 |
18 |
19 | class Model:
20 | def __init__(self):
21 | self.sess = None
22 | self.X = None
23 | self.y = None
24 | self.layer_keeps = None
25 | self.vars = None
26 | self.keep_prob_train = None
27 | self.keep_prob_test = None
28 |
29 | def run(self, fetches, X=None, y=None, mode='train'):
30 | feed_dict = {}
31 | if type(self.X) is list:
32 | for i in range(len(X)):
33 | feed_dict[self.X[i]] = X[i]
34 | else:
35 | feed_dict[self.X] = X
36 | if y is not None:
37 | feed_dict[self.y] = y
38 | if self.layer_keeps is not None:
39 | if mode == 'train':
40 | feed_dict[self.layer_keeps] = self.keep_prob_train
41 | elif mode == 'test':
42 | feed_dict[self.layer_keeps] = self.keep_prob_test
43 | return self.sess.run(fetches, feed_dict)
44 |
45 | def dump(self, model_path):
46 | var_map = {}
47 | for name, var in self.vars.iteritems():
48 | var_map[name] = self.run(var)
49 | pkl.dump(var_map, open(model_path, 'wb'))
50 | print('model dumped at', model_path)
51 |
52 |
53 | class LR(Model):
54 | def __init__(self, input_dim=None, output_dim=1, init_path=None, opt_algo='gd', learning_rate=1e-2, l2_weight=0,
55 | random_seed=None):
56 | Model.__init__(self)
57 | init_vars = [('w', [input_dim, output_dim], 'xavier', dtype),
58 | ('b', [output_dim], 'zero', dtype)]
59 | self.graph = tf.Graph()
60 | with self.graph.as_default():
61 | if random_seed is not None:
62 | tf.set_random_seed(random_seed)
63 | self.X = tf.sparse_placeholder(dtype)
64 | self.y = tf.placeholder(dtype)
65 | self.vars = utils.init_var_map(init_vars, init_path)
66 |
67 | w = self.vars['w']
68 | b = self.vars['b']
69 | xw = tf.sparse_tensor_dense_matmul(self.X, w)
70 | logits = tf.reshape(xw + b, [-1])
71 | self.y_prob = tf.sigmoid(logits)
72 |
73 | self.loss = tf.reduce_mean(
74 | tf.nn.sigmoid_cross_entropy_with_logits(labels=self.y, logits=logits)) + \
75 | l2_weight * tf.nn.l2_loss(xw)
76 | self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
77 |
78 | config = tf.ConfigProto()
79 | config.gpu_options.allow_growth = True
80 | self.sess = tf.Session(config=config)
81 | tf.global_variables_initializer().run(session=self.sess)
82 |
83 |
84 | class FM(Model):
85 | def __init__(self, input_dim=None, output_dim=1, factor_order=10, init_path=None, opt_algo='gd', learning_rate=1e-2,
86 | l2_w=0, l2_v=0, random_seed=None):
87 | Model.__init__(self)
88 | init_vars = [('w', [input_dim, output_dim], 'xavier', dtype),
89 | ('v', [input_dim, factor_order], 'xavier', dtype),
90 | ('b', [output_dim], 'zero', dtype)]
91 | self.graph = tf.Graph()
92 | with self.graph.as_default():
93 | if random_seed is not None:
94 | tf.set_random_seed(random_seed)
95 | self.X = tf.sparse_placeholder(dtype)
96 | self.y = tf.placeholder(dtype)
97 | self.vars = utils.init_var_map(init_vars, init_path)
98 |
99 | w = self.vars['w']
100 | v = self.vars['v']
101 | b = self.vars['b']
102 |
103 | X_square = tf.SparseTensor(self.X.indices, tf.square(self.X.values), tf.to_int64(tf.shape(self.X)))
104 | xv = tf.square(tf.sparse_tensor_dense_matmul(self.X, v))
105 | p = 0.5 * tf.reshape(
106 | tf.reduce_sum(xv - tf.sparse_tensor_dense_matmul(X_square, tf.square(v)), 1),
107 | [-1, output_dim])
108 | xw = tf.sparse_tensor_dense_matmul(self.X, w)
109 | logits = tf.reshape(xw + b + p, [-1])
110 | self.y_prob = tf.sigmoid(logits)
111 |
112 | self.loss = tf.reduce_mean(
113 | tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=self.y)) + \
114 | l2_w * tf.nn.l2_loss(xw) + \
115 | l2_v * tf.nn.l2_loss(xv)
116 | self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
117 |
118 | config = tf.ConfigProto()
119 | config.gpu_options.allow_growth = True
120 | self.sess = tf.Session(config=config)
121 | tf.global_variables_initializer().run(session=self.sess)
122 |
123 |
124 | class FNN(Model):
125 | def __init__(self, field_sizes=None, embed_size=10, layer_sizes=None, layer_acts=None, drop_out=None,
126 | embed_l2=None, layer_l2=None, init_path=None, opt_algo='gd', learning_rate=1e-2, random_seed=None):
127 | Model.__init__(self)
128 | init_vars = []
129 | num_inputs = len(field_sizes)
130 | for i in range(num_inputs):
131 | init_vars.append(('embed_%d' % i, [field_sizes[i], embed_size], 'xavier', dtype))
132 | node_in = num_inputs * embed_size
133 | for i in range(len(layer_sizes)):
134 | init_vars.append(('w%d' % i, [node_in, layer_sizes[i]], 'xavier', dtype))
135 | init_vars.append(('b%d' % i, [layer_sizes[i]], 'zero', dtype))
136 | node_in = layer_sizes[i]
137 | self.graph = tf.Graph()
138 | with self.graph.as_default():
139 | if random_seed is not None:
140 | tf.set_random_seed(random_seed)
141 | self.X = [tf.sparse_placeholder(dtype) for i in range(num_inputs)]
142 | self.y = tf.placeholder(dtype)
143 | self.keep_prob_train = 1 - np.array(drop_out)
144 | self.keep_prob_test = np.ones_like(drop_out)
145 | self.layer_keeps = tf.placeholder(dtype)
146 | self.vars = utils.init_var_map(init_vars, init_path)
147 | w0 = [self.vars['embed_%d' % i] for i in range(num_inputs)]
148 | xw = tf.concat([tf.sparse_tensor_dense_matmul(self.X[i], w0[i]) for i in range(num_inputs)], 1)
149 | l = xw
150 |
151 | for i in range(len(layer_sizes)):
152 | wi = self.vars['w%d' % i]
153 | bi = self.vars['b%d' % i]
154 | print(l.shape, wi.shape, bi.shape)
155 | l = tf.nn.dropout(
156 | utils.activate(
157 | tf.matmul(l, wi) + bi,
158 | layer_acts[i]),
159 | self.layer_keeps[i])
160 |
161 | l = tf.squeeze(l)
162 | self.y_prob = tf.sigmoid(l)
163 |
164 | self.loss = tf.reduce_mean(
165 | tf.nn.sigmoid_cross_entropy_with_logits(logits=l, labels=self.y))
166 | if layer_l2 is not None:
167 | self.loss += embed_l2 * tf.nn.l2_loss(xw)
168 | for i in range(len(layer_sizes)):
169 | wi = self.vars['w%d' % i]
170 | self.loss += layer_l2[i] * tf.nn.l2_loss(wi)
171 | self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
172 |
173 | config = tf.ConfigProto()
174 | config.gpu_options.allow_growth = True
175 | self.sess = tf.Session(config=config)
176 | tf.global_variables_initializer().run(session=self.sess)
177 |
178 |
179 | class DeepFM(Model):
180 | def __init__(self, field_sizes=None, embed_size=10, layer_sizes=None, layer_acts=None, drop_out=None,
181 | embed_l2=None, layer_l2=None, init_path=None, opt_algo='gd', learning_rate=1e-2, random_seed=None):
182 | Model.__init__(self)
183 | init_vars = []
184 | num_inputs = len(field_sizes)
185 | for i in range(num_inputs):
186 | init_vars.append(('embed_%d' % i, [field_sizes[i], embed_size], 'xavier', dtype))
187 | init_vars.append(('weight_%d' % i, [field_sizes[i], 1], 'xavier', dtype))
188 | init_vars.append(('bias', [1], 'zero', dtype))
189 | node_in = num_inputs * embed_size
190 | for i in range(len(layer_sizes)):
191 | init_vars.append(('w%d' % i, [node_in, layer_sizes[i]], 'xavier', dtype))
192 | init_vars.append(('b%d' % i, [layer_sizes[i]], 'zero', dtype))
193 | node_in = layer_sizes[i]
194 | self.graph = tf.Graph()
195 | with self.graph.as_default():
196 | if random_seed is not None:
197 | tf.set_random_seed(random_seed)
198 | self.X = [tf.sparse_placeholder(dtype) for i in range(num_inputs)]
199 | self.y = tf.placeholder(dtype)
200 | self.keep_prob_train = 1 - np.array(drop_out)
201 | self.keep_prob_test = np.ones_like(drop_out)
202 | self.layer_keeps = tf.placeholder(dtype)
203 | self.vars = utils.init_var_map(init_vars, init_path)
204 | w = [self.vars['weight_%d' % i] for i in range(num_inputs)]
205 | v = [self.vars['embed_%d' % i] for i in range(num_inputs)]
206 | b = self.vars['bias']
207 | xw = tf.concat([tf.sparse_tensor_dense_matmul(self.X[i], w[i]) for i in range(num_inputs)], 1)
208 | xv = tf.concat([tf.sparse_tensor_dense_matmul(self.X[i], v[i]) for i in range(num_inputs)], 1)
209 | l = xv
210 |
211 | for i in range(len(layer_sizes)):
212 | wi = self.vars['w%d' % i]
213 | bi = self.vars['b%d' % i]
214 | print(l.shape, wi.shape, bi.shape)
215 | l = tf.nn.dropout(
216 | utils.activate(
217 | tf.matmul(l, wi) + bi,
218 | layer_acts[i]),
219 | self.layer_keeps[i])
220 | l = tf.squeeze(l)
221 |
222 | xv = tf.reshape(xv, [-1, num_inputs, embed_size])
223 | p = 0.5 * tf.reduce_sum(
224 | tf.square(tf.reduce_sum(xv, 1)) -
225 | tf.reduce_sum(tf.square(xv), 1),
226 | 1)
227 | xw = tf.reduce_sum(xw, 1)
228 | logits = tf.reshape(l + xw + b + p, [-1])
229 |
230 | self.y_prob = tf.sigmoid(logits)
231 |
232 | self.loss = tf.reduce_mean(
233 | tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=self.y))
234 | if layer_l2 is not None:
235 | self.loss += embed_l2 * tf.nn.l2_loss(xw)
236 | for i in range(len(layer_sizes)):
237 | wi = self.vars['w%d' % i]
238 | self.loss += layer_l2[i] * tf.nn.l2_loss(wi)
239 | self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
240 |
241 | config = tf.ConfigProto()
242 | config.gpu_options.allow_growth = True
243 | self.sess = tf.Session(config=config)
244 | tf.global_variables_initializer().run(session=self.sess)
245 |
246 |
247 | class CCPM(Model):
248 | def __init__(self, field_sizes=None, embed_size=10, filter_sizes=None, layer_acts=None, drop_out=None,
249 | init_path=None, opt_algo='gd', learning_rate=1e-2, random_seed=None):
250 | Model.__init__(self)
251 | init_vars = []
252 | num_inputs = len(field_sizes)
253 | for i in range(num_inputs):
254 | init_vars.append(('embed_%d' % i, [field_sizes[i], embed_size], 'xavier', dtype))
255 | init_vars.append(('f1', [embed_size, filter_sizes[0], 1, 2], 'xavier', dtype))
256 | init_vars.append(('f2', [embed_size, filter_sizes[1], 2, 2], 'xavier', dtype))
257 | init_vars.append(('w1', [2 * 3 * embed_size, 1], 'xavier', dtype))
258 | init_vars.append(('b1', [1], 'zero', dtype))
259 |
260 | self.graph = tf.Graph()
261 | with self.graph.as_default():
262 | if random_seed is not None:
263 | tf.set_random_seed(random_seed)
264 | self.X = [tf.sparse_placeholder(dtype) for i in range(num_inputs)]
265 | self.y = tf.placeholder(dtype)
266 | self.keep_prob_train = 1 - np.array(drop_out)
267 | self.keep_prob_test = np.ones_like(drop_out)
268 | self.layer_keeps = tf.placeholder(dtype)
269 | self.vars = utils.init_var_map(init_vars, init_path)
270 | w0 = [self.vars['embed_%d' % i] for i in range(num_inputs)]
271 | xw = tf.concat([tf.sparse_tensor_dense_matmul(self.X[i], w0[i]) for i in range(num_inputs)], 1)
272 | l = xw
273 |
274 | l = tf.transpose(tf.reshape(l, [-1, num_inputs, embed_size, 1]), [0, 2, 1, 3])
275 | f1 = self.vars['f1']
276 | l = tf.nn.conv2d(l, f1, [1, 1, 1, 1], 'SAME')
277 | l = tf.transpose(
278 | utils.max_pool_4d(
279 | tf.transpose(l, [0, 1, 3, 2]),
280 | int(num_inputs / 2)),
281 | [0, 1, 3, 2])
282 | f2 = self.vars['f2']
283 | l = tf.nn.conv2d(l, f2, [1, 1, 1, 1], 'SAME')
284 | l = tf.transpose(
285 | utils.max_pool_4d(
286 | tf.transpose(l, [0, 1, 3, 2]), 3),
287 | [0, 1, 3, 2])
288 | l = tf.nn.dropout(
289 | utils.activate(
290 | tf.reshape(l, [-1, embed_size * 3 * 2]),
291 | layer_acts[0]),
292 | self.layer_keeps[0])
293 | w1 = self.vars['w1']
294 | b1 = self.vars['b1']
295 | l = tf.matmul(l, w1) + b1
296 |
297 | l = tf.squeeze(l)
298 | self.y_prob = tf.sigmoid(l)
299 |
300 | self.loss = tf.reduce_mean(
301 | tf.nn.sigmoid_cross_entropy_with_logits(logits=l, labels=self.y))
302 | self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
303 |
304 | config = tf.ConfigProto()
305 | config.gpu_options.allow_growth = True
306 | self.sess = tf.Session(config=config)
307 | tf.global_variables_initializer().run(session=self.sess)
308 |
309 |
310 | class PNN1(Model):
311 | def __init__(self, field_sizes=None, embed_size=10, layer_sizes=None, layer_acts=None, drop_out=None,
312 | embed_l2=None, layer_l2=None, init_path=None, opt_algo='gd', learning_rate=1e-2, random_seed=None):
313 | Model.__init__(self)
314 | init_vars = []
315 | num_inputs = len(field_sizes)
316 | for i in range(num_inputs):
317 | init_vars.append(('embed_%d' % i, [field_sizes[i], embed_size], 'xavier', dtype))
318 | num_pairs = int(num_inputs * (num_inputs - 1) / 2)
319 | node_in = num_inputs * embed_size + num_pairs
320 | # node_in = num_inputs * (embed_size + num_inputs)
321 | for i in range(len(layer_sizes)):
322 | init_vars.append(('w%d' % i, [node_in, layer_sizes[i]], 'xavier', dtype))
323 | init_vars.append(('b%d' % i, [layer_sizes[i]], 'zero', dtype))
324 | node_in = layer_sizes[i]
325 | self.graph = tf.Graph()
326 | with self.graph.as_default():
327 | if random_seed is not None:
328 | tf.set_random_seed(random_seed)
329 | self.X = [tf.sparse_placeholder(dtype) for i in range(num_inputs)]
330 | self.y = tf.placeholder(dtype)
331 | self.keep_prob_train = 1 - np.array(drop_out)
332 | self.keep_prob_test = np.ones_like(drop_out)
333 | self.layer_keeps = tf.placeholder(dtype)
334 | self.vars = utils.init_var_map(init_vars, init_path)
335 | w0 = [self.vars['embed_%d' % i] for i in range(num_inputs)]
336 | xw = tf.concat([tf.sparse_tensor_dense_matmul(self.X[i], w0[i]) for i in range(num_inputs)], 1)
337 | xw3d = tf.reshape(xw, [-1, num_inputs, embed_size])
338 |
339 | row = []
340 | col = []
341 | for i in range(num_inputs-1):
342 | for j in range(i+1, num_inputs):
343 | row.append(i)
344 | col.append(j)
345 | # batch * pair * k
346 | p = tf.transpose(
347 | # pair * batch * k
348 | tf.gather(
349 | # num * batch * k
350 | tf.transpose(
351 | xw3d, [1, 0, 2]),
352 | row),
353 | [1, 0, 2])
354 | # batch * pair * k
355 | q = tf.transpose(
356 | tf.gather(
357 | tf.transpose(
358 | xw3d, [1, 0, 2]),
359 | col),
360 | [1, 0, 2])
361 | p = tf.reshape(p, [-1, num_pairs, embed_size])
362 | q = tf.reshape(q, [-1, num_pairs, embed_size])
363 | ip = tf.reshape(tf.reduce_sum(p * q, [-1]), [-1, num_pairs])
364 |
365 | # simple but redundant
366 | # batch * n * 1 * k, batch * 1 * n * k
367 | # ip = tf.reshape(
368 | # tf.reduce_sum(
369 | # tf.expand_dims(xw3d, 2) *
370 | # tf.expand_dims(xw3d, 1),
371 | # 3),
372 | # [-1, num_inputs**2])
373 | l = tf.concat([xw, ip], 1)
374 |
375 | for i in range(len(layer_sizes)):
376 | wi = self.vars['w%d' % i]
377 | bi = self.vars['b%d' % i]
378 | l = tf.nn.dropout(
379 | utils.activate(
380 | tf.matmul(l, wi) + bi,
381 | layer_acts[i]),
382 | self.layer_keeps[i])
383 |
384 | l = tf.squeeze(l)
385 | self.y_prob = tf.sigmoid(l)
386 |
387 | self.loss = tf.reduce_mean(
388 | tf.nn.sigmoid_cross_entropy_with_logits(logits=l, labels=self.y))
389 | if layer_l2 is not None:
390 | self.loss += embed_l2 * tf.nn.l2_loss(xw)
391 | for i in range(len(layer_sizes)):
392 | wi = self.vars['w%d' % i]
393 | self.loss += layer_l2[i] * tf.nn.l2_loss(wi)
394 | self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
395 |
396 | config = tf.ConfigProto()
397 | config.gpu_options.allow_growth = True
398 | self.sess = tf.Session(config=config)
399 | tf.global_variables_initializer().run(session=self.sess)
400 |
401 |
402 | class PNN2(Model):
403 | def __init__(self, field_sizes=None, embed_size=10, layer_sizes=None, layer_acts=None, drop_out=None,
404 | embed_l2=None, layer_l2=None, init_path=None, opt_algo='gd', learning_rate=1e-2, random_seed=None,
405 | layer_norm=True, kernel_type='mat'):
406 | Model.__init__(self)
407 | init_vars = []
408 | num_inputs = len(field_sizes)
409 | for i in range(num_inputs):
410 | init_vars.append(('embed_%d' % i, [field_sizes[i], embed_size], 'xavier', dtype))
411 | num_pairs = int(num_inputs * (num_inputs - 1) / 2)
412 | node_in = num_inputs * embed_size + num_pairs
413 | if kernel_type == 'mat':
414 | init_vars.append(('kernel', [embed_size, num_pairs, embed_size], 'xavier', dtype))
415 | elif kernel_type == 'vec':
416 | init_vars.append(('kernel', [num_pairs, embed_size], 'xavier', dtype))
417 | elif kernel_type == 'num':
418 | init_vars.append(('kernel', [num_pairs, 1], 'xavier', dtype))
419 | for i in range(len(layer_sizes)):
420 | init_vars.append(('w%d' % i, [node_in, layer_sizes[i]], 'xavier', dtype))
421 | init_vars.append(('b%d' % i, [layer_sizes[i]], 'zero', dtype))
422 | node_in = layer_sizes[i]
423 | self.graph = tf.Graph()
424 | with self.graph.as_default():
425 | if random_seed is not None:
426 | tf.set_random_seed(random_seed)
427 | self.X = [tf.sparse_placeholder(dtype) for i in range(num_inputs)]
428 | self.y = tf.placeholder(dtype)
429 | self.keep_prob_train = 1 - np.array(drop_out)
430 | self.keep_prob_test = np.ones_like(drop_out)
431 | self.layer_keeps = tf.placeholder(dtype)
432 | self.vars = utils.init_var_map(init_vars, init_path)
433 | w0 = [self.vars['embed_%d' % i] for i in range(num_inputs)]
434 | xw = tf.concat([tf.sparse_tensor_dense_matmul(self.X[i], w0[i]) for i in range(num_inputs)], 1)
435 | xw3d = tf.reshape(xw, [-1, num_inputs, embed_size])
436 |
437 | row = []
438 | col = []
439 | for i in range(num_inputs - 1):
440 | for j in range(i + 1, num_inputs):
441 | row.append(i)
442 | col.append(j)
443 | # batch * pair * k
444 | p = tf.transpose(
445 | # pair * batch * k
446 | tf.gather(
447 | # num * batch * k
448 | tf.transpose(
449 | xw3d, [1, 0, 2]),
450 | row),
451 | [1, 0, 2])
452 | # batch * pair * k
453 | q = tf.transpose(
454 | tf.gather(
455 | tf.transpose(
456 | xw3d, [1, 0, 2]),
457 | col),
458 | [1, 0, 2])
459 | # b * p * k
460 | p = tf.reshape(p, [-1, num_pairs, embed_size])
461 | # b * p * k
462 | q = tf.reshape(q, [-1, num_pairs, embed_size])
463 | k = self.vars['kernel']
464 |
465 | if kernel_type == 'mat':
466 | # batch * 1 * pair * k
467 | p = tf.expand_dims(p, 1)
468 | # batch * pair
469 | kp = tf.reduce_sum(
470 | # batch * pair * k
471 | tf.multiply(
472 | # batch * pair * k
473 | tf.transpose(
474 | # batch * k * pair
475 | tf.reduce_sum(
476 | # batch * k * pair * k
477 | tf.multiply(
478 | p, k),
479 | -1),
480 | [0, 2, 1]),
481 | q),
482 | -1)
483 | else:
484 | # 1 * pair * (k or 1)
485 | k = tf.expand_dims(k, 0)
486 | # batch * pair
487 | kp = tf.reduce_sum(p * q * k, -1)
488 |
489 | #
490 | # if layer_norm:
491 | # # x_mean, x_var = tf.nn.moments(xw, [1], keep_dims=True)
492 | # # xw = (xw - x_mean) / tf.sqrt(x_var)
493 | # # x_g = tf.Variable(tf.ones([num_inputs * embed_size]), name='x_g')
494 | # # x_b = tf.Variable(tf.zeros([num_inputs * embed_size]), name='x_b')
495 | # # x_g = tf.Print(x_g, [x_g[:10], x_b])
496 | # # xw = xw * x_g + x_b
497 | # p_mean, p_var = tf.nn.moments(op, [1], keep_dims=True)
498 | # op = (op - p_mean) / tf.sqrt(p_var)
499 | # p_g = tf.Variable(tf.ones([embed_size**2]), name='p_g')
500 | # p_b = tf.Variable(tf.zeros([embed_size**2]), name='p_b')
501 | # # p_g = tf.Print(p_g, [p_g[:10], p_b])
502 | # op = op * p_g + p_b
503 |
504 | l = tf.concat([xw, kp], 1)
505 | for i in range(len(layer_sizes)):
506 | wi = self.vars['w%d' % i]
507 | bi = self.vars['b%d' % i]
508 | l = tf.nn.dropout(
509 | utils.activate(
510 | tf.matmul(l, wi) + bi,
511 | layer_acts[i]),
512 | self.layer_keeps[i])
513 |
514 | l = tf.squeeze(l)
515 | self.y_prob = tf.sigmoid(l)
516 |
517 | self.loss = tf.reduce_mean(
518 | tf.nn.sigmoid_cross_entropy_with_logits(logits=l, labels=self.y))
519 | if layer_l2 is not None:
520 | self.loss += embed_l2 * tf.nn.l2_loss(xw)#tf.concat(w0, 0))
521 | for i in range(len(layer_sizes)):
522 | wi = self.vars['w%d' % i]
523 | self.loss += layer_l2[i] * tf.nn.l2_loss(wi)
524 | self.optimizer = utils.get_optimizer(opt_algo, learning_rate, self.loss)
525 |
526 | config = tf.ConfigProto()
527 | config.gpu_options.allow_growth = True
528 | self.sess = tf.Session(config=config)
529 | tf.global_variables_initializer().run(session=self.sess)
530 |
--------------------------------------------------------------------------------