├── DataLoader.py ├── README.md ├── data.tar.gz ├── environment.yml ├── layer.py ├── main.ipynb ├── model.py └── rnn.py /DataLoader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MyDataSet( torch.utils.data.Dataset): 4 | def __init__( self, data_path, user_map, material_map, category_map, max_length): 5 | 6 | user = []; material = []; category = [] 7 | material_historical = []; category_historical = [] 8 | material_historical_neg = []; category_historical_nge = [] 9 | mask = []; sequential_length = [] 10 | target = [] 11 | 12 | with open( data_path, 'r') as fin: 13 | 14 | for line in fin: 15 | item = line.strip('\n').split('\t') 16 | if not item: continue 17 | 18 | user.append( user_map.get( item[1], 0 ) ) 19 | material.append( material_map.get( item[2], 0 ) ) 20 | category.append( category_map.get( item[3], 0 ) ) 21 | 22 | material_historical_item = [0] * max_length 23 | temp = item[4].split("") 24 | if( len( temp) >= max_length): temp = temp[ -max_length:] 25 | for i, m in enumerate( temp): 26 | material_historical_item[i] = material_map.get( m, 0 ) 27 | material_historical.append( material_historical_item) 28 | 29 | category_historical_item = [0] * max_length 30 | temp = item[5].split("") 31 | if( len( temp) >= max_length): temp = temp[ -max_length:] 32 | for i, c in enumerate( temp): 33 | category_historical_item[i] = category_map.get( c, 0 ) 34 | category_historical.append( category_historical_item) 35 | 36 | temp = min( len(temp), max_length) 37 | mask_item = [1] * temp + [0] * ( max_length - temp) 38 | 39 | mask.append( mask_item) 40 | sequential_length.append( temp) 41 | 42 | target.append( int( item[0])) 43 | 44 | self.user = torch.tensor( user) 45 | 46 | self.material = torch.tensor( material) 47 | self.catetory = torch.tensor( category) 48 | 49 | self.material_historical = torch.tensor( material_historical) 50 | self.catetory_historical = torch.tensor( category_historical) 51 | 52 | self.mask = torch.tensor( mask).type( torch.FloatTensor) 53 | self.sequential_length = torch.tensor( sequential_length) 54 | 55 | self.target = torch.tensor( target) 56 | 57 | 58 | def __len__( self): 59 | return len( self.user) 60 | 61 | def __getitem__(self, index): 62 | if torch.is_tensor( index): 63 | index = index.tolist() 64 | 65 | user = self.user[ index] 66 | 67 | material_historical = self.material_historical[ index, :] 68 | category_historical = self.catetory_historical[ index, :] 69 | mask = self.mask[ index, :] 70 | sequential_length = self.sequential_length[ index] 71 | 72 | material = self.material[ index] 73 | category = self.catetory[ index] 74 | 75 | target = self.target[ index] 76 | 77 | return user, material_historical, category_historical, mask, sequential_length , \ 78 | material, category, 0, 0, target -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DIN/DIEN 2 | 3 | Implementation based on pytorch for DIN recommendation algorithm 4 | 5 | 6 | ## Attention 7 | 8 | 1. For convenience, referring to authors tensorflow implementation, feature-embedding dimension is identical. 9 | 2. Without any L1/L2 normalization or dropout strategy, it's supposed to choose suitable model according to the evaluation stage manually. 10 | 11 | ## File description 12 | |file name|description| 13 | |--|----| 14 | |main.ipynb|Session for training and evaluation| 15 | |model.py|Defination of target models| 16 | |DataLoader.py|Self-defined data loader| 17 | |environment.yml|Conda envrionment yaml| 18 | 19 | ## Original paper 20 | [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) 21 | 22 | [Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) 23 | 24 | ## Source data 25 | [meta_Books.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Books.json.gz) 26 | 27 | [reviews_Books.json.gz](http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Books.json.gz) 28 | 29 | Preprocessed data wrapped within `data.tar.gz` came from [mouna99/dien](https://github.com/mouna99/dien) 30 | 31 | ## Reference 32 | 33 | [mouna99/dien](https://github.com/mouna99/dien) 34 | 35 | [alibaba/x-deeplearning](https://github.com/alibaba/x-deeplearning) 36 | 37 | [shenweichen/DeepCTR-Torch](https://github.com/shenweichen/DeepCTR-Torch) 38 | 39 | 40 | ## To do list 41 | 42 | - [x] DIN 43 | - [x] AUGRU 44 | - [ ] DICE activation layer 45 | - [ ] Auxialary loss with neg_sample -------------------------------------------------------------------------------- /data.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ydhuyong/DIN-pytorch/72da8a5505d5855679f464b66343d07135125546/data.tar.gz -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: my_env 2 | channels: 3 | - pytorch 4 | - https://conda.anaconda.org/anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - astroid=2.4.1=py38_0 10 | - attrs=19.3.0=py_0 11 | - backcall=0.1.0=py38_0 12 | - blas=1.0=mkl 13 | - bleach=3.1.4=py_0 14 | - bzip2=1.0.8=h516909a_2 15 | - ca-certificates=2020.6.24=0 16 | - cairo=1.16.0=hcf35c78_1003 17 | - certifi=2020.6.20=py38_0 18 | - cloudpickle=1.4.1=py_0 19 | - cudatoolkit=10.2.89=hfd86e86_1 20 | - cycler=0.10.0=py_2 21 | - cytoolz=0.10.1=py38h7b6447c_0 22 | - dask-core=2.17.2=py_0 23 | - dbus=1.13.14=hb2f20db_0 24 | - decorator=4.4.2=py_0 25 | - defusedxml=0.6.0=py_0 26 | - entrypoints=0.3=py38_0 27 | - expat=2.2.9=he1b5a44_2 28 | - ffmpeg=4.2=h167e202_0 29 | - fontconfig=2.13.1=h86ecdb6_1001 30 | - freetype=2.9.1=h8a8886c_1 31 | - gettext=0.19.8.1=hc5be6a0_1002 32 | - giflib=5.2.1=h516909a_2 33 | - glib=2.64.3=h6f030ca_0 34 | - gmp=6.1.2=h6c8ec71_1 35 | - gnutls=3.6.13=h79a8f9a_0 36 | - graphite2=1.3.13=he1b5a44_1001 37 | - gst-plugins-base=1.14.5=h0935bb2_2 38 | - gstreamer=1.14.5=h36ae1b5_2 39 | - harfbuzz=2.4.0=h9f30f68_3 40 | - hdf5=1.10.6=nompi_h3c11f04_100 41 | - icu=64.2=he1b5a44_1 42 | - imageio=2.8.0=py_0 43 | - importlib_metadata=1.5.0=py38_0 44 | - intel-openmp=2020.0=166 45 | - ipykernel=5.1.4=py38h39e3cac_0 46 | - ipython=7.13.0=py38h5ca1d4c_0 47 | - ipython_genutils=0.2.0=py38_0 48 | - isort=4.3.21=py38_0 49 | - jasper=1.900.1=h07fcdf6_1006 50 | - jedi=0.17.0=py38_0 51 | - jinja2=2.11.2=py_0 52 | - joblib=0.16.0=py_0 53 | - jpeg=9d=h516909a_0 54 | - jsonschema=3.2.0=py38_0 55 | - jupyter_client=6.1.3=py_0 56 | - jupyter_core=4.6.3=py38_0 57 | - kiwisolver=1.2.0=py38hbf85e49_0 58 | - lame=3.100=h14c3975_1001 59 | - lazy-object-proxy=1.4.3=py38h7b6447c_0 60 | - ld_impl_linux-64=2.34=h53a641e_4 61 | - libblas=3.8.0=15_mkl 62 | - libcblas=3.8.0=15_mkl 63 | - libclang=9.0.1=default_hde54327_0 64 | - libedit=3.1.20181209=hc058e9b_0 65 | - libffi=3.2.1=he1b5a44_1007 66 | - libgcc-ng=9.1.0=hdf63c60_0 67 | - libgfortran-ng=7.3.0=hdf63c60_0 68 | - libiconv=1.15=h516909a_1006 69 | - liblapack=3.8.0=15_mkl 70 | - liblapacke=3.8.0=15_mkl 71 | - libllvm9=9.0.1=he513fc3_1 72 | - libopencv=4.2.0=py38_6 73 | - libpng=1.6.37=hbc83047_0 74 | - libsodium=1.0.16=h1bed415_0 75 | - libstdcxx-ng=9.1.0=hdf63c60_0 76 | - libtiff=4.1.0=h2733197_0 77 | - libuuid=2.32.1=h14c3975_1000 78 | - libwebp=1.0.2=h56121f0_5 79 | - libxcb=1.13=h14c3975_1002 80 | - libxkbcommon=0.10.0=he1b5a44_0 81 | - libxml2=2.9.9=hea5a465_1 82 | - markupsafe=1.1.1=py38h7b6447c_0 83 | - matplotlib=3.2.1=0 84 | - matplotlib-base=3.2.1=py38h2af1d28_0 85 | - mccabe=0.6.1=py38_1 86 | - mistune=0.8.4=py38h7b6447c_1000 87 | - mkl=2020.0=166 88 | - mkl-service=2.3.0=py38he904b0f_0 89 | - mkl_fft=1.0.15=py38ha843d7b_0 90 | - mkl_random=1.1.0=py38h962f231_0 91 | - nb_conda_kernels=2.2.3=py38_0 92 | - nbconvert=5.6.1=py38_0 93 | - nbformat=5.0.6=py_0 94 | - ncurses=6.2=he6710b0_1 95 | - nettle=3.4.1=h1bed415_1002 96 | - networkx=2.4=py_0 97 | - ninja=1.9.0=py38hfd86e86_0 98 | - notebook=6.0.3=py38_0 99 | - nspr=4.25=he1b5a44_0 100 | - nss=3.47=he751ad9_0 101 | - numpy=1.18.1=py38h4f9e942_0 102 | - numpy-base=1.18.1=py38hde5b4d6_1 103 | - olefile=0.46=py_0 104 | - opencv=4.2.0=py38_6 105 | - openh264=1.8.0=hdbcaa40_1000 106 | - openssl=1.1.1g=h7b6447c_0 107 | - pandas=1.0.4=py38hcb8c335_0 108 | - pandoc=2.2.3.2=0 109 | - pandocfilters=1.4.2=py38_1 110 | - parso=0.7.0=py_0 111 | - pcre=8.44=he1b5a44_0 112 | - pexpect=4.8.0=py38_0 113 | - pickleshare=0.7.5=py38_1000 114 | - pillow=7.1.2=py38hb39fc2d_0 115 | - pip=20.0.2=py38_1 116 | - pixman=0.38.0=h516909a_1003 117 | - prometheus_client=0.7.1=py_0 118 | - prompt-toolkit=3.0.4=py_0 119 | - prompt_toolkit=3.0.4=0 120 | - pthread-stubs=0.4=h14c3975_1001 121 | - ptyprocess=0.6.0=py38_0 122 | - py-opencv=4.2.0=py38h23f93f0_6 123 | - pygments=2.6.1=py_0 124 | - pylint=2.5.2=py38_0 125 | - pyparsing=2.4.7=pyh9f0ad1d_0 126 | - pyqt=5.12.3=py38ha8c2ead_3 127 | - pyrsistent=0.16.0=py38h7b6447c_0 128 | - python=3.8.3=cpython_he5300dc_0 129 | - python-dateutil=2.8.1=py_0 130 | - python_abi=3.8=1_cp38 131 | - pytorch=1.5.0=py3.8_cuda10.2.89_cudnn7.6.5_0 132 | - pytz=2020.1=pyh9f0ad1d_0 133 | - pywavelets=1.1.1=py38h7b6447c_0 134 | - pyyaml=5.3.1=py38h7b6447c_0 135 | - pyzmq=18.1.1=py38he6710b0_0 136 | - qt=5.12.5=hd8c4c69_1 137 | - readline=8.0=h7b6447c_0 138 | - redis=5.0.3=h7b6447c_0 139 | - scikit-image=0.16.2=py38h0573a6f_0 140 | - scikit-learn=0.23.1=py38h423224d_0 141 | - scipy=1.4.1=py38h0b6359f_0 142 | - send2trash=1.5.0=py38_0 143 | - setuptools=46.1.3=py38_0 144 | - sip=4.19.13=py38he6710b0_0 145 | - six=1.14.0=py38_0 146 | - sqlite=3.31.1=h62c20be_1 147 | - terminado=0.8.3=py38_0 148 | - testpath=0.4.4=py_0 149 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 150 | - tk=8.6.10=hed695b0_0 151 | - toml=0.10.0=pyh91ea838_0 152 | - toolz=0.10.0=py_0 153 | - torchvision=0.6.0=py38_cu102 154 | - tornado=6.0.4=py38h7b6447c_1 155 | - traitlets=4.3.3=py38_0 156 | - wcwidth=0.1.9=py_0 157 | - webencodings=0.5.1=py38_1 158 | - wheel=0.34.2=py38_0 159 | - wrapt=1.11.2=py38h7b6447c_0 160 | - x264=1!152.20180806=h14c3975_0 161 | - xorg-kbproto=1.0.7=h14c3975_1002 162 | - xorg-libice=1.0.10=h516909a_0 163 | - xorg-libsm=1.2.3=h84519dc_1000 164 | - xorg-libx11=1.6.9=h516909a_0 165 | - xorg-libxau=1.0.9=h14c3975_0 166 | - xorg-libxdmcp=1.1.3=h516909a_0 167 | - xorg-libxext=1.3.4=h516909a_0 168 | - xorg-libxrender=0.9.10=h516909a_1002 169 | - xorg-renderproto=0.11.1=h14c3975_1002 170 | - xorg-xextproto=7.3.0=h14c3975_1002 171 | - xorg-xproto=7.0.31=h14c3975_1007 172 | - xz=5.2.5=h7b6447c_0 173 | - yaml=0.1.7=h96e3832_1 174 | - zeromq=4.3.1=he6710b0_3 175 | - zipp=3.1.0=py_0 176 | - zlib=1.2.11=h7b6447c_3 177 | - zstd=1.3.7=h0b5b093_0 178 | - pip: 179 | - pyqt5-sip==4.19.18 180 | - pyqtchart==5.12 181 | - pyqtwebengine==5.12.1 182 | prefix: /home/juboge/opt/anaconda3/envs/my_env 183 | 184 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MLP( nn.Module): 6 | 7 | def __init__(self, input_dimension, hidden_size , target_dimension = 1, activation_layer = 'LeakyReLU'): 8 | super().__init__() 9 | 10 | Activation = nn.LeakyReLU 11 | 12 | # if activation_layer == 'DICE': pass 13 | # elif activation_layer == 'LeakyReLU': pass 14 | 15 | def _dense( in_dim, out_dim, bias = False): 16 | return nn.Sequential( 17 | nn.Linear( in_dim, out_dim, bias = bias), 18 | nn.BatchNorm1d( out_dim), 19 | Activation( 0.1 )) 20 | 21 | dimension_pair = [input_dimension] + hidden_size 22 | layers = [ _dense( dimension_pair[i], dimension_pair[i+1]) for i in range( len( hidden_size))] 23 | 24 | layers.append( nn.Linear( hidden_size[-1], target_dimension)) 25 | layers.insert( 0, nn.BatchNorm1d( input_dimension) ) 26 | 27 | self.model = nn.Sequential( *layers ) 28 | 29 | def forward( self, X): return self.model( X) 30 | 31 | 32 | class InputEmbedding( nn.Module): 33 | 34 | def __init__(self, n_uid, n_mid, n_cid, embedding_dim ): 35 | super().__init__() 36 | self.user_embedding_unit = nn.Embedding( n_uid, embedding_dim) 37 | self.material_embedding_unit = nn.Embedding( n_mid, embedding_dim) 38 | self.category_embedding_unit = nn.Embedding( n_cid, embedding_dim) 39 | 40 | def forward( self, user, material, category, material_historical, category_historical, 41 | material_historical_neg, category_historical_neg, neg_smaple = False ): 42 | 43 | user_embedding = self.user_embedding_unit( user) 44 | 45 | material_embedding = self.material_embedding_unit( material) 46 | material_historical_embedding = self.material_embedding_unit( material_historical) 47 | 48 | category_embedding = self.category_embedding_unit( category) 49 | category_historical_embedding = self.category_embedding_unit( category_historical) 50 | 51 | material_historical_neg_embedding = self.material_embedding_unit( material_historical_neg) if neg_smaple else None 52 | category_historical_neg_embedding = self.category_embedding_unit( category_historical_neg) if neg_smaple else None 53 | 54 | ans = [ user_embedding, material_historical_embedding, category_historical_embedding, 55 | material_embedding, category_embedding, material_historical_neg_embedding, category_historical_neg_embedding ] 56 | return tuple( map( lambda x: x.squeeze() if x != None else None , ans) ) 57 | 58 | 59 | 60 | class AttentionLayer( nn.Module): 61 | 62 | def __init__(self, embedding_dim, hidden_size, activation_layer = 'sigmoid'): 63 | super().__init__() 64 | 65 | Activation = nn.Sigmoid 66 | if activation_layer == 'Dice': pass 67 | 68 | def _dense( in_dim, out_dim): 69 | return nn.Sequential( nn.Linear( in_dim, out_dim), Activation() ) 70 | 71 | dimension_pair = [embedding_dim * 8] + hidden_size 72 | layers = [ _dense( dimension_pair[i], dimension_pair[i+1]) for i in range( len( hidden_size))] 73 | layers.append( nn.Linear( hidden_size[-1], 1) ) 74 | self.model = nn.Sequential( *layers) 75 | 76 | def forward( self, query, fact, mask, return_scores = False): 77 | B, T, D = fact.shape 78 | 79 | query = torch.ones((B, T, 1) ).type( query.type() ) * query.view( (B, 1, D)) 80 | # query = query.view(-1).expand( T, -1).view( T, B, D).permute( 1, 0, 2) 81 | 82 | combination = torch.cat( [ fact, query, fact * query, query - fact ], dim = 2) 83 | 84 | scores = self.model( combination).squeeze() 85 | scores = torch.where( mask == 1, scores, torch.ones_like( scores) * ( -2 ** 31 ) ) 86 | 87 | scores = ( scores.softmax( dim = -1) * mask ).view( (B , 1, T)) 88 | 89 | if return_scores: return scores.squeeze() 90 | return torch.matmul( scores, fact).squeeze() -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "orig_nbformat": 2, 6 | "kernelspec": { 7 | "name": "python38264bitmyenvconda79347d8e286349d2b920a35841b643b7", 8 | "display_name": "Python 3.8.2 64-bit ('my_env': conda)" 9 | }, 10 | "colab": { 11 | "name": "main.ipynb", 12 | "provenance": [], 13 | "collapsed_sections": [] 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "code", 20 | "metadata": { 21 | "id": "bXDFpyeik6N3", 22 | "colab_type": "code", 23 | "colab": {} 24 | }, 25 | "source": [ 26 | "import torch\n", 27 | "import torch.nn as nn\n", 28 | "import os\n", 29 | "import sys\n", 30 | "import pickle as pk\n", 31 | "import numpy as np\n", 32 | "import random\n", 33 | "\n", 34 | "from sklearn.metrics import roc_auc_score\n" 35 | ], 36 | "execution_count": null, 37 | "outputs": [] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "metadata": { 42 | "id": "ouEBXD1uk6N8", 43 | "colab_type": "code", 44 | "colab": { 45 | "base_uri": "https://localhost:8080/", 46 | "height": 151 47 | }, 48 | "outputId": "8334bd71-1517-40ea-cab0-9db4b7add7bf" 49 | }, 50 | "source": [ 51 | "workspace_dir = '.'\n", 52 | "try:\n", 53 | " from google.colab import drive\n", 54 | " drive.mount( '/content/drive/' )\n", 55 | "\n", 56 | " workspace_dir = os.path.join( '.' , 'drive', 'My Drive', 'DIN-pytorch')\n", 57 | " sys.path.append( workspace_dir)\n", 58 | " ! rm -rf data\n", 59 | " ! tar zxf \"{workspace_dir}/data.tar.gz\" -C ./\n", 60 | " ! tar zxf \"{workspace_dir}/loader.tar.gz\" -C ./\n", 61 | " ! ls -al data \n", 62 | "except ImportError:\n", 63 | " pass" 64 | ], 65 | "execution_count": null, 66 | "outputs": [] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "metadata": { 71 | "tags": [], 72 | "id": "EMQbWEQBk6N_", 73 | "colab_type": "code", 74 | "colab": {} 75 | }, 76 | "source": [ 77 | "from model import DIN, DIEN, DynamicGRU\n", 78 | "from DataLoader import MyDataSet\n", 79 | "\n", 80 | "%load_ext autoreload\n", 81 | "%autoreload 2" 82 | ], 83 | "execution_count": null, 84 | "outputs": [] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "KyIw70o-k6OC", 90 | "colab_type": "code", 91 | "colab": {} 92 | }, 93 | "source": [ 94 | "#Model hyper parameter\n", 95 | "MAX_LEN = 100\n", 96 | "EMBEDDING_DIM = 18\n", 97 | "# HIDDEN_SIZE_ATTENTION = [80, 40]\n", 98 | "# HIDDEN_SIZE_FC = [200, 80]\n", 99 | "# ACTIVATION_LAYER = 'LeakyReLU' # lr = 0.01\n", 100 | "\n", 101 | "\n", 102 | "# Adam\n", 103 | "LR = 1e-3\n", 104 | "BETA1 = 0.5\n", 105 | "BETA2 = 0.99\n", 106 | "\n", 107 | "# Train\n", 108 | "BATCH_SIZE = 128\n", 109 | "EPOCH_TIME = 20\n", 110 | "TEST_ITER = 1000\n", 111 | "\n", 112 | "RANDOM_SEED = 19940808\n", 113 | "\n", 114 | "USE_CUDA = True" 115 | ], 116 | "execution_count": null, 117 | "outputs": [] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "7NQM5lkgk6OF", 123 | "colab_type": "code", 124 | "colab": {} 125 | }, 126 | "source": [ 127 | "train_file = os.path.join( './data', \"local_train_splitByUser\")\n", 128 | "test_file = os.path.join( './data', \"local_test_splitByUser\")\n", 129 | "uid_voc = os.path.join( './data', \"uid_voc.pkl\")\n", 130 | "mid_voc = os.path.join( './data', \"mid_voc.pkl\")\n", 131 | "cat_voc = os.path.join( './data', \"cat_voc.pkl\")" 132 | ], 133 | "execution_count": null, 134 | "outputs": [] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "metadata": { 139 | "tags": [], 140 | "id": "YnNy6DAqk6OH", 141 | "colab_type": "code", 142 | "colab": {} 143 | }, 144 | "source": [ 145 | "if USE_CUDA and torch.cuda.is_available():\n", 146 | " print( \"Cuda is avialable\" )\n", 147 | " device = torch.device('cuda')\n", 148 | " dtype = torch.cuda.FloatTensor\n", 149 | "else:\n", 150 | " device = torch.device( 'cpu')\n", 151 | " dtype = torch.FloatTensor" 152 | ], 153 | "execution_count": null, 154 | "outputs": [] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "id": "qC6I-EKmk6OK", 160 | "colab_type": "code", 161 | "colab": {} 162 | }, 163 | "source": [ 164 | "# Stable the random seed\n", 165 | "def same_seeds(seed = RANDOM_SEED):\n", 166 | " torch.manual_seed(seed)\n", 167 | " if torch.cuda.is_available():\n", 168 | " torch.cuda.manual_seed(seed)\n", 169 | " torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.\n", 170 | " np.random.seed(seed) \n", 171 | " random.seed(seed) \n", 172 | " torch.backends.cudnn.benchmark = False\n", 173 | " torch.backends.cudnn.deterministic = True\n", 174 | "\n", 175 | "# Initilize parameters\n", 176 | "def weights_init( m):\n", 177 | " try:\n", 178 | " classname = m.__class__.__name__\n", 179 | " if classname.find( 'BatchNorm') != -1:\n", 180 | " nn.init.normal_( m.weight.data, 1.0, 0.02)\n", 181 | " nn.init.constant_( m.bias.data, 0)\n", 182 | " elif classname.find( 'Linear') != -1:\n", 183 | " nn.init.normal_( m.weight.data, 0.0, 0.02)\n", 184 | " elif classname.find( 'Embedding') != -1:\n", 185 | " m.weight.data.uniform_(-1, 1)\n", 186 | " except AttributeError:\n", 187 | " print( \"AttributeError:\", classname)\n", 188 | " \n", 189 | "\n", 190 | "\n", 191 | "def eval_output( scores, target, loss_function = torch.nn.functional.binary_cross_entropy_with_logits):\n", 192 | " loss = loss_function( scores.type( dtype) , target.type( dtype))\n", 193 | "\n", 194 | " y_pred = scores.sigmoid().round()\n", 195 | " accuracy = ( y_pred == target).type( dtype).mean()\n", 196 | "\n", 197 | " auc = roc_auc_score( target.cpu().detach(), scores.cpu().detach() )\n", 198 | " return loss, accuracy, auc" 199 | ], 200 | "execution_count": null, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "WajPHmvzk6ON", 207 | "colab_type": "code", 208 | "colab": {} 209 | }, 210 | "source": [ 211 | "# The dict mapping description(string) to type index(int) \n", 212 | "# A more graceful api https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html#sklearn.preprocessing.LabelEncoder not used in this project\n", 213 | "\n", 214 | "user_map = pk.load( open( uid_voc, 'rb')); n_uid = len( user_map)\n", 215 | "material_map = pk.load( open( mid_voc, 'rb')); n_mid = len( material_map)\n", 216 | "category_map = pk.load( open( cat_voc, 'rb')); n_cat = len( category_map)" 217 | ], 218 | "execution_count": null, 219 | "outputs": [] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "metadata": { 224 | "tags": [], 225 | "id": "D0dxFeCNk6OP", 226 | "colab_type": "code", 227 | "colab": {} 228 | }, 229 | "source": [ 230 | "same_seeds( RANDOM_SEED)\n", 231 | "\n", 232 | "dataset_train = MyDataSet( train_file, user_map, material_map, category_map, max_length = MAX_LEN)\n", 233 | "dataset_test = MyDataSet( test_file, user_map, material_map, category_map, max_length = MAX_LEN)\n", 234 | "\n", 235 | "loader_train = torch.utils.data.DataLoader( dataset_train, batch_size = BATCH_SIZE, shuffle = True)\n", 236 | "loader_test = torch.utils.data.DataLoader( dataset_test, batch_size = BATCH_SIZE, shuffle = False)\n", 237 | "\n", 238 | "# with open( 'data/loader.pk', 'rb') as fin:\n", 239 | "# loader_train, loader_test = pk.load(fin) " 240 | ], 241 | "execution_count": null, 242 | "outputs": [] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "metadata": { 247 | "tags": [], 248 | "id": "2XTRK9jlk6OS", 249 | "colab_type": "code", 250 | "colab": {} 251 | }, 252 | "source": [ 253 | "# Get model and initialize it\n", 254 | "# model = DIEN( n_uid, n_mid, n_cat, EMBEDDING_DIM).to( device)\n", 255 | "model = DIN( n_uid, n_mid, n_cat, EMBEDDING_DIM ).to( device)\n", 256 | "model.apply( weights_init)\n", 257 | "\n", 258 | "# Set loss function and optimizer\n", 259 | "optimizer = torch.optim.Adam( model.parameters(), LR, ( BETA1, BETA2))\n", 260 | "\n", 261 | "model.train(); iter = 0\n", 262 | "for epoch in range( EPOCH_TIME):\n", 263 | "\n", 264 | " for i, data in enumerate( loader_train):\n", 265 | " iter += 1\n", 266 | "\n", 267 | " # transform data to target device\n", 268 | " \n", 269 | " data = [ item.to( device) if item != None else None for item in data]\n", 270 | " target = data.pop(-1) \n", 271 | " \n", 272 | " model.zero_grad()\n", 273 | "\n", 274 | " scores = model( data, neg_sample = False)\n", 275 | " \n", 276 | " loss, accuracy, auc = eval_output( scores, target)\n", 277 | "\n", 278 | " loss.backward()\n", 279 | " optimizer.step( )\n", 280 | " \n", 281 | " print( \"\\r[%d/%d][%d/%d]\\tloss:%.5f\\tacc:%.5f\\tauc:%.5f\"%( epoch + 1, EPOCH_TIME, i + 1, len( loader_train), loss.item(), accuracy.item(), auc.item() ) ,end='')\n", 282 | "\n", 283 | " if iter % TEST_ITER == 0:\n", 284 | " model.eval()\n", 285 | " with torch.no_grad():\n", 286 | " score_list = []; target_list = []\n", 287 | " for data in loader_test:\n", 288 | " data = [ item.to( device) if item != None else None for item in data]\n", 289 | " \n", 290 | " target = data.pop(-1)\n", 291 | "\n", 292 | " scores = model( data, neg_sample = False)\n", 293 | " score_list.append( scores)\n", 294 | " target_list.append( target)\n", 295 | " scores = torch.cat( score_list, dim = -1)\n", 296 | " target = torch.cat( target_list, dim = -1)\n", 297 | " loss, accuracy, auc = eval_output( scores, target)\n", 298 | " print( \"\\tTest Set\\tloss:%.5f\\tacc:%.5f\\tauc:%.5f\"%( loss.item(), accuracy.item(), auc.item() ) )\n", 299 | " model.train()" 300 | ], 301 | "execution_count": null, 302 | "outputs": [] 303 | } 304 | ] 305 | } -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from layer import * 6 | from rnn import * 7 | 8 | class DIN( nn.Module): 9 | def __init__(self, n_uid, n_mid, n_cid, embedding_dim, ): 10 | super().__init__() 11 | 12 | self.embedding_layer = InputEmbedding( n_uid, n_mid, n_cid, embedding_dim ) 13 | self.attention_layer = AttentionLayer( embedding_dim, hidden_size = [ 80, 40], activation_layer='sigmoid') 14 | # self.output_layer = MLP( embedding_dim * 9, [ 200, 80], 1, 'ReLU') 15 | self.output_layer = MLP( embedding_dim * 7, [ 200, 80], 1, 'ReLU') 16 | 17 | def forward( self, data, neg_sample = False): 18 | 19 | user, material_historical, category_historical, mask, sequential_length , material, category, \ 20 | material_historical_neg, category_historical_neg = data 21 | 22 | user_embedding, material_historical_embedding, category_historical_embedding, \ 23 | material_embedding, category_embedding, material_historical_neg_embedding, category_historical_neg_embedding = \ 24 | self.embedding_layer( user, material, category, material_historical, category_historical, material_historical_neg, category_historical_neg, neg_sample) 25 | 26 | item_embedding = torch.cat( [ material_embedding, category_embedding], dim = 1) 27 | item_historical_embedding = torch.cat( [ material_historical_embedding, category_historical_embedding], dim = 2 ) 28 | 29 | item_historical_embedding_sum = torch.matmul( mask.unsqueeze( dim = 1), item_historical_embedding).squeeze() / sequential_length.type( mask.type() ).unsqueeze( dim = 1) 30 | 31 | 32 | attention_feature = self.attention_layer( item_embedding, item_historical_embedding, mask) 33 | 34 | # combination = torch.cat( [ user_embedding, item_embedding, item_historical_embedding_sum, attention_feature ], dim = 1) 35 | combination = torch.cat( [ user_embedding, item_embedding, item_historical_embedding_sum, 36 | # item_embedding * item_historical_embedding_sum, 37 | attention_feature ], dim = 1) 38 | 39 | scores = self.output_layer( combination) 40 | 41 | return scores.squeeze() 42 | 43 | class DIEN( nn.Module): 44 | def __init__(self, n_uid, n_mid, n_cid, embedding_dim): 45 | super().__init__() 46 | 47 | self.embedding_layer = InputEmbedding( n_uid, n_mid, n_cid, embedding_dim ) 48 | self.gru_based_layer = nn.GRU( embedding_dim * 2 , embedding_dim * 2, batch_first = True) 49 | self.attention_layer = AttentionLayer( embedding_dim, hidden_size = [ 80, 40], activation_layer='sigmoid') 50 | self.gru_customized_layer = DynamicGRU( embedding_dim * 2, embedding_dim * 2) 51 | 52 | self.output_layer = MLP( embedding_dim * 9, [ 200, 80], 1, 'ReLU') 53 | # self.output_layer = MLP( embedding_dim * 9, [ 200, 80], 1, 'ReLU') 54 | 55 | def forward( self, data, neg_sample = False): 56 | 57 | user, material_historical, category_historical, mask, sequential_length , material, category, \ 58 | material_historical_neg, category_historical_neg = data 59 | 60 | user_embedding, material_historical_embedding, category_historical_embedding, \ 61 | material_embedding, category_embedding, material_historical_neg_embedding, category_historical_neg_embedding = \ 62 | self.embedding_layer( user, material, category, material_historical, category_historical, material_historical_neg, category_historical_neg, neg_sample) 63 | 64 | item_embedding = torch.cat( [ material_embedding, category_embedding], dim = 1) 65 | item_historical_embedding = torch.cat( [ material_historical_embedding, category_historical_embedding], dim = 2 ) 66 | 67 | item_historical_embedding_sum = torch.matmul( mask.unsqueeze( dim = 1), item_historical_embedding).squeeze() / sequential_length.unsqueeze( dim = 1) 68 | 69 | output_based_gru, _ = self.gru_based_layer( item_historical_embedding) 70 | attention_scores = self.attention_layer( item_embedding, output_based_gru, mask, return_scores = True) 71 | output_customized_gru = self.gru_customized_layer( output_based_gru, attention_scores) 72 | 73 | attention_feature = output_customized_gru[ range( len( sequential_length)), sequential_length - 1] 74 | 75 | combination = torch.cat( [ user_embedding, item_embedding, item_historical_embedding_sum, item_embedding * item_historical_embedding_sum, attention_feature ], dim = 1) 76 | 77 | scores = self.output_layer( combination) 78 | 79 | return scores.squeeze() -------------------------------------------------------------------------------- /rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class AUGRUCell(nn.Module): 5 | def __init__(self, input_dim, hidden_dim, bias = True): 6 | super(AUGRUCell, self).__init__() 7 | 8 | in_dim = input_dim + hidden_dim 9 | self.reset_gate = nn.Sequential( nn.Linear( in_dim, hidden_dim, bias = bias), nn.Sigmoid()) 10 | self.update_gate = nn.Sequential( nn.Linear( in_dim, hidden_dim, bias = bias), nn.Sigmoid()) 11 | self.h_hat_gate = nn.Sequential( nn.Linear( in_dim, hidden_dim, bias = bias), nn.Tanh()) 12 | 13 | 14 | def forward(self, X, h_prev, attention_score): 15 | temp_input = torch.cat( [ h_prev, X ] , dim = -1) 16 | r = self.reset_gate( temp_input) 17 | u = self.update_gate( temp_input) 18 | 19 | h_hat = self.h_hat_gate( torch.cat( [ h_prev * r, X], dim = -1) ) 20 | 21 | u = attention_score.unsqueeze(1) * u 22 | h_cur = (1. - u) * h_prev + u * h_hat 23 | 24 | return h_cur 25 | 26 | 27 | class DynamicGRU(nn.Module): 28 | def __init__(self, input_dim, hidden_dim, bias=True): 29 | super().__init__() 30 | self.input_dim = input_dim 31 | self.hidden_dim = hidden_dim 32 | self.rnn_cell = AUGRUCell( input_dim, hidden_dim, bias = True) 33 | 34 | def forward(self, X, attenion_scores , h0 = None ): 35 | B, T, D = X.shape 36 | H = self.hidden_dim 37 | 38 | output = torch.zeros( B, T, H ).type( X.type() ) 39 | h_prev = torch.zeros( B, H ).type( X.type() ) if h0 == None else h0 40 | for t in range( T): 41 | h_prev = output[ : , t, :] = self.rnn_cell( X[ : , t, :], h_prev, attenion_scores[ :, t] ) 42 | return output 43 | --------------------------------------------------------------------------------