├── .gitignore ├── README.md └── User-as-graph.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # User-as-Graph 2 | Source codes for our IJCAI 2021 paper "User-as-Graph: User Modeling with Heterogeneous Graph Pooling for News Recommendation" 3 | 4 | Enviroments: 5 | tensorflow==1.12 6 | Keras==2.2.4 7 | 8 | Notes: 9 | The original experiments are conducted on the old version of the MIND dataset (part of the officially released MIND data). 10 | -------------------------------------------------------------------------------- /User-as-graph.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!wget https://mind201910small.blob.core.windows.net/release/MINDlarge_train.zip\n", 10 | "!wget https://mind201910small.blob.core.windows.net/release/MINDlarge_dev.zip \n", 11 | "!mkdir train\n", 12 | "!mkdir val\n", 13 | "!unzip MINDlarge_train.zip -d ./train\n", 14 | "!unzip MINDlarge_dev.zip -d ./val" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import csv\n", 24 | "import random\n", 25 | "import json\n", 26 | "import numpy as np\n", 27 | "from numpy.linalg import cholesky\n", 28 | "from keras.utils.np_utils import to_categorical\n", 29 | "from keras.layers import *\n", 30 | "from keras.models import Model, load_model \n", 31 | "from keras import activations, constraints, initializers, regularizers\n", 32 | "from keras import backend as K \n", 33 | "from keras.engine.topology import Layer, InputSpec\n", 34 | "from keras import initializers #keras2\n", 35 | "from sklearn.metrics import *\n", 36 | "from keras.optimizers import *\n", 37 | "import keras \n", 38 | "from nltk.tokenize import word_tokenize " 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "with open('train/behaviors.tsv')as f:\n", 48 | " trainuser=f.readlines()\n", 49 | "\n", 50 | "with open('val/behaviors.tsv') as f:\n", 51 | " valuser=f.readlines()\n", 52 | "\n", 53 | "with open('train/news.tsv')as f:\n", 54 | " data=f.readlines()\n", 55 | "\n", 56 | "with open('val/news.tsv')as f:\n", 57 | " data+=f.readlines()\n", 58 | "\n", 59 | "with open('train/entity_embedding.vec')as f:\n", 60 | " entity_emb=f.readlines()\n", 61 | " \n", 62 | "with open('val/entity_embedding.vec')as f:\n", 63 | " entity_emb+=f.readlines()\n", 64 | "\n", 65 | "entity_emb_dict={}\n", 66 | "for i in entity_emb:\n", 67 | " entity_emb_dict[i.strip().split('\\t')[0]]=[float(x) for x in i.strip().split('\\t')[1:]]\n", 68 | "\n", 69 | "entityidlist={'NULL':0}\n", 70 | "entity_emb_table=[[0.]*100]\n", 71 | "cnt=0.\n", 72 | "\n", 73 | "\n", 74 | "news={}\n", 75 | "category={'NULL':0}\n", 76 | "subcategory={'NULL':0}\n", 77 | "newsnumber=0\n", 78 | "for i in data:\n", 79 | " line=i.strip('\\n').split('\\t')\n", 80 | " if line[0] not in news:\n", 81 | " news[line[0]]=[line[1],line[2],word_tokenize(line[3].lower()),[x[\"WikidataId\"] for x in json.loads(line[6])]]\n", 82 | " if line[1] not in category:\n", 83 | " category[line[1]]=len(category)\n", 84 | " if line[2] not in subcategory:\n", 85 | " subcategory[line[2]]=len(subcategory)\n", 86 | " newsnumber+=1\n", 87 | " if newsnumber%1000==0:\n", 88 | " print(newsnumber)\n", 89 | "\n", 90 | "for i in news:\n", 91 | " for j in news[i][3]:\n", 92 | " if j not in entityidlist:\n", 93 | " entityidlist[j]=len(entityidlist)\n", 94 | " entity_emb_table.append(entity_emb_dict.get(j,np.random.uniform(-0.03,0.03,(100,))))\n", 95 | "entity_emb_table=np.array(entity_emb_table,dtype=np.float32)\n", 96 | "print(entity_emb_table.shape)\n", 97 | "newsindex={'NULL':0}\n", 98 | "for i in news:\n", 99 | " newsindex[i]=len(newsindex)\n", 100 | "\n", 101 | "word_dict={'PADDING':[0,999999]}\n", 102 | "\n", 103 | "for i in news:\n", 104 | " for j in news[i][2]:\n", 105 | " if j in word_dict:\n", 106 | " word_dict[j][1]+=1\n", 107 | " else:\n", 108 | " word_dict[j]=[len(word_dict),1]" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "!wget https://nlp.stanford.edu/data/glove.840B.300d.zip\n", 118 | "!unzip glove.840B.300d.zip" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | " \n", 128 | "embdict={}\n", 129 | "with open('glove.840B.300d.txt','rb')as f:\n", 130 | " linenb=0\n", 131 | " while True:\n", 132 | " line=f.readline()\n", 133 | " if len(line)==0:\n", 134 | " break\n", 135 | " line = line.split()\n", 136 | " word=line[0].decode()\n", 137 | " linenb+=1\n", 138 | " if len(word) != 0:\n", 139 | " vec=[float(x) for x in line[1:]]\n", 140 | " if word in word_dict:\n", 141 | " embdict[word]=vec\n", 142 | "\n", 143 | "\n", 144 | "emb_mat=[0]*len(word_dict)\n", 145 | "in_dict_emb=[]\n", 146 | "for i in embdict.keys():\n", 147 | " emb_mat[word_dict[i][0]]=np.array(embdict[i],dtype='float32')\n", 148 | " in_dict_emb.append(emb_mat[word_dict[i][0]])\n", 149 | "in_dict_emb=np.array(in_dict_emb,dtype='float32')\n", 150 | "\n", 151 | "mu=np.mean(in_dict_emb, axis=0)\n", 152 | "Sigma=np.cov(in_dict_emb.T)\n", 153 | "\n", 154 | "norm=np.random.multivariate_normal(mu, Sigma, 1)\n", 155 | "print(mu.shape,Sigma.shape,norm.shape)\n", 156 | "\n", 157 | "for i in range(len(emb_mat)):\n", 158 | " if type(emb_mat[i])==int:\n", 159 | " emb_mat[i]=np.reshape(norm, 300)\n", 160 | "emb_mat[0]=np.zeros(300,dtype='float32')\n", 161 | "emb_mat=np.array(emb_mat,dtype='float32')\n", 162 | "print(emb_mat.shape)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "news_words=[[0]*30]\n", 172 | "\n", 173 | "for i in news:\n", 174 | " words=[]\n", 175 | " for word in news[i][2]:\n", 176 | " if word in word_dict:\n", 177 | " words.append(word_dict[word][0])\n", 178 | " words=words[:30]\n", 179 | " news_words.append(words+[0]*(30-len(words)))\n", 180 | "news_words=np.array(news_words,dtype='int32') \n", 181 | "\n", 182 | "news_entity=[[0]*5]\n", 183 | "\n", 184 | "for i in news:\n", 185 | " entities=[]\n", 186 | " for entity in news[i][3]:\n", 187 | " if entity in entityidlist:\n", 188 | " entities.append(entityidlist[entity])\n", 189 | " entities=entities[:5]\n", 190 | " entities=entities+[0]*(5-len(entities))\n", 191 | " news_entity.append(entities)\n", 192 | "news_entity=np.array(news_entity,dtype='int32') \n", 193 | "\n", 194 | "news_topic=[0]\n", 195 | "for i in news:\n", 196 | " \n", 197 | " news_topic.append(category[news[i][0]])\n", 198 | "news_topic=np.array(news_topic,dtype='int32') " 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "def newsample(array,ratio):\n", 208 | " if ratio >len(array):\n", 209 | " return random.sample(array*(ratio//len(array)+1),ratio)\n", 210 | " else:\n", 211 | " return random.sample(array,ratio)\n", 212 | " \n", 213 | "npratio=4\n", 214 | "train_candidate=[] \n", 215 | "train_label=[]\n", 216 | "train_user_his=[]\n", 217 | "\n", 218 | "for user in trainuser:\n", 219 | " userline=user.replace('\\n','').split('\\t')\n", 220 | " clickids=[newsindex[x] for x in userline[3].split()][-50:]\n", 221 | " pdoc=[newsindex[x.split('-')[0]] for x in userline[4].split() if x.split('-')[1]=='1']\n", 222 | " ndoc=[newsindex[x.split('-')[0]] for x in userline[4].split() if x.split('-')[1]=='0']\n", 223 | " \n", 224 | " for doc in pdoc:\n", 225 | " negd=newsample(ndoc,npratio)\n", 226 | " negd.append(doc)\n", 227 | " candidate_label=[0]*npratio+[1]\n", 228 | " candidate_order=list(range(npratio+1))\n", 229 | " random.shuffle(candidate_order)\n", 230 | " candidate_shuffle=[]\n", 231 | " candidate_label_shuffle=[]\n", 232 | " for i in candidate_order:\n", 233 | " candidate_shuffle.append(negd[i])\n", 234 | " candidate_label_shuffle.append(candidate_label[i])\n", 235 | " train_candidate.append(candidate_shuffle)\n", 236 | " train_label.append(candidate_label_shuffle)\n", 237 | " train_user_his.append(clickids+[0]*(50-len(clickids))) \n", 238 | "\n", 239 | "\n", 240 | "# In[33]:\n", 241 | "\n", 242 | "\n", 243 | "\n", 244 | "\n", 245 | "\n", 246 | "test_candidate=[] \n", 247 | "test_user_his=[]\n", 248 | "test_index=[]\n", 249 | "test_label=[]\n", 250 | "\n", 251 | "for user in valuser:\n", 252 | " userline=user.replace('\\n','').split('\\t')\n", 253 | " clickids=[newsindex[x] for x in userline[3].split()][-50:]\n", 254 | " docs=[newsindex[x.split('-')[0]] for x in userline[4].split()]\n", 255 | " index=[]\n", 256 | " index.append(len(test_candidate))\n", 257 | " \n", 258 | " test_user_his.append(clickids+[0]*(50-len(clickids)))\n", 259 | " for x in userline[4].split():\n", 260 | " test_label.append(int(x.split('-')[1]))\n", 261 | " for doc in docs:\n", 262 | " test_candidate.append(doc)\n", 263 | " index.append(len(test_candidate))\n", 264 | " test_index.append(index)\n", 265 | "\n", 266 | "\n", 267 | "train_candidate=np.array(train_candidate,dtype='int32')\n", 268 | "train_label=np.array(train_label,dtype='int32')\n", 269 | "train_user_his=np.array(train_user_his,dtype='int32')\n", 270 | "\n", 271 | "test_candidate=np.array(test_candidate,dtype='int32') \n", 272 | "test_user_his=np.array(test_user_his,dtype='int32')\n", 273 | "test_label=np.array(test_label,dtype='int32')" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "def generate_batch_data_random(batch_size):\n", 283 | " idx = np.arange(len(train_label))\n", 284 | " np.random.shuffle(idx)\n", 285 | " y=train_label\n", 286 | " batches = [idx[range(batch_size*i, min(len(y), batch_size*(i+1)))] for i in range(len(y)//batch_size+1)]\n", 287 | " \n", 288 | " while (True):\n", 289 | " for i in batches:\n", 290 | " \n", 291 | " item_words = news_words[train_candidate[i]]\n", 292 | " \n", 293 | " item_topic=news_topic[train_candidate[i]]\n", 294 | " item_entity=news_entity[train_candidate[i]]\n", 295 | " \n", 296 | " user_his=news_words[train_user_his[i]]\n", 297 | " user_topic=news_topic[train_user_his[i]]\n", 298 | " user_entity=news_entity[train_user_his[i]]\n", 299 | " user_entity_feature=[]\n", 300 | " all_A=[]\n", 301 | " for s in range(len(i)):\n", 302 | " Asize=len(train_user_his[i][s])+len(category)+50\n", 303 | " newsA=np.zeros((Asize,Asize))\n", 304 | " entityid_set={}\n", 305 | " entityid_set_ids=[]\n", 306 | " for el in range(len(user_entity[s])):\n", 307 | " for e in range(len(user_entity[s][el])):\n", 308 | " if user_entity[s][el][e] not in entityid_set:\n", 309 | " entityid_set[user_entity[s][el][e]]=len(entityid_set)\n", 310 | " entityid_set_ids.append(user_entity[s][el][e])\n", 311 | " if user_entity[s][el][e]!=0 and entityid_set[user_entity[s][el][e]]<50:\n", 312 | " newsA[el][len(train_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]]=1\n", 313 | " newsA[len(train_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]][el]=1\n", 314 | " newsA[len(train_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]][len(train_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]]=1\n", 315 | " entityid_set_ids=entityid_set_ids[:50]\n", 316 | " entityid_set_ids+=[0]*(50-len(entityid_set_ids))\n", 317 | " \n", 318 | " for m in range(len(train_user_his[i][s])):\n", 319 | " if train_user_his[i][s][m]!=0:\n", 320 | " newsA[m][m]=1\n", 321 | " if m>=1 and train_user_his[i][s][m-1]!=0:\n", 322 | " newsA[m][m-1]=1\n", 323 | " if m= 2\n", 435 | " F = input_shape[0][-1]\n", 436 | "\n", 437 | " # Initialize weights for each attention head\n", 438 | " for head in range(self.attn_heads):\n", 439 | " # Layer kernel\n", 440 | " kernel = self.add_weight(shape=(F, self.F_),\n", 441 | " initializer=self.kernel_initializer,\n", 442 | " regularizer=self.kernel_regularizer,\n", 443 | " constraint=self.kernel_constraint,\n", 444 | " name='kernel_{}'.format(head))\n", 445 | " self.kernels.append(kernel)\n", 446 | "\n", 447 | " # # Layer bias\n", 448 | " if self.use_bias:\n", 449 | " bias = self.add_weight(shape=(self.F_, ),\n", 450 | " initializer=self.bias_initializer,\n", 451 | " regularizer=self.bias_regularizer,\n", 452 | " constraint=self.bias_constraint,\n", 453 | " name='bias_{}'.format(head))\n", 454 | " self.biases.append(bias)\n", 455 | "\n", 456 | " # Attention kernels\n", 457 | " attn_kernel_self = self.add_weight(shape=(self.F_, 1),\n", 458 | " initializer=self.attn_kernel_initializer,\n", 459 | " regularizer=self.attn_kernel_regularizer,\n", 460 | " constraint=self.attn_kernel_constraint,\n", 461 | " name='attn_kernel_self_{}'.format(head),)\n", 462 | " attn_kernel_neighs = self.add_weight(shape=(self.F_, 1),\n", 463 | " initializer=self.attn_kernel_initializer,\n", 464 | " regularizer=self.attn_kernel_regularizer,\n", 465 | " constraint=self.attn_kernel_constraint,\n", 466 | " name='attn_kernel_neigh_{}'.format(head))\n", 467 | " self.attn_kernels.append([attn_kernel_self, attn_kernel_neighs])\n", 468 | " self.built = True\n", 469 | "\n", 470 | " def call(self, inputs):\n", 471 | " X = inputs[0] # Node features (N x F)\n", 472 | " A = inputs[1] # Adjacency matrix (N x N)\n", 473 | "\n", 474 | " outputs = []\n", 475 | " for head in range(self.attn_heads):\n", 476 | " kernel = self.kernels[head] # W in the paper (F x F')\n", 477 | " attention_kernel = self.attn_kernels[head] # Attention kernel a in the paper (2F' x 1)\n", 478 | "\n", 479 | " # Compute inputs to attention network\n", 480 | " features = K.dot(X, kernel) # (N x F')\n", 481 | "\n", 482 | " # Compute feature combinations\n", 483 | " # Note: [[a_1], [a_2]]^T [[Wh_i], [Wh_2]] = [a_1]^T [Wh_i] + [a_2]^T [Wh_j]\n", 484 | " attn_for_self = K.dot(features, attention_kernel[0]) # (N x 1), [a_1]^T [Wh_i]\n", 485 | " attn_for_neighs = K.dot(features, attention_kernel[1]) # (N x 1), [a_2]^T [Wh_j]\n", 486 | " # Attention head a(Wh_i, Wh_j) = a^T [[Wh_i], [Wh_j]]\n", 487 | " dense = attn_for_self + K.permute_dimensions(attn_for_neighs,(0,2,1)) # (N x N) via broadcasting\n", 488 | "\n", 489 | " # Add nonlinearty\n", 490 | " dense = LeakyReLU(alpha=0.2)(dense)\n", 491 | "\n", 492 | " # Mask values before activation (Vaswani et al., 2017)\n", 493 | " mask = -10e9 * (1.0 - A)\n", 494 | " dense += mask\n", 495 | "\n", 496 | " # Apply softmax to get attention coefficients\n", 497 | " dense = K.softmax(dense) # (N x N)\n", 498 | "\n", 499 | " # Apply dropout to features and attention coefficients\n", 500 | " dropout_attn = Dropout(self.dropout_rate)(dense) # (N x N)\n", 501 | " dropout_feat = Dropout(self.dropout_rate)(features) # (N x F')\n", 502 | "\n", 503 | " # Linear combination with neighbors' features\n", 504 | " node_features = K.batch_dot(dropout_attn, dropout_feat,axes=[2,1]) # (N x F')\n", 505 | " \n", 506 | " if self.use_bias:\n", 507 | " node_features = K.bias_add(node_features, self.biases[head])\n", 508 | " \n", 509 | " # Add output of attention head to final output\n", 510 | " outputs.append(node_features)\n", 511 | "\n", 512 | " # Aggregate the heads' output according to the reduction method\n", 513 | " if self.attn_heads_reduction == 'concat':\n", 514 | " output = K.concatenate(outputs) # (N x KF')\n", 515 | " else:\n", 516 | " output = K.mean(K.stack(outputs), axis=0) # N x F')\n", 517 | " \n", 518 | " output = self.activation(output)\n", 519 | " return output\n", 520 | "\n", 521 | " def compute_output_shape(self, input_shape):\n", 522 | " output_shape = (input_shape[0][0], input_shape[0][1], self.output_dim)\n", 523 | " return output_shape" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "class Attention(Layer):\n", 533 | "\n", 534 | " def __init__(self, nb_head, size_per_head, **kwargs):\n", 535 | " self.nb_head = nb_head\n", 536 | " self.size_per_head = size_per_head\n", 537 | " self.output_dim = nb_head*size_per_head\n", 538 | " super(Attention, self).__init__(**kwargs)\n", 539 | "\n", 540 | " def build(self, input_shape):\n", 541 | " self.WQ = self.add_weight(name='WQ', \n", 542 | " shape=(input_shape[0][-1], self.output_dim),\n", 543 | " initializer='glorot_uniform',\n", 544 | " trainable=True)\n", 545 | " self.WK = self.add_weight(name='WK', \n", 546 | " shape=(input_shape[1][-1], self.output_dim),\n", 547 | " initializer='glorot_uniform',\n", 548 | " trainable=True)\n", 549 | " self.WV = self.add_weight(name='WV', \n", 550 | " shape=(input_shape[2][-1], self.output_dim),\n", 551 | " initializer='glorot_uniform',\n", 552 | " trainable=True)\n", 553 | " super(Attention, self).build(input_shape)\n", 554 | " \n", 555 | " def Mask(self, inputs, seq_len, mode='mul'):\n", 556 | " if seq_len == None:\n", 557 | " return inputs\n", 558 | " else:\n", 559 | " mask = K.one_hot(seq_len[:,0], K.shape(inputs)[1])\n", 560 | " mask = 1 - K.cumsum(mask, 1)\n", 561 | " for _ in range(len(inputs.shape)-2):\n", 562 | " mask = K.expand_dims(mask, 2)\n", 563 | " if mode == 'mul':\n", 564 | " return inputs * mask\n", 565 | " if mode == 'add':\n", 566 | " return inputs - (1 - mask) * 1e12\n", 567 | " \n", 568 | " def call(self, x): \n", 569 | " if len(x) == 3:\n", 570 | " Q_seq,K_seq,V_seq = x\n", 571 | " Q_len,V_len = None,None\n", 572 | " elif len(x) == 5:\n", 573 | " Q_seq,K_seq,V_seq,Q_len,V_len = x \n", 574 | " Q_seq = K.dot(Q_seq, self.WQ)\n", 575 | " Q_seq = K.reshape(Q_seq, (-1, K.shape(Q_seq)[1], self.nb_head, self.size_per_head))\n", 576 | " Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))\n", 577 | " K_seq = K.dot(K_seq, self.WK)\n", 578 | " K_seq = K.reshape(K_seq, (-1, K.shape(K_seq)[1], self.nb_head, self.size_per_head))\n", 579 | " K_seq = K.permute_dimensions(K_seq, (0,2,1,3))\n", 580 | " V_seq = K.dot(V_seq, self.WV)\n", 581 | " V_seq = K.reshape(V_seq, (-1, K.shape(V_seq)[1], self.nb_head, self.size_per_head))\n", 582 | " V_seq = K.permute_dimensions(V_seq, (0,2,1,3)) \n", 583 | " A = K.batch_dot(Q_seq, K_seq, axes=[3,3]) / self.size_per_head**0.5\n", 584 | " A = K.permute_dimensions(A, (0,3,2,1))\n", 585 | " A = self.Mask(A, V_len, 'add')\n", 586 | " A = K.permute_dimensions(A, (0,3,2,1)) \n", 587 | " A = K.softmax(A) \n", 588 | " O_seq = K.batch_dot(A, V_seq, axes=[3,2])\n", 589 | " O_seq = K.permute_dimensions(O_seq, (0,2,1,3))\n", 590 | " O_seq = K.reshape(O_seq, (-1, K.shape(O_seq)[1], self.output_dim))\n", 591 | " O_seq = self.Mask(O_seq, Q_len, 'mul')\n", 592 | " return O_seq\n", 593 | " \n", 594 | " def compute_output_shape(self, input_shape):\n", 595 | " return (input_shape[0][0], input_shape[0][1], self.output_dim)" 596 | ] 597 | }, 598 | { 599 | "cell_type": "code", 600 | "execution_count": null, 601 | "metadata": {}, 602 | "outputs": [], 603 | "source": [ 604 | "class AttLayer(Layer):\n", 605 | " def __init__(self, dim=200,**kwargs):\n", 606 | " self.init = initializers.get('normal')\n", 607 | " self.dim = dim\n", 608 | " super(AttLayer, self).__init__(**kwargs)\n", 609 | " def build(self, input_shape):\n", 610 | " dim = self.dim\n", 611 | " self.W1 = K.variable(self.init((input_shape[-1], dim)))\n", 612 | " self.b1 = K.variable(self.init((dim,)))\n", 613 | " self.q1 = K.variable(self.init((dim, 1)))\n", 614 | " \n", 615 | " self.trainable_weights = [self.W1, self.b1, self.q1]\n", 616 | " super(AttLayer,self).build(input_shape) \n", 617 | "\n", 618 | " def call(self, inputs, **kwargs): \n", 619 | " attention1 = K.tanh(K.dot(inputs, self.W1) + self.b1)\n", 620 | " attention1 = K.dot(attention1, self.q1)\n", 621 | " attention1 = K.squeeze(attention1, axis=2)\n", 622 | " \n", 623 | " attention = attention1\n", 624 | " attention = K.exp(attention)\n", 625 | " attention_weight = attention / (K.sum(attention, axis=-1, keepdims=True) + K.epsilon())\n", 626 | "\n", 627 | " attention_weight = K.expand_dims(attention_weight)\n", 628 | " weighted_input = inputs * attention_weight\n", 629 | " return K.sum(weighted_input, axis=1)\n", 630 | "\n", 631 | " def compute_mask(self, input, input_mask=None):\n", 632 | " return None\n", 633 | "\n", 634 | " def compute_output_shape(self, input_shape):\n", 635 | " return input_shape[0], input_shape[-1]" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": null, 641 | "metadata": {}, 642 | "outputs": [], 643 | "source": [ 644 | "MAX_SENT_LENGTH=30\n", 645 | "MAX_SENTS=50\n", 646 | "keras.backend.clear_session() \n", 647 | "sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')\n", 648 | "topic_input = Input(shape=(1,), dtype='int32')\n", 649 | "\n", 650 | "embedding_layer = Embedding(len(word_dict), 300, weights=[emb_mat],trainable=True)\n", 651 | "\n", 652 | "topic_embedding_layer = Embedding(len(category),256,trainable=True)\n", 653 | "entity_embedding_layer = Embedding(len(entity_emb_table),100, weights=[entity_emb_table],trainable=True)\n", 654 | "\n", 655 | "embedded_sequences = embedding_layer(sentence_input)\n", 656 | "\n", 657 | "embedded_sequences=Dropout(0.2)(embedded_sequences)\n", 658 | "embedded_sequences=Dropout(0.2)(Attention(20,20)([embedded_sequences,embedded_sequences,embedded_sequences]))\n", 659 | "\n", 660 | "textrep=Dense(256,activation='relu')(AttLayer()(embedded_sequences))\n", 661 | "\n", 662 | "text_encoder = Model([sentence_input], textrep)\n", 663 | "\n", 664 | "review_input = Input((MAX_SENTS,MAX_SENT_LENGTH,), dtype='int32')\n", 665 | "review_encoders= TimeDistributed(text_encoder)(review_input)\n", 666 | "\n", 667 | "review_encoders = Dropout(0.2)(review_encoders)\n", 668 | "A_input = Input(shape=(MAX_SENTS+len(category)+50,MAX_SENTS+len(category)+50), dtype='float32')\n", 669 | "\n", 670 | "all_topic_input = Input((len(category),), dtype='int32')\n", 671 | "all_entity_input = Input((50,), dtype='int32')\n", 672 | "\n", 673 | "all_topic_emb=topic_embedding_layer(all_topic_input)\n", 674 | "all_entity_emb=Dense(256)(entity_embedding_layer(all_entity_input))\n", 675 | "xinput=concatenate([review_encoders ,all_topic_emb, all_entity_emb],axis=1)\n", 676 | "\n", 677 | "graph_attention_basis = GraphAttention(16,\n", 678 | " attn_heads=16,\n", 679 | " attn_heads_reduction='concat',\n", 680 | " dropout_rate=0.2,\n", 681 | " activation='elu')([xinput, A_input])\n", 682 | "\n", 683 | "\n", 684 | "graph_pool_n1 = GraphAttention(16,\n", 685 | " attn_heads=1,\n", 686 | " attn_heads_reduction='concat',\n", 687 | " dropout_rate=0.2,\n", 688 | " activation='elu')([graph_attention_basis, A_input])\n", 689 | "graph_pool_n1 =Activation('softmax')(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(Dense(50)(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(graph_pool_n1))))\n", 690 | "graph_pool_n1=concatenate([graph_pool_n1,Lambda(lambda x:K.zeros_like(x[:,:,:16]))(all_topic_emb),Lambda(lambda x:K.zeros_like(x[:,:,:16]))(all_entity_emb)],axis=1)\n", 691 | "\n", 692 | "\n", 693 | "graph_pool_t1 = GraphAttention(3,\n", 694 | " attn_heads=1,\n", 695 | " attn_heads_reduction='concat',\n", 696 | " dropout_rate=0.2,\n", 697 | " activation='elu')([graph_attention_basis, A_input])\n", 698 | "graph_pool_t1 =Activation('softmax')(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(Dense(len(category))(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(graph_pool_t1))))\n", 699 | "graph_pool_t1=concatenate([Lambda(lambda x:K.zeros_like(x[:,:,:3]))(review_encoders),graph_pool_t1,Lambda(lambda x:K.zeros_like(x[:,:,:3]))(all_entity_emb)],axis=1)\n", 700 | "\n", 701 | "\n", 702 | "graph_pool_e1 = GraphAttention(9,\n", 703 | " attn_heads=1,\n", 704 | " attn_heads_reduction='concat',\n", 705 | " dropout_rate=0.2,\n", 706 | " activation='elu')([graph_attention_basis, A_input])\n", 707 | "graph_pool_e1 =Activation('softmax')(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(Dense(50)(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(graph_pool_e1))))\n", 708 | "graph_pool_e1=concatenate([Lambda(lambda x:K.zeros_like(x[:,:,:9]))(review_encoders),Lambda(lambda x:K.zeros_like(x[:,:,:9]))(all_topic_emb),graph_pool_e1],axis=1)\n", 709 | "\n", 710 | "\n", 711 | "Apool_nn1=Dot((2,1))([Permute((2,1))(graph_pool_n1),Dot((1,1))([A_input,graph_pool_n1])])\n", 712 | "xpool_n1=Dot((1,1))([graph_pool_n1,graph_attention_basis])\n", 713 | "\n", 714 | "Apool_tt1=Dot((2,1))([Permute((2,1))(graph_pool_t1),Dot((1,1))([A_input,graph_pool_t1])])\n", 715 | "xpool_t1=Dot((1,1))([graph_pool_t1,graph_attention_basis])\n", 716 | "\n", 717 | "Apool_ee1=Dot((2,1))([Permute((2,1))(graph_pool_e1),Dot((1,1))([A_input,graph_pool_e1])])\n", 718 | "xpool_e1=Dot((1,1))([graph_pool_e1,graph_attention_basis])\n", 719 | "\n", 720 | "Apool_nt1=Dot((2,1))([Permute((2,1))(graph_pool_n1),Dot((1,1))([A_input,graph_pool_t1])])\n", 721 | "Apool_tn1=Dot((2,1))([Permute((2,1))(graph_pool_t1),Dot((1,1))([A_input,graph_pool_n1])])\n", 722 | "Apool_ne1=Dot((2,1))([Permute((2,1))(graph_pool_n1),Dot((1,1))([A_input,graph_pool_e1])])\n", 723 | "Apool_en1=Dot((2,1))([Permute((2,1))(graph_pool_e1),Dot((1,1))([A_input,graph_pool_n1])])\n", 724 | "Apool_te1=Dot((2,1))([Permute((2,1))(graph_pool_t1),Dot((1,1))([A_input,graph_pool_e1])])\n", 725 | "Apool_et1=Dot((2,1))([Permute((2,1))(graph_pool_e1),Dot((1,1))([A_input,graph_pool_t1])])\n", 726 | "Ap1=concatenate([concatenate([Apool_nn1,Apool_tn1,Apool_en1],axis=1),concatenate([Apool_nt1,Apool_tt1,Apool_et1],axis=1),concatenate([Apool_ne1,Apool_te1,Apool_ee1],axis=1)])\n", 727 | "xp1=concatenate([xpool_n1,xpool_t1,xpool_e1],axis=1)\n", 728 | "\n", 729 | "graph_pool_n2 = GraphAttention(1,\n", 730 | " attn_heads=1,\n", 731 | " attn_heads_reduction='concat',\n", 732 | " dropout_rate=0.2,\n", 733 | " activation='elu')([xp1, Ap1])\n", 734 | "graph_pool_n2 =Activation('softmax')(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(Dense(16)(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(graph_pool_n2))))\n", 735 | "graph_pool_n2=concatenate([graph_pool_n2,Lambda(lambda x:K.zeros_like(x[:,:,:1]))(xpool_t1),Lambda(lambda x:K.zeros_like(x[:,:,:1]))(xpool_e1)],axis=1)\n", 736 | "\n", 737 | "graph_pool_t2 = GraphAttention(1,\n", 738 | " attn_heads=1,\n", 739 | " attn_heads_reduction='concat',\n", 740 | " dropout_rate=0.2,\n", 741 | " activation='elu')([xp1, Ap1])\n", 742 | "graph_pool_t2 =Activation('softmax')(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(Dense(3)(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(graph_pool_t2))))\n", 743 | "graph_pool_t2=concatenate([Lambda(lambda x:K.zeros_like(x[:,:,:1]))(xpool_n1),graph_pool_t2,Lambda(lambda x:K.zeros_like(x[:,:,:1]))(xpool_e1)],axis=1)\n", 744 | "\n", 745 | "graph_pool_e2 = GraphAttention(1,\n", 746 | " attn_heads=1,\n", 747 | " attn_heads_reduction='concat',\n", 748 | " dropout_rate=0.2,\n", 749 | " activation='elu')([xp1, Ap1])\n", 750 | "graph_pool_e2 =Activation('softmax')(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(Dense(9)(Lambda(lambda x:K.permute_dimensions(x,(0,2,1)))(graph_pool_e2))))\n", 751 | "graph_pool_e2=concatenate([Lambda(lambda x:K.zeros_like(x[:,:,:1]))(xpool_n1),Lambda(lambda x:K.zeros_like(x[:,:,:1]))(xpool_t1),graph_pool_e2],axis=1)\n", 752 | "\n", 753 | "\n", 754 | "Apool_nn2=Dot((2,1))([Permute((2,1))(graph_pool_n2),Dot((1,1))([Ap1,graph_pool_n2])])\n", 755 | "xpool_n2=Dot((1,1))([graph_pool_n2,xp1])\n", 756 | "\n", 757 | "Apool_tt2=Dot((2,1))([Permute((2,1))(graph_pool_t2),Dot((1,1))([Ap1,graph_pool_t2])])\n", 758 | "xpool_t2=Dot((1,1))([graph_pool_t2,xp1])\n", 759 | "\n", 760 | "Apool_ee2=Dot((2,1))([Permute((2,1))(graph_pool_e2),Dot((1,1))([Ap1,graph_pool_e2])])\n", 761 | "xpool_e2=Dot((1,1))([graph_pool_e2,xp1])\n", 762 | "\n", 763 | "Apool_nt2=Dot((2,1))([Permute((2,1))(graph_pool_n2),Dot((1,1))([Ap1,graph_pool_t2])])\n", 764 | "Apool_tn2=Dot((2,1))([Permute((2,1))(graph_pool_t2),Dot((1,1))([Ap1,graph_pool_n2])])\n", 765 | "Apool_ne2=Dot((2,1))([Permute((2,1))(graph_pool_n2),Dot((1,1))([Ap1,graph_pool_e2])])\n", 766 | "Apool_en2=Dot((2,1))([Permute((2,1))(graph_pool_e2),Dot((1,1))([Ap1,graph_pool_n2])])\n", 767 | "Apool_te2=Dot((2,1))([Permute((2,1))(graph_pool_t2),Dot((1,1))([Ap1,graph_pool_e2])])\n", 768 | "Apool_et2=Dot((2,1))([Permute((2,1))(graph_pool_e2),Dot((1,1))([Ap1,graph_pool_t2])])\n", 769 | "Ap2=concatenate([concatenate([Apool_nn2,Apool_tn2,Apool_en2],axis=1),concatenate([Apool_nt2,Apool_tt2,Apool_et2],axis=1),concatenate([Apool_ne2,Apool_te2,Apool_ee2],axis=1)])\n", 770 | "xp2=concatenate([xpool_n2,xpool_t2,xpool_e2],axis=1)\n", 771 | "\n", 772 | "diff_pool = GraphAttention(1,\n", 773 | " attn_heads=1,\n", 774 | " attn_heads_reduction='concat',\n", 775 | " dropout_rate=0.2,\n", 776 | " activation='elu')([xp2, Ap2])\n", 777 | "\n", 778 | "diff_pool =Flatten()(Activation('softmax')(diff_pool))\n", 779 | "u_att=keras.layers.Dot((1, 1))([xp2, diff_pool])\n", 780 | "\n", 781 | "\n", 782 | "candidates=Input((1+npratio,MAX_SENT_LENGTH,), dtype='int32')\n", 783 | "candidate_vecs=TimeDistributed(text_encoder)(candidates)\n", 784 | "\n", 785 | "candidatestopic=Input((1+npratio,), dtype='int32')\n", 786 | "candidatesentity=Input((1+npratio,5), dtype='int32')\n", 787 | "candidate_topicemb=topic_embedding_layer(candidatestopic)\n", 788 | "entity_dim_dense=Dense(256)\n", 789 | "entity_att=AttLayer()\n", 790 | "candidate_entityemb = TimeDistributed(entity_att)(TimeDistributed(TimeDistributed(entity_dim_dense))(entity_embedding_layer(candidatesentity)))\n", 791 | "\n", 792 | "view_emb=concatenate([Lambda(lambda y:K.expand_dims(y,axis=2))(x) for x in [candidate_vecs,candidate_topicemb,candidate_entityemb]],axis=2)\n", 793 | "\n", 794 | "view_att=AttLayer()\n", 795 | "\n", 796 | "candidate_emb=TimeDistributed(view_att)(view_emb)\n", 797 | "\n", 798 | "logits = Lambda(lambda x:K.clip(x,-5,5))(keras.layers.dot([u_att, candidate_emb], axes=-1))\n", 799 | "logits = keras.layers.Activation(keras.activations.softmax)(logits)\n", 800 | "\n", 801 | "\n", 802 | "model = Model([candidates,review_input,candidatestopic,candidatesentity,all_entity_input,all_topic_input ,A_input], [logits])\n", 803 | "model.compile(loss=['categorical_crossentropy'], optimizer=Adam(lr=0.0001), metrics=['acc'])\n", 804 | "\n", 805 | "\n", 806 | "candidate_one = keras.Input((MAX_SENT_LENGTH,))\n", 807 | "candidate_one_topic_input = Input(shape=(1,), dtype='int32')\n", 808 | "candidate_one_entity_input = Input(shape=(5,), dtype='int32')\n", 809 | "candidate_one_vec = text_encoder([candidate_one])\n", 810 | "candidate_one_topicemb=topic_embedding_layer(candidate_one_topic_input)\n", 811 | "candidate_one_entityemb=entity_att(TimeDistributed(entity_dim_dense)(entity_embedding_layer(candidate_one_entity_input)))\n", 812 | "\n", 813 | "candidate_one_view_emb=concatenate([Lambda(lambda y:K.expand_dims(y,axis=1))(x) for x in [candidate_one_vec,candidate_one_entityemb]]+[candidate_one_topicemb],axis=1)\n", 814 | "\n", 815 | "candidate_one_vec = view_att(candidate_one_view_emb)\n", 816 | "candidate_encoder = keras.Model([candidate_one,candidate_one_topic_input,candidate_one_entity_input], candidate_one_vec)" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": null, 822 | "metadata": {}, 823 | "outputs": [], 824 | "source": [ 825 | "def generate_batch_data_user(batch_size):\n", 826 | " idlist = np.arange(len(test_user_his)) \n", 827 | " batches = [idlist[range(batch_size*i, min(len(test_user_his), batch_size*(i+1)))] for i in range(len(test_user_his)//batch_size+1)]\n", 828 | " while (True):\n", 829 | " for i in batches: \n", 830 | "\n", 831 | " user_his=news_words[test_user_his[i]]\n", 832 | " user_topic=news_topic[test_user_his[i]]\n", 833 | " user_entity=news_entity[test_user_his[i]]\n", 834 | " user_entity_feature=[]\n", 835 | " all_A=[]\n", 836 | " for s in range(len(i)):\n", 837 | " Asize=len(test_user_his[i][s])+len(category)+50\n", 838 | " newsA=np.zeros((Asize,Asize))\n", 839 | " entityid_set={}\n", 840 | " entityid_set_ids=[]\n", 841 | " for el in range(len(user_entity[s])):\n", 842 | " for e in range(len(user_entity[s][el])):\n", 843 | " if user_entity[s][el][e] not in entityid_set:\n", 844 | " entityid_set[user_entity[s][el][e]]=len(entityid_set)\n", 845 | " entityid_set_ids.append(user_entity[s][el][e])\n", 846 | " if user_entity[s][el][e]!=0 and entityid_set[user_entity[s][el][e]]<50:\n", 847 | " newsA[el][len(test_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]]=1\n", 848 | " newsA[len(test_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]][el]=1\n", 849 | " newsA[len(test_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]][len(test_user_his[i][s])+len(category)+entityid_set[user_entity[s][el][e]]]=1\n", 850 | " entityid_set_ids=entityid_set_ids[:50]\n", 851 | " entityid_set_ids+=[0]*(50-len(entityid_set_ids))\n", 852 | " \n", 853 | " for m in range(len(test_user_his[i][s])):\n", 854 | " if test_user_his[i][s][m]!=0:\n", 855 | " newsA[m][m]=1\n", 856 | " if m>=1 and test_user_his[i][s][m-1]!=0:\n", 857 | " newsA[m][m-1]=1\n", 858 | " if m1:\n", 915 | " all_auc.append(roc_auc_score(test_label[m[0]:m[1]],predictsession[t]))\n", 916 | " all_mrr.append(mrr_score(test_label[m[0]:m[1]],predictsession[t]))\n", 917 | " all_ndcg.append(ndcg_score(test_label[m[0]:m[1]],predictsession[t],k=5))\n", 918 | " all_ndcg2.append(ndcg_score(test_label[m[0]:m[1]],predictsession[t],k=10))\n", 919 | " if len(all_auc)%10000==0:\n", 920 | " print(len(all_auc))\n", 921 | " results.append([np.mean(all_auc),np.mean(all_mrr),np.mean(all_ndcg),np.mean(all_ndcg2)])\n", 922 | " print(np.mean(all_auc),np.mean(all_mrr),np.mean(all_ndcg),np.mean(all_ndcg2))\n", 923 | " " 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": null, 929 | "metadata": {}, 930 | "outputs": [], 931 | "source": [] 932 | } 933 | ], 934 | "metadata": { 935 | "kernelspec": { 936 | "display_name": "Python 3", 937 | "language": "python", 938 | "name": "python3" 939 | }, 940 | "language_info": { 941 | "codemirror_mode": { 942 | "name": "ipython", 943 | "version": 3 944 | }, 945 | "file_extension": ".py", 946 | "mimetype": "text/x-python", 947 | "name": "python", 948 | "nbconvert_exporter": "python", 949 | "pygments_lexer": "ipython3", 950 | "version": "3.6.8" 951 | } 952 | }, 953 | "nbformat": 4, 954 | "nbformat_minor": 2 955 | } 956 | --------------------------------------------------------------------------------