├── .gitignore ├── LICENSE ├── README.md ├── datasets └── delicious │ ├── edges.txt │ ├── stats.txt │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── preprocess.py ├── run.py └── srs ├── layers ├── narm.py ├── nextitnet.py ├── seframe.py ├── serec.py ├── srgnn.py ├── ssrm.py └── stamp.py ├── models ├── DGRec.py ├── NARM.py ├── NextItNet.py ├── SERec.py ├── SNARM.py ├── SNextItNet.py ├── SRGNN.py ├── SSRGNN.py ├── SSRM.py ├── SSSRM.py ├── SSTAMP.py └── STAMP.py └── utils ├── Dict.py ├── argparse.py ├── data ├── collate.py ├── load.py ├── preprocess.py └── transform.py ├── prepare_batch.py └── train_runner.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | .vscode/ 4 | /datasets 5 | /*.sh 6 | /*.txt 7 | /logs/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tianwen CHEN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SEFrame 2 | This repository contains the code for the paper "An Efficient and Effective Framework for Session-based Social Recommendation". 3 | 4 | ## Requirements 5 | - Python 3.8 6 | - CUDA 10.2 7 | - PyTorch 1.7.1 8 | - DGL 0.5.3 9 | - NumPy 1.19.2 10 | - Pandas 1.1.3 11 | 12 | ## Usage 13 | 1. Install all the requirements. 14 | 15 | 2. Download the datasets: 16 | - [Gowalla](https://snap.stanford.edu/data/loc-gowalla.html) 17 | - [Delicious](https://grouplens.org/datasets/hetrec-2011/) 18 | - [Foursquare](https://sites.google.com/site/yangdingqi/home/foursquare-dataset) 19 | 20 | 3. Create a folder called `datasets` and extract the raw data files to the folder. 21 | The folder should include the following files for each dataset: 22 | - Gowalla: `loc-gowalla_totalCheckins.txt` and `loc-gowalla_edges.txt` 23 | - Delicious: `user_taggedbookmarks-timestamps.dat` and `user_contacts-timestamps.dat` 24 | - Foursquare: `dataset_WWW_Checkins_anonymized.txt` and `dataset_WWW_friendship_new.txt` 25 | 26 | 4. Preprocess the datasets using the Python script [preprocess.py](preprocess.py). 27 | For example, to preprocess the *Gowalla* dataset, run the following command: 28 | ```bash 29 | python preprocess.py --dataset gowalla 30 | ``` 31 | The above command will create a folder `datasets/gowalla` to store the preprocessed data files. 32 | Replace `gowalla` with `delicious` or `foursquare` to preprocess other datasets. 33 | 34 | To see the detailed usage of `preprocess.py`, run the following command: 35 | ```bash 36 | python preprocess.py -h 37 | ``` 38 | 39 | 5. Train and evaluate a model using the Python script [run.py](run.py). 40 | For example, to train and evaluate the model NARM on the *Gowalla* dataset, run the following command: 41 | ```bash 42 | python run.py --model NARM --dataset-dir datasets/gowalla 43 | ``` 44 | Other available models are NextItNet, STAMP, SRGNN, SSRM, SNARM, SNextItNet, SSTAMP, SSRGNN, SSSRM, DGRec, and SERec. 45 | You can also see all the available models in the [srs/models](srs/models) folder. 46 | 47 | To see the detailed usage of `run.py`, run the following command: 48 | ```bash 49 | python run.py -h 50 | ``` 51 | 52 | ## Dataset Format 53 | You can train the models using your datasets. Each dataset should contain the following files: 54 | 55 | - `stats.txt`: A TSV file containing three fields, `num_users`, `num_items`, and `max_len` (the maximum length of sessions). The first row is the header and the second row contains the values. 56 | 57 | - `train.txt`: A TSV file containing all training sessions, where each session has three fileds, namely, `sessionId`, `userId`, and `items`. Both `sessionId` and `userId` should be integers. A session with a larger `sessionId` means that it was generated later (this requirement can be ignored if the used models do not care about the order of sessions, i.e., when the models are not DGRec). The `userId` should be in the range of `[0, num_users)`. The `items` field of each session contains the clicked items in the session which is a sequence of item IDs separated by commas. The item IDs should be in the range of `[0, num_items)`. 58 | 59 | - `valid.txt` and `test.txt`: TSV files containing all validation and test sessions, respectively. Both files have the same format as `train.txt`. Note that the session IDs in `valid.txt` and `test.txt` should be larger than those in `train.txt`. 60 | 61 | - `edges.txt`: A TSV file containing the relations in the social network. It has two columns, `follower` and `followee`. Both columns contain the user IDs. 62 | 63 | You can see [datasets/delicious](datasets/delicious) for an example of the dataset. 64 | 65 | ## Citation 66 | If you use this code for your research, please cite our [paper](http://home.cse.ust.hk/~raywong/paper/wsdm21-SEFrame.pdf): 67 | ``` 68 | @inproceedings{chen2021seframe, 69 | title="An Efficient and Effective Framework for Session-based Social Recommendation", 70 | author="Tianwen {Chen} and Raymond Chi-Wing {Wong}", 71 | booktitle="Proceedings of the Fourteenth ACM International Conference on Web Search and Data Mining (WSDM '21)", 72 | pages="400--408", 73 | year="2021" 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /datasets/delicious/stats.txt: -------------------------------------------------------------------------------- 1 | num_users num_items max_len 2 | 1313 5793 50 3 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from srs.utils import argparse 3 | from pathlib import Path 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument( 7 | '--dataset', 8 | choices=['gowalla', 'delicious', 'foursquare'], 9 | required=True, 10 | help='the dataset name', 11 | ) 12 | parser.add_argument( 13 | '--input-dir', 14 | type=Path, 15 | default='datasets', 16 | help='the directory containing the raw data files', 17 | ) 18 | parser.add_argument( 19 | '--output-dir', 20 | type=Path, 21 | default='datasets', 22 | help='the directory to store the preprocessed dataset', 23 | ) 24 | parser.add_argument( 25 | '--train-split', type=float, default=0.6, help='the ratio of the training set' 26 | ) 27 | parser.add_argument( 28 | '--max-len', type=int, default=50, help='the maximum session length' 29 | ) 30 | args = parser.parse_args() 31 | 32 | FILENAMES = { 33 | 'gowalla': ['loc-gowalla_totalCheckins.txt', 'loc-gowalla_edges.txt'], 34 | 'delicious': [ 35 | 'user_taggedbookmarks-timestamps.dat', 36 | 'user_contacts-timestamps.dat', 37 | ], 38 | 'foursquare': [ 39 | 'dataset_WWW_Checkins_anonymized.txt', 40 | 'dataset_WWW_friendship_new.txt', 41 | 'raw_POIs.txt', 42 | ], 43 | } 44 | 45 | filenames = FILENAMES[args.dataset] 46 | for filename in filenames: 47 | if not (args.input_dir / filename).exists(): 48 | print(f'File {filename} not found in {args.input_dir}', file=sys.stderr) 49 | sys.exit(1) 50 | clicks = args.input_dir / filenames[0] 51 | edges = args.input_dir / filenames[1] 52 | 53 | import numpy as np 54 | import pandas as pd 55 | from srs.utils.data.preprocess import preprocess, update_id 56 | 57 | print('reading dataset...') 58 | if args.dataset == 'gowalla': 59 | args.interval = pd.Timedelta(days=1) 60 | args.max_items = 50000 61 | 62 | df = pd.read_csv( 63 | clicks, 64 | sep='\t', 65 | header=None, 66 | names=['userId', 'timestamp', 'latitude', 'longitude', 'itemId'], 67 | parse_dates=['timestamp'], 68 | infer_datetime_format=True, 69 | ) 70 | df_clicks = df[['userId', 'timestamp', 'itemId']] 71 | df_loc = df.groupby('itemId').agg({ 72 | 'latitude': lambda col: col.iloc[0], 73 | 'longitude': lambda col: col.iloc[0], 74 | }).reset_index() 75 | df_edges = pd.read_csv(edges, sep='\t', header=None, names=['follower', 'followee']) 76 | elif args.dataset == 'delicious': 77 | 78 | df_clicks = pd.read_csv( 79 | clicks, 80 | sep='\t', 81 | skiprows=1, 82 | header=None, 83 | names=['userId', 'sessionId', 'itemId', 'timestamp'], 84 | ) 85 | df_clicks['timestamp'] = pd.to_datetime(df_clicks.timestamp, unit='ms') 86 | df_loc = None 87 | df_edges = pd.read_csv( 88 | edges, 89 | sep='\t', 90 | skiprows=1, 91 | header=None, 92 | usecols=[0, 1], 93 | names=['follower', 'followee'], 94 | ) 95 | elif args.dataset == 'foursquare': 96 | args.interval = pd.Timedelta(days=1) 97 | args.max_users = 50000 98 | args.max_items = 50000 99 | 100 | df_loc = pd.read_csv( 101 | args.input_dir / 'raw_POIs.txt', 102 | sep='\t', 103 | header=None, 104 | usecols=[0, 1, 2], 105 | names=['itemId', 'latitude', 'longitude'] 106 | ) 107 | 108 | df_clicks = pd.read_csv( 109 | clicks, 110 | sep='\t', 111 | header=None, 112 | usecols=[0, 1, 2], 113 | names=['userId', 'itemId', 'timestamp'], 114 | ) 115 | df_clicks['timestamp'] = pd.to_datetime( 116 | df_clicks.timestamp, format='%a %b %d %H:%M:%S %z %Y', errors='coerce' 117 | ) 118 | 119 | df_edges = pd.read_csv(edges, sep='\t', header=None, names=['follower', 'followee']) 120 | df_edges_rev = pd.DataFrame({ 121 | 'followee': df_edges.follower, 122 | 'follower': df_edges.followee 123 | }) 124 | df_edges = df_edges.append(df_edges_rev, ignore_index=True) 125 | else: 126 | print(f'Unsupported dataset {args.dataset}', file=sys.stderr) 127 | sys.exit(1) 128 | 129 | df_edges = df_edges[df_edges.follower != df_edges.followee] 130 | df_clicks = df_clicks.dropna() 131 | print('converting IDs to integers...') 132 | df_clicks, df_edges = update_id( 133 | df_clicks, df_edges, colnames=['userId', 'followee', 'follower'] 134 | ) 135 | if df_loc is None: 136 | df_clicks = update_id(df_clicks, colnames='itemId') 137 | else: 138 | df_clicks, df_loc = update_id(df_clicks, df_loc, colnames='itemId') 139 | df_clicks = df_clicks.sort_values(['userId', 'timestamp']) 140 | np.random.seed(123456) 141 | preprocess(df_clicks, df_edges, df_loc, args) 142 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from srs.utils.argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | parser = ArgumentParser() 5 | parser.add_argument('--model', required=True, help='the prediction model') 6 | parser.add_argument( 7 | '--dataset-dir', type=Path, required=True, help='the dataset set directory' 8 | ) 9 | parser.add_argument( 10 | '--embedding-dim', type=int, default=128, help='the dimensionality of embeddings' 11 | ) 12 | parser.add_argument( 13 | '--feat-drop', type=float, default=0.2, help='the dropout ratio for input features' 14 | ) 15 | parser.add_argument( 16 | '--num-layers', 17 | type=int, 18 | default=1, 19 | help='the number of HGNN layers in the KGE component', 20 | ) 21 | parser.add_argument( 22 | '--num-neighbors', 23 | default='10', 24 | help='the number of neighbors to sample at each layer.' 25 | ' Give an integer if the number is the same for all layers.' 26 | ' Give a list of integers separated by commas if this number is different at different layers, e.g., 10,10,5' 27 | ) 28 | parser.add_argument( 29 | '--model-args', 30 | type=str, 31 | default='{}', 32 | help="the extra arguments passed to the model's initializer." 33 | ' Will be evaluated as a dictionary.', 34 | ) 35 | parser.add_argument('--batch-size', type=int, default=128, help='the batch size') 36 | parser.add_argument( 37 | '--epochs', type=int, default=30, help='the maximum number of training epochs' 38 | ) 39 | parser.add_argument('--lr', type=float, default=1e-3, help='the learning rate') 40 | parser.add_argument( 41 | '--weight-decay', 42 | type=float, 43 | default=1e-4, 44 | help='the weight decay for the optimizer', 45 | ) 46 | parser.add_argument( 47 | '--patience', 48 | type=int, 49 | default=2, 50 | help='stop training if the performance does not improve in this number of consecutive epochs', 51 | ) 52 | parser.add_argument( 53 | '--Ks', 54 | default='10,20', 55 | help='the values of K in evaluation metrics, separated by commas' 56 | ) 57 | parser.add_argument( 58 | '--ignore-list', 59 | default='bias,batch_norm,activation', 60 | help='the names of parameters excluded from being regularized', 61 | ) 62 | parser.add_argument( 63 | '--log-level', 64 | choices=['debug', 'info', 'warning', 'error'], 65 | default='debug', 66 | help='the log level', 67 | ) 68 | parser.add_argument( 69 | '--log-interval', 70 | type=int, 71 | default=1000, 72 | help='if log level is info or debug, print training information after every this number of iterations', 73 | ) 74 | parser.add_argument( 75 | '--device', type=int, default=0, help='the index of GPU device (-1 for CPU)' 76 | ) 77 | parser.add_argument( 78 | '--num-workers', 79 | type=int, 80 | default=1, 81 | help='the number of processes for data loaders', 82 | ) 83 | parser.add_argument( 84 | '--OTF', 85 | action='store_true', 86 | help='compute KG embeddings on the fly instead of precomputing them before inference to save memory', 87 | ) 88 | args = parser.parse_args() 89 | args.model_args = eval(args.model_args) 90 | args.num_neighbors = [int(x) for x in args.num_neighbors.split(',')] 91 | args.Ks = [int(K) for K in args.Ks.split(',')] 92 | args.ignore_list = [x.strip() for x in args.ignore_list.split(',') if x.strip() != ''] 93 | 94 | import logging 95 | import importlib 96 | 97 | module = importlib.import_module(f'srs.models.{args.model}') 98 | config = module.config 99 | for k, v in vars(args).items(): 100 | config[k] = v 101 | args = config 102 | 103 | log_level = getattr(logging, args.log_level.upper(), None) 104 | logging.basicConfig(format='%(message)s', level=log_level) 105 | logging.debug(args) 106 | 107 | import torch as th 108 | from torch.utils.data import DataLoader 109 | from srs.layers.seframe import SEFrame 110 | from srs.utils.data.load import read_dataset, AugmentedDataset, AnonymousAugmentedDataset 111 | from srs.utils.train_runner import TrainRunner 112 | 113 | args.device = ( 114 | th.device('cpu') if args.device < 0 else th.device(f'cuda:{args.device}') 115 | ) 116 | args.prepare_batch = args.prepare_batch_factory(args.device) 117 | 118 | logging.info(f'reading dataset {args.dataset_dir}...') 119 | df_train, df_valid, df_test, stats = read_dataset(args.dataset_dir) 120 | 121 | if issubclass(args.Model, SEFrame): 122 | from srs.utils.data.load import (read_social_network, build_knowledge_graph) 123 | 124 | social_network = read_social_network(args.dataset_dir / 'edges.txt') 125 | args.knowledge_graph = build_knowledge_graph(df_train, social_network) 126 | 127 | elif args.Model.__name__ == 'DGRec': 128 | from srs.utils.data.load import ( 129 | compute_visible_time_list_and_in_neighbors, 130 | filter_invalid_sessions, 131 | ) 132 | 133 | visible_time_list, in_neighbors = compute_visible_time_list_and_in_neighbors( 134 | df_train, args.dataset_dir, args.num_layers 135 | ) 136 | args.visible_time_list = visible_time_list 137 | args.in_neighbors = in_neighbors 138 | args.uid2sessions = [{ 139 | 'sids': df['sessionId'].values, 140 | 'sessions': df['items'].values 141 | } for _, df in df_train.groupby('userId')] 142 | L_hop_visible_time = visible_time_list[args.num_layers] 143 | 144 | df_train, df_valid, df_test = filter_invalid_sessions( 145 | df_train, df_valid, df_test, L_hop_visible_time=L_hop_visible_time 146 | ) 147 | 148 | args.num_users = getattr(stats, 'num_users', None) 149 | args.num_items = stats.num_items 150 | args.max_len = stats.max_len 151 | 152 | model = args.Model(**args, **args.model_args) 153 | model = model.to(args.device) 154 | logging.debug(model) 155 | 156 | if args.num_users is None: 157 | train_set = AnonymousAugmentedDataset(df_train) 158 | valid_set = AnonymousAugmentedDataset(df_valid) 159 | test_set = AnonymousAugmentedDataset(df_test) 160 | else: 161 | read_sid = args.Model.__name__ == 'DGRec' 162 | train_set = AugmentedDataset(df_train, read_sid) 163 | valid_set = AugmentedDataset(df_valid, read_sid) 164 | test_set = AugmentedDataset(df_test, read_sid) 165 | 166 | if 'CollateFn' in args: 167 | collate_fn = args.CollateFn(**args) 168 | collate_train = collate_fn.collate_train 169 | if args.OTF and issubclass(args.Model, SEFrame): 170 | print('compute KG embeddings on the fly') 171 | collate_test = collate_fn.collate_test_otf 172 | else: 173 | collate_test = collate_fn.collate_test 174 | else: 175 | collate_train = collate_test = args.collate_fn 176 | 177 | args.model = model 178 | 179 | if 'BatchSampler' in config: 180 | logging.debug('using batch sampler') 181 | batch_sampler = config.BatchSampler( 182 | train_set, batch_size=args.batch_size, drop_last=True, seed=0 183 | ) 184 | train_loader = DataLoader( 185 | train_set, 186 | batch_sampler=batch_sampler, 187 | collate_fn=collate_train, 188 | num_workers=args.num_workers, 189 | ) 190 | else: 191 | train_loader = DataLoader( 192 | train_set, 193 | batch_size=args.batch_size, 194 | collate_fn=collate_train, 195 | num_workers=args.num_workers, 196 | drop_last=True, 197 | shuffle=True, 198 | ) 199 | 200 | valid_loader = DataLoader( 201 | valid_set, 202 | batch_size=args.batch_size, 203 | collate_fn=collate_test, 204 | num_workers=args.num_workers, 205 | drop_last=False, 206 | shuffle=False, 207 | ) 208 | 209 | test_loader = DataLoader( 210 | test_set, 211 | batch_size=args.batch_size, 212 | collate_fn=collate_test, 213 | num_workers=args.num_workers, 214 | drop_last=False, 215 | shuffle=False, 216 | ) 217 | 218 | runner = TrainRunner(train_loader, valid_loader, test_loader, **args) 219 | logging.info('start training') 220 | results = runner.train(args.epochs, log_interval=args.log_interval) 221 | -------------------------------------------------------------------------------- /srs/layers/narm.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class NARMLayer(nn.Module): 7 | def __init__(self, input_dim, feat_drop=0.0): 8 | super().__init__() 9 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 10 | self.gru = nn.GRU(input_dim, input_dim) 11 | self.attn_i = nn.Linear(input_dim, input_dim, bias=False) 12 | self.attn_t = nn.Linear(input_dim, input_dim, bias=False) 13 | self.attn_e = nn.Linear(input_dim, 1, bias=False) 14 | 15 | def forward(self, emb_seqs, lens): 16 | batch_size, max_len, _ = emb_seqs.size() 17 | mask = th.arange( 18 | max_len, device=lens.device 19 | ).unsqueeze(0).expand(batch_size, max_len) >= lens.unsqueeze(-1) 20 | 21 | if self.feat_drop is not None: 22 | emb_seqs = self.feat_drop(emb_seqs) 23 | packed_seqs = pack_padded_sequence(emb_seqs, lens.cpu(), batch_first=True) 24 | out, ht = self.gru(packed_seqs) 25 | out, _ = pad_packed_sequence(out, batch_first=True) # (batch_size, max_len, d) 26 | ht = ht.transpose(0, 1) # (batch_size, 1, d) 27 | 28 | ei = self.attn_i(out) 29 | et = self.attn_t(ht) 30 | e = self.attn_e(th.sigmoid(ei + et)) # (batch_size, max_len, 1) 31 | e = e.squeeze(-1) # (batch_size, max_len) 32 | alpha = th.masked_fill(e, mask, 0) 33 | 34 | ct_g = ht.squeeze(1) # (batch_size, d) 35 | ct_l = th.sum(out * alpha.unsqueeze(-1), dim=1) # (batch_size, d) 36 | sr = th.cat([ct_g, ct_l], dim=1) 37 | return sr 38 | -------------------------------------------------------------------------------- /srs/layers/nextitnet.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | 5 | class NextItNetLayer(nn.Module): 6 | def __init__(self, channels, dilations, one_masked, kernel_size, feat_drop=0.0): 7 | super().__init__() 8 | if one_masked: 9 | ResBlock = ResBlockOneMasked 10 | if dilations is None: 11 | dilations = [1, 2, 4] 12 | else: 13 | ResBlock = ResBlockTwoMasked 14 | if dilations is None: 15 | dilations = [1, 4] 16 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 17 | self.res_blocks = nn.ModuleList([ 18 | ResBlock(channels, kernel_size, dilation) for dilation in dilations 19 | ]) 20 | 21 | def forward(self, emb_seqs, lens): 22 | # emb_seqs: (B, L, C) 23 | batch_size, max_len, _ = emb_seqs.size() 24 | mask = th.arange( 25 | max_len, device=lens.device 26 | ).unsqueeze(0).expand(batch_size, max_len) >= lens.unsqueeze(-1) 27 | emb_seqs = th.masked_fill(emb_seqs, mask.unsqueeze(-1), 0) 28 | if self.feat_drop is not None: 29 | emb_seqs = self.feat_drop(emb_seqs) 30 | 31 | x = th.transpose(emb_seqs, 1, 2) # (B, C, L) 32 | for res_block in self.res_blocks: 33 | x = res_block(x) 34 | batch_idx = th.arange(len(lens)) 35 | last_idx = lens - 1 36 | sr = x[batch_idx, :, last_idx] # (B, C) 37 | return sr 38 | 39 | 40 | class MaskedConv1d(nn.Module): 41 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1): 42 | super().__init__() 43 | self.repr_str = ( 44 | f'{self.__class__.__name__}(in_channels={in_channels}, ' 45 | f'out_channels={out_channels}, kernel_size={kernel_size}, dilation={dilation})' 46 | ) 47 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation) 48 | self.padding = (kernel_size - 1) * dilation 49 | 50 | def forward(self, x): 51 | # x: (B, C, L) 52 | x = th.nn.functional.pad(x, [self.padding, 0]) # (B, C, L + self.padding) 53 | x = self.conv(x) 54 | return x 55 | 56 | def __repr__(self): 57 | return self.repr_str 58 | 59 | 60 | class LayerNorm(nn.Module): 61 | def __init__(self, channels, epsilon=1e-5): 62 | super().__init__() 63 | self.gamma = nn.Parameter(th.ones([1, channels, 1], dtype=th.float32)) 64 | self.beta = nn.Parameter(th.zeros([1, channels, 1], dtype=th.float32)) 65 | self.epsilon = epsilon 66 | 67 | def forward(self, x): 68 | # x: (B, C, L) 69 | var, mean = th.var_mean(x, dim=1, keepdim=True, unbiased=False) 70 | x = (x - mean) / th.sqrt(var + self.epsilon) 71 | return x * self.gamma + self.beta 72 | 73 | 74 | class ResBlockOneMasked(nn.Module): 75 | def __init__(self, channels, kernel_size, dilation): 76 | super().__init__() 77 | mid_channels = channels // 2 78 | self.layer_norm1 = LayerNorm(channels) 79 | self.conv1 = nn.Conv1d(channels, mid_channels, kernel_size=1) 80 | self.layer_norm2 = LayerNorm(mid_channels) 81 | self.conv2 = MaskedConv1d( 82 | mid_channels, mid_channels, kernel_size=kernel_size, dilation=dilation 83 | ) 84 | self.layer_norm3 = LayerNorm(mid_channels) 85 | self.conv3 = nn.Conv1d(mid_channels, channels, kernel_size=1) 86 | 87 | def forward(self, x): 88 | # x: (B, C, L) 89 | y = x 90 | y = th.relu(self.layer_norm1(y)) 91 | y = self.conv1(y) 92 | y = th.relu(self.layer_norm2(y)) 93 | y = self.conv2(y) 94 | y = th.relu(self.layer_norm3(y)) 95 | y = self.conv3(y) 96 | return y + x 97 | 98 | 99 | class ResBlockTwoMasked(nn.Module): 100 | def __init__(self, channels, kernel_size, dilation): 101 | super().__init__() 102 | self.conv1 = MaskedConv1d(channels, channels, kernel_size, dilation) 103 | self.layer_norm1 = LayerNorm(channels) 104 | self.conv2 = MaskedConv1d(channels, channels, kernel_size, 2 * dilation) 105 | self.layer_norm2 = LayerNorm(channels) 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.conv1(y) 110 | y = th.relu(self.layer_norm1(y)) 111 | y = self.conv2(y) 112 | y = th.relu(self.layer_norm2(y)) 113 | return y + x 114 | -------------------------------------------------------------------------------- /srs/layers/seframe.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | 4 | import torch as th 5 | from torch import nn 6 | import dgl 7 | import dgl.ops as F 8 | from dgl.nn.pytorch import edge_softmax 9 | 10 | 11 | class HomoAttentionAggregationLayer(nn.Module): 12 | def __init__( 13 | self, 14 | qry_feats, 15 | key_feats, 16 | val_feats, 17 | num_heads=1, 18 | feat_drop=0.0, 19 | attn_drop=0.0, 20 | activation=None, 21 | batch_norm=True, 22 | ): 23 | super().__init__() 24 | if batch_norm: 25 | self.batch_norm_q = nn.BatchNorm1d(qry_feats) 26 | self.batch_norm_k = nn.BatchNorm1d(key_feats) 27 | else: 28 | self.batch_norm_q = None 29 | self.batch_norm_k = None 30 | 31 | self.feat_drop = nn.Dropout(feat_drop) 32 | self.attn_drop = nn.Dropout(attn_drop) 33 | 34 | self.fc_q = nn.Linear(qry_feats, val_feats, bias=True) 35 | self.fc_k = nn.Linear(key_feats, val_feats, bias=False) 36 | self.fc_v = nn.Linear(qry_feats, val_feats, bias=False) 37 | self.attn_e = nn.Parameter( 38 | th.randn(1, val_feats, dtype=th.float), requires_grad=True 39 | ) 40 | self.activation = activation 41 | 42 | self.val_feats = val_feats 43 | self.num_heads = num_heads 44 | self.head_feats = val_feats // num_heads 45 | 46 | def extra_repr(self): 47 | return '\n'.join([ 48 | f'num_heads={self.num_heads}', f'(attn_e): Parameter(1, {self.val_feats})' 49 | ]) 50 | 51 | def forward(self, g, ft_q, ft_k, ft_e=None, return_ev=False): 52 | if self.batch_norm_q is not None: 53 | ft_q = self.batch_norm_q(ft_q) 54 | ft_k = self.batch_norm_k(ft_k) 55 | q = self.fc_q(self.feat_drop(ft_q)) 56 | k = self.fc_k(self.feat_drop(ft_k)) 57 | v = self.fc_v(self.feat_drop(ft_q)).view(-1, self.num_heads, self.head_feats) 58 | e = F.u_add_v(g, q, k) 59 | if ft_e is not None: 60 | e = e + ft_e 61 | e = (self.attn_e * th.sigmoid(e)).view(-1, self.num_heads, self.head_feats).sum( 62 | -1, keepdim=True 63 | ) 64 | if return_ev: 65 | return e, v 66 | a = self.attn_drop(edge_softmax(g, e)) 67 | rst = F.u_mul_e_sum(g, v, a).view(-1, self.val_feats) 68 | if self.activation is not None: 69 | rst = self.activation(rst) 70 | return rst 71 | 72 | 73 | class HeteroAttentionAggregationLayer(nn.Module): 74 | def __init__( 75 | self, 76 | kg, 77 | embedding_dim, 78 | num_heads=1, 79 | batch_norm=True, 80 | feat_drop=0.0, 81 | relu=False, 82 | ): 83 | super().__init__() 84 | self.batch_norm = nn.ModuleDict() if batch_norm else None 85 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 86 | self.edge_aggregate = nn.ModuleDict() 87 | self.edge_embedding = nn.ModuleDict() 88 | self.linear_agg = nn.ModuleDict() 89 | self.linear_self = nn.ModuleDict() 90 | self.activation = nn.ModuleDict() 91 | self.vtype2eutypes = defaultdict(list) 92 | for utype, etype, vtype in kg.canonical_etypes: 93 | self.edge_aggregate[etype] = HomoAttentionAggregationLayer( 94 | embedding_dim, 95 | embedding_dim, 96 | embedding_dim, 97 | num_heads=num_heads, 98 | batch_norm=False, 99 | feat_drop=0.0, 100 | activation=None, 101 | ) 102 | if 'cnt' in kg.edges[etype].data: 103 | num_cnt_embeddings = kg.edges[etype].data['cnt'].max() + 1 104 | self.edge_embedding[etype] = nn.Embedding( 105 | num_cnt_embeddings, embedding_dim 106 | ) 107 | self.vtype2eutypes[vtype].append((etype, utype)) 108 | for vtype in self.vtype2eutypes: 109 | self.linear_agg[vtype] = nn.Linear(embedding_dim, embedding_dim, bias=True) 110 | self.linear_self[vtype] = nn.Linear( 111 | embedding_dim, embedding_dim, bias=False 112 | ) 113 | self.activation[vtype] = nn.ReLU() if relu else nn.PReLU(embedding_dim) 114 | if self.batch_norm is not None: 115 | self.batch_norm.update({ 116 | vtype: nn.BatchNorm1d(embedding_dim) 117 | for vtype in self.vtype2eutypes 118 | }) 119 | 120 | def forward(self, g, ft_src): 121 | if self.batch_norm is not None: 122 | ft_src = {ntype: self.batch_norm[ntype](ft) for ntype, ft in ft_src.items()} 123 | if self.feat_drop is not None: 124 | ft_src = {ntype: self.feat_drop(ft) for ntype, ft in ft_src.items()} 125 | device = next(iter(ft_src.values())).device 126 | ft_dst = { 127 | vtype: ft_src[vtype][:g.number_of_dst_nodes(vtype)] 128 | for vtype in g.dsttypes 129 | } 130 | feats = {} 131 | for vtype, eutypes in self.vtype2eutypes.items(): 132 | src_nid = [] 133 | dst_nid = [] 134 | num_utypes_nodes = 0 135 | src_val = [] 136 | attn_score = [] 137 | for etype, utype in eutypes: 138 | sg = g[etype] 139 | ft_e = ( 140 | self.edge_embedding[etype](sg.edata['cnt'].to(device)) 141 | if etype in self.edge_embedding else None 142 | ) 143 | e, v = self.edge_aggregate[etype]( 144 | sg, 145 | ft_src[utype], 146 | ft_dst[vtype], 147 | ft_e=ft_e, 148 | return_ev=True, 149 | ) 150 | uid, vid = sg.all_edges(form='uv', order='eid') 151 | src_nid.append(uid + num_utypes_nodes) 152 | dst_nid.append(vid) 153 | num_utypes_nodes += sg.number_of_src_nodes() 154 | src_val.append(v) 155 | attn_score.append(e) 156 | src_nid = th.cat(src_nid, dim=0) 157 | dst_nid = th.cat(dst_nid, dim=0) 158 | edge_softmax_g = dgl.heterograph( 159 | data_dict={('utypes', 'etypes', 'vtype'): (src_nid, dst_nid)}, 160 | num_nodes_dict={ 161 | 'utypes': num_utypes_nodes, 162 | 'vtype': g.number_of_dst_nodes(vtype) 163 | }, 164 | device=device 165 | ) 166 | src_val = th.cat(src_val, dim=0) # (num_utypes_nodes, num_heads, num_feats) 167 | attn_score = th.cat(attn_score, dim=0) # (num_edges, num_heads, 1) 168 | attn_weight = F.edge_softmax(edge_softmax_g, attn_score) 169 | agg = F.u_mul_e_sum(edge_softmax_g, src_val, attn_weight) 170 | agg = agg.view(g.number_of_dst_nodes(vtype), -1) 171 | feats[vtype] = self.activation[vtype]( 172 | self.linear_agg[vtype](agg) + self.linear_self[vtype](ft_dst[vtype]) 173 | ) 174 | 175 | return feats 176 | 177 | 178 | class KnowledgeGraphEmbeddingLayer(nn.Module): 179 | def __init__( 180 | self, 181 | knowledge_graph, 182 | node_feats, 183 | num_layers, 184 | residual=True, 185 | batch_norm=True, 186 | feat_drop=0.0, 187 | ): 188 | super().__init__() 189 | self.layers = nn.ModuleList([ 190 | HeteroAttentionAggregationLayer( 191 | knowledge_graph, 192 | node_feats, 193 | batch_norm=batch_norm, 194 | feat_drop=feat_drop, 195 | ) for _ in range(num_layers) 196 | ]) 197 | self.residual = residual 198 | 199 | def forward(self, graphs, feats): 200 | for layer, g in zip(self.layers, graphs): 201 | out_feats = layer(g, feats) 202 | if self.residual: 203 | feats = { 204 | ntype: out_feats[ntype] + feat[:len(out_feats[ntype])] 205 | for ntype, feat in feats.items() 206 | } 207 | else: 208 | feats = out_feats 209 | return feats 210 | 211 | 212 | class SEFrame(nn.Module): 213 | def __init__( 214 | self, 215 | num_users, 216 | num_items, 217 | embedding_dim, 218 | knowledge_graph, 219 | num_layers, 220 | batch_norm=True, 221 | feat_drop=0.0, 222 | **kwargs, 223 | ): 224 | super().__init__() 225 | self.user_embedding = nn.Embedding(num_users, embedding_dim, max_norm=1) 226 | self.user_indices = nn.Parameter( 227 | th.arange(num_users, dtype=th.long), requires_grad=False 228 | ) 229 | self.item_embedding = nn.Embedding(num_items, embedding_dim, max_norm=1) 230 | self.item_indices = nn.Parameter( 231 | th.arange(num_items, dtype=th.long), requires_grad=False 232 | ) 233 | self.knowledge_graph = knowledge_graph 234 | self.KGE_layer = KnowledgeGraphEmbeddingLayer( 235 | knowledge_graph, 236 | embedding_dim, 237 | num_layers, 238 | batch_norm=batch_norm, 239 | feat_drop=feat_drop, 240 | ) 241 | 242 | def precompute_KG_embeddings(self): 243 | self.eval() 244 | kg_device = self.knowledge_graph.device 245 | ft_device = self.user_indices.device 246 | if kg_device != ft_device: 247 | logging.debug(f'Copying knowledge graph from {kg_device} to {ft_device}') 248 | self.knowledge_graph = self.knowledge_graph.to(ft_device) 249 | with th.no_grad(): 250 | graphs = [self.knowledge_graph] * len(self.KGE_layer.layers) 251 | feats = { 252 | 'user': self.user_embedding(self.user_indices), 253 | 'item': self.item_embedding(self.item_indices), 254 | } 255 | self.KG_embeddings = self.KGE_layer(graphs, feats) 256 | 257 | def forward(self, inputs): 258 | if inputs is None: 259 | return self.KG_embeddings 260 | else: 261 | graphs, used_nodes = inputs 262 | feats = { 263 | 'user': self.user_embedding(used_nodes['user']), 264 | 'item': self.item_embedding(used_nodes['item']), 265 | } 266 | return self.KGE_layer(graphs, feats) 267 | -------------------------------------------------------------------------------- /srs/layers/serec.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | import dgl 5 | import dgl.ops as F 6 | 7 | 8 | class UpdateCell(nn.Module): 9 | def __init__(self, input_dim, output_dim): 10 | super().__init__() 11 | self.x2i = nn.Linear(input_dim, 2 * output_dim, bias=True) 12 | self.h2h = nn.Linear(output_dim, 2 * output_dim, bias=False) 13 | 14 | def forward(self, x, hidden): 15 | i_i, i_n = self.x2i(x).chunk(2, 1) 16 | h_i, h_n = self.h2h(hidden).chunk(2, 1) 17 | input_gate = th.sigmoid(i_i + h_i) 18 | new_gate = th.tanh(i_n + h_n) 19 | return new_gate + input_gate * (hidden - new_gate) 20 | 21 | 22 | class PWGGNN(nn.Module): 23 | def __init__( 24 | self, 25 | input_dim, 26 | hidden_dim, 27 | output_dim, 28 | num_steps=1, 29 | batch_norm=True, 30 | feat_drop=0.0, 31 | activation=None, 32 | ): 33 | super().__init__() 34 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 35 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 36 | self.fc_i2h = nn.Linear( 37 | input_dim, hidden_dim, bias=False 38 | ) if input_dim != hidden_dim else None 39 | self.fc_in = nn.Linear(hidden_dim, hidden_dim, bias=True) 40 | self.fc_out = nn.Linear(hidden_dim, hidden_dim, bias=True) 41 | # self.upd_cell = nn.GRUCell(2 * hidden_dim, hidden_dim) 42 | self.upd_cell = UpdateCell(2 * hidden_dim, hidden_dim) 43 | self.fc_h2o = nn.Linear( 44 | hidden_dim, output_dim, bias=False 45 | ) if hidden_dim != output_dim else None 46 | self.hidden_dim = hidden_dim 47 | self.num_steps = num_steps 48 | self.activation = activation 49 | 50 | def propagate(self, g, rg, feat): 51 | if g.number_of_edges() > 0: 52 | feat_in = self.fc_in(feat) 53 | feat_out = self.fc_out(feat) 54 | a_in = F.u_mul_e_sum(g, feat_in, g.edata['iw']) 55 | a_out = F.u_mul_e_sum(rg, feat_out, rg.edata['ow']) 56 | # a: (num_nodes, 2 * hidden_dim) 57 | a = th.cat((a_in, a_out), dim=1) 58 | else: 59 | num_nodes = g.number_of_nodes() 60 | a = feat.new_zeros((num_nodes, 2 * self.hidden_dim)) 61 | hn = self.upd_cell(a, feat) 62 | return hn 63 | 64 | def forward(self, g, rg, feat): 65 | if self.batch_norm is not None: 66 | feat = self.batch_norm(feat) 67 | if self.feat_drop is not None: 68 | feat = self.feat_drop(feat) 69 | if self.fc_i2h is not None: 70 | feat = self.fc_i2h(feat) 71 | for _ in range(self.num_steps): 72 | feat = self.propagate(g, rg, feat) 73 | if self.fc_h2o is not None: 74 | feat = self.fc_h2o(feat) 75 | if self.activation is not None: 76 | feat = self.activation(feat) 77 | return feat 78 | 79 | 80 | class PAttentionReadout(nn.Module): 81 | def __init__(self, embedding_dim, batch_norm=False, feat_drop=0.0, activation=None): 82 | super().__init__() 83 | if batch_norm: 84 | self.batch_norm = nn.ModuleDict({ 85 | 'user': nn.BatchNorm1d(embedding_dim), 86 | 'item': nn.BatchNorm1d(embedding_dim) 87 | }) 88 | else: 89 | self.batch_norm = None 90 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 91 | self.fc_user = nn.Linear(embedding_dim, embedding_dim, bias=True) 92 | self.fc_key = nn.Linear(embedding_dim, embedding_dim, bias=False) 93 | self.fc_last = nn.Linear(embedding_dim, embedding_dim, bias=False) 94 | self.fc_e = nn.Linear(embedding_dim, 1, bias=False) 95 | self.activation = activation 96 | 97 | def forward(self, g, feat_i, feat_u, last_nodes): 98 | if self.batch_norm is not None: 99 | feat_i = self.batch_norm['item'](feat_i) 100 | feat_u = self.batch_norm['user'](feat_u) 101 | if self.feat_drop is not None: 102 | feat_i = self.feat_drop(feat_i) 103 | feat_u = self.feat_drop(feat_u) 104 | feat_val = feat_i 105 | feat_key = self.fc_key(feat_i) 106 | feat_u = self.fc_user(feat_u) 107 | feat_last = self.fc_last(feat_i[last_nodes]) 108 | feat_qry = dgl.broadcast_nodes(g, feat_u + feat_last) 109 | e = self.fc_e(th.sigmoid(feat_qry + feat_key)) # (num_nodes, 1) 110 | e = e + g.ndata['cnt'].log().view_as(e) 111 | alpha = F.segment.segment_softmax(g.batch_num_nodes(), e) 112 | rst = F.segment.segment_reduce(g.batch_num_nodes(), alpha * feat_val, 'sum') 113 | if self.activation is not None: 114 | rst = self.activation(rst) 115 | return rst 116 | 117 | 118 | class SERecLayer(nn.Module): 119 | def __init__( 120 | self, 121 | embedding_dim, 122 | num_steps=1, 123 | batch_norm=True, 124 | feat_drop=0.0, 125 | relu=False, 126 | ): 127 | super().__init__() 128 | self.fc_i = nn.Linear(embedding_dim, embedding_dim, bias=False) 129 | self.fc_u = nn.Linear(embedding_dim, embedding_dim, bias=False) 130 | self.pwggnn = PWGGNN( 131 | embedding_dim, 132 | embedding_dim, 133 | embedding_dim, 134 | num_steps=num_steps, 135 | batch_norm=batch_norm, 136 | feat_drop=feat_drop, 137 | activation=nn.ReLU() if relu else nn.PReLU(embedding_dim), 138 | ) 139 | self.readout = PAttentionReadout( 140 | embedding_dim, 141 | batch_norm=batch_norm, 142 | feat_drop=feat_drop, 143 | activation=nn.ReLU() if relu else nn.PReLU(embedding_dim), 144 | ) 145 | 146 | def forward(self, g, feat, feat_u): 147 | rg = dgl.reverse(g, False, False) 148 | if g.number_of_edges() > 0: 149 | edge_weight = g.edata['w'] 150 | in_deg = F.copy_e_sum(g, edge_weight) 151 | g.edata['iw'] = F.e_div_v(g, edge_weight, in_deg) 152 | out_deg = F.copy_e_sum(rg, edge_weight) 153 | rg.edata['ow'] = F.e_div_v(rg, edge_weight, out_deg) 154 | 155 | feat = self.pwggnn(g, rg, feat) 156 | last_nodes = g.filter_nodes(lambda nodes: nodes.data['last'] == 1) 157 | ct_l = feat[last_nodes] 158 | ct_g = self.readout(g, feat, feat_u, last_nodes) 159 | sr = th.cat((ct_l, ct_g), dim=1) 160 | return sr 161 | -------------------------------------------------------------------------------- /srs/layers/srgnn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | import dgl 4 | import dgl.ops as F 5 | 6 | 7 | class GGNN(nn.Module): 8 | def __init__( 9 | self, 10 | input_dim, 11 | hidden_dim=None, 12 | output_dim=None, 13 | num_steps=1, 14 | batch_norm=False, 15 | feat_drop=0.0, 16 | activation=None, 17 | ): 18 | super().__init__() 19 | if hidden_dim is None: 20 | hidden_dim = input_dim 21 | if output_dim is None: 22 | output_dim = input_dim 23 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 24 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 25 | self.fc_in = nn.Linear(hidden_dim, hidden_dim, bias=True) 26 | self.fc_out = nn.Linear(hidden_dim, hidden_dim, bias=True) 27 | self.gru_cell = nn.GRUCell(2 * hidden_dim, hidden_dim) 28 | self.hidden_dim = hidden_dim 29 | self.num_steps = num_steps 30 | self.activation = activation 31 | 32 | def propagate(self, g, rg, feat): 33 | if g.number_of_edges() > 0: 34 | feat_in = self.fc_in(feat) 35 | feat_out = self.fc_out(feat) 36 | a_in = F.copy_u_mean(g, feat_in) 37 | a_out = F.copy_u_mean(rg, feat_out) 38 | # a: (num_nodes, 2 * hidden_dim) 39 | a = th.cat((a_in, a_out), dim=1) 40 | else: 41 | num_nodes = g.number_of_nodes() 42 | a = feat.new_zeros((num_nodes, 2 * self.hidden_dim)) 43 | hn = self.gru_cell(a, feat) 44 | return hn 45 | 46 | def forward(self, g, rg, feat): 47 | if self.batch_norm is not None: 48 | feat = self.batch_norm(feat) 49 | if self.feat_drop is not None: 50 | feat = self.feat_drop(feat) 51 | for _ in range(self.num_steps): 52 | feat = self.propagate(g, rg, feat) 53 | if self.activation is not None: 54 | feat = self.activation(feat) 55 | return feat 56 | 57 | 58 | class AttentionReadout(nn.Module): 59 | def __init__( 60 | self, 61 | input_dim, 62 | hidden_dim=None, 63 | output_dim=None, 64 | batch_norm=False, 65 | feat_drop=0.0, 66 | activation=None, 67 | ): 68 | super().__init__() 69 | if hidden_dim is None: 70 | hidden_dim = input_dim 71 | if output_dim is None: 72 | output_dim = input_dim 73 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 74 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 75 | self.fc_u = nn.Linear(input_dim, hidden_dim, bias=False) 76 | self.fc_v = nn.Linear(input_dim, hidden_dim, bias=True) 77 | self.fc_e = nn.Linear(hidden_dim, 1, bias=False) 78 | self.fc_out = ( 79 | nn.Linear(input_dim, output_dim, bias=False) 80 | if input_dim != output_dim else None 81 | ) 82 | self.activation = activation 83 | 84 | def forward(self, g, feat, last_nodes): 85 | if self.batch_norm is not None: 86 | feat = self.batch_norm(feat) 87 | if self.feat_drop is not None: 88 | feat = self.feat_drop(feat) 89 | feat_u = self.fc_u(feat) 90 | feat_v = self.fc_v(feat[last_nodes]) 91 | feat_v = dgl.broadcast_nodes(g, feat_v) 92 | e = self.fc_e(th.sigmoid(feat_u + feat_v)) # (num_nodes, 1) 93 | alpha = e * g.ndata['cnt'].view_as(e) 94 | rst = F.segment.segment_reduce(g.batch_num_nodes(), feat * alpha, 'sum') 95 | if self.fc_out is not None: 96 | rst = self.fc_out(rst) 97 | if self.activation is not None: 98 | rst = self.activation(rst) 99 | return rst 100 | 101 | 102 | class SRGNNLayer(nn.Module): 103 | def __init__(self, embedding_dim, feat_drop=0.0): 104 | super().__init__() 105 | self.ggnn = GGNN(embedding_dim, num_steps=1, feat_drop=feat_drop, activation=None) 106 | self.readout = AttentionReadout(embedding_dim, embedding_dim, feat_drop=feat_drop) 107 | 108 | def forward(self, g, feat): 109 | rg = dgl.reverse(g, False, False) 110 | feat = self.ggnn(g, rg, feat) 111 | last_nodes = g.filter_nodes(lambda nodes: nodes.data['last'] == 1) 112 | ct_l = feat[last_nodes] 113 | ct_g = self.readout(g, feat, last_nodes) 114 | sr = th.cat([ct_g, ct_l], dim=1) 115 | return sr 116 | -------------------------------------------------------------------------------- /srs/layers/ssrm.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class SSRMLayer(nn.Module): 7 | def __init__(self, embedding_dim, w=0.5, feat_drop=0.0): 8 | super().__init__() 9 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 10 | self.gru = nn.GRU(embedding_dim, embedding_dim, batch_first=True) 11 | self.B = nn.Linear(2 * embedding_dim, embedding_dim, bias=False) 12 | self.w = w 13 | 14 | def forward(self, emb_seqs, lens, feat_u): 15 | """ 16 | emb_seqs: (batch_size, max_len, d) 17 | """ 18 | if self.feat_drop is not None: 19 | emb_seqs = self.feat_drop(emb_seqs) 20 | feat_u = self.feat_drop(feat_u) 21 | 22 | batch_size, max_len, _ = emb_seqs.size() 23 | mask = th.arange( 24 | max_len, device=lens.device 25 | ).unsqueeze(0).expand(batch_size, max_len) >= lens.unsqueeze(-1) 26 | 27 | packed_seqs = pack_padded_sequence(emb_seqs, lens.cpu(), batch_first=True) 28 | 29 | out, hn = self.gru(packed_seqs) 30 | out, _ = pad_packed_sequence( 31 | out, batch_first=True 32 | ) # out: (batch_size, max_len, d) 33 | h_t = hn.squeeze(0) # (batch_size, d) 34 | 35 | alpha = (emb_seqs * feat_u.unsqueeze(1)).sum(dim=-1) # (batch_size, max_len) 36 | alpha = th.masked_fill(alpha, mask, float('-inf')) 37 | alpha = alpha.softmax(dim=1).unsqueeze(-1) # (batch_size, max_len, 1) 38 | h_sum = (alpha * out).sum(dim=1) # (batch_size, d) 39 | ct = th.cat([h_sum, h_t], dim=1) 40 | 41 | sr = self.w * feat_u + (1 - self.w) * self.B(ct) 42 | return sr 43 | -------------------------------------------------------------------------------- /srs/layers/stamp.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | 5 | class STAMPLayer(nn.Module): 6 | def __init__(self, embedding_dim, feat_drop=0.0): 7 | super().__init__() 8 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 9 | self.fc_a = nn.Linear(embedding_dim, embedding_dim, bias=True) 10 | self.fc_t = nn.Linear(embedding_dim, embedding_dim, bias=True) 11 | self.attn_i = nn.Linear(embedding_dim, embedding_dim, bias=False) 12 | self.attn_t = nn.Linear(embedding_dim, embedding_dim, bias=True) 13 | self.attn_s = nn.Linear(embedding_dim, embedding_dim, bias=False) 14 | self.attn_e = nn.Linear(embedding_dim, 1, bias=False) 15 | 16 | def forward(self, emb_seqs, lens): 17 | # emb_seqs: (batch_size, max_len, d) 18 | if self.feat_drop is not None: 19 | emb_seqs = self.feat_drop(emb_seqs) 20 | batch_size, max_len, _ = emb_seqs.size() 21 | mask = th.arange( 22 | max_len, device=lens.device 23 | ).unsqueeze(0).expand(batch_size, max_len) >= lens.unsqueeze(-1) 24 | emb_seqs = th.masked_fill(emb_seqs, mask.unsqueeze(-1), 0) 25 | # emb_seqs = th.where(mask.unsqueeze(-1), emb_seqs, th.zeros_like(emb_seqs)) 26 | 27 | ms = emb_seqs.sum(dim=1) / lens.unsqueeze(-1) # (batch_size, d) 28 | 29 | xt = emb_seqs[th.arange(batch_size), lens - 1] # (batch_size, d) 30 | ei = self.attn_i(emb_seqs) # (batch_size, max_len, d) 31 | et = self.attn_t(xt).unsqueeze(1) # (batch_size, 1, d) 32 | es = self.attn_s(ms).unsqueeze(1) # (batch_size, 1, d) 33 | e = self.attn_e(th.sigmoid(ei + et + es)).squeeze(-1) # (batch_size, max_len) 34 | alpha = th.masked_fill(e, mask, 0) 35 | alpha = alpha.unsqueeze(-1) # (batch_size, max_len, 1) 36 | ma = th.sum(alpha * emb_seqs, dim=1) # (batch_size, d) 37 | 38 | ha = self.fc_a(ma) 39 | ht = self.fc_t(xt) 40 | 41 | sr = ha * ht 42 | return sr 43 | -------------------------------------------------------------------------------- /srs/models/DGRec.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | from torch.nn.utils.rnn import pack_padded_sequence 4 | 5 | import dgl.ops as F 6 | 7 | from srs.utils.Dict import Dict 8 | from srs.utils.data.collate import CollateFnDGRec 9 | from srs.utils.prepare_batch import prepare_batch_factory_recursive 10 | 11 | 12 | class GAT(nn.Module): 13 | def __init__( 14 | self, 15 | qry_feats, 16 | key_feats, 17 | val_feats, 18 | feat_drop=0.0, 19 | batch_norm=False, 20 | ): 21 | super().__init__() 22 | if batch_norm: 23 | self.batch_norm_q = nn.BatchNorm1d(qry_feats) 24 | self.batch_norm_k = nn.BatchNorm1d(key_feats) 25 | else: 26 | self.batch_norm_q = None 27 | self.batch_norm_k = None 28 | self.feat_drop = nn.Dropout(feat_drop) 29 | 30 | self.fc = nn.Linear(qry_feats, val_feats, bias=True) 31 | 32 | self.qry_feats = qry_feats 33 | 34 | def forward(self, g, feat_src, feat_dst): 35 | if self.batch_norm_q is not None: 36 | feat_src = self.batch_norm_q(feat_src) 37 | feat_dst = self.batch_norm_k(feat_dst) 38 | if self.feat_drop is not None: 39 | feat_src = self.feat_drop(feat_src) 40 | feat_dst = self.feat_drop(feat_dst) 41 | score = F.u_dot_v(g, feat_src, feat_dst) # (num_edges, 1) 42 | weight = F.edge_softmax(g, score) 43 | rst = F.u_mul_e_sum(g, feat_src, weight) 44 | rst = th.relu(self.fc(rst)) 45 | return rst 46 | 47 | 48 | class DGRec(nn.Module): 49 | def __init__( 50 | self, 51 | num_users, 52 | num_items, 53 | embedding_dim, 54 | num_layers, 55 | batch_norm=False, 56 | feat_drop=0.0, 57 | residual=True, 58 | **kwargs, 59 | ): 60 | super().__init__() 61 | self.user_embedding = nn.Embedding(num_users, embedding_dim, max_norm=1) 62 | self.item_embeeding = nn.Embedding( 63 | num_items + 1, embedding_dim, max_norm=1, padding_idx=0 64 | ) 65 | self.item_indices = nn.Parameter( 66 | th.arange(1, num_items + 1, dtype=th.long), requires_grad=False 67 | ) 68 | self.feat_drop = nn.Dropout(feat_drop) if feat_drop > 0 else None 69 | self.lstm = nn.LSTM(embedding_dim, embedding_dim) 70 | self.W1 = nn.Linear(2 * embedding_dim, embedding_dim, bias=False) 71 | self.layers = nn.ModuleList() 72 | input_dim = embedding_dim 73 | for _ in range(num_layers): 74 | layer = GAT( 75 | input_dim, 76 | input_dim, 77 | embedding_dim, 78 | batch_norm=batch_norm, 79 | feat_drop=feat_drop, 80 | ) 81 | if not residual: 82 | input_dim += embedding_dim 83 | self.layers.append(layer) 84 | self.residual = residual 85 | self.W2 = nn.Linear(input_dim + embedding_dim, embedding_dim, bias=False) 86 | 87 | def forward(self, graphs, idx_maps, uids, padded_seqs, lens, cur_sidx): 88 | emb_seqs = self.item_embeeding(padded_seqs) 89 | if self.feat_drop is not None: 90 | emb_seqs = self.feat_drop(emb_seqs) 91 | packed_seqs = pack_padded_sequence( 92 | emb_seqs, lens.cpu(), batch_first=True, enforce_sorted=False 93 | ) 94 | _, (hn, _) = self.lstm(packed_seqs) 95 | 96 | long_term = self.user_embedding(uids) 97 | short_term = hn.squeeze(0) 98 | cur_u_short_term = short_term[cur_sidx] 99 | feat = th.cat((long_term, short_term), dim=1) 100 | feat = th.relu(self.W1(feat)) 101 | # the node features of the current user are only the short-term interests 102 | # the node features of neighbors are the combination of short-term and long-term interests 103 | feat[cur_sidx] = cur_u_short_term 104 | for g, idx_map, layer in zip(graphs, idx_maps, self.layers): 105 | feat_src = feat 106 | feat_dst = feat[idx_map] 107 | feat = layer(g, feat_src, feat_dst) 108 | if self.residual: 109 | feat = feat_dst + feat 110 | else: 111 | feat = th.cat((feat_dst, feat), dim=1) 112 | sr = self.W2(th.cat((cur_u_short_term, feat), dim=1)) 113 | logits = sr @ self.item_embeeding(self.item_indices).t() 114 | 115 | return logits 116 | 117 | 118 | config = Dict({ 119 | 'Model': DGRec, 120 | 'CollateFn': CollateFnDGRec, 121 | 'prepare_batch_factory': prepare_batch_factory_recursive, 122 | }) 123 | -------------------------------------------------------------------------------- /srs/models/NARM.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.narm import NARMLayer 5 | from srs.utils.data.collate import collate_fn_for_rnn_cnn 6 | from srs.utils.data.load import BatchSampler 7 | from srs.utils.Dict import Dict 8 | from srs.utils.prepare_batch import prepare_batch_factory 9 | 10 | 11 | class NARM(nn.Module): 12 | def __init__(self, num_items, embedding_dim, feat_drop=0.0, **kwargs): 13 | super().__init__() 14 | self.embedding = nn.Embedding(num_items, embedding_dim, max_norm=1) 15 | self.indices = nn.Parameter( 16 | th.arange(num_items, dtype=th.long), requires_grad=False 17 | ) 18 | self.narm_layer = NARMLayer(embedding_dim, feat_drop=feat_drop) 19 | self.fc_sr = nn.Linear(2 * embedding_dim, embedding_dim, bias=False) 20 | 21 | def forward(self, uids, padded_seqs, lens): 22 | emb_seqs = self.embedding(padded_seqs) 23 | sr = self.narm_layer(emb_seqs, lens) 24 | sr = self.fc_sr(sr) 25 | logits = sr @ self.embedding(self.indices).t() 26 | return logits 27 | 28 | 29 | config = Dict( 30 | { 31 | 'Model': NARM, 32 | 'collate_fn': collate_fn_for_rnn_cnn, 33 | 'BatchSampler': BatchSampler, 34 | 'prepare_batch_factory': prepare_batch_factory, 35 | } 36 | ) 37 | -------------------------------------------------------------------------------- /srs/models/NextItNet.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.nextitnet import NextItNetLayer 5 | from srs.utils.data.collate import collate_fn_for_rnn_cnn 6 | from srs.utils.data.load import BatchSampler 7 | from srs.utils.Dict import Dict 8 | from srs.utils.prepare_batch import prepare_batch_factory 9 | 10 | 11 | class NextItNet(nn.Module): 12 | def __init__( 13 | self, 14 | num_items, 15 | embedding_dim, 16 | dilations=None, 17 | one_masked=False, 18 | kernel_size=3, 19 | feat_drop=0.0, 20 | **kwargs 21 | ): 22 | super().__init__() 23 | self.embedding = nn.Embedding(num_items, embedding_dim, max_norm=1) 24 | self.indices = nn.Parameter( 25 | th.arange(num_items, dtype=th.long), requires_grad=False 26 | ) 27 | self.layer = NextItNetLayer( 28 | embedding_dim, dilations, one_masked, kernel_size, feat_drop=feat_drop 29 | ) 30 | self.fc_sr = nn.Linear(embedding_dim, embedding_dim, bias=False) 31 | 32 | def forward(self, uids, padded_seqs, lens): 33 | # padded_seqs: (B, L) 34 | emb_seqs = self.embedding(padded_seqs) # (B, L, C) 35 | sr = self.layer(emb_seqs, lens) 36 | logits = self.fc_sr(sr) @ self.embedding(self.indices).t() 37 | return logits 38 | 39 | 40 | config = Dict({ 41 | 'Model': NextItNet, 42 | 'collate_fn': collate_fn_for_rnn_cnn, 43 | 'BatchSampler': BatchSampler, 44 | 'prepare_batch_factory': prepare_batch_factory, 45 | }) 46 | -------------------------------------------------------------------------------- /srs/models/SERec.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | import dgl 4 | 5 | from srs.layers.seframe import SEFrame 6 | from srs.layers.serec import SERecLayer 7 | from srs.utils.data.collate import CollateFnGNN 8 | from srs.utils.Dict import Dict 9 | from srs.utils.prepare_batch import prepare_batch_factory_recursive 10 | from srs.utils.data.transform import seq_to_weighted_graph 11 | 12 | 13 | class SERec(SEFrame): 14 | def __init__( 15 | self, 16 | num_users, 17 | num_items, 18 | embedding_dim, 19 | knowledge_graph, 20 | num_layers, 21 | relu=False, 22 | batch_norm=True, 23 | feat_drop=0.0, 24 | **kwargs 25 | ): 26 | super().__init__( 27 | num_users, 28 | num_items, 29 | embedding_dim, 30 | knowledge_graph, 31 | num_layers, 32 | batch_norm=batch_norm, 33 | feat_drop=feat_drop, 34 | ) 35 | self.fc_i = nn.Linear(embedding_dim, embedding_dim, bias=False) 36 | self.fc_u = nn.Linear(embedding_dim, embedding_dim, bias=False) 37 | self.PSE_layer = SERecLayer( 38 | embedding_dim, 39 | num_steps=1, 40 | batch_norm=batch_norm, 41 | feat_drop=feat_drop, 42 | relu=relu, 43 | ) 44 | input_dim = 3 * embedding_dim 45 | self.batch_norm = nn.BatchNorm1d(input_dim) if batch_norm else None 46 | self.fc_sr = nn.Linear(input_dim, embedding_dim, bias=False) 47 | 48 | def forward(self, inputs, extra_inputs=None): 49 | KG_embeddings = super().forward(extra_inputs) 50 | 51 | uid, g = inputs 52 | iid = g.ndata['iid'] # (num_nodes,) 53 | feat_i = KG_embeddings['item'][iid] 54 | feat_u = KG_embeddings['user'][uid] 55 | feat = self.fc_i(feat_i) + dgl.broadcast_nodes(g, self.fc_u(feat_u)) 56 | feat_i = self.PSE_layer(g, feat, feat_u) 57 | sr = th.cat([feat_i, feat_u], dim=1) 58 | if self.batch_norm is not None: 59 | sr = self.batch_norm(sr) 60 | logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() 61 | return logits 62 | 63 | 64 | seq_to_graph_fns = [seq_to_weighted_graph] 65 | 66 | config = Dict({ 67 | 'Model': SERec, 68 | 'seq_to_graph_fns': seq_to_graph_fns, 69 | 'CollateFn': CollateFnGNN, 70 | 'prepare_batch_factory': prepare_batch_factory_recursive, 71 | }) 72 | -------------------------------------------------------------------------------- /srs/models/SNARM.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.seframe import SEFrame 5 | from srs.layers.narm import NARMLayer 6 | from srs.utils.data.collate import CollateFnRNNCNN 7 | from srs.utils.data.load import BatchSampler 8 | from srs.utils.Dict import Dict 9 | from srs.utils.prepare_batch import prepare_batch_factory_recursive 10 | 11 | 12 | class SNARM(SEFrame): 13 | def __init__( 14 | self, 15 | num_users, 16 | num_items, 17 | embedding_dim, 18 | knowledge_graph, 19 | num_layers, 20 | batch_norm=True, 21 | feat_drop=0.0, 22 | **kwargs 23 | ): 24 | super().__init__( 25 | num_users, 26 | num_items, 27 | embedding_dim, 28 | knowledge_graph, 29 | num_layers, 30 | batch_norm=batch_norm, 31 | feat_drop=feat_drop, 32 | **kwargs, 33 | ) 34 | self.fc_i = nn.Linear(embedding_dim, embedding_dim, bias=False) 35 | self.fc_u = nn.Linear(embedding_dim, embedding_dim, bias=False) 36 | self.PSE_layer = NARMLayer(embedding_dim, feat_drop=feat_drop) 37 | self.fc_sr = nn.Linear(3 * embedding_dim, embedding_dim, bias=False) 38 | 39 | def forward(self, inputs, extra_inputs=None): 40 | KG_embeddings = super().forward(extra_inputs) 41 | 42 | uids, padded_seqs, lens = inputs 43 | emb_seqs = KG_embeddings['item'][padded_seqs] 44 | feat_u = KG_embeddings['user'][uids] 45 | feat = self.fc_i(emb_seqs) + self.fc_u(feat_u).unsqueeze(1) 46 | feat_i = self.PSE_layer(feat, lens) 47 | sr = th.cat([feat_i, feat_u], dim=1) 48 | logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() 49 | return logits 50 | 51 | 52 | config = Dict({ 53 | 'Model': SNARM, 54 | 'CollateFn': CollateFnRNNCNN, 55 | 'BatchSampler': BatchSampler, 56 | 'prepare_batch_factory': prepare_batch_factory_recursive, 57 | }) 58 | -------------------------------------------------------------------------------- /srs/models/SNextItNet.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.seframe import SEFrame 5 | from srs.layers.nextitnet import NextItNetLayer 6 | from srs.utils.data.collate import CollateFnRNNCNN 7 | from srs.utils.data.load import BatchSampler 8 | from srs.utils.Dict import Dict 9 | from srs.utils.prepare_batch import prepare_batch_factory_recursive 10 | 11 | 12 | class SNextItNet(SEFrame): 13 | def __init__( 14 | self, 15 | num_users, 16 | num_items, 17 | embedding_dim, 18 | knowledge_graph, 19 | num_layers, 20 | batch_norm=True, 21 | feat_drop=0.0, 22 | **kwargs 23 | ): 24 | super().__init__( 25 | num_users, 26 | num_items, 27 | embedding_dim, 28 | knowledge_graph, 29 | num_layers, 30 | batch_norm=batch_norm, 31 | feat_drop=feat_drop, 32 | ) 33 | self.fc_i = nn.Linear(embedding_dim, embedding_dim, bias=False) 34 | self.fc_u = nn.Linear(embedding_dim, embedding_dim, bias=False) 35 | self.PSE_layer = NextItNetLayer( 36 | embedding_dim, 37 | dilations=None, 38 | one_masked=False, 39 | kernel_size=3, 40 | feat_drop=feat_drop, 41 | ) 42 | self.fc_sr = nn.Linear(2 * embedding_dim, embedding_dim, bias=False) 43 | 44 | def forward(self, inputs, extra_inputs=None): 45 | KG_embeddings = super().forward(extra_inputs) 46 | 47 | uids, padded_seqs, lens = inputs 48 | emb_seqs = KG_embeddings['item'][padded_seqs] 49 | feat_u = KG_embeddings['user'][uids] 50 | feat = self.fc_i(emb_seqs) + self.fc_u(feat_u).unsqueeze(1) 51 | feat_i = self.PSE_layer(feat, lens) 52 | sr = th.cat([feat_i, feat_u], dim=1) 53 | logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() 54 | return logits 55 | 56 | 57 | config = Dict( 58 | { 59 | 'Model': SNextItNet, 60 | 'CollateFn': CollateFnRNNCNN, 61 | 'BatchSampler': BatchSampler, 62 | 'prepare_batch_factory': prepare_batch_factory_recursive, 63 | } 64 | ) 65 | -------------------------------------------------------------------------------- /srs/models/SRGNN.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.srgnn import SRGNNLayer 5 | from srs.utils.data.collate import collate_fn_for_gnn_factory 6 | from srs.utils.Dict import Dict 7 | from srs.utils.prepare_batch import prepare_batch_factory 8 | from srs.utils.data.transform import seq_to_unweighted_graph 9 | 10 | 11 | class SRGNN(nn.Module): 12 | def __init__(self, num_items, embedding_dim, feat_drop=0.0, **kwargs): 13 | super().__init__() 14 | self.item_embedding = nn.Embedding(num_items, embedding_dim, max_norm=1) 15 | self.item_indices = nn.Parameter( 16 | th.arange(num_items, dtype=th.long), requires_grad=False 17 | ) 18 | self.layer = SRGNNLayer(embedding_dim, feat_drop=feat_drop) 19 | self.fc_sr = nn.Linear(2 * embedding_dim, embedding_dim, bias=False) 20 | 21 | def forward(self, uid, g): 22 | iid = g.ndata['iid'] # (num_nodes,) 23 | feat = self.item_embedding(iid) 24 | sr = self.layer(g, feat) 25 | logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() 26 | return logits 27 | 28 | 29 | seq_to_graph_fns = [seq_to_unweighted_graph] 30 | collate_fn = collate_fn_for_gnn_factory(*seq_to_graph_fns) 31 | 32 | config = Dict({ 33 | 'Model': SRGNN, 34 | 'seq_to_graph_fns': seq_to_graph_fns, 35 | 'collate_fn': collate_fn, 36 | 'prepare_batch_factory': prepare_batch_factory, 37 | }) 38 | -------------------------------------------------------------------------------- /srs/models/SSRGNN.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | import dgl 4 | 5 | from srs.layers.seframe import SEFrame 6 | from srs.layers.srgnn import SRGNNLayer 7 | from srs.utils.data.collate import CollateFnGNN 8 | from srs.utils.Dict import Dict 9 | from srs.utils.prepare_batch import prepare_batch_factory_recursive 10 | from srs.utils.data.transform import seq_to_unweighted_graph 11 | 12 | 13 | class SSRGNN(SEFrame): 14 | def __init__( 15 | self, 16 | num_users, 17 | num_items, 18 | embedding_dim, 19 | knowledge_graph, 20 | num_layers, 21 | batch_norm=True, 22 | feat_drop=0.0, 23 | **kwargs 24 | ): 25 | super().__init__( 26 | num_users, 27 | num_items, 28 | embedding_dim, 29 | knowledge_graph, 30 | num_layers, 31 | batch_norm=batch_norm, 32 | feat_drop=feat_drop, 33 | ) 34 | self.fc_i = nn.Linear(embedding_dim, embedding_dim, bias=False) 35 | self.fc_u = nn.Linear(embedding_dim, embedding_dim, bias=False) 36 | self.PSE_layer = SRGNNLayer(embedding_dim, feat_drop=feat_drop) 37 | self.fc_sr = nn.Linear(3 * embedding_dim, embedding_dim, bias=False) 38 | 39 | def forward(self, inputs, extra_inputs=None): 40 | KG_embeddings = super().forward(extra_inputs) 41 | 42 | uid, g = inputs 43 | iid = g.ndata['iid'] # (num_nodes,) 44 | feat_i = KG_embeddings['item'][iid] 45 | feat_u = KG_embeddings['user'][uid] 46 | feat = self.fc_i(feat_i) + dgl.broadcast_nodes(g, self.fc_u(feat_u)) 47 | feat_i = self.PSE_layer(g, feat) 48 | sr = th.cat([feat_i, feat_u], dim=1) 49 | logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() 50 | return logits 51 | 52 | 53 | seq_to_graph_fns = [seq_to_unweighted_graph] 54 | 55 | config = Dict({ 56 | 'Model': SSRGNN, 57 | 'seq_to_graph_fns': seq_to_graph_fns, 58 | 'CollateFn': CollateFnGNN, 59 | 'prepare_batch_factory': prepare_batch_factory_recursive, 60 | }) 61 | -------------------------------------------------------------------------------- /srs/models/SSRM.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.ssrm import SSRMLayer 5 | from srs.utils.data.collate import collate_fn_for_rnn_cnn 6 | from srs.utils.data.load import BatchSampler 7 | from srs.utils.Dict import Dict 8 | from srs.utils.prepare_batch import prepare_batch_factory 9 | 10 | 11 | class SSRM(nn.Module): 12 | def __init__( 13 | self, num_users, num_items, embedding_dim, w=0.5, feat_drop=0.0, **kwargs 14 | ): 15 | super().__init__() 16 | self.user_embedding = nn.Embedding(num_users, embedding_dim, max_norm=1) 17 | self.item_embedding = nn.Embedding(num_items, embedding_dim, max_norm=1) 18 | self.indices = nn.Parameter( 19 | th.arange(num_items, dtype=th.long), requires_grad=False 20 | ) 21 | self.layer = SSRMLayer(embedding_dim, w, feat_drop=feat_drop) 22 | 23 | def forward(self, uids, padded_seqs, lens): 24 | feat_u = self.user_embedding(uids) 25 | emb_seqs = self.item_embedding(padded_seqs) 26 | sr = self.layer(emb_seqs, lens, feat_u) 27 | logits = sr @ self.item_embedding(self.indices).t() 28 | return logits 29 | 30 | 31 | config = Dict({ 32 | 'Model': SSRM, 33 | 'collate_fn': collate_fn_for_rnn_cnn, 34 | 'BatchSampler': BatchSampler, 35 | 'prepare_batch_factory': prepare_batch_factory, 36 | }) 37 | -------------------------------------------------------------------------------- /srs/models/SSSRM.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.seframe import SEFrame 5 | from srs.layers.ssrm import SSRMLayer 6 | from srs.utils.data.collate import CollateFnRNNCNN 7 | from srs.utils.data.load import BatchSampler 8 | from srs.utils.Dict import Dict 9 | from srs.utils.prepare_batch import prepare_batch_factory_recursive 10 | 11 | 12 | class SSSRM(SEFrame): 13 | def __init__( 14 | self, 15 | num_users, 16 | num_items, 17 | embedding_dim, 18 | knowledge_graph, 19 | num_layers, 20 | w=0.5, 21 | batch_norm=True, 22 | feat_drop=0.0, 23 | **kwargs 24 | ): 25 | super().__init__( 26 | num_users, 27 | num_items, 28 | embedding_dim, 29 | knowledge_graph, 30 | num_layers, 31 | batch_norm=batch_norm, 32 | feat_drop=feat_drop, 33 | ) 34 | self.PSE_layer = SSRMLayer(embedding_dim, w, feat_drop=feat_drop) 35 | self.fc_sr = nn.Linear(2 * embedding_dim, embedding_dim, bias=False) 36 | 37 | def forward(self, inputs, extra_inputs=None): 38 | KG_embeddings = super().forward(extra_inputs) 39 | 40 | uids, padded_seqs, lens = inputs 41 | emb_seqs = KG_embeddings['item'][padded_seqs] 42 | feat_u = KG_embeddings['user'][uids] 43 | sr = self.PSE_layer(emb_seqs, lens, feat_u) 44 | sr = th.cat([sr, feat_u], dim=1) 45 | logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() 46 | 47 | return logits 48 | 49 | 50 | config = Dict({ 51 | 'Model': SSSRM, 52 | 'CollateFn': CollateFnRNNCNN, 53 | 'BatchSampler': BatchSampler, 54 | 'prepare_batch_factory': prepare_batch_factory_recursive, 55 | }) 56 | -------------------------------------------------------------------------------- /srs/models/SSTAMP.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.seframe import SEFrame 5 | from srs.layers.stamp import STAMPLayer 6 | from srs.utils.data.collate import CollateFnRNNCNN 7 | from srs.utils.data.load import BatchSampler 8 | from srs.utils.Dict import Dict 9 | from srs.utils.prepare_batch import prepare_batch_factory_recursive 10 | 11 | 12 | class SSTAMP(SEFrame): 13 | def __init__( 14 | self, 15 | num_users, 16 | num_items, 17 | embedding_dim, 18 | knowledge_graph, 19 | num_layers, 20 | batch_norm=True, 21 | feat_drop=0.0, 22 | **kwargs 23 | ): 24 | super().__init__( 25 | num_users, 26 | num_items, 27 | embedding_dim, 28 | knowledge_graph, 29 | num_layers, 30 | batch_norm=batch_norm, 31 | feat_drop=feat_drop, 32 | ) 33 | self.fc_i = nn.Linear(embedding_dim, embedding_dim, bias=False) 34 | self.fc_u = nn.Linear(embedding_dim, embedding_dim, bias=False) 35 | self.PSE_layer = STAMPLayer(embedding_dim, feat_drop=feat_drop) 36 | self.fc_sr = nn.Linear(2 * embedding_dim, embedding_dim, bias=False) 37 | 38 | def forward(self, inputs, extra_inputs=None): 39 | KG_embeddings = super().forward(extra_inputs) 40 | 41 | uids, padded_seqs, lens = inputs 42 | emb_seqs = KG_embeddings['item'][padded_seqs] 43 | feat_u = KG_embeddings['user'][uids] 44 | feat = self.fc_i(emb_seqs) + self.fc_u(feat_u).unsqueeze(1) 45 | feat_i = self.PSE_layer(feat, lens) 46 | sr = th.cat([feat_i, feat_u], dim=1) 47 | logits = self.fc_sr(sr) @ self.item_embedding(self.item_indices).t() 48 | return logits 49 | 50 | 51 | config = Dict({ 52 | 'Model': SSTAMP, 53 | 'CollateFn': CollateFnRNNCNN, 54 | 'BatchSampler': BatchSampler, 55 | 'prepare_batch_factory': prepare_batch_factory_recursive, 56 | }) 57 | -------------------------------------------------------------------------------- /srs/models/STAMP.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import nn 3 | 4 | from srs.layers.stamp import STAMPLayer 5 | from srs.utils.data.collate import collate_fn_for_rnn_cnn 6 | from srs.utils.data.load import BatchSampler 7 | from srs.utils.Dict import Dict 8 | from srs.utils.prepare_batch import prepare_batch_factory 9 | 10 | 11 | class STAMP(nn.Module): 12 | def __init__(self, num_items, embedding_dim, feat_drop=0.0, **kwargs): 13 | super().__init__() 14 | self.embedding = nn.Embedding(num_items, embedding_dim, max_norm=1) 15 | self.indices = nn.Parameter( 16 | th.arange(num_items, dtype=th.long), requires_grad=False 17 | ) 18 | self.layer = STAMPLayer(embedding_dim, feat_drop=feat_drop) 19 | 20 | def forward(self, uids, padded_seqs, lens): 21 | emb_seqs = self.embedding(padded_seqs) # (batch_size, max_len, d) 22 | sr = self.layer(emb_seqs, lens) 23 | logits = sr @ self.embedding(self.indices).t() 24 | return logits 25 | 26 | 27 | config = Dict( 28 | { 29 | 'Model': STAMP, 30 | 'collate_fn': collate_fn_for_rnn_cnn, 31 | 'BatchSampler': BatchSampler, 32 | 'prepare_batch_factory': prepare_batch_factory, 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /srs/utils/Dict.py: -------------------------------------------------------------------------------- 1 | class Dict(dict): 2 | def __getattr__(self, key): 3 | return self.__getitem__(key) 4 | 5 | def __setattr__(self, key, val): 6 | return self.__setitem__(key, val) 7 | 8 | def __setitem__(self, key, val): 9 | if type(val) is dict: 10 | val = Dict(val) 11 | super().__setitem__(key, val) 12 | 13 | def __str__(self): 14 | import re 15 | 16 | _str = '' 17 | for key, val in sorted(self.items()): 18 | if type(key) is not str: 19 | continue 20 | val_str = str(val) 21 | if len(val_str) > 80: 22 | val_str = re.sub(r'\s+', ' ', val_str.strip())[:60] 23 | val_str = f'"{val_str}..."' 24 | _str += f"{key}: {val_str}\n" 25 | return _str 26 | -------------------------------------------------------------------------------- /srs/utils/argparse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class ArgumentParser(argparse.ArgumentParser): 5 | def __init__(self, **kwargs): 6 | super().__init__( 7 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, **kwargs 8 | ) 9 | self.optional = self._action_groups.pop() 10 | self.required = self.add_argument_group('required arguments') 11 | self._action_groups.append(self.optional) 12 | 13 | def add_argument(self, *args, **kwargs): 14 | if kwargs.get('required', False): 15 | return self.required.add_argument(*args, **kwargs) 16 | else: 17 | return super().add_argument(*args, **kwargs) 18 | -------------------------------------------------------------------------------- /srs/utils/data/collate.py: -------------------------------------------------------------------------------- 1 | import random 2 | from bisect import bisect_left 3 | import numpy as np 4 | import pandas as pd 5 | import torch as th 6 | import dgl 7 | 8 | 9 | def collate_fn_for_rnn_cnn(samples): 10 | uids, seqs, labels = zip(*samples) 11 | uids = th.LongTensor(uids) 12 | labels = th.LongTensor(labels) 13 | 14 | seqs = list(map(lambda seq: th.LongTensor(seq), seqs)) 15 | padded_seqs = th.nn.utils.rnn.pad_sequence(seqs, batch_first=True) 16 | lens = th.LongTensor(list(map(len, seqs))) 17 | inputs = uids, padded_seqs, lens 18 | return inputs, labels 19 | 20 | 21 | def collate_fn_for_gnn_factory(*seq_to_graph_fns): 22 | def collate_fn(samples): 23 | uids, seqs, labels = zip(*samples) 24 | uids = th.LongTensor(uids) 25 | labels = th.LongTensor(labels) 26 | 27 | inputs = [uids] 28 | for seq_to_graph in seq_to_graph_fns: 29 | graphs = [seq_to_graph(seq) for seq in seqs] 30 | bg = dgl.batch(graphs) 31 | inputs.append(bg) 32 | return inputs, labels 33 | 34 | return collate_fn 35 | 36 | 37 | def sample_blocks(g, uniq_uids, uniq_iids, fanouts, steps): 38 | seeds = {'user': th.LongTensor(uniq_uids), 'item': th.LongTensor(uniq_iids)} 39 | blocks = [] 40 | for fanout in fanouts: 41 | if fanout <= 0: 42 | frontier = dgl.in_subgraph(g, seeds) 43 | else: 44 | frontier = dgl.sampling.sample_neighbors( 45 | g, seeds, fanout, copy_ndata=False, copy_edata=True 46 | ) 47 | block = dgl.to_block(frontier, seeds) 48 | seeds = {ntype: block.srcnodes[ntype].data[dgl.NID] for ntype in block.srctypes} 49 | blocks.insert(0, block) 50 | return blocks, seeds 51 | 52 | 53 | class CollateFnRNNCNN: 54 | def __init__(self, knowledge_graph, num_layers, num_neighbors, **kwargs): 55 | self.knowledge_graph = knowledge_graph 56 | self.num_layers = num_layers 57 | # num_neighbors is a list of integers 58 | if len(num_neighbors) != num_layers: 59 | assert len(num_neighbors) == 1 60 | self.fanouts = num_neighbors * num_layers 61 | else: 62 | self.fanouts = num_neighbors 63 | 64 | def _collate_fn(self, samples, fanouts): 65 | uids, seqs, labels = zip(*samples) 66 | 67 | batch_size = len(seqs) 68 | lens = list(map(len, seqs)) 69 | max_len = max(lens) 70 | 71 | iids = np.concatenate(seqs) 72 | new_iids, uniq_iids = pd.factorize(iids, sort=True) 73 | padded_seqs = np.zeros((batch_size, max_len), dtype=np.long) 74 | cur_idx = 0 75 | for i, seq in enumerate(seqs): 76 | padded_seqs[i, :len(seq)] = new_iids[cur_idx:cur_idx + len(seq)] 77 | cur_idx += len(seq) 78 | 79 | new_uids, uniq_uids = pd.factorize(uids, sort=True) 80 | 81 | extra_inputs = sample_blocks( 82 | self.knowledge_graph, uniq_uids, uniq_iids, fanouts, self.num_layers 83 | ) 84 | 85 | new_uids = th.LongTensor(new_uids) 86 | padded_seqs = th.from_numpy(padded_seqs) 87 | lens = th.LongTensor(lens) 88 | labels = th.LongTensor(labels) 89 | inputs = new_uids, padded_seqs, lens 90 | return (inputs, extra_inputs), labels 91 | 92 | def collate_train(self, samples): 93 | return self._collate_fn(samples, self.fanouts) 94 | 95 | def collate_test(self, samples): 96 | inputs, labels = collate_fn_for_rnn_cnn(samples) 97 | return (inputs, ), labels 98 | 99 | def collate_test_otf(self, samples): 100 | return self._collate_fn(samples, [0] * self.num_layers) 101 | 102 | 103 | class CollateFnGNN: 104 | def __init__( 105 | self, knowledge_graph, num_layers, num_neighbors, seq_to_graph_fns, **kwargs 106 | ): 107 | self.knowledge_graph = knowledge_graph 108 | self.num_layers = num_layers 109 | self.seq_to_graph_fns = seq_to_graph_fns 110 | # num_neighbors is a list of integers 111 | if len(num_neighbors) != num_layers: 112 | assert len(num_neighbors) == 1 113 | self.fanouts = num_neighbors * num_layers 114 | else: 115 | self.fanouts = num_neighbors 116 | 117 | def _collate_fn(self, samples, fanouts): 118 | uids, seqs, labels = zip(*samples) 119 | 120 | new_uids, uniq_uids = pd.factorize(uids, sort=True) 121 | new_uids = th.LongTensor(new_uids) 122 | labels = th.LongTensor(labels) 123 | 124 | iids = np.concatenate(seqs) 125 | new_iids, uniq_iids = pd.factorize(iids, sort=True) 126 | cur_idx = 0 127 | new_seqs = [] 128 | for i, seq in enumerate(seqs): 129 | new_seq = new_iids[cur_idx:cur_idx + len(seq)] 130 | cur_idx += len(seq) 131 | new_seqs.append(new_seq) 132 | 133 | inputs = [new_uids] 134 | for seq_to_graph in self.seq_to_graph_fns: 135 | graphs = [seq_to_graph(seq) for seq in new_seqs] 136 | bg = dgl.batch(graphs) 137 | inputs.append(bg) 138 | 139 | extra_inputs = sample_blocks( 140 | self.knowledge_graph, uniq_uids, uniq_iids, fanouts, self.num_layers 141 | ) 142 | return (inputs, extra_inputs), labels 143 | 144 | def collate_train(self, samples): 145 | return self._collate_fn(samples, self.fanouts) 146 | 147 | def collate_test(self, samples): 148 | uids, seqs, labels = zip(*samples) 149 | 150 | uids = th.LongTensor(uids) 151 | labels = th.LongTensor(labels) 152 | 153 | inputs = [uids] 154 | for seq_to_graph in self.seq_to_graph_fns: 155 | graphs = [seq_to_graph(seq) for seq in seqs] 156 | bg = dgl.batch(graphs) 157 | inputs.append(bg) 158 | 159 | return (inputs, ), labels 160 | 161 | def collate_test_otf(self, samples): 162 | return self._collate_fn(samples, [0] * self.num_layers) 163 | 164 | 165 | class CollateFnDGRec: 166 | def __init__( 167 | self, visible_time_list, in_neighbors, uid2sessions, num_layers, num_neighbors, 168 | **kwargs 169 | ): 170 | """ 171 | Args 172 | ---- 173 | visible_time_list: 174 | `visible_time_list[l][i]` is the time t when user i has a l-hop neighbor 175 | such that every user along the path from the neighbor to user i has 176 | generated a session at or before time t. 177 | in_neighbors: 178 | `in_neighbors[i]` is the user ids of the incomping neighbors of user i. 179 | uid2sessions: 180 | `uid2sessions[i]` is a list of all training sessions generated by user i, 181 | sorted in ascending order by session id. Since a session with a smaller 182 | session id has an earlier end time, the sessions are also sorted in 183 | ascending order by end time. 184 | num_layers: int 185 | The number of graph attention layers 186 | num_neighbors: list[int] 187 | `len(num_neighbors)` should be either num_layers or 1 188 | `num_neighbors[l]` is the number of sampled neighbors at layer l. 189 | If `len(num_neighbors)` is 1, then the number of sampled neighbors is the 190 | same in all layers. 191 | """ 192 | self.visible_time_list = visible_time_list[:-1] 193 | self.in_neighbors = in_neighbors 194 | self.uid2sessions = uid2sessions 195 | # num_neighbors is a list of integers 196 | if len(num_neighbors) != num_layers: 197 | assert len(num_neighbors) == 1 198 | self.fanouts = num_neighbors * num_layers 199 | # repeat num_neighbors[0] num_layers times. 200 | else: 201 | self.fanouts = num_neighbors 202 | 203 | def sample_sessions(self, sid, uid, seq, all_uids): 204 | """ 205 | Args 206 | ---- 207 | sid: int 208 | The session id of the current session 209 | uid: int 210 | The user id of the current session 211 | seq: list[int] 212 | The prefix of the current session 213 | all_uids: list[int] 214 | All user ids in the (sampled) L-hop neighborhood of the current 215 | user, including the current user id 216 | 217 | Returns 218 | ------- 219 | sessions: list 220 | A list of sessions, where the i-th session is the latest session, 221 | i.e., the last session happened before the current session, of 222 | the i-th user in all_uids 223 | """ 224 | sessions = [] 225 | for neigh_uid in all_uids: 226 | if neigh_uid == uid: 227 | sessions.append((uid, seq)) 228 | else: 229 | sids = self.uid2sessions[neigh_uid]['sids'] 230 | idx = bisect_left(sids, sid) 231 | assert idx > 0 232 | session = self.uid2sessions[neigh_uid]['sessions'][idx - 1] 233 | sessions.append((neigh_uid, session)) 234 | return sessions 235 | 236 | def sample_blocks(self, sid, uid): 237 | """ 238 | Args 239 | ---- 240 | sid : int 241 | The session id of the current session 242 | uid : int 243 | The user id of the current session 244 | 245 | Returns 246 | ------- 247 | blocks : list[DGLGraph] 248 | A list of bipartite graphs. `blocks[-i]` is a graph from the sampled 249 | users at the i-th layer to the sampled users at the (i-1)-th layer, 250 | for 1 <= i <= L. (the sampled user at the 0-th layer contains uid only) 251 | The target nodes in `blocks[i]` are included at the beginning of 252 | the source nodes in `blocks[i]`. 253 | all_uids : list[int] 254 | All user ids in the (sampled) L-hop neighborhood of the current user. 255 | The first entry is `uid`. 256 | """ 257 | blocks = [] 258 | seeds = [uid] 259 | nid_map = {uid: 0} 260 | for fanout, visible_time in zip(self.fanouts, self.visible_time_list[::-1]): 261 | src = [] 262 | dst = [] 263 | for i, node in enumerate(seeds): 264 | candidates = [ 265 | neigh for neigh in self.in_neighbors[node] 266 | if visible_time[neigh] <= sid 267 | ] 268 | assert len(candidates) > 0 269 | if len(candidates) <= fanout: 270 | sampled_neighs = candidates 271 | else: 272 | sampled_neighs = random.sample(candidates, fanout) 273 | for neigh in sampled_neighs: 274 | if neigh not in nid_map: 275 | nid_map[neigh] = len(nid_map) 276 | src.append(nid_map[neigh]) 277 | dst += [i] * len(sampled_neighs) 278 | block = dgl.heterograph( 279 | data_dict={('followee', 'followedby', 'follower'): (src, dst)}, 280 | num_nodes_dict={ 281 | 'followee': len(nid_map), 282 | 'follower': len(seeds) 283 | } 284 | ) 285 | blocks.insert(0, block) 286 | seeds = list(nid_map.keys()) 287 | return blocks, seeds 288 | 289 | def collate_train(self, samples): 290 | batch_blocks = [] 291 | batch_sessions = [] 292 | labels = [] 293 | # the uid is the first src node of block. 294 | cur_sidx = [0] 295 | # cur_sidx[b]: the index (in batch_sessions) of the b-th ongoing session 296 | # in the current batch 297 | for sid, uid, seq, label in samples: 298 | blocks, all_uids = self.sample_blocks(sid, uid) 299 | sessions = self.sample_sessions(sid, uid, seq, all_uids) 300 | batch_blocks.append(blocks) 301 | batch_sessions += sessions 302 | labels.append(label) 303 | cur_sidx.append(cur_sidx[-1] + len(sessions)) 304 | graphs = [] 305 | # graphs[i]: a graph that batches the blocks of all samples at the i-th layer 306 | idx_maps = [] 307 | # idx_maps[i]: the indices of the target nodes of graphs[i] 308 | # in the source nodes of graphs[i] 309 | for blocks in zip(*batch_blocks): 310 | graphs.append(dgl.batch(blocks)) 311 | idx_map = [] 312 | total_number_of_src_nodes = 0 313 | for block in blocks: 314 | idx = th.arange(block.number_of_dst_nodes()) + total_number_of_src_nodes 315 | idx_map.append(idx) 316 | total_number_of_src_nodes += block.number_of_src_nodes() 317 | idx_map = th.cat(idx_map) 318 | idx_maps.append(idx_map) 319 | uids, seqs = zip(*batch_sessions) 320 | tensor_seqs = list(map(lambda seq: th.LongTensor(seq) + 1, seqs)) 321 | padded_seqs = th.nn.utils.rnn.pad_sequence(tensor_seqs, batch_first=True) 322 | lens = list(map(len, tensor_seqs)) 323 | 324 | uids = th.LongTensor(uids) 325 | lens = th.LongTensor(lens) 326 | cur_sidx = th.LongTensor(cur_sidx[:-1]) 327 | inputs = [graphs, idx_maps, uids, padded_seqs, lens, cur_sidx] 328 | labels = th.LongTensor(labels) 329 | return inputs, labels 330 | 331 | collate_test = collate_train 332 | -------------------------------------------------------------------------------- /srs/utils/data/load.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import itertools 4 | from collections import Counter 5 | 6 | import pandas as pd 7 | import numpy as np 8 | 9 | import torch as th 10 | import dgl 11 | 12 | 13 | class AugmentedDataset: 14 | def __init__(self, df, read_sid=False, sort_by_length=True): 15 | if read_sid: 16 | df = df[['sessionId', 'userId', 'items']] 17 | else: 18 | df = df[['userId', 'items']] 19 | self.sessions = df.values 20 | session_lens = df['items'].apply(len) 21 | index = create_index(session_lens) 22 | if sort_by_length: 23 | # sort by labelIndex in descending order 24 | # it is to be used with BatchSampler to make data loading of RNN models faster 25 | ind = np.argsort(index[:, 1])[::-1] 26 | index = index[ind] 27 | self.index = index 28 | 29 | def __getitem__(self, idx): 30 | sidx, lidx = self.index[idx] 31 | sess = self.sessions[sidx] 32 | seq = sess[-1][:lidx] 33 | label = sess[-1][lidx] 34 | item = (*sess[:-1], seq, label) 35 | return item 36 | 37 | def __len__(self): 38 | return len(self.index) 39 | 40 | 41 | class AnonymousAugmentedDataset: 42 | def __init__(self, df, sort_by_length=True): 43 | self.sessions = df['items'].values 44 | session_lens = np.fromiter(map(len, self.sessions), dtype=np.long) 45 | index = create_index(session_lens) 46 | if sort_by_length: 47 | ind = np.argsort(index[:, 1])[::-1] 48 | index = index[ind] 49 | self.index = index 50 | 51 | def __getitem__(self, idx): 52 | sidx, lidx = self.index[idx] 53 | sess = self.sessions[sidx] 54 | seq = sess[:lidx] 55 | label = sess[lidx] 56 | item = (0, seq, label) 57 | return item 58 | 59 | def __len__(self): 60 | return len(self.index) 61 | 62 | 63 | class BatchSampler: 64 | """ 65 | First, the sequences of the same length are grouped into the same batch 66 | Then, the remaining sequences of similar lengths are grouped into the same batch 67 | the sequences in a batch is sorted by length in desending order 68 | """ 69 | def __init__(self, augmented_dataset, batch_size, drop_last=False, seed=None): 70 | df_index = pd.DataFrame( 71 | augmented_dataset.index, columns=['sessionId', 'labelIdx'] 72 | ) 73 | self.groups = [df for _, df in df_index.groupby('labelIdx')] 74 | self.groups.sort(key=lambda g: g.iloc[0].labelIdx, reverse=True) 75 | self.batch_size = batch_size 76 | num_batches = len(augmented_dataset) / batch_size 77 | if drop_last: 78 | self.num_batches = math.floor(num_batches) 79 | else: 80 | self.num_batches = math.ceil(num_batches) 81 | self.drop_last = drop_last 82 | self.seed = seed 83 | 84 | def _create_batch_indices(self): 85 | # shuffle sequences of the same length 86 | groups = [df.sample(frac=1, random_state=self.seed) for df in self.groups] 87 | df_index = pd.concat(groups) 88 | # shuffle batches 89 | batch_indices = [ 90 | df.index 91 | for _, df in df_index.groupby(np.arange(len(df_index)) // self.batch_size) 92 | ][:self.num_batches] 93 | random.seed(self.seed) 94 | random.shuffle(batch_indices) 95 | self.batch_indices = batch_indices 96 | if self.seed is not None: 97 | self.seed += 1 98 | 99 | def __iter__(self): 100 | self._create_batch_indices() 101 | return iter(self.batch_indices) 102 | 103 | def __len__(self): 104 | return self.num_batches 105 | 106 | 107 | def create_index(session_lens): 108 | num_sessions = len(session_lens) 109 | session_idx = np.repeat(np.arange(num_sessions), session_lens - 1) 110 | label_idx = map(lambda l: range(1, l), session_lens) 111 | label_idx = itertools.chain.from_iterable(label_idx) 112 | label_idx = np.fromiter(label_idx, dtype=np.long) 113 | idx = np.column_stack((session_idx, label_idx)) 114 | return idx 115 | 116 | 117 | def read_sessions(filepath): 118 | df = pd.read_csv(filepath, sep='\t') 119 | df['items'] = df['items'].apply(lambda x: [int(i) for i in x.split(',')]) 120 | return df 121 | 122 | 123 | def read_dataset(dataset_dir): 124 | stats = pd.read_csv(dataset_dir / 'stats.txt', sep='\t').iloc[0] 125 | df_train = read_sessions(dataset_dir / 'train.txt') 126 | df_valid = read_sessions(dataset_dir / 'valid.txt') 127 | df_test = read_sessions(dataset_dir / 'test.txt') 128 | return df_train, df_valid, df_test, stats 129 | 130 | 131 | def read_social_network(csv_file): 132 | df = pd.read_csv(csv_file, sep='\t') 133 | g = dgl.graph((df.followee.values, df.follower.values)) 134 | return g 135 | 136 | 137 | def build_knowledge_graph(df_train, social_network, do_count_clipping=True): 138 | print('building heterogeneous knowledge graph...') 139 | followed_edges = social_network.edges() 140 | clicks = Counter() 141 | transits = Counter() 142 | for _, row in df_train.iterrows(): 143 | uid = row['userId'] 144 | seq = row['items'] 145 | for iid in seq: 146 | clicks[(uid, iid)] += 1 147 | transits.update(zip(seq, seq[1:])) 148 | clicks_u, clicks_i = zip(*clicks.keys()) 149 | prev_i, next_i = zip(*transits.keys()) 150 | kg = dgl.heterograph({ 151 | ('user', 'followedby', 'user'): followed_edges, 152 | ('user', 'clicks', 'item'): (clicks_u, clicks_i), 153 | ('item', 'clickedby', 'user'): (clicks_i, clicks_u), 154 | ('item', 'transitsto', 'item'): (prev_i, next_i), 155 | }) 156 | click_cnts = np.array(list(clicks.values())) 157 | transit_cnts = np.array(list(transits.values())) 158 | if do_count_clipping: 159 | click_cnts = clip_counts(click_cnts) 160 | transit_cnts = clip_counts(transit_cnts) 161 | click_cnts = th.LongTensor(click_cnts) - 1 162 | transit_cnts = th.LongTensor(transit_cnts) - 1 163 | kg.edges['clicks'].data['cnt'] = click_cnts 164 | kg.edges['clickedby'].data['cnt'] = click_cnts 165 | kg.edges['transitsto'].data['cnt'] = transit_cnts 166 | return kg 167 | 168 | 169 | def find_max_count(counts): 170 | max_cnt = np.max(counts) 171 | density = np.histogram( 172 | counts, bins=np.arange(1, max_cnt + 2), range=(1, max_cnt + 1), density=True 173 | )[0] 174 | cdf = np.cumsum(density) 175 | for i in range(max_cnt): 176 | if cdf[i] > 0.95: 177 | return i + 1 178 | return max_cnt 179 | 180 | 181 | def clip_counts(counts): 182 | """ 183 | Truncate the counts to the maximum value of the smallest 95% counts. 184 | This could avoid outliers and reduce the number of count embeddings. 185 | """ 186 | max_cnt = find_max_count(counts) 187 | counts = np.minimum(counts, max_cnt) 188 | return counts 189 | 190 | 191 | def compute_visible_time_list(in_neighbors, zero_hop_visible_time, num_layers): 192 | visible_time_list = [zero_hop_visible_time] 193 | num_nodes = len(zero_hop_visible_time) 194 | for n in range(1, num_layers + 1): 195 | prev_visible_time = visible_time_list[-1] 196 | n_hop_visible_time = [] 197 | for node in range(num_nodes): 198 | if len(in_neighbors[node]) == 0: 199 | neigh_vis_time = float('inf') 200 | else: 201 | neigh_vis_time = min([ 202 | prev_visible_time[neigh] for neigh in in_neighbors[node] 203 | ]) 204 | node_vis_time = max(neigh_vis_time, zero_hop_visible_time[node]) 205 | n_hop_visible_time.append(node_vis_time) 206 | visible_time_list.append(n_hop_visible_time) 207 | return visible_time_list 208 | 209 | 210 | def compute_visible_time_list_and_in_neighbors(df_train, dataset_dir, num_layers): 211 | df_edges = pd.read_csv(dataset_dir / 'edges.txt', sep='\t') 212 | num_nodes = df_edges.values.max() + 1 213 | in_neighbors = [[] for i in range(num_nodes)] 214 | pd_series = df_edges.groupby('follower').followee.apply(list) 215 | for follower, followees in pd_series.items(): 216 | in_neighbors[follower] = followees 217 | 218 | visible_time = df_train.groupby('userId').sessionId.min().values 219 | visible_time_list = compute_visible_time_list( 220 | in_neighbors, visible_time, num_layers 221 | ) 222 | return visible_time_list, in_neighbors 223 | 224 | 225 | def filter_invalid_sessions(*dfs, L_hop_visible_time): 226 | print('filtering invalid sessions') 227 | dfs_filtered = [] 228 | for df in dfs: 229 | sids_to_keep = [] 230 | for _, row in df.iterrows(): 231 | sid = row['sessionId'] 232 | uid = row['userId'] 233 | if L_hop_visible_time[uid] <= sid: 234 | sids_to_keep.append(sid) 235 | df_filtered = df[df.sessionId.isin(sids_to_keep)] 236 | dfs_filtered.append(df_filtered) 237 | return dfs_filtered 238 | -------------------------------------------------------------------------------- /srs/utils/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def group_sessions(df, interval): 5 | df_prev = df.shift() 6 | is_new_session = (df.userId != 7 | df_prev.userId) | (df.timestamp - df_prev.timestamp > interval) 8 | sessionId = is_new_session.cumsum() - 1 9 | df = df.assign(sessionId=sessionId) 10 | return df 11 | 12 | 13 | def filter_short_sessions(df, min_len=2): 14 | session_len = df.groupby('sessionId', sort=False).size() 15 | long_sessions = session_len[session_len >= min_len].index 16 | df_long = df[df.sessionId.isin(long_sessions)] 17 | print( 18 | f'removed {len(session_len) - len(long_sessions)}/{len(session_len)} sessions shorter than {min_len}' 19 | ) 20 | return df_long 21 | 22 | 23 | def filter_infreq_items(df, min_support=5): 24 | item_support = df.groupby('itemId', sort=False).size() 25 | freq_items = item_support[item_support >= min_support].index 26 | df_freq = df[df.itemId.isin(freq_items)] 27 | print( 28 | f'removed {len(item_support) - len(freq_items)}/{len(item_support)} items with supprot < {min_support}' 29 | ) 30 | return df_freq 31 | 32 | 33 | def filter_isolated_users(df_clicks, df_edges): 34 | num_sessions = df_clicks.sessionId.nunique() 35 | num_edges = len(df_edges) 36 | while True: 37 | sess_users = df_clicks.userId.unique() 38 | soci_users = np.unique(df_edges[['follower', 'followee']].values) 39 | # users must be followed by or follow at least one user 40 | df_clicks_new = df_clicks[df_clicks.userId.isin(soci_users)] 41 | # users must have at least one session 42 | df_edges_new = df_edges[df_edges.follower.isin(sess_users) 43 | & df_edges.followee.isin(sess_users)] 44 | if len(df_clicks_new) == len(df_clicks) and len(df_edges) == len(df_edges_new): 45 | break 46 | df_clicks = df_clicks_new 47 | df_edges = df_edges_new 48 | print( 49 | f'removed {num_sessions - df_clicks.sessionId.nunique()}/{num_sessions}' 50 | f' sessions and {num_edges - len(df_edges)}/{num_edges} edges of isolated users' 51 | ) 52 | return df_clicks, df_edges 53 | 54 | 55 | def filter_loop(df_clicks, df_edges, args): 56 | while True: 57 | df_long = filter_short_sessions(df_clicks) 58 | df_freq = filter_infreq_items(df_long) 59 | df_conn, df_edges = filter_isolated_users(df_freq, df_edges) 60 | if len(df_conn) == len(df_clicks): 61 | break 62 | df_clicks = df_conn 63 | return df_clicks, df_edges 64 | 65 | 66 | def truncate_long_sessions(df, max_len, is_sorted=False): 67 | if not is_sorted: 68 | df = df.sort_values(['sessionId', 'timestamp']) 69 | itemIdx = df.groupby('sessionId').cumcount() 70 | df_t = df[itemIdx < max_len] 71 | print( 72 | f'removed {len(df) - len(df_t)}/{len(df)} clicks in sessions longer than {max_len}' 73 | ) 74 | return df_t 75 | 76 | 77 | def update_id(*dataframes, colnames, mapping=None): 78 | """ 79 | Map the values in the columns `colnames` of `dataframes` according to `mapping`. 80 | If `mapping` is `None`, a dictionary that maps the values in column `colnames[0]` 81 | of `dataframes[0]` to unique integers will be used. 82 | Note that values not appear in `mapping` will be mapped to `NaN`. 83 | 84 | Args 85 | ---- 86 | dataframes : list[DataFrame] 87 | A list of dataframes. 88 | colnames: str, list[str] 89 | The names of columns. 90 | mapping: function, dict, optional 91 | Mapping correspondence. 92 | 93 | Returns 94 | ------- 95 | DataFrame, list[DataFrame] 96 | A dataframe (if there is only one input dataframe) or a list of dataframes 97 | with columns in `colnames` updated according to `mapping`. 98 | """ 99 | if type(colnames) is str: 100 | colnames = [colnames] 101 | if mapping is None: 102 | uniques = dataframes[0][colnames[0]].unique() 103 | mapping = {oid: i for i, oid in enumerate(uniques)} 104 | results = [] 105 | for df in dataframes: 106 | columns = {} 107 | for name in colnames: 108 | if name in df.columns: 109 | columns[name] = df[name].map(mapping) 110 | df = df.assign(**columns) 111 | results.append(df) 112 | if len(results) == 1: 113 | return results[0] 114 | else: 115 | return results 116 | 117 | 118 | def remove_immediate_repeats(df): 119 | df_prev = df.shift() 120 | is_not_repeat = (df.sessionId != df_prev.sessionId) | (df.itemId != df_prev.itemId) 121 | df_no_repeat = df[is_not_repeat] 122 | print( 123 | f'removed {len(df) - len(df_no_repeat)}/{len(df)} immediate repeat consumptions' 124 | ) 125 | return df_no_repeat 126 | 127 | 128 | def reorder_sessions_by_endtime(df): 129 | endtime = df.groupby('sessionId', sort=False).timestamp.max() 130 | df_endtime = endtime.sort_values().reset_index() 131 | oid2nid = dict(zip(df_endtime.sessionId, df_endtime.index)) 132 | sessionId_new = df.sessionId.map(oid2nid) 133 | df = df.assign(sessionId=sessionId_new) 134 | df = df.sort_values(['sessionId', 'timestamp']) 135 | return df 136 | 137 | 138 | def train_test_split(df, test_split=0.2): 139 | endtime = df.groupby('sessionId', sort=False).timestamp.max() 140 | endtime = endtime.sort_values() 141 | num_tests = int(len(endtime) * test_split) 142 | test_session_ids = endtime.index[-num_tests:] 143 | df_train = df[~df.sessionId.isin(test_session_ids)] 144 | df_test = df[df.sessionId.isin(test_session_ids)] 145 | return df_train, df_test 146 | 147 | 148 | def save_sessions(df_clicks, filepath): 149 | df_clicks = df_clicks.groupby('sessionId').agg({ 150 | 'userId': 151 | lambda col: col.iloc[0], 152 | 'itemId': 153 | lambda col: ','.join(col.astype(str)), 154 | }) 155 | df_clicks.to_csv(filepath, sep='\t', header=['userId', 'items'], index=True) 156 | 157 | 158 | def keep_valid_sessions(df_train, df_test, train_split): 159 | print('\nprocessing test sets...') 160 | uid = df_train.userId.unique() 161 | iid = df_train.itemId.unique() 162 | df_test = df_test[df_test.userId.isin(uid) & df_test.itemId.isin(iid)] 163 | df_test = filter_short_sessions(df_test) 164 | return df_test 165 | 166 | 167 | def keep_top_n(df_clicks, n, colname): 168 | print(f'keeping top {n} most frequent values in column {colname}') 169 | supports = df_clicks.groupby(colname, sort=False).size() 170 | top_values = supports.nlargest(n).index 171 | df_top = df_clicks[df_clicks[colname].isin(top_values)] 172 | print(f'removed {len(supports) - len(top_values)}/{len(supports)} values') 173 | return df_top 174 | 175 | 176 | def update_session_id(df_train, df_test): 177 | df_train = reorder_sessions_by_endtime(df_train) 178 | df_test = reorder_sessions_by_endtime(df_test) 179 | num_train_sessions = df_train.sessionId.max() + 1 180 | df_test = df_test.assign(sessionId=df_test.sessionId + num_train_sessions) 181 | return df_train, df_test 182 | 183 | 184 | def print_stats(df_clicks, name): 185 | print( 186 | f'{name}:\n' 187 | f'No. of clicks: {len(df_clicks)}\n' 188 | f'No. of sessions: {df_clicks.sessionId.nunique()}\n' 189 | f'No. of users: {df_clicks.userId.nunique()}\n' 190 | f'No. of items: {df_clicks.itemId.nunique()}\n' 191 | f'Avg. session length: {len(df_clicks) / df_clicks.sessionId.nunique():.3f}\n' 192 | ) 193 | 194 | 195 | def save_dataset(df_train, df_test, df_edges, df_loc, args): 196 | df_test = keep_valid_sessions(df_train, df_test, args.train_split) 197 | 198 | print(f'No. of Clicks: {len(df_train) + len(df_test)}') 199 | print_stats(df_train, 'Training set') 200 | print_stats(df_test, 'Test set') 201 | print(f'No. of Connections: {len(df_edges)}') 202 | print(f'No. of Followers: {df_edges.follower.nunique()}') 203 | print(f'No. of Followees: {df_edges.followee.nunique()}') 204 | num_users = df_train.userId.nunique() 205 | print(f'Avg. Followers: {len(df_edges) / num_users:.3f}') 206 | 207 | df_train, df_test = update_session_id(df_train, df_test) 208 | 209 | # update userId 210 | df_train, df_test, df_edges = update_id( 211 | df_train, df_test, df_edges, colnames=['userId', 'followee', 'follower'] 212 | ) 213 | 214 | # update itemId 215 | if df_loc is None: 216 | df_train, df_test = update_id(df_train, df_test, colnames='itemId') 217 | else: 218 | df_loc = df_loc[df_loc.itemId.isin(df_train.itemId.unique())] 219 | df_train, df_test, df_loc = update_id( 220 | df_train, df_test, df_loc, colnames='itemId' 221 | ) 222 | df_loc = df_loc.sort_values('itemId') 223 | 224 | dataset_dir = args.output_dir / args.dataset 225 | print(f'saving dataset to {dataset_dir}') 226 | # save sessions 227 | dataset_dir.mkdir(parents=True, exist_ok=True) 228 | save_sessions(df_train, dataset_dir / 'train.txt') 229 | # randomly and evenly split df_test into df_valid and df_test 230 | valid_test_sids = df_test.sessionId.unique() 231 | num_valid_sessions = len(valid_test_sids) // 2 232 | valid_sids = np.random.choice(valid_test_sids, num_valid_sessions, replace=False) 233 | df_valid = df_test[df_test.sessionId.isin(valid_sids)] 234 | df_test = df_test[~df_test.sessionId.isin(valid_sids)] 235 | save_sessions(df_valid, dataset_dir / 'valid.txt') 236 | save_sessions(df_test, dataset_dir / 'test.txt') 237 | 238 | if df_loc is not None: 239 | df_loc.to_csv( 240 | dataset_dir / 'loc.txt', 241 | sep='\t', 242 | index=False, 243 | header=True, 244 | float_format='%.2f' 245 | ) 246 | 247 | # save social network 248 | df_edges = df_edges.sort_values(['followee', 'follower']) 249 | df_edges.to_csv(dataset_dir / 'edges.txt', sep='\t', header=True, index=False) 250 | 251 | # save stats 252 | num_users = df_train.userId.nunique() 253 | num_items = df_train.itemId.nunique() 254 | with open(dataset_dir / 'stats.txt', 'w') as f: 255 | f.write('num_users\tnum_items\tmax_len\n') 256 | f.write(f'{num_users}\t{num_items}\t{args.max_len}') 257 | 258 | 259 | def preprocess(df_clicks, df_edges, df_loc, args): 260 | print('arguments: ', args) 261 | if 'sessionId' in df_clicks.columns: 262 | print('clicks are already grouped into sessions') 263 | df_clicks = df_clicks.sort_values(['userId', 'sessionId', 'timestamp']) 264 | sessionId = df_clicks.userId.astype(str) + '_' + df_clicks.sessionId.astype(str) 265 | df_clicks = df_clicks.assign(sessionId=sessionId) 266 | df_clicks = update_id(df_clicks, colnames='sessionId') 267 | else: 268 | df_clicks = group_sessions(df_clicks, args.interval) 269 | df_clicks = remove_immediate_repeats(df_clicks) 270 | if args.max_len > 0: 271 | df_clicks = truncate_long_sessions(df_clicks, args.max_len, is_sorted=True) 272 | if 'max_users' in args: 273 | df_clicks = keep_top_n(df_clicks, args.max_users, 'userId') 274 | if 'max_items' in args: 275 | df_clicks = keep_top_n(df_clicks, args.max_items, 'itemId') 276 | df_train, df_test = train_test_split(df_clicks, test_split=1 - args.train_split) 277 | df_train, df_edges = filter_loop(df_train, df_edges, args) 278 | save_dataset(df_train, df_test, df_edges, df_loc, args) 279 | -------------------------------------------------------------------------------- /srs/utils/data/transform.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import dgl 3 | import numpy as np 4 | from collections import Counter 5 | 6 | 7 | def label_last(g, last_nid): 8 | is_last = th.zeros(g.number_of_nodes(), dtype=th.int32) 9 | is_last[last_nid] = 1 10 | g.ndata['last'] = is_last 11 | return g 12 | 13 | 14 | def seq_to_unweighted_graph(seq): 15 | iid, seq_nid, cnt = np.unique(seq, return_inverse=True, return_counts=True) 16 | num_nodes = len(iid) 17 | 18 | if len(seq_nid) > 1: 19 | edges = zip(seq_nid, seq_nid[1:]) 20 | counter = Counter(edges) 21 | unique_edges = counter.keys() 22 | src, dst = zip(*unique_edges) 23 | else: 24 | src = th.LongTensor([]) 25 | dst = th.LongTensor([]) 26 | 27 | g = dgl.graph((src, dst), num_nodes=num_nodes) 28 | g.ndata['iid'] = th.LongTensor(iid) 29 | g.ndata['cnt'] = th.FloatTensor(cnt) 30 | label_last(g, seq_nid[-1]) 31 | return g 32 | 33 | 34 | def seq_to_weighted_graph(seq): 35 | iid, seq_nid, cnt = np.unique(seq, return_inverse=True, return_counts=True) 36 | num_nodes = len(iid) 37 | 38 | if len(seq_nid) > 1: 39 | counter = Counter(zip(seq_nid, seq_nid[1:])) 40 | src, dst = zip(*counter.keys()) 41 | weight = th.FloatTensor(list(counter.values())) 42 | else: 43 | src = th.LongTensor([]) 44 | dst = th.LongTensor([]) 45 | weight = th.FloatTensor([]) 46 | 47 | g = dgl.graph((src, dst), num_nodes=num_nodes) 48 | g.ndata['iid'] = th.LongTensor(iid) 49 | g.ndata['cnt'] = th.FloatTensor(cnt) 50 | g.edata['w'] = weight.view(g.num_edges(), 1) 51 | label_last(g, seq_nid[-1]) 52 | return g 53 | -------------------------------------------------------------------------------- /srs/utils/prepare_batch.py: -------------------------------------------------------------------------------- 1 | def prepare_batch_factory(device): 2 | def prepare_batch(batch): 3 | inputs, labels = batch 4 | inputs_gpu = [x.to(device) for x in inputs] 5 | labels_gpu = labels.to(device) 6 | return inputs_gpu, labels_gpu 7 | 8 | return prepare_batch 9 | 10 | 11 | def prepare_batch_factory_recursive(device): 12 | def prepare_batch_recursive(batch): 13 | if type(batch) is list or type(batch) is tuple: 14 | return [prepare_batch_recursive(x) for x in batch] 15 | elif type(batch) is dict: 16 | return {k: prepare_batch_recursive(v) for k, v in batch.items()} 17 | else: 18 | return batch.to(device) 19 | 20 | return prepare_batch_recursive 21 | -------------------------------------------------------------------------------- /srs/utils/train_runner.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | import time 4 | 5 | import torch as th 6 | from torch import nn, optim 7 | 8 | 9 | def evaluate(model, data_loader, prepare_batch, Ks=[20]): 10 | model.eval() 11 | results = defaultdict(float) 12 | max_K = max(Ks) 13 | num_samples = 0 14 | with th.no_grad(): 15 | for batch in data_loader: 16 | inputs, labels = prepare_batch(batch) 17 | logits = model(*inputs) 18 | batch_size = logits.size(0) 19 | num_samples += batch_size 20 | topk = th.topk(logits, k=max_K, sorted=True)[1] 21 | labels = labels.unsqueeze(-1) 22 | for K in Ks: 23 | hit_ranks = th.where(topk[:, :K] == labels)[1] + 1 24 | hit_ranks = hit_ranks.float().cpu() 25 | results[f'HR@{K}'] += hit_ranks.numel() 26 | results[f'MRR@{K}'] += hit_ranks.reciprocal().sum().item() 27 | results[f'NDCG@{K}'] += th.log2(1 + hit_ranks).reciprocal().sum().item() 28 | for metric in results: 29 | results[metric] /= num_samples 30 | return results 31 | 32 | 33 | def fix_weight_decay(model, ignore_list=['bias', 'batch_norm']): 34 | decay = [] 35 | no_decay = [] 36 | logging.debug('ignore weight decay for ' + ', '.join(ignore_list)) 37 | for name, param in model.named_parameters(): 38 | if not param.requires_grad: 39 | continue 40 | if any(map(lambda x: x in name, ignore_list)): 41 | no_decay.append(param) 42 | else: 43 | decay.append(param) 44 | params = [{'params': decay}, {'params': no_decay, 'weight_decay': 0}] 45 | return params 46 | 47 | 48 | def print_results(*results_list): 49 | metrics = list(results_list[0][1].keys()) 50 | logging.warning('Metric\t' + '\t'.join(metrics)) 51 | for name, results in results_list: 52 | logging.warning( 53 | name + '\t' + 54 | '\t'.join([f'{round(results[metric] * 100, 2):.2f}' for metric in metrics]) 55 | ) 56 | 57 | 58 | class TrainRunner: 59 | def __init__( 60 | self, 61 | train_loader, 62 | valid_loader, 63 | test_loader, 64 | model, 65 | prepare_batch, 66 | Ks=[20], 67 | lr=1e-3, 68 | weight_decay=0, 69 | ignore_list=None, 70 | patience=2, 71 | OTF=False, 72 | **kwargs, 73 | ): 74 | self.model = model 75 | if weight_decay > 0: 76 | if ignore_list is not None: 77 | params = fix_weight_decay(model, ignore_list) 78 | else: 79 | params = model.parameters() 80 | self.optimizer = optim.AdamW(params, lr=lr, weight_decay=weight_decay) 81 | self.criterion = nn.CrossEntropyLoss() 82 | self.train_loader = train_loader 83 | self.valid_loader = valid_loader 84 | self.test_loader = test_loader 85 | self.prepare_batch = prepare_batch 86 | self.Ks = Ks 87 | self.epoch = 0 88 | self.batch = 0 89 | self.patience = patience if patience > 0 else 2 90 | self.precompute = hasattr(model, 'KGE_layer') and not OTF 91 | 92 | def train(self, epochs, log_interval=100): 93 | best_results = defaultdict(float) 94 | report_results = defaultdict(float) 95 | bad_counter = 0 96 | t = time.time() 97 | mean_loss = 0 98 | for epoch in range(epochs): 99 | self.model.train() 100 | train_ts = time.time() 101 | for batch in self.train_loader: 102 | inputs, labels = self.prepare_batch(batch) 103 | self.optimizer.zero_grad() 104 | logits = self.model(*inputs) 105 | loss = self.criterion(logits, labels) 106 | loss.backward() 107 | self.optimizer.step() 108 | mean_loss += loss.item() / log_interval 109 | if self.batch > 0 and self.batch % log_interval == 0: 110 | logging.info( 111 | f'Batch {self.batch}: Loss = {mean_loss:.4f}, Elapsed Time = {time.time() - t:.2f}s' 112 | ) 113 | t = time.time() 114 | mean_loss = 0 115 | self.batch += 1 116 | eval_ts = time.time() 117 | logging.debug( 118 | f'Training time per {log_interval} batches: ' 119 | f'{(eval_ts - train_ts) / len(self.train_loader) * log_interval:.2f}s' 120 | ) 121 | if self.precompute: 122 | ts = time.time() 123 | self.model.precompute_KG_embeddings() 124 | te = time.time() 125 | logging.debug(f'Precomuting KG embeddings took {te - ts:.2f}s') 126 | 127 | ts = time.time() 128 | valid_results = evaluate( 129 | self.model, self.valid_loader, self.prepare_batch, self.Ks 130 | ) 131 | test_results = evaluate( 132 | self.model, self.test_loader, self.prepare_batch, self.Ks 133 | ) 134 | if self.precompute: 135 | # release precomputed KG embeddings 136 | self.model.KG_embeddings = None 137 | th.cuda.empty_cache() 138 | te = time.time() 139 | num_batches = len(self.valid_loader) + len(self.test_loader) 140 | logging.debug( 141 | f'Evaluation time per {log_interval} batches: ' 142 | f'{(te - ts) / num_batches * log_interval:.2f}s' 143 | ) 144 | 145 | logging.warning(f'Epoch {self.epoch}:') 146 | print_results(('Valid', valid_results), ('Test', test_results)) 147 | 148 | any_better_result = False 149 | for metric in valid_results: 150 | if valid_results[metric] > best_results[metric]: 151 | best_results[metric] = valid_results[metric] 152 | report_results[metric] = test_results[metric] 153 | any_better_result = True 154 | 155 | if any_better_result: 156 | bad_counter = 0 157 | else: 158 | bad_counter += 1 159 | if bad_counter == self.patience: 160 | break 161 | self.epoch += 1 162 | eval_te = time.time() 163 | t += eval_te - eval_ts 164 | print_results(('Report', report_results)) 165 | return report_results 166 | --------------------------------------------------------------------------------