├── README.md
├── amazon_prepro.ipynb
├── config.py
├── data.py
├── layer.py
├── main.py
├── ml_prepro.ipynb
├── model.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Deep Interest Network for Click-Through Rate Prediction
Deep Interest Evolution Network for Click-Through Rate Prediction
2 |
3 | I reference [zhougr1993](https://github.com/zhougr1993/DeepInterestNetwork) and [mouna99](https://github.com/mouna99/dien) code and converte it to TensorFlow 2.0.
4 | This code performs similarly to the paper on ml-20 and amazon datasets.
5 | You can modify the ```model``` called in ```main.py``` and then utilize a model such as Base, DIN, DIEN.
6 |
7 | Requirements
8 | * python 3.6
9 | * tensorflow 2.0
10 |
11 | Run ```python main.py```
12 |
13 |
14 |
--------------------------------------------------------------------------------
/amazon_prepro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 13,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import tensorflow as tf\n",
10 | "import numpy as np\n",
11 | "import csv\n",
12 | "from tqdm import tqdm\n",
13 | "import pandas as pd\n",
14 | "import random\n",
15 | "import pickle\n",
16 | "\n",
17 | "root = '/data/private/Ad/amazon/'"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 14,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "with open(root+'Electronics_5.json') as fin:\n",
27 | " df = {}\n",
28 | " for i, line in enumerate(fin):\n",
29 | " df[i] = eval(line)\n",
30 | " reviews_df = pd.DataFrame.from_dict(df, orient='index')"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 15,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "with open(root+'np_prepro/reviews.pkl', 'wb') as f:\n",
40 | " pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 16,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "with open(root+'meta_Electronics.json') as fin:\n",
50 | " df = {}\n",
51 | " for i, line in enumerate(fin):\n",
52 | " df[i] = eval(line)\n",
53 | " meta_df = pd.DataFrame.from_dict(df, orient='index')\n",
54 | "\n",
55 | "meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]\n",
56 | "meta_df = meta_df.reset_index(drop=True)\n",
57 | "with open(root+'np_prepro/meta.pkl', 'wb') as f:\n",
58 | " pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 17,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]\n",
68 | "meta_df = meta_df[['asin', 'categories']]\n",
69 | "# only one category...\n",
70 | "meta_df['categories'] = meta_df['categories'].map(lambda x: x[-1][-1])"
71 | ]
72 | },
73 | {
74 | "cell_type": "code",
75 | "execution_count": 18,
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "def build_map(df, col_name):\n",
80 | " key = sorted(df[col_name].unique().tolist())\n",
81 | " m = dict(zip(key, range(len(key))))\n",
82 | " df[col_name] = df[col_name].map(lambda x: m[x])\n",
83 | " return m, key"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 19,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "asin_map, asin_key = build_map(meta_df, 'asin')\n",
93 | "cate_map, cate_key = build_map(meta_df, 'categories')\n",
94 | "revi_map, revi_key = build_map(reviews_df, 'reviewerID')"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 20,
100 | "metadata": {},
101 | "outputs": [
102 | {
103 | "name": "stdout",
104 | "output_type": "stream",
105 | "text": [
106 | "user_count: 192403\titem_count: 63001\tcate_count: 801\texample_count: 1689188\n"
107 | ]
108 | }
109 | ],
110 | "source": [
111 | "user_count, item_count, cate_count, example_count =\\\n",
112 | " len(revi_map), len(asin_map), len(cate_map), reviews_df.shape[0]\n",
113 | "print('user_count: %d\\titem_count: %d\\tcate_count: %d\\texample_count: %d' %\n",
114 | " (user_count, item_count, cate_count, example_count))"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 21,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "meta_df = meta_df.sort_values('asin')\n",
124 | "meta_df = meta_df.reset_index(drop=True)"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 22,
130 | "metadata": {
131 | "collapsed": true,
132 | "jupyter": {
133 | "outputs_hidden": true
134 | }
135 | },
136 | "outputs": [
137 | {
138 | "name": "stdout",
139 | "output_type": "stream",
140 | "text": [
141 | " reviewerID asin unixReviewTime\n",
142 | "0 0 13179 1400457600\n",
143 | "1 0 17993 1400457600\n",
144 | "2 0 28326 1400457600\n",
145 | "3 0 29247 1400457600\n",
146 | "4 0 62275 1400457600\n",
147 | "5 1 58134 1379548800\n",
148 | "6 1 62555 1379548800\n",
149 | "7 1 41862 1384041600\n",
150 | "8 1 46010 1385769600\n",
151 | "9 1 54171 1385769600\n",
152 | "10 1 56540 1385769600\n",
153 | "11 2 42298 1366156800\n",
154 | "12 2 46782 1366156800\n",
155 | "13 2 50682 1366156800\n",
156 | "14 2 42390 1370563200\n",
157 | "15 2 47355 1370563200\n",
158 | "16 3 25578 1371772800\n",
159 | "17 3 21989 1375142400\n",
160 | "18 3 58444 1402876800\n",
161 | "19 3 60072 1402876800\n",
162 | "20 3 62274 1402876800\n",
163 | "21 4 54245 1359331200\n",
164 | "22 4 3112 1361145600\n",
165 | "23 4 40094 1361145600\n",
166 | "24 4 48963 1361145600\n",
167 | "25 4 30275 1389744000\n",
168 | "26 4 58671 1402358400\n",
169 | "27 4 62022 1402358400\n",
170 | "28 5 30462 1373241600\n",
171 | "29 5 55698 1373241600\n",
172 | "... ... ... ...\n",
173 | "1689158 192402 36004 1357776000\n",
174 | "1689159 192402 37977 1357776000\n",
175 | "1689160 192402 39411 1357776000\n",
176 | "1689161 192402 7681 1358035200\n",
177 | "1689162 192402 18186 1358035200\n",
178 | "1689163 192402 27522 1358035200\n",
179 | "1689164 192402 29206 1358035200\n",
180 | "1689165 192402 30547 1358035200\n",
181 | "1689166 192402 42076 1358035200\n",
182 | "1689167 192402 29518 1377907200\n",
183 | "1689168 192402 35691 1377907200\n",
184 | "1689169 192402 44123 1377907200\n",
185 | "1689170 192402 54615 1377907200\n",
186 | "1689171 192402 45465 1385856000\n",
187 | "1689172 192402 48862 1385856000\n",
188 | "1689173 192402 52445 1385856000\n",
189 | "1689174 192402 60275 1385856000\n",
190 | "1689175 192402 51004 1386633600\n",
191 | "1689176 192402 53509 1386633600\n",
192 | "1689177 192402 61519 1386633600\n",
193 | "1689178 192402 28581 1388534400\n",
194 | "1689179 192402 29369 1388534400\n",
195 | "1689180 192402 41590 1388534400\n",
196 | "1689181 192402 51306 1388534400\n",
197 | "1689182 192402 49816 1389744000\n",
198 | "1689183 192402 57576 1389744000\n",
199 | "1689184 192402 22519 1396396800\n",
200 | "1689185 192402 20977 1404172800\n",
201 | "1689186 192402 60283 1404172800\n",
202 | "1689187 192402 62677 1405123200\n",
203 | "\n",
204 | "[1689188 rows x 3 columns]\n"
205 | ]
206 | }
207 | ],
208 | "source": [
209 | "reviews_df['asin'] = reviews_df['asin'].map(lambda x: asin_map[x])\n",
210 | "reviews_df = reviews_df.sort_values(['reviewerID', 'unixReviewTime'])\n",
211 | "reviews_df = reviews_df.reset_index(drop=True)"
212 | ]
213 | },
214 | {
215 | "cell_type": "code",
216 | "execution_count": 24,
217 | "metadata": {},
218 | "outputs": [],
219 | "source": [
220 | "cate_list = [meta_df['categories'][i] for i in range(len(asin_map))]\n",
221 | "cate_list = np.array(cate_list, dtype=np.int32)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 25,
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "with open(root+'np_prepro/remap.pkl', 'wb') as f:\n",
231 | " pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) # uid, iid\n",
232 | " pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL) # cid of iid line\n",
233 | " pickle.dump((user_count, item_count, cate_count, example_count),\n",
234 | " f, pickle.HIGHEST_PROTOCOL)\n",
235 | " pickle.dump((asin_key, cate_key, revi_key), f, pickle.HIGHEST_PROTOCOL)"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 26,
241 | "metadata": {},
242 | "outputs": [],
243 | "source": [
244 | "random.seed(1234)\n",
245 | "\n",
246 | "train_set = []\n",
247 | "test_set = []\n",
248 | "for reviewerID, hist in reviews_df.groupby('reviewerID'):\n",
249 | " pos_list = hist['asin'].tolist()\n",
250 | " neg_list = []\n",
251 | " for _ in range(len(pos_list)):\n",
252 | " neg = pos_list[0]\n",
253 | " while neg in pos_list + neg_list :\n",
254 | " neg = random.randint(0, item_count-1)\n",
255 | " neg_list.append(neg)\n",
256 | " \n",
257 | " for i in range(1, len(pos_list)-1):\n",
258 | " hist = pos_list[:i]\n",
259 | " train_set.append((reviewerID, hist, pos_list[i], 1))\n",
260 | " train_set.append((reviewerID, hist, neg_list[i], 0))\n",
261 | " label = (pos_list[-1], neg_list[-1])\n",
262 | " test_set.append((reviewerID, hist, label))\n",
263 | "\n",
264 | "random.shuffle(train_set)\n",
265 | "random.shuffle(test_set)\n",
266 | "\n",
267 | "assert len(test_set) == user_count\n",
268 | "\n",
269 | "with open(root+'np_prepro/dataset.pkl', 'wb') as f:\n",
270 | " pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)\n",
271 | " pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)\n",
272 | " pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL)\n",
273 | " pickle.dump((user_count, item_count, cate_count), f, pickle.HIGHEST_PROTOCOL)"
274 | ]
275 | }
276 | ],
277 | "metadata": {
278 | "kernelspec": {
279 | "display_name": "Python 3",
280 | "language": "python",
281 | "name": "python3"
282 | },
283 | "language_info": {
284 | "codemirror_mode": {
285 | "name": "ipython",
286 | "version": 3
287 | },
288 | "file_extension": ".py",
289 | "mimetype": "text/x-python",
290 | "name": "python",
291 | "nbconvert_exporter": "python",
292 | "pygments_lexer": "ipython3",
293 | "version": "3.6.8"
294 | }
295 | },
296 | "nbformat": 4,
297 | "nbformat_minor": 4
298 | }
299 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def argparser():
4 | parser = argparse.ArgumentParser()
5 |
6 | parser.add_argument('--lr', default=0.1, help='learning rate', type=float)
7 | parser.add_argument('--train_batch_size', default=32, help='batch size', type=int)
8 | parser.add_argument('--test_batch_size', default=512, help='batch size', type=int)
9 | parser.add_argument('--epochs', default=10, help='number of epochs', type=int)
10 | parser.add_argument('--print_step', default=1000, help='step size for print log', type=int)
11 |
12 | parser.add_argument('--dataset_dir', default='/data/private/Ad/amazon/np_prepro/', help='dataset path')
13 | parser.add_argument('--model_path', default='./models/', help='model load path', type=str)
14 | parser.add_argument('--log_path', default='./logs/', help='log path fot tensorboard', type=str)
15 | parser.add_argument('--is_reuse', default=False)
16 | parser.add_argument('--multi_gpu', default=False)
17 |
18 | parser.add_argument('--user_count', default=192403, help='number of users', type=int)
19 | parser.add_argument('--item_count', default=63001, help='number of items', type=int)
20 | parser.add_argument('--cate_count', default=801, help='number of categories', type=int)
21 |
22 | parser.add_argument('--user_dim', default=128, help='dimension of user', type=int)
23 | parser.add_argument('--item_dim', default=64, help='dimension of item', type=int)
24 | parser.add_argument('--cate_dim', default=64, help='dimension of category', type=int)
25 |
26 | parser.add_argument('--dim_layers', default=[80,40,1], type=int)
27 |
28 | args = parser.parse_args()
29 |
30 | return args
31 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import random
3 | import pickle
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from config import argparser
8 |
9 | args = argparser()
10 |
11 | with open(args.dataset_dir+'dataset.pkl', 'rb') as f:
12 | train_set = pickle.load(f, encoding='latin1')
13 | test_set = pickle.load(f, encoding='latin1')
14 | cate_list = pickle.load(f, encoding='latin1')
15 | cate_list = tf.convert_to_tensor(cate_list, dtype=tf.int64)
16 | user_count, item_count, cate_count = pickle.load(f)
17 |
18 | class DataLoader:
19 | def __init__(self, batch_size, data):
20 | self.batch_size = batch_size
21 | self.data = data
22 | self.epoch_size = len(self.data) // self.batch_size
23 | if self.epoch_size * self.batch_size < len(self.data):
24 | self.epoch_size += 1
25 | self.i = 0
26 |
27 | def __iter__(self):
28 | self.i = 0
29 | return self
30 |
31 | def __next__(self):
32 | if self.i == self.epoch_size:
33 | raise StopIteration
34 | ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size,
35 | len(self.data))]
36 | self.i += 1
37 |
38 | u, i, y, sl = [], [], [], []
39 | for t in ts:
40 | u.append(t[0])
41 | i.append(t[2])
42 | y.append(t[3])
43 | sl.append(len(t[1]))
44 | max_sl = max(sl)
45 |
46 | hist_i = np.zeros([len(ts), max_sl], np.int64)
47 |
48 | k = 0
49 | for t in ts:
50 | for l in range(len(t[1])):
51 | hist_i[k][l] = t[1][l]
52 | k += 1
53 |
54 | return tf.convert_to_tensor(u), tf.convert_to_tensor(i), \
55 | tf.convert_to_tensor(y), tf.convert_to_tensor(hist_i), \
56 | sl
57 |
58 | class DataLoaderTest:
59 | def __init__(self, batch_size, data):
60 |
61 | self.batch_size = batch_size
62 | self.data = data
63 | self.epoch_size = len(self.data) // self.batch_size
64 | if self.epoch_size * self.batch_size < len(self.data):
65 | self.epoch_size += 1
66 | self.i = 0
67 |
68 | def __iter__(self):
69 | self.i = 0
70 | return self
71 |
72 | def __next__(self):
73 |
74 | if self.i == self.epoch_size:
75 | raise StopIteration
76 |
77 | ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size,
78 | len(self.data))]
79 | self.i += 1
80 |
81 | u, i, j, sl = [], [], [], []
82 | for t in ts:
83 | u.append(t[0])
84 | i.append(t[2][0])
85 | j.append(t[2][1])
86 | sl.append(len(t[1]))
87 | max_sl = max(sl)
88 |
89 | hist_i = np.zeros([len(ts), max_sl], np.int64)
90 |
91 | k = 0
92 | for t in ts:
93 | for l in range(len(t[1])):
94 | hist_i[k][l] = t[1][l]
95 | k += 1
96 |
97 | return tf.convert_to_tensor(u), tf.convert_to_tensor(i), \
98 | tf.convert_to_tensor(j), tf.convert_to_tensor(hist_i), \
99 | sl
100 |
101 | def __len__(self):
102 | return len(self.data)
103 |
104 | def get_dataloader(train_batch_size, test_batch_size):
105 | return DataLoader(train_batch_size, train_set), DataLoaderTest(test_batch_size, test_set), \
106 | user_count, item_count, cate_count, cate_list
107 |
--------------------------------------------------------------------------------
/layer.py:
--------------------------------------------------------------------------------
1 | # https://github.com/zhougr1993/DeepInterestNetwork/blob/master/din/Dice.py
2 | import tensorflow as tf
3 | import tensorflow.keras.layers as nn
4 |
5 | class attention(tf.keras.layers.Layer):
6 | def __init__(self, keys_dim, dim_layers):
7 | super(attention, self).__init__()
8 | self.keys_dim = keys_dim
9 |
10 | self.fc = tf.keras.Sequential()
11 | for dim_layer in dim_layers[:-1]:
12 | self.fc.add(nn.Dense(dim_layer, activation='sigmoid'))
13 | self.fc.add(nn.Dense(dim_layers[-1], activation=None))
14 |
15 | def call(self, queries, keys, keys_length):
16 | queries = tf.tile(tf.expand_dims(queries, 1), [1, tf.shape(keys)[1], 1])
17 | # outer product ?
18 | din_all = tf.concat([queries, keys, queries-keys, queries*keys], axis=-1)
19 | outputs = tf.transpose(self.fc(din_all), [0,2,1])
20 |
21 | # Mask
22 | key_masks = tf.sequence_mask(keys_length, max(keys_length), dtype=tf.bool) # [B, T]
23 | key_masks = tf.expand_dims(key_masks, 1)
24 | paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)
25 | outputs = tf.where(key_masks, outputs, paddings) # [B, 1, T]
26 |
27 | # Scale
28 | outputs = outputs / (self.keys_dim ** 0.5)
29 |
30 | # Activation
31 | outputs = tf.keras.activations.softmax(outputs, -1) # [B, 1, T]
32 |
33 | # Weighted sum
34 | outputs = tf.squeeze(tf.matmul(outputs, keys)) # [B, H]
35 |
36 | return outputs
37 |
38 | class dice(tf.keras.layers.Layer):
39 | def __init__(self, feat_dim):
40 | super(dice, self).__init__()
41 | self.feat_dim = feat_dim
42 | self.alphas= tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32)
43 | self.beta = tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32)
44 |
45 | self.bn = tf.keras.layers.BatchNormalization(center=False, scale=False)
46 |
47 | def call(self, _x, axis=-1, epsilon=0.000000001):
48 |
49 | reduction_axes = list(range(len(_x.get_shape())))
50 | del reduction_axes[axis]
51 | broadcast_shape = [1] * len(_x.get_shape())
52 | broadcast_shape[axis] = self.feat_dim
53 |
54 | mean = tf.reduce_mean(_x, axis=reduction_axes)
55 | brodcast_mean = tf.reshape(mean, broadcast_shape)
56 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes)
57 | std = tf.sqrt(std)
58 | brodcast_std = tf.reshape(std, broadcast_shape)
59 |
60 | x_normed = self.bn(_x)
61 | x_p = tf.keras.activations.sigmoid(self.beta * x_normed)
62 |
63 | return self.alphas * (1.0 - x_p) * _x + x_p * _x
64 |
65 | def parametric_relu(_x):
66 | with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE):
67 | alphas = tf.get_variable('alpha', _x.get_shape()[-1],
68 | initializer=tf.constant_initializer(0.0),
69 | dtype=tf.float32)
70 | pos = tf.nn.relu(_x)
71 | neg = alphas * (_x - abs(_x)) * 0.5
72 |
73 | return pos + neg
74 |
75 | class Bilinear(tf.keras.layers.Layer):
76 | def __init__(self, units):
77 | super(Bilinear, self).__init__()
78 | self.linear_act = nn.Dense(units, activation=None, use_bias=True)
79 | self.linear_noact = nn.Dense(units, activation=None, use_bias=False)
80 |
81 | def call(self, a, b, gate_b=None):
82 | if gate_b is None:
83 | return tf.keras.activations.sigmoid(self.linear_act(a) + self.linear_noact(b))
84 | else:
85 | return tf.keras.activations.tanh(self.linear_act(a) + tf.math.multiply(gate_b, self.linear_noact(b)))
86 |
87 | class AUGRU(tf.keras.layers.Layer):
88 | def __init__(self, units):
89 | super(AUGRU, self).__init__()
90 |
91 | self.u_gate = Bilinear(units)
92 | self.r_gate = Bilinear(units)
93 | self.c_memo = Bilinear(units)
94 |
95 | def call(self, inputs, state, att_score):
96 | u = self.u_gate(inputs, state)
97 | r = self.r_gate(inputs, state)
98 | c = self.c_memo(inputs, state, r)
99 |
100 | u_= att_score * u
101 | final = (1 - u_) * state + u_ * c
102 |
103 | return final
104 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 | '''0 = all messages are logged (default behavior)
4 | 1 = INFO messages are not printed
5 | 2 = INFO and WARNING messages are not printed
6 | 3 = INFO, WARNING, and ERROR messages are not printed'''
7 | import time
8 |
9 | import tensorflow as tf
10 |
11 | from config import argparser
12 | from data import get_dataloader
13 | from model import Base, DIN, DIEN
14 | from utils import eval
15 |
16 | # Config
17 | print(tf.__version__)
18 | print("GPU Available: ", tf.test.is_gpu_available())
19 |
20 | args = argparser()
21 |
22 | # Data Load
23 | train_data, test_data, \
24 | user_count, item_count, cate_count, \
25 | cate_list = get_dataloader(args.train_batch_size, args.test_batch_size)
26 |
27 | # Loss, Optim
28 | optimizer = tf.keras.optimizers.SGD(learning_rate=args.lr, momentum=0.0)
29 | loss_metric = tf.keras.metrics.Sum()
30 | auc_metric = tf.keras.metrics.AUC()
31 |
32 | # Model
33 | model = Base(user_count, item_count, cate_count, cate_list,
34 | args.user_dim, args.item_dim, args.cate_dim, args.dim_layers)
35 |
36 | # Board
37 | train_summary_writer = tf.summary.create_file_writer(args.log_path)
38 |
39 | #@tf.function
40 | def train_one_step(u,i,y,hist_i,sl):
41 | with tf.GradientTape() as tape:
42 | output,_ = model(u,i,hist_i,sl)
43 | loss = tf.reduce_mean(
44 | tf.nn.sigmoid_cross_entropy_with_logits(logits=output,
45 | labels=tf.cast(y, dtype=tf.float32)))
46 | gradient = tape.gradient(loss, model.trainable_variables)
47 | clip_gradient, _ = tf.clip_by_global_norm(gradient, 5.0)
48 | optimizer.apply_gradients(zip(clip_gradient, model.trainable_variables))
49 |
50 | loss_metric(loss)
51 |
52 | # Train
53 | def train(optimizer):
54 | best_loss= 0.
55 | best_auc = 0.
56 | start_time = time.time()
57 | for epoch in range(args.epochs):
58 | for step, (u, i, y, hist_i, sl) in enumerate(train_data, start=1):
59 | train_one_step(u, i, y, hist_i, sl)
60 |
61 | if step % args.print_step == 0:
62 | test_gauc, auc = eval(model, test_data)
63 | print('Epoch %d Global_step %d\tTrain_loss: %.4f\tEval_GAUC: %.4f\tEval_AUC: %.4f' %
64 | (epoch, step, loss_metric.result() / args.print_step, test_gauc, auc))
65 |
66 | if best_auc < test_gauc:
67 | best_loss= loss_metric.result() / args.print_step
68 | best_auc = test_gauc
69 | model.save_weights(args.model_path+'cp-%d.ckpt'%epoch)
70 | loss_metric.reset_states()
71 |
72 | with train_summary_writer.as_default():
73 | tf.summary.scalar('loss', best_loss, step=epoch)
74 | tf.summary.scalar('test_gauc', best_auc, step=epoch)
75 |
76 | loss_metric.reset_states()
77 | optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0)
78 |
79 | print('Epoch %d DONE\tCost time: %.2f' % (epoch, time.time()-start_time))
80 | print('Best test_gauc: ', best_auc)
81 |
82 |
83 | # Main
84 | if __name__ == '__main__':
85 | train(optimizer)
86 |
--------------------------------------------------------------------------------
/ml_prepro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd\n",
10 | "import random\n",
11 | "import numpy as np\n",
12 | "import pickle\n",
13 | "import csv\n",
14 | "import tensorflow as tf\n",
15 | "\n",
16 | "dataset_dir = '/data/private/Ad/amazon/np_prepro/'\n",
17 | "with open(dataset_dir+'dataset.pkl', 'rb') as f:\n",
18 | " train_set = pickle.load(f, encoding='latin1') # uid, [hist], vid, label\n",
19 | " test_set = pickle.load(f, encoding='latin1') # uid, [hist], pid, nid\n",
20 | " cate_list = pickle.load(f, encoding='latin1') # id2cate list\n",
21 | " cate_list = tf.convert_to_tensor(cate_list, dtype=tf.int64)\n",
22 | " user_count, item_count, cate_count = pickle.load(f) # (user_count, item_count, cate_count)\n"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 6,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "root = '/data/private/Ad/ml-20m/'\n",
32 | "reviews_df = pd.read_csv(root+'ratings.csv')"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 7,
38 | "metadata": {
39 | "collapsed": true,
40 | "jupyter": {
41 | "outputs_hidden": true
42 | }
43 | },
44 | "outputs": [
45 | {
46 | "data": {
47 | "text/html": [
48 | "
\n", 66 | " | userId | \n", 67 | "movieId | \n", 68 | "rating | \n", 69 | "timestamp | \n", 70 | "
---|---|---|---|---|
0 | \n", 75 | "1 | \n", 76 | "2 | \n", 77 | "3.5 | \n", 78 | "1112486027 | \n", 79 | "
1 | \n", 82 | "1 | \n", 83 | "29 | \n", 84 | "3.5 | \n", 85 | "1112484676 | \n", 86 | "
2 | \n", 89 | "1 | \n", 90 | "32 | \n", 91 | "3.5 | \n", 92 | "1112484819 | \n", 93 | "
3 | \n", 96 | "1 | \n", 97 | "47 | \n", 98 | "3.5 | \n", 99 | "1112484727 | \n", 100 | "
4 | \n", 103 | "1 | \n", 104 | "50 | \n", 105 | "3.5 | \n", 106 | "1112484580 | \n", 107 | "
5 | \n", 110 | "1 | \n", 111 | "112 | \n", 112 | "3.5 | \n", 113 | "1094785740 | \n", 114 | "
6 | \n", 117 | "1 | \n", 118 | "151 | \n", 119 | "4.0 | \n", 120 | "1094785734 | \n", 121 | "
7 | \n", 124 | "1 | \n", 125 | "223 | \n", 126 | "4.0 | \n", 127 | "1112485573 | \n", 128 | "
8 | \n", 131 | "1 | \n", 132 | "253 | \n", 133 | "4.0 | \n", 134 | "1112484940 | \n", 135 | "
9 | \n", 138 | "1 | \n", 139 | "260 | \n", 140 | "4.0 | \n", 141 | "1112484826 | \n", 142 | "
10 | \n", 145 | "1 | \n", 146 | "293 | \n", 147 | "4.0 | \n", 148 | "1112484703 | \n", 149 | "
11 | \n", 152 | "1 | \n", 153 | "296 | \n", 154 | "4.0 | \n", 155 | "1112484767 | \n", 156 | "
12 | \n", 159 | "1 | \n", 160 | "318 | \n", 161 | "4.0 | \n", 162 | "1112484798 | \n", 163 | "
13 | \n", 166 | "1 | \n", 167 | "337 | \n", 168 | "3.5 | \n", 169 | "1094785709 | \n", 170 | "
14 | \n", 173 | "1 | \n", 174 | "367 | \n", 175 | "3.5 | \n", 176 | "1112485980 | \n", 177 | "
15 | \n", 180 | "1 | \n", 181 | "541 | \n", 182 | "4.0 | \n", 183 | "1112484603 | \n", 184 | "
16 | \n", 187 | "1 | \n", 188 | "589 | \n", 189 | "3.5 | \n", 190 | "1112485557 | \n", 191 | "
17 | \n", 194 | "1 | \n", 195 | "593 | \n", 196 | "3.5 | \n", 197 | "1112484661 | \n", 198 | "
18 | \n", 201 | "1 | \n", 202 | "653 | \n", 203 | "3.0 | \n", 204 | "1094785691 | \n", 205 | "
19 | \n", 208 | "1 | \n", 209 | "919 | \n", 210 | "3.5 | \n", 211 | "1094785621 | \n", 212 | "
20 | \n", 215 | "1 | \n", 216 | "924 | \n", 217 | "3.5 | \n", 218 | "1094785598 | \n", 219 | "
21 | \n", 222 | "1 | \n", 223 | "1009 | \n", 224 | "3.5 | \n", 225 | "1112486013 | \n", 226 | "
22 | \n", 229 | "1 | \n", 230 | "1036 | \n", 231 | "4.0 | \n", 232 | "1112485480 | \n", 233 | "
23 | \n", 236 | "1 | \n", 237 | "1079 | \n", 238 | "4.0 | \n", 239 | "1094785665 | \n", 240 | "
24 | \n", 243 | "1 | \n", 244 | "1080 | \n", 245 | "3.5 | \n", 246 | "1112485375 | \n", 247 | "
25 | \n", 250 | "1 | \n", 251 | "1089 | \n", 252 | "3.5 | \n", 253 | "1112484669 | \n", 254 | "
26 | \n", 257 | "1 | \n", 258 | "1090 | \n", 259 | "4.0 | \n", 260 | "1112485453 | \n", 261 | "
27 | \n", 264 | "1 | \n", 265 | "1097 | \n", 266 | "4.0 | \n", 267 | "1112485701 | \n", 268 | "
28 | \n", 271 | "1 | \n", 272 | "1136 | \n", 273 | "3.5 | \n", 274 | "1112484609 | \n", 275 | "
29 | \n", 278 | "1 | \n", 279 | "1193 | \n", 280 | "3.5 | \n", 281 | "1112484690 | \n", 282 | "
... | \n", 285 | "... | \n", 286 | "... | \n", 287 | "... | \n", 288 | "... | \n", 289 | "
20000233 | \n", 292 | "138493 | \n", 293 | "50872 | \n", 294 | "3.5 | \n", 295 | "1256750388 | \n", 296 | "
20000234 | \n", 299 | "138493 | \n", 300 | "51086 | \n", 301 | "3.5 | \n", 302 | "1255810566 | \n", 303 | "
20000235 | \n", 306 | "138493 | \n", 307 | "51662 | \n", 308 | "4.5 | \n", 309 | "1255856908 | \n", 310 | "
20000236 | \n", 313 | "138493 | \n", 314 | "51884 | \n", 315 | "4.5 | \n", 316 | "1256294768 | \n", 317 | "
20000237 | \n", 320 | "138493 | \n", 321 | "52579 | \n", 322 | "4.0 | \n", 323 | "1255856957 | \n", 324 | "
20000238 | \n", 327 | "138493 | \n", 328 | "52975 | \n", 329 | "4.0 | \n", 330 | "1256680293 | \n", 331 | "
20000239 | \n", 334 | "138493 | \n", 335 | "53123 | \n", 336 | "4.0 | \n", 337 | "1255816320 | \n", 338 | "
20000240 | \n", 341 | "138493 | \n", 342 | "53125 | \n", 343 | "3.0 | \n", 344 | "1255810649 | \n", 345 | "
20000241 | \n", 348 | "138493 | \n", 349 | "53322 | \n", 350 | "4.0 | \n", 351 | "1255812146 | \n", 352 | "
20000242 | \n", 355 | "138493 | \n", 356 | "53464 | \n", 357 | "4.0 | \n", 358 | "1260209920 | \n", 359 | "
20000243 | \n", 362 | "138493 | \n", 363 | "53996 | \n", 364 | "4.5 | \n", 365 | "1259865104 | \n", 366 | "
20000244 | \n", 369 | "138493 | \n", 370 | "55269 | \n", 371 | "5.0 | \n", 372 | "1255816088 | \n", 373 | "
20000245 | \n", 376 | "138493 | \n", 377 | "55814 | \n", 378 | "5.0 | \n", 379 | "1255811181 | \n", 380 | "
20000246 | \n", 383 | "138493 | \n", 384 | "56757 | \n", 385 | "3.0 | \n", 386 | "1255810698 | \n", 387 | "
20000247 | \n", 390 | "138493 | \n", 391 | "56801 | \n", 392 | "3.0 | \n", 393 | "1255809988 | \n", 394 | "
20000248 | \n", 397 | "138493 | \n", 398 | "58879 | \n", 399 | "4.5 | \n", 400 | "1255816798 | \n", 401 | "
20000249 | \n", 404 | "138493 | \n", 405 | "59315 | \n", 406 | "4.0 | \n", 407 | "1255818138 | \n", 408 | "
20000250 | \n", 411 | "138493 | \n", 412 | "59725 | \n", 413 | "3.0 | \n", 414 | "1255818078 | \n", 415 | "
20000251 | \n", 418 | "138493 | \n", 419 | "59784 | \n", 420 | "5.0 | \n", 421 | "1255816901 | \n", 422 | "
20000252 | \n", 425 | "138493 | \n", 426 | "60069 | \n", 427 | "4.0 | \n", 428 | "1258134687 | \n", 429 | "
20000253 | \n", 432 | "138493 | \n", 433 | "60816 | \n", 434 | "4.5 | \n", 435 | "1259865163 | \n", 436 | "
20000254 | \n", 439 | "138493 | \n", 440 | "61160 | \n", 441 | "4.0 | \n", 442 | "1258390537 | \n", 443 | "
20000255 | \n", 446 | "138493 | \n", 447 | "65682 | \n", 448 | "4.5 | \n", 449 | "1255816373 | \n", 450 | "
20000256 | \n", 453 | "138493 | \n", 454 | "66762 | \n", 455 | "4.5 | \n", 456 | "1255805408 | \n", 457 | "
20000257 | \n", 460 | "138493 | \n", 461 | "68319 | \n", 462 | "4.5 | \n", 463 | "1260209720 | \n", 464 | "
20000258 | \n", 467 | "138493 | \n", 468 | "68954 | \n", 469 | "4.5 | \n", 470 | "1258126920 | \n", 471 | "
20000259 | \n", 474 | "138493 | \n", 475 | "69526 | \n", 476 | "4.5 | \n", 477 | "1259865108 | \n", 478 | "
20000260 | \n", 481 | "138493 | \n", 482 | "69644 | \n", 483 | "3.0 | \n", 484 | "1260209457 | \n", 485 | "
20000261 | \n", 488 | "138493 | \n", 489 | "70286 | \n", 490 | "5.0 | \n", 491 | "1258126944 | \n", 492 | "
20000262 | \n", 495 | "138493 | \n", 496 | "71619 | \n", 497 | "2.5 | \n", 498 | "1255811136 | \n", 499 | "
20000263 rows × 4 columns
\n", 503 | "\n", 627 | " | movieId | \n", 628 | "title | \n", 629 | "genres | \n", 630 | "
---|---|---|---|
0 | \n", 635 | "1 | \n", 636 | "Toy Story (1995) | \n", 637 | "Adventure|Animation|Children|Comedy|Fantasy | \n", 638 | "
1 | \n", 641 | "2 | \n", 642 | "Jumanji (1995) | \n", 643 | "Adventure|Children|Fantasy | \n", 644 | "
2 | \n", 647 | "3 | \n", 648 | "Grumpier Old Men (1995) | \n", 649 | "Comedy|Romance | \n", 650 | "
3 | \n", 653 | "4 | \n", 654 | "Waiting to Exhale (1995) | \n", 655 | "Comedy|Drama|Romance | \n", 656 | "
4 | \n", 659 | "5 | \n", 660 | "Father of the Bride Part II (1995) | \n", 661 | "Comedy | \n", 662 | "
5 | \n", 665 | "6 | \n", 666 | "Heat (1995) | \n", 667 | "Action|Crime|Thriller | \n", 668 | "
6 | \n", 671 | "7 | \n", 672 | "Sabrina (1995) | \n", 673 | "Comedy|Romance | \n", 674 | "
7 | \n", 677 | "8 | \n", 678 | "Tom and Huck (1995) | \n", 679 | "Adventure|Children | \n", 680 | "
8 | \n", 683 | "9 | \n", 684 | "Sudden Death (1995) | \n", 685 | "Action | \n", 686 | "
9 | \n", 689 | "10 | \n", 690 | "GoldenEye (1995) | \n", 691 | "Action|Adventure|Thriller | \n", 692 | "
10 | \n", 695 | "11 | \n", 696 | "American President, The (1995) | \n", 697 | "Comedy|Drama|Romance | \n", 698 | "
11 | \n", 701 | "12 | \n", 702 | "Dracula: Dead and Loving It (1995) | \n", 703 | "Comedy|Horror | \n", 704 | "
12 | \n", 707 | "13 | \n", 708 | "Balto (1995) | \n", 709 | "Adventure|Animation|Children | \n", 710 | "
13 | \n", 713 | "14 | \n", 714 | "Nixon (1995) | \n", 715 | "Drama | \n", 716 | "
14 | \n", 719 | "15 | \n", 720 | "Cutthroat Island (1995) | \n", 721 | "Action|Adventure|Romance | \n", 722 | "
15 | \n", 725 | "16 | \n", 726 | "Casino (1995) | \n", 727 | "Crime|Drama | \n", 728 | "
16 | \n", 731 | "17 | \n", 732 | "Sense and Sensibility (1995) | \n", 733 | "Drama|Romance | \n", 734 | "
17 | \n", 737 | "18 | \n", 738 | "Four Rooms (1995) | \n", 739 | "Comedy | \n", 740 | "
18 | \n", 743 | "19 | \n", 744 | "Ace Ventura: When Nature Calls (1995) | \n", 745 | "Comedy | \n", 746 | "
19 | \n", 749 | "20 | \n", 750 | "Money Train (1995) | \n", 751 | "Action|Comedy|Crime|Drama|Thriller | \n", 752 | "
20 | \n", 755 | "21 | \n", 756 | "Get Shorty (1995) | \n", 757 | "Comedy|Crime|Thriller | \n", 758 | "
21 | \n", 761 | "22 | \n", 762 | "Copycat (1995) | \n", 763 | "Crime|Drama|Horror|Mystery|Thriller | \n", 764 | "
22 | \n", 767 | "23 | \n", 768 | "Assassins (1995) | \n", 769 | "Action|Crime|Thriller | \n", 770 | "
23 | \n", 773 | "24 | \n", 774 | "Powder (1995) | \n", 775 | "Drama|Sci-Fi | \n", 776 | "
24 | \n", 779 | "25 | \n", 780 | "Leaving Las Vegas (1995) | \n", 781 | "Drama|Romance | \n", 782 | "
25 | \n", 785 | "26 | \n", 786 | "Othello (1995) | \n", 787 | "Drama | \n", 788 | "
26 | \n", 791 | "27 | \n", 792 | "Now and Then (1995) | \n", 793 | "Children|Drama | \n", 794 | "
27 | \n", 797 | "28 | \n", 798 | "Persuasion (1995) | \n", 799 | "Drama|Romance | \n", 800 | "
28 | \n", 803 | "29 | \n", 804 | "City of Lost Children, The (Cité des enfants p... | \n", 805 | "Adventure|Drama|Fantasy|Mystery|Sci-Fi | \n", 806 | "
29 | \n", 809 | "30 | \n", 810 | "Shanghai Triad (Yao a yao yao dao waipo qiao) ... | \n", 811 | "Crime|Drama | \n", 812 | "
... | \n", 815 | "... | \n", 816 | "... | \n", 817 | "... | \n", 818 | "
27248 | \n", 821 | "131146 | \n", 822 | "Werner - Volles Rooäää (1999) | \n", 823 | "Animation|Comedy | \n", 824 | "
27249 | \n", 827 | "131148 | \n", 828 | "What A Man (2011) | \n", 829 | "Comedy|Romance | \n", 830 | "
27250 | \n", 833 | "131150 | \n", 834 | "7 Dwarves: The Forest Is Not Enough (2006) | \n", 835 | "Comedy | \n", 836 | "
27251 | \n", 839 | "131152 | \n", 840 | "The Fat Spy (1966) | \n", 841 | "Comedy | \n", 842 | "
27252 | \n", 845 | "131154 | \n", 846 | "Die Bademeister – Weiber, saufen, Leben retten... | \n", 847 | "Comedy | \n", 848 | "
27253 | \n", 851 | "131156 | \n", 852 | "Ants in the Pants 2 (2002) | \n", 853 | "Comedy | \n", 854 | "
27254 | \n", 857 | "131158 | \n", 858 | "Manta, Manta (1991) | \n", 859 | "Comedy | \n", 860 | "
27255 | \n", 863 | "131160 | \n", 864 | "Oscar and the Lady in Pink (2009) | \n", 865 | "Drama | \n", 866 | "
27256 | \n", 869 | "131162 | \n", 870 | "Por un puñado de besos (2014) | \n", 871 | "Drama|Romance | \n", 872 | "
27257 | \n", 875 | "131164 | \n", 876 | "Vietnam in HD (2011) | \n", 877 | "War | \n", 878 | "
27258 | \n", 881 | "131166 | \n", 882 | "WWII IN HD (2009) | \n", 883 | "(no genres listed) | \n", 884 | "
27259 | \n", 887 | "131168 | \n", 888 | "Phoenix (2014) | \n", 889 | "Drama | \n", 890 | "
27260 | \n", 893 | "131170 | \n", 894 | "Parallels (2015) | \n", 895 | "Sci-Fi | \n", 896 | "
27261 | \n", 899 | "131172 | \n", 900 | "Closed Curtain (2013) | \n", 901 | "(no genres listed) | \n", 902 | "
27262 | \n", 905 | "131174 | \n", 906 | "Gentlemen (2014) | \n", 907 | "Drama|Romance|Thriller | \n", 908 | "
27263 | \n", 911 | "131176 | \n", 912 | "A Second Chance (2014) | \n", 913 | "Drama | \n", 914 | "
27264 | \n", 917 | "131180 | \n", 918 | "Dead Rising: Watchtower (2015) | \n", 919 | "Action|Horror|Thriller | \n", 920 | "
27265 | \n", 923 | "131231 | \n", 924 | "Standby (2014) | \n", 925 | "Comedy|Romance | \n", 926 | "
27266 | \n", 929 | "131237 | \n", 930 | "What Men Talk About (2010) | \n", 931 | "Comedy | \n", 932 | "
27267 | \n", 935 | "131239 | \n", 936 | "Three Quarter Moon (2011) | \n", 937 | "Comedy|Drama | \n", 938 | "
27268 | \n", 941 | "131241 | \n", 942 | "Ants in the Pants (2000) | \n", 943 | "Comedy|Romance | \n", 944 | "
27269 | \n", 947 | "131243 | \n", 948 | "Werner - Gekotzt wird später (2003) | \n", 949 | "Animation|Comedy | \n", 950 | "
27270 | \n", 953 | "131248 | \n", 954 | "Brother Bear 2 (2006) | \n", 955 | "Adventure|Animation|Children|Comedy|Fantasy | \n", 956 | "
27271 | \n", 959 | "131250 | \n", 960 | "No More School (2000) | \n", 961 | "Comedy | \n", 962 | "
27272 | \n", 965 | "131252 | \n", 966 | "Forklift Driver Klaus: The First Day on the Jo... | \n", 967 | "Comedy|Horror | \n", 968 | "
27273 | \n", 971 | "131254 | \n", 972 | "Kein Bund für's Leben (2007) | \n", 973 | "Comedy | \n", 974 | "
27274 | \n", 977 | "131256 | \n", 978 | "Feuer, Eis & Dosenbier (2002) | \n", 979 | "Comedy | \n", 980 | "
27275 | \n", 983 | "131258 | \n", 984 | "The Pirates (2014) | \n", 985 | "Adventure | \n", 986 | "
27276 | \n", 989 | "131260 | \n", 990 | "Rentun Ruusu (2001) | \n", 991 | "(no genres listed) | \n", 992 | "
27277 | \n", 995 | "131262 | \n", 996 | "Innocence (2014) | \n", 997 | "Adventure|Fantasy|Horror | \n", 998 | "
27278 rows × 3 columns
\n", 1002 | "