├── .gitignore ├── requirements.txt ├── config.yaml ├── config.py ├── readme.md ├── train.py ├── utils.py ├── eval_catboost.py ├── prepare_catboost_dataset.py ├── prepare_brightkite_dataset.py ├── prepare_gowalla_dataset.py ├── fit_catboost.py ├── metrics.py ├── dataloader.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/* 3 | .ipynb_checkpoints/* 4 | __pycache__/* 5 | *.log 6 | runs/* -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PyYAML==5.3.1 2 | torch==1.8.0 3 | loguru==0.5.3 4 | tqdm==4.51.0 5 | scipy==1.6.1 6 | faiss==1.7.0 7 | numpy==1.16.6 8 | pandas==1.2.2 9 | yaml==0.2.5 10 | wget==3.2 11 | tensorboard==1.15.0 12 | implicit==0.4.0 13 | haversine==2.3.0 14 | catboost==0.24.4 -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | # LightGCN, TopNModel, TopNPersonalized, TopNNearestModel, iALS 2 | MODEL: LightGCN 3 | DATASET: gowalla 4 | USE_TENSORBOARD: True 5 | 6 | NUM_RAT_FOR_ITEM: 5 7 | NUM_RAT_FOR_USER: 5 8 | N_NEGATIVES: 20 9 | 10 | METRICS_REPORT: [1, 10, 20] 11 | EVAL_EPOCHS: 0 12 | 13 | TRAIN_EPOCHS: 1000 14 | BATCH_SIZE: 1024 15 | 16 | # LightGCN config 17 | LATENT_DIM: 64 18 | N_LAYERS: 3 19 | BPR_REG_ALPHA: 0.001 20 | 21 | # TopNModel config 22 | TOP_N: 100 23 | 24 | # iALS config 25 | ALS_N_ITERATIONS: 50 26 | ALS_N_FACTORS: 64 27 | 28 | # Gowalla dataset config 29 | SPLIT_DATE: 2010-09-04 30 | TEST_DAYS: 15 31 | VAL_DAYS: 15 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from loguru import logger 4 | from time import gmtime, strftime 5 | from utils import TensorboardWriter 6 | 7 | 8 | if not os.path.isdir('logs'): 9 | os.mkdir('logs') 10 | current_time = strftime("%Y-%m-%d_%H:%M:%S", gmtime()) 11 | logger.add(f'logs/train_{current_time}.log') 12 | 13 | # problem on macOS 14 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 15 | 16 | with open('config.yaml') as f: 17 | config = yaml.safe_load(f) 18 | 19 | tensorboard_writer = None 20 | if 'USE_TENSORBOARD' in config and config['USE_TENSORBOARD']: 21 | tensorboard_writer = TensorboardWriter(f'runs/{current_time}') 22 | 23 | logger.info(f'config loaded: {config}') 24 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ### LightGCN on Pytorch 2 | 3 | This is a implementation of LightGCN ([Paper in arXiv](https://arxiv.org/abs/2002.02126)) neural net from SIGIR 2020 4 | 5 | ### Supported datasets: 6 | - [gowalla](https://snap.stanford.edu/data/loc-gowalla.html) 7 | - [brightkite](https://snap.stanford.edu/data/loc-brightkite.html) 8 | 9 | Use `prepare__dataset.py` for download and splitting by time 10 | 11 | ### Supported models: 12 | - [iALS](https://implicit.readthedocs.io/en/latest/als.html) is matrix factorization model from `implicit` open-source library 13 | - TopNModel recommends top items from all user feedback 14 | - TopNPersonalized recommends top items from unique user feedback 15 | - TopNNearestModel recommends nearest by last user location items (domain-specific for geo features) 16 | - [LightGCN](https://arxiv.org/abs/2002.02126) 17 | - [Catboost](https://catboost.ai) fitting with LogLoss/YetiRank and ranking candidates 18 | 19 | ### Training: 20 | 21 | Main script is `train.py` which trains model from `MODEL` setting in `config.yaml` file 22 | 23 | Also there is `fit_catboost.py` script which trains catboost ranking model -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import implicit 3 | import pandas as pd 4 | from pathlib import Path 5 | from config import config 6 | from model import LightGCN, TopNModel, TopNPersonalized, TopNNearestModel 7 | from dataloader import GowallaLightGCNDataset, GowallaTopNDataset, GowallaALSDataset 8 | 9 | if __name__ == '__main__': 10 | dataset_path = Path('dataset') / config['DATASET'] / config['DATASET'] 11 | if config['MODEL'] == 'LightGCN': 12 | train_dataset = GowallaLightGCNDataset(f'{dataset_path}_custom.train') 13 | test_dataset = GowallaLightGCNDataset(f'{dataset_path}_custom.test', train=False) 14 | 15 | model = LightGCN(train_dataset) 16 | model.fit(config['TRAIN_EPOCHS'], test_dataset) 17 | elif config['MODEL'] == 'TopNModel': 18 | train_dataset = GowallaTopNDataset(f'{dataset_path}.train') 19 | test_dataset = GowallaTopNDataset(f'{dataset_path}.test', train=False) 20 | 21 | model = TopNModel(config['TOP_N']) 22 | model.fit(train_dataset) 23 | model.eval(test_dataset) 24 | elif config['MODEL'] == 'TopNPersonalized': 25 | train_dataset = GowallaTopNDataset(f'{dataset_path}.train') 26 | test_dataset = GowallaTopNDataset(f'{dataset_path}.test', train=False) 27 | 28 | model = TopNPersonalized(config['TOP_N']) 29 | model.fit(train_dataset) 30 | model.eval(test_dataset) 31 | elif config['MODEL'] == 'TopNNearestModel': 32 | train_dataset = GowallaTopNDataset(f'{dataset_path}.train') 33 | test_dataset = GowallaTopNDataset(f'{dataset_path}.test', train=False) 34 | 35 | df = pd.concat([train_dataset.df, test_dataset.df]) 36 | calc_nearest = utils.calc_nearest(df) 37 | model = TopNNearestModel(config['TOP_N'], calc_nearest) 38 | model.fit(train_dataset) 39 | model.eval(test_dataset) 40 | elif config['MODEL'] == 'iALS': 41 | gowalla_train, user_item_data, item_user_data = GowallaALSDataset( 42 | f'{dataset_path}.train').get_dataset() 43 | gowalla_test = GowallaALSDataset(f'{dataset_path}.test', train=False).get_dataset() 44 | model = implicit.als.AlternatingLeastSquares( 45 | iterations=config['ALS_N_ITERATIONS'], factors=config['ALS_N_FACTORS']) 46 | model.fit_callback = utils.eval_als_model(model, user_item_data, gowalla_test) 47 | model.fit(item_user_data) 48 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import metrics 4 | import haversine 5 | from loguru import logger 6 | from collections import defaultdict 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | def print_progressbar(current, total, width=80): 11 | progress_message = "Downloading: %d%% [%d / %d] bytes" % (current / total * 100, current, total) 12 | # Don't use print() as it will print in new line every time. 13 | sys.stdout.write("\r" + progress_message) 14 | sys.stdout.flush() 15 | 16 | 17 | class TensorboardWriter(SummaryWriter): 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.n_iter = defaultdict(lambda: 0) 21 | 22 | def add_scalar(self, tag, scalar_value, global_step=None, walltime=None): 23 | if not global_step: 24 | global_step = self.n_iter[tag] 25 | self.n_iter[tag] += 1 26 | super().add_scalar(tag, scalar_value, global_step, walltime) 27 | 28 | 29 | def eval_als_model(model, user_item_data, gowalla_test): 30 | from config import config 31 | 32 | def inner(iteration, elapsed): 33 | preds = [] 34 | ground_truth = [] 35 | n_recommend = max(config['METRICS_REPORT']) 36 | test_interactions = gowalla_test.groupby('userId')['loc_id'].apply(list).to_dict() 37 | for userId in gowalla_test['userId'].unique(): 38 | preds.append( 39 | list(map(lambda x: x[0], model.recommend(userId, user_item_data, n_recommend)))) 40 | ground_truth.append(test_interactions[userId]) 41 | 42 | logger.info(f'{iteration} iteration:') 43 | max_length = max(map(len, metrics.metric_dict.keys())) + max( 44 | map(lambda x: len(str(x)), config['METRICS_REPORT'])) 45 | for metric_name, metric_func in metrics.metric_dict.items(): 46 | for k in config['METRICS_REPORT']: 47 | metric_name_total = f'{metric_name}@{k}' 48 | metric_value = metric_func(preds, ground_truth, k).mean() 49 | logger.info(f'{metric_name_total: >{max_length + 1}} = {metric_value}') 50 | 51 | return inner 52 | 53 | 54 | def calc_nearest(df): 55 | df = df.set_index('loc_id') 56 | item_lat = df['lat'].to_dict() 57 | item_long = df['long'].to_dict() 58 | locations = {item: (item_long[item], item_lat[item]) for item in item_lat} 59 | 60 | def inner(item_id, k=20): 61 | loc = locations[item_id] 62 | distances = [ 63 | (item, haversine.haversine(loc, location)) for item, location in locations.items()] 64 | return list(map(lambda x: x[0], sorted(distances, key=lambda x: x[1])[:k])) 65 | return inner 66 | 67 | 68 | def collate_function(batch): 69 | users = [] 70 | pos_items = [] 71 | neg_items = [] 72 | for user, pos, neg in batch: 73 | users.extend([user for _ in pos]) 74 | pos_items.extend(pos) 75 | neg_items.extend(neg) 76 | return list(map(torch.tensor, [users, pos_items, neg_items])) 77 | -------------------------------------------------------------------------------- /eval_catboost.py: -------------------------------------------------------------------------------- 1 | import metrics 2 | import numpy as np 3 | import pandas as pd 4 | from collections import defaultdict 5 | from catboost import CatBoostClassifier 6 | 7 | if __name__ == '__main__': 8 | candidates_dataset = pd.read_csv('catboost_eval_dataset.csv', names=['userId', 'itemId']) 9 | 10 | gowalla_train = pd.read_csv('dataset/gowalla.traintest', 11 | names=['userId', 'timestamp', 'long', 'lat', 'itemId']) 12 | 13 | gowalla_val = pd.read_csv('dataset/gowalla.val', 14 | names=['userId', 'timestamp', 'long', 'lat', 'itemId']) 15 | gowalla_dataset = pd.concat([gowalla_train, gowalla_val]) 16 | 17 | # count geopositions of items 18 | tmp = gowalla_dataset.drop_duplicates('itemId').set_index('itemId') 19 | item_lat = tmp['lat'].to_dict() 20 | item_long = tmp['long'].to_dict() 21 | del tmp 22 | 23 | user_num_interactions = defaultdict( 24 | lambda: 0, gowalla_train.groupby('userId')['itemId'].count().to_dict()) 25 | item_num_interactions = defaultdict( 26 | lambda: 0, gowalla_train.groupby('itemId')['userId'].count().to_dict()) 27 | user_item_interactions = defaultdict( 28 | lambda: 0, gowalla_train.groupby(['userId', 'itemId'])['timestamp'].count().to_dict()) 29 | candidates_user_item_interactions = list( 30 | map(lambda x: user_item_interactions[(x[0], x[1])], 31 | candidates_dataset.loc[:, ['userId', 'itemId']].values)) 32 | 33 | gowalla_friendships = pd.read_csv('dataset/gowalla.friends', header=None, 34 | names=['user1', 'user2']) 35 | user_num_friends = gowalla_friendships.groupby('user1')['user2'].count().to_dict() 36 | 37 | candidates_dataset['num_friends'] = candidates_dataset['userId'].map(user_num_friends) 38 | candidates_dataset['user_num_interactions'] = \ 39 | candidates_dataset['userId'].map(user_num_interactions) 40 | candidates_dataset['item_num_interactions'] = \ 41 | candidates_dataset['userId'].map(item_num_interactions) 42 | candidates_dataset['user_item_interactions'] = candidates_user_item_interactions 43 | candidates_dataset['long'] = candidates_dataset['itemId'].map(item_long) 44 | candidates_dataset['lat'] = candidates_dataset['itemId'].map(item_lat) 45 | 46 | model = CatBoostClassifier( 47 | iterations=5000, loss_function='Logloss', eval_metric='AUC', verbose=10) 48 | model.load_model('catboost.cbm') 49 | candidates_dataset['target_pred'] = model.predict_proba( 50 | candidates_dataset.drop(['userId', 'itemId'], axis=1))[:, 1] 51 | 52 | k = 20 53 | catboost_hitrates = [] 54 | for user, df in candidates_dataset.groupby('userId'): 55 | preds = df.sort_values('target_pred', ascending=False).head(k)['itemId'].values 56 | ground_truth = gowalla_val[gowalla_val['userId'] == user]['itemId'].values 57 | if len(ground_truth) > 0: 58 | catboost_hitrates.append(metrics.user_hitrate(preds, ground_truth)) 59 | 60 | als_hitrates = [] 61 | als_candidates = pd.read_csv('als_candidates.csv', names=['userId', 'itemId']) 62 | for user, df in als_candidates.groupby('userId'): 63 | preds = df.head(k)['itemId'].values 64 | ground_truth = gowalla_val[gowalla_val['userId'] == user]['itemId'].values 65 | if len(ground_truth) > 0: 66 | als_hitrates.append(metrics.user_hitrate(preds, ground_truth)) 67 | 68 | topn_hitrates = [] 69 | topn_candidates = pd.read_csv('topn_candidates.csv', names=['userId', 'itemId']) 70 | for user, df in topn_candidates.groupby('userId'): 71 | preds = df.head(k)['itemId'].values 72 | ground_truth = gowalla_val[gowalla_val['userId'] == user]['itemId'].values 73 | if len(ground_truth) > 0: 74 | topn_hitrates.append(metrics.user_hitrate(preds, ground_truth)) 75 | 76 | print(f'als hitrate@{k}: {np.mean(als_hitrates)}') 77 | print(f'topn hitrate@{k}: {np.mean(topn_hitrates)}') 78 | print(f'catboost hitrate@{k}: {np.mean(catboost_hitrates)}') 79 | -------------------------------------------------------------------------------- /prepare_catboost_dataset.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import torch 3 | import implicit 4 | import pandas as pd 5 | from model import LightGCN, TopNModel, TopNPersonalized, TopNNearestModel 6 | from dataloader import GowallaLightGCNDataset, GowallaTopNDataset, GowallaALSDataset 7 | 8 | 9 | def get_als_recommendations(path): 10 | gowalla_train, user_item_data, item_user_data = GowallaALSDataset(path) \ 11 | .get_dataset(5739, 56261) # n_users and m_items from list.txt 12 | 13 | model = implicit.als.AlternatingLeastSquares(iterations=5, factors=64) 14 | model.fit(item_user_data) 15 | 16 | users = [] 17 | items_pred = [] 18 | for user in range(5738 + 1): 19 | preds = list(map(lambda x: x[0], model.recommend(user, user_item_data, 20))) 20 | users.extend([user for _ in preds]) 21 | items_pred.extend(preds) 22 | return pd.DataFrame({'userId': users, 'itemId': items_pred}) \ 23 | .drop_duplicates(['userId', 'itemId']) 24 | 25 | 26 | def get_topn_recommendations(path): 27 | train_dataset = GowallaTopNDataset(path) 28 | 29 | model = TopNPersonalized(15) 30 | model.fit(train_dataset) 31 | 32 | users = [] 33 | items_pred = [] 34 | for user in range(5738 + 1): 35 | preds = list(model.recommend([user])[0]) 36 | users.extend([user for _ in preds]) 37 | items_pred.extend(preds) 38 | return pd.DataFrame({'userId': users, 'itemId': items_pred}) \ 39 | .drop_duplicates(['userId', 'itemId']) 40 | 41 | 42 | def get_top_nearest_recommendations(train_path, locations_path): 43 | train_dataset = GowallaTopNDataset(train_path) 44 | df = pd.read_csv(locations_path, names=['loc_id', 'long', 'lat']) 45 | calc_nearest = utils.calc_nearest(df) 46 | model = TopNNearestModel(15, calc_nearest) 47 | model.fit(train_dataset) 48 | 49 | users = [] 50 | items_pred = [] 51 | for user in range(5738 + 1): 52 | preds = list(model.recommend([user])[0]) 53 | users.extend([user for _ in preds]) 54 | items_pred.extend(preds) 55 | return pd.DataFrame({'userId': users, 'itemId': items_pred}) \ 56 | .drop_duplicates(['userId', 'itemId']) 57 | 58 | 59 | def get_lightgcn_recommendations(path): 60 | train_dataset = GowallaLightGCNDataset(path) 61 | model = LightGCN(train_dataset) 62 | model.fit(100) 63 | 64 | users = [] 65 | items_pred = [] 66 | preds = list(model.recommend(torch.tensor([user for user in range(5738 + 1)]))) 67 | for user in range(5738 + 1): 68 | users.extend([user for _ in preds[user]]) 69 | items_pred.extend(preds[user]) 70 | return pd.DataFrame({'userId': users, 'itemId': items_pred}) \ 71 | .drop_duplicates(['userId', 'itemId']) 72 | 73 | 74 | def get_recommendations(train_path, locations_path): 75 | als_recommendations = get_als_recommendations(train_path) 76 | topn_recommendations = get_topn_recommendations(train_path) 77 | # top_nearest_recommendations = get_top_nearest_recommendations(train_path, locations_path) 78 | lightgcn_recommendations = get_lightgcn_recommendations(train_path) 79 | return pd.concat([als_recommendations, topn_recommendations, lightgcn_recommendations]) \ 80 | .drop_duplicates(['userId', 'itemId']) 81 | 82 | 83 | if __name__ == '__main__': 84 | # make candidates for catboost training 85 | get_recommendations('dataset/gowalla.train', 'dataset/gowalla.locations') \ 86 | .to_csv('catboost_train_dataset.csv', index=False, header=False) 87 | 88 | # make candidates for catboost eval 89 | get_recommendations('dataset/gowalla.traintest', 'dataset/gowalla.locations') \ 90 | .to_csv('catboost_eval_dataset.csv', index=False, header=False) 91 | 92 | get_als_recommendations('dataset/gowalla.traintest') \ 93 | .to_csv('als_candidates.csv', index=False, header=False) 94 | get_topn_recommendations('dataset/gowalla.traintest') \ 95 | .to_csv('topn_candidates.csv', index=False, header=False) 96 | get_lightgcn_recommendations('dataset/gowalla.traintest') \ 97 | .to_csv('lightgcn_candidates.csv', index=False, header=False) 98 | -------------------------------------------------------------------------------- /prepare_brightkite_dataset.py: -------------------------------------------------------------------------------- 1 | import wget 2 | import pandas as pd 3 | from pathlib import Path 4 | from config import config 5 | from utils import print_progressbar 6 | 7 | if __name__ == '__main__': 8 | dataset_dir = Path('dataset') 9 | if not dataset_dir.exists(): 10 | dataset_dir.mkdir(parents=True, exist_ok=True) 11 | 12 | dataset_path = dataset_dir / 'loc-brightkite_totalCheckins.txt.gz' 13 | if not dataset_path.exists(): 14 | wget.download('https://snap.stanford.edu/data/loc-brightkite_edges.txt.gz', 15 | out=str(dataset_path), bar=print_progressbar) 16 | wget.download('https://snap.stanford.edu/data/loc-brightkite_totalCheckins.txt.gz', 17 | out=str(dataset_path), bar=print_progressbar) 18 | brightkite_dataset = pd.read_csv( 19 | dataset_path, sep='\t', names=['userId', 'timestamp', 'long', 'lat', 'loc_id']) 20 | brightkite_dataset['timestamp'] = pd.to_datetime(brightkite_dataset['timestamp']).dt.tz_localize(None) 21 | 22 | split_date = pd.to_datetime(config['SPLIT_DATE']) 23 | start_date = brightkite_dataset['timestamp'].min() \ 24 | if 'TRAIN_DAYS' not in config \ 25 | else pd.to_datetime(split_date - pd.DateOffset(days=config['TRAIN_DAYS'])) 26 | end_test_date = split_date + pd.DateOffset(days=config['TEST_DAYS']) 27 | end_date = pd.to_datetime( 28 | end_test_date + pd.DateOffset(days=config['VAL_DAYS']) if 'VAL_DAYS' in config else 0) 29 | 30 | timestamp_filter = (brightkite_dataset['timestamp'] >= start_date) & ( 31 | brightkite_dataset['timestamp'] <= end_date) 32 | brightkite_dataset = brightkite_dataset[timestamp_filter] 33 | brightkite_dataset.sort_values('timestamp', inplace=True) 34 | 35 | new_user_ids = {k: v for v, k in enumerate(brightkite_dataset['userId'].unique())} 36 | new_item_ids = {k: v for v, k in enumerate(brightkite_dataset['loc_id'].unique())} 37 | 38 | brightkite_dataset['userId'] = brightkite_dataset['userId'].map(new_user_ids) 39 | brightkite_dataset['loc_id'] = brightkite_dataset['loc_id'].map(new_item_ids) 40 | 41 | with open(dataset_dir / 'user_list.txt', 'w') as f: 42 | f.write('org_id remap_id\n') 43 | for org_id, remap_id in new_user_ids.items(): 44 | f.write(f'{org_id} {remap_id}\n') 45 | 46 | print('user_list.txt saved') 47 | 48 | with open(dataset_dir / 'item_list.txt', 'w') as f: 49 | f.write('org_id remap_id\n') 50 | for org_id, remap_id in new_item_ids.items(): 51 | f.write(f'{org_id} {remap_id}\n') 52 | 53 | print('item_list.txt saved') 54 | 55 | train_filter = (brightkite_dataset['timestamp'] >= start_date) & ( 56 | brightkite_dataset['timestamp'] <= split_date) 57 | brightkite_train = brightkite_dataset[train_filter] 58 | 59 | test_filter = (brightkite_dataset['timestamp'] > split_date) & ( 60 | brightkite_dataset['timestamp'] <= end_test_date) 61 | brightkite_test = brightkite_dataset[test_filter] 62 | 63 | if 'VAL_DAYS' in config: 64 | val_filter = (brightkite_dataset['timestamp'] > end_test_date) & ( 65 | brightkite_dataset['timestamp'] <= end_date) 66 | brightkite_val = brightkite_dataset[val_filter] 67 | pd.concat([brightkite_train, brightkite_test]).to_csv( 68 | dataset_dir / 'brightkite.traintest', index=None, header=None) 69 | brightkite_val.to_csv(dataset_dir / 'brightkite.val', index=None, header=None) 70 | 71 | brightkite_train.to_csv(dataset_dir / 'brightkite.train', index=None, header=None) 72 | brightkite_test.to_csv(dataset_dir / 'brightkite.test', index=None, header=None) 73 | brightkite_dataset.loc[:, ['loc_id', 'long', 'lat']] \ 74 | .to_csv(dataset_dir / 'brightkite.locations', index=None, header=None) 75 | 76 | print('dataset splits saved') 77 | 78 | unique_users = set(brightkite_dataset['userId'].unique()) 79 | brightkite_friendships = pd.read_csv( 80 | 'dataset/loc-brightkite_edges.txt.gz', sep='\t', names=['user1', 'user2']) 81 | brightkite_friendships[(brightkite_friendships['user1'].isin(unique_users)) & 82 | (brightkite_friendships['user2'].isin(unique_users))] \ 83 | .to_csv(dataset_dir / 'brightkite.friends', index=None, header=None) 84 | -------------------------------------------------------------------------------- /prepare_gowalla_dataset.py: -------------------------------------------------------------------------------- 1 | import wget 2 | import pandas as pd 3 | from pathlib import Path 4 | from config import config 5 | from utils import print_progressbar 6 | 7 | if __name__ == '__main__': 8 | dataset_dir = Path('dataset') 9 | if not dataset_dir.exists(): 10 | dataset_dir.mkdir(parents=True, exist_ok=True) 11 | 12 | dataset_path = dataset_dir / 'gowalla' 13 | if not dataset_path.exists(): 14 | dataset_path.mkdir(parents=True, exist_ok=True) 15 | wget.download( 16 | 'https://snap.stanford.edu/data/loc-gowalla_edges.txt.gz', 17 | out=str(dataset_path / 'loc-gowalla_edges.txt.gz'), bar=print_progressbar) 18 | wget.download( 19 | 'https://snap.stanford.edu/data/loc-gowalla_totalCheckins.txt.gz', 20 | out=str(dataset_path / 'loc-gowalla_totalCheckins.txt.gz'), bar=print_progressbar) 21 | 22 | gowalla_dataset = pd.read_csv( 23 | dataset_path / 'loc-gowalla_totalCheckins.txt.gz', 24 | sep='\t', names=['userId', 'timestamp', 'long', 'lat', 'loc_id']) 25 | gowalla_dataset['timestamp'] = pd.to_datetime(gowalla_dataset['timestamp']).dt.tz_localize(None) 26 | 27 | split_date = pd.to_datetime(config['SPLIT_DATE']) 28 | start_date = gowalla_dataset['timestamp'].min() \ 29 | if 'TRAIN_DAYS' not in config \ 30 | else pd.to_datetime(split_date - pd.DateOffset(days=config['TRAIN_DAYS'])) 31 | end_test_date = split_date + pd.DateOffset(days=config['TEST_DAYS']) 32 | end_date = pd.to_datetime( 33 | end_test_date + pd.DateOffset(days=config['VAL_DAYS']) if 'VAL_DAYS' in config else 0) 34 | 35 | timestamp_filter = (gowalla_dataset['timestamp'] >= start_date) & ( 36 | gowalla_dataset['timestamp'] <= end_date) 37 | gowalla_dataset = gowalla_dataset[timestamp_filter] 38 | gowalla_dataset.sort_values('timestamp', inplace=True) 39 | 40 | new_user_ids = {k: v for v, k in enumerate(gowalla_dataset['userId'].unique())} 41 | new_item_ids = {k: v for v, k in enumerate(gowalla_dataset['loc_id'].unique())} 42 | 43 | gowalla_dataset['userId'] = gowalla_dataset['userId'].map(new_user_ids) 44 | gowalla_dataset['loc_id'] = gowalla_dataset['loc_id'].map(new_item_ids) 45 | 46 | with open(dataset_dir / 'user_list.txt', 'w') as f: 47 | f.write('org_id remap_id\n') 48 | for org_id, remap_id in new_user_ids.items(): 49 | f.write(f'{org_id} {remap_id}\n') 50 | 51 | print('user_list.txt saved') 52 | 53 | with open(dataset_dir / 'item_list.txt', 'w') as f: 54 | f.write('org_id remap_id\n') 55 | for org_id, remap_id in new_item_ids.items(): 56 | f.write(f'{org_id} {remap_id}\n') 57 | 58 | print('item_list.txt saved') 59 | 60 | train_filter = (gowalla_dataset['timestamp'] >= start_date) & ( 61 | gowalla_dataset['timestamp'] <= split_date) 62 | gowalla_train = gowalla_dataset[train_filter] 63 | 64 | test_filter = (gowalla_dataset['timestamp'] > split_date) & ( 65 | gowalla_dataset['timestamp'] <= end_test_date) 66 | gowalla_test = gowalla_dataset[test_filter] 67 | 68 | if 'VAL_DAYS' in config: 69 | val_filter = (gowalla_dataset['timestamp'] > end_test_date) & ( 70 | gowalla_dataset['timestamp'] <= end_date) 71 | gowalla_val = gowalla_dataset[val_filter] 72 | pd.concat([gowalla_train, gowalla_test]).to_csv( 73 | dataset_path / 'gowalla.traintest', index=None, header=None) 74 | gowalla_val.to_csv(dataset_path / 'gowalla.val', index=None, header=None) 75 | 76 | gowalla_train.to_csv(dataset_path / 'gowalla.train', index=None, header=None) 77 | gowalla_test.to_csv(dataset_path / 'gowalla.test', index=None, header=None) 78 | gowalla_dataset.loc[:, ['loc_id', 'long', 'lat']] \ 79 | .to_csv(dataset_path / 'gowalla.locations', index=None, header=None) 80 | 81 | print('dataset splits saved') 82 | 83 | unique_users = set(gowalla_dataset['userId'].unique()) 84 | gowalla_friendships = pd.read_csv( 85 | dataset_path / 'loc-gowalla_edges.txt.gz', sep='\t', names=['user1', 'user2']) 86 | gowalla_friendships[(gowalla_friendships['user1'].isin(unique_users)) & 87 | (gowalla_friendships['user2'].isin(unique_users))] \ 88 | .to_csv(dataset_path / 'gowalla.friends', index=None, header=None) 89 | 90 | print('dataset friendships saved') 91 | -------------------------------------------------------------------------------- /fit_catboost.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from collections import defaultdict 3 | from catboost import CatBoostClassifier, Pool, CatBoost 4 | from sklearn.model_selection import train_test_split 5 | 6 | if __name__ == '__main__': 7 | candidates_dataset = pd.read_csv('catboost_train_dataset.csv', names=['userId', 'itemId']) 8 | 9 | gowalla_train = pd.read_csv('dataset/gowalla.train', 10 | names=['userId', 'timestamp', 'long', 'lat', 'itemId']) 11 | gowalla_test = pd.read_csv('dataset/gowalla.test', 12 | names=['userId', 'timestamp', 'long', 'lat', 'itemId']) 13 | gowalla_val = pd.read_csv('dataset/gowalla.val', 14 | names=['userId', 'timestamp', 'long', 'lat', 'itemId']) 15 | gowalla_dataset = pd.concat([gowalla_train, gowalla_test, gowalla_val]) 16 | 17 | # add test candidates as positives for better training 18 | candidates_dataset = pd.concat([candidates_dataset, gowalla_test.loc[:, ['userId', 'itemId']]]) 19 | 20 | test_user_item_pairs = set( 21 | map(lambda x: (x[0], x[1]), gowalla_test.loc[:, ['userId', 'itemId']].values)) 22 | 23 | # count geopositions of items 24 | tmp = gowalla_dataset.drop_duplicates('itemId').set_index('itemId') 25 | item_lat = tmp['lat'].to_dict() 26 | item_long = tmp['long'].to_dict() 27 | del tmp 28 | 29 | user_num_interactions = defaultdict( 30 | lambda: 0, gowalla_train.groupby('userId')['itemId'].count().to_dict()) 31 | item_num_interactions = defaultdict( 32 | lambda: 0, gowalla_train.groupby('itemId')['userId'].count().to_dict()) 33 | user_item_interactions = defaultdict( 34 | lambda: 0, gowalla_train.groupby(['userId', 'itemId'])['timestamp'].count().to_dict()) 35 | candidates_user_item_interactions = list( 36 | map(lambda x: user_item_interactions[(x[0], x[1])], 37 | candidates_dataset.loc[:, ['userId', 'itemId']].values)) 38 | 39 | gowalla_friendships = pd.read_csv('dataset/gowalla.friends', header=None, 40 | names=['user1', 'user2']) 41 | user_num_friends = gowalla_friendships.groupby('user1')['user2'].count().to_dict() 42 | 43 | candidates_target_values = list( 44 | map(lambda x: int((x[0], x[1]) in test_user_item_pairs), 45 | candidates_dataset.loc[:, ['userId', 'itemId']].values)) 46 | 47 | candidates_dataset['num_friends'] = candidates_dataset['userId'].map(user_num_friends) 48 | candidates_dataset['user_num_interactions'] = \ 49 | candidates_dataset['userId'].map(user_num_interactions) 50 | candidates_dataset['item_num_interactions'] = \ 51 | candidates_dataset['userId'].map(item_num_interactions) 52 | candidates_dataset['user_item_interactions'] = candidates_user_item_interactions 53 | candidates_dataset['long'] = candidates_dataset['itemId'].map(item_long) 54 | candidates_dataset['lat'] = candidates_dataset['itemId'].map(item_lat) 55 | candidates_dataset['target'] = candidates_target_values 56 | candidates_dataset.drop_duplicates(['userId', 'itemId'], inplace=True) 57 | 58 | print(candidates_dataset['target'].value_counts()) 59 | 60 | train_df, eval_df = train_test_split(candidates_dataset, test_size=0.2) 61 | train_df.sort_values('userId', inplace=True) 62 | eval_df.sort_values('userId', inplace=True) 63 | 64 | train_pool = Pool( 65 | train_df.drop(['userId', 'itemId', 'target'], axis=1), 66 | train_df['target'], 67 | group_id=train_df['userId'] 68 | ) 69 | eval_pool = Pool( 70 | eval_df.drop(['userId', 'itemId', 'target'], axis=1), 71 | eval_df['target'], 72 | group_id=eval_df['userId'] 73 | ) 74 | 75 | # model = CatBoostClassifier( 76 | # iterations=2000, loss_function='Logloss', eval_metric='AUC', verbose=10) 77 | # model.fit(train_pool, eval_set=eval_pool) 78 | # model.save_model('catboost.cbm') 79 | 80 | # YetiRank for ranking works better 81 | parameters = { 82 | 'iterations': 500, 83 | 'loss_function': 'YetiRank', 84 | 'eval_metric': 'AUC', 85 | 'verbose': 10, 86 | } 87 | model = CatBoost(parameters) 88 | model.fit(train_pool, eval_set=eval_pool) 89 | model.save_model('catboost.cbm') 90 | 91 | train_df['preds'] = model.predict(train_pool) 92 | print(len(train_df[(train_df['preds'] == 0) & (train_df['target'] == 1)])) 93 | print(train_df) 94 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def user_hitrate(rank, ground_truth, k=20): 5 | """ 6 | :param rank: shape [n_recommended_items] 7 | :param ground_truth: shape [n_relevant_items] 8 | :param k: number of top recommended items 9 | :return: single hitrate 10 | """ 11 | return len(set(rank[:k]).intersection(set(ground_truth))) 12 | 13 | 14 | def hitrate(rank, ground_truth, k=20): 15 | """ 16 | :param rank: shape [n_users, n_recommended_items] 17 | :param ground_truth: shape [n_users, n_relevant_items] 18 | :param k: number of top recommended items 19 | :return: shape [n_users] 20 | """ 21 | return np.array([ 22 | user_hitrate(user_rank, user_ground_truth, k) 23 | for user_rank, user_ground_truth in zip(rank, ground_truth) 24 | ]) 25 | 26 | 27 | def user_precision(rank, ground_truth, k=20): 28 | """ 29 | :param rank: shape [n_recommended_items] 30 | :param ground_truth: shape [n_relevant_items] 31 | :param k: number of top recommended items 32 | :return: single precision 33 | """ 34 | return user_hitrate(rank, ground_truth, k) / len(rank[:k]) 35 | 36 | 37 | def precision(rank, ground_truth, k=20): 38 | """ 39 | :param rank: shape [n_users, n_recommended_items] 40 | :param ground_truth: shape [n_users, n_relevant_items] 41 | :param k: number of top recommended items 42 | :return: shape [n_users] 43 | """ 44 | return np.array([ 45 | user_precision(user_rank, user_ground_truth, k) 46 | for user_rank, user_ground_truth in zip(rank, ground_truth) 47 | ]) 48 | 49 | 50 | def user_recall(rank, ground_truth, k=20): 51 | """ 52 | :param rank: shape [n_recommended_items] 53 | :param ground_truth: shape [n_relevant_items] 54 | :param k: number of top recommended items 55 | :return: single recall 56 | """ 57 | return user_hitrate(rank, ground_truth, k) / len(set(ground_truth)) 58 | 59 | 60 | def recall(rank, ground_truth, k=20): 61 | """ 62 | :param rank: shape [n_users, n_recommended_items] 63 | :param ground_truth: shape [n_users, n_relevant_items] 64 | :param k: number of top recommended items 65 | :return: shape [n_users] 66 | """ 67 | return np.array([ 68 | user_recall(user_rank, user_ground_truth, k) 69 | for user_rank, user_ground_truth in zip(rank, ground_truth) 70 | ]) 71 | 72 | 73 | def user_ap(rank, ground_truth, k=20): 74 | """ 75 | :param rank: shape [n_recommended_items] 76 | :param ground_truth: shape [n_relevant_items] 77 | :param k: number of top recommended items 78 | :return: single ap 79 | """ 80 | return np.sum([ 81 | user_precision(rank, ground_truth, idx + 1) 82 | for idx, item in enumerate(rank[:k]) if item in ground_truth 83 | ]) / len(rank[:k]) 84 | 85 | 86 | def ap(rank, ground_truth, k=20): 87 | """ 88 | :param rank: shape [n_users, n_recommended_items] 89 | :param ground_truth: shape [n_users, n_relevant_items] 90 | :param k: number of top recommended items 91 | :return: shape [n_users] 92 | """ 93 | return np.array([ 94 | user_ap(user_rank, user_ground_truth, k) 95 | for user_rank, user_ground_truth in zip(rank, ground_truth) 96 | ]) 97 | 98 | 99 | def map(rank, ground_truth, k=20): 100 | """ 101 | :param rank: shape [n_users, n_recommended_items] 102 | :param ground_truth: shape [n_users, n_relevant_items] 103 | :param k: number of top recommended items 104 | :return: single map 105 | """ 106 | return np.mean([ap(rank, ground_truth, k)]) 107 | 108 | 109 | def user_ndcg(rank, ground_truth, k=20): 110 | """ 111 | :param rank: shape [n_recommended_items] 112 | :param ground_truth: shape [n_relevant_items] 113 | :param k: number of top recommended items 114 | :return: single ndcg 115 | """ 116 | dcg = 0 117 | idcg = 0 118 | for idx, item in enumerate(rank[:k]): 119 | dcg += 1.0 / np.log2(idx + 2) if item in ground_truth else 0.0 120 | idcg += 1.0 / np.log2(idx + 2) 121 | return dcg / idcg 122 | 123 | 124 | def ndcg(rank, ground_truth, k=20): 125 | """ 126 | :param rank: shape [n_users, n_recommended_items] 127 | :param ground_truth: shape [n_users, n_relevant_items] 128 | :param k: number of top recommended items 129 | :return: shape [n_users] 130 | """ 131 | return np.array([ 132 | user_ndcg(user_rank, user_ground_truth, k) 133 | for user_rank, user_ground_truth in zip(rank, ground_truth) 134 | ]) 135 | 136 | 137 | def user_mrr(rank, ground_truth, k=20): 138 | """ 139 | :param rank: shape [n_recommended_items] 140 | :param ground_truth: shape [n_relevant_items] 141 | :param k: number of top recommended items 142 | :return: single mrr 143 | """ 144 | for idx, item in enumerate(rank[:k]): 145 | if item in ground_truth: 146 | return 1 / (idx + 1) 147 | return 0 148 | 149 | 150 | def mrr(rank, ground_truth, k=20): 151 | """ 152 | :param rank: shape [n_users, n_recommended_items] 153 | :param ground_truth: shape [n_users, n_relevant_items] 154 | :param k: number of top recommended items 155 | :return: shape [n_users] 156 | """ 157 | return np.array([ 158 | user_mrr(user_rank, user_ground_truth, k) 159 | for user_rank, user_ground_truth in zip(rank, ground_truth) 160 | ]) 161 | 162 | 163 | metric_dict = { 164 | 'Hitrate': hitrate, 165 | 'Precision': precision, 166 | 'Recall': recall, 167 | 'MAP': map, 168 | 'NDCG': ndcg, 169 | 'MRR': mrr} 170 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | import scipy.sparse as sp 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class GowallaDataset(Dataset): 9 | def __init__(self, train, path='dataset'): 10 | print('init ' + ('train' if train else 'test') + ' dataset') 11 | self.n_users_ = int(open(f'{path}/user_list.txt').readlines()[-1][:-1].split(' ')[1]) + 1 12 | self.m_items_ = int(open(f'{path}/item_list.txt').readlines()[-1][:-1].split(' ')[1]) + 1 13 | 14 | def get_all_users(self): 15 | raise NotImplemented 16 | 17 | def get_user_positives(self, user): 18 | raise NotImplemented 19 | 20 | def get_user_negatives(self, user, k): 21 | raise NotImplemented 22 | 23 | @property 24 | def n_users(self): 25 | return self.n_users_ 26 | 27 | @property 28 | def m_items(self): 29 | return self.m_items_ 30 | 31 | def __len__(self): 32 | return self.n_users_ 33 | 34 | def __getitem__(self, idx): 35 | raise NotImplemented 36 | 37 | 38 | class GowallaTopNDataset(GowallaDataset): 39 | def __init__(self, path, train=True): 40 | super().__init__(train) 41 | self.df = pd.read_csv(path, names=['userId', 'timestamp', 'long', 'lat', 'loc_id']) 42 | 43 | self.unique_users = self.df['userId'].unique() 44 | self.user_positive_items = self.df.groupby('userId')['loc_id'].apply(list).to_dict() 45 | 46 | def get_all_users(self): 47 | return self.unique_users 48 | 49 | def get_user_positives(self, user): 50 | if user not in self.user_positive_items: 51 | return [] 52 | return self.user_positive_items[user] 53 | 54 | def get_user_negatives(self, user, k=10): 55 | neg = [] 56 | positives = set(self.get_user_positives(user)) 57 | while len(neg) < k: 58 | candidate = np.random.randint(1, self.m_items) 59 | if candidate not in positives: 60 | neg.append(candidate) 61 | return neg 62 | 63 | 64 | class GowallaLightGCNDataset(GowallaDataset): 65 | def __init__(self, path, train=True, n_negatives: int = 10): 66 | super().__init__(train) 67 | self.n_negatives = n_negatives 68 | 69 | # dataset = pd.read_csv(path, names=['userId', 'timestamp', 'long', 'lat', 'loc_id']) 70 | dataset = pd.read_csv(path, names=['userId', 'loc_id']) 71 | 72 | dataset['feed'] = 1 73 | users = dataset['userId'] 74 | items = dataset['loc_id'] 75 | feed = dataset['feed'] 76 | self.unique_users = users.unique() 77 | self.user_positive_items = dataset.groupby('userId')['loc_id'].apply(list).to_dict() 78 | del dataset 79 | 80 | n_nodes = self.n_users + self.m_items 81 | 82 | # build scipy sparse matrix 83 | user_np = np.array(users.values, dtype=np.int32) 84 | item_np = np.array(items.values, dtype=np.int32) 85 | ratings = np.array(feed.values, dtype=np.int32) 86 | 87 | tmp_adj = sp.csr_matrix((ratings, (user_np, item_np + self.n_users)), 88 | shape=(n_nodes, n_nodes)) 89 | adj_mat = tmp_adj + tmp_adj.T 90 | 91 | # normalize matrix 92 | rowsum = np.array(adj_mat.sum(1)) 93 | d_inv = np.power(rowsum, -0.5).flatten() 94 | d_inv[np.isinf(d_inv)] = 0. 95 | d_mat_inv = sp.diags(d_inv) 96 | 97 | # normalize by user counts 98 | norm_adj_tmp = d_mat_inv.dot(adj_mat) 99 | # normalize by item counts 100 | normalized_adj_matrix = norm_adj_tmp.dot(d_mat_inv) 101 | 102 | # convert to torch sparse matrix 103 | adj_mat_coo = normalized_adj_matrix.tocoo() 104 | 105 | values = adj_mat_coo.data 106 | indices = np.vstack((adj_mat_coo.row, adj_mat_coo.col)) 107 | 108 | i = torch.LongTensor(indices) 109 | v = torch.FloatTensor(values) 110 | shape = adj_mat_coo.shape 111 | 112 | self.adj_matrix = torch.sparse_coo_tensor(i, v, torch.Size(shape)) 113 | 114 | def get_all_users(self): 115 | return self.unique_users 116 | 117 | def get_user_positives(self, user): 118 | if user not in self.user_positive_items: 119 | return [] 120 | return self.user_positive_items[user] 121 | 122 | def get_user_negatives(self, user, k=10): 123 | neg = [] 124 | positives = set(self.get_user_positives(user)) 125 | while len(neg) < k: 126 | candidate = np.random.randint(1, self.m_items) 127 | if candidate not in positives: 128 | neg.append(candidate) 129 | return neg 130 | 131 | def get_sparse_graph(self): 132 | """ 133 | Returns a grapth in torch.sparse_coo_tensor. 134 | A = |0, R| 135 | |R^T, 0| 136 | """ 137 | return self.adj_matrix 138 | 139 | def __len__(self): 140 | return len(self.unique_users) 141 | 142 | def __getitem__(self, idx): 143 | """ 144 | returns user, pos_items, neg_items 145 | 146 | :param idx: index of user from unique_users 147 | :return: 148 | """ 149 | user = self.unique_users[idx] 150 | pos = np.random.choice(self.get_user_positives(user), self.n_negatives) 151 | neg = self.get_user_negatives(user, self.n_negatives) 152 | return user, pos, neg 153 | 154 | 155 | class GowallaALSDataset(GowallaDataset): 156 | def __init__(self, path, train=True): 157 | super().__init__(train) 158 | self.path = path 159 | self.train = train 160 | self.df = pd.read_csv(path, names=['userId', 'timestamp', 'long', ' lat', 'loc_id']) 161 | 162 | def get_dataset(self, n_users=None, m_items=None): 163 | if self.train: 164 | users = self.df['userId'].values 165 | items = self.df['loc_id'].values 166 | ratings = np.ones(len(users)) 167 | 168 | n_users = self.n_users if n_users is None else n_users 169 | m_items = self.m_items if m_items is None else m_items 170 | user_item_data = sp.csr_matrix((ratings, (users, items)), 171 | shape=(n_users, m_items)) 172 | item_user_data = user_item_data.T.tocsr() 173 | return self.df, user_item_data, item_user_data 174 | else: 175 | return self.df 176 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import torch 3 | import utils 4 | import metrics 5 | import torch.nn as nn 6 | from tqdm import tqdm 7 | from loguru import logger 8 | from collections import defaultdict 9 | from torch.utils.data import DataLoader 10 | from config import config, tensorboard_writer 11 | from dataloader import GowallaLightGCNDataset, GowallaTopNDataset 12 | 13 | 14 | class TopNModel: 15 | def __init__(self, top_n): 16 | self.top_n = top_n 17 | self.top_items = [] 18 | 19 | def fit(self, dataset: GowallaTopNDataset): 20 | item_counts = dataset.df.groupby('loc_id')['userId'].count().reset_index(name='count') 21 | self.top_items = item_counts.sort_values('count', ascending=False).head(self.top_n)[ 22 | 'loc_id'].values 23 | 24 | def recommend(self, users: list, k: int = 20): 25 | return [self.top_items[:k] for _ in users] 26 | 27 | def eval(self, test_dataset: GowallaTopNDataset): 28 | users = [] 29 | ground_truth = [] 30 | 31 | for user in test_dataset.get_all_users(): 32 | user_positive_items = test_dataset.get_user_positives(user) 33 | if user_positive_items: 34 | users.append(user) 35 | ground_truth.append(user_positive_items) 36 | 37 | preds = self.recommend(users) 38 | max_length = max(map(len, metrics.metric_dict.keys())) + max( 39 | map(lambda x: len(str(x)), config['METRICS_REPORT'])) 40 | for metric_name, metric_func in metrics.metric_dict.items(): 41 | for k in config['METRICS_REPORT']: 42 | metric_name_total = f'{metric_name}@{k}' 43 | metric_value = metric_func(preds, ground_truth, k).mean() 44 | logger.info(f'{metric_name_total: >{max_length + 1}} = {metric_value}') 45 | 46 | 47 | class TopNPersonalized: 48 | def __init__(self, top_n): 49 | self.top_n = top_n 50 | self.top_items = [] 51 | self.top_user_items = defaultdict(lambda: self.top_items) 52 | 53 | def fit(self, dataset: GowallaTopNDataset): 54 | item_counts = dataset.df.groupby('loc_id')['userId'].count().reset_index(name='count') 55 | self.top_items = item_counts.sort_values('count', ascending=False) \ 56 | .head(self.top_n)['loc_id'].values 57 | for user_id, df in dataset.df.groupby('userId'): 58 | item_counts = df.groupby('loc_id')['timestamp'].count().reset_index(name='count') 59 | self.top_user_items[user_id] = item_counts.sort_values('count', ascending=False) \ 60 | .head(self.top_n)['loc_id'].values 61 | 62 | def recommend(self, users: list, k: int = 20): 63 | return [self.top_user_items[user][:k] for user in users] 64 | 65 | def eval(self, test_dataset: GowallaTopNDataset): 66 | users = [] 67 | ground_truth = [] 68 | 69 | for user in test_dataset.get_all_users(): 70 | user_positive_items = test_dataset.get_user_positives(user) 71 | if user_positive_items: 72 | users.append(user) 73 | ground_truth.append(user_positive_items) 74 | 75 | preds = self.recommend(users) 76 | max_length = max(map(len, metrics.metric_dict.keys())) + max( 77 | map(lambda x: len(str(x)), config['METRICS_REPORT'])) 78 | for metric_name, metric_func in metrics.metric_dict.items(): 79 | for k in config['METRICS_REPORT']: 80 | metric_name_total = f'{metric_name}@{k}' 81 | metric_value = metric_func(preds, ground_truth, k).mean() 82 | logger.info(f'{metric_name_total: >{max_length + 1}} = {metric_value}') 83 | 84 | 85 | class TopNNearestModel: 86 | def __init__(self, top_n, calc_nearest): 87 | self.top_n = top_n 88 | self.top_items = [] 89 | self.calc_nearest = calc_nearest 90 | self.user_last_item = defaultdict(lambda: -1) 91 | self.top_nearest_items = defaultdict(lambda: self.top_items) 92 | 93 | def fit(self, dataset: GowallaTopNDataset): 94 | item_counts = dataset.df.groupby('loc_id')['userId'].count().reset_index(name='count') 95 | self.top_items = item_counts.sort_values('count', ascending=False) \ 96 | .head(self.top_n)['loc_id'].values 97 | for user, item in dataset.df.sort_values('timestamp') \ 98 | .drop_duplicates('userId', keep='last').set_index('userId')['loc_id'] \ 99 | .to_dict().items(): 100 | self.user_last_item[user] = item 101 | for user, last_item in tqdm(self.user_last_item.items()): 102 | self.top_nearest_items[user] = self.calc_nearest(last_item, self.top_n) 103 | 104 | def recommend(self, users: list, k: int = 20): 105 | return [self.top_nearest_items[self.user_last_item[user]][:k] for user in users] 106 | 107 | def eval(self, test_dataset: GowallaTopNDataset): 108 | users = [] 109 | ground_truth = [] 110 | 111 | for user in test_dataset.get_all_users(): 112 | user_positive_items = test_dataset.get_user_positives(user) 113 | if user_positive_items: 114 | users.append(user) 115 | ground_truth.append(user_positive_items) 116 | 117 | preds = self.recommend(users) 118 | max_length = max(map(len, metrics.metric_dict.keys())) + max( 119 | map(lambda x: len(str(x)), config['METRICS_REPORT'])) 120 | for metric_name, metric_func in metrics.metric_dict.items(): 121 | for k in config['METRICS_REPORT']: 122 | metric_name_total = f'{metric_name}@{k}' 123 | metric_value = metric_func(preds, ground_truth, k).mean() 124 | logger.info(f'{metric_name_total: >{max_length + 1}} = {metric_value}') 125 | 126 | 127 | class LightGCN(nn.Module): 128 | def __init__(self, dataset: GowallaLightGCNDataset): 129 | """ 130 | :param dataset: dataset derived from BasicDataset 131 | """ 132 | super(LightGCN, self).__init__() 133 | self.dataset: GowallaLightGCNDataset = dataset 134 | self.num_users = self.dataset.n_users 135 | self.num_items = self.dataset.m_items 136 | self.latent_dim = config['LATENT_DIM'] 137 | self.n_layers = config['N_LAYERS'] 138 | self.__init_weight() 139 | 140 | def __init_weight(self): 141 | """ 142 | Initialize embeddings with normal distribution 143 | :return: 144 | """ 145 | self.embedding_user = torch.nn.Embedding( 146 | num_embeddings=self.num_users, embedding_dim=self.latent_dim) 147 | self.embedding_item = torch.nn.Embedding( 148 | num_embeddings=self.num_items, embedding_dim=self.latent_dim) 149 | 150 | if 'pretrain' not in config or not config['PRETRAIN']: 151 | nn.init.normal_(self.embedding_user.weight, std=0.1) 152 | nn.init.normal_(self.embedding_item.weight, std=0.1) 153 | else: 154 | self.embedding_user.weight.data.copy_(torch.from_numpy(config['USER_EMB_FILE'])) 155 | self.embedding_item.weight.data.copy_(torch.from_numpy(config['ITEM_EMB_FILE'])) 156 | print('use pretrained data') 157 | 158 | self.Graph = self.dataset.get_sparse_graph() 159 | print('LightGCN is ready to go') 160 | 161 | def computer(self) -> tuple: 162 | """ 163 | Propagate high-hop embeddings for lightGCN 164 | :return: user embeddings, item embeddings 165 | """ 166 | users_emb = self.embedding_user.weight 167 | items_emb = self.embedding_item.weight 168 | all_emb = torch.cat([users_emb, items_emb]) 169 | 170 | layer_embeddings = [all_emb] 171 | for _ in range(self.n_layers): 172 | all_emb = torch.sparse.mm(self.Graph, all_emb) 173 | layer_embeddings.append(all_emb) 174 | layer_embeddings = torch.stack(layer_embeddings, dim=1) 175 | 176 | final_embeddings = layer_embeddings.mean(dim=1) # output is mean of all layers 177 | users, items = torch.split(final_embeddings, [self.num_users, self.num_items]) 178 | return users, items 179 | 180 | def get_users_rating(self, users: torch.tensor) -> torch.tensor: 181 | """ 182 | Compute item ratings for users 183 | :param users: user ids for which compute ratings 184 | :return: 185 | """ 186 | all_users, all_items = self.computer() 187 | users_emb = all_users[users.long()] 188 | items_emb = all_items 189 | rating = torch.matmul(users_emb, items_emb.t()) 190 | return rating 191 | 192 | def get_embedding(self, users: torch.tensor, pos_items: torch.tensor, 193 | neg_items: torch.tensor) -> tuple: 194 | all_users, all_items = self.computer() 195 | users_emb = all_users[users] 196 | pos_emb = all_items[pos_items] 197 | neg_emb = all_items[neg_items] 198 | users_emb_ego = self.embedding_user(users) 199 | pos_emb_ego = self.embedding_item(pos_items) 200 | neg_emb_ego = self.embedding_item(neg_items) 201 | return users_emb, pos_emb, neg_emb, users_emb_ego, pos_emb_ego, neg_emb_ego 202 | 203 | def bpr_loss(self, users: torch.tensor, pos: torch.tensor, neg: torch.tensor) -> tuple: 204 | """ 205 | Calculate BPR loss as - sum ln(sigma(pos_scores - neg_scores)) + L2 norm 206 | :param users: users for which calculate loss 207 | :param pos: positive items 208 | :param neg: negative items 209 | :return: loss, reg_loss 210 | """ 211 | (users_emb, pos_emb, neg_emb, 212 | userEmb0, posEmb0, negEmb0) = self.get_embedding(users.long(), pos.long(), neg.long()) 213 | reg_loss = (1 / 2) * (userEmb0.norm(2).pow(2) + 214 | posEmb0.norm(2).pow(2) + 215 | negEmb0.norm(2).pow(2)) / float(len(users)) 216 | pos_scores = torch.mul(users_emb, pos_emb) 217 | pos_scores = torch.sum(pos_scores, dim=1) 218 | neg_scores = torch.mul(users_emb, neg_emb) 219 | neg_scores = torch.sum(neg_scores, dim=1) 220 | 221 | loss = - (pos_scores - neg_scores).sigmoid().log().mean() 222 | return loss, reg_loss 223 | 224 | def forward(self, users: torch.tensor, items: torch.tensor): 225 | # compute embedding 226 | all_users, all_items = self.computer() 227 | 228 | users_emb = all_users[users] 229 | items_emb = all_items[items] 230 | inner_prod = torch.mul(users_emb, items_emb) 231 | return torch.sum(inner_prod, dim=1).sigmoid() 232 | 233 | def fit(self, n_epochs: int = 10, dataset: GowallaLightGCNDataset = None): 234 | """ 235 | Fitting model with BPR loss. 236 | :param n_epochs: number of epochs to fit model 237 | :param dataset: dataset for training 238 | :return: 239 | """ 240 | optimizer = torch.optim.Adam(self.parameters()) 241 | pbar = tqdm(range(n_epochs)) 242 | dataloader = DataLoader( 243 | self.dataset, batch_size=config['BATCH_SIZE'], 244 | shuffle=True, collate_fn=utils.collate_function) 245 | for epoch in pbar: 246 | for users, pos, neg in dataloader: 247 | optimizer.zero_grad() 248 | loss, reg_loss = self.bpr_loss(users, pos, neg) 249 | total_loss = loss + config['BPR_REG_ALPHA'] * reg_loss 250 | 251 | total_loss.backward() 252 | optimizer.step() 253 | 254 | if tensorboard_writer: 255 | tensorboard_writer.add_scalar('Train/bpr_loss', loss.item()) 256 | tensorboard_writer.add_scalar('Train/bpr_reg_loss', reg_loss.item()) 257 | tensorboard_writer.add_scalar('Train/bpr_total_loss', total_loss.item()) 258 | pbar.set_postfix({'bpr_loss': total_loss.item()}) 259 | 260 | if dataset and config['EVAL_EPOCHS'] > 0 and epoch % config['EVAL_EPOCHS'] == 0: 261 | self.eval(dataset) 262 | 263 | @torch.no_grad() 264 | def recommend(self, users: torch.tensor, k: int = 20): 265 | all_users, all_items = self.computer() 266 | users_emb = all_users[users.long()].numpy() 267 | items_emb = all_items.numpy() 268 | 269 | index = faiss.IndexFlatIP(self.latent_dim) 270 | index.add(items_emb) 271 | return index.search(users_emb, k)[1] 272 | 273 | @torch.no_grad() 274 | def eval(self, test_dataset: GowallaLightGCNDataset): 275 | users = [] 276 | ground_truth = [] 277 | 278 | for user in test_dataset.get_all_users(): 279 | user_positive_items = test_dataset.get_user_positives(user) 280 | if user_positive_items: 281 | users.append(user) 282 | ground_truth.append(user_positive_items) 283 | 284 | preds = self.recommend(torch.tensor(users), max(config['METRICS_REPORT'])) 285 | max_length = max(map(len, metrics.metric_dict.keys())) + max( 286 | map(lambda x: len(str(x)), config['METRICS_REPORT'])) 287 | for metric_name, metric_func in metrics.metric_dict.items(): 288 | for k in config['METRICS_REPORT']: 289 | metric_name_total = f'{metric_name}@{k}' 290 | metric_value = metric_func(preds, ground_truth, k).mean() 291 | logger.info(f'{metric_name_total: >{max_length + 1}} = {metric_value}') 292 | if tensorboard_writer: 293 | tensorboard_writer.add_scalar(f'Eval/{metric_name_total}', metric_value) 294 | --------------------------------------------------------------------------------