├── data └── read_me ├── images ├── zenodo.png ├── homepage.png └── results.png ├── options.py ├── model.py ├── data_generation.py ├── main.py ├── clustering.py ├── model_training.py ├── README.md └── model_evaluation.py /data/read_me: -------------------------------------------------------------------------------- 1 | dowload here the data from https://zenodo.org/record/5121674#.YQuiitMzaIZ 2 | -------------------------------------------------------------------------------- /images/zenodo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deezer/semi_perso_user_cold_start/HEAD/images/zenodo.png -------------------------------------------------------------------------------- /images/homepage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deezer/semi_perso_user_cold_start/HEAD/images/homepage.png -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deezer/semi_perso_user_cold_start/HEAD/images/results.png -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | config = { 2 | # song 3 | 'nb_songs': 50000, 4 | # user 5 | 'embeddings_version': "svd",# can be "svd" for TT-SVD, or "mf" for UT-ALS in deezer data case 6 | 'embeddings_dim': 128,#128 for TT-SVD, or 256 for UT-ALS 7 | # cuda setting 8 | 'use_cuda': True, 9 | 'device_number': 0, 10 | # model setting 11 | 'input_dim': 2579, #2579 & 5139 for TT-SVD and UT-ALS train features respectively 12 | 'nb_epochs': 130, 13 | 'learning_rate': 0.00001, 14 | 'batch_size': 512, 15 | 'reg_param': 0, 16 | 'drop_out': 0, 17 | # model training 18 | 'eval_every': 10, 19 | 'k_val': 50, 20 | #clustering for semi personalization strategy 21 | 'nb_clusters': 1000, 22 | 'max_iter': 20, 23 | 'random_state': 0, 24 | #clustering for inputfeatures baseline 25 | 'nb_clusters_inputfeatures': 100, 26 | 'max_iter': 20, 27 | 'random_state': 0, 28 | # model evaluation 29 | 'k_val_list': [50], 30 | 'nb_iterations_eval_stddev': 2, 31 | 'nb_sub_iterations_eval_stddev': 5, 32 | 'indic_eval_evolution': 1000, 33 | } 34 | 35 | dataset_eval = ["validation", "test"] 36 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn 4 | 5 | class RegressionTripleHidden(torch.nn.Module): 6 | def __init__(self, input_dim, output_dim, first_hidden_dim = 400, second_hidden_dim = 300, third_hidden_dim = 200, drop_out = 0): 7 | super(RegressionTripleHidden, self).__init__() 8 | self.input_dim = input_dim 9 | self.first_hidden_dim = first_hidden_dim 10 | self.second_hidden_dim = second_hidden_dim 11 | self.third_hidden_dim = third_hidden_dim 12 | self.output_dim = output_dim 13 | self.dpin = torch.nn.Dropout(drop_out) 14 | 15 | self.fc1 = torch.nn.Linear(self.input_dim, self.first_hidden_dim) 16 | self.fc1_bn = torch.nn.BatchNorm1d(self.first_hidden_dim) 17 | 18 | self.fc2 = torch.nn.Linear(self.first_hidden_dim, self.second_hidden_dim) 19 | self.fc2_bn = torch.nn.BatchNorm1d(self.second_hidden_dim) 20 | 21 | self.fc3 = torch.nn.Linear(self.second_hidden_dim, self.third_hidden_dim) 22 | self.fc3_bn = torch.nn.BatchNorm1d(self.third_hidden_dim) 23 | 24 | self.fc4 = torch.nn.Linear(self.third_hidden_dim, self.output_dim) 25 | 26 | def forward(self, x): 27 | hidden1 = self.fc1_bn(F.relu((self.fc1(self.dpin(x))))) 28 | hidden2 = self.fc2_bn(F.relu(self.fc2(hidden1))) 29 | hidden3 = self.fc3_bn(F.relu(self.fc3(hidden2))) 30 | output = F.normalize(self.fc4(hidden3), dim = 1) 31 | return output 32 | -------------------------------------------------------------------------------- /data_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | import torch.nn 5 | import pickle 6 | from options import dataset_eval 7 | 8 | def generate(dataset_path, master_path, embeddings_version): 9 | 10 | #songs 11 | 12 | print("--- songs data ---") 13 | 14 | song_embeddings_path = dataset_path + "/song_embeddings.parquet" 15 | song_embeddings = pd.read_parquet(song_embeddings_path, engine = 'fastparquet') 16 | 17 | if not os.path.exists(master_path+"/m_song_dict.pkl"): 18 | song_dict = {} 19 | for idx, row in song_embeddings.iterrows(): 20 | song_dict[row['song_index']] = idx 21 | pickle.dump(song_dict, open("{}/m_song_dict.pkl".format(master_path), "wb")) 22 | else: 23 | song_dict = pickle.load(open("{}/m_song_dict.pkl".format(master_path), "rb")) 24 | 25 | 26 | # user embeddings (target = only for train users) 27 | 28 | print("--- user embeddings - training dataset ---") 29 | 30 | user_embeddings = pd.read_parquet(dataset_path + "/user_embeddings.parquet", engine = 'fastparquet') 31 | list_embeddings = ["embedding_"+str(i) for i in range(len(user_embeddings[embeddings_version + "_embeddings"][0]))] 32 | user_embeddings[list_embeddings] = pd.DataFrame(user_embeddings[embeddings_version + "_embeddings"].tolist(), index= user_embeddings.index) 33 | 34 | # user features train 35 | 36 | print("--- user features - training dataset ---") 37 | 38 | features_train_path = dataset_path + "/user_features_train_" + embeddings_version + ".parquet" 39 | features_train = pd.read_parquet(features_train_path, engine = 'fastparquet').fillna(0) 40 | features_train = features_train.sort_values("user_index") 41 | features_train = features_train.reset_index(drop=True)#to check it is ok for train data 42 | 43 | # training dataset creation 44 | 45 | dataset = "train" 46 | if not os.path.exists(master_path+"/"): 47 | os.mkdir(master_path+"/") 48 | if not os.path.exists(master_path+"/"+embeddings_version+"/"): 49 | os.mkdir(master_path+"/"+embeddings_version+"/") 50 | if not os.path.exists(master_path+"/"+embeddings_version+"/"+dataset+"/"): 51 | os.mkdir(master_path+"/"+embeddings_version+"/"+dataset+"/") 52 | for idx in range(len(features_train)): 53 | x_train = torch.FloatTensor(features_train.iloc[idx,2:]) 54 | y_train = torch.FloatTensor(user_embeddings[list_embeddings].iloc[idx,:]) 55 | pickle.dump(x_train, open("{}/{}/{}/x_train_{}.pkl".format(master_path, embeddings_version, dataset, idx), "wb")) 56 | pickle.dump(y_train, open("{}/{}/{}/y_train_{}.pkl".format(master_path, embeddings_version, dataset, idx), "wb")) 57 | 58 | # user features validation & test 59 | 60 | print("--- user features - evaluation datasets ---") 61 | 62 | for dataset in dataset_eval : 63 | 64 | print("--- "+dataset+" ---") 65 | 66 | features_validation_path = dataset_path + "/user_features_" + dataset + "_" + embeddings_version + ".parquet" 67 | features_validation = pd.read_parquet(features_validation_path, engine = 'fastparquet').fillna(0) 68 | features_validation = features_validation.sort_values("user_index") 69 | features_validation = features_validation.reset_index(drop=True) 70 | 71 | # validation & test dataset creation 72 | 73 | if not os.path.exists(master_path+"/"+embeddings_version+"/"): 74 | os.mkdir(master_path+"/"+embeddings_version+"/") 75 | if not os.path.exists(master_path+"/"+embeddings_version+"/"+dataset+"/"): 76 | os.mkdir(master_path+"/"+embeddings_version+"/"+dataset+"/") 77 | for i in range(len(features_validation)): 78 | x_validation = torch.FloatTensor(features_validation.iloc[i,2:]) 79 | y_validation = [song_dict[song_index] for song_index in features_validation["d1d30_songs"][i]] 80 | groundtruth_validation_list = [1.0 * (song in y_validation) for song in range(len(song_embeddings))] 81 | pickle.dump(x_validation, open("{}/{}/{}/x_{}.pkl".format(master_path, embeddings_version, dataset, i), "wb")) 82 | pickle.dump(y_validation, open("{}/{}/{}/y_listened_songs_{}.pkl".format(master_path, embeddings_version, dataset, i), "wb")) 83 | pickle.dump(groundtruth_validation_list, open("{}/{}/{}/groundtruth_list_{}.pkl".format(master_path, embeddings_version, dataset, i), "wb")) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data_generation import generate 3 | from model_training import training 4 | from model_evaluation import evaluation 5 | from clustering import train_kmeans, train_inputfeatureskmeans 6 | from options import config 7 | import time 8 | import pandas as pd 9 | import torch 10 | 11 | if __name__ == "__main__": 12 | 13 | try: 14 | pd.io.parquet.get_engine('fastparquet') 15 | except ImportError as e: 16 | print("please install fastparquet first") 17 | 18 | if config['use_cuda'] and not torch.cuda.is_available(): 19 | print("please make cuda gpu available, or set use_cuda=False") 20 | assert torch.cuda.is_available() 21 | 22 | master_path= "./deezer" 23 | dataset_path = os.getcwd() + "/data" 24 | embeddings_version = config["embeddings_version"] 25 | model_filename = "regression_model_" + embeddings_version 26 | clustering_path = "clustering_" + embeddings_version 27 | clusters_filename = "clustering_model" + embeddings_version 28 | inputfeaturesclustering_path = "inputfeaturesclustering_" + embeddings_version 29 | inputfeaturesclusters_filename = "inputfeaturesclustering_model" + embeddings_version 30 | print("--- running for embeddings version " + embeddings_version + " ---") 31 | 32 | if not os.path.exists("{}/".format(master_path)): 33 | os.mkdir("{}/".format(master_path)) 34 | if not os.path.exists(master_path + "/" + embeddings_version + "/"): 35 | print("--- the data has not been generated yet for the embeddings version " + embeddings_version + " : generation running ---") 36 | os.mkdir(master_path + "/" + embeddings_version + "/") 37 | # preparing dataset 38 | print("--- data generation ---") 39 | start_time_data_generation = time.time() 40 | generate(dataset_path, master_path, config['embeddings_version']) 41 | print("--- data generation done ---") 42 | print("--- seconds ---" + str(time.time() - start_time_data_generation)) 43 | else: 44 | print("--- the data has already been generated : no need to regenerate it ---") 45 | 46 | # training model. 47 | print("--- training prediction model ---") 48 | start_time_prediction_model = time.time() 49 | training(dataset_path, master_path, embeddings_version = embeddings_version, eval = True, model_save = True, model_filename = model_filename) 50 | print("--- training prediction model done ---") 51 | print("--- seconds ---" + str(time.time() - start_time_prediction_model)) 52 | 53 | 54 | # evaluation of the model - full personalization strategy. 55 | print("--- full personalisation evaluation ---") 56 | start_time_fullperso_eval = time.time() 57 | evaluation(dataset_path, master_path, eval_type = "full_perso", embeddings_version = embeddings_version, model_filename = model_filename) 58 | print("--- full personalisation evaluation done ---") 59 | print("--- seconds ---" + str(time.time() - start_time_fullperso_eval)) 60 | 61 | # evaluation of the model - semi personalization strategy. 62 | if not os.path.exists("{}/{}/".format(master_path, clustering_path)): 63 | os.mkdir("{}/{}/".format(master_path, clustering_path)) 64 | if not os.path.exists(master_path + "/" + clustering_path + "/" + clusters_filename): 65 | print("--- clustering running ---") 66 | start_time_clustering = time.time() 67 | train_kmeans(dataset_path, master_path, clustering_path, config['nb_clusters'], config['max_iter'], config['random_state'], embeddings_version = embeddings_version, clusters_filename = clusters_filename) 68 | print("--- clustering done ---") 69 | print("--- seconds ---" + str(time.time() - start_time_clustering)) 70 | else: 71 | print("--- no need to do the clustering again ---") 72 | 73 | print("--- semi personalisation evaluation ---") 74 | start_time_semiperso_eval = time.time() 75 | evaluation(dataset_path, master_path, eval_type = "semi_perso", embeddings_version = embeddings_version, model_filename = model_filename, clustering_path=clustering_path, clusters_filename = clusters_filename, nb_clusters=config["nb_clusters"]) 76 | print("--- semi personalisation evaluation done ---") 77 | print("--- seconds ---" + str(time.time() - start_time_semiperso_eval)) 78 | 79 | # popularity baseline. 80 | print("--- popularity baseline evaluation ---") 81 | start_time_popbaseline_eval = time.time() 82 | evaluation(dataset_path, master_path, eval_type = "popularity", embeddings_version = embeddings_version, model_filename = model_filename) 83 | print("--- popularity baseline evaluation done ---") 84 | print("--- seconds ---" + str(time.time() - start_time_popbaseline_eval)) 85 | 86 | # avg d0 stream baseline. 87 | print("--- avg d0 stream baseline evaluation ---") 88 | start_time_avgd0streambaseline_eval = time.time() 89 | evaluation(dataset_path, master_path, eval_type = "avgd0stream", embeddings_version = embeddings_version, model_filename = model_filename) 90 | print("--- avg d0 stream baseline evaluation done ---") 91 | print("--- seconds ---" + str(time.time() - start_time_avgd0streambaseline_eval)) 92 | 93 | # input features clustering baseline. 94 | print("--- input features clustering baseline evaluation ---") 95 | start_time_inputfeatures_clustering = time.time() 96 | if not os.path.exists("{}/{}/".format(master_path, inputfeaturesclustering_path)): 97 | os.mkdir("{}/{}/".format(master_path, inputfeaturesclustering_path)) 98 | if not os.path.exists(master_path + "/" + inputfeaturesclustering_path + "/" + inputfeaturesclusters_filename): 99 | print("--- input features clustering running ---") 100 | start_time_inputfeaturesclustering = time.time() 101 | train_inputfeatureskmeans(dataset_path, master_path, inputfeaturesclustering_path, config['nb_clusters_inputfeatures'], config['max_iter'], config['random_state'], config["nb_songs"], embeddings_version = embeddings_version, clusters_filename = inputfeaturesclusters_filename) 102 | print("--- clustering done ---") 103 | print("--- seconds ---" + str(time.time() - start_time_inputfeatures_clustering)) 104 | else: 105 | print("--- no need to do the clustering again ---") 106 | 107 | print("--- input features clustering baseline evaluation ---") 108 | start_time_inputfeaturesbaseline_eval = time.time() 109 | evaluation(dataset_path, master_path, eval_type = "inputfeatures", embeddings_version = embeddings_version, model_filename = model_filename, clustering_path = inputfeaturesclustering_path, clusters_filename = inputfeaturesclusters_filename, nb_clusters = config['nb_clusters_inputfeatures']) 110 | print("--- input features clustering baseline evaluation done ---") 111 | print("--- seconds ---" + str(time.time() - start_time_inputfeaturesbaseline_eval)) 112 | -------------------------------------------------------------------------------- /clustering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from sklearn.cluster import KMeans 5 | import pickle 6 | from options import config 7 | from sklearn.preprocessing import Normalizer 8 | 9 | def train_kmeans(dataset_path, master_path, clustering_path, nb_clusters, max_iter, random_state, embeddings_version="svd", clusters_filename=None): 10 | 11 | # user embeddings (target = only for train users) 12 | 13 | user_embeddings = pd.read_parquet(dataset_path + "/user_embeddings.parquet", engine = 'fastparquet') 14 | list_embeddings = ["embedding_"+str(i) for i in range(len(user_embeddings[embeddings_version + "_embeddings"][0]))] 15 | user_embeddings[list_embeddings] = pd.DataFrame(user_embeddings[embeddings_version + "_embeddings"].tolist(), index= user_embeddings.index) 16 | user_embeddings_values = user_embeddings[list_embeddings].values 17 | 18 | # clustering 19 | 20 | kmeans = KMeans(n_clusters=nb_clusters, random_state=random_state, max_iter = max_iter, n_jobs=None, precompute_distances='auto').fit(user_embeddings_values) 21 | with open(master_path + "/" + clustering_path + "/" + clusters_filename, "wb") as f: 22 | pickle.dump(kmeans, f) 23 | 24 | # top songs by cluster 25 | 26 | user_embeddings["cluster"] = kmeans.labels_ 27 | user_clusters = user_embeddings[["user_index", "cluster"]] 28 | 29 | # load songs listened between D1 and D30 for train users 30 | 31 | features_train_path = dataset_path + "/user_features_train_" + embeddings_version + ".parquet" 32 | features_train = pd.read_parquet(features_train_path, engine = 'fastparquet').fillna(0) 33 | features_train = features_train.sort_values("user_index") 34 | features_train = features_train.reset_index(drop=True)#to check it is ok for train data 35 | 36 | listd1d30 = features_train[["user_index", "d1d30_songs"]] 37 | listd1d30 = pd.merge(listd1d30, user_clusters, left_on = "user_index", right_on = "user_index") 38 | listd1d30_exploded = listd1d30.explode('d1d30_songs') 39 | listd1d30_exploded["count"] = np.ones(len(listd1d30_exploded)) 40 | listd1d30_by_cluster = pd.DataFrame(listd1d30_exploded.groupby(["cluster", "d1d30_songs"])['count'].count()) 41 | 42 | # most popular songs by cluster 43 | 44 | nb_songs = config["nb_songs"] 45 | arrays = (np.repeat(np.arange(nb_clusters), repeats = nb_songs), np.tile(np.arange(nb_songs), nb_clusters)) 46 | tuples = list(zip(*arrays)) 47 | index_perso = pd.MultiIndex.from_tuples(tuples, names=["cluster", "song_index"]) 48 | df = pd.DataFrame(index=["default"], columns=index_perso).T 49 | both = pd.concat([listd1d30_by_cluster, df], axis=1)[["count"]].fillna(0) 50 | both = both.reset_index(drop=False) 51 | both.columns = ["cluster", "song_index", "nb_streams"] 52 | data_by_cluster = pd.DataFrame(both.groupby("cluster")['nb_streams'].sum()) 53 | data_by_cluster.columns = ["nb_streams_by_cluster"] 54 | data_by_cluster_and_song = pd.merge(both, data_by_cluster, left_on = "cluster", right_on = "cluster") 55 | data_by_cluster_and_song["segment_proba"]=data_by_cluster_and_song["nb_streams"]/data_by_cluster_and_song["nb_streams_by_cluster"] 56 | 57 | if not os.path.exists("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)): 58 | os.mkdir("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)) 59 | for cluster_id in range(nb_clusters): 60 | if cluster_id % 100 == 0: 61 | print("probas by cluster and by song computed for cluster : "+ str(cluster_id)) 62 | list_proba = [] 63 | for song_index in range(nb_songs): 64 | list_proba.append(data_by_cluster_and_song.iloc[cluster_id*nb_songs+song_index]["segment_proba"]) 65 | pickle.dump(list_proba, open("{}/{}/list_proba_{}.pkl".format(master_path, clustering_path + "_probas_" + embeddings_version, cluster_id), "wb")) 66 | 67 | 68 | def train_inputfeatureskmeans(dataset_path, master_path, clustering_path, nb_clusters, max_iter, random_state, nb_songs, embeddings_version="svd", clusters_filename=None): 69 | 70 | # train features 71 | 72 | user_features_train = pd.read_parquet(dataset_path+"/user_features_train_"+embeddings_version+".parquet", engine = 'fastparquet') 73 | features_train = user_features_train.fillna(0).sort_values("user_index") 74 | features_train_ = features_train.values[:,2:] 75 | transformer = Normalizer().fit(features_train_) 76 | X_train = transformer.transform(features_train_) 77 | 78 | # clustering 79 | 80 | kmeans = KMeans(n_clusters=nb_clusters, random_state=random_state, max_iter = max_iter, n_jobs=None, precompute_distances='auto').fit(X_train) 81 | with open(master_path + "/" + clustering_path + "/" + clusters_filename, "wb") as f: 82 | pickle.dump(kmeans, f) 83 | 84 | # top songs by cluster 85 | 86 | features_train["cluster"] = kmeans.labels_ 87 | listd1d30 = features_train[["user_index", "d1d30_songs", "cluster"]] 88 | listd1d30_exploded = listd1d30.explode('d1d30_songs') 89 | listd1d30_exploded["count"] = np.ones(len(listd1d30_exploded)) 90 | listd1d30_by_cluster = pd.DataFrame(listd1d30_exploded.groupby(["cluster", "d1d30_songs"])['count'].count()) 91 | 92 | #we need to create empty dataframe with all combination of songs + clusters 93 | arrays = (np.repeat(np.arange(nb_clusters), repeats = nb_songs), 94 | np.tile(np.arange(nb_songs), nb_clusters)) 95 | tuples = list(zip(*arrays)) 96 | index_perso = pd.MultiIndex.from_tuples(tuples, names=["cluster", "song_index"]) 97 | 98 | df = pd.DataFrame(index=["test"], columns=index_perso).T 99 | 100 | both = pd.concat([listd1d30_by_cluster, df], axis=1)[["count"]].fillna(0) 101 | both = both.reset_index(drop=False) 102 | both.columns = ["cluster", "song_index", "nb_streams"] 103 | 104 | data_by_cluster = pd.DataFrame(both.groupby("cluster")['nb_streams'].sum()) 105 | data_by_cluster.columns = ["nb_streams_by_cluster"] 106 | data_by_cluster_and_song = pd.merge(both, data_by_cluster, left_on = "cluster", 107 | right_on = "cluster") 108 | data_by_cluster_and_song["segment_proba"]=data_by_cluster_and_song["nb_streams"]/data_by_cluster_and_song["nb_streams_by_cluster"] 109 | 110 | if not os.path.exists("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)): 111 | os.mkdir("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)) 112 | for cluster_id in range(nb_clusters): 113 | if cluster_id % 10 == 0: 114 | print("probas by cluster and by song computed for cluster : "+ str(cluster_id)) 115 | list_proba = [] 116 | for song_index in range(nb_songs): 117 | list_proba.append(data_by_cluster_and_song.iloc[cluster_id*nb_songs+song_index]["segment_proba"]) 118 | pickle.dump(list_proba, open("{}/{}/list_proba_{}.pkl".format(master_path, clustering_path + "_probas_" + embeddings_version, cluster_id), "wb")) -------------------------------------------------------------------------------- /model_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | import torch.nn 6 | import time 7 | import pickle 8 | import random 9 | from model import RegressionTripleHidden 10 | from options import config 11 | 12 | 13 | def training(dataset_path, master_path, embeddings_version="svd", eval=True, model_save=True, model_filename=None): 14 | 15 | use_cuda = config['use_cuda'] 16 | cuda_number = config['device_number'] 17 | cuda = torch.device(cuda_number) 18 | target_dim = config['embeddings_dim'] 19 | input_dim = config['input_dim'] 20 | nb_epochs = config['nb_epochs'] 21 | learning_rate = config['learning_rate'] 22 | reg_param = config['reg_param'] 23 | drop_out = config['drop_out'] 24 | batch_size = config['batch_size'] 25 | eval_every = config['eval_every'] 26 | k_val = config['k_val'] 27 | 28 | if not os.path.exists(master_path + "/" + model_filename + ".pt"): 29 | 30 | print("--- no model pre-existing for "+embeddings_version+" : training regression model running ---") 31 | 32 | # Load training dataset. 33 | training_set_size = int(len(os.listdir("{}/{}/train".format(master_path, embeddings_version))) / 2) 34 | train_xs = [] 35 | train_ys = [] 36 | for idx in range(training_set_size): 37 | train_xs.append(pickle.load(open("{}/{}/train/x_train_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 38 | train_ys.append(pickle.load(open("{}/{}/train/y_train_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 39 | total_dataset = list(zip(train_xs, train_ys)) 40 | del(train_xs, train_ys) 41 | 42 | if eval: 43 | 44 | # Load validation dataset. 45 | 46 | validation_set_size = int(len(os.listdir("{}/{}/validation".format(master_path, embeddings_version))) / 3) 47 | validation_xs = [] 48 | listened_songs_validation_ys = [] 49 | for idx in range(validation_set_size): 50 | validation_xs.append(pickle.load(open("{}/{}/validation/x_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 51 | listened_songs_validation_ys.append(pickle.load(open("{}/{}/validation/y_listened_songs_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 52 | total_validation_dataset = list(zip(validation_xs, listened_songs_validation_ys)) 53 | del(validation_xs, listened_songs_validation_ys) 54 | 55 | # Load song embeddings for evaluation 56 | 57 | song_embeddings_path = dataset_path + "/song_embeddings.parquet" 58 | song_embeddings = pd.read_parquet(song_embeddings_path, engine = 'fastparquet') 59 | list_features = ["feature_" + str(i) for i in range(len(song_embeddings["features_" + embeddings_version][0]))] 60 | song_embeddings[list_features] = pd.DataFrame(song_embeddings["features_" + embeddings_version].tolist(), index= song_embeddings.index) 61 | song_embeddings_values = song_embeddings[list_features].values 62 | song_embeddings_values_ = torch.FloatTensor(song_embeddings_values.astype(np.float32)) 63 | 64 | if use_cuda: 65 | regression_model = RegressionTripleHidden(input_dim = input_dim, output_dim = target_dim, drop_out = drop_out).cuda(device = cuda) 66 | else: 67 | regression_model = RegressionTripleHidden(input_dim = input_dim, output_dim = target_dim, drop_out = drop_out) 68 | criterion = torch.nn.MSELoss() 69 | optimizer = torch.optim.Adam(regression_model.parameters(), lr = learning_rate, weight_decay=reg_param ) 70 | 71 | print("training set size : "+str(training_set_size)) 72 | print("validation set size : "+str(validation_set_size)) 73 | print("input dimension : " + str(input_dim)) 74 | print("regression model : "+ str(regression_model)) 75 | print("training running") 76 | 77 | loss_train = [] 78 | 79 | for nb in range(nb_epochs): 80 | print("nb epoch : "+str(nb)) 81 | start_time_epoch = time.time() 82 | random.Random(nb).shuffle(total_dataset) 83 | a,b = zip(*total_dataset) 84 | num_batch = int(training_set_size / batch_size) 85 | if use_cuda: 86 | regression_model = regression_model.to(device = cuda) 87 | for i in range(num_batch): 88 | optimizer.zero_grad() 89 | if use_cuda: 90 | batch_features_tensor = torch.stack(a[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 91 | batch_target_tensor = torch.stack(b[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 92 | else: 93 | batch_features_tensor = torch.stack(a[batch_size*i:batch_size*(i+1)]) 94 | batch_target_tensor = torch.stack(b[batch_size*i:batch_size*(i+1)]) 95 | output_tensor = regression_model(batch_features_tensor) 96 | loss = criterion(output_tensor, batch_target_tensor) 97 | loss.backward() 98 | optimizer.step() 99 | loss_train.append(loss.item()) 100 | print('epoch ' + str(nb) + " training loss : "+ str(sum(loss_train)/float(len(loss_train)))) 101 | print("--- seconds ---" + str(time.time() - start_time_epoch)) 102 | 103 | if nb != 0 and (nb % eval_every == 0 or nb == nb_epochs - 1): 104 | print('testing model') 105 | start_time_eval = time.time() 106 | reg = regression_model.eval() 107 | if use_cuda: 108 | reg = reg.to(device=cuda) 109 | validation_set_size = len(total_validation_dataset) 110 | a,b = zip(*total_validation_dataset) 111 | num_batch_validation = int(validation_set_size / batch_size) 112 | current_precisions = [] 113 | with torch.set_grad_enabled(False): 114 | for i in range(num_batch_validation): 115 | if use_cuda: 116 | batch_features_tensor_validation = torch.stack(a[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 117 | else: 118 | batch_features_tensor_validation = torch.stack(a[batch_size*i:batch_size*(i+1)]) 119 | predictions_validation = reg(batch_features_tensor_validation) 120 | groundtruth_validation = list(b[batch_size*i:batch_size*(i+1)]) 121 | predictions_songs_validation = torch.mm(predictions_validation.cpu(), song_embeddings_values_.transpose(0, 1)) 122 | recommendations_validation = (predictions_songs_validation.topk(k= k_val, dim = 1)[1]).tolist() 123 | precisions = list(map(lambda x, y: len(set(x) & set(y))/float(min(len(x), k_val)), groundtruth_validation, recommendations_validation)) 124 | current_precisions.extend(precisions) 125 | print('epoch ' + str(nb) + " precision test : "+ str(sum(current_precisions) / float(len(current_precisions))) ) 126 | print("--- %s seconds ---" + str(time.time() - start_time_eval)) 127 | print("--- training finished ---") 128 | 129 | if model_save: 130 | print("--- saving model ---") 131 | torch.save(regression_model.state_dict(), master_path + "/" + model_filename + ".pt") 132 | print(regression_model) 133 | print("--- model saved ---") 134 | 135 | else: 136 | print("--- there is already a model pre-existing for "+embeddings_version+" : no need to run training again ---") 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Semi-Personalized System for User Cold Start Recommendation on Music Streaming Apps 2 | 3 | This repository provides Python code to reproduce experiments from the article [A Semi-Personalized System for User Cold Start Recommendation on Music Streaming Apps](https://arxiv.org/pdf/2106.03819.pdf) published in the proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery and Data Mining ([KDD 2021](https://virtual.2021.kdd.org/index.html)). 4 | 5 | ## Recommending Musical Content on Deezer 6 | 7 | Music streaming services heavily rely on recommender systems to improve their users’ experience, by helping them navigate through a large musical catalog and discover new songs, playlists, albums or artists. 8 | 9 | To recommend personalized musical content to users, [Deezer](https://www.deezer.com/) leverages **latent models for collaborative filtering**. In a nutshell, these models aim at learning vector space representations of users and items, a.k.a. **embedding representations**, where proximity should reflect user preferences. More specifically, two different methods, referred to as **UT-ALS** and **TT-SVD** embeddings, are consider in this work. We refer to [our paper](https://arxiv.org/pdf/2106.03819.pdf) for technical details. 10 | 11 |

12 | 13 |

14 | 15 | #### The User Cold Start Problem 16 | 17 | However, performances of such models tend to significantly degrade for new users who only had few interactions with the catalog. They might even become unsuitable for users with no interaction at all, who are absent from user-item matrices in standard algorithms. This is commonly referred to as the **user cold start problem**. Yet, recommending relevant content to these new users is crucial for online services such as Deezer. Indeed, a new user facing low-quality recommendations might have a bad first impression and decide to stop using the service. 18 | 19 | More precisely, our work aim at addressing the following problem: **given an existing latent model for collaborative filtering learning an embedding space from a set of "warm" users, how can we effectively include new "cold" users into this same space, by the end of their registration day on Deezer?** 20 | 21 | Our [KDD paper](https://arxiv.org/pdf/2106.03819.pdf) presents the industrial-scale semi-personalized system recently deployed on Deezer to address this problem. 22 | 23 | #### Offline Evaluation 24 | 25 | This repository provides code and data to reproduce our **offline experiments**. In the paper, an online A/B test on Deezer complements these experiments. 26 | 27 | **Setting:** Through experiments on an offline dataset of 100 000 active users described thereafter, we evaluate to which extent the proposed recommendations at registration day would have matched the actual musical preferences of a set of users on their first month on the service. Specifically, we compute the 50 most relevant music tracks for each user of the dataset, from our proposed model and registration day’s input features (described in [the paper](https://arxiv.org/pdf/2106.03819.pdf)). We compare them to the tracks listened by each user during their next 30 days on Deezer, using three standard recommendation metrics: the **Precision@K**, the **Recall@K**, as well as the Normalized Discounted Cumulative Gain (**NDCG@K**) as a measure of ranking quality. 28 | 29 | ## Installation 30 | 31 | #### Code 32 | 33 | ```Bash 34 | git clone https://github.com/deezer/semi_perso_user_cold_start 35 | cd semi_perso_user_cold_start 36 | ``` 37 | 38 | _Requirements: python 3, torch, matplotlib, numpy, pandas, fastparquet, matplotlib, sklearn, pickle, statistics, random_ 39 | 40 | #### Data 41 | 42 | We release eight datasets : 43 | - `song_embeddings.parquet`: a dataset of 50 000 fully anonymized Deezer songs. Each song is described by: 44 | - a song index (field `song_index`) 45 | - 128-dimensional TT-SVD embedding vector (fields `features_svd`) 46 | - 256-dimensional UT-ALS embedding vector (fields `features_mf`) 47 | - `user_embeddings.parquet`: a dataset of 70 000 fully anonymized Deezer users representing our training dataset. Each user _u_ is described by: 48 | - a user index (field `user_index`) 49 | - a 128-dimensional embedding vector (field `svd_embeddings`) summarizing the user's musical preferences in the TT-SVD space 50 | - a 256-dimensional embedding vector (field `mf_embeddings`) summarizing the user's musical preferences in the UT-ALS space 51 | For each of the embedding space - TT-SVD space or UT-ALS one - we release : 52 | - `user_features_train_svd.parquet` (or `user_features_train_mf.parquet`) with 70 000 fully anonymized Deezer users of the training dataset 53 | - `user_features_validation_svd.parquet` (or `user_features_validation_mf.parquet`) with 20 000 fully anonymized Deezer users of the validation dataset 54 | - `user_features_test_svd.parquet` (or `user_features_test_mf.parquet`) with 10 000 fully anonymized Deezer users of the test dataset 55 | These three last datasets contain the features used to predict the user embedding, from which we deduce the top K songs most liekly to be listened in the 30 days after their regstration day. For each of the dataset, we have the following features fields : 56 | - a user index (field `user_index`) 57 | - a list of the listened songs during the 30 days following their registration day - the groundtruth songs (field `d1d30_songs`) 58 | - a 128 or 256 - dimensional embedding vector for each feature (fields for example for age : `age_embedding0` to `age_embedding128` or `age_embedding256`). Features can be age, country, favorite_songs, favorite_artists, favorite_albums, favorite_playlists, onboarding_like_artists, skipped_songs, skipped_artists, skipped_albums, streamed_songs, streamed_artists, streamed_albums, banned_songs, banned_artists, banned_albums, search_songs, search_artists, search_albums, search_playlists 59 | - a number of interactions done with each action (field for example for the number of favorite songs entered : `nb_favorite_songs`) 60 | - age of the user (field `age_value`) 61 | 62 | 63 | #### Download complete datasets 64 | 65 | Due to size restrictions, the datasets are [available for download on Zenodo](https://zenodo.org/record/5121674#.YQuiitMzaIZ). 66 | 67 | Please download it there and subsequently place it in the `data` folder. 68 | 69 |

70 | 71 |

72 | 73 | ## Run Offline Experiment 74 | 75 | Simulations proceed as detailed in [Section 4.1](https://arxiv.org/pdf/2106.03819.pdf). 76 | 77 | One can choose to run the experiment considering the TT-SVD embedding space or UT-ALS one by changing the parameters in the `options.py` file - by default set up to run on the TT-SVD embedding space. Please set the `use_cuda` option to `False` in the absence of GPU. 78 | 79 | Type in the following command to run offline experiments with similar hyperparameters w.r.t. the paper: 80 | 81 | ```Bash 82 | python main.py 83 | ``` 84 | 85 | Estimated running times with TT-SVD embeddings on a standard machine: 86 | - data generation : 5084 seconds 87 | - training model (with GPU) : 9 seconds 88 | - eval full perso : 1418 seconds 89 | - clustering : 14672 seconds 90 | - eval semi perso : 991 seconds 91 | 92 | One should obtain results (up to slight variations due to randomness) similar to Table 1 from the paper: 93 | 94 |

95 | 96 |

97 | 98 | #### Next Steps 99 | 100 | The current repository runs our proposed **Deezer Full-Personalized** and **Deezer Semi-Personalized** methods. 101 | 102 | Other baselines will be added in the upcoming days. 103 | 104 | 105 | 106 | ## Cite 107 | 108 | Please cite our paper if you use this code or data in your own work: 109 | 110 | ```BibTeX 111 | @inproceedings{briand2021semipersousercoldstart, 112 | title={A Semi-Personalized System for User Cold Start Recommendation on Music Streaming Apps}, 113 | author={Briand, Lea and Salha-Galvan, Guillaume and Bendada, Walid and Morlon, Mathieu and Tran, Viet-Anh}, 114 | booktitle={27th ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD 2021)}, 115 | year={2021} 116 | } 117 | ``` 118 | -------------------------------------------------------------------------------- /model_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | import torch.nn 6 | import pickle 7 | from model import RegressionTripleHidden 8 | from options import config 9 | from sklearn.metrics import ndcg_score, dcg_score 10 | import statistics 11 | from sklearn.preprocessing import Normalizer 12 | 13 | def evaluation(dataset_path, master_path, eval_type = "full_perso", embeddings_version="svd", model_filename=None, clustering_path=None, clusters_filename=None, nb_clusters=config["nb_clusters"]): 14 | 15 | use_cuda = config['use_cuda'] 16 | target_dim = config['embeddings_dim'] 17 | input_dim = config['input_dim'] 18 | k_val_list = config["k_val_list"] 19 | indic_eval_evolution = config["indic_eval_evolution"] 20 | cuda = torch.device(0) 21 | model_filename = master_path + "/" + model_filename + ".pt" 22 | 23 | # Load testing dataset. 24 | print("--- Load testing dataset ---") 25 | testing_set_size = int((len(os.listdir("{}/{}/test".format(master_path, embeddings_version)))) / 3) 26 | test_xs = [] 27 | listened_songs_test_ys = [] 28 | goundtruth_list_test = [] 29 | for idx in range(testing_set_size): 30 | if eval_type in ["full_perso", "semi_perso", "popularity"] : 31 | test_xs.append(pickle.load(open("{}/{}/test/x_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 32 | elif eval_type in ["inputfeatures"] : 33 | vector = pickle.load(open("{}/{}/test/x_{}.pkl".format(master_path, embeddings_version, idx), "rb")) 34 | transformer = Normalizer().fit(vector.reshape(1, -1)) 35 | norm_vector = torch.FloatTensor(transformer.transform(vector.reshape(1, -1))[0]) 36 | test_xs.append(norm_vector) 37 | listened_songs_test_ys.append(pickle.load(open("{}/{}/test/y_listened_songs_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 38 | goundtruth_list_test.append(pickle.load(open("{}/{}/test/groundtruth_list_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 39 | if eval_type in ["avgd0stream"] : 40 | listd1d30streams = pd.read_parquet(dataset_path+"/user_features_test_"+embeddings_version+".parquet", engine ='fastparquet') 41 | colavgd0stream = list(listd1d30streams)[2+target_dim*10:2+target_dim*10+target_dim] 42 | avgd0stream = listd1d30streams[["user_index"]+colavgd0stream] 43 | avgd0stream_df = avgd0stream.set_index("user_index", drop = True).sort_index() 44 | test_xs = avgd0stream_df.values 45 | total_test_dataset = list(zip(test_xs, listened_songs_test_ys, goundtruth_list_test)) 46 | del(test_xs, listened_songs_test_ys, goundtruth_list_test) 47 | print("--- nb of test samples : "+str(len(total_test_dataset))+" ---") 48 | 49 | if eval_type in ["full_perso", "semi_perso", "avgd0stream"] : 50 | 51 | # Load song embeddings 52 | 53 | print("--- Load song embeddings ---") 54 | song_embeddings_path = dataset_path + "/song_embeddings.parquet" 55 | song_embeddings = pd.read_parquet(song_embeddings_path, engine = 'fastparquet').fillna(0) 56 | list_features = ["feature_"+str(i) for i in range(len(song_embeddings["features_" + embeddings_version][0]))] 57 | song_embeddings[list_features] = pd.DataFrame(song_embeddings["features_" + embeddings_version].tolist(), index= song_embeddings.index) 58 | song_embeddings_values = song_embeddings[list_features].values 59 | song_embeddings_values_ = torch.FloatTensor(song_embeddings_values.astype(np.float32)) 60 | print("--- nb of songs : "+str(len(song_embeddings_values_))+" ---") 61 | 62 | if eval_type in ["full_perso", "semi_perso"] : 63 | 64 | # Load model saved 65 | 66 | print("--- Load model ---") 67 | regression_model = RegressionTripleHidden(input_dim = input_dim, output_dim = target_dim) 68 | regression_model.load_state_dict(torch.load(model_filename)) 69 | reg = regression_model.eval() 70 | if use_cuda: 71 | reg = reg.to(device=cuda) 72 | print(reg) 73 | 74 | # if evaluation semi perso : 75 | if eval_type in ["semi_perso"]: 76 | 77 | print("--- Load centroids for semi perso evaluation ---") 78 | #centroids to assign segment 79 | with open(master_path + "/" + clustering_path + "/" + clusters_filename, "rb") as f: 80 | kmeans = pickle.load(f) 81 | centroids = kmeans.cluster_centers_ 82 | centroids_df = pd.DataFrame(centroids) 83 | if use_cuda: 84 | centroid_ = torch.FloatTensor(centroids_df.values).to(device=cuda) 85 | else: 86 | centroid_ = torch.FloatTensor(centroids_df.values) 87 | print("--- nb of centroids : "+str(len(centroid_))+" ---") 88 | 89 | #proba by segment for all song ids 90 | print("--- Load proba by segment for all song ids ---") 91 | song_proba_by_segment = [] 92 | for cluster_id in range(nb_clusters): 93 | song_proba_by_segment.append(pickle.load(open("{}/{}/list_proba_{}.pkl".format(master_path, clustering_path + "_probas_" + embeddings_version, cluster_id), "rb"))) 94 | print("--- nb of proba by segment for all song ids : "+str(len(song_proba_by_segment))+" ---") 95 | 96 | elif eval_type in ["popularity"] : 97 | list_proba = generate_for_popularity_evaluation(dataset_path, embeddings_version="svd") 98 | print("list of probabilities for each song for popularity baseline loaded") 99 | 100 | elif eval_type in ["inputfeatures"]: 101 | 102 | print("--- Load centroids for inputfeatures evaluation ---") 103 | #centroids to assign segment 104 | with open(master_path + "/" + clustering_path + "/" + clusters_filename, "rb") as f: 105 | kmeans = pickle.load(f) 106 | centroids = kmeans.cluster_centers_ 107 | centroids_df = pd.DataFrame(centroids) 108 | if use_cuda: 109 | centroid_ = torch.FloatTensor(centroids_df.values).to(device=cuda) 110 | else: 111 | centroid_ = torch.FloatTensor(centroids_df.values) 112 | cuda = torch.device(0) 113 | print("--- nb of centroids : "+str(len(centroid_))+" ---") 114 | 115 | #proba by segment for all song ids 116 | print("--- Load proba by segment for all song ids ---") 117 | song_proba_by_segment = [] 118 | for cluster_id in range(nb_clusters): 119 | song_proba_by_segment.append(pickle.load(open("{}/{}/list_proba_{}.pkl".format(master_path, clustering_path + "_probas_" + embeddings_version, cluster_id), "rb"))) 120 | print("--- nb of proba by segment for all song ids : "+str(len(song_proba_by_segment))+" ---") 121 | 122 | # Compute evaluation metrics : avg precision, recall and ndcg 123 | 124 | testing_set_size = len(total_test_dataset) 125 | a,b,c = zip(*total_test_dataset) 126 | batch_size = 1 127 | num_batch_test = int(testing_set_size / batch_size) 128 | current_ndcg = {} 129 | current_recalls = {} 130 | current_precisions = {} 131 | for k_val in k_val_list: 132 | current_ndcg[k_val] = [] 133 | for k_val in k_val_list: 134 | current_recalls[k_val] = [] 135 | current_precisions[k_val] = [] 136 | print("--- Evaluation running : average precision, recall and ndcg ---") 137 | print(eval_type) 138 | with torch.set_grad_enabled(False): 139 | for i in range(num_batch_test): 140 | if i % indic_eval_evolution == 0 & i != 0 : 141 | print("eval done for "+str(i)+" users") 142 | if eval_type in ["full_perso", "semi_perso"] : 143 | if use_cuda: 144 | batch_features_tensor_test = torch.stack(a[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 145 | else: 146 | batch_features_tensor_test = torch.stack(a[batch_size*i:batch_size*(i+1)]) 147 | predictions_test = reg(batch_features_tensor_test) 148 | elif eval_type in ["avgd0stream"]: 149 | predictions_test = torch.FloatTensor(a[batch_size*i:batch_size*(i+1)]) 150 | elif eval_type in ["inputfeatures"]: 151 | if use_cuda: 152 | predictions_test = torch.stack(a[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 153 | else: 154 | predictions_test = torch.stack(a[batch_size*i:batch_size*(i+1)]) 155 | # list of song indexes listened by user - index 156 | groundtruth_test_list_id = list(b[batch_size*i:batch_size*(i+1)])[0] 157 | groundtruth_test_list = list(c[batch_size*i:batch_size*(i+1)]) 158 | k_val_max = max(k_val_list) 159 | 160 | if eval_type in ["full_perso", "avgd0stream"] : 161 | proba_values = torch.mm(predictions_test.cpu(), song_embeddings_values_.transpose(0, 1)) 162 | recommended_songs = (proba_values.topk(k= k_val_max, dim = 1)[1]).tolist()[0] 163 | 164 | elif eval_type in ["semi_perso", "inputfeatures"] : 165 | predicted_segment = segment_pred(predictions_test, centroid_, k = 1, cuda_name = cuda)[0] 166 | proba_values = song_proba_by_segment[int(predicted_segment)-1] 167 | recommended_songs = np.argsort(proba_values)[::-1] 168 | 169 | elif eval_type == "popularity" : 170 | proba_values = list_proba 171 | recommended_songs = np.argsort(proba_values)[::-1] 172 | 173 | else : 174 | "error eval_type unknown" 175 | 176 | for k_val in k_val_list: 177 | intersection = set(groundtruth_test_list_id) & set(recommended_songs[:k_val]) 178 | denom_precision = float(len(groundtruth_test_list_id)) if len(groundtruth_test_list_id) < k_val else float(k_val) 179 | precision = len(intersection)/denom_precision 180 | current_precisions[k_val].append(precision) 181 | denom_recall = float(len(groundtruth_test_list_id)) 182 | recall = len(intersection)/denom_recall 183 | current_recalls[k_val].append(recall) 184 | groundtruth_array = np.array(groundtruth_test_list, int) 185 | if eval_type in ["full_perso", "avgd0stream"] : 186 | scores = np.asarray([proba_values.numpy()[0].tolist()]) 187 | elif eval_type in ["semi_perso", "popularity", "inputfeatures"] : 188 | scores = np.asarray([proba_values]) 189 | else : 190 | "error eval_type unknown" 191 | for k_val in k_val_list: 192 | ndcg = ndcg_score(groundtruth_array, scores, k=k_val) 193 | current_ndcg[k_val].append(ndcg) 194 | 195 | print('length dataset : '+str(num_batch_test)) 196 | for keys in current_ndcg.keys(): 197 | print("ndcg at "+ str(keys) +" is : " 198 | + str(sum(current_ndcg[keys])/float(len(current_ndcg[keys])))) 199 | for keys in current_recalls.keys(): 200 | print("recall at "+ str(keys) +" is : " 201 | + str(sum(current_recalls[keys])/float(len(current_recalls[keys])))) 202 | for keys in current_precisions.keys(): 203 | print("precision at "+ str(keys) +" is : " 204 | + str(sum(current_precisions[keys])/float(len(current_precisions[keys])))) 205 | 206 | # standard deviation estimation 207 | 208 | print("--- Evaluation running : standard deviation estimation ---") 209 | print(eval_type) 210 | 211 | max_loc = num_batch_test 212 | nb_iterations_eval_stddev = config["nb_iterations_eval_stddev"] 213 | nb_sub_iterations_eval_stddev = config["nb_sub_iterations_eval_stddev"] 214 | batch_size = int(len(total_test_dataset)/float(nb_sub_iterations_eval_stddev)) 215 | batch_ndcg_list = {} 216 | batch_recall_list = {} 217 | batch_precision_list = {} 218 | for k_val in k_val_list: 219 | batch_ndcg_list[k_val] = [] 220 | batch_recall_list[k_val] = [] 221 | batch_precision_list[k_val] = [] 222 | 223 | for iteration in range(nb_iterations_eval_stddev): 224 | torch.manual_seed(iteration) 225 | randInd = torch.randperm(max_loc) 226 | current_position = 0 227 | for i in range(nb_sub_iterations_eval_stddev): 228 | ending_position = min(current_position + batch_size, max_loc) 229 | for k_val in k_val_list: 230 | batch_recall = pd.DataFrame(current_recalls[k_val]).values[randInd[current_position : ending_position]] 231 | batch_recall_mean = sum(batch_recall)/float(len(batch_recall)) 232 | batch_recall_list[k_val].append(batch_recall_mean[0]) 233 | batch_precision = pd.DataFrame(current_precisions[k_val]).values[randInd[current_position : ending_position]] 234 | batch_precision_mean = sum(batch_precision)/float(len(batch_precision)) 235 | batch_precision_list[k_val].append(batch_precision_mean[0]) 236 | batch_ndcg = pd.DataFrame(current_ndcg[k_val]).values[randInd[current_position : ending_position]] 237 | batch_ndcg_mean = sum(batch_ndcg)/float(len(batch_ndcg)) 238 | batch_ndcg_list[k_val].append(batch_ndcg_mean[0]) 239 | current_position += batch_size 240 | 241 | print('length dataset : '+str(num_batch_test)) 242 | for keys in batch_ndcg_list.keys(): 243 | print("stddev ndcg at "+ str(keys) +" is : " 244 | + str(statistics.stdev(batch_ndcg_list[keys]))) 245 | for keys in batch_recall_list.keys(): 246 | print("stddev recall at "+ str(keys) +" is : " 247 | + str(statistics.stdev(batch_recall_list[keys]))) 248 | for keys in batch_precision_list.keys(): 249 | print("stddev precision at "+ str(keys) +" is : " 250 | + str(statistics.stdev(batch_precision_list[keys]))) 251 | 252 | def segment_pred(target_validation_estimated, centroid_, k = 10, cuda_name = torch.device(0)): 253 | use_cuda = config['use_cuda'] 254 | n1, n2 = target_validation_estimated.size(0), centroid_.size(0) 255 | target_validation_norm_ = torch.sum(target_validation_estimated**2, dim=1) 256 | centroid_norm_ = torch.sum(centroid_**2, dim=1) 257 | centroid_norm_expand = centroid_norm_.expand(n1, n2).t() 258 | target_validation_norm_expand = target_validation_norm_.expand(n2, n1) 259 | product_ = centroid_.mm(target_validation_estimated.t()) 260 | distance = - target_validation_norm_expand - centroid_norm_expand + 2 * product_ 261 | idx = torch.topk(distance, k=k, dim=0)[1].float() 262 | if use_cuda: 263 | results = (idx+ torch.ones(k, target_validation_norm_.size(0)).to(device = cuda_name)).cpu().numpy() 264 | else: 265 | results = (idx+ torch.ones(k, target_validation_norm_.size(0))).numpy() 266 | return results 267 | 268 | def generate_for_popularity_evaluation(dataset_path, embeddings_version="svd"): 269 | 270 | listd1d30streams = pd.read_parquet(dataset_path+"/user_features_train_"+embeddings_version+".parquet", engine = 'fastparquet') 271 | exploded_data = listd1d30streams[["user_index", "d1d30_songs"]].explode('d1d30_songs').set_index('d1d30_songs') 272 | grouped_data = exploded_data.groupby(['d1d30_songs']).size() 273 | popularity_df = pd.DataFrame(grouped_data / float(sum(grouped_data))) 274 | popularity_df.columns = ["proba"] 275 | list_proba = [] 276 | for song_index in range(config["nb_songs"]): 277 | if song_index in popularity_df.index : 278 | list_proba.append(popularity_df.loc[song_index]["proba"]) 279 | else : 280 | list_proba.append(0) 281 | 282 | return list_proba 283 | 284 | --------------------------------------------------------------------------------