├── .gitignore ├── README.md ├── code ├── baselines.py ├── dataloader.py ├── elastic_client.py ├── preprocess_alipay.py ├── preprocess_taobao.py ├── preprocess_tmall.py ├── rec.py ├── rnn.py ├── train.py ├── train_baselines.py ├── ubr.py └── utils.py └── data └── taobao └── raw_data └── UserBehavior.csv /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .vscode/ 107 | *.png 108 | logs*/ 109 | *.pdf 110 | save_model*/ 111 | summary*/ 112 | .DS_Store 113 | *swp 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # User Behavior Retrieval for CTR Prediction (UBR4CTR) 2 | A `tensorflow` implementation of all the compared models for our SIGIR 2020 paper: 3 | 4 | [User Behavior Retrieval for Click-Through Rate Prediction](https://arxiv.org/pdf/2005.14171.pdf) 5 | 6 | If you have any questions, please contact the author: [Jiarui Qin](http://jiaruiqin.me). 7 | 8 | 9 | ## Abstract 10 | > Click-through rate (CTR) prediction plays a key role in modern online personalization services. 11 | In practice, it is necessary to capture user's drifting interests by modeling sequential user behaviors to build an accurate CTR prediction model. 12 | However, as the users accumulate more and more behavioral data on the platform, it becomes non-trivial for the sequential models to make use of the whole behavior history of each user. First, directly feeding the long behavior sequence will make online inference time and system load infeasible. Second, there is much noise in such long histories to fail the sequential model learning. 13 | The current industrial solutions mainly truncate the sequences and just feed recent behaviors to the prediction model, which leads to a problem that sequential patterns such as periodicity or long-term dependency are not embedded in the recent several behaviors but in far back history. 14 | To tackle these issues, in this paper we consider it from the data perspective instead of just designing more sophisticated yet complicated models and propose User Behavior Retrieval for CTR prediction (UBR4CTR) framework. In UBR4CTR, the most relevant and appropriate user behaviors will be firstly retrieved from the entire user history sequence using a learnable search method. These retrieved behaviors are then fed into a deep model to make the final prediction instead of simply using the most recent ones. It is highly feasible to deploy UBR4CTR into industrial model pipeline with low cost. Experiments on three real-world large-scale datasets demonstrate the superiority and efficacy of our proposed framework and models. 15 | 16 | ## Citation 17 | ``` 18 | @inproceedings{qin2020user, 19 | title={User Behavior Retrieval for Click-Through Rate Prediction}, 20 | author={Qin, Jiarui and Zhang, Weinan and Wu, Xin and Jin, Jiarui and Fang, Yuchen and Yu, Yong}, 21 | booktitle={Proceedings of the 43rd International ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR ’20)}, 22 | year={2020}, 23 | organization={ACM} 24 | } 25 | ``` 26 | ## Dependencies 27 | - [Tensorflow](https://www.tensorflow.org) >= 1.4 28 | - [Python](https://www.python.org) >= 3.5 29 | - [Elastic Search](https://www.elastic.co) 30 | - [numpy](https://numpy.org) 31 | - [sklearn](https://scikit-learn.org) 32 | 33 | ## Data Preparation & Preprocessing 34 | - We give a sample raw data in the `data` folder. The full raw datasets are: [Tmall](https://tianchi.aliyun.com/dataset/dataDetail?dataId=42), [Taobao](https://tianchi.aliyun.com/dataset/dataDetail?dataId=649) and [Alipay](https://tianchi.aliyun.com/dataset/dataDetail?dataId=53). **Remove the first line of table head**. 35 | - Feature Engineering: 36 | ``` 37 | python3 preprocess_tmall.py # for Tmall 38 | python3 preprocess_taobao.py # for Taobao 39 | python3 preprocess_alipay.py # for Alipay 40 | ``` 41 | 42 | 43 | ## Train the Models 44 | - To run UBR4CTR, rec_model=['RecAtt', 'RecSum'], ubr_model=['UBR_SA'] 45 | ``` 46 | python3 train.py [rec_model] [ubr_model] [gpu] [dataset] 47 | ``` 48 | 49 | - To run baselines, model_name=['GRU4Rec', 'Caser', 'SASRec', 'HPMN', 'MIMN', 'DIN', 'DIEN']: 50 | ``` 51 | python3 train_baseline.py [model_name] [gpu] [dataset] 52 | ``` 53 | -------------------------------------------------------------------------------- /code/baselines.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn_cell import GRUCell 3 | from utils import MIMNCell, VecAttGRUCell 4 | from rnn import dynamic_rnn 5 | 6 | class BaseModel(object): 7 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 8 | # reset graph 9 | tf.reset_default_graph() 10 | 11 | # input placeholders 12 | with tf.name_scope('inputs'): 13 | self.user_seq_ph = tf.placeholder(tf.int32, [None, max_time_len, item_fnum], name='user_seq_ph') 14 | self.user_seq_length_ph = tf.placeholder(tf.int32, [None,], name='user_seq_length_ph') 15 | self.target_user_ph = tf.placeholder(tf.int32, [None, user_fnum], name='target_user_ph') 16 | self.target_item_ph = tf.placeholder(tf.int32, [None, item_fnum], name='target_item_ph') 17 | self.label_ph = tf.placeholder(tf.int32, [None,], name='label_ph') 18 | 19 | # lr 20 | self.lr = tf.placeholder(tf.float32, []) 21 | # reg lambda 22 | self.reg_lambda = tf.placeholder(tf.float32, []) 23 | # keep prob 24 | self.keep_prob = tf.placeholder(tf.float32, []) 25 | 26 | # embedding 27 | with tf.name_scope('embedding'): 28 | if emb_initializer is not None: 29 | self.emb_mtx = tf.get_variable('emb_mtx', initializer=emb_initializer) 30 | else: 31 | self.emb_mtx = tf.get_variable('emb_mtx', [feature_size, eb_dim], initializer=tf.truncated_normal_initializer) 32 | self.emb_mtx_mask = tf.constant(value=1., shape=[feature_size - 1, eb_dim]) 33 | self.emb_mtx_mask = tf.concat([tf.constant(value=0., shape=[1, eb_dim]), self.emb_mtx_mask], axis=0) 34 | self.emb_mtx = self.emb_mtx * self.emb_mtx_mask 35 | 36 | self.user_seq = tf.nn.embedding_lookup(self.emb_mtx, self.user_seq_ph) 37 | self.user_seq = tf.reshape(self.user_seq, [-1, max_time_len, item_fnum * eb_dim]) 38 | self.target_item = tf.nn.embedding_lookup(self.emb_mtx, self.target_item_ph) 39 | self.target_item = tf.reshape(self.target_item, [-1, item_fnum * eb_dim]) 40 | self.target_user = tf.nn.embedding_lookup(self.emb_mtx, self.target_user_ph) 41 | self.target_user = tf.reshape(self.target_user, [-1, user_fnum * eb_dim]) 42 | 43 | def build_fc_net(self, inp): 44 | bn1 = tf.layers.batch_normalization(inputs=inp, name='bn1') 45 | fc1 = tf.layers.dense(bn1, 200, activation=tf.nn.relu, name='fc1') 46 | dp1 = tf.nn.dropout(fc1, self.keep_prob, name='dp1') 47 | fc2 = tf.layers.dense(dp1, 80, activation=tf.nn.relu, name='fc2') 48 | dp2 = tf.nn.dropout(fc2, self.keep_prob, name='dp2') 49 | fc3 = tf.layers.dense(dp2, 2, activation=None, name='fc3') 50 | score = tf.nn.softmax(fc3) 51 | # output 52 | self.y_pred = tf.reshape(score[:,0], [-1,]) 53 | 54 | def build_logloss(self): 55 | # loss 56 | self.log_loss = tf.losses.log_loss(self.label_ph, self.y_pred) 57 | self.loss = self.log_loss 58 | for v in tf.trainable_variables(): 59 | if 'bias' not in v.name and 'emb' not in v.name: 60 | self.loss += self.reg_lambda * tf.nn.l2_loss(v) 61 | # optimizer and training step 62 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 63 | self.train_step = self.optimizer.minimize(self.loss) 64 | 65 | def build_mseloss(self): 66 | self.loss = tf.losses.mean_squared_error(self.label_ph, self.y_pred) 67 | # regularization term 68 | for v in tf.trainable_variables(): 69 | if 'bias' not in v.name and 'emb' not in v.name: 70 | self.loss += self.reg_lambda * tf.nn.l2_loss(v) 71 | # optimizer and training step 72 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 73 | self.train_step = self.optimizer.minimize(self.loss) 74 | 75 | def train(self, sess, batch_data, lr, reg_lambda): 76 | loss, _ = sess.run([self.loss, self.train_step], feed_dict = { 77 | self.user_seq_ph : batch_data[0], 78 | self.user_seq_length_ph : batch_data[1], 79 | self.target_user_ph : batch_data[2], 80 | self.target_item_ph : batch_data[3], 81 | self.label_ph : batch_data[4], 82 | self.lr : lr, 83 | self.reg_lambda : reg_lambda, 84 | self.keep_prob : 0.8 85 | }) 86 | return loss 87 | 88 | def eval(self, sess, batch_data, reg_lambda): 89 | pred, label, loss = sess.run([self.y_pred, self.label_ph, self.loss], feed_dict = { 90 | self.user_seq_ph : batch_data[0], 91 | self.user_seq_length_ph : batch_data[1], 92 | self.target_user_ph : batch_data[2], 93 | self.target_item_ph : batch_data[3], 94 | self.label_ph : batch_data[4], 95 | self.reg_lambda : reg_lambda, 96 | self.keep_prob : 1. 97 | }) 98 | 99 | return pred.reshape([-1,]).tolist(), label.reshape([-1,]).tolist(), loss 100 | 101 | def save(self, sess, path): 102 | saver = tf.train.Saver() 103 | saver.save(sess, save_path=path) 104 | 105 | def restore(self, sess, path): 106 | saver = tf.train.Saver() 107 | saver.restore(sess, save_path=path) 108 | print('model restored from {}'.format(path)) 109 | 110 | class SumPooling(BaseModel): 111 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 112 | super(SumPooling, self).\ 113 | __init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 114 | 115 | # use sum pooling to model the user behaviors, padding is zero (embedding id is also zero) 116 | user_behavior_rep = tf.reduce_sum(self.user_seq, axis=1) 117 | 118 | inp = tf.concat([user_behavior_rep, self.target_item, self.target_user], axis=1) 119 | 120 | # fc layer 121 | self.build_fc_net(inp) 122 | self.build_logloss() 123 | 124 | class GRU4Rec(BaseModel): 125 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 126 | super(GRU4Rec, self).__init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 127 | 128 | # GRU 129 | with tf.name_scope('rnn'): 130 | user_seq_ht, user_seq_final_state = tf.nn.dynamic_rnn(GRUCell(hidden_size), inputs=self.user_seq, 131 | sequence_length=self.user_seq_length_ph, dtype=tf.float32, scope='gru1') 132 | # _, user_seq_final_state = tf.nn.dynamic_rnn(GRUCell(hidden_size), inputs=user_seq_ht, 133 | # sequence_length=self.user_seq_length_ph, dtype=tf.float32, scope='gru2') 134 | 135 | inp = tf.concat([user_seq_final_state, self.target_item, self.target_user], axis=1) 136 | 137 | # fc layer 138 | self.build_fc_net(inp) 139 | self.build_logloss() 140 | 141 | 142 | class Caser(BaseModel): 143 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 144 | super(Caser, self).__init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 145 | 146 | with tf.name_scope('user_seq_cnn'): 147 | # horizontal filters 148 | filters_user = 4 149 | h_kernel_size_user = [5, eb_dim * item_fnum] 150 | v_kernel_size_user = [self.user_seq.get_shape().as_list()[1], 1] 151 | 152 | self.user_seq = tf.expand_dims(self.user_seq, 3) 153 | conv1 = tf.layers.conv2d(self.user_seq, filters_user, h_kernel_size_user) 154 | max1 = tf.layers.max_pooling2d(conv1, [conv1.get_shape().as_list()[1], 1], 1) 155 | user_hori_out = tf.reshape(max1, [-1, filters_user]) #[B, F] 156 | 157 | # vertical 158 | conv2 = tf.layers.conv2d(self.user_seq, filters_user, v_kernel_size_user) 159 | conv2 = tf.reshape(conv2, [-1, eb_dim * item_fnum, filters_user]) 160 | user_vert_out = tf.reshape(tf.layers.dense(conv2, 1), [-1, eb_dim * item_fnum]) 161 | 162 | inp = tf.concat([user_hori_out, user_vert_out, self.target_item, self.target_user], axis=1) 163 | 164 | # fully connected layer 165 | self.build_fc_net(inp) 166 | self.build_logloss() 167 | 168 | class DIN(BaseModel): 169 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 170 | super(DIN, self).__init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 171 | mask = tf.sequence_mask(self.user_seq_length_ph, max_time_len, dtype=tf.float32) 172 | _, user_behavior_rep = self.attention(self.user_seq, self.user_seq, self.target_item, mask) 173 | 174 | inp = tf.concat([user_behavior_rep, self.target_user, self.target_item], axis=1) 175 | 176 | # fc layer 177 | self.build_fc_net(inp) 178 | self.build_logloss() 179 | 180 | 181 | def attention(self, key, value, query, mask): 182 | # key: [B, T, Dk], query: [B, Dq], mask: [B, T] 183 | _, max_len, k_dim = key.get_shape().as_list() 184 | query = tf.layers.dense(query, k_dim, activation=None) 185 | queries = tf.tile(tf.expand_dims(query, 1), [1, max_len, 1]) # [B, T, Dk] 186 | kq_inter = queries * key 187 | atten = tf.reduce_sum(kq_inter, axis=2) 188 | 189 | mask = tf.equal(mask, tf.ones_like(mask)) #[B, T] 190 | paddings = tf.ones_like(atten) * (-2 ** 32 + 1) 191 | atten = tf.nn.softmax(tf.where(mask, atten, paddings)) #[B, T] 192 | atten = tf.expand_dims(atten, 2) 193 | 194 | res = tf.reduce_sum(atten * value, axis=1) 195 | return atten, res 196 | 197 | class HPMN(BaseModel): 198 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 199 | super(HPMN, self).__init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 200 | self.layer_num = 3 201 | self.split_by = 2 202 | self.memory =[] 203 | with tf.name_scope('rnn'): 204 | inp = self.user_seq 205 | length = max_time_len 206 | for i in range(self.layer_num): 207 | user_seq_ht, user_seq_final_state = tf.nn.dynamic_rnn(GRUCell(hidden_size), inputs=inp, dtype=tf.float32, scope='GRU%s' % i) 208 | 209 | user_seq_final_state = tf.expand_dims(user_seq_final_state, 1) 210 | self.memory.append(user_seq_final_state) 211 | 212 | length = int(length / self.split_by) 213 | user_seq_ht = tf.reshape(user_seq_ht, [-1, length, self.split_by, hidden_size]) 214 | inp = tf.reshape(tf.gather(user_seq_ht, [self.split_by - 1], axis=2), [-1, length, hidden_size]) 215 | 216 | self.memory = tf.concat(self.memory, axis=1) 217 | _, output = self.attention(self.memory, self.memory, self.target_item) 218 | self.repre = tf.concat([self.target_user, self.target_item, output], axis=1) 219 | self.build_fc_net(self.repre) 220 | self.build_loss() 221 | 222 | def attention(self, key, value, query): 223 | # key: [B, T, Dk], query: [B, Dq] 224 | _, max_len, k_dim = key.get_shape().as_list() 225 | query = tf.layers.dense(query, k_dim, activation=None) 226 | queries = tf.tile(tf.expand_dims(query, 1), [1, max_len, 1]) # [B, T, Dk] 227 | kq_inter = queries * key 228 | atten = tf.reduce_sum(kq_inter, axis=2) 229 | 230 | atten = tf.nn.softmax(atten) #[B, T] 231 | atten = tf.expand_dims(atten, 2) 232 | 233 | res = tf.reduce_sum(atten * value, axis=1) 234 | return atten, res 235 | 236 | def get_covreg(self, memory, k): 237 | mean = tf.reduce_mean(memory, axis=2, keep_dims=True) 238 | C = memory - mean 239 | C = tf.matmul(C, tf.transpose(C, [0, 2, 1])) / tf.cast( 240 | tf.shape(memory)[2], tf.float32) 241 | C_diag = tf.linalg.diag_part(C) 242 | C_diag = tf.linalg.diag(C_diag) 243 | C = C - C_diag 244 | norm = tf.norm(C, ord='fro', axis=[1, 2]) 245 | return tf.reduce_sum(norm) 246 | 247 | def build_loss(self): 248 | # loss 249 | self.log_loss = tf.losses.log_loss(self.label_ph, self.y_pred) 250 | self.loss = self.log_loss 251 | # l2 norm 252 | for v in tf.trainable_variables(): 253 | if 'bias' not in v.name and 'emb' not in v.name: 254 | self.loss += self.reg_lambda * tf.nn.l2_loss(v) 255 | # covreg loss 256 | self.loss += self.reg_lambda * self.get_covreg(self.memory, self.layer_num) 257 | 258 | # optimizer and training step 259 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 260 | self.train_step = self.optimizer.minimize(self.loss) 261 | 262 | 263 | class MIMN(BaseModel): 264 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 265 | super(MIMN, self).__init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 266 | 267 | with tf.name_scope('inputs'): 268 | self.batch_size = tf.placeholder(tf.int32, [], name='batch_size_ph') 269 | batch_size = self.batch_size 270 | 271 | cell = MIMNCell(hidden_size, item_fnum * eb_dim, batch_size) 272 | 273 | state = cell.zero_state(batch_size) 274 | 275 | for t in range(max_time_len): 276 | _, state = cell(self.user_seq[:, t, :], state) 277 | 278 | # [batch_size, memory_size, fnum * eb_dim] -> [batch_size, fnum * eb_dim] 279 | mean_memory = tf.reduce_mean(state['sum_aggre'], axis=1) 280 | 281 | read_out, _ = cell(self.target_item, state) 282 | 283 | self.item_his_eb_sum = tf.reduce_sum(self.user_seq, 1) # [batch_size, fnum * eb_dim] 284 | inp = tf.concat([self.target_item, self.item_his_eb_sum, read_out, mean_memory * self.target_item], 1) 285 | self.build_fc_net(inp) 286 | self.build_logloss() 287 | 288 | # ''' 289 | def train(self, sess, batch_data, lr, reg_lambda): 290 | loss, _ = sess.run([self.loss, self.train_step], feed_dict={ 291 | self.user_seq_ph: batch_data[0], 292 | self.user_seq_length_ph: batch_data[1], 293 | self.target_user_ph: batch_data[2], 294 | self.target_item_ph: batch_data[3], 295 | self.label_ph: batch_data[4], 296 | self.lr: lr, 297 | self.reg_lambda: reg_lambda, 298 | self.keep_prob: 0.8, 299 | self.batch_size: len(batch_data[0]) 300 | }) 301 | return loss 302 | 303 | def eval(self, sess, batch_data, reg_lambda): 304 | pred, label, loss = sess.run([self.y_pred, self.label_ph, self.loss], feed_dict={ 305 | self.user_seq_ph: batch_data[0], 306 | self.user_seq_length_ph: batch_data[1], 307 | self.target_user_ph: batch_data[2], 308 | self.target_item_ph: batch_data[3], 309 | self.label_ph: batch_data[4], 310 | self.reg_lambda: reg_lambda, 311 | self.keep_prob: 1., 312 | self.batch_size: len(batch_data[0]) 313 | }) 314 | 315 | return pred.reshape([-1, ]).tolist(), label.reshape([-1, ]).tolist(), loss 316 | # ''' 317 | 318 | class DIEN(BaseModel): 319 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 320 | super(DIEN, self).__init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 321 | mask = tf.sequence_mask(self.user_seq_length_ph, max_time_len, dtype=tf.float32) 322 | 323 | # attention RNN layer 324 | with tf.name_scope('rnn_1'): 325 | user_seq_ht, _ = tf.nn.dynamic_rnn(GRUCell(hidden_size), inputs=self.user_seq, 326 | sequence_length=self.user_seq_length_ph, dtype=tf.float32, scope='gru1') 327 | with tf.name_scope('attention'): 328 | atten_score, _, = self.attention(user_seq_ht, user_seq_ht, self.target_item, mask) 329 | with tf.name_scope('rnn_2'): 330 | _, seq_rep = dynamic_rnn(VecAttGRUCell(hidden_size), inputs=user_seq_ht, 331 | att_scores = atten_score, 332 | sequence_length=self.user_seq_length_ph, dtype=tf.float32, scope="argru1") 333 | 334 | inp = tf.concat([seq_rep, self.target_user, self.target_item], axis=1) 335 | 336 | # fully connected layer 337 | self.build_fc_net(inp) 338 | self.build_logloss() 339 | 340 | def attention(self, key, value, query, mask): 341 | # key: [B, T, Dk], query: [B, Dq], mask: [B, T] 342 | _, max_len, k_dim = key.get_shape().as_list() 343 | query = tf.layers.dense(query, k_dim, activation=None) 344 | queries = tf.tile(tf.expand_dims(query, 1), [1, max_len, 1]) # [B, T, Dk] 345 | kq_inter = queries * key 346 | atten = tf.reduce_sum(kq_inter, axis=2) 347 | 348 | mask = tf.equal(mask, tf.ones_like(mask)) #[B, T] 349 | paddings = tf.ones_like(atten) * (-2 ** 32 + 1) 350 | atten = tf.nn.softmax(tf.where(mask, atten, paddings)) #[B, T] 351 | atten = tf.expand_dims(atten, 2) 352 | 353 | res = tf.reduce_sum(atten * value, axis=1) 354 | return atten, res 355 | 356 | class SASRec(BaseModel): 357 | def __init__(self, feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer): 358 | super(SASRec, self).__init__(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 359 | self.user_seq = self.multihead_attention(self.normalize(self.user_seq), self.user_seq) 360 | self.target_user_t = tf.tile(tf.expand_dims(self.target_user, 1), [1, max_time_len, 1]) 361 | self.target_item_t = tf.tile(tf.expand_dims(self.target_item, 1), [1, max_time_len, 1]) 362 | 363 | self.mask = tf.expand_dims(tf.sequence_mask(self.user_seq_length_ph, max_time_len, dtype=tf.float32), axis=-1) 364 | self.mask_1 = tf.expand_dims(tf.sequence_mask(self.user_seq_length_ph - 1, max_time_len, dtype=tf.float32), axis=-1) 365 | self.get_mask = self.mask - self.mask_1 366 | self.seq_rep = self.user_seq * self.mask 367 | self.final_pred_rep = tf.reduce_sum(self.user_seq * self.get_mask, axis=1) 368 | 369 | # pos and neg for sequence 370 | self.pos = self.user_seq[:, 1:, :] 371 | self.neg = self.user_seq[:, 2:, :] 372 | 373 | self.target_user_t = tf.tile(tf.expand_dims(self.target_user, 1), [1, max_time_len, 1]) 374 | 375 | self.pos_seq_rep = tf.concat([self.seq_rep[:, 1:, :], self.pos, self.target_user_t[:, 1:, :]], axis=2) 376 | self.neg_seq_rep = tf.concat([self.seq_rep[:, 2:, :], self.neg, self.target_user_t[:, 2:, :]], axis=2) 377 | 378 | self.preds_pos = self.build_fc_net(self.pos_seq_rep) 379 | self.preds_neg = self.build_fc_net(self.neg_seq_rep) 380 | self.label_pos = tf.ones_like(self.preds_pos) 381 | self.label_neg = tf.zeros_like(self.preds_neg) 382 | 383 | self.loss = tf.losses.log_loss(self.label_pos, self.preds_pos) + tf.losses.log_loss(self.label_neg, self.preds_neg) 384 | 385 | # prediction for target user and item 386 | inp = tf.concat([self.final_pred_rep, self.target_item, self.target_user], axis=1) 387 | self.y_pred = self.build_fc_net(inp) 388 | self.y_pred = tf.reshape(self.y_pred, [-1,]) 389 | 390 | self.loss += tf.losses.log_loss(self.label_ph, self.y_pred) 391 | for v in tf.trainable_variables(): 392 | if 'bias' not in v.name and 'emb' not in v.name: 393 | self.loss += self.reg_lambda * tf.nn.l2_loss(v) 394 | # optimizer and training step 395 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr) 396 | self.train_step = self.optimizer.minimize(self.loss) 397 | 398 | def build_fc_net(self, inp): 399 | with tf.variable_scope('prediction_layer'): 400 | fc1 = tf.layers.dense(inp, 200, activation=tf.nn.relu, name='fc1', reuse=tf.AUTO_REUSE) 401 | dp1 = tf.nn.dropout(fc1, self.keep_prob, name='dp1') 402 | fc2 = tf.layers.dense(dp1, 80, activation=tf.nn.relu, name='fc2', reuse=tf.AUTO_REUSE) 403 | dp2 = tf.nn.dropout(fc2, self.keep_prob, name='dp2') 404 | fc3 = tf.layers.dense(dp2, 1, activation=tf.sigmoid, name='fc3', reuse=tf.AUTO_REUSE) 405 | return fc3 406 | 407 | def multihead_attention(self, 408 | queries, 409 | keys, 410 | num_units=None, 411 | num_heads=2, 412 | scope="multihead_attention", 413 | reuse=None): 414 | '''Applies multihead attention. 415 | 416 | Args: 417 | queries: A 3d tensor with shape of [N, T_q, C_q]. 418 | keys: A 3d tensor with shape of [N, T_k, C_k]. 419 | num_units: A scalar. Attention size. 420 | num_heads: An int. Number of heads. 421 | scope: Optional scope for `variable_scope`. 422 | reuse: Boolean, whether to reuse the weights of a previous layer 423 | by the same name. 424 | 425 | Returns 426 | A 3d tensor with shape of (N, T_q, C) 427 | ''' 428 | with tf.variable_scope(scope, reuse=reuse): 429 | # Set the fall back option for num_units 430 | if num_units is None: 431 | num_units = queries.get_shape().as_list()[-1] 432 | 433 | # Linear projections 434 | # Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # (N, T_q, C) 435 | # K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 436 | # V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # (N, T_k, C) 437 | Q = tf.layers.dense(queries, num_units, activation=None) # (N, T_q, C) 438 | K = tf.layers.dense(keys, num_units, activation=None) # (N, T_k, C) 439 | V = tf.layers.dense(keys, num_units, activation=None) # (N, T_k, C) 440 | 441 | # Split and concat 442 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) 443 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 444 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 445 | 446 | # Multiplication 447 | outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) 448 | 449 | # Scale 450 | outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) 451 | 452 | # Key Masking 453 | key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k) 454 | key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k) 455 | key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k) 456 | 457 | paddings = tf.ones_like(outputs)*(-2**32+1) 458 | outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k) 459 | 460 | # Activation 461 | outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) 462 | 463 | # Query Masking 464 | query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q) 465 | query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q) 466 | query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k) 467 | outputs *= query_masks # broadcasting. (N, T_q, C) 468 | 469 | # Dropouts 470 | outputs = tf.nn.dropout(outputs, self.keep_prob) 471 | 472 | # Weighted sum 473 | outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) 474 | 475 | # Restore shape 476 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, C) 477 | 478 | # Residual connection 479 | outputs += queries 480 | 481 | # Normalize 482 | #outputs = normalize(outputs) # (N, T_q, C) 483 | 484 | return outputs 485 | 486 | def normalize(self, 487 | inputs, 488 | epsilon = 1e-8, 489 | scope="ln", 490 | reuse=None): 491 | '''Applies layer normalization. 492 | 493 | Args: 494 | inputs: A tensor with 2 or more dimensions, where the first dimension has 495 | `batch_size`. 496 | epsilon: A floating number. A very small number for preventing ZeroDivision Error. 497 | scope: Optional scope for `variable_scope`. 498 | reuse: Boolean, whether to reuse the weights of a previous layer 499 | by the same name. 500 | 501 | Returns: 502 | A tensor with the same shape and data dtype as `inputs`. 503 | ''' 504 | with tf.variable_scope(scope, reuse=reuse): 505 | inputs_shape = inputs.get_shape() 506 | params_shape = inputs_shape[-1:] 507 | 508 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 509 | beta= tf.Variable(tf.zeros(params_shape)) 510 | gamma = tf.Variable(tf.ones(params_shape)) 511 | normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) ) 512 | outputs = gamma * normalized + beta 513 | 514 | return outputs 515 | 516 | 517 | -------------------------------------------------------------------------------- /code/dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pickle as pkl 3 | import time 4 | import numpy as np 5 | import random 6 | from common import * 7 | from elastic_client import * 8 | import multiprocessing 9 | 10 | random.seed(1111) 11 | 12 | class DataLoader(object): 13 | def __init__(self, 14 | batch_size, 15 | seq_file, 16 | target_file, 17 | user_feat_dict_file, 18 | item_feat_dict_file, 19 | max_len): 20 | self.batch_size = batch_size 21 | self.seq_file = open(seq_file, 'r') 22 | self.target_file = open(target_file, 'r') 23 | 24 | if user_feat_dict_file != None: 25 | with open(user_feat_dict_file, 'rb') as f: 26 | self.user_feat_dict = pkl.load(f) 27 | else: 28 | self.user_feat_dict = None 29 | 30 | # item has to have multiple feature fields 31 | with open(item_feat_dict_file, 'rb') as f: 32 | self.item_feat_dict = pkl.load(f) 33 | 34 | self.max_len = max_len 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | def __next__(self): 40 | target_user_batch = [] 41 | target_item_batch = [] 42 | label_batch = [] 43 | user_seq_batch = [] 44 | user_seq_len_batch = [] 45 | 46 | for i in range(self.batch_size): 47 | seq_line = self.seq_file.readline() 48 | target_line = self.target_file.readline() 49 | if seq_line == '': 50 | raise StopIteration 51 | 52 | target_uid, target_iid = target_line[:-1].split(',') 53 | if self.user_feat_dict != None: 54 | target_user_batch.append([int(target_uid)] + self.user_feat_dict[target_uid]) 55 | else: 56 | target_user_batch.append([int(target_uid)]) 57 | 58 | target_item_batch.append([int(target_iid)] + self.item_feat_dict[target_iid]) 59 | if i % 2 == 0: 60 | label_batch.append(1) 61 | else: 62 | label_batch.append(0) 63 | 64 | seq = seq_line[:-1].split(',') 65 | seqlen = len(seq) 66 | user_seq = [] 67 | for iid in seq: 68 | item = [int(iid)] + self.item_feat_dict[iid] 69 | user_seq.append(item) 70 | if seqlen >= self.max_len: 71 | user_seq = user_seq[-self.max_len:] 72 | user_seq_len_batch.append(self.max_len) 73 | else: 74 | user_seq += [[0] * len(user_seq[-1])] * (self.max_len - seqlen) 75 | user_seq_len_batch.append(seqlen) 76 | user_seq_batch.append(user_seq) 77 | 78 | 79 | return [user_seq_batch, user_seq_len_batch, target_user_batch, target_item_batch, label_batch] 80 | 81 | class DataLoader_Target(object): 82 | def __init__(self, 83 | batch_size, 84 | target_file, 85 | user_feat_dict_file, 86 | item_feat_dict_file, 87 | context_dict_file): 88 | self.batch_size = batch_size 89 | self.target_file = open(target_file, 'r') 90 | 91 | if user_feat_dict_file != None: 92 | with open(user_feat_dict_file, 'rb') as f: 93 | self.user_feat_dict = pkl.load(f) 94 | else: 95 | self.user_feat_dict = None 96 | 97 | with open(item_feat_dict_file, 'rb') as f: 98 | self.item_feat_dict = pkl.load(f) 99 | with open(context_dict_file, 'rb') as f: 100 | self.context_dict = pkl.load(f) 101 | 102 | def __iter__(self): 103 | return self 104 | 105 | def __next__(self): 106 | target_batch = [] 107 | label_batch = [] 108 | 109 | for i in range(self.batch_size): 110 | target_line = self.target_file.readline() 111 | if target_line == '': 112 | raise StopIteration 113 | 114 | target_uid, target_iid = target_line[:-1].split(',') 115 | if self.user_feat_dict != None: 116 | target_batch.append([int(target_uid)] + self.user_feat_dict[target_uid] + [int(target_iid)] + self.item_feat_dict[target_iid] + self.context_dict[target_uid]) 117 | else: 118 | target_batch.append([int(target_uid)] + [int(target_iid)] + self.item_feat_dict[target_iid] + self.context_dict[target_uid]) 119 | 120 | if i % 2 == 0: 121 | label_batch.append(1) 122 | else: 123 | label_batch.append(0) 124 | return target_batch, label_batch 125 | 126 | class Taker(object): 127 | def __init__(self, es_reader, batch_size, b_num, record_fnum): 128 | self.es_reader = es_reader 129 | self.batch_size = batch_size 130 | self.b_num = b_num 131 | self.record_fnum = record_fnum 132 | 133 | def take_behave(self, target_batch, index_batch): 134 | seq_batch = [] 135 | seq_len_batch = [self.b_num] * self.batch_size 136 | 137 | queries = [] 138 | for i in range(self.batch_size): 139 | target = np.array(target_batch[i][1:]) # with out uid 140 | index = np.array(index_batch[i]) # F-1 141 | query_tup = (str(target_batch[i][0]), ','.join(list(map(str, target[index==1].tolist())))) 142 | queries.append(query_tup) 143 | seq_batch = self.es_reader.query(queries, self.b_num, self.record_fnum) 144 | 145 | return seq_batch, seq_len_batch 146 | 147 | 148 | class DataLoader_Multi(object): 149 | def __init__(self, workload_list, taker, worker_num=2, wait_time=0.001): 150 | self.taker = taker 151 | self.worker_num = worker_num 152 | self.wait_time = wait_time 153 | self.threads = [] 154 | self.work = multiprocessing.Queue() 155 | self.res = multiprocessing.Queue() 156 | 157 | for workload_tuple in workload_list: 158 | self.work.put(workload_tuple) 159 | print("workload queue size: {}".format(self.work.qsize())) 160 | 161 | for i in range(self.worker_num): 162 | thread = multiprocessing.Process(target=self.worker) 163 | self.threads.append(thread) 164 | thread.daemon = True 165 | thread.start() 166 | 167 | def worker(self): 168 | while self.work.empty() == False: 169 | target_batch, label_batch, index_batch = self.work.get() 170 | seq_batch, seq_len_batch = self.taker.take_behave(target_batch, index_batch) 171 | self.res.put([seq_batch, seq_len_batch, target_batch, label_batch]) 172 | 173 | def __iter__(self): 174 | return self 175 | 176 | def __next__(self): 177 | if self.res.empty(): 178 | if self.work.empty(): 179 | for thread in self.threads: 180 | thread.terminate() 181 | raise StopIteration 182 | else: 183 | time.sleep(self.wait_time) 184 | 185 | re = self.res.get() 186 | return re -------------------------------------------------------------------------------- /code/elastic_client.py: -------------------------------------------------------------------------------- 1 | from elasticsearch import Elasticsearch 2 | from elasticsearch_dsl import Search, MultiSearch 3 | import elasticsearch.helpers 4 | import time 5 | import numpy as np 6 | 7 | class ESWriter(object): 8 | def __init__(self, input_file, index_name, host_url = 'localhost:9200'): 9 | self.input_file = input_file 10 | self.es = Elasticsearch(host_url) 11 | self.index_name = index_name 12 | 13 | # delete if there is old 14 | self.es.indices.delete(index=self.index_name, ignore=[400, 404]) 15 | 16 | self.create_index_body = { 17 | "settings": { 18 | "analysis": { 19 | "analyzer": { 20 | "my_analyzer": { 21 | "tokenizer": "my_tokenizer" 22 | } 23 | }, 24 | "tokenizer": { 25 | "my_tokenizer": { 26 | "type": "pattern", 27 | "pattern": "," 28 | } 29 | } 30 | }, 31 | 32 | }, 33 | "mappings": { 34 | "properties": { 35 | "record": { 36 | "type": "text", 37 | 'analyzer': 'my_analyzer', 38 | 'search_analyzer': 'my_analyzer' 39 | } 40 | } 41 | }, 42 | } 43 | self.es.indices.create(index=self.index_name, body=self.create_index_body, ignore=400) 44 | print('index created') 45 | 46 | def write(self): 47 | t = time.time() 48 | with open(self.input_file, 'r') as f: 49 | docs = [] 50 | batch_num = 0 51 | for line in f: 52 | line_item = line[:-1].split(',') 53 | userid = line_item[0] 54 | record = ','.join(line_item[1:-1]) 55 | timestamp = int(line_item[-1]) 56 | doc = { 57 | 'userid' : userid, 58 | 'record': record, 59 | 'timestamp': timestamp 60 | } 61 | docs.append(doc) 62 | if len(docs) == 1000: 63 | actions = [{ 64 | '_op_type': 'index', 65 | '_index': self.index_name, 66 | '_source': d 67 | } 68 | for d in docs] 69 | 70 | elasticsearch.helpers.bulk(self.es, actions) 71 | batch_num += 1 72 | docs = [] 73 | if batch_num % 1000 == 0: 74 | print('{} data samples have been inserted'.format(batch_num * 1000)) 75 | 76 | # the last bulk 77 | if docs != []: 78 | actions = [{ 79 | '_op_type': 'index', 80 | '_index': self.index_name, 81 | '_source': d 82 | } 83 | for d in docs] 84 | elasticsearch.helpers.bulk(self.es, actions) 85 | 86 | print('data insert time: %.2f seconds' % (time.time() - t)) 87 | 88 | 89 | class ESReader(object): 90 | def __init__(self, index_name, host_url = 'localhost:9200'): 91 | self.es = Elasticsearch(host_url) 92 | self.index_name = index_name 93 | 94 | def query(self, queries, size, record_fnum): 95 | ms = MultiSearch(using=self.es, index=self.index_name) 96 | for q in queries: 97 | s = Search().query("match", userid=q[0]).query("match", record=q[1])[:size] 98 | ms = ms.add(s) 99 | responses = ms.execute() 100 | 101 | res_batch = [] 102 | for response in responses: 103 | res = [] 104 | for hit in response: 105 | res.append([int(hit.userid)] + list(map(int, hit.record.split(',')))) 106 | if len(res) < size: 107 | res += [np.zeros([record_fnum,]).astype(np.int32).tolist()] * (size - len(res)) 108 | res_batch.append(res) 109 | return res_batch 110 | -------------------------------------------------------------------------------- /code/preprocess_alipay.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import random 3 | import numpy as np 4 | import sys 5 | from elastic_client import * 6 | import datetime 7 | 8 | RAW_DIR = '../ubr4rec-data/alipay/raw_data/' 9 | FEATENG_DIR = '../ubr4rec-data/alipay/feateng_data/' 10 | 11 | ORI_FEAT_SIZE = 2836404 12 | FEAT_SIZE = ORI_FEAT_SIZE + 6 13 | SECONDS_PER_DAY = 24 * 3600 14 | 15 | def feateng(in_file, remap_dict_file): 16 | uid_remap_dict = {} 17 | iid_remap_dict = {} 18 | sid_remap_dict = {} 19 | cid_remap_dict = {} 20 | 21 | uid_set = set() 22 | iid_set = set() 23 | sid_set = set() 24 | cid_set = set() 25 | 26 | with open(in_file, 'r') as r: 27 | i = 0 28 | for line in r: 29 | if i == 0: 30 | i += 1 31 | continue 32 | uid, sid, iid, cid, btype, date = line[:-1].split(',') 33 | if btype == '0': 34 | uid_set.add(uid) 35 | iid_set.add(iid) 36 | sid_set.add(sid) 37 | cid_set.add(cid) 38 | 39 | uid_list = list(uid_set) 40 | iid_list = list(iid_set) 41 | cid_list = list(cid_set) 42 | sid_list = list(sid_set) 43 | 44 | print('user number is: {}'.format(len(uid_list))) 45 | print('item number is: {}'.format(len(iid_list))) 46 | 47 | feature_id = 1 48 | for uid in uid_list: 49 | uid_remap_dict[uid] = str(feature_id) 50 | feature_id += 1 51 | for iid in iid_list: 52 | iid_remap_dict[iid] = str(feature_id) 53 | feature_id += 1 54 | for cid in cid_list: 55 | cid_remap_dict[cid] = str(feature_id) 56 | feature_id += 1 57 | for sid in sid_list: 58 | sid_remap_dict[sid] = str(feature_id) 59 | feature_id += 1 60 | print('total original feature number: {}'.format(feature_id)) 61 | 62 | with open(remap_dict_file, 'wb') as f: 63 | pkl.dump(uid_remap_dict, f) 64 | pkl.dump(iid_remap_dict, f) 65 | pkl.dump(cid_remap_dict, f) 66 | pkl.dump(sid_remap_dict, f) 67 | print('remap dict dumpped') 68 | 69 | def get_season(month): 70 | if month >= 10: 71 | return 3 72 | elif month >= 7 and month <= 9: 73 | return 2 74 | elif month >= 4 and month <= 6: 75 | return 1 76 | else: 77 | return 0 78 | 79 | def get_ud(day): 80 | if day <= 15: 81 | return 0 82 | else: 83 | return 1 84 | 85 | def remap_log_file(input_log_file, remap_dict_file, output_log_file, item_feat_dict_file): 86 | with open(remap_dict_file, 'rb') as f: 87 | uid_remap_dict = pkl.load(f) 88 | iid_remap_dict = pkl.load(f) 89 | cid_remap_dict = pkl.load(f) 90 | sid_remap_dict = pkl.load(f) 91 | item_feat_dict = {} 92 | newlines = [] 93 | 94 | with open(input_log_file, 'r') as f: 95 | for line in f: 96 | uid, sid, iid, cid, btype, date = line[:-1].split(',') 97 | if btype != '0': 98 | continue 99 | uid = uid_remap_dict[uid] 100 | iid = iid_remap_dict[iid] 101 | cid = cid_remap_dict[cid] 102 | sid = sid_remap_dict[sid] 103 | 104 | ts = str(int(time.mktime(datetime.datetime.strptime(date, "%Y%m%d").timetuple()))) 105 | 106 | month = int(date[4:6]) 107 | day = int(date[6:]) 108 | sea_id = str(get_season(month) + ORI_FEAT_SIZE) 109 | ud_id = str(get_ud(day) + ORI_FEAT_SIZE + 4) 110 | 111 | if iid not in item_feat_dict: 112 | item_feat_dict[iid] = [cid, sid] 113 | 114 | newline = ','.join([uid, iid, cid, sid, sea_id, ud_id, ts]) + '\n' 115 | newlines.append(newline) 116 | 117 | with open(output_log_file, 'w') as f: 118 | f.writelines(newlines) 119 | with open(item_feat_dict_file, 'wb') as f: 120 | pkl.dump(item_feat_dict, f) 121 | 122 | 123 | def sort_raw_log(raw_log_ts_file, sorted_raw_log_ts_file): 124 | line_dict = {} 125 | with open(raw_log_ts_file) as f: 126 | for line in f: 127 | uid, _, _, _, _, _, ts = line[:-1].split(',') 128 | if uid not in line_dict: 129 | line_dict[uid] = [[line, int(ts)]] 130 | else: 131 | line_dict[uid].append([line, int(ts)]) 132 | 133 | for uid in line_dict: 134 | line_dict[uid].sort(key = lambda x:x[1]) 135 | print('sort complete') 136 | print(len(line_dict.keys())) 137 | newlines = [] 138 | for uid in line_dict: 139 | for tup in line_dict[uid]: 140 | newlines.append(tup[0]) 141 | with open(sorted_raw_log_ts_file, 'w') as f: 142 | f.writelines(newlines) 143 | 144 | 145 | def random_sample(min = 626042, max = 2826332): 146 | return str(random.randint(min, max)) 147 | 148 | def neg_sample(user_seq): 149 | r = random.randint(0, 4) 150 | if r == 0: 151 | return random_sample() 152 | else: 153 | return random.choice(user_seq) 154 | 155 | def gen_target_seq(input_file, 156 | target_train_file, 157 | target_vali_file, 158 | target_test_file, 159 | user_seq_file, 160 | database_file, 161 | context_dict_train_file, 162 | context_dict_vali_file, 163 | context_dict_test_file): 164 | line_dict = {} 165 | user_seq_dict = {} 166 | context_dict_train = {} 167 | context_dict_vali = {} 168 | context_dict_test = {} 169 | 170 | 171 | with open(input_file, 'r') as f: 172 | for line in f: 173 | uid, iid, cid, sid, sea_id, ud_id, ts = line[:-1].split(',') 174 | if uid not in line_dict: 175 | line_dict[uid] = [line] 176 | user_seq_dict[uid] = [iid] 177 | else: 178 | line_dict[uid].append(line) 179 | user_seq_dict[uid].append(iid) 180 | 181 | 182 | target_train_lines = [] 183 | target_vali_lines = [] 184 | target_test_lines = [] 185 | user_seq_lines = [] 186 | database_lines = [] 187 | 188 | for uid in user_seq_dict: 189 | if len(user_seq_dict[uid]) > 3: 190 | target_train_lines += [','.join([uid, user_seq_dict[uid][-3]]) + '\n'] 191 | target_train_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3])]) + '\n'] 192 | context_dict_train[uid] = list(map(int, line_dict[uid][-3][:-1].split(',')[-3:-1])) 193 | 194 | target_vali_lines += [','.join([uid, user_seq_dict[uid][-2]]) + '\n'] 195 | target_vali_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3])]) + '\n'] 196 | context_dict_vali[uid] = list(map(int, line_dict[uid][-2][:-1].split(',')[-3:-1])) 197 | 198 | target_test_lines += [','.join([uid, user_seq_dict[uid][-1]]) + '\n'] 199 | target_test_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3])]) + '\n'] 200 | context_dict_test[uid] = list(map(int, line_dict[uid][-1][:-1].split(',')[-3:-1])) 201 | 202 | user_seq = user_seq_dict[uid][:-3] 203 | user_seq_lines += [','.join(user_seq) + '\n'] * 2 #(1 pos and 1 neg item) 204 | 205 | database_lines += line_dict[uid][:-3] 206 | 207 | with open(target_train_file, 'w') as f: 208 | f.writelines(target_train_lines) 209 | with open(target_vali_file, 'w') as f: 210 | f.writelines(target_vali_lines) 211 | with open(target_test_file, 'w') as f: 212 | f.writelines(target_test_lines) 213 | 214 | with open(user_seq_file, 'w') as f: 215 | f.writelines(user_seq_lines) 216 | with open(database_file, 'w') as f: 217 | f.writelines(database_lines) 218 | 219 | with open(context_dict_train_file, 'wb') as f: 220 | pkl.dump(context_dict_train, f) 221 | with open(context_dict_vali_file, 'wb') as f: 222 | pkl.dump(context_dict_vali, f) 223 | with open(context_dict_test_file, 'wb') as f: 224 | pkl.dump(context_dict_test, f) 225 | 226 | 227 | def insert_elastic(input_file): 228 | writer = ESWriter(input_file, 'alipay') 229 | writer.write() 230 | 231 | if __name__ == "__main__": 232 | feateng(RAW_DIR + 'ijcai2016_taobao.csv', FEATENG_DIR + 'id_remap_dict.pkl') 233 | remap_log_file(RAW_DIR + 'ijcai2016_taobao.csv', FEATENG_DIR + 'id_remap_dict.pkl', FEATENG_DIR + 'remapped_log.csv', FEATENG_DIR + 'item_feat_dict.pkl') 234 | sort_raw_log(FEATENG_DIR + 'remapped_log.csv', FEATENG_DIR + 'sorted_remapped_log.csv') 235 | gen_target_seq(FEATENG_DIR + 'sorted_remapped_log.csv', 236 | FEATENG_DIR + 'target_train.txt', FEATENG_DIR + 'target_vali.txt', FEATENG_DIR + 'target_test.txt', FEATENG_DIR + 'user_seq.txt', FEATENG_DIR + 'database.txt', 237 | FEATENG_DIR + 'context_dict_train.pkl', FEATENG_DIR + 'context_dict_vali.pkl', FEATENG_DIR + 'context_dict_test.pkl') 238 | insert_elastic(FEATENG_DIR + 'database.txt') 239 | -------------------------------------------------------------------------------- /code/preprocess_taobao.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import random 3 | import numpy as np 4 | import sys 5 | from elastic_client import * 6 | 7 | RAW_DIR = '../data/taobao/raw_data/' 8 | FEATENG_DIR = '../data/taobao/feateng_data/' 9 | 10 | ORI_FEAT_SIZE = 5062312 11 | FEAT_SIZE = ORI_FEAT_SIZE + 2 12 | START_TIME = 1511539200 13 | SECONDS_PER_DAY = 24 * 3600 14 | 15 | def feateng(in_file, remap_dict_file): 16 | uid_remap_dict = {} 17 | iid_remap_dict = {} 18 | cid_remap_dict = {} 19 | 20 | uid_set = set() 21 | iid_set = set() 22 | cid_set = set() 23 | 24 | with open(in_file, 'r') as r: 25 | for line in r: 26 | uid, iid, cid, btype, ts = line.split(',') 27 | if btype == 'pv': 28 | uid_set.add(uid) 29 | iid_set.add(iid) 30 | cid_set.add(cid) 31 | 32 | uid_list = list(uid_set) 33 | iid_list = list(iid_set) 34 | cid_list = list(cid_set) 35 | 36 | 37 | feature_id = 1 38 | for uid in uid_list: 39 | uid_remap_dict[uid] = str(feature_id) 40 | feature_id += 1 41 | for iid in iid_list: 42 | iid_remap_dict[iid] = str(feature_id) 43 | feature_id += 1 44 | for cid in cid_list: 45 | cid_remap_dict[cid] = str(feature_id) 46 | feature_id += 1 47 | print('total original feature number: {}'.format(feature_id)) 48 | 49 | with open(remap_dict_file, 'wb') as f: 50 | pkl.dump(uid_remap_dict, f) 51 | pkl.dump(iid_remap_dict, f) 52 | pkl.dump(cid_remap_dict, f) 53 | print('remap dict dumpped') 54 | 55 | def isweekday(date): 56 | if date in [0, 1, 8]: 57 | return str(ORI_FEAT_SIZE) 58 | else: 59 | return str(ORI_FEAT_SIZE + 1) 60 | 61 | def remap_log_file(input_log_file, remap_dict_file, output_log_file, user_seq_dict_file, item_feat_dict_file): 62 | with open(remap_dict_file, 'rb') as f: 63 | uid_remap_dict = pkl.load(f) 64 | iid_remap_dict = pkl.load(f) 65 | cid_remap_dict = pkl.load(f) 66 | user_seq_dict = {} 67 | item_feat_dict = {} 68 | newlines = [] 69 | 70 | with open(input_log_file, 'r') as f: 71 | for line in f: 72 | uid, iid, cid, btype, ts = line[:-1].split(',') 73 | if btype != 'pv': 74 | continue 75 | uid = uid_remap_dict[uid] 76 | iid = iid_remap_dict[iid] 77 | cid = cid_remap_dict[cid] 78 | 79 | date = (int(ts) - START_TIME) // SECONDS_PER_DAY 80 | if date < 0: 81 | continue 82 | date = isweekday(date) 83 | 84 | if uid not in user_seq_dict: 85 | user_seq_dict[uid] = [iid] 86 | else: 87 | user_seq_dict[uid].append(iid) 88 | 89 | if iid not in item_feat_dict: 90 | item_feat_dict[iid] = [cid] 91 | 92 | newline = ','.join([uid, iid, cid, date, ts]) + '\n' 93 | newlines.append(newline) 94 | 95 | with open(output_log_file, 'w') as f: 96 | f.writelines(newlines) 97 | with open(item_feat_dict_file, 'wb') as f: 98 | pkl.dump(item_feat_dict, f) 99 | with open(user_seq_dict_file, 'wb') as f: 100 | pkl.dump(user_seq_dict, f) 101 | 102 | def neg_sample(user_seq, items): 103 | r = random.randint(0, 1) 104 | if r == 1: 105 | return random.choice(user_seq) 106 | else: 107 | return random.choice(items) 108 | 109 | def gen_target_seq(input_file, 110 | item_feat_dict_file, 111 | target_train_file, 112 | target_vali_file, 113 | target_test_file, 114 | user_seq_file, 115 | database_file, 116 | context_dict_train_file, 117 | context_dict_vali_file, 118 | context_dict_test_file): 119 | with open(item_feat_dict_file, 'rb') as f: 120 | d = pkl.load(f) 121 | items = [] 122 | for item in d.keys(): 123 | items.append(item) 124 | 125 | line_dict = {} 126 | user_seq_dict = {} 127 | context_dict_train = {} 128 | context_dict_vali = {} 129 | context_dict_test = {} 130 | 131 | with open(input_file, 'r') as f: 132 | for line in f: 133 | uid, iid, cid, did, time_stamp = line[:-1].split(',') 134 | if uid not in line_dict: 135 | line_dict[uid] = [line] 136 | user_seq_dict[uid] = [iid] 137 | else: 138 | line_dict[uid].append(line) 139 | user_seq_dict[uid].append(iid) 140 | 141 | 142 | target_train_lines = [] 143 | target_vali_lines = [] 144 | target_test_lines = [] 145 | user_seq_lines = [] 146 | database_lines = [] 147 | 148 | for uid in user_seq_dict: 149 | if len(user_seq_dict[uid]) > 3: 150 | target_train_lines += [','.join([uid, user_seq_dict[uid][-3]]) + '\n'] 151 | target_train_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3], items)]) + '\n'] 152 | context_dict_train[uid] = [int(line_dict[uid][-3][:-1].split(',')[-2])] 153 | 154 | target_vali_lines += [','.join([uid, user_seq_dict[uid][-2]]) + '\n'] 155 | target_vali_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3], items)]) + '\n'] 156 | context_dict_vali[uid] = [int(line_dict[uid][-2][:-1].split(',')[-2])] 157 | 158 | target_test_lines += [','.join([uid, user_seq_dict[uid][-1]]) + '\n'] 159 | target_test_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3], items)]) + '\n'] 160 | context_dict_test[uid] = [int(line_dict[uid][-1][:-1].split(',')[-2])] 161 | 162 | user_seq = user_seq_dict[uid][:-3] 163 | user_seq_lines += [','.join(user_seq) + '\n'] * 2 #(1 pos and 1 neg item) 164 | 165 | database_lines += line_dict[uid][:-3] 166 | 167 | with open(target_train_file, 'w') as f: 168 | f.writelines(target_train_lines) 169 | with open(target_vali_file, 'w') as f: 170 | f.writelines(target_vali_lines) 171 | with open(target_test_file, 'w') as f: 172 | f.writelines(target_test_lines) 173 | 174 | with open(user_seq_file, 'w') as f: 175 | f.writelines(user_seq_lines) 176 | with open(database_file, 'w') as f: 177 | f.writelines(database_lines) 178 | 179 | with open(context_dict_train_file, 'wb') as f: 180 | pkl.dump(context_dict_train, f) 181 | with open(context_dict_vali_file, 'wb') as f: 182 | pkl.dump(context_dict_vali, f) 183 | with open(context_dict_test_file, 'wb') as f: 184 | pkl.dump(context_dict_test, f) 185 | 186 | 187 | def insert_elastic(input_file): 188 | writer = ESWriter(input_file, 'taobao') 189 | writer.write() 190 | 191 | 192 | if __name__ == "__main__": 193 | feateng(RAW_DIR + 'UserBehavior.csv', FEATENG_DIR + 'id_remap_dict.pkl') 194 | remap_log_file(RAW_DIR + 'UserBehavior.csv', FEATENG_DIR + 'id_remap_dict.pkl', FEATENG_DIR + 'remapped_log.csv', FEATENG_DIR + 'user_seq_dict.pkl', FEATENG_DIR + 'item_feat_dict.pkl') 195 | gen_target_seq(FEATENG_DIR + 'remapped_log.csv', FEATENG_DIR + 'item_feat_dict.pkl', FEATENG_DIR + 'target_train.txt', FEATENG_DIR + 'target_vali.txt', FEATENG_DIR + 'target_test.txt', FEATENG_DIR + 'user_seq.txt', FEATENG_DIR + 'database.txt', 196 | FEATENG_DIR + 'context_dict_train.pkl', FEATENG_DIR + 'context_dict_vali.pkl', FEATENG_DIR + 'context_dict_test.pkl') 197 | insert_elastic(FEATENG_DIR + 'database.txt') 198 | 199 | -------------------------------------------------------------------------------- /code/preprocess_tmall.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import random 3 | import numpy as np 4 | import datetime 5 | import time 6 | import sys 7 | from elastic_client import * 8 | 9 | random.seed(1111) 10 | 11 | RAW_DIR = '../data/tmall/raw_data/' 12 | FEATENG_DIR = '../data/tmall/feateng_data/' 13 | ORI_FEATSIZE = 1529672 14 | 15 | def join_user_profile(user_profile_file, behavior_file, joined_file): 16 | user_profile_dict = {} 17 | with open(user_profile_file, 'r') as f: 18 | for line in f: 19 | uid, aid, gid = line[:-1].split(',') 20 | user_profile_dict[uid] = ','.join([aid, gid]) 21 | 22 | # join 23 | newlines = [] 24 | with open(behavior_file, 'r') as f: 25 | for line in f: 26 | uid = line[:-1].split(',')[0] 27 | user_profile = user_profile_dict[uid] 28 | newlines.append(line[:-1] + ',' + user_profile + '\n') 29 | with open(joined_file, 'w') as f: 30 | f.writelines(newlines) 31 | 32 | 33 | def feateng(joined_raw_file, remap_dict_file, user_feat_dict_file, item_feat_dict_file): 34 | uid_set = set() 35 | iid_set = set() 36 | cid_set = set() 37 | sid_set = set() 38 | bid_set = set() 39 | aid_set = set() 40 | gid_set = set() 41 | with open(raw_file, 'r') as f: 42 | lines = f.readlines()[1:] 43 | for line in lines: 44 | uid, iid, cid, sid, bid, date_str, btypeid, aid, gid = line[:-1].split(',') 45 | uid_set.add(uid) 46 | iid_set.add(iid) 47 | cid_set.add(cid) 48 | sid_set.add(sid) 49 | bid_set.add(bid) 50 | aid_set.add(aid) 51 | gid_set.add(gid) 52 | date_str = '2015' + date_str 53 | time_int = int(time.mktime(datetime.datetime.strptime(date_str, "%Y%m%d").timetuple())) 54 | 55 | # remap 56 | uid_list = list(uid_set) 57 | iid_list = list(iid_set) 58 | cid_list = list(cid_set) 59 | sid_list = list(sid_set) 60 | bid_list = list(bid_set) 61 | aid_list = list(aid_set) 62 | gid_list = list(gid_set) 63 | 64 | print('user num: {}'.format(len(uid_list))) 65 | print('item num: {}'.format(len(iid_list))) 66 | print('cate num: {}'.format(len(cid_list))) 67 | print('seller num: {}'.format(len(sid_list))) 68 | print('brand num: {}'.format(len(bid_list))) 69 | print('age num: {}'.format(len(aid_list))) 70 | print('gender num: {}'.format(len(gid_list))) 71 | 72 | remap_id = 1 73 | uid_remap_dict = {} 74 | iid_remap_dict = {} 75 | cid_remap_dict = {} 76 | sid_remap_dict = {} 77 | bid_remap_dict = {} 78 | aid_remap_dict = {} 79 | gid_remap_dict = {} 80 | 81 | for uid in uid_list: 82 | uid_remap_dict[uid] = str(remap_id) 83 | remap_id += 1 84 | for iid in iid_list: 85 | iid_remap_dict[iid] = str(remap_id) 86 | remap_id += 1 87 | for cid in cid_list: 88 | cid_remap_dict[cid] = str(remap_id) 89 | remap_id += 1 90 | for sid in sid_list: 91 | sid_remap_dict[sid] = str(remap_id) 92 | remap_id += 1 93 | for bid in bid_list: 94 | bid_remap_dict[bid] = str(remap_id) 95 | remap_id += 1 96 | for aid in aid_list: 97 | aid_remap_dict[aid] = str(remap_id) 98 | remap_id += 1 99 | for gid in gid_list: 100 | gid_remap_dict[gid] = str(remap_id) 101 | remap_id += 1 102 | print('feat size: {}'.format(remap_id)) 103 | 104 | with open(remap_dict_file, 'wb') as f: 105 | pkl.dump(uid_remap_dict, f) 106 | pkl.dump(iid_remap_dict, f) 107 | pkl.dump(cid_remap_dict, f) 108 | pkl.dump(sid_remap_dict, f) 109 | pkl.dump(bid_remap_dict, f) 110 | pkl.dump(aid_remap_dict, f) 111 | pkl.dump(gid_remap_dict, f) 112 | print('remap ids completed') 113 | 114 | # remap file generate 115 | item_feat_dict = {} 116 | user_feat_dict = {} 117 | # for dummy user 118 | user_feat_dict['0'] = [0, 0] 119 | with open(raw_file, 'r') as f: 120 | lines = f.readlines()[1:] 121 | for i in range(len(lines)): 122 | uid, iid, cid, sid, bid, time_stamp, btypeid, aid, gid = lines[i][:-1].split(',') 123 | uid_remap = uid_remap_dict[uid] 124 | iid_remap = iid_remap_dict[iid] 125 | cid_remap = cid_remap_dict[cid] 126 | sid_remap = sid_remap_dict[sid] 127 | bid_remap = bid_remap_dict[bid] 128 | aid_remap = aid_remap_dict[aid] 129 | gid_remap = gid_remap_dict[gid] 130 | item_feat_dict[iid_remap] = [int(cid_remap), int(sid_remap), int(bid_remap)] 131 | user_feat_dict[uid_remap] = [int(aid_remap), int(gid_remap)] 132 | print('remaped file generated') 133 | 134 | 135 | with open(user_feat_dict_file, 'wb') as f: 136 | pkl.dump(user_feat_dict, f) 137 | print('user feat dict dump completed') 138 | with open(item_feat_dict_file, 'wb') as f: 139 | pkl.dump(item_feat_dict, f) 140 | print('item feat dict dump completed') 141 | 142 | def get_season(month): 143 | if month >= 10: 144 | return 3 145 | elif month >= 7 and month <= 9: 146 | return 2 147 | elif month >= 4 and month <= 6: 148 | return 1 149 | else: 150 | return 0 151 | 152 | def get_ud(day): 153 | if day <= 15: 154 | return 0 155 | else: 156 | return 1 157 | 158 | 159 | def remap(raw_file, remap_dict_file, remap_file): 160 | with open(remap_dict_file, 'rb') as f: 161 | uid_remap_dict = pkl.load(f) 162 | iid_remap_dict = pkl.load(f) 163 | cid_remap_dict = pkl.load(f) 164 | sid_remap_dict = pkl.load(f) 165 | bid_remap_dict = pkl.load(f) 166 | aid_remap_dict = pkl.load(f) 167 | gid_remap_dict = pkl.load(f) 168 | 169 | newlines = [] 170 | with open(raw_file, 'r') as f: 171 | lines = f.readlines()[1:] 172 | for line in lines: 173 | uid, iid, cid, sid, bid, date, _, aid, gid = line[:-1].split(',') 174 | uid = uid_remap_dict[uid] 175 | iid = iid_remap_dict[iid] 176 | cid = cid_remap_dict[cid] 177 | sid = sid_remap_dict[sid] 178 | bid = bid_remap_dict[bid] 179 | aid = aid_remap_dict[aid] 180 | gid = gid_remap_dict[gid] 181 | 182 | month = int(date[:2]) 183 | day = int(date[2:]) 184 | sea_id = str(get_season(month) + ORI_FEATSIZE) 185 | ud_id = str(get_ud(day) + ORI_FEATSIZE + 4) 186 | 187 | date = '2015' + date 188 | time_stamp = str(int(time.mktime(datetime.datetime.strptime(date, "%Y%m%d").timetuple()))) 189 | newline = ','.join([uid, aid, gid, iid, cid, sid, bid, sea_id, ud_id, time_stamp]) + '\n' 190 | newlines.append(newline) 191 | 192 | with open(remap_file, 'w') as f: 193 | f.writelines(newlines) 194 | 195 | def sort_log(log_ts_file, sorted_log_ts_file): 196 | line_dict = {} 197 | with open(log_ts_file) as f: 198 | for line in f: 199 | line_items = line[:-1].split(',') 200 | uid = line_items[0] 201 | ts = int(line_items[-1]) 202 | if uid not in line_dict: 203 | line_dict[uid] = [[line, ts]] 204 | else: 205 | line_dict[uid].append([line, ts]) 206 | 207 | for uid in line_dict: 208 | line_dict[uid].sort(key = lambda x:x[1]) 209 | print('sort complete') 210 | newlines = [] 211 | for uid in line_dict: 212 | for tup in line_dict[uid]: 213 | newlines.append(tup[0]) 214 | with open(sorted_log_ts_file, 'w') as f: 215 | f.writelines(newlines) 216 | 217 | 218 | def random_sample(min = 424171, max = 1514560): 219 | return str(random.randint(min, max)) 220 | 221 | def neg_sample(user_seq): 222 | r = random.randint(0, 1) 223 | if r == 1: 224 | return random.choice(user_seq) 225 | else: 226 | return random_sample() 227 | 228 | def gen_target_seq(input_file, 229 | target_train_file, 230 | target_vali_file, 231 | target_test_file, 232 | user_seq_file, 233 | database_file, 234 | context_dict_train_file, 235 | context_dict_vali_file, 236 | context_dict_test_file): 237 | line_dict = {} 238 | user_seq_dict = {} 239 | context_dict_train = {} 240 | context_dict_vali = {} 241 | context_dict_test = {} 242 | 243 | with open(input_file, 'r') as f: 244 | for line in f: 245 | uid, aid, gid, iid, cid, sid, bid, sea_id, ud_id, time_stamp = line[:-1].split(',') 246 | if uid not in line_dict: 247 | line_dict[uid] = [line] 248 | user_seq_dict[uid] = [iid] 249 | else: 250 | line_dict[uid].append(line) 251 | user_seq_dict[uid].append(iid) 252 | 253 | 254 | target_train_lines = [] 255 | target_vali_lines = [] 256 | target_test_lines = [] 257 | user_seq_lines = [] 258 | database_lines = [] 259 | 260 | for uid in user_seq_dict: 261 | if len(user_seq_dict[uid]) > 3: 262 | target_train_lines += [','.join([uid, user_seq_dict[uid][-3]]) + '\n'] 263 | target_train_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3])]) + '\n'] 264 | context_dict_train[uid] = list(map(int, line_dict[uid][-3][:-1].split(',')[-3:-1])) 265 | 266 | target_vali_lines += [','.join([uid, user_seq_dict[uid][-2]]) + '\n'] 267 | target_vali_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3])]) + '\n'] 268 | context_dict_vali[uid] = list(map(int, line_dict[uid][-2][:-1].split(',')[-3:-1])) 269 | 270 | target_test_lines += [','.join([uid, user_seq_dict[uid][-1]]) + '\n'] 271 | target_test_lines += [','.join([uid, neg_sample(user_seq_dict[uid][:-3])]) + '\n'] 272 | context_dict_test[uid] = list(map(int, line_dict[uid][-1][:-1].split(',')[-3:-1])) 273 | 274 | user_seq = user_seq_dict[uid][:-3] 275 | user_seq_lines += [','.join(user_seq) + '\n'] * 2 #(1 pos and 1 neg item) 276 | 277 | database_lines += line_dict[uid][:-3] 278 | 279 | with open(target_train_file, 'w') as f: 280 | f.writelines(target_train_lines) 281 | with open(target_vali_file, 'w') as f: 282 | f.writelines(target_vali_lines) 283 | with open(target_test_file, 'w') as f: 284 | f.writelines(target_test_lines) 285 | 286 | with open(user_seq_file, 'w') as f: 287 | f.writelines(user_seq_lines) 288 | with open(database_file, 'w') as f: 289 | f.writelines(database_lines) 290 | 291 | with open(context_dict_train_file, 'wb') as f: 292 | pkl.dump(context_dict_train, f) 293 | with open(context_dict_vali_file, 'wb') as f: 294 | pkl.dump(context_dict_vali, f) 295 | with open(context_dict_test_file, 'wb') as f: 296 | pkl.dump(context_dict_test, f) 297 | 298 | def insert_elastic(input_file): 299 | writer = ESWriter(input_file, 'tmall') 300 | writer.write() 301 | 302 | if __name__ == "__main__": 303 | join_user_profile(RAW_DIR + 'user_info_format1.csv', RAW_DIR + 'user_log_format1.csv', FEATENG_DIR + 'joined_user_behavior.csv') 304 | feateng(FEATENG_DIR + 'joined_user_behavior.csv', FEATENG_DIR + 'remap_dict.pkl', FEATENG_DIR + 'user_feat_dict.pkl', FEATENG_DIR + 'item_feat_dict.pkl') 305 | remap(FEATENG_DIR + 'joined_user_behavior.csv', FEATENG_DIR + 'remap_dict.pkl', FEATENG_DIR + 'remap_joined_user_behavior.csv') 306 | sort_log(FEATENG_DIR + 'remap_joined_user_behavior.csv', FEATENG_DIR + 'sorted_remap_joined_user_behavior.csv') 307 | gen_target_seq(FEATENG_DIR + 'sorted_remap_joined_user_behavior.csv', 308 | FEATENG_DIR + 'target_train.txt', FEATENG_DIR + 'target_vali.txt', FEATENG_DIR + 'target_test.txt', FEATENG_DIR + 'user_seq.txt', FEATENG_DIR + 'database.txt', 309 | FEATENG_DIR + 'context_dict_train.pkl', FEATENG_DIR + 'context_dict_vali.pkl', FEATENG_DIR + 'context_dict_test.pkl') 310 | insert_elastic(FEATENG_DIR + 'database.txt') 311 | -------------------------------------------------------------------------------- /code/rec.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn_cell import GRUCell 3 | import numpy as np 4 | 5 | class RecBase(object): 6 | def __init__(self, feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer): 7 | # input placeholders 8 | with tf.name_scope('rec/inputs'): 9 | self.seq_ph = tf.placeholder(tf.int32, [None, b_num, record_fnum], name='seq_ph') 10 | self.seq_length_ph = tf.placeholder(tf.int32, [None,], name='seq_length_ph') 11 | self.target_ph = tf.placeholder(tf.int32, [None, record_fnum], name='target_ph') 12 | self.label_ph = tf.placeholder(tf.int32, [None,], name='label_ph') 13 | 14 | # lr 15 | self.lr = tf.placeholder(tf.float32, []) 16 | # reg lambda 17 | self.reg_lambda = tf.placeholder(tf.float32, []) 18 | # keep prob 19 | self.keep_prob = tf.placeholder(tf.float32, []) 20 | 21 | # embedding 22 | with tf.variable_scope('embedding', reuse=tf.AUTO_REUSE): 23 | if emb_initializer is not None: 24 | self.emb_mtx = tf.get_variable('emb_mtx', initializer=emb_initializer, ) 25 | else: 26 | self.emb_mtx = tf.get_variable('emb_mtx', [feature_size, eb_dim], initializer=tf.truncated_normal_initializer) 27 | self.emb_mtx_mask = tf.constant(value=1., shape=[feature_size - 1, eb_dim]) 28 | self.emb_mtx_mask = tf.concat([tf.constant(value=0., shape=[1, eb_dim]), self.emb_mtx_mask], axis=0) 29 | self.emb_mtx = self.emb_mtx * self.emb_mtx_mask 30 | 31 | self.seq = tf.nn.embedding_lookup(self.emb_mtx, self.seq_ph) 32 | self.seq = tf.reshape(self.seq, [-1, b_num, record_fnum * eb_dim]) 33 | self.target = tf.nn.embedding_lookup(self.emb_mtx, self.target_ph) 34 | self.target = tf.reshape(self.target, [-1, record_fnum * eb_dim]) 35 | 36 | 37 | def build_fc_net(self, inp): 38 | bn1 = tf.layers.batch_normalization(inputs=inp, name='rec_bn1') 39 | fc1 = tf.layers.dense(bn1, 200, activation=tf.nn.relu, name='rec_fc1') 40 | dp1 = tf.nn.dropout(fc1, self.keep_prob, name='rec_dp1') 41 | fc2 = tf.layers.dense(dp1, 80, activation=tf.nn.relu, name='rec_fc2') 42 | dp2 = tf.nn.dropout(fc2, self.keep_prob, name='rec_dp2') 43 | fc3 = tf.layers.dense(dp2, 2, activation=None, name='rec_fc3') 44 | score = tf.nn.softmax(fc3) 45 | # output 46 | self.y_pred = tf.reshape(score[:,0], [-1,]) 47 | 48 | def build_reward(self): 49 | # rig as reward (reward) 50 | self.ground_truth = tf.cast(self.label_ph, tf.float32) 51 | self.reward = self.ground_truth * tf.log(tf.clip_by_value(self.y_pred, 1e-10, 1)) + (1 - self.ground_truth) * tf.log(1 - tf.clip_by_value(self.y_pred, 1e-10, 1)) 52 | self.reward = 1 - (self.reward / tf.log(0.5)) # use RIG as reward signal 53 | self.edge = -tf.ones_like(self.reward) 54 | self.reward = tf.where(self.reward < -1, self.edge, self.reward) 55 | 56 | def build_logloss(self): 57 | # loss 58 | self.log_loss = tf.losses.log_loss(self.label_ph, self.y_pred) 59 | self.loss = self.log_loss 60 | for v in tf.trainable_variables(): 61 | if 'bias' not in v.name and 'emb' not in v.name: 62 | self.loss += self.reg_lambda * tf.nn.l2_loss(v) 63 | 64 | def build_optimizer(self): 65 | # optimizer and training step 66 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, name='rec_optimizer') 67 | self.train_step = self.optimizer.minimize(self.loss) 68 | 69 | 70 | def train(self, sess, batch_data, lr, reg_lambda): 71 | loss, _ = sess.run([self.loss, self.train_step], feed_dict = { 72 | self.seq_ph : batch_data[0], 73 | self.seq_length_ph : batch_data[1], 74 | self.target_ph : batch_data[2], 75 | self.label_ph : batch_data[3], 76 | self.lr : lr, 77 | self.reg_lambda : reg_lambda, 78 | self.keep_prob : 0.8 79 | }) 80 | return loss 81 | 82 | def eval(self, sess, batch_data, reg_lambda): 83 | pred, label, loss = sess.run([self.y_pred, self.label_ph, self.loss], feed_dict = { 84 | self.seq_ph : batch_data[0], 85 | self.seq_length_ph : batch_data[1], 86 | self.target_ph : batch_data[2], 87 | self.label_ph : batch_data[3], 88 | self.reg_lambda : reg_lambda, 89 | self.keep_prob : 1. 90 | }) 91 | 92 | return pred.reshape([-1,]).tolist(), label.reshape([-1,]).tolist(), loss 93 | 94 | def get_reward(self, sess, batch_data): 95 | reward = sess.run(self.reward, feed_dict = { 96 | self.seq_ph : batch_data[0], 97 | self.seq_length_ph : batch_data[1], 98 | self.target_ph : batch_data[2], 99 | self.label_ph : batch_data[3], 100 | self.keep_prob : 1. 101 | }) 102 | return np.reshape(reward, [-1, 1]) #[B,1] 103 | 104 | 105 | def save(self, sess, path): 106 | saver = tf.train.Saver() 107 | saver.save(sess, save_path=path) 108 | 109 | def restore(self, sess, path): 110 | saver = tf.train.Saver() 111 | saver.restore(sess, save_path=path) 112 | print('model restored from {}'.format(path)) 113 | 114 | class RecSum(RecBase): 115 | def __init__(self, feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer): 116 | super(RecSum, self).__init__(feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer) 117 | 118 | # use sum pooling to model the user behaviors, padding is zero (embedding id is also zero) 119 | user_behavior_rep = tf.reduce_sum(self.seq, axis=1) 120 | 121 | inp = tf.concat([user_behavior_rep, self.target], axis=1) 122 | 123 | # fc layer 124 | self.build_fc_net(inp) 125 | self.build_reward() 126 | self.build_logloss() 127 | self.build_optimizer() 128 | 129 | class RecAtt(RecBase): 130 | def __init__(self, feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer): 131 | super(RecAtt, self).__init__(feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer) 132 | mask = tf.sequence_mask(self.seq_length_ph, b_num, dtype=tf.float32) 133 | self.atten, user_behavior_rep = self.attention(self.seq, self.seq, self.target, mask) 134 | self.atten = tf.reshape(self.atten, [-1, b_num]) 135 | inp = tf.concat([user_behavior_rep, self.target], axis=1) 136 | 137 | # fc layer 138 | self.build_fc_net(inp) 139 | self.build_reward() 140 | self.build_logloss() 141 | self.build_optimizer() 142 | 143 | 144 | def attention(self, key, value, query, mask): 145 | # key: [B, T, Dk], query: [B, Dq], mask: [B, T] 146 | _, max_len, k_dim = key.get_shape().as_list() 147 | query = tf.layers.dense(query, k_dim, activation=None) 148 | queries = tf.tile(tf.expand_dims(query, 1), [1, max_len, 1]) # [B, T, Dk] 149 | kq_inter = queries * key 150 | atten = tf.reduce_sum(kq_inter, axis=2) 151 | 152 | mask = tf.equal(mask, tf.ones_like(mask)) #[B, T] 153 | paddings = tf.ones_like(atten) * (-2 ** 32 + 1) 154 | atten = tf.nn.softmax(tf.where(mask, atten, paddings)) #[B, T] 155 | atten = tf.expand_dims(atten, 2) 156 | 157 | res = tf.reduce_sum(atten * value, axis=1) 158 | return atten, res 159 | 160 | -------------------------------------------------------------------------------- /code/rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """RNN helpers for TensorFlow models. 17 | 18 | 19 | @@bidirectional_dynamic_rnn 20 | @@dynamic_rnn 21 | @@raw_rnn 22 | @@static_rnn 23 | @@static_state_saving_rnn 24 | @@static_bidirectional_rnn 25 | """ 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | from tensorflow.python.framework import constant_op 31 | from tensorflow.python.framework import dtypes 32 | from tensorflow.python.framework import ops 33 | from tensorflow.python.framework import tensor_shape 34 | from tensorflow.python.ops import array_ops 35 | from tensorflow.python.ops import control_flow_ops 36 | from tensorflow.python.ops import math_ops 37 | from tensorflow.python.ops import rnn_cell_impl 38 | from tensorflow.python.ops import tensor_array_ops 39 | from tensorflow.python.ops import variable_scope as vs 40 | from tensorflow.python.util import nest 41 | 42 | 43 | # pylint: disable=protected-access 44 | _concat = rnn_cell_impl._concat 45 | _like_rnncell = rnn_cell_impl._like_rnncell 46 | # pylint: enable=protected-access 47 | 48 | 49 | def _transpose_batch_time(x): 50 | """Transpose the batch and time dimensions of a Tensor. 51 | 52 | Retains as much of the static shape information as possible. 53 | 54 | Args: 55 | x: A tensor of rank 2 or higher. 56 | 57 | Returns: 58 | x transposed along the first two dimensions. 59 | 60 | Raises: 61 | ValueError: if `x` is rank 1 or lower. 62 | """ 63 | x_static_shape = x.get_shape() 64 | if x_static_shape.ndims is not None and x_static_shape.ndims < 2: 65 | raise ValueError( 66 | "Expected input tensor %s to have rank at least 2, but saw shape: %s" % 67 | (x, x_static_shape)) 68 | x_rank = array_ops.rank(x) 69 | x_t = array_ops.transpose( 70 | x, array_ops.concat( 71 | ([1, 0], math_ops.range(2, x_rank)), axis=0)) 72 | x_t.set_shape( 73 | tensor_shape.TensorShape([ 74 | x_static_shape[1].value, x_static_shape[0].value 75 | ]).concatenate(x_static_shape[2:])) 76 | return x_t 77 | 78 | 79 | def _best_effort_input_batch_size(flat_input): 80 | """Get static input batch size if available, with fallback to the dynamic one. 81 | 82 | Args: 83 | flat_input: An iterable of time major input Tensors of shape [max_time, 84 | batch_size, ...]. All inputs should have compatible batch sizes. 85 | 86 | Returns: 87 | The batch size in Python integer if available, or a scalar Tensor otherwise. 88 | 89 | Raises: 90 | ValueError: if there is any input with an invalid shape. 91 | """ 92 | for input_ in flat_input: 93 | shape = input_.shape 94 | if shape.ndims is None: 95 | continue 96 | if shape.ndims < 2: 97 | raise ValueError( 98 | "Expected input tensor %s to have rank at least 2" % input_) 99 | batch_size = shape[1].value 100 | if batch_size is not None: 101 | return batch_size 102 | # Fallback to the dynamic batch size of the first input. 103 | return array_ops.shape(flat_input[0])[1] 104 | 105 | 106 | def _infer_state_dtype(explicit_dtype, state): 107 | """Infer the dtype of an RNN state. 108 | 109 | Args: 110 | explicit_dtype: explicitly declared dtype or None. 111 | state: RNN's hidden state. Must be a Tensor or a nested iterable containing 112 | Tensors. 113 | 114 | Returns: 115 | dtype: inferred dtype of hidden state. 116 | 117 | Raises: 118 | ValueError: if `state` has heterogeneous dtypes or is empty. 119 | """ 120 | if explicit_dtype is not None: 121 | return explicit_dtype 122 | elif nest.is_sequence(state): 123 | inferred_dtypes = [element.dtype for element in nest.flatten(state)] 124 | if not inferred_dtypes: 125 | raise ValueError("Unable to infer dtype from empty state.") 126 | all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes]) 127 | if not all_same: 128 | raise ValueError( 129 | "State has tensors of different inferred_dtypes. Unable to infer a " 130 | "single representative dtype.") 131 | return inferred_dtypes[0] 132 | else: 133 | return state.dtype 134 | 135 | 136 | # pylint: disable=unused-argument 137 | def _rnn_step( 138 | time, sequence_length, min_sequence_length, max_sequence_length, 139 | zero_output, state, call_cell, state_size, skip_conditionals=False): 140 | """Calculate one step of a dynamic RNN minibatch. 141 | 142 | Returns an (output, state) pair conditioned on the sequence_lengths. 143 | When skip_conditionals=False, the pseudocode is something like: 144 | 145 | if t >= max_sequence_length: 146 | return (zero_output, state) 147 | if t < min_sequence_length: 148 | return call_cell() 149 | 150 | # Selectively output zeros or output, old state or new state depending 151 | # on if we've finished calculating each row. 152 | new_output, new_state = call_cell() 153 | final_output = np.vstack([ 154 | zero_output if time >= sequence_lengths[r] else new_output_r 155 | for r, new_output_r in enumerate(new_output) 156 | ]) 157 | final_state = np.vstack([ 158 | state[r] if time >= sequence_lengths[r] else new_state_r 159 | for r, new_state_r in enumerate(new_state) 160 | ]) 161 | return (final_output, final_state) 162 | 163 | Args: 164 | time: Python int, the current time step 165 | sequence_length: int32 `Tensor` vector of size [batch_size] 166 | min_sequence_length: int32 `Tensor` scalar, min of sequence_length 167 | max_sequence_length: int32 `Tensor` scalar, max of sequence_length 168 | zero_output: `Tensor` vector of shape [output_size] 169 | state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 170 | or a list/tuple of such tensors. 171 | call_cell: lambda returning tuple of (new_output, new_state) where 172 | new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 173 | new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 174 | state_size: The `cell.state_size` associated with the state. 175 | skip_conditionals: Python bool, whether to skip using the conditional 176 | calculations. This is useful for `dynamic_rnn`, where the input tensor 177 | matches `max_sequence_length`, and using conditionals just slows 178 | everything down. 179 | 180 | Returns: 181 | A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 182 | final_output is a `Tensor` matrix of shape [batch_size, output_size] 183 | final_state is either a single `Tensor` matrix, or a tuple of such 184 | matrices (matching length and shapes of input `state`). 185 | 186 | Raises: 187 | ValueError: If the cell returns a state tuple whose length does not match 188 | that returned by `state_size`. 189 | """ 190 | 191 | # Convert state to a list for ease of use 192 | flat_state = nest.flatten(state) 193 | flat_zero_output = nest.flatten(zero_output) 194 | 195 | def _copy_one_through(output, new_output): 196 | # If the state contains a scalar value we simply pass it through. 197 | if output.shape.ndims == 0: 198 | return new_output 199 | copy_cond = (time >= sequence_length) 200 | with ops.colocate_with(new_output): 201 | return array_ops.where(copy_cond, output, new_output) 202 | 203 | def _copy_some_through(flat_new_output, flat_new_state): 204 | # Use broadcasting select to determine which values should get 205 | # the previous state & zero output, and which values should get 206 | # a calculated state & output. 207 | flat_new_output = [ 208 | _copy_one_through(zero_output, new_output) 209 | for zero_output, new_output in zip(flat_zero_output, flat_new_output)] 210 | flat_new_state = [ 211 | _copy_one_through(state, new_state) 212 | for state, new_state in zip(flat_state, flat_new_state)] 213 | return flat_new_output + flat_new_state 214 | 215 | def _maybe_copy_some_through(): 216 | """Run RNN step. Pass through either no or some past state.""" 217 | new_output, new_state = call_cell() 218 | 219 | nest.assert_same_structure(state, new_state) 220 | 221 | flat_new_state = nest.flatten(new_state) 222 | flat_new_output = nest.flatten(new_output) 223 | return control_flow_ops.cond( 224 | # if t < min_seq_len: calculate and return everything 225 | time < min_sequence_length, lambda: flat_new_output + flat_new_state, 226 | # else copy some of it through 227 | lambda: _copy_some_through(flat_new_output, flat_new_state)) 228 | 229 | # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 230 | # but benefits from removing cond() and its gradient. We should 231 | # profile with and without this switch here. 232 | if skip_conditionals: 233 | # Instead of using conditionals, perform the selective copy at all time 234 | # steps. This is faster when max_seq_len is equal to the number of unrolls 235 | # (which is typical for dynamic_rnn). 236 | new_output, new_state = call_cell() 237 | nest.assert_same_structure(state, new_state) 238 | new_state = nest.flatten(new_state) 239 | new_output = nest.flatten(new_output) 240 | final_output_and_state = _copy_some_through(new_output, new_state) 241 | else: 242 | empty_update = lambda: flat_zero_output + flat_state 243 | final_output_and_state = control_flow_ops.cond( 244 | # if t >= max_seq_len: copy all state through, output zeros 245 | time >= max_sequence_length, empty_update, 246 | # otherwise calculation is required: copy some or all of it through 247 | _maybe_copy_some_through) 248 | 249 | if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 250 | raise ValueError("Internal error: state and output were not concatenated " 251 | "correctly.") 252 | final_output = final_output_and_state[:len(flat_zero_output)] 253 | final_state = final_output_and_state[len(flat_zero_output):] 254 | 255 | for output, flat_output in zip(final_output, flat_zero_output): 256 | output.set_shape(flat_output.get_shape()) 257 | for substate, flat_substate in zip(final_state, flat_state): 258 | substate.set_shape(flat_substate.get_shape()) 259 | 260 | final_output = nest.pack_sequence_as( 261 | structure=zero_output, flat_sequence=final_output) 262 | final_state = nest.pack_sequence_as( 263 | structure=state, flat_sequence=final_state) 264 | 265 | return final_output, final_state 266 | 267 | 268 | def _reverse_seq(input_seq, lengths): 269 | """Reverse a list of Tensors up to specified lengths. 270 | 271 | Args: 272 | input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) 273 | or nested tuples of tensors. 274 | lengths: A `Tensor` of dimension batch_size, containing lengths for each 275 | sequence in the batch. If "None" is specified, simply reverses 276 | the list. 277 | 278 | Returns: 279 | time-reversed sequence 280 | """ 281 | if lengths is None: 282 | return list(reversed(input_seq)) 283 | 284 | flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) 285 | 286 | flat_results = [[] for _ in range(len(input_seq))] 287 | for sequence in zip(*flat_input_seq): 288 | input_shape = tensor_shape.unknown_shape( 289 | ndims=sequence[0].get_shape().ndims) 290 | for input_ in sequence: 291 | input_shape.merge_with(input_.get_shape()) 292 | input_.set_shape(input_shape) 293 | 294 | # Join into (time, batch_size, depth) 295 | s_joined = array_ops.stack(sequence) 296 | 297 | # Reverse along dimension 0 298 | s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) 299 | # Split again into list 300 | result = array_ops.unstack(s_reversed) 301 | for r, flat_result in zip(result, flat_results): 302 | r.set_shape(input_shape) 303 | flat_result.append(r) 304 | 305 | results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) 306 | for input_, flat_result in zip(input_seq, flat_results)] 307 | return results 308 | 309 | 310 | def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, 311 | initial_state_fw=None, initial_state_bw=None, 312 | dtype=None, parallel_iterations=None, 313 | swap_memory=False, time_major=False, scope=None): 314 | """Creates a dynamic version of bidirectional recurrent neural network. 315 | 316 | Takes input and builds independent forward and backward RNNs. The input_size 317 | of forward and backward cell must match. The initial state for both directions 318 | is zero by default (but can be set optionally) and no intermediate states are 319 | ever returned -- the network is fully unrolled for the given (passed in) 320 | length(s) of the sequence(s) or completely unrolled if length(s) is not 321 | given. 322 | 323 | Args: 324 | cell_fw: An instance of RNNCell, to be used for forward direction. 325 | cell_bw: An instance of RNNCell, to be used for backward direction. 326 | inputs: The RNN inputs. 327 | If time_major == False (default), this must be a tensor of shape: 328 | `[batch_size, max_time, ...]`, or a nested tuple of such elements. 329 | If time_major == True, this must be a tensor of shape: 330 | `[max_time, batch_size, ...]`, or a nested tuple of such elements. 331 | sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 332 | containing the actual lengths for each of the sequences in the batch. 333 | If not provided, all batch entries are assumed to be full sequences; and 334 | time reversal is applied from time `0` to `max_time` for each sequence. 335 | initial_state_fw: (optional) An initial state for the forward RNN. 336 | This must be a tensor of appropriate type and shape 337 | `[batch_size, cell_fw.state_size]`. 338 | If `cell_fw.state_size` is a tuple, this should be a tuple of 339 | tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. 340 | initial_state_bw: (optional) Same as for `initial_state_fw`, but using 341 | the corresponding properties of `cell_bw`. 342 | dtype: (optional) The data type for the initial states and expected output. 343 | Required if initial_states are not provided or RNN states have a 344 | heterogeneous dtype. 345 | parallel_iterations: (Default: 32). The number of iterations to run in 346 | parallel. Those operations which do not have any temporal dependency 347 | and can be run in parallel, will be. This parameter trades off 348 | time for space. Values >> 1 use more memory but take less time, 349 | while smaller values use less memory but computations take longer. 350 | swap_memory: Transparently swap the tensors produced in forward inference 351 | but needed for back prop from GPU to CPU. This allows training RNNs 352 | which would typically not fit on a single GPU, with very minimal (or no) 353 | performance penalty. 354 | time_major: The shape format of the `inputs` and `outputs` Tensors. 355 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 356 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 357 | Using `time_major = True` is a bit more efficient because it avoids 358 | transposes at the beginning and end of the RNN calculation. However, 359 | most TensorFlow data is batch-major, so by default this function 360 | accepts input and emits output in batch-major form. 361 | scope: VariableScope for the created subgraph; defaults to 362 | "bidirectional_rnn" 363 | 364 | Returns: 365 | A tuple (outputs, output_states) where: 366 | outputs: A tuple (output_fw, output_bw) containing the forward and 367 | the backward rnn output `Tensor`. 368 | If time_major == False (default), 369 | output_fw will be a `Tensor` shaped: 370 | `[batch_size, max_time, cell_fw.output_size]` 371 | and output_bw will be a `Tensor` shaped: 372 | `[batch_size, max_time, cell_bw.output_size]`. 373 | If time_major == True, 374 | output_fw will be a `Tensor` shaped: 375 | `[max_time, batch_size, cell_fw.output_size]` 376 | and output_bw will be a `Tensor` shaped: 377 | `[max_time, batch_size, cell_bw.output_size]`. 378 | It returns a tuple instead of a single concatenated `Tensor`, unlike 379 | in the `bidirectional_rnn`. If the concatenated one is preferred, 380 | the forward and backward outputs can be concatenated as 381 | `tf.concat(outputs, 2)`. 382 | output_states: A tuple (output_state_fw, output_state_bw) containing 383 | the forward and the backward final states of bidirectional rnn. 384 | 385 | Raises: 386 | TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 387 | """ 388 | 389 | if not _like_rnncell(cell_fw): 390 | raise TypeError("cell_fw must be an instance of RNNCell") 391 | if not _like_rnncell(cell_bw): 392 | raise TypeError("cell_bw must be an instance of RNNCell") 393 | 394 | with vs.variable_scope(scope or "bidirectional_rnn"): 395 | # Forward direction 396 | with vs.variable_scope("fw") as fw_scope: 397 | output_fw, output_state_fw = dynamic_rnn( 398 | cell=cell_fw, inputs=inputs, sequence_length=sequence_length, 399 | initial_state=initial_state_fw, dtype=dtype, 400 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 401 | time_major=time_major, scope=fw_scope) 402 | 403 | # Backward direction 404 | if not time_major: 405 | time_dim = 1 406 | batch_dim = 0 407 | else: 408 | time_dim = 0 409 | batch_dim = 1 410 | 411 | def _reverse(input_, seq_lengths, seq_dim, batch_dim): 412 | if seq_lengths is not None: 413 | return array_ops.reverse_sequence( 414 | input=input_, seq_lengths=seq_lengths, 415 | seq_dim=seq_dim, batch_dim=batch_dim) 416 | else: 417 | return array_ops.reverse(input_, axis=[seq_dim]) 418 | 419 | with vs.variable_scope("bw") as bw_scope: 420 | inputs_reverse = _reverse( 421 | inputs, seq_lengths=sequence_length, 422 | seq_dim=time_dim, batch_dim=batch_dim) 423 | tmp, output_state_bw = dynamic_rnn( 424 | cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, 425 | initial_state=initial_state_bw, dtype=dtype, 426 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 427 | time_major=time_major, scope=bw_scope) 428 | 429 | output_bw = _reverse( 430 | tmp, seq_lengths=sequence_length, 431 | seq_dim=time_dim, batch_dim=batch_dim) 432 | 433 | outputs = (output_fw, output_bw) 434 | output_states = (output_state_fw, output_state_bw) 435 | 436 | return (outputs, output_states) 437 | 438 | 439 | def dynamic_rnn(cell, inputs, att_scores=None, sequence_length=None, initial_state=None, 440 | dtype=None, parallel_iterations=None, swap_memory=False, 441 | time_major=False, scope=None): 442 | """Creates a recurrent neural network specified by RNNCell `cell`. 443 | 444 | Performs fully dynamic unrolling of `inputs`. 445 | 446 | Example: 447 | 448 | ```python 449 | # create a BasicRNNCell 450 | rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) 451 | 452 | # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 453 | 454 | # defining initial state 455 | initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 456 | 457 | # 'state' is a tensor of shape [batch_size, cell_state_size] 458 | outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, 459 | initial_state=initial_state, 460 | dtype=tf.float32) 461 | ``` 462 | 463 | ```python 464 | # create 2 LSTMCells 465 | rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 466 | 467 | # create a RNN cell composed sequentially of a number of RNNCells 468 | multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) 469 | 470 | # 'outputs' is a tensor of shape [batch_size, max_time, 256] 471 | # 'state' is a N-tuple where N is the number of LSTMCells containing a 472 | # tf.contrib.rnn.LSTMStateTuple for each cell 473 | outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, 474 | inputs=data, 475 | dtype=tf.float32) 476 | ``` 477 | 478 | 479 | Args: 480 | cell: An instance of RNNCell. 481 | inputs: The RNN inputs. 482 | If `time_major == False` (default), this must be a `Tensor` of shape: 483 | `[batch_size, max_time, ...]`, or a nested tuple of such 484 | elements. 485 | If `time_major == True`, this must be a `Tensor` of shape: 486 | `[max_time, batch_size, ...]`, or a nested tuple of such 487 | elements. 488 | This may also be a (possibly nested) tuple of Tensors satisfying 489 | this property. The first two dimensions must match across all the inputs, 490 | but otherwise the ranks and other shape components may differ. 491 | In this case, input to `cell` at each time-step will replicate the 492 | structure of these tuples, except for the time dimension (from which the 493 | time is taken). 494 | The input to `cell` at each time step will be a `Tensor` or (possibly 495 | nested) tuple of Tensors each with dimensions `[batch_size, ...]`. 496 | sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 497 | Used to copy-through state and zero-out outputs when past a batch 498 | element's sequence length. So it's more for correctness than performance. 499 | initial_state: (optional) An initial state for the RNN. 500 | If `cell.state_size` is an integer, this must be 501 | a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 502 | If `cell.state_size` is a tuple, this should be a tuple of 503 | tensors having shapes `[batch_size, s] for s in cell.state_size`. 504 | dtype: (optional) The data type for the initial state and expected output. 505 | Required if initial_state is not provided or RNN state has a heterogeneous 506 | dtype. 507 | parallel_iterations: (Default: 32). The number of iterations to run in 508 | parallel. Those operations which do not have any temporal dependency 509 | and can be run in parallel, will be. This parameter trades off 510 | time for space. Values >> 1 use more memory but take less time, 511 | while smaller values use less memory but computations take longer. 512 | swap_memory: Transparently swap the tensors produced in forward inference 513 | but needed for back prop from GPU to CPU. This allows training RNNs 514 | which would typically not fit on a single GPU, with very minimal (or no) 515 | performance penalty. 516 | time_major: The shape format of the `inputs` and `outputs` Tensors. 517 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 518 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 519 | Using `time_major = True` is a bit more efficient because it avoids 520 | transposes at the beginning and end of the RNN calculation. However, 521 | most TensorFlow data is batch-major, so by default this function 522 | accepts input and emits output in batch-major form. 523 | scope: VariableScope for the created subgraph; defaults to "rnn". 524 | 525 | Returns: 526 | A pair (outputs, state) where: 527 | 528 | outputs: The RNN output `Tensor`. 529 | 530 | If time_major == False (default), this will be a `Tensor` shaped: 531 | `[batch_size, max_time, cell.output_size]`. 532 | 533 | If time_major == True, this will be a `Tensor` shaped: 534 | `[max_time, batch_size, cell.output_size]`. 535 | 536 | Note, if `cell.output_size` is a (possibly nested) tuple of integers 537 | or `TensorShape` objects, then `outputs` will be a tuple having the 538 | same structure as `cell.output_size`, containing Tensors having shapes 539 | corresponding to the shape data in `cell.output_size`. 540 | 541 | state: The final state. If `cell.state_size` is an int, this 542 | will be shaped `[batch_size, cell.state_size]`. If it is a 543 | `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 544 | If it is a (possibly nested) tuple of ints or `TensorShape`, this will 545 | be a tuple having the corresponding shapes. If cells are `LSTMCells` 546 | `state` will be a tuple containing a `LSTMStateTuple` for each cell. 547 | 548 | Raises: 549 | TypeError: If `cell` is not an instance of RNNCell. 550 | ValueError: If inputs is None or an empty list. 551 | """ 552 | if not _like_rnncell(cell): 553 | raise TypeError("cell must be an instance of RNNCell") 554 | 555 | # By default, time_major==False and inputs are batch-major: shaped 556 | # [batch, time, depth] 557 | # For internal calculations, we transpose to [time, batch, depth] 558 | flat_input = nest.flatten(inputs) 559 | 560 | if not time_major: 561 | # (B,T,D) => (T,B,D) 562 | flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 563 | flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 564 | 565 | parallel_iterations = parallel_iterations or 32 566 | if sequence_length is not None: 567 | sequence_length = math_ops.to_int32(sequence_length) 568 | if sequence_length.get_shape().ndims not in (None, 1): 569 | raise ValueError( 570 | "sequence_length must be a vector of length batch_size, " 571 | "but saw shape: %s" % sequence_length.get_shape()) 572 | sequence_length = array_ops.identity( # Just to find it in the graph. 573 | sequence_length, name="sequence_length") 574 | 575 | # Create a new scope in which the caching device is either 576 | # determined by the parent scope, or is set to place the cached 577 | # Variable using the same placement as for the rest of the RNN. 578 | with vs.variable_scope(scope or "rnn") as varscope: 579 | if varscope.caching_device is None: 580 | varscope.set_caching_device(lambda op: op.device) 581 | batch_size = _best_effort_input_batch_size(flat_input) 582 | 583 | if initial_state is not None: 584 | state = initial_state 585 | else: 586 | if not dtype: 587 | raise ValueError("If there is no initial_state, you must give a dtype.") 588 | state = cell.zero_state(batch_size, dtype) 589 | 590 | def _assert_has_shape(x, shape): 591 | x_shape = array_ops.shape(x) 592 | packed_shape = array_ops.stack(shape) 593 | return control_flow_ops.Assert( 594 | math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), 595 | ["Expected shape for Tensor %s is " % x.name, 596 | packed_shape, " but saw shape: ", x_shape]) 597 | 598 | if sequence_length is not None: 599 | # Perform some shape validation 600 | with ops.control_dependencies( 601 | [_assert_has_shape(sequence_length, [batch_size])]): 602 | sequence_length = array_ops.identity( 603 | sequence_length, name="CheckSeqLen") 604 | 605 | inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 606 | 607 | (outputs, final_state) = _dynamic_rnn_loop( 608 | cell, 609 | inputs, 610 | state, 611 | parallel_iterations=parallel_iterations, 612 | swap_memory=swap_memory, 613 | att_scores = att_scores, 614 | sequence_length=sequence_length, 615 | dtype=dtype) 616 | 617 | # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 618 | # If we are performing batch-major calculations, transpose output back 619 | # to shape [batch, time, depth] 620 | if not time_major: 621 | # (T,B,D) => (B,T,D) 622 | outputs = nest.map_structure(_transpose_batch_time, outputs) 623 | 624 | return (outputs, final_state) 625 | 626 | 627 | def _dynamic_rnn_loop(cell, 628 | inputs, 629 | initial_state, 630 | parallel_iterations, 631 | swap_memory, 632 | att_scores = None, 633 | sequence_length=None, 634 | dtype=None): 635 | """Internal implementation of Dynamic RNN. 636 | 637 | Args: 638 | cell: An instance of RNNCell. 639 | inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 640 | tuple of such elements. 641 | initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 642 | `cell.state_size` is a tuple, then this should be a tuple of 643 | tensors having shapes `[batch_size, s] for s in cell.state_size`. 644 | parallel_iterations: Positive Python int. 645 | swap_memory: A Python boolean 646 | sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 647 | dtype: (optional) Expected dtype of output. If not specified, inferred from 648 | initial_state. 649 | 650 | Returns: 651 | Tuple `(final_outputs, final_state)`. 652 | final_outputs: 653 | A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 654 | `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 655 | objects, then this returns a (possibly nsted) tuple of Tensors matching 656 | the corresponding shapes. 657 | final_state: 658 | A `Tensor`, or possibly nested tuple of Tensors, matching in length 659 | and shapes to `initial_state`. 660 | 661 | Raises: 662 | ValueError: If the input depth cannot be inferred via shape inference 663 | from the inputs. 664 | """ 665 | state = initial_state 666 | assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 667 | 668 | state_size = cell.state_size 669 | 670 | flat_input = nest.flatten(inputs) 671 | flat_output_size = nest.flatten(cell.output_size) 672 | 673 | # Construct an initial output 674 | input_shape = array_ops.shape(flat_input[0]) 675 | time_steps = input_shape[0] 676 | batch_size = _best_effort_input_batch_size(flat_input) 677 | 678 | inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3) 679 | for input_ in flat_input) 680 | 681 | const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2] 682 | 683 | for shape in inputs_got_shape: 684 | if not shape[2:].is_fully_defined(): 685 | raise ValueError( 686 | "Input size (depth of inputs) must be accessible via shape inference," 687 | " but saw value None.") 688 | got_time_steps = shape[0].value 689 | got_batch_size = shape[1].value 690 | if const_time_steps != got_time_steps: 691 | raise ValueError( 692 | "Time steps is not the same for all the elements in the input in a " 693 | "batch.") 694 | if const_batch_size != got_batch_size: 695 | raise ValueError( 696 | "Batch_size is not the same for all the elements in the input.") 697 | 698 | # Prepare dynamic conditional copying of state & output 699 | def _create_zero_arrays(size): 700 | size = _concat(batch_size, size) 701 | return array_ops.zeros( 702 | array_ops.stack(size), _infer_state_dtype(dtype, state)) 703 | 704 | flat_zero_output = tuple(_create_zero_arrays(output) 705 | for output in flat_output_size) 706 | zero_output = nest.pack_sequence_as(structure=cell.output_size, 707 | flat_sequence=flat_zero_output) 708 | 709 | if sequence_length is not None: 710 | min_sequence_length = math_ops.reduce_min(sequence_length) 711 | max_sequence_length = math_ops.reduce_max(sequence_length) 712 | 713 | time = array_ops.constant(0, dtype=dtypes.int32, name="time") 714 | 715 | with ops.name_scope("dynamic_rnn") as scope: 716 | base_name = scope 717 | 718 | def _create_ta(name, dtype): 719 | return tensor_array_ops.TensorArray(dtype=dtype, 720 | size=time_steps, 721 | tensor_array_name=base_name + name) 722 | 723 | output_ta = tuple(_create_ta("output_%d" % i, 724 | _infer_state_dtype(dtype, state)) 725 | for i in range(len(flat_output_size))) 726 | input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) 727 | for i in range(len(flat_input))) 728 | 729 | input_ta = tuple(ta.unstack(input_) 730 | for ta, input_ in zip(input_ta, flat_input)) 731 | 732 | def _time_step(time, output_ta_t, state, att_scores=None): 733 | """Take a time step of the dynamic RNN. 734 | 735 | Args: 736 | time: int32 scalar Tensor. 737 | output_ta_t: List of `TensorArray`s that represent the output. 738 | state: nested tuple of vector tensors that represent the state. 739 | 740 | Returns: 741 | The tuple (time + 1, output_ta_t with updated flow, new_state). 742 | """ 743 | 744 | input_t = tuple(ta.read(time) for ta in input_ta) 745 | # Restore some shape information 746 | for input_, shape in zip(input_t, inputs_got_shape): 747 | input_.set_shape(shape[1:]) 748 | 749 | input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) 750 | if att_scores is not None: 751 | att_score = att_scores[:, time, :] 752 | call_cell = lambda: cell(input_t, state, att_score) 753 | else: 754 | call_cell = lambda: cell(input_t, state) 755 | 756 | if sequence_length is not None: 757 | (output, new_state) = _rnn_step( 758 | time=time, 759 | sequence_length=sequence_length, 760 | min_sequence_length=min_sequence_length, 761 | max_sequence_length=max_sequence_length, 762 | zero_output=zero_output, 763 | state=state, 764 | call_cell=call_cell, 765 | state_size=state_size, 766 | skip_conditionals=True) 767 | else: 768 | (output, new_state) = call_cell() 769 | 770 | # Pack state if using state tuples 771 | output = nest.flatten(output) 772 | 773 | output_ta_t = tuple( 774 | ta.write(time, out) for ta, out in zip(output_ta_t, output)) 775 | if att_scores is not None: 776 | return (time + 1, output_ta_t, new_state, att_scores) 777 | else: 778 | return (time + 1, output_ta_t, new_state) 779 | 780 | if att_scores is not None: 781 | _, output_final_ta, final_state, _ = control_flow_ops.while_loop( 782 | cond=lambda time, *_: time < time_steps, 783 | body=_time_step, 784 | loop_vars=(time, output_ta, state, att_scores), 785 | parallel_iterations=parallel_iterations, 786 | swap_memory=swap_memory) 787 | else: 788 | _, output_final_ta, final_state = control_flow_ops.while_loop( 789 | cond=lambda time, *_: time < time_steps, 790 | body=_time_step, 791 | loop_vars=(time, output_ta, state), 792 | parallel_iterations=parallel_iterations, 793 | swap_memory=swap_memory) 794 | 795 | # Unpack final output if not using output tuples. 796 | final_outputs = tuple(ta.stack() for ta in output_final_ta) 797 | 798 | # Restore some shape information 799 | for output, output_size in zip(final_outputs, flat_output_size): 800 | shape = _concat( 801 | [const_time_steps, const_batch_size], output_size, static=True) 802 | output.set_shape(shape) 803 | 804 | final_outputs = nest.pack_sequence_as( 805 | structure=cell.output_size, flat_sequence=final_outputs) 806 | 807 | return (final_outputs, final_state) 808 | 809 | 810 | def raw_rnn(cell, loop_fn, 811 | parallel_iterations=None, swap_memory=False, scope=None): 812 | """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. 813 | 814 | **NOTE: This method is still in testing, and the API may change.** 815 | 816 | This function is a more primitive version of `dynamic_rnn` that provides 817 | more direct access to the inputs each iteration. It also provides more 818 | control over when to start and finish reading the sequence, and 819 | what to emit for the output. 820 | 821 | For example, it can be used to implement the dynamic decoder of a seq2seq 822 | model. 823 | 824 | Instead of working with `Tensor` objects, most operations work with 825 | `TensorArray` objects directly. 826 | 827 | The operation of `raw_rnn`, in pseudo-code, is basically the following: 828 | 829 | ```python 830 | time = tf.constant(0, dtype=tf.int32) 831 | (finished, next_input, initial_state, _, loop_state) = loop_fn( 832 | time=time, cell_output=None, cell_state=None, loop_state=None) 833 | emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) 834 | state = initial_state 835 | while not all(finished): 836 | (output, cell_state) = cell(next_input, state) 837 | (next_finished, next_input, next_state, emit, loop_state) = loop_fn( 838 | time=time + 1, cell_output=output, cell_state=cell_state, 839 | loop_state=loop_state) 840 | # Emit zeros and copy forward state for minibatch entries that are finished. 841 | state = tf.where(finished, state, next_state) 842 | emit = tf.where(finished, tf.zeros_like(emit), emit) 843 | emit_ta = emit_ta.write(time, emit) 844 | # If any new minibatch entries are marked as finished, mark these. 845 | finished = tf.logical_or(finished, next_finished) 846 | time += 1 847 | return (emit_ta, state, loop_state) 848 | ``` 849 | 850 | with the additional properties that output and state may be (possibly nested) 851 | tuples, as determined by `cell.output_size` and `cell.state_size`, and 852 | as a result the final `state` and `emit_ta` may themselves be tuples. 853 | 854 | A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this: 855 | 856 | ```python 857 | inputs = tf.placeholder(shape=(max_time, batch_size, input_depth), 858 | dtype=tf.float32) 859 | sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32) 860 | inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 861 | inputs_ta = inputs_ta.unstack(inputs) 862 | 863 | cell = tf.contrib.rnn.LSTMCell(num_units) 864 | 865 | def loop_fn(time, cell_output, cell_state, loop_state): 866 | emit_output = cell_output # == None for time == 0 867 | if cell_output is None: # time == 0 868 | next_cell_state = cell.zero_state(batch_size, tf.float32) 869 | else: 870 | next_cell_state = cell_state 871 | elements_finished = (time >= sequence_length) 872 | finished = tf.reduce_all(elements_finished) 873 | next_input = tf.cond( 874 | finished, 875 | lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), 876 | lambda: inputs_ta.read(time)) 877 | next_loop_state = None 878 | return (elements_finished, next_input, next_cell_state, 879 | emit_output, next_loop_state) 880 | 881 | outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) 882 | outputs = outputs_ta.stack() 883 | ``` 884 | 885 | Args: 886 | cell: An instance of RNNCell. 887 | loop_fn: A callable that takes inputs 888 | `(time, cell_output, cell_state, loop_state)` 889 | and returns the tuple 890 | `(finished, next_input, next_cell_state, emit_output, next_loop_state)`. 891 | Here `time` is an int32 scalar `Tensor`, `cell_output` is a 892 | `Tensor` or (possibly nested) tuple of tensors as determined by 893 | `cell.output_size`, and `cell_state` is a `Tensor` 894 | or (possibly nested) tuple of tensors, as determined by the `loop_fn` 895 | on its first call (and should match `cell.state_size`). 896 | The outputs are: `finished`, a boolean `Tensor` of 897 | shape `[batch_size]`, `next_input`: the next input to feed to `cell`, 898 | `next_cell_state`: the next state to feed to `cell`, 899 | and `emit_output`: the output to store for this iteration. 900 | 901 | Note that `emit_output` should be a `Tensor` or (possibly nested) 902 | tuple of tensors with shapes and structure matching `cell.output_size` 903 | and `cell_output` above. The parameter `cell_state` and output 904 | `next_cell_state` may be either a single or (possibly nested) tuple 905 | of tensors. The parameter `loop_state` and 906 | output `next_loop_state` may be either a single or (possibly nested) tuple 907 | of `Tensor` and `TensorArray` objects. This last parameter 908 | may be ignored by `loop_fn` and the return value may be `None`. If it 909 | is not `None`, then the `loop_state` will be propagated through the RNN 910 | loop, for use purely by `loop_fn` to keep track of its own state. 911 | The `next_loop_state` parameter returned may be `None`. 912 | 913 | The first call to `loop_fn` will be `time = 0`, `cell_output = None`, 914 | `cell_state = None`, and `loop_state = None`. For this call: 915 | The `next_cell_state` value should be the value with which to initialize 916 | the cell's state. It may be a final state from a previous RNN or it 917 | may be the output of `cell.zero_state()`. It should be a 918 | (possibly nested) tuple structure of tensors. 919 | If `cell.state_size` is an integer, this must be 920 | a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 921 | If `cell.state_size` is a `TensorShape`, this must be a `Tensor` of 922 | appropriate type and shape `[batch_size] + cell.state_size`. 923 | If `cell.state_size` is a (possibly nested) tuple of ints or 924 | `TensorShape`, this will be a tuple having the corresponding shapes. 925 | The `emit_output` value may be either `None` or a (possibly nested) 926 | tuple structure of tensors, e.g., 927 | `(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. 928 | If this first `emit_output` return value is `None`, 929 | then the `emit_ta` result of `raw_rnn` will have the same structure and 930 | dtypes as `cell.output_size`. Otherwise `emit_ta` will have the same 931 | structure, shapes (prepended with a `batch_size` dimension), and dtypes 932 | as `emit_output`. The actual values returned for `emit_output` at this 933 | initializing call are ignored. Note, this emit structure must be 934 | consistent across all time steps. 935 | 936 | parallel_iterations: (Default: 32). The number of iterations to run in 937 | parallel. Those operations which do not have any temporal dependency 938 | and can be run in parallel, will be. This parameter trades off 939 | time for space. Values >> 1 use more memory but take less time, 940 | while smaller values use less memory but computations take longer. 941 | swap_memory: Transparently swap the tensors produced in forward inference 942 | but needed for back prop from GPU to CPU. This allows training RNNs 943 | which would typically not fit on a single GPU, with very minimal (or no) 944 | performance penalty. 945 | scope: VariableScope for the created subgraph; defaults to "rnn". 946 | 947 | Returns: 948 | A tuple `(emit_ta, final_state, final_loop_state)` where: 949 | 950 | `emit_ta`: The RNN output `TensorArray`. 951 | If `loop_fn` returns a (possibly nested) set of Tensors for 952 | `emit_output` during initialization, (inputs `time = 0`, 953 | `cell_output = None`, and `loop_state = None`), then `emit_ta` will 954 | have the same structure, dtypes, and shapes as `emit_output` instead. 955 | If `loop_fn` returns `emit_output = None` during this call, 956 | the structure of `cell.output_size` is used: 957 | If `cell.output_size` is a (possibly nested) tuple of integers 958 | or `TensorShape` objects, then `emit_ta` will be a tuple having the 959 | same structure as `cell.output_size`, containing TensorArrays whose 960 | elements' shapes correspond to the shape data in `cell.output_size`. 961 | 962 | `final_state`: The final cell state. If `cell.state_size` is an int, this 963 | will be shaped `[batch_size, cell.state_size]`. If it is a 964 | `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 965 | If it is a (possibly nested) tuple of ints or `TensorShape`, this will 966 | be a tuple having the corresponding shapes. 967 | 968 | `final_loop_state`: The final loop state as returned by `loop_fn`. 969 | 970 | Raises: 971 | TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not 972 | a `callable`. 973 | """ 974 | 975 | if not _like_rnncell(cell): 976 | raise TypeError("cell must be an instance of RNNCell") 977 | if not callable(loop_fn): 978 | raise TypeError("loop_fn must be a callable") 979 | 980 | parallel_iterations = parallel_iterations or 32 981 | 982 | # Create a new scope in which the caching device is either 983 | # determined by the parent scope, or is set to place the cached 984 | # Variable using the same placement as for the rest of the RNN. 985 | with vs.variable_scope(scope or "rnn") as varscope: 986 | if varscope.caching_device is None: 987 | varscope.set_caching_device(lambda op: op.device) 988 | 989 | time = constant_op.constant(0, dtype=dtypes.int32) 990 | (elements_finished, next_input, initial_state, emit_structure, 991 | init_loop_state) = loop_fn( 992 | time, None, None, None) # time, cell_output, cell_state, loop_state 993 | flat_input = nest.flatten(next_input) 994 | 995 | # Need a surrogate loop state for the while_loop if none is available. 996 | loop_state = (init_loop_state if init_loop_state is not None 997 | else constant_op.constant(0, dtype=dtypes.int32)) 998 | 999 | input_shape = [input_.get_shape() for input_ in flat_input] 1000 | static_batch_size = input_shape[0][0] 1001 | 1002 | for input_shape_i in input_shape: 1003 | # Static verification that batch sizes all match 1004 | static_batch_size.merge_with(input_shape_i[0]) 1005 | 1006 | batch_size = static_batch_size.value 1007 | if batch_size is None: 1008 | batch_size = array_ops.shape(flat_input[0])[0] 1009 | 1010 | nest.assert_same_structure(initial_state, cell.state_size) 1011 | state = initial_state 1012 | flat_state = nest.flatten(state) 1013 | flat_state = [ops.convert_to_tensor(s) for s in flat_state] 1014 | state = nest.pack_sequence_as(structure=state, 1015 | flat_sequence=flat_state) 1016 | 1017 | if emit_structure is not None: 1018 | flat_emit_structure = nest.flatten(emit_structure) 1019 | flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else 1020 | array_ops.shape(emit) for emit in flat_emit_structure] 1021 | flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] 1022 | else: 1023 | emit_structure = cell.output_size 1024 | flat_emit_size = nest.flatten(emit_structure) 1025 | flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) 1026 | 1027 | flat_emit_ta = [ 1028 | tensor_array_ops.TensorArray( 1029 | dtype=dtype_i, dynamic_size=True, size=0, name="rnn_output_%d" % i) 1030 | for i, dtype_i in enumerate(flat_emit_dtypes)] 1031 | emit_ta = nest.pack_sequence_as(structure=emit_structure, 1032 | flat_sequence=flat_emit_ta) 1033 | flat_zero_emit = [ 1034 | array_ops.zeros(_concat(batch_size, size_i), dtype_i) 1035 | for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)] 1036 | zero_emit = nest.pack_sequence_as(structure=emit_structure, 1037 | flat_sequence=flat_zero_emit) 1038 | 1039 | def condition(unused_time, elements_finished, *_): 1040 | return math_ops.logical_not(math_ops.reduce_all(elements_finished)) 1041 | 1042 | def body(time, elements_finished, current_input, 1043 | emit_ta, state, loop_state): 1044 | """Internal while loop body for raw_rnn. 1045 | 1046 | Args: 1047 | time: time scalar. 1048 | elements_finished: batch-size vector. 1049 | current_input: possibly nested tuple of input tensors. 1050 | emit_ta: possibly nested tuple of output TensorArrays. 1051 | state: possibly nested tuple of state tensors. 1052 | loop_state: possibly nested tuple of loop state tensors. 1053 | 1054 | Returns: 1055 | Tuple having the same size as Args but with updated values. 1056 | """ 1057 | (next_output, cell_state) = cell(current_input, state) 1058 | 1059 | nest.assert_same_structure(state, cell_state) 1060 | nest.assert_same_structure(cell.output_size, next_output) 1061 | 1062 | next_time = time + 1 1063 | (next_finished, next_input, next_state, emit_output, 1064 | next_loop_state) = loop_fn( 1065 | next_time, next_output, cell_state, loop_state) 1066 | 1067 | nest.assert_same_structure(state, next_state) 1068 | nest.assert_same_structure(current_input, next_input) 1069 | nest.assert_same_structure(emit_ta, emit_output) 1070 | 1071 | # If loop_fn returns None for next_loop_state, just reuse the 1072 | # previous one. 1073 | loop_state = loop_state if next_loop_state is None else next_loop_state 1074 | 1075 | def _copy_some_through(current, candidate): 1076 | """Copy some tensors through via array_ops.where.""" 1077 | def copy_fn(cur_i, cand_i): 1078 | with ops.colocate_with(cand_i): 1079 | return array_ops.where(elements_finished, cur_i, cand_i) 1080 | return nest.map_structure(copy_fn, current, candidate) 1081 | 1082 | emit_output = _copy_some_through(zero_emit, emit_output) 1083 | next_state = _copy_some_through(state, next_state) 1084 | 1085 | emit_ta = nest.map_structure( 1086 | lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) 1087 | 1088 | elements_finished = math_ops.logical_or(elements_finished, next_finished) 1089 | 1090 | return (next_time, elements_finished, next_input, 1091 | emit_ta, next_state, loop_state) 1092 | 1093 | returned = control_flow_ops.while_loop( 1094 | condition, body, loop_vars=[ 1095 | time, elements_finished, next_input, 1096 | emit_ta, state, loop_state], 1097 | parallel_iterations=parallel_iterations, 1098 | swap_memory=swap_memory) 1099 | 1100 | (emit_ta, final_state, final_loop_state) = returned[-3:] 1101 | 1102 | if init_loop_state is None: 1103 | final_loop_state = None 1104 | 1105 | return (emit_ta, final_state, final_loop_state) 1106 | 1107 | 1108 | def static_rnn(cell, 1109 | inputs, 1110 | initial_state=None, 1111 | dtype=None, 1112 | sequence_length=None, 1113 | scope=None): 1114 | """Creates a recurrent neural network specified by RNNCell `cell`. 1115 | 1116 | The simplest form of RNN network generated is: 1117 | 1118 | ```python 1119 | state = cell.zero_state(...) 1120 | outputs = [] 1121 | for input_ in inputs: 1122 | output, state = cell(input_, state) 1123 | outputs.append(output) 1124 | return (outputs, state) 1125 | ``` 1126 | However, a few other options are available: 1127 | 1128 | An initial state can be provided. 1129 | If the sequence_length vector is provided, dynamic calculation is performed. 1130 | This method of calculation does not compute the RNN steps past the maximum 1131 | sequence length of the minibatch (thus saving computational time), 1132 | and properly propagates the state at an example's sequence length 1133 | to the final state output. 1134 | 1135 | The dynamic calculation performed is, at time `t` for batch row `b`, 1136 | 1137 | ```python 1138 | (output, state)(b, t) = 1139 | (t >= sequence_length(b)) 1140 | ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 1141 | : cell(input(b, t), state(b, t - 1)) 1142 | ``` 1143 | 1144 | Args: 1145 | cell: An instance of RNNCell. 1146 | inputs: A length T list of inputs, each a `Tensor` of shape 1147 | `[batch_size, input_size]`, or a nested tuple of such elements. 1148 | initial_state: (optional) An initial state for the RNN. 1149 | If `cell.state_size` is an integer, this must be 1150 | a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 1151 | If `cell.state_size` is a tuple, this should be a tuple of 1152 | tensors having shapes `[batch_size, s] for s in cell.state_size`. 1153 | dtype: (optional) The data type for the initial state and expected output. 1154 | Required if initial_state is not provided or RNN state has a heterogeneous 1155 | dtype. 1156 | sequence_length: Specifies the length of each sequence in inputs. 1157 | An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 1158 | scope: VariableScope for the created subgraph; defaults to "rnn". 1159 | 1160 | Returns: 1161 | A pair (outputs, state) where: 1162 | 1163 | - outputs is a length T list of outputs (one for each input), or a nested 1164 | tuple of such elements. 1165 | - state is the final state 1166 | 1167 | Raises: 1168 | TypeError: If `cell` is not an instance of RNNCell. 1169 | ValueError: If `inputs` is `None` or an empty list, or if the input depth 1170 | (column size) cannot be inferred from inputs via shape inference. 1171 | """ 1172 | 1173 | if not _like_rnncell(cell): 1174 | raise TypeError("cell must be an instance of RNNCell") 1175 | if not nest.is_sequence(inputs): 1176 | raise TypeError("inputs must be a sequence") 1177 | if not inputs: 1178 | raise ValueError("inputs must not be empty") 1179 | 1180 | outputs = [] 1181 | # Create a new scope in which the caching device is either 1182 | # determined by the parent scope, or is set to place the cached 1183 | # Variable using the same placement as for the rest of the RNN. 1184 | with vs.variable_scope(scope or "rnn") as varscope: 1185 | if varscope.caching_device is None: 1186 | varscope.set_caching_device(lambda op: op.device) 1187 | 1188 | # Obtain the first sequence of the input 1189 | first_input = inputs 1190 | while nest.is_sequence(first_input): 1191 | first_input = first_input[0] 1192 | 1193 | # Temporarily avoid EmbeddingWrapper and seq2seq badness 1194 | # TODO(lukaszkaiser): remove EmbeddingWrapper 1195 | if first_input.get_shape().ndims != 1: 1196 | 1197 | input_shape = first_input.get_shape().with_rank_at_least(2) 1198 | fixed_batch_size = input_shape[0] 1199 | 1200 | flat_inputs = nest.flatten(inputs) 1201 | for flat_input in flat_inputs: 1202 | input_shape = flat_input.get_shape().with_rank_at_least(2) 1203 | batch_size, input_size = input_shape[0], input_shape[1:] 1204 | fixed_batch_size.merge_with(batch_size) 1205 | for i, size in enumerate(input_size): 1206 | if size.value is None: 1207 | raise ValueError( 1208 | "Input size (dimension %d of inputs) must be accessible via " 1209 | "shape inference, but saw value None." % i) 1210 | else: 1211 | fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 1212 | 1213 | if fixed_batch_size.value: 1214 | batch_size = fixed_batch_size.value 1215 | else: 1216 | batch_size = array_ops.shape(first_input)[0] 1217 | if initial_state is not None: 1218 | state = initial_state 1219 | else: 1220 | if not dtype: 1221 | raise ValueError("If no initial_state is provided, " 1222 | "dtype must be specified") 1223 | state = cell.zero_state(batch_size, dtype) 1224 | 1225 | if sequence_length is not None: # Prepare variables 1226 | sequence_length = ops.convert_to_tensor( 1227 | sequence_length, name="sequence_length") 1228 | if sequence_length.get_shape().ndims not in (None, 1): 1229 | raise ValueError( 1230 | "sequence_length must be a vector of length batch_size") 1231 | 1232 | def _create_zero_output(output_size): 1233 | # convert int to TensorShape if necessary 1234 | size = _concat(batch_size, output_size) 1235 | output = array_ops.zeros( 1236 | array_ops.stack(size), _infer_state_dtype(dtype, state)) 1237 | shape = _concat(fixed_batch_size.value, output_size, static=True) 1238 | output.set_shape(tensor_shape.TensorShape(shape)) 1239 | return output 1240 | 1241 | output_size = cell.output_size 1242 | flat_output_size = nest.flatten(output_size) 1243 | flat_zero_output = tuple( 1244 | _create_zero_output(size) for size in flat_output_size) 1245 | zero_output = nest.pack_sequence_as( 1246 | structure=output_size, flat_sequence=flat_zero_output) 1247 | 1248 | sequence_length = math_ops.to_int32(sequence_length) 1249 | min_sequence_length = math_ops.reduce_min(sequence_length) 1250 | max_sequence_length = math_ops.reduce_max(sequence_length) 1251 | 1252 | for time, input_ in enumerate(inputs): 1253 | if time > 0: 1254 | varscope.reuse_variables() 1255 | # pylint: disable=cell-var-from-loop 1256 | call_cell = lambda: cell(input_, state) 1257 | # pylint: enable=cell-var-from-loop 1258 | if sequence_length is not None: 1259 | (output, state) = _rnn_step( 1260 | time=time, 1261 | sequence_length=sequence_length, 1262 | min_sequence_length=min_sequence_length, 1263 | max_sequence_length=max_sequence_length, 1264 | zero_output=zero_output, 1265 | state=state, 1266 | call_cell=call_cell, 1267 | state_size=cell.state_size) 1268 | else: 1269 | (output, state) = call_cell() 1270 | 1271 | outputs.append(output) 1272 | 1273 | return (outputs, state) 1274 | 1275 | 1276 | def static_state_saving_rnn(cell, 1277 | inputs, 1278 | state_saver, 1279 | state_name, 1280 | sequence_length=None, 1281 | scope=None): 1282 | """RNN that accepts a state saver for time-truncated RNN calculation. 1283 | 1284 | Args: 1285 | cell: An instance of `RNNCell`. 1286 | inputs: A length T list of inputs, each a `Tensor` of shape 1287 | `[batch_size, input_size]`. 1288 | state_saver: A state saver object with methods `state` and `save_state`. 1289 | state_name: Python string or tuple of strings. The name to use with the 1290 | state_saver. If the cell returns tuples of states (i.e., 1291 | `cell.state_size` is a tuple) then `state_name` should be a tuple of 1292 | strings having the same length as `cell.state_size`. Otherwise it should 1293 | be a single string. 1294 | sequence_length: (optional) An int32/int64 vector size [batch_size]. 1295 | See the documentation for rnn() for more details about sequence_length. 1296 | scope: VariableScope for the created subgraph; defaults to "rnn". 1297 | 1298 | Returns: 1299 | A pair (outputs, state) where: 1300 | outputs is a length T list of outputs (one for each input) 1301 | states is the final state 1302 | 1303 | Raises: 1304 | TypeError: If `cell` is not an instance of RNNCell. 1305 | ValueError: If `inputs` is `None` or an empty list, or if the arity and 1306 | type of `state_name` does not match that of `cell.state_size`. 1307 | """ 1308 | state_size = cell.state_size 1309 | state_is_tuple = nest.is_sequence(state_size) 1310 | state_name_tuple = nest.is_sequence(state_name) 1311 | 1312 | if state_is_tuple != state_name_tuple: 1313 | raise ValueError("state_name should be the same type as cell.state_size. " 1314 | "state_name: %s, cell.state_size: %s" % (str(state_name), 1315 | str(state_size))) 1316 | 1317 | if state_is_tuple: 1318 | state_name_flat = nest.flatten(state_name) 1319 | state_size_flat = nest.flatten(state_size) 1320 | 1321 | if len(state_name_flat) != len(state_size_flat): 1322 | raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" % 1323 | (len(state_name_flat), len(state_size_flat))) 1324 | 1325 | initial_state = nest.pack_sequence_as( 1326 | structure=state_size, 1327 | flat_sequence=[state_saver.state(s) for s in state_name_flat]) 1328 | else: 1329 | initial_state = state_saver.state(state_name) 1330 | 1331 | (outputs, state) = static_rnn( 1332 | cell, 1333 | inputs, 1334 | initial_state=initial_state, 1335 | sequence_length=sequence_length, 1336 | scope=scope) 1337 | 1338 | if state_is_tuple: 1339 | flat_state = nest.flatten(state) 1340 | state_name = nest.flatten(state_name) 1341 | save_state = [ 1342 | state_saver.save_state(name, substate) 1343 | for name, substate in zip(state_name, flat_state) 1344 | ] 1345 | else: 1346 | save_state = [state_saver.save_state(state_name, state)] 1347 | 1348 | with ops.control_dependencies(save_state): 1349 | last_output = outputs[-1] 1350 | flat_last_output = nest.flatten(last_output) 1351 | flat_last_output = [ 1352 | array_ops.identity(output) for output in flat_last_output 1353 | ] 1354 | outputs[-1] = nest.pack_sequence_as( 1355 | structure=last_output, flat_sequence=flat_last_output) 1356 | 1357 | return (outputs, state) 1358 | 1359 | 1360 | def static_bidirectional_rnn(cell_fw, 1361 | cell_bw, 1362 | inputs, 1363 | initial_state_fw=None, 1364 | initial_state_bw=None, 1365 | dtype=None, 1366 | sequence_length=None, 1367 | scope=None): 1368 | """Creates a bidirectional recurrent neural network. 1369 | 1370 | Similar to the unidirectional case above (rnn) but takes input and builds 1371 | independent forward and backward RNNs with the final forward and backward 1372 | outputs depth-concatenated, such that the output will have the format 1373 | [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 1374 | forward and backward cell must match. The initial state for both directions 1375 | is zero by default (but can be set optionally) and no intermediate states are 1376 | ever returned -- the network is fully unrolled for the given (passed in) 1377 | length(s) of the sequence(s) or completely unrolled if length(s) is not given. 1378 | 1379 | Args: 1380 | cell_fw: An instance of RNNCell, to be used for forward direction. 1381 | cell_bw: An instance of RNNCell, to be used for backward direction. 1382 | inputs: A length T list of inputs, each a tensor of shape 1383 | [batch_size, input_size], or a nested tuple of such elements. 1384 | initial_state_fw: (optional) An initial state for the forward RNN. 1385 | This must be a tensor of appropriate type and shape 1386 | `[batch_size, cell_fw.state_size]`. 1387 | If `cell_fw.state_size` is a tuple, this should be a tuple of 1388 | tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. 1389 | initial_state_bw: (optional) Same as for `initial_state_fw`, but using 1390 | the corresponding properties of `cell_bw`. 1391 | dtype: (optional) The data type for the initial state. Required if 1392 | either of the initial states are not provided. 1393 | sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 1394 | containing the actual lengths for each of the sequences. 1395 | scope: VariableScope for the created subgraph; defaults to 1396 | "bidirectional_rnn" 1397 | 1398 | Returns: 1399 | A tuple (outputs, output_state_fw, output_state_bw) where: 1400 | outputs is a length `T` list of outputs (one for each input), which 1401 | are depth-concatenated forward and backward outputs. 1402 | output_state_fw is the final state of the forward rnn. 1403 | output_state_bw is the final state of the backward rnn. 1404 | 1405 | Raises: 1406 | TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 1407 | ValueError: If inputs is None or an empty list. 1408 | """ 1409 | 1410 | if not _like_rnncell(cell_fw): 1411 | raise TypeError("cell_fw must be an instance of RNNCell") 1412 | if not _like_rnncell(cell_bw): 1413 | raise TypeError("cell_bw must be an instance of RNNCell") 1414 | if not nest.is_sequence(inputs): 1415 | raise TypeError("inputs must be a sequence") 1416 | if not inputs: 1417 | raise ValueError("inputs must not be empty") 1418 | 1419 | with vs.variable_scope(scope or "bidirectional_rnn"): 1420 | # Forward direction 1421 | with vs.variable_scope("fw") as fw_scope: 1422 | output_fw, output_state_fw = static_rnn( 1423 | cell_fw, 1424 | inputs, 1425 | initial_state_fw, 1426 | dtype, 1427 | sequence_length, 1428 | scope=fw_scope) 1429 | 1430 | # Backward direction 1431 | with vs.variable_scope("bw") as bw_scope: 1432 | reversed_inputs = _reverse_seq(inputs, sequence_length) 1433 | tmp, output_state_bw = static_rnn( 1434 | cell_bw, 1435 | reversed_inputs, 1436 | initial_state_bw, 1437 | dtype, 1438 | sequence_length, 1439 | scope=bw_scope) 1440 | 1441 | output_bw = _reverse_seq(tmp, sequence_length) 1442 | # Concat each of the forward/backward outputs 1443 | flat_output_fw = nest.flatten(output_fw) 1444 | flat_output_bw = nest.flatten(output_bw) 1445 | 1446 | flat_outputs = tuple( 1447 | array_ops.concat([fw, bw], 1) 1448 | for fw, bw in zip(flat_output_fw, flat_output_bw)) 1449 | 1450 | outputs = nest.pack_sequence_as( 1451 | structure=output_fw, flat_sequence=flat_outputs) 1452 | 1453 | return (outputs, output_state_fw, output_state_bw) -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import sys 4 | from dataloader import * 5 | from rec import * 6 | from ubr import * 7 | from sklearn.metrics import * 8 | import random 9 | import time 10 | import numpy as np 11 | import pickle as pkl 12 | import math 13 | 14 | random.seed(1111) 15 | 16 | EMBEDDING_SIZE = 16 17 | HIDDEN_SIZE = 16 * 2 18 | EVAL_BATCH_SIZE = 500 19 | 20 | # for TMALL 21 | FEAT_SIZE_TMALL = 1529672 + 6 #(6 is for time context) 22 | DATA_DIR_TMALL = '../data/tmall/feateng_data/' 23 | 24 | # for CCMR 25 | FEAT_SIZE_CCMR = 1 + 4920695 + 190129 + (80171 + 1) + (213481 + 1) + (62 + 1) + (1043 + 1) + 4 26 | DATA_DIR_CCMR = '../data/ccmr/feateng_data/' 27 | 28 | # for TAOBAO 29 | FEAT_SIZE_TAOBAO = 5062314 30 | DATA_DIR_TAOBAO = '../data/taobao/feateng_data/' 31 | 32 | # for ALIPAY 33 | FEAT_SIZE_ALIPAY = 2836410 34 | DATA_DIR_ALIPAY = '../../ubr4rec-data/alipay/feateng_data/' 35 | 36 | def restore(data_set_name, target_test_file, user_feat_dict_file, item_feat_dict_file, context_dict_file, 37 | rec_model_type, ubr_model_type, b_num, train_batch_size, feature_size, eb_dim, hidden_size, 38 | rec_lr, ubr_lr, reg_lambda, record_fnum, emb_initializer, taker): 39 | print('restore begin') 40 | tf.reset_default_graph() 41 | 42 | if rec_model_type == 'RecSum': 43 | rec_model = RecSum(feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer) 44 | elif rec_model_type == 'RecAtt': 45 | rec_model = RecAtt(feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer) 46 | else: 47 | print('WRONG REC MODEL TYPE') 48 | exit(1) 49 | 50 | if ubr_model_type == 'UBR_SA': 51 | ubr_model = UBR_SA(feature_size, eb_dim, hidden_size, record_fnum, emb_initializer) 52 | else: 53 | print('WRONG UBR MODEL TYPE') 54 | exit(1) 55 | 56 | rec_model_name = '{}_{}_{}_{}'.format(rec_model_type, train_batch_size, rec_lr, reg_lambda) 57 | ubr_model_name = '{}_{}_{}'.format(ubr_model_type, train_batch_size, ubr_lr) 58 | 59 | gpu_options = tf.GPUOptions(allow_growth=True) 60 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 61 | rec_model.restore(sess, 'save_model_{}/{}/{}_{}/{}/ckpt'.format(data_set_name, b_num, rec_model_type, ubr_model_type, rec_model_name)) 62 | ubr_model.restore(sess, 'save_model_{}/{}/{}_{}/{}/ckpt'.format(data_set_name, b_num, rec_model_type, ubr_model_type, ubr_model_name)) 63 | print('restore eval begin') 64 | _, logloss, rig, auc = eval(rec_model, ubr_model, sess, target_test_file, user_feat_dict_file, item_feat_dict_file, context_dict_file, reg_lambda, taker, train_batch_size) 65 | 66 | print('RESTORE, LOGLOSS %.4f RIG: %.4f AUC: %.4f' % (logloss, rig, auc)) 67 | with open('logs_{}/{}/{}_{}/{}.txt'.format(data_set_name, b_num, rec_model_type, ubr_model_type, rec_model_type), 'a') as f: 68 | results = [train_batch_size, rec_lr, reg_lambda, logloss, rig, auc] 69 | results = [rec_model_type] + [str(res) for res in results] 70 | result_line = '\t'.join(results) + '\n' 71 | f.write(result_line) 72 | 73 | def eval(rec_model, ubr_model, sess, target_file, user_feat_dict_file, item_feat_dict_file, 74 | context_dict_file, reg_lambda, taker, batch_size): 75 | preds = [] 76 | labels = [] 77 | losses = [] 78 | 79 | data_loader = DataLoader_Target(batch_size, target_file, user_feat_dict_file, item_feat_dict_file, context_dict_file) 80 | 81 | t = time.time() 82 | for batch_data in data_loader: 83 | target_batch, label_batch = batch_data 84 | index_batch = ubr_model.get_index(sess, target_batch) 85 | seq_batch, seq_len_batch = taker.take_behave(target_batch, index_batch) 86 | 87 | pred, label, loss = rec_model.eval(sess, [seq_batch, seq_len_batch, target_batch, label_batch], reg_lambda) 88 | preds += pred 89 | labels += label 90 | losses.append(loss) 91 | 92 | logloss = log_loss(labels, preds) 93 | rig = 1 -(logloss / -(0.5 * math.log(0.5) + (1 - 0.5) * math.log(1 - 0.5))) 94 | auc = roc_auc_score(labels, preds) 95 | loss = sum(losses) / len(losses) 96 | 97 | print("EVAL TIME: %.4fs" % (time.time() - t)) 98 | return loss, logloss, rig, auc 99 | 100 | def train_rec_model(rec_training_monitor, epoch_num, sess, eval_iter_num, train_batch_size, 101 | taker, lr, reg_lambda, rec_model, ubr_model, 102 | target_train_file, user_feat_dict_file, 103 | item_feat_dict_file, context_dict_file, step, b_num): 104 | early_stop = False 105 | losses_step = [] 106 | auc_step = [] 107 | logloss_step = [] 108 | rig_step = [] 109 | 110 | for epoch in range(epoch_num): 111 | if early_stop: 112 | break 113 | # train rec model 114 | data_loader = DataLoader_Target(train_batch_size, target_train_file, user_feat_dict_file, 115 | item_feat_dict_file, context_dict_file) 116 | t = time.time() 117 | for batch_data in data_loader: 118 | if early_stop: 119 | break 120 | # get the retrieve data 121 | target_batch, label_batch = batch_data 122 | index_batch = ubr_model.get_index(sess, target_batch) 123 | seq_batch, seq_len_batch = taker.take_behave(target_batch, index_batch) 124 | new_batch_data = [seq_batch, seq_len_batch, target_batch, label_batch] 125 | 126 | # run train and eval 127 | loss = rec_model.train(sess, new_batch_data, lr, reg_lambda) 128 | pred, label, _ = rec_model.eval(sess, new_batch_data, reg_lambda) 129 | step += 1 130 | 131 | # calculate evaluation metrics 132 | logloss = log_loss(label, pred) 133 | rig = 1 -(logloss / -(0.5 * math.log(0.5) + (1 - 0.5) * math.log(1 - 0.5))) 134 | auc = roc_auc_score(label, pred) 135 | losses_step.append(loss) 136 | auc_step.append(auc) 137 | logloss_step.append(logloss) 138 | rig_step.append(rig) 139 | # print evaluation results 140 | if step % eval_iter_num == 0: 141 | train_loss = sum(losses_step) / len(losses_step) 142 | rec_training_monitor['loss'].append(train_loss) 143 | losses_step = [] 144 | 145 | train_auc = sum(auc_step) / len(auc_step) 146 | rec_training_monitor['auc'].append(train_auc) 147 | auc_step = [] 148 | 149 | train_logloss = sum(logloss_step) / len(logloss_step) 150 | rec_training_monitor['logloss'].append(train_logloss) 151 | logloss_step = [] 152 | 153 | train_rig = sum(rig_step) / len(rig_step) 154 | rec_training_monitor['rig'].append(train_rig) 155 | rig_step = [] 156 | 157 | print("TIME UNTIL EVAL: %.4f" % (time.time() - t)) 158 | print("REC MODEL STEP %d LOSS: %.4f LOGLOSS: %.4f RIG: %.4f AUC: %.4f" % (step, train_loss, train_logloss, train_rig, train_auc)) 159 | t = time.time() 160 | 161 | if len(rec_training_monitor['auc']) >= 2: 162 | if rec_training_monitor['auc'][-1] > max(rec_training_monitor['auc'][:-1]): 163 | # save model 164 | model_name = '{}_{}_{}_{}'.format(rec_model_type, train_batch_size, lr, reg_lambda) 165 | if not os.path.exists('save_model_{}/{}/{}_{}/{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type, model_name)): 166 | os.makedirs('save_model_{}/{}/{}_{}/{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type, model_name)) 167 | save_path = 'save_model_{}/{}/{}_{}/{}/ckpt'.format(data_set_name, b_num, rec_model_type, ubr_model_type, model_name) 168 | rec_model.save(sess, save_path) 169 | 170 | if len(rec_training_monitor['loss']) > 2: 171 | if (rec_training_monitor['loss'][-1] > rec_training_monitor['loss'][-2] and rec_training_monitor['loss'][-2] > rec_training_monitor['loss'][-3]): 172 | early_stop = True 173 | if (rec_training_monitor['loss'][-2] - rec_training_monitor['loss'][-1]) <= 0.001 and (rec_training_monitor['loss'][-3] - rec_training_monitor['loss'][-2]) <= 0.001: 174 | early_stop = True 175 | return step, early_stop 176 | 177 | def train_ubr_model(ubr_training_monitor, epoch_num, sess, eval_iter_num, taker, lr, 178 | train_batch_size, rec_model, ubr_model, target_train_file, 179 | user_feat_dict_file, item_feat_dict_file, context_dict_file, 180 | summary_writer, step, b_num): 181 | loss_step = [] 182 | reward_step = [] 183 | 184 | for i in range(epoch_num): 185 | data_loader = DataLoader_Target(train_batch_size, target_train_file, user_feat_dict_file, 186 | item_feat_dict_file, context_dict_file) 187 | 188 | t = time.time() 189 | i = 0 190 | for batch_data in data_loader: 191 | target_batch, label_batch = batch_data 192 | index_batch = ubr_model.get_index(sess, target_batch) 193 | seq_batch, seq_len_batch = taker.take_behave(target_batch, index_batch) 194 | new_batch_data = [seq_batch, seq_len_batch, target_batch, label_batch] 195 | 196 | rewards = rec_model.get_reward(sess, new_batch_data) 197 | loss, reward, summary = ubr_model.train(sess, target_batch, lr, rewards) 198 | loss_step.append(loss) 199 | reward_step.append(reward) 200 | 201 | summary_writer.add_summary(summary, step) 202 | step += 1 203 | 204 | if step % eval_iter_num == 0: 205 | avg_loss = sum(loss_step) / len(loss_step) 206 | avg_reward = sum(reward_step) / len(reward_step) 207 | ubr_training_monitor['loss'].append(avg_loss) 208 | ubr_training_monitor['reward'].append(avg_reward) 209 | loss_step = [] 210 | reward_step = [] 211 | 212 | print("TIME UNTIL EVAL: %.4f" % (time.time() - t)) 213 | print("UBR MODEL STEP %d LOSS: %.4f REWARD: %.4f" % (step, avg_loss, avg_reward)) 214 | t = time.time() 215 | 216 | # save model 217 | model_name = '{}_{}_{}'.format(ubr_model_type, train_batch_size, lr) 218 | if not os.path.exists('save_model_{}/{}/{}_{}/{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type, model_name)): 219 | os.makedirs('save_model_{}/{}/{}_{}/{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type, model_name)) 220 | save_path = 'save_model_{}/{}/{}_{}/{}/ckpt'.format(data_set_name, b_num, rec_model_type, ubr_model_type, model_name) 221 | ubr_model.save(sess, save_path) 222 | return step 223 | 224 | def train(data_set_name, target_train_file, user_feat_dict_file, 225 | item_feat_dict_file, context_dict_file, rec_model_type, ubr_model_type, 226 | taker, b_num, train_batch_size, feature_size, eb_dim, hidden_size, 227 | rec_lr, ubr_lr, reg_lambda, dataset_size, record_fnum, emb_initializer): 228 | tf.reset_default_graph() 229 | 230 | if rec_model_type == 'RecSum': 231 | rec_model = RecSum(feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer) 232 | elif rec_model_type == 'RecAtt': 233 | rec_model = RecAtt(feature_size, eb_dim, hidden_size, b_num, record_fnum, emb_initializer) 234 | else: 235 | print('WRONG REC MODEL TYPE') 236 | exit(1) 237 | 238 | if ubr_model_type == 'UBR_SA': 239 | ubr_model = UBR_SA(feature_size, eb_dim, hidden_size, record_fnum, emb_initializer) 240 | else: 241 | print('WRONG UBR MODEL TYPE') 242 | exit(1) 243 | 244 | rec_training_monitor = { 245 | 'loss' : [], 246 | 'logloss' : [], 247 | 'rig' : [], 248 | 'auc' : [] 249 | } 250 | 251 | ubr_training_monitor = { 252 | 'loss' : [], 253 | 'reward' : [] 254 | } 255 | 256 | # gpu settings 257 | gpu_options = tf.GPUOptions(allow_growth=True) 258 | 259 | # training process 260 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 261 | sess.run(tf.global_variables_initializer()) 262 | sess.run(tf.local_variables_initializer()) 263 | rec_step = 0 264 | ubr_step = 0 265 | eval_iter_num = (dataset_size // 25) // batch_size 266 | 267 | # summary writer 268 | if not os.path.exists('summary_{}/{}/{}_{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type)): 269 | os.makedirs('summary_{}/{}/{}_{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type)) 270 | rec_model_name = '{}_{}_{}_{}'.format(rec_model_type, batch_size, rec_lr, reg_lambda) 271 | ubr_model_name = '{}_{}_{}'.format(ubr_model_type, batch_size, ubr_lr) 272 | summary_writer_ubr = tf.summary.FileWriter('summary_{}/{}/{}_{}/{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type, ubr_model_name)) 273 | 274 | # begin training process 275 | rec_step, early_stop = train_rec_model(rec_training_monitor, 1, sess, eval_iter_num, train_batch_size, 276 | taker, rec_lr, reg_lambda, rec_model, ubr_model, target_train_file, user_feat_dict_file, 277 | item_feat_dict_file, context_dict_file, rec_step, b_num) 278 | for i in range(10): 279 | ubr_step = train_ubr_model(ubr_training_monitor, 1, sess, eval_iter_num, taker, ubr_lr, train_batch_size, 280 | rec_model, ubr_model, target_train_file, user_feat_dict_file, 281 | item_feat_dict_file, context_dict_file, summary_writer_ubr, ubr_step, b_num) 282 | 283 | rec_step, early_stop = train_rec_model(rec_training_monitor, 1, sess, eval_iter_num, train_batch_size, 284 | taker, rec_lr, reg_lambda, rec_model, ubr_model, target_train_file, user_feat_dict_file, 285 | item_feat_dict_file, context_dict_file, rec_step, b_num) 286 | if early_stop: 287 | break 288 | 289 | # generate log 290 | if not os.path.exists('logs_{}/{}/{}_{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type)): 291 | os.makedirs('logs_{}/{}/{}_{}/'.format(data_set_name, b_num, rec_model_type, ubr_model_type)) 292 | 293 | with open('logs_{}/{}/{}_{}/{}.pkl'.format(data_set_name, b_num, rec_model_type, ubr_model_type, rec_model_name), 'wb') as f: 294 | pkl.dump(rec_training_monitor, f) 295 | with open('logs_{}/{}/{}_{}/{}.pkl'.format(data_set_name, b_num, rec_model_type, ubr_model_type, ubr_model_name), 'wb') as f: 296 | pkl.dump(ubr_training_monitor, f) 297 | 298 | 299 | 300 | if __name__ == '__main__': 301 | if len(sys.argv) < 4: 302 | print("PLEASE INPUT [REC MODEL TYPE] [UBR MODEL TYPE] [GPU] [DATASET]") 303 | sys.exit(0) 304 | rec_model_type = sys.argv[1] 305 | ubr_model_type = sys.argv[2] 306 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[3] 307 | data_set_name = sys.argv[4] 308 | 309 | if data_set_name == 'tmall': 310 | record_fnum = 9 311 | 312 | target_train_file = DATA_DIR_TMALL + 'target_train.txt' 313 | target_test_file = DATA_DIR_TMALL + 'target_test.txt' 314 | 315 | user_feat_dict_file = DATA_DIR_TMALL + 'user_feat_dict.pkl' 316 | item_feat_dict_file = DATA_DIR_TMALL + 'item_feat_dict.pkl' 317 | context_dict_train_file = DATA_DIR_TMALL + 'context_dict_train.pkl' 318 | context_dict_test_file = DATA_DIR_TMALL + 'context_dict_test.pkl' 319 | 320 | # model parameter 321 | feature_size = FEAT_SIZE_TMALL 322 | dataset_size = 847568 323 | 324 | emb_initializer = None 325 | b_num = 20 326 | reader = ESReader('tmall') 327 | 328 | 329 | elif data_set_name == 'taobao': 330 | record_fnum = 4 331 | 332 | target_train_file = DATA_DIR_TAOBAO + 'target_train.txt' 333 | target_test_file = DATA_DIR_TAOBAO + 'target_test.txt' 334 | 335 | user_feat_dict_file = None 336 | item_feat_dict_file = DATA_DIR_TAOBAO + 'item_feat_dict.pkl' 337 | context_dict_train_file = DATA_DIR_TAOBAO + 'context_dict_train.pkl' 338 | context_dict_test_file = DATA_DIR_TAOBAO + 'context_dict_test.pkl' 339 | 340 | # model parameter 341 | feature_size = FEAT_SIZE_TAOBAO 342 | dataset_size = 1962046 343 | 344 | emb_initializer = None 345 | b_num = 20 346 | reader = ESReader('taobao') 347 | 348 | elif data_set_name == 'alipay': 349 | record_fnum = 6 350 | 351 | target_train_file = DATA_DIR_ALIPAY + 'target_train.txt' 352 | target_test_file = DATA_DIR_ALIPAY + 'target_test.txt' 353 | 354 | user_feat_dict_file = None 355 | item_feat_dict_file = DATA_DIR_ALIPAY + 'item_feat_dict.pkl' 356 | context_dict_train_file = DATA_DIR_ALIPAY + 'context_dict_train.pkl' 357 | context_dict_test_file = DATA_DIR_ALIPAY + 'context_dict_test.pkl' 358 | 359 | # model parameter 360 | feature_size = FEAT_SIZE_ALIPAY 361 | dataset_size = 996616 362 | 363 | emb_initializer = None 364 | b_num = 12 365 | reader = ESReader('alipay') 366 | else: 367 | print('WRONG DATASET NAME: {}'.format(data_set_name)) 368 | exit() 369 | 370 | ################################## training hyper params ################################## 371 | 372 | reg_lambdas = [1e-4] 373 | batch_sizes = [100, 200] 374 | rec_lrs = [5e-4, 1e-3] 375 | ubr_lrs = [1e-6, 1e-5, 1e-4] 376 | 377 | for reg_lambda in reg_lambdas: 378 | for i in range(len(batch_sizes)): 379 | batch_size = batch_sizes[i] 380 | rec_lr = rec_lrs[i] 381 | taker = Taker(reader, batch_size, b_num, record_fnum) 382 | 383 | for ubr_lr in ubr_lrs: 384 | train(data_set_name, target_train_file, user_feat_dict_file, 385 | item_feat_dict_file, context_dict_train_file, rec_model_type, ubr_model_type, 386 | taker, b_num, batch_size, feature_size, EMBEDDING_SIZE, HIDDEN_SIZE, 387 | rec_lr, ubr_lr, reg_lambda, dataset_size, record_fnum, emb_initializer) 388 | restore(data_set_name, target_test_file, user_feat_dict_file, item_feat_dict_file, context_dict_test_file, 389 | rec_model_type, ubr_model_type, b_num, batch_size, feature_size, 390 | EMBEDDING_SIZE, HIDDEN_SIZE, rec_lr, ubr_lr, reg_lambda, 391 | record_fnum, emb_initializer, taker) 392 | 393 | -------------------------------------------------------------------------------- /code/train_baselines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from dataloader import * 4 | from baselines import * 5 | from rec import * 6 | from sklearn.metrics import * 7 | import random 8 | import time 9 | import numpy as np 10 | import pickle as pkl 11 | import math 12 | 13 | random.seed(1111) 14 | 15 | EMBEDDING_SIZE = 16 16 | HIDDEN_SIZE = 16 * 2 17 | EVAL_BATCH_SIZE = 500 18 | 19 | # for TMALL 20 | FEAT_SIZE_TMALL = 1529672 + 6 #(6 is for time context) 21 | DATA_DIR_TMALL = '../data/tmall/feateng_data/' 22 | MAX_LEN_TMALL = 20 23 | 24 | # for Taobao 25 | FEAT_SIZE_TAOBAO = 5062314 26 | DATA_DIR_TAOBAO = '../data/taobao/feateng_data/' 27 | MAX_LEN_TAOBAO = 20 28 | 29 | # for Alipay 30 | FEAT_SIZE_ALIPAY = 2836410 31 | DATA_DIR_ALIPAY = '../data/alipay/feateng_data/' 32 | MAX_LEN_ALIPAY = 12 33 | 34 | 35 | def restore(data_set_name, target_test_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, 36 | model_type, train_batch_size, feature_size, eb_dim, hidden_size, max_time_len, 37 | lr, reg_lambda, user_fnum, item_fnum, emb_initializer): 38 | print('restore begin') 39 | tf.reset_default_graph() 40 | if model_type == 'GRU4Rec': 41 | model = GRU4Rec(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 42 | elif model_type == 'Caser': 43 | model = Caser(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 44 | elif model_type == 'DIN': 45 | model = DIN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 46 | elif model_type == 'DIEN': 47 | model = DIEN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 48 | elif model_type == 'SASRec': 49 | model = SASRec(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 50 | elif model_type == 'MIMN': 51 | model = MIMN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 52 | elif model_type == 'HPMN': 53 | model = HPMN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 54 | else: 55 | print('WRONG MODEL TYPE') 56 | exit(1) 57 | model_name = '{}_{}_{}_{}'.format(model_type, train_batch_size, lr, reg_lambda) 58 | 59 | gpu_options = tf.GPUOptions(allow_growth=True) 60 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 61 | model.restore(sess, 'save_model_{}/{}/{}/ckpt'.format(data_set_name, max_time_len, model_name)) 62 | print('restore eval begin') 63 | _, logloss, rig, auc = eval(model, sess, target_test_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, max_time_len, reg_lambda) 64 | 65 | print('RESTORE, LOGLOSS %.4f RIG: %.4f AUC: %.4f' % (logloss, rig, auc)) 66 | with open('logs_{}/{}/{}.txt'.format(data_set_name, max_time_len, model_type), 'a') as f: 67 | results = [train_batch_size, lr, reg_lambda, logloss, rig, auc] 68 | results = [model_type] + [str(res) for res in results] 69 | result_line = '\t'.join(results) + '\n' 70 | f.write(result_line) 71 | 72 | 73 | def eval(model, sess, target_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, max_time_len, reg_lambda): 74 | preds = [] 75 | labels = [] 76 | losses = [] 77 | 78 | data_loader = DataLoader(EVAL_BATCH_SIZE, user_seq_file, target_file, user_feat_dict_file, item_feat_dict_file, max_time_len) 79 | 80 | t = time.time() 81 | for batch_data in data_loader: 82 | pred, label, loss = model.eval(sess, batch_data, reg_lambda) 83 | preds += pred 84 | labels += label 85 | losses.append(loss) 86 | 87 | logloss = log_loss(labels, preds) 88 | rig = 1 - (logloss / -(0.5 * math.log(0.5) + (1 - 0.5) * math.log(1 - 0.5))) 89 | auc = roc_auc_score(labels, preds) 90 | loss = sum(losses) / len(losses) 91 | 92 | print("EVAL TIME: %.4fs" % (time.time() - t)) 93 | return loss, logloss, rig, auc 94 | 95 | 96 | def train(data_set_name, target_train_file, target_vali_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, 97 | model_type, train_batch_size, feature_size, eb_dim, hidden_size, max_time_len, lr, reg_lambda, dataset_size, 98 | user_fnum, item_fnum, emb_initializer): 99 | tf.reset_default_graph() 100 | 101 | if model_type == 'GRU4Rec': 102 | model = GRU4Rec(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 103 | elif model_type == 'Caser': 104 | model = Caser(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 105 | elif model_type == 'DIN': 106 | model = DIN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 107 | elif model_type == 'DIEN': 108 | model = DIEN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 109 | elif model_type == 'SASRec': 110 | model = SASRec(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 111 | elif model_type == 'MIMN': 112 | model = MIMN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 113 | elif model_type == 'HPMN': 114 | model = HPMN(feature_size, eb_dim, hidden_size, max_time_len, user_fnum, item_fnum, emb_initializer) 115 | else: 116 | print('WRONG MODEL TYPE') 117 | exit(1) 118 | 119 | training_monitor = { 120 | 'train_loss' : [], 121 | 'vali_loss' : [], 122 | 'logloss' : [], 123 | 'rig' : [], 124 | 'auc' : [] 125 | } 126 | 127 | # gpu settings 128 | gpu_options = tf.GPUOptions(allow_growth=True) 129 | 130 | # training process 131 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 132 | sess.run(tf.global_variables_initializer()) 133 | sess.run(tf.local_variables_initializer()) 134 | 135 | train_losses_step = [] 136 | 137 | # before training process 138 | step = 0 139 | vali_loss, logloss, rig, auc = eval(model, sess, target_vali_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, max_time_len, reg_lambda) 140 | 141 | training_monitor['train_loss'].append(None) 142 | training_monitor['vali_loss'].append(vali_loss) 143 | training_monitor['logloss'].append(logloss) 144 | training_monitor['rig'].append(rig) 145 | training_monitor['auc'].append(auc) 146 | 147 | print("STEP %d LOSS TRAIN: NULL LOSS VALI: %.4f LOGLOSS: %.4f RIG: %.4f AUC: %.4f" % (step, vali_loss, logloss, rig, auc)) 148 | early_stop = False 149 | eval_iter_num = (dataset_size // 5) // train_batch_size 150 | # begin training process 151 | for epoch in range(10): 152 | if early_stop: 153 | break 154 | data_loader = DataLoader(train_batch_size, user_seq_file, target_train_file, user_feat_dict_file, item_feat_dict_file, max_time_len) 155 | 156 | for batch_data in data_loader: 157 | if early_stop: 158 | break 159 | loss = model.train(sess, batch_data, lr, reg_lambda) 160 | step += 1 161 | train_losses_step.append(loss) 162 | 163 | if step % eval_iter_num == 0: 164 | train_loss = sum(train_losses_step) / len(train_losses_step) 165 | training_monitor['train_loss'].append(train_loss) 166 | train_losses_step = [] 167 | 168 | vali_loss, logloss, rig, auc = eval(model, sess, target_vali_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, max_time_len, reg_lambda) 169 | training_monitor['vali_loss'].append(vali_loss) 170 | training_monitor['logloss'].append(logloss) 171 | training_monitor['rig'].append(rig) 172 | training_monitor['auc'].append(auc) 173 | 174 | print("STEP %d LOSS TRAIN: %.4f LOSS VALI: %.4f LOGLOSS: %.4f RIG: %.4f AUC: %.4f" % (step, train_loss, vali_loss, logloss, rig, auc)) 175 | if training_monitor['auc'][-1] > max(training_monitor['auc'][:-1]): 176 | # save model 177 | model_name = '{}_{}_{}_{}'.format(model_type, train_batch_size, lr, reg_lambda) 178 | if not os.path.exists('save_model_{}/{}/{}/'.format(data_set_name, max_time_len, model_name)): 179 | os.makedirs('save_model_{}/{}/{}/'.format(data_set_name, max_time_len, model_name)) 180 | save_path = 'save_model_{}/{}/{}/ckpt'.format(data_set_name, max_time_len, model_name) 181 | model.save(sess, save_path) 182 | 183 | if len(training_monitor['vali_loss']) > 2 and epoch > 0: 184 | if (training_monitor['vali_loss'][-1] > training_monitor['vali_loss'][-2] and training_monitor['vali_loss'][-2] > training_monitor['vali_loss'][-3]): 185 | early_stop = True 186 | if (training_monitor['vali_loss'][-2] - training_monitor['vali_loss'][-1]) <= 0.001 and (training_monitor['vali_loss'][-3] - training_monitor['vali_loss'][-2]) <= 0.001: 187 | early_stop = True 188 | 189 | # generate log 190 | if not os.path.exists('logs_{}/{}/'.format(data_set_name, max_time_len)): 191 | os.makedirs('logs_{}/{}/'.format(data_set_name, max_time_len)) 192 | model_name = '{}_{}_{}_{}'.format(model_type, train_batch_size, lr, reg_lambda) 193 | 194 | with open('logs_{}/{}/{}.pkl'.format(data_set_name, max_time_len, model_name), 'wb') as f: 195 | pkl.dump(training_monitor, f) 196 | 197 | 198 | if __name__ == '__main__': 199 | if len(sys.argv) < 4: 200 | print("PLEASE INPUT [MODEL TYPE] [GPU] [DATASET]") 201 | sys.exit(0) 202 | model_type = sys.argv[1] 203 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[2] 204 | data_set_name = sys.argv[3] 205 | 206 | if data_set_name == 'tmall': 207 | user_fnum = 3 208 | item_fnum = 4 209 | 210 | target_train_file = DATA_DIR_TMALL + 'target_train.txt' 211 | target_vali_file = DATA_DIR_TMALL + 'target_vali.txt' 212 | target_test_file = DATA_DIR_TMALL + 'target_test.txt' 213 | user_seq_file = DATA_DIR_TMALL + 'user_seq.txt' 214 | 215 | user_feat_dict_file = DATA_DIR_TMALL + 'user_feat_dict.pkl' 216 | item_feat_dict_file = DATA_DIR_TMALL + 'item_feat_dict.pkl' 217 | 218 | # model parameter 219 | feature_size = FEAT_SIZE_TMALL 220 | max_time_len = MAX_LEN_TMALL #100 221 | dataset_size = 847568 222 | 223 | emb_initializer = None 224 | 225 | elif data_set_name == 'taobao': 226 | user_fnum = 1 227 | item_fnum = 2 228 | 229 | target_train_file = DATA_DIR_TAOBAO + 'target_train.txt' 230 | target_vali_file = DATA_DIR_TAOBAO + 'target_vali.txt' 231 | target_test_file = DATA_DIR_TAOBAO + 'target_test.txt' 232 | user_seq_file = DATA_DIR_TAOBAO + 'user_seq.txt' 233 | 234 | user_feat_dict_file = None 235 | item_feat_dict_file = DATA_DIR_TAOBAO + 'item_feat_dict.pkl' 236 | 237 | # model parameter 238 | feature_size = FEAT_SIZE_TAOBAO 239 | max_time_len = MAX_LEN_TAOBAO #100 240 | dataset_size = 1962046 241 | 242 | emb_initializer = None 243 | elif data_set_name == 'alipay': 244 | user_fnum = 1 245 | item_fnum = 3 246 | 247 | target_train_file = DATA_DIR_ALIPAY + 'target_train.txt' 248 | target_vali_file = DATA_DIR_ALIPAY + 'target_vali.txt' 249 | target_test_file = DATA_DIR_ALIPAY + 'target_test.txt' 250 | user_seq_file = DATA_DIR_ALIPAY + 'user_seq.txt' 251 | 252 | user_feat_dict_file = None 253 | item_feat_dict_file = DATA_DIR_ALIPAY + 'item_feat_dict.pkl' 254 | 255 | # model parameter 256 | feature_size = FEAT_SIZE_ALIPAY 257 | max_time_len = MAX_LEN_ALIPAY #60 258 | dataset_size = 996616 259 | 260 | emb_initializer = None 261 | 262 | else: 263 | print('WRONG DATASET NAME: {}'.format(data_set_name)) 264 | exit() 265 | 266 | ################################## training hyper params ################################## 267 | reg_lambdas = [1e-4, 5e-4] 268 | hyper_paras = [(100, 5e-4), (200, 1e-3)] 269 | 270 | for hyper in hyper_paras: 271 | train_batch_size, lr = hyper 272 | for reg_lambda in reg_lambdas: 273 | train(data_set_name, target_train_file, target_vali_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, 274 | model_type, train_batch_size, feature_size, EMBEDDING_SIZE, HIDDEN_SIZE, max_time_len, lr, reg_lambda, dataset_size, 275 | user_fnum, item_fnum, emb_initializer) 276 | restore(data_set_name, target_test_file, user_seq_file, user_feat_dict_file, item_feat_dict_file, model_type, train_batch_size, 277 | feature_size, EMBEDDING_SIZE, HIDDEN_SIZE, max_time_len, lr, reg_lambda, user_fnum, item_fnum, emb_initializer) 278 | -------------------------------------------------------------------------------- /code/ubr.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class UBRBase(object): 4 | def __init__(self, feature_size, eb_dim, hidden_size, record_fnum, emb_initializer): 5 | self.record_fnum = record_fnum 6 | 7 | # input placeholders 8 | with tf.name_scope('ubr/inputs'): 9 | self.target_ph = tf.placeholder(tf.int32, [None, record_fnum], name='ubr_target_ph') 10 | 11 | self.rewards = tf.placeholder(tf.float32, [None, 1], name='rewards_ph') 12 | self.lr = tf.placeholder(tf.float32, []) 13 | 14 | # embedding 15 | with tf.variable_scope('embedding', reuse=tf.AUTO_REUSE): 16 | if emb_initializer is not None: 17 | self.emb_mtx = tf.get_variable('emb_mtx', initializer=emb_initializer) 18 | else: 19 | self.emb_mtx = tf.get_variable('emb_mtx', [feature_size, eb_dim], initializer=tf.truncated_normal_initializer) 20 | self.emb_mtx_mask = tf.constant(value=1., shape=[feature_size - 1, eb_dim]) 21 | self.emb_mtx_mask = tf.concat([tf.constant(value=0., shape=[1, eb_dim]), self.emb_mtx_mask], axis=0) 22 | self.emb_mtx = self.emb_mtx * self.emb_mtx_mask 23 | 24 | self.target = tf.nn.embedding_lookup(self.emb_mtx, self.target_ph) #[ B, F, EMB_DIM] 25 | self.target_input = self.target[:, 1:, :] # exclude uid 26 | 27 | def build_index_and_loss(self, probs): 28 | uniform = tf.random_uniform(tf.shape(probs), 0, 1) 29 | condition = probs - uniform 30 | self.index = tf.where(condition >= 0, tf.ones_like(probs), tf.zeros_like(probs)) 31 | log_probs = tf.log(tf.clip_by_value(probs, 1e-10, 1)) 32 | 33 | self.loss = -tf.reduce_mean(tf.reduce_sum(log_probs * self.index * self.rewards, axis=1)) 34 | self.reward = tf.reduce_mean(self.rewards) 35 | tf.summary.scalar('ubr_reward', self.reward) 36 | self.merged = tf.summary.merge_all() 37 | 38 | def build_optimizer(self): 39 | # optimizer and training step 40 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr, name='ubr_adam') 41 | gvs = self.optimizer.compute_gradients(self.loss) 42 | capped_gvs = [] 43 | for grad, var in gvs: 44 | if grad is not None: 45 | capped_gvs.append((tf.clip_by_norm(grad, 5.), var)) 46 | self.train_step = self.optimizer.apply_gradients(capped_gvs) 47 | # self.train_step = self.optimizer.minimize(self.loss) 48 | 49 | def train(self, sess, batch_data, lr, rewards): 50 | loss, reward, _, summary = sess.run([self.loss, self.reward, self.train_step, self.merged], feed_dict = { 51 | self.target_ph : batch_data, 52 | self.lr : lr, 53 | self.rewards : rewards 54 | }) 55 | return loss, reward, summary 56 | 57 | def get_distri(self, sess, batch_data): 58 | res = sess.run(self.probs, feed_dict={ 59 | self.target_ph : batch_data 60 | }) 61 | return res 62 | 63 | def get_index(self, sess, batch_data): 64 | res = sess.run(self.index, feed_dict={ 65 | self.target_ph : batch_data 66 | }) 67 | return res 68 | 69 | def save(self, sess, path): 70 | saver = tf.train.Saver() 71 | saver.save(sess, save_path=path) 72 | 73 | def restore(self, sess, path): 74 | saver = tf.train.Saver() 75 | saver.restore(sess, save_path=path) 76 | print('model restored from {}'.format(path)) 77 | 78 | 79 | class UBR_SA(UBRBase): 80 | def __init__(self, feature_size, eb_dim, hidden_size, record_fnum, emb_initializer): 81 | super(UBR_SA, self).__init__(feature_size, eb_dim, hidden_size, record_fnum, emb_initializer) 82 | self.probs = self.build_select_probs(self.target_input) 83 | self.build_index_and_loss(self.probs) 84 | self.build_optimizer() 85 | 86 | def build_select_probs(self, target_input): 87 | sa_target = self.multihead_attention(self.normalize(target_input), target_input) 88 | probs = tf.layers.dense(sa_target, 20, activation=tf.nn.relu, name='fc1') 89 | probs = tf.layers.dense(probs, 10, activation=tf.nn.relu, name='fc2') 90 | probs = tf.layers.dense(probs, 1, activation=tf.nn.sigmoid, name='fc3') 91 | probs = tf.reshape(probs, [-1, self.record_fnum - 1]) 92 | return probs 93 | 94 | 95 | def multihead_attention(self, 96 | queries, 97 | keys, 98 | num_units=None, 99 | num_heads=2, 100 | scope="multihead_attention", 101 | reuse=None): 102 | '''Applies multihead attention. 103 | 104 | Args: 105 | queries: A 3d tensor with shape of [N, T_q, C_q]. 106 | keys: A 3d tensor with shape of [N, T_k, C_k]. 107 | num_units: A scalar. Attention size. 108 | num_heads: An int. Number of heads. 109 | scope: Optional scope for `variable_scope`. 110 | reuse: Boolean, whether to reuse the weights of a previous layer 111 | by the same name. 112 | 113 | Returns 114 | A 3d tensor with shape of (N, T_q, C) 115 | ''' 116 | with tf.variable_scope(scope, reuse=reuse): 117 | # Set the fall back option for num_units 118 | if num_units is None: 119 | num_units = queries.get_shape().as_list()[-1] 120 | 121 | # Linear projections 122 | Q = tf.layers.dense(queries, num_units, activation=None) # (N, T_q, C) 123 | K = tf.layers.dense(keys, num_units, activation=None) # (N, T_k, C) 124 | V = tf.layers.dense(keys, num_units, activation=None) # (N, T_k, C) 125 | 126 | # Split and concat 127 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, C/h) 128 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 129 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, C/h) 130 | 131 | # Multiplication 132 | outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (h*N, T_q, T_k) 133 | 134 | # Scale 135 | outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5) 136 | 137 | # Key Masking 138 | key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k) 139 | key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k) 140 | key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k) 141 | 142 | paddings = tf.ones_like(outputs)*(-2**32+1) 143 | outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k) 144 | 145 | # Activation 146 | outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k) 147 | 148 | # Query Masking 149 | query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q) 150 | query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q) 151 | query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k) 152 | outputs *= query_masks # broadcasting. (N, T_q, C) 153 | 154 | # Dropouts 155 | outputs = tf.nn.dropout(outputs, 0.8) 156 | 157 | # Weighted sum 158 | outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h) 159 | 160 | # Restore shape 161 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, C) 162 | 163 | # Residual connection 164 | outputs += queries 165 | 166 | # Normalize 167 | #outputs = normalize(outputs) # (N, T_q, C) 168 | 169 | return outputs 170 | 171 | def normalize(self, 172 | inputs, 173 | epsilon = 1e-8, 174 | scope="ln", 175 | reuse=None): 176 | '''Applies layer normalization. 177 | 178 | Args: 179 | inputs: A tensor with 2 or more dimensions, where the first dimension has 180 | `batch_size`. 181 | epsilon: A floating number. A very small number for preventing ZeroDivision Error. 182 | scope: Optional scope for `variable_scope`. 183 | reuse: Boolean, whether to reuse the weights of a previous layer 184 | by the same name. 185 | 186 | Returns: 187 | A tensor with the same shape and data dtype as `inputs`. 188 | ''' 189 | with tf.variable_scope(scope, reuse=reuse): 190 | inputs_shape = inputs.get_shape() 191 | params_shape = inputs_shape[-1:] 192 | 193 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 194 | beta= tf.Variable(tf.zeros(params_shape)) 195 | gamma = tf.Variable(tf.ones(params_shape)) 196 | normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) ) 197 | outputs = gamma * normalized + beta 198 | 199 | return outputs 200 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.ops import array_ops 4 | from tensorflow.python.ops import init_ops 5 | from tensorflow.python.ops import math_ops 6 | from tensorflow.python.ops import variable_scope as vs 7 | import gc 8 | import numpy as np 9 | from tensorflow.python.ops.rnn_cell import * 10 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell 11 | 12 | def expand(x, axis, N, dims=2): 13 | if dims != 2: 14 | return tf.tile(tf.expand_dims(x, axis), [N, 1, 1]) 15 | return tf.tile(tf.expand_dims(x, axis), [N, 1]) 16 | # return tf.concat([tf.expand_dims(x, dim) for _ in tf.range(N)], axis=dim) 17 | 18 | 19 | def create_linear_initializer(input_size, dtype=tf.float32): 20 | stddev = 1.0 / np.sqrt(input_size) 21 | return tf.truncated_normal_initializer(stddev=stddev, dtype=dtype) 22 | 23 | 24 | def learned_init(units): 25 | return tf.squeeze(tf.contrib.layers.fully_connected( 26 | tf.ones([1, 1]), units, activation_fn=None, biases_initializer=None)) 27 | 28 | 29 | class MIMNCell(tf.contrib.rnn.RNNCell): 30 | def __init__(self, controller_units, memory_vector_dim, batch_size=128, memory_size=4, 31 | read_head_num=1, write_head_num=1, reuse=False, output_dim=16, clip_value=20, sharp_value=2.): 32 | self.controller_units = controller_units 33 | self.memory_vector_dim = memory_vector_dim 34 | self.memory_size = memory_size 35 | self.batch_size = batch_size 36 | self.read_head_num = read_head_num 37 | self.write_head_num = write_head_num 38 | self.reuse = reuse 39 | self.clip_value = clip_value 40 | self.sharp_value = sharp_value 41 | 42 | def single_cell(num_units): 43 | return tf.nn.rnn_cell.GRUCell(num_units) 44 | 45 | self.controller = single_cell(self.controller_units) 46 | self.step = 0 47 | self.output_dim = output_dim 48 | 49 | # TODO: ? 50 | self.o2p_initializer = create_linear_initializer(self.controller_units) 51 | self.o2o_initializer = create_linear_initializer( 52 | self.controller_units + self.memory_vector_dim * self.read_head_num) 53 | 54 | def __call__(self, x, prev_state): 55 | prev_read_vector_list = prev_state["read_vector_list"] 56 | 57 | controller_input = tf.concat([x] + prev_read_vector_list, axis=1) 58 | with tf.variable_scope('controller', reuse=self.reuse): 59 | controller_output, controller_state = self.controller(controller_input, prev_state["controller_state"]) 60 | 61 | num_parameters_per_head = self.memory_vector_dim + 1 # TODO: why +1? sharp_value? 62 | num_heads = self.read_head_num + self.write_head_num 63 | total_parameter_num = num_parameters_per_head * num_heads + self.memory_vector_dim * 2 * self.write_head_num 64 | 65 | with tf.variable_scope("o2p", reuse=(self.step > 0) or self.reuse): 66 | parameters = tf.contrib.layers.fully_connected( 67 | controller_output, total_parameter_num, activation_fn=None, 68 | weights_initializer=self.o2p_initializer) 69 | parameters = tf.clip_by_norm(parameters, self.clip_value) 70 | 71 | head_parameter_list = tf.split(parameters[:, :num_parameters_per_head * num_heads], num_heads, axis=1) 72 | erase_add_list = tf.split(parameters[:, num_parameters_per_head * num_heads:], 2 * self.write_head_num, axis=1) 73 | 74 | prev_M = prev_state["M"] 75 | key_M = prev_state["key_M"] 76 | w_list = [] 77 | for i, head_parameter in enumerate(head_parameter_list): 78 | k = tf.tanh(head_parameter[:, 0:self.memory_vector_dim]) 79 | beta = (tf.nn.softplus(head_parameter[:, self.memory_vector_dim]) + 1) * self.sharp_value 80 | with tf.variable_scope('addressing_head_%d' % i): 81 | w = self.addressing(k, beta, key_M, prev_M) # [batch_size, memory_size] 82 | w_list.append(w) 83 | 84 | read_w_list = w_list[:self.read_head_num] 85 | read_vector_list = [] 86 | for i in range(self.read_head_num): 87 | # [batch_size, fnum * eb_dim] 88 | read_vector = tf.reduce_sum(tf.expand_dims(read_w_list[i], dim=2) * prev_M, axis=1) 89 | read_vector = tf.reshape(read_vector, [-1, self.memory_vector_dim]) 90 | read_vector_list.append(read_vector) 91 | 92 | write_w_list = w_list[self.read_head_num:] 93 | 94 | M = prev_M 95 | sum_aggre = prev_state["sum_aggre"] 96 | 97 | for i in range(self.write_head_num): 98 | w = tf.expand_dims(write_w_list[i], axis=2) # [batch_size, memory_size, 1] 99 | erase_vector = tf.expand_dims(tf.sigmoid(erase_add_list[i * 2]), axis=1) # [batch_size, 1, fnum * eb_dim] 100 | add_vector = tf.expand_dims(tf.tanh(erase_add_list[i * 2 + 1]), axis=1) 101 | 102 | # [batch_size, memory_size, fnum * eb_dim] 103 | # M_t = (1 - E_t) * M_t + A_t 104 | ones = tf.ones([self.batch_size, self.memory_size, self.memory_vector_dim]) 105 | M = M * (ones - tf.matmul(w, erase_vector)) + tf.matmul(w, add_vector) 106 | sum_aggre += tf.matmul(tf.stop_gradient(w), add_vector) # [batch_size, memory_size, fnum * eb_dim] 107 | 108 | with tf.variable_scope("o2o", reuse=(self.step > 0) or self.reuse): 109 | read_output = tf.contrib.layers.fully_connected( 110 | tf.concat([controller_output] + read_vector_list, axis=1), self.output_dim, activation_fn=None, 111 | weights_initializer=self.o2o_initializer) 112 | read_output = tf.clip_by_norm(read_output, self.clip_value) 113 | 114 | self.step += 1 115 | return read_output, { 116 | "controller_state": controller_state, 117 | "read_vector_list": read_vector_list, 118 | "w_list": w_list, 119 | "M": M, 120 | "key_M": key_M, 121 | "sum_aggre": sum_aggre 122 | } 123 | 124 | def addressing(self, k, beta, key_M, prev_M): 125 | # Cosine Similarity 126 | def cosine_similarity(key, M): 127 | key = tf.expand_dims(key, axis=2) 128 | inner_product = tf.matmul(M, key) 129 | k_norm = tf.sqrt(tf.reduce_sum(tf.square(key), axis=1, keep_dims=True)) 130 | M_norm = tf.sqrt(tf.reduce_sum(tf.square(M), axis=2, keep_dims=True)) 131 | norm_product = M_norm * k_norm 132 | K = tf.squeeze(inner_product / (norm_product + 1e-8)) 133 | return K 134 | 135 | K = 0.5*(cosine_similarity(k,key_M) + cosine_similarity(k,prev_M)) 136 | K_amplified = tf.exp(tf.expand_dims(beta, axis=1) * K) 137 | w_c = K_amplified / tf.reduce_sum(K_amplified, axis=1, keep_dims=True) 138 | 139 | return w_c 140 | 141 | def zero_state(self, batch_size): 142 | with tf.variable_scope('init', reuse=self.reuse): 143 | read_vector_list = [expand(tf.tanh(learned_init(self.memory_vector_dim)), 0, batch_size) 144 | for _ in range(self.read_head_num)] 145 | 146 | w_list = [expand(tf.nn.softmax(learned_init(self.memory_size)), 0, batch_size) 147 | for _ in range(self.read_head_num + self.write_head_num)] 148 | 149 | controller_init_state = self.controller.zero_state(batch_size, tf.float32) 150 | 151 | M = expand(tf.tanh(tf.get_variable( 152 | 'init_M', [self.memory_size, self.memory_vector_dim], 153 | initializer=tf.random_normal_initializer(mean=0.0, stddev=1e-5), trainable=False)), 0, batch_size, 3) 154 | 155 | key_M = expand(tf.tanh(tf.get_variable( 156 | 'key_M', [self.memory_size, self.memory_vector_dim], 157 | initializer=tf.random_normal_initializer(mean=0.0, stddev=0.5))), 0, batch_size, 3) 158 | 159 | sum_aggre = tf.zeros([batch_size, self.memory_size, self.memory_vector_dim], dtype=tf.float32) 160 | 161 | state = { 162 | "controller_state": controller_init_state, 163 | "read_vector_list": read_vector_list, 164 | "w_list": w_list, 165 | "M": M, 166 | "key_M": key_M, 167 | "sum_aggre": sum_aggre 168 | } 169 | return state 170 | 171 | class VecAttGRUCell(RNNCell): 172 | """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). 173 | Args: 174 | num_units: int, The number of units in the GRU cell. 175 | activation: Nonlinearity to use. Default: `tanh`. 176 | reuse: (optional) Python boolean describing whether to reuse variables 177 | in an existing scope. If not `True`, and the existing scope already has 178 | the given variables, an error is raised. 179 | kernel_initializer: (optional) The initializer to use for the weight and 180 | projection matrices. 181 | bias_initializer: (optional) The initializer to use for the bias. 182 | """ 183 | 184 | def __init__(self, 185 | num_units, 186 | activation=None, 187 | reuse=None, 188 | kernel_initializer=None, 189 | bias_initializer=None): 190 | super(VecAttGRUCell, self).__init__(_reuse=reuse) 191 | self._num_units = num_units 192 | self._activation = activation or math_ops.tanh 193 | self._kernel_initializer = kernel_initializer 194 | self._bias_initializer = bias_initializer 195 | self._gate_linear = None 196 | self._candidate_linear = None 197 | 198 | @property 199 | def state_size(self): 200 | return self._num_units 201 | 202 | @property 203 | def output_size(self): 204 | return self._num_units 205 | 206 | def __call__(self, inputs, state, att_score): 207 | return self.call(inputs, state, att_score) 208 | 209 | def call(self, inputs, state, att_score=None): 210 | """Gated recurrent unit (GRU) with nunits cells.""" 211 | if self._gate_linear is None: 212 | bias_ones = self._bias_initializer 213 | if self._bias_initializer is None: 214 | bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype) 215 | with vs.variable_scope("gates"): # Reset gate and update gate. 216 | self._gate_linear = core_rnn_cell._Linear( 217 | [inputs, state], 218 | 2 * self._num_units, 219 | True, 220 | bias_initializer=bias_ones, 221 | kernel_initializer=self._kernel_initializer) 222 | 223 | value = math_ops.sigmoid(self._gate_linear([inputs, state])) 224 | r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) 225 | 226 | r_state = r * state 227 | if self._candidate_linear is None: 228 | with vs.variable_scope("candidate"): 229 | self._candidate_linear = core_rnn_cell._Linear( 230 | [inputs, r_state], 231 | self._num_units, 232 | True, 233 | bias_initializer=self._bias_initializer, 234 | kernel_initializer=self._kernel_initializer) 235 | c = self._activation(self._candidate_linear([inputs, r_state])) 236 | u = (1.0 - att_score) * u 237 | new_h = u * state + (1 - u) * c 238 | return new_h, new_h --------------------------------------------------------------------------------