├── README.md ├── amazon_prepro.ipynb ├── config.py ├── data.py ├── layer.py ├── main.py ├── ml_prepro.ipynb ├── model.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Interest Network for Click-Through Rate Prediction
Deep Interest Evolution Network for Click-Through Rate Prediction 2 | 3 | I reference [zhougr1993](https://github.com/zhougr1993/DeepInterestNetwork) and [mouna99](https://github.com/mouna99/dien) code and converte it to TensorFlow 2.0. 4 | This code performs similarly to the paper on ml-20 and amazon datasets. 5 | You can modify the ```model``` called in ```main.py``` and then utilize a model such as Base, DIN, DIEN. 6 | 7 | Requirements 8 | * python 3.6 9 | * tensorflow 2.0 10 | 11 | Run ```python main.py``` 12 | 13 | 14 | -------------------------------------------------------------------------------- /amazon_prepro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import numpy as np\n", 11 | "import csv\n", 12 | "from tqdm import tqdm\n", 13 | "import pandas as pd\n", 14 | "import random\n", 15 | "import pickle\n", 16 | "\n", 17 | "root = '/data/private/Ad/amazon/'" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 14, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "with open(root+'Electronics_5.json') as fin:\n", 27 | " df = {}\n", 28 | " for i, line in enumerate(fin):\n", 29 | " df[i] = eval(line)\n", 30 | " reviews_df = pd.DataFrame.from_dict(df, orient='index')" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 15, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "with open(root+'np_prepro/reviews.pkl', 'wb') as f:\n", 40 | " pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 16, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "with open(root+'meta_Electronics.json') as fin:\n", 50 | " df = {}\n", 51 | " for i, line in enumerate(fin):\n", 52 | " df[i] = eval(line)\n", 53 | " meta_df = pd.DataFrame.from_dict(df, orient='index')\n", 54 | "\n", 55 | "meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]\n", 56 | "meta_df = meta_df.reset_index(drop=True)\n", 57 | "with open(root+'np_prepro/meta.pkl', 'wb') as f:\n", 58 | " pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 17, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]\n", 68 | "meta_df = meta_df[['asin', 'categories']]\n", 69 | "# only one category...\n", 70 | "meta_df['categories'] = meta_df['categories'].map(lambda x: x[-1][-1])" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 18, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "def build_map(df, col_name):\n", 80 | " key = sorted(df[col_name].unique().tolist())\n", 81 | " m = dict(zip(key, range(len(key))))\n", 82 | " df[col_name] = df[col_name].map(lambda x: m[x])\n", 83 | " return m, key" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 19, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "asin_map, asin_key = build_map(meta_df, 'asin')\n", 93 | "cate_map, cate_key = build_map(meta_df, 'categories')\n", 94 | "revi_map, revi_key = build_map(reviews_df, 'reviewerID')" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 20, 100 | "metadata": {}, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "user_count: 192403\titem_count: 63001\tcate_count: 801\texample_count: 1689188\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "user_count, item_count, cate_count, example_count =\\\n", 112 | " len(revi_map), len(asin_map), len(cate_map), reviews_df.shape[0]\n", 113 | "print('user_count: %d\\titem_count: %d\\tcate_count: %d\\texample_count: %d' %\n", 114 | " (user_count, item_count, cate_count, example_count))" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 21, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "meta_df = meta_df.sort_values('asin')\n", 124 | "meta_df = meta_df.reset_index(drop=True)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 22, 130 | "metadata": { 131 | "collapsed": true, 132 | "jupyter": { 133 | "outputs_hidden": true 134 | } 135 | }, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | " reviewerID asin unixReviewTime\n", 142 | "0 0 13179 1400457600\n", 143 | "1 0 17993 1400457600\n", 144 | "2 0 28326 1400457600\n", 145 | "3 0 29247 1400457600\n", 146 | "4 0 62275 1400457600\n", 147 | "5 1 58134 1379548800\n", 148 | "6 1 62555 1379548800\n", 149 | "7 1 41862 1384041600\n", 150 | "8 1 46010 1385769600\n", 151 | "9 1 54171 1385769600\n", 152 | "10 1 56540 1385769600\n", 153 | "11 2 42298 1366156800\n", 154 | "12 2 46782 1366156800\n", 155 | "13 2 50682 1366156800\n", 156 | "14 2 42390 1370563200\n", 157 | "15 2 47355 1370563200\n", 158 | "16 3 25578 1371772800\n", 159 | "17 3 21989 1375142400\n", 160 | "18 3 58444 1402876800\n", 161 | "19 3 60072 1402876800\n", 162 | "20 3 62274 1402876800\n", 163 | "21 4 54245 1359331200\n", 164 | "22 4 3112 1361145600\n", 165 | "23 4 40094 1361145600\n", 166 | "24 4 48963 1361145600\n", 167 | "25 4 30275 1389744000\n", 168 | "26 4 58671 1402358400\n", 169 | "27 4 62022 1402358400\n", 170 | "28 5 30462 1373241600\n", 171 | "29 5 55698 1373241600\n", 172 | "... ... ... ...\n", 173 | "1689158 192402 36004 1357776000\n", 174 | "1689159 192402 37977 1357776000\n", 175 | "1689160 192402 39411 1357776000\n", 176 | "1689161 192402 7681 1358035200\n", 177 | "1689162 192402 18186 1358035200\n", 178 | "1689163 192402 27522 1358035200\n", 179 | "1689164 192402 29206 1358035200\n", 180 | "1689165 192402 30547 1358035200\n", 181 | "1689166 192402 42076 1358035200\n", 182 | "1689167 192402 29518 1377907200\n", 183 | "1689168 192402 35691 1377907200\n", 184 | "1689169 192402 44123 1377907200\n", 185 | "1689170 192402 54615 1377907200\n", 186 | "1689171 192402 45465 1385856000\n", 187 | "1689172 192402 48862 1385856000\n", 188 | "1689173 192402 52445 1385856000\n", 189 | "1689174 192402 60275 1385856000\n", 190 | "1689175 192402 51004 1386633600\n", 191 | "1689176 192402 53509 1386633600\n", 192 | "1689177 192402 61519 1386633600\n", 193 | "1689178 192402 28581 1388534400\n", 194 | "1689179 192402 29369 1388534400\n", 195 | "1689180 192402 41590 1388534400\n", 196 | "1689181 192402 51306 1388534400\n", 197 | "1689182 192402 49816 1389744000\n", 198 | "1689183 192402 57576 1389744000\n", 199 | "1689184 192402 22519 1396396800\n", 200 | "1689185 192402 20977 1404172800\n", 201 | "1689186 192402 60283 1404172800\n", 202 | "1689187 192402 62677 1405123200\n", 203 | "\n", 204 | "[1689188 rows x 3 columns]\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "reviews_df['asin'] = reviews_df['asin'].map(lambda x: asin_map[x])\n", 210 | "reviews_df = reviews_df.sort_values(['reviewerID', 'unixReviewTime'])\n", 211 | "reviews_df = reviews_df.reset_index(drop=True)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 24, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "cate_list = [meta_df['categories'][i] for i in range(len(asin_map))]\n", 221 | "cate_list = np.array(cate_list, dtype=np.int32)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 25, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "with open(root+'np_prepro/remap.pkl', 'wb') as f:\n", 231 | " pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) # uid, iid\n", 232 | " pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL) # cid of iid line\n", 233 | " pickle.dump((user_count, item_count, cate_count, example_count),\n", 234 | " f, pickle.HIGHEST_PROTOCOL)\n", 235 | " pickle.dump((asin_key, cate_key, revi_key), f, pickle.HIGHEST_PROTOCOL)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": 26, 241 | "metadata": {}, 242 | "outputs": [], 243 | "source": [ 244 | "random.seed(1234)\n", 245 | "\n", 246 | "train_set = []\n", 247 | "test_set = []\n", 248 | "for reviewerID, hist in reviews_df.groupby('reviewerID'):\n", 249 | " pos_list = hist['asin'].tolist()\n", 250 | " neg_list = []\n", 251 | " for _ in range(len(pos_list)):\n", 252 | " neg = pos_list[0]\n", 253 | " while neg in pos_list + neg_list :\n", 254 | " neg = random.randint(0, item_count-1)\n", 255 | " neg_list.append(neg)\n", 256 | " \n", 257 | " for i in range(1, len(pos_list)-1):\n", 258 | " hist = pos_list[:i]\n", 259 | " train_set.append((reviewerID, hist, pos_list[i], 1))\n", 260 | " train_set.append((reviewerID, hist, neg_list[i], 0))\n", 261 | " label = (pos_list[-1], neg_list[-1])\n", 262 | " test_set.append((reviewerID, hist, label))\n", 263 | "\n", 264 | "random.shuffle(train_set)\n", 265 | "random.shuffle(test_set)\n", 266 | "\n", 267 | "assert len(test_set) == user_count\n", 268 | "\n", 269 | "with open(root+'np_prepro/dataset.pkl', 'wb') as f:\n", 270 | " pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)\n", 271 | " pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)\n", 272 | " pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL)\n", 273 | " pickle.dump((user_count, item_count, cate_count), f, pickle.HIGHEST_PROTOCOL)" 274 | ] 275 | } 276 | ], 277 | "metadata": { 278 | "kernelspec": { 279 | "display_name": "Python 3", 280 | "language": "python", 281 | "name": "python3" 282 | }, 283 | "language_info": { 284 | "codemirror_mode": { 285 | "name": "ipython", 286 | "version": 3 287 | }, 288 | "file_extension": ".py", 289 | "mimetype": "text/x-python", 290 | "name": "python", 291 | "nbconvert_exporter": "python", 292 | "pygments_lexer": "ipython3", 293 | "version": "3.6.8" 294 | } 295 | }, 296 | "nbformat": 4, 297 | "nbformat_minor": 4 298 | } 299 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def argparser(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument('--lr', default=0.1, help='learning rate', type=float) 7 | parser.add_argument('--train_batch_size', default=32, help='batch size', type=int) 8 | parser.add_argument('--test_batch_size', default=512, help='batch size', type=int) 9 | parser.add_argument('--epochs', default=10, help='number of epochs', type=int) 10 | parser.add_argument('--print_step', default=1000, help='step size for print log', type=int) 11 | 12 | parser.add_argument('--dataset_dir', default='/data/private/Ad/amazon/np_prepro/', help='dataset path') 13 | parser.add_argument('--model_path', default='./models/', help='model load path', type=str) 14 | parser.add_argument('--log_path', default='./logs/', help='log path fot tensorboard', type=str) 15 | parser.add_argument('--is_reuse', default=False) 16 | parser.add_argument('--multi_gpu', default=False) 17 | 18 | parser.add_argument('--user_count', default=192403, help='number of users', type=int) 19 | parser.add_argument('--item_count', default=63001, help='number of items', type=int) 20 | parser.add_argument('--cate_count', default=801, help='number of categories', type=int) 21 | 22 | parser.add_argument('--user_dim', default=128, help='dimension of user', type=int) 23 | parser.add_argument('--item_dim', default=64, help='dimension of item', type=int) 24 | parser.add_argument('--cate_dim', default=64, help='dimension of category', type=int) 25 | 26 | parser.add_argument('--dim_layers', default=[80,40,1], type=int) 27 | 28 | args = parser.parse_args() 29 | 30 | return args 31 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | import pickle 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from config import argparser 8 | 9 | args = argparser() 10 | 11 | with open(args.dataset_dir+'dataset.pkl', 'rb') as f: 12 | train_set = pickle.load(f, encoding='latin1') 13 | test_set = pickle.load(f, encoding='latin1') 14 | cate_list = pickle.load(f, encoding='latin1') 15 | cate_list = tf.convert_to_tensor(cate_list, dtype=tf.int64) 16 | user_count, item_count, cate_count = pickle.load(f) 17 | 18 | class DataLoader: 19 | def __init__(self, batch_size, data): 20 | self.batch_size = batch_size 21 | self.data = data 22 | self.epoch_size = len(self.data) // self.batch_size 23 | if self.epoch_size * self.batch_size < len(self.data): 24 | self.epoch_size += 1 25 | self.i = 0 26 | 27 | def __iter__(self): 28 | self.i = 0 29 | return self 30 | 31 | def __next__(self): 32 | if self.i == self.epoch_size: 33 | raise StopIteration 34 | ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size, 35 | len(self.data))] 36 | self.i += 1 37 | 38 | u, i, y, sl = [], [], [], [] 39 | for t in ts: 40 | u.append(t[0]) 41 | i.append(t[2]) 42 | y.append(t[3]) 43 | sl.append(len(t[1])) 44 | max_sl = max(sl) 45 | 46 | hist_i = np.zeros([len(ts), max_sl], np.int64) 47 | 48 | k = 0 49 | for t in ts: 50 | for l in range(len(t[1])): 51 | hist_i[k][l] = t[1][l] 52 | k += 1 53 | 54 | return tf.convert_to_tensor(u), tf.convert_to_tensor(i), \ 55 | tf.convert_to_tensor(y), tf.convert_to_tensor(hist_i), \ 56 | sl 57 | 58 | class DataLoaderTest: 59 | def __init__(self, batch_size, data): 60 | 61 | self.batch_size = batch_size 62 | self.data = data 63 | self.epoch_size = len(self.data) // self.batch_size 64 | if self.epoch_size * self.batch_size < len(self.data): 65 | self.epoch_size += 1 66 | self.i = 0 67 | 68 | def __iter__(self): 69 | self.i = 0 70 | return self 71 | 72 | def __next__(self): 73 | 74 | if self.i == self.epoch_size: 75 | raise StopIteration 76 | 77 | ts = self.data[self.i * self.batch_size : min((self.i+1) * self.batch_size, 78 | len(self.data))] 79 | self.i += 1 80 | 81 | u, i, j, sl = [], [], [], [] 82 | for t in ts: 83 | u.append(t[0]) 84 | i.append(t[2][0]) 85 | j.append(t[2][1]) 86 | sl.append(len(t[1])) 87 | max_sl = max(sl) 88 | 89 | hist_i = np.zeros([len(ts), max_sl], np.int64) 90 | 91 | k = 0 92 | for t in ts: 93 | for l in range(len(t[1])): 94 | hist_i[k][l] = t[1][l] 95 | k += 1 96 | 97 | return tf.convert_to_tensor(u), tf.convert_to_tensor(i), \ 98 | tf.convert_to_tensor(j), tf.convert_to_tensor(hist_i), \ 99 | sl 100 | 101 | def __len__(self): 102 | return len(self.data) 103 | 104 | def get_dataloader(train_batch_size, test_batch_size): 105 | return DataLoader(train_batch_size, train_set), DataLoaderTest(test_batch_size, test_set), \ 106 | user_count, item_count, cate_count, cate_list 107 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | # https://github.com/zhougr1993/DeepInterestNetwork/blob/master/din/Dice.py 2 | import tensorflow as tf 3 | import tensorflow.keras.layers as nn 4 | 5 | class attention(tf.keras.layers.Layer): 6 | def __init__(self, keys_dim, dim_layers): 7 | super(attention, self).__init__() 8 | self.keys_dim = keys_dim 9 | 10 | self.fc = tf.keras.Sequential() 11 | for dim_layer in dim_layers[:-1]: 12 | self.fc.add(nn.Dense(dim_layer, activation='sigmoid')) 13 | self.fc.add(nn.Dense(dim_layers[-1], activation=None)) 14 | 15 | def call(self, queries, keys, keys_length): 16 | queries = tf.tile(tf.expand_dims(queries, 1), [1, tf.shape(keys)[1], 1]) 17 | # outer product ? 18 | din_all = tf.concat([queries, keys, queries-keys, queries*keys], axis=-1) 19 | outputs = tf.transpose(self.fc(din_all), [0,2,1]) 20 | 21 | # Mask 22 | key_masks = tf.sequence_mask(keys_length, max(keys_length), dtype=tf.bool) # [B, T] 23 | key_masks = tf.expand_dims(key_masks, 1) 24 | paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) 25 | outputs = tf.where(key_masks, outputs, paddings) # [B, 1, T] 26 | 27 | # Scale 28 | outputs = outputs / (self.keys_dim ** 0.5) 29 | 30 | # Activation 31 | outputs = tf.keras.activations.softmax(outputs, -1) # [B, 1, T] 32 | 33 | # Weighted sum 34 | outputs = tf.squeeze(tf.matmul(outputs, keys)) # [B, H] 35 | 36 | return outputs 37 | 38 | class dice(tf.keras.layers.Layer): 39 | def __init__(self, feat_dim): 40 | super(dice, self).__init__() 41 | self.feat_dim = feat_dim 42 | self.alphas= tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32) 43 | self.beta = tf.Variable(tf.zeros([feat_dim]), dtype=tf.float32) 44 | 45 | self.bn = tf.keras.layers.BatchNormalization(center=False, scale=False) 46 | 47 | def call(self, _x, axis=-1, epsilon=0.000000001): 48 | 49 | reduction_axes = list(range(len(_x.get_shape()))) 50 | del reduction_axes[axis] 51 | broadcast_shape = [1] * len(_x.get_shape()) 52 | broadcast_shape[axis] = self.feat_dim 53 | 54 | mean = tf.reduce_mean(_x, axis=reduction_axes) 55 | brodcast_mean = tf.reshape(mean, broadcast_shape) 56 | std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes) 57 | std = tf.sqrt(std) 58 | brodcast_std = tf.reshape(std, broadcast_shape) 59 | 60 | x_normed = self.bn(_x) 61 | x_p = tf.keras.activations.sigmoid(self.beta * x_normed) 62 | 63 | return self.alphas * (1.0 - x_p) * _x + x_p * _x 64 | 65 | def parametric_relu(_x): 66 | with tf.variable_scope(name_or_scope='', reuse=tf.AUTO_REUSE): 67 | alphas = tf.get_variable('alpha', _x.get_shape()[-1], 68 | initializer=tf.constant_initializer(0.0), 69 | dtype=tf.float32) 70 | pos = tf.nn.relu(_x) 71 | neg = alphas * (_x - abs(_x)) * 0.5 72 | 73 | return pos + neg 74 | 75 | class Bilinear(tf.keras.layers.Layer): 76 | def __init__(self, units): 77 | super(Bilinear, self).__init__() 78 | self.linear_act = nn.Dense(units, activation=None, use_bias=True) 79 | self.linear_noact = nn.Dense(units, activation=None, use_bias=False) 80 | 81 | def call(self, a, b, gate_b=None): 82 | if gate_b is None: 83 | return tf.keras.activations.sigmoid(self.linear_act(a) + self.linear_noact(b)) 84 | else: 85 | return tf.keras.activations.tanh(self.linear_act(a) + tf.math.multiply(gate_b, self.linear_noact(b))) 86 | 87 | class AUGRU(tf.keras.layers.Layer): 88 | def __init__(self, units): 89 | super(AUGRU, self).__init__() 90 | 91 | self.u_gate = Bilinear(units) 92 | self.r_gate = Bilinear(units) 93 | self.c_memo = Bilinear(units) 94 | 95 | def call(self, inputs, state, att_score): 96 | u = self.u_gate(inputs, state) 97 | r = self.r_gate(inputs, state) 98 | c = self.c_memo(inputs, state, r) 99 | 100 | u_= att_score * u 101 | final = (1 - u_) * state + u_ * c 102 | 103 | return final 104 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | '''0 = all messages are logged (default behavior) 4 | 1 = INFO messages are not printed 5 | 2 = INFO and WARNING messages are not printed 6 | 3 = INFO, WARNING, and ERROR messages are not printed''' 7 | import time 8 | 9 | import tensorflow as tf 10 | 11 | from config import argparser 12 | from data import get_dataloader 13 | from model import Base, DIN, DIEN 14 | from utils import eval 15 | 16 | # Config 17 | print(tf.__version__) 18 | print("GPU Available: ", tf.test.is_gpu_available()) 19 | 20 | args = argparser() 21 | 22 | # Data Load 23 | train_data, test_data, \ 24 | user_count, item_count, cate_count, \ 25 | cate_list = get_dataloader(args.train_batch_size, args.test_batch_size) 26 | 27 | # Loss, Optim 28 | optimizer = tf.keras.optimizers.SGD(learning_rate=args.lr, momentum=0.0) 29 | loss_metric = tf.keras.metrics.Sum() 30 | auc_metric = tf.keras.metrics.AUC() 31 | 32 | # Model 33 | model = Base(user_count, item_count, cate_count, cate_list, 34 | args.user_dim, args.item_dim, args.cate_dim, args.dim_layers) 35 | 36 | # Board 37 | train_summary_writer = tf.summary.create_file_writer(args.log_path) 38 | 39 | #@tf.function 40 | def train_one_step(u,i,y,hist_i,sl): 41 | with tf.GradientTape() as tape: 42 | output,_ = model(u,i,hist_i,sl) 43 | loss = tf.reduce_mean( 44 | tf.nn.sigmoid_cross_entropy_with_logits(logits=output, 45 | labels=tf.cast(y, dtype=tf.float32))) 46 | gradient = tape.gradient(loss, model.trainable_variables) 47 | clip_gradient, _ = tf.clip_by_global_norm(gradient, 5.0) 48 | optimizer.apply_gradients(zip(clip_gradient, model.trainable_variables)) 49 | 50 | loss_metric(loss) 51 | 52 | # Train 53 | def train(optimizer): 54 | best_loss= 0. 55 | best_auc = 0. 56 | start_time = time.time() 57 | for epoch in range(args.epochs): 58 | for step, (u, i, y, hist_i, sl) in enumerate(train_data, start=1): 59 | train_one_step(u, i, y, hist_i, sl) 60 | 61 | if step % args.print_step == 0: 62 | test_gauc, auc = eval(model, test_data) 63 | print('Epoch %d Global_step %d\tTrain_loss: %.4f\tEval_GAUC: %.4f\tEval_AUC: %.4f' % 64 | (epoch, step, loss_metric.result() / args.print_step, test_gauc, auc)) 65 | 66 | if best_auc < test_gauc: 67 | best_loss= loss_metric.result() / args.print_step 68 | best_auc = test_gauc 69 | model.save_weights(args.model_path+'cp-%d.ckpt'%epoch) 70 | loss_metric.reset_states() 71 | 72 | with train_summary_writer.as_default(): 73 | tf.summary.scalar('loss', best_loss, step=epoch) 74 | tf.summary.scalar('test_gauc', best_auc, step=epoch) 75 | 76 | loss_metric.reset_states() 77 | optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0) 78 | 79 | print('Epoch %d DONE\tCost time: %.2f' % (epoch, time.time()-start_time)) 80 | print('Best test_gauc: ', best_auc) 81 | 82 | 83 | # Main 84 | if __name__ == '__main__': 85 | train(optimizer) 86 | -------------------------------------------------------------------------------- /ml_prepro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import random\n", 11 | "import numpy as np\n", 12 | "import pickle\n", 13 | "import csv\n", 14 | "import tensorflow as tf\n", 15 | "\n", 16 | "dataset_dir = '/data/private/Ad/amazon/np_prepro/'\n", 17 | "with open(dataset_dir+'dataset.pkl', 'rb') as f:\n", 18 | " train_set = pickle.load(f, encoding='latin1') # uid, [hist], vid, label\n", 19 | " test_set = pickle.load(f, encoding='latin1') # uid, [hist], pid, nid\n", 20 | " cate_list = pickle.load(f, encoding='latin1') # id2cate list\n", 21 | " cate_list = tf.convert_to_tensor(cate_list, dtype=tf.int64)\n", 22 | " user_count, item_count, cate_count = pickle.load(f) # (user_count, item_count, cate_count)\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 6, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "root = '/data/private/Ad/ml-20m/'\n", 32 | "reviews_df = pd.read_csv(root+'ratings.csv')" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 7, 38 | "metadata": { 39 | "collapsed": true, 40 | "jupyter": { 41 | "outputs_hidden": true 42 | } 43 | }, 44 | "outputs": [ 45 | { 46 | "data": { 47 | "text/html": [ 48 | "
\n", 49 | "\n", 62 | "\n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | "
userIdmovieIdratingtimestamp
0123.51112486027
11293.51112484676
21323.51112484819
31473.51112484727
41503.51112484580
511123.51094785740
611514.01094785734
712234.01112485573
812534.01112484940
912604.01112484826
1012934.01112484703
1112964.01112484767
1213184.01112484798
1313373.51094785709
1413673.51112485980
1515414.01112484603
1615893.51112485557
1715933.51112484661
1816533.01094785691
1919193.51094785621
2019243.51094785598
21110093.51112486013
22110364.01112485480
23110794.01094785665
24110803.51112485375
25110893.51112484669
26110904.01112485453
27110974.01112485701
28111363.51112484609
29111933.51112484690
...............
20000233138493508723.51256750388
20000234138493510863.51255810566
20000235138493516624.51255856908
20000236138493518844.51256294768
20000237138493525794.01255856957
20000238138493529754.01256680293
20000239138493531234.01255816320
20000240138493531253.01255810649
20000241138493533224.01255812146
20000242138493534644.01260209920
20000243138493539964.51259865104
20000244138493552695.01255816088
20000245138493558145.01255811181
20000246138493567573.01255810698
20000247138493568013.01255809988
20000248138493588794.51255816798
20000249138493593154.01255818138
20000250138493597253.01255818078
20000251138493597845.01255816901
20000252138493600694.01258134687
20000253138493608164.51259865163
20000254138493611604.01258390537
20000255138493656824.51255816373
20000256138493667624.51255805408
20000257138493683194.51260209720
20000258138493689544.51258126920
20000259138493695264.51259865108
20000260138493696443.01260209457
20000261138493702865.01258126944
20000262138493716192.51255811136
\n", 502 | "

20000263 rows × 4 columns

\n", 503 | "
" 504 | ], 505 | "text/plain": [ 506 | " userId movieId rating timestamp\n", 507 | "0 1 2 3.5 1112486027\n", 508 | "1 1 29 3.5 1112484676\n", 509 | "2 1 32 3.5 1112484819\n", 510 | "3 1 47 3.5 1112484727\n", 511 | "4 1 50 3.5 1112484580\n", 512 | "5 1 112 3.5 1094785740\n", 513 | "6 1 151 4.0 1094785734\n", 514 | "7 1 223 4.0 1112485573\n", 515 | "8 1 253 4.0 1112484940\n", 516 | "9 1 260 4.0 1112484826\n", 517 | "10 1 293 4.0 1112484703\n", 518 | "11 1 296 4.0 1112484767\n", 519 | "12 1 318 4.0 1112484798\n", 520 | "13 1 337 3.5 1094785709\n", 521 | "14 1 367 3.5 1112485980\n", 522 | "15 1 541 4.0 1112484603\n", 523 | "16 1 589 3.5 1112485557\n", 524 | "17 1 593 3.5 1112484661\n", 525 | "18 1 653 3.0 1094785691\n", 526 | "19 1 919 3.5 1094785621\n", 527 | "20 1 924 3.5 1094785598\n", 528 | "21 1 1009 3.5 1112486013\n", 529 | "22 1 1036 4.0 1112485480\n", 530 | "23 1 1079 4.0 1094785665\n", 531 | "24 1 1080 3.5 1112485375\n", 532 | "25 1 1089 3.5 1112484669\n", 533 | "26 1 1090 4.0 1112485453\n", 534 | "27 1 1097 4.0 1112485701\n", 535 | "28 1 1136 3.5 1112484609\n", 536 | "29 1 1193 3.5 1112484690\n", 537 | "... ... ... ... ...\n", 538 | "20000233 138493 50872 3.5 1256750388\n", 539 | "20000234 138493 51086 3.5 1255810566\n", 540 | "20000235 138493 51662 4.5 1255856908\n", 541 | "20000236 138493 51884 4.5 1256294768\n", 542 | "20000237 138493 52579 4.0 1255856957\n", 543 | "20000238 138493 52975 4.0 1256680293\n", 544 | "20000239 138493 53123 4.0 1255816320\n", 545 | "20000240 138493 53125 3.0 1255810649\n", 546 | "20000241 138493 53322 4.0 1255812146\n", 547 | "20000242 138493 53464 4.0 1260209920\n", 548 | "20000243 138493 53996 4.5 1259865104\n", 549 | "20000244 138493 55269 5.0 1255816088\n", 550 | "20000245 138493 55814 5.0 1255811181\n", 551 | "20000246 138493 56757 3.0 1255810698\n", 552 | "20000247 138493 56801 3.0 1255809988\n", 553 | "20000248 138493 58879 4.5 1255816798\n", 554 | "20000249 138493 59315 4.0 1255818138\n", 555 | "20000250 138493 59725 3.0 1255818078\n", 556 | "20000251 138493 59784 5.0 1255816901\n", 557 | "20000252 138493 60069 4.0 1258134687\n", 558 | "20000253 138493 60816 4.5 1259865163\n", 559 | "20000254 138493 61160 4.0 1258390537\n", 560 | "20000255 138493 65682 4.5 1255816373\n", 561 | "20000256 138493 66762 4.5 1255805408\n", 562 | "20000257 138493 68319 4.5 1260209720\n", 563 | "20000258 138493 68954 4.5 1258126920\n", 564 | "20000259 138493 69526 4.5 1259865108\n", 565 | "20000260 138493 69644 3.0 1260209457\n", 566 | "20000261 138493 70286 5.0 1258126944\n", 567 | "20000262 138493 71619 2.5 1255811136\n", 568 | "\n", 569 | "[20000263 rows x 4 columns]" 570 | ] 571 | }, 572 | "execution_count": 7, 573 | "metadata": {}, 574 | "output_type": "execute_result" 575 | } 576 | ], 577 | "source": [ 578 | "with open(root+'np_prepro/reviews.pkl', 'wb') as f:\n", 579 | " pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": 9, 585 | "metadata": {}, 586 | "outputs": [], 587 | "source": [ 588 | "meta_df = pd.read_csv(root+'movies.csv')\n", 589 | "\n", 590 | "meta_df[meta_df['movieId'].isin(reviews_df['movieId'].unique())]\n", 591 | "meta_df = meta_df.reset_index(drop=True)\n", 592 | "with open(root+'np_prepro/meta.pkl', 'wb') as f:\n", 593 | " pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL)" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 10, 599 | "metadata": { 600 | "collapsed": true, 601 | "jupyter": { 602 | "outputs_hidden": true 603 | } 604 | }, 605 | "outputs": [ 606 | { 607 | "data": { 608 | "text/html": [ 609 | "
\n", 610 | "\n", 623 | "\n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | " \n", 971 | " \n", 972 | " \n", 973 | " \n", 974 | " \n", 975 | " \n", 976 | " \n", 977 | " \n", 978 | " \n", 979 | " \n", 980 | " \n", 981 | " \n", 982 | " \n", 983 | " \n", 984 | " \n", 985 | " \n", 986 | " \n", 987 | " \n", 988 | " \n", 989 | " \n", 990 | " \n", 991 | " \n", 992 | " \n", 993 | " \n", 994 | " \n", 995 | " \n", 996 | " \n", 997 | " \n", 998 | " \n", 999 | " \n", 1000 | "
movieIdtitlegenres
01Toy Story (1995)Adventure|Animation|Children|Comedy|Fantasy
12Jumanji (1995)Adventure|Children|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama|Romance
45Father of the Bride Part II (1995)Comedy
56Heat (1995)Action|Crime|Thriller
67Sabrina (1995)Comedy|Romance
78Tom and Huck (1995)Adventure|Children
89Sudden Death (1995)Action
910GoldenEye (1995)Action|Adventure|Thriller
1011American President, The (1995)Comedy|Drama|Romance
1112Dracula: Dead and Loving It (1995)Comedy|Horror
1213Balto (1995)Adventure|Animation|Children
1314Nixon (1995)Drama
1415Cutthroat Island (1995)Action|Adventure|Romance
1516Casino (1995)Crime|Drama
1617Sense and Sensibility (1995)Drama|Romance
1718Four Rooms (1995)Comedy
1819Ace Ventura: When Nature Calls (1995)Comedy
1920Money Train (1995)Action|Comedy|Crime|Drama|Thriller
2021Get Shorty (1995)Comedy|Crime|Thriller
2122Copycat (1995)Crime|Drama|Horror|Mystery|Thriller
2223Assassins (1995)Action|Crime|Thriller
2324Powder (1995)Drama|Sci-Fi
2425Leaving Las Vegas (1995)Drama|Romance
2526Othello (1995)Drama
2627Now and Then (1995)Children|Drama
2728Persuasion (1995)Drama|Romance
2829City of Lost Children, The (Cité des enfants p...Adventure|Drama|Fantasy|Mystery|Sci-Fi
2930Shanghai Triad (Yao a yao yao dao waipo qiao) ...Crime|Drama
............
27248131146Werner - Volles Rooäää (1999)Animation|Comedy
27249131148What A Man (2011)Comedy|Romance
272501311507 Dwarves: The Forest Is Not Enough (2006)Comedy
27251131152The Fat Spy (1966)Comedy
27252131154Die Bademeister – Weiber, saufen, Leben retten...Comedy
27253131156Ants in the Pants 2 (2002)Comedy
27254131158Manta, Manta (1991)Comedy
27255131160Oscar and the Lady in Pink (2009)Drama
27256131162Por un puñado de besos (2014)Drama|Romance
27257131164Vietnam in HD (2011)War
27258131166WWII IN HD (2009)(no genres listed)
27259131168Phoenix (2014)Drama
27260131170Parallels (2015)Sci-Fi
27261131172Closed Curtain (2013)(no genres listed)
27262131174Gentlemen (2014)Drama|Romance|Thriller
27263131176A Second Chance (2014)Drama
27264131180Dead Rising: Watchtower (2015)Action|Horror|Thriller
27265131231Standby (2014)Comedy|Romance
27266131237What Men Talk About (2010)Comedy
27267131239Three Quarter Moon (2011)Comedy|Drama
27268131241Ants in the Pants (2000)Comedy|Romance
27269131243Werner - Gekotzt wird später (2003)Animation|Comedy
27270131248Brother Bear 2 (2006)Adventure|Animation|Children|Comedy|Fantasy
27271131250No More School (2000)Comedy
27272131252Forklift Driver Klaus: The First Day on the Jo...Comedy|Horror
27273131254Kein Bund für's Leben (2007)Comedy
27274131256Feuer, Eis & Dosenbier (2002)Comedy
27275131258The Pirates (2014)Adventure
27276131260Rentun Ruusu (2001)(no genres listed)
27277131262Innocence (2014)Adventure|Fantasy|Horror
\n", 1001 | "

27278 rows × 3 columns

\n", 1002 | "
" 1003 | ], 1004 | "text/plain": [ 1005 | " movieId title \\\n", 1006 | "0 1 Toy Story (1995) \n", 1007 | "1 2 Jumanji (1995) \n", 1008 | "2 3 Grumpier Old Men (1995) \n", 1009 | "3 4 Waiting to Exhale (1995) \n", 1010 | "4 5 Father of the Bride Part II (1995) \n", 1011 | "5 6 Heat (1995) \n", 1012 | "6 7 Sabrina (1995) \n", 1013 | "7 8 Tom and Huck (1995) \n", 1014 | "8 9 Sudden Death (1995) \n", 1015 | "9 10 GoldenEye (1995) \n", 1016 | "10 11 American President, The (1995) \n", 1017 | "11 12 Dracula: Dead and Loving It (1995) \n", 1018 | "12 13 Balto (1995) \n", 1019 | "13 14 Nixon (1995) \n", 1020 | "14 15 Cutthroat Island (1995) \n", 1021 | "15 16 Casino (1995) \n", 1022 | "16 17 Sense and Sensibility (1995) \n", 1023 | "17 18 Four Rooms (1995) \n", 1024 | "18 19 Ace Ventura: When Nature Calls (1995) \n", 1025 | "19 20 Money Train (1995) \n", 1026 | "20 21 Get Shorty (1995) \n", 1027 | "21 22 Copycat (1995) \n", 1028 | "22 23 Assassins (1995) \n", 1029 | "23 24 Powder (1995) \n", 1030 | "24 25 Leaving Las Vegas (1995) \n", 1031 | "25 26 Othello (1995) \n", 1032 | "26 27 Now and Then (1995) \n", 1033 | "27 28 Persuasion (1995) \n", 1034 | "28 29 City of Lost Children, The (Cité des enfants p... \n", 1035 | "29 30 Shanghai Triad (Yao a yao yao dao waipo qiao) ... \n", 1036 | "... ... ... \n", 1037 | "27248 131146 Werner - Volles Rooäää (1999) \n", 1038 | "27249 131148 What A Man (2011) \n", 1039 | "27250 131150 7 Dwarves: The Forest Is Not Enough (2006) \n", 1040 | "27251 131152 The Fat Spy (1966) \n", 1041 | "27252 131154 Die Bademeister – Weiber, saufen, Leben retten... \n", 1042 | "27253 131156 Ants in the Pants 2 (2002) \n", 1043 | "27254 131158 Manta, Manta (1991) \n", 1044 | "27255 131160 Oscar and the Lady in Pink (2009) \n", 1045 | "27256 131162 Por un puñado de besos (2014) \n", 1046 | "27257 131164 Vietnam in HD (2011) \n", 1047 | "27258 131166 WWII IN HD (2009) \n", 1048 | "27259 131168 Phoenix (2014) \n", 1049 | "27260 131170 Parallels (2015) \n", 1050 | "27261 131172 Closed Curtain (2013) \n", 1051 | "27262 131174 Gentlemen (2014) \n", 1052 | "27263 131176 A Second Chance (2014) \n", 1053 | "27264 131180 Dead Rising: Watchtower (2015) \n", 1054 | "27265 131231 Standby (2014) \n", 1055 | "27266 131237 What Men Talk About (2010) \n", 1056 | "27267 131239 Three Quarter Moon (2011) \n", 1057 | "27268 131241 Ants in the Pants (2000) \n", 1058 | "27269 131243 Werner - Gekotzt wird später (2003) \n", 1059 | "27270 131248 Brother Bear 2 (2006) \n", 1060 | "27271 131250 No More School (2000) \n", 1061 | "27272 131252 Forklift Driver Klaus: The First Day on the Jo... \n", 1062 | "27273 131254 Kein Bund für's Leben (2007) \n", 1063 | "27274 131256 Feuer, Eis & Dosenbier (2002) \n", 1064 | "27275 131258 The Pirates (2014) \n", 1065 | "27276 131260 Rentun Ruusu (2001) \n", 1066 | "27277 131262 Innocence (2014) \n", 1067 | "\n", 1068 | " genres \n", 1069 | "0 Adventure|Animation|Children|Comedy|Fantasy \n", 1070 | "1 Adventure|Children|Fantasy \n", 1071 | "2 Comedy|Romance \n", 1072 | "3 Comedy|Drama|Romance \n", 1073 | "4 Comedy \n", 1074 | "5 Action|Crime|Thriller \n", 1075 | "6 Comedy|Romance \n", 1076 | "7 Adventure|Children \n", 1077 | "8 Action \n", 1078 | "9 Action|Adventure|Thriller \n", 1079 | "10 Comedy|Drama|Romance \n", 1080 | "11 Comedy|Horror \n", 1081 | "12 Adventure|Animation|Children \n", 1082 | "13 Drama \n", 1083 | "14 Action|Adventure|Romance \n", 1084 | "15 Crime|Drama \n", 1085 | "16 Drama|Romance \n", 1086 | "17 Comedy \n", 1087 | "18 Comedy \n", 1088 | "19 Action|Comedy|Crime|Drama|Thriller \n", 1089 | "20 Comedy|Crime|Thriller \n", 1090 | "21 Crime|Drama|Horror|Mystery|Thriller \n", 1091 | "22 Action|Crime|Thriller \n", 1092 | "23 Drama|Sci-Fi \n", 1093 | "24 Drama|Romance \n", 1094 | "25 Drama \n", 1095 | "26 Children|Drama \n", 1096 | "27 Drama|Romance \n", 1097 | "28 Adventure|Drama|Fantasy|Mystery|Sci-Fi \n", 1098 | "29 Crime|Drama \n", 1099 | "... ... \n", 1100 | "27248 Animation|Comedy \n", 1101 | "27249 Comedy|Romance \n", 1102 | "27250 Comedy \n", 1103 | "27251 Comedy \n", 1104 | "27252 Comedy \n", 1105 | "27253 Comedy \n", 1106 | "27254 Comedy \n", 1107 | "27255 Drama \n", 1108 | "27256 Drama|Romance \n", 1109 | "27257 War \n", 1110 | "27258 (no genres listed) \n", 1111 | "27259 Drama \n", 1112 | "27260 Sci-Fi \n", 1113 | "27261 (no genres listed) \n", 1114 | "27262 Drama|Romance|Thriller \n", 1115 | "27263 Drama \n", 1116 | "27264 Action|Horror|Thriller \n", 1117 | "27265 Comedy|Romance \n", 1118 | "27266 Comedy \n", 1119 | "27267 Comedy|Drama \n", 1120 | "27268 Comedy|Romance \n", 1121 | "27269 Animation|Comedy \n", 1122 | "27270 Adventure|Animation|Children|Comedy|Fantasy \n", 1123 | "27271 Comedy \n", 1124 | "27272 Comedy|Horror \n", 1125 | "27273 Comedy \n", 1126 | "27274 Comedy \n", 1127 | "27275 Adventure \n", 1128 | "27276 (no genres listed) \n", 1129 | "27277 Adventure|Fantasy|Horror \n", 1130 | "\n", 1131 | "[27278 rows x 3 columns]" 1132 | ] 1133 | }, 1134 | "execution_count": 10, 1135 | "metadata": {}, 1136 | "output_type": "execute_result" 1137 | } 1138 | ], 1139 | "source": [ 1140 | "meta_df" 1141 | ] 1142 | }, 1143 | { 1144 | "cell_type": "code", 1145 | "execution_count": 36, 1146 | "metadata": {}, 1147 | "outputs": [], 1148 | "source": [ 1149 | "# with open(root+'np_prepro/reviews.pkl', 'rb') as f:\n", 1150 | "# reviews_df = pickle.load(f)\n", 1151 | "# with open(root+'np_prepro/meta.pkl', 'rb') as f:\n", 1152 | "# meta_df = pickle.load(f)" 1153 | ] 1154 | }, 1155 | { 1156 | "cell_type": "code", 1157 | "execution_count": 37, 1158 | "metadata": {}, 1159 | "outputs": [], 1160 | "source": [ 1161 | "reviews_df = reviews_df[['userId','movieId','rating','timestamp']]\n", 1162 | "reviews_df.loc[:,'rating'] = reviews_df['rating'].map(lambda x: 1 if x > 3 else 0)\n", 1163 | "meta_df = meta_df[['movieId', 'genres']]\n", 1164 | "meta_df.loc[:,'genres'] = meta_df['genres'].map(lambda x: x.split('|')[0])" 1165 | ] 1166 | }, 1167 | { 1168 | "cell_type": "code", 1169 | "execution_count": 23, 1170 | "metadata": {}, 1171 | "outputs": [], 1172 | "source": [ 1173 | "def build_map(df, col_name):\n", 1174 | " key = sorted(df[col_name].unique().tolist())\n", 1175 | " m = dict(zip(key, range(len(key))))\n", 1176 | " df.loc[:,col_name] = df[col_name].map(lambda x: m[x])\n", 1177 | " return m, key" 1178 | ] 1179 | }, 1180 | { 1181 | "cell_type": "code", 1182 | "execution_count": 38, 1183 | "metadata": {}, 1184 | "outputs": [], 1185 | "source": [ 1186 | "vid_map, vid_key = build_map(meta_df, 'movieId')\n", 1187 | "cat_map, cat_key = build_map(meta_df, 'genres')\n", 1188 | "uid_map, uid_key = build_map(reviews_df, 'userId')" 1189 | ] 1190 | }, 1191 | { 1192 | "cell_type": "code", 1193 | "execution_count": 41, 1194 | "metadata": {}, 1195 | "outputs": [ 1196 | { 1197 | "name": "stdout", 1198 | "output_type": "stream", 1199 | "text": [ 1200 | "user_count: 138493\titem_count: 27278\tcate_count: 20\texample_count: 20000263\n" 1201 | ] 1202 | } 1203 | ], 1204 | "source": [ 1205 | "user_count, item_count, cate_count, example_count =\\\n", 1206 | " len(uid_map), len(vid_map), len(cat_map), reviews_df.shape[0]\n", 1207 | "print('user_count: %d\\titem_count: %d\\tcate_count: %d\\texample_count: %d' %\n", 1208 | " (user_count, item_count, cate_count, example_count))" 1209 | ] 1210 | }, 1211 | { 1212 | "cell_type": "code", 1213 | "execution_count": 44, 1214 | "metadata": { 1215 | "collapsed": true, 1216 | "jupyter": { 1217 | "outputs_hidden": true 1218 | } 1219 | }, 1220 | "outputs": [ 1221 | { 1222 | "ename": "KeyError", 1223 | "evalue": "'userID'", 1224 | "output_type": "error", 1225 | "traceback": [ 1226 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 1227 | "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", 1228 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mreviews_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'movieId'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreviews_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'movieId'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvid_map\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mreviews_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreviews_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msort_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'userID'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'timestamp'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mreviews_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreviews_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1229 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36msort_values\u001b[0;34m(self, by, axis, ascending, inplace, kind, na_position)\u001b[0m\n\u001b[1;32m 4709\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4710\u001b[0m keys = [self._get_label_or_level_values(x, axis=axis)\n\u001b[0;32m-> 4711\u001b[0;31m for x in by]\n\u001b[0m\u001b[1;32m 4712\u001b[0m indexer = lexsort_indexer(keys, orders=ascending,\n\u001b[1;32m 4713\u001b[0m na_position=na_position)\n", 1230 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 4709\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4710\u001b[0m keys = [self._get_label_or_level_values(x, axis=axis)\n\u001b[0;32m-> 4711\u001b[0;31m for x in by]\n\u001b[0m\u001b[1;32m 4712\u001b[0m indexer = lexsort_indexer(keys, orders=ascending,\n\u001b[1;32m 4713\u001b[0m na_position=na_position)\n", 1231 | "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m_get_label_or_level_values\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1704\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_level_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1705\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1706\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1707\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1708\u001b[0m \u001b[0;31m# Check for duplicates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 1232 | "\u001b[0;31mKeyError\u001b[0m: 'userID'" 1233 | ] 1234 | } 1235 | ], 1236 | "source": [ 1237 | "meta_df = meta_df.sort_values('movieId')\n", 1238 | "meta_df = meta_df.reset_index(drop=True)\n", 1239 | "\n", 1240 | "reviews_df['movieId'] = reviews_df['movieId'].map(lambda x: vid_map[x])\n", 1241 | "reviews_df = reviews_df.sort_values(['userId', 'timestamp'])\n", 1242 | "reviews_df = reviews_df.reset_index(drop=True)" 1243 | ] 1244 | }, 1245 | { 1246 | "cell_type": "code", 1247 | "execution_count": 46, 1248 | "metadata": {}, 1249 | "outputs": [], 1250 | "source": [ 1251 | "cate_list = [meta_df['genres'][i] for i in range(len(vid_map))]\n", 1252 | "cate_list = np.array(cate_list, dtype=np.int32)" 1253 | ] 1254 | }, 1255 | { 1256 | "cell_type": "code", 1257 | "execution_count": 47, 1258 | "metadata": {}, 1259 | "outputs": [], 1260 | "source": [ 1261 | "with open(root+'np_prepro/remap.pkl', 'wb') as f:\n", 1262 | " pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL) # uid, iid\n", 1263 | " pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL) # cid of iid line\n", 1264 | " pickle.dump((user_count, item_count, cate_count, example_count),\n", 1265 | " f, pickle.HIGHEST_PROTOCOL)\n", 1266 | " pickle.dump((vid_key, cat_key, uid_key), f, pickle.HIGHEST_PROTOCOL)" 1267 | ] 1268 | }, 1269 | { 1270 | "cell_type": "code", 1271 | "execution_count": null, 1272 | "metadata": {}, 1273 | "outputs": [], 1274 | "source": [ 1275 | "pos_cnt, neg_cnt = 0, 0\n", 1276 | "for userId, hist in reviews_df.groupby('userId'):\n", 1277 | " movie_list = hist['movieId'].tolist()\n", 1278 | " label_list = hist['rating'].tolist()\n", 1279 | "\n", 1280 | " pos_cnt += sum(label_list)\n", 1281 | " neg_cnt += len(label_list) - sum(label_list)\n", 1282 | " \n", 1283 | "print(pos_cnt, neg_cnt, pos_cnt/(pos_cnt+neg_cnt))" 1284 | ] 1285 | }, 1286 | { 1287 | "cell_type": "code", 1288 | "execution_count": 61, 1289 | "metadata": {}, 1290 | "outputs": [], 1291 | "source": [ 1292 | "random.seed(1234)\n", 1293 | "\n", 1294 | "train_set = []\n", 1295 | "test_set = []\n", 1296 | "train_count = 100000\n", 1297 | "train_user = np.random.choice(user_count, train_count, replace=False)\n", 1298 | "for userId, hist in reviews_df.groupby('userId'):\n", 1299 | " movie_list = hist['movieId'].tolist()\n", 1300 | " label_list = hist['rating'].tolist()\n", 1301 | " pos_list, neg_list = [], []\n", 1302 | " for i, (v,r) in enumerate(zip(movie_list, label_list)):\n", 1303 | " if r == 1: pos_list.append(v);\n", 1304 | " else: neg_list.append(v)\n", 1305 | " \n", 1306 | " if len(pos_list) > len(neg_list):\n", 1307 | " for _ in range(len(pos_list)-len(neg_list)):\n", 1308 | " neg = pos_list[0]\n", 1309 | " while neg in pos_list + neg_list :\n", 1310 | " neg = random.randint(0, item_count-1)\n", 1311 | " neg_list.append(neg)\n", 1312 | "\n", 1313 | " if userId in train_user:\n", 1314 | " for i in range(1, len(pos_list)):\n", 1315 | " hist = pos_list[:i]\n", 1316 | " train_set.append((userId, hist, pos_list[i], 1))\n", 1317 | " train_set.append((userId, hist, neg_list[i], 0))\n", 1318 | " else:\n", 1319 | " for i in range(1, len(pos_list)):\n", 1320 | " hist = movie_list[:i]\n", 1321 | " label = (pos_list[i], neg_list[i])\n", 1322 | " test_set.append((userId, hist, label))\n", 1323 | "\n", 1324 | "random.shuffle(train_set)\n", 1325 | "random.shuffle(test_set)\n", 1326 | "\n", 1327 | "with open(root+'np_prepro/dataset.pkl', 'wb') as f:\n", 1328 | " pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)\n", 1329 | " pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)\n", 1330 | " pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL)\n", 1331 | " pickle.dump((user_count, item_count, cate_count), f, pickle.HIGHEST_PROTOCOL)" 1332 | ] 1333 | }, 1334 | { 1335 | "cell_type": "code", 1336 | "execution_count": 62, 1337 | "metadata": {}, 1338 | "outputs": [], 1339 | "source": [ 1340 | "with open(root+'dataset.pkl', 'wb') as f:\n", 1341 | " pickle.dump(train_set, f, protocol=2)\n", 1342 | " pickle.dump(test_set, f, protocol=2)\n", 1343 | " pickle.dump(cate_list, f, protocol=2)\n", 1344 | " pickle.dump((user_count, item_count, cate_count), f, protocol=2)" 1345 | ] 1346 | } 1347 | ], 1348 | "metadata": { 1349 | "kernelspec": { 1350 | "display_name": "Python 3", 1351 | "language": "python", 1352 | "name": "python3" 1353 | }, 1354 | "language_info": { 1355 | "codemirror_mode": { 1356 | "name": "ipython", 1357 | "version": 3 1358 | }, 1359 | "file_extension": ".py", 1360 | "mimetype": "text/x-python", 1361 | "name": "python", 1362 | "nbconvert_exporter": "python", 1363 | "pygments_lexer": "ipython3", 1364 | "version": "3.6.8" 1365 | } 1366 | }, 1367 | "nbformat": 4, 1368 | "nbformat_minor": 4 1369 | } 1370 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.keras.layers as nn 4 | 5 | from layer import attention, dice, AUGRU 6 | from utils import sequence_mask 7 | 8 | class Base(tf.keras.Model): 9 | def __init__(self, user_count, item_count, cate_count, cate_list, 10 | user_dim, item_dim, cate_dim, 11 | dim_layers): 12 | super(Base, self).__init__() 13 | self.item_dim = item_dim 14 | self.cate_dim = cate_dim 15 | 16 | self.user_emb = nn.Embedding(user_count, user_dim) 17 | self.item_emb = nn.Embedding(item_count, item_dim) 18 | self.cate_emb = nn.Embedding(cate_count, cate_dim) 19 | self.item_bias= tf.Variable(tf.zeros([item_count]), trainable=True) 20 | self.cate_list = cate_list 21 | 22 | self.hist_bn = nn.BatchNormalization() 23 | self.hist_fc = nn.Dense(item_dim+cate_dim) 24 | 25 | self.fc = tf.keras.Sequential() 26 | self.fc.add(nn.BatchNormalization()) 27 | for dim_layer in dim_layers[:-1]: 28 | self.fc.add(nn.Dense(dim_layer, activation='sigmoid')) 29 | self.fc.add(nn.Dense(dim_layers[-1], activation=None)) 30 | 31 | def get_emb(self, user, item, history): 32 | user_emb = self.user_emb(user) 33 | 34 | item_emb = self.item_emb(item) 35 | item_cate_emb = self.cate_emb(tf.gather(self.cate_list, item)) 36 | item_join_emb = tf.concat([item_emb, item_cate_emb], -1) 37 | item_bias= tf.gather(self.item_bias, item) 38 | 39 | hist_emb = self.item_emb(history) 40 | hist_cate_emb = self.cate_emb(tf.gather(self.cate_list, history)) 41 | hist_join_emb = tf.concat([hist_emb, hist_cate_emb], -1) 42 | 43 | return user_emb, item_join_emb, item_bias, hist_join_emb 44 | 45 | def call(self, user, item, history, length): 46 | user_emb, item_join_emb, item_bias, hist_join_emb = self.get_emb(user, item, history) 47 | 48 | hist_mask = tf.sequence_mask(length, max(length), dtype=tf.float32) 49 | hist_mask = tf.tile(tf.expand_dims(hist_mask, -1), (1,1,self.item_dim+self.cate_dim)) 50 | hist_join_emb = tf.math.multiply(hist_join_emb, hist_mask) 51 | hist_join_emb = tf.reduce_sum(hist_join_emb, 1) 52 | hist_join_emb = tf.math.divide(hist_join_emb, tf.cast(tf.tile(tf.expand_dims(length, -1), 53 | [1,self.item_dim+self.cate_dim]), tf.float32)) 54 | 55 | hist_hid_emb = self.hist_fc(self.hist_bn(hist_join_emb)) 56 | join_emb = tf.concat([user_emb, item_join_emb, hist_hid_emb], -1) 57 | 58 | output = tf.squeeze(self.fc(join_emb)) + item_bias 59 | logit = tf.keras.activations.sigmoid(output) 60 | 61 | return output, logit 62 | 63 | 64 | class DIN(Base): 65 | def __init__(self, user_count, item_count, cate_count, cate_list, 66 | user_dim, item_dim, cate_dim, 67 | dim_layers): 68 | super(DIN, self).__init__(user_count, item_count, cate_count, cate_list, 69 | user_dim, item_dim, cate_dim, 70 | dim_layers) 71 | 72 | self.hist_at = attention(item_dim+cate_dim, dim_layers) 73 | 74 | self.fc = tf.keras.Sequential() 75 | self.fc.add(nn.BatchNormalization()) 76 | for dim_layer in dim_layers[:-1]: 77 | self.fc.add(nn.Dense(dim_layer, activation=None)) 78 | self.fc.add(dice(dim_layer)) 79 | self.fc.add(nn.Dense(dim_layers[-1], activation=None)) 80 | 81 | def call(self, user, item, history, length): 82 | user_emb, item_join_emb, item_bias, hist_join_emb = self.get_emb(user, item, history) 83 | 84 | hist_attn_emb = self.hist_at(item_join_emb, hist_join_emb, length) 85 | hist_attn_emb = self.hist_fc(self.hist_bn(hist_attn_emb)) 86 | 87 | join_emb = tf.concat([user_emb, item_join_emb, hist_attn_emb], -1) 88 | 89 | output = tf.squeeze(self.fc(join_emb)) + item_bias 90 | logit = tf.keras.activations.sigmoid(output) 91 | 92 | return output, logit 93 | 94 | class DIEN(Base): 95 | def __init__(self, user_count, item_count, cate_count, cate_list, 96 | user_dim, item_dim, cate_dim, 97 | dim_layers): 98 | super(DIEN, self).__init__(user_count, item_count, cate_count, cate_list, 99 | user_dim, item_dim, cate_dim, 100 | dim_layers) 101 | 102 | self.hist_gru = nn.GRU(item_dim+cate_dim, return_sequences=True) 103 | self.hist_augru = AUGRU(item_dim+cate_dim) 104 | 105 | def call(self, user, item, history, length): 106 | user_emb, item_join_emb, item_bias, hist_join_emb = self.get_emb(user, item, history) 107 | 108 | hist_gru_emb = self.hist_gru(hist_join_emb) 109 | hist_mask = tf.sequence_mask(length, max(length), dtype=tf.bool) 110 | hist_mask = tf.tile(tf.expand_dims(hist_mask, -1), (1,1,self.item_dim+self.cate_dim)) 111 | hist_attn = tf.nn.softmax(tf.matmul(tf.expand_dims(item_join_emb, 1), hist_gru_emb, transpose_b=True)) 112 | 113 | hist_hid_emb = tf.zeros_like(hist_gru_emb[:,0,:]) 114 | for in_emb, in_att in zip(tf.transpose(hist_gru_emb, [1,0,2]), 115 | tf.transpose(hist_attn, [2,0,1])): 116 | hist_hid_emb = self.hist_augru(in_emb, hist_hid_emb, in_att) 117 | 118 | join_emb = tf.concat([user_emb, item_join_emb, hist_hid_emb], -1) 119 | 120 | output = tf.squeeze(self.fc(join_emb)) + item_bias 121 | logit = tf.keras.activations.sigmoid(output) 122 | 123 | return output, logit 124 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def calc_auc(raw_arr): 5 | """Summary 6 | Args: 7 | raw_arr (TYPE): Description 8 | Returns: 9 | TYPE: Description 10 | """ 11 | # sort by pred value, from small to big 12 | arr = sorted(raw_arr, key=lambda d:d[2]) 13 | 14 | auc = 0.0 15 | fp1, tp1, fp2, tp2 = 0.0, 0.0, 0.0, 0.0 16 | for record in arr: 17 | fp2 += record[0] # noclick 18 | tp2 += record[1] # click 19 | auc += (fp2 - fp1) * (tp2 + tp1) 20 | fp1, tp1 = fp2, tp2 21 | 22 | # if all nonclick or click, disgard 23 | threshold = len(arr) - 1e-3 24 | if tp2 > threshold or fp2 > threshold: 25 | return -0.5 26 | 27 | if tp2 * fp2 > 0.0: # normal auc 28 | return (1.0 - auc / (2.0 * tp2 * fp2)) 29 | else: 30 | return None 31 | 32 | def auc_arr(score_p, score_n): 33 | score_arr = [] 34 | for s in score_p.numpy(): 35 | score_arr.append([0, 1, s]) 36 | for s in score_n.numpy(): 37 | score_arr.append([1, 0, s]) 38 | return score_arr 39 | 40 | def eval(model, test_data): 41 | auc_sum = 0.0 42 | score_arr = [] 43 | for u, i, j, hist_i, sl in test_data: 44 | p_out, p_logit = model(u,i,hist_i,sl) 45 | n_out, n_logit = model(u,j,hist_i,sl) 46 | mf_auc = tf.reduce_sum(tf.cast(p_out>n_out, dtype=tf.float32)) 47 | 48 | score_arr += auc_arr(p_logit, n_logit) 49 | auc_sum += mf_auc 50 | test_gauc = auc_sum / len(test_data) 51 | auc = calc_auc(score_arr) 52 | return test_gauc, auc 53 | 54 | def sequence_mask(lengths, maxlen=None, dtype=tf.bool): 55 | """Returns a mask tensor representing the first N positions of each cell. 56 | 57 | If `lengths` has shape `[d_1, d_2, ..., d_n]` the resulting tensor `mask` has 58 | dtype `dtype` and shape `[d_1, d_2, ..., d_n, maxlen]`, with 59 | 60 | ``` 61 | mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) 62 | ``` 63 | 64 | Examples: 65 | 66 | ```python 67 | tf.sequence_mask([1, 3, 2], 5) # [[True, False, False, False, False], 68 | # [True, True, True, False, False], 69 | # [True, True, False, False, False]] 70 | tf.sequence_mask([[1, 3],[2,0]]) # [[[True, False, False], 71 | # [True, True, True]], 72 | # [[True, True, False], 73 | # [False, False, False]]] 74 | ``` 75 | 76 | Args: 77 | lengths: integer tensor, all its values <= maxlen. 78 | maxlen: scalar integer tensor, size of last dimension of returned tensor. 79 | Default is the maximum value in `lengths`. 80 | dtype: output type of the resulting tensor. 81 | name: name of the op. 82 | 83 | Returns: 84 | A mask tensor of shape `lengths.shape + (maxlen,)`, cast to specified dtype. 85 | Raises: 86 | ValueError: if `maxlen` is not a scalar. 87 | """ 88 | # lengths = lengths.numpy() 89 | 90 | if maxlen is None: 91 | maxlen = max(lengths) 92 | # else: 93 | # maxlen = maxlen 94 | # if maxlen.get_shape().ndims is not None and maxlen.get_shape().ndims != 0: 95 | # raise ValueError("maxlen must be scalar for sequence_mask") 96 | 97 | # The basic idea is to compare a range row vector of size maxlen: 98 | # [0, 1, 2, 3, 4] 99 | # to length as a matrix with 1 column: [[1], [3], [2]]. 100 | # Because of broadcasting on both arguments this comparison results 101 | # in a matrix of size (len(lengths), maxlen) 102 | row_vector = range(maxlen) 103 | # Since maxlen >= max(lengths), it is safe to use maxlen as a cast 104 | # authoritative type. Whenever maxlen fits into tf.int32, so do the lengths. 105 | matrix = np.expand_dims(lengths, -1) 106 | result = row_vector < matrix 107 | 108 | if dtype is None: 109 | return tf.convert_to_tensor(result) 110 | else: 111 | return tf.cast(tf.convert_to_tensor(result), dtype) 112 | --------------------------------------------------------------------------------