├── data └── read_me ├── images ├── zenodo.png ├── homepage.png └── results.png ├── options.py ├── model.py ├── data_generation.py ├── main.py ├── clustering.py ├── model_training.py ├── README.md └── model_evaluation.py /data/read_me: -------------------------------------------------------------------------------- 1 | dowload here the data from https://zenodo.org/record/5121674#.YQuiitMzaIZ 2 | -------------------------------------------------------------------------------- /images/zenodo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deezer/semi_perso_user_cold_start/HEAD/images/zenodo.png -------------------------------------------------------------------------------- /images/homepage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deezer/semi_perso_user_cold_start/HEAD/images/homepage.png -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deezer/semi_perso_user_cold_start/HEAD/images/results.png -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | config = { 2 | # song 3 | 'nb_songs': 50000, 4 | # user 5 | 'embeddings_version': "svd",# can be "svd" for TT-SVD, or "mf" for UT-ALS in deezer data case 6 | 'embeddings_dim': 128,#128 for TT-SVD, or 256 for UT-ALS 7 | # cuda setting 8 | 'use_cuda': True, 9 | 'device_number': 0, 10 | # model setting 11 | 'input_dim': 2579, #2579 & 5139 for TT-SVD and UT-ALS train features respectively 12 | 'nb_epochs': 130, 13 | 'learning_rate': 0.00001, 14 | 'batch_size': 512, 15 | 'reg_param': 0, 16 | 'drop_out': 0, 17 | # model training 18 | 'eval_every': 10, 19 | 'k_val': 50, 20 | #clustering for semi personalization strategy 21 | 'nb_clusters': 1000, 22 | 'max_iter': 20, 23 | 'random_state': 0, 24 | #clustering for inputfeatures baseline 25 | 'nb_clusters_inputfeatures': 100, 26 | 'max_iter': 20, 27 | 'random_state': 0, 28 | # model evaluation 29 | 'k_val_list': [50], 30 | 'nb_iterations_eval_stddev': 2, 31 | 'nb_sub_iterations_eval_stddev': 5, 32 | 'indic_eval_evolution': 1000, 33 | } 34 | 35 | dataset_eval = ["validation", "test"] 36 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn 4 | 5 | class RegressionTripleHidden(torch.nn.Module): 6 | def __init__(self, input_dim, output_dim, first_hidden_dim = 400, second_hidden_dim = 300, third_hidden_dim = 200, drop_out = 0): 7 | super(RegressionTripleHidden, self).__init__() 8 | self.input_dim = input_dim 9 | self.first_hidden_dim = first_hidden_dim 10 | self.second_hidden_dim = second_hidden_dim 11 | self.third_hidden_dim = third_hidden_dim 12 | self.output_dim = output_dim 13 | self.dpin = torch.nn.Dropout(drop_out) 14 | 15 | self.fc1 = torch.nn.Linear(self.input_dim, self.first_hidden_dim) 16 | self.fc1_bn = torch.nn.BatchNorm1d(self.first_hidden_dim) 17 | 18 | self.fc2 = torch.nn.Linear(self.first_hidden_dim, self.second_hidden_dim) 19 | self.fc2_bn = torch.nn.BatchNorm1d(self.second_hidden_dim) 20 | 21 | self.fc3 = torch.nn.Linear(self.second_hidden_dim, self.third_hidden_dim) 22 | self.fc3_bn = torch.nn.BatchNorm1d(self.third_hidden_dim) 23 | 24 | self.fc4 = torch.nn.Linear(self.third_hidden_dim, self.output_dim) 25 | 26 | def forward(self, x): 27 | hidden1 = self.fc1_bn(F.relu((self.fc1(self.dpin(x))))) 28 | hidden2 = self.fc2_bn(F.relu(self.fc2(hidden1))) 29 | hidden3 = self.fc3_bn(F.relu(self.fc3(hidden2))) 30 | output = F.normalize(self.fc4(hidden3), dim = 1) 31 | return output 32 | -------------------------------------------------------------------------------- /data_generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | import torch.nn 5 | import pickle 6 | from options import dataset_eval 7 | 8 | def generate(dataset_path, master_path, embeddings_version): 9 | 10 | #songs 11 | 12 | print("--- songs data ---") 13 | 14 | song_embeddings_path = dataset_path + "/song_embeddings.parquet" 15 | song_embeddings = pd.read_parquet(song_embeddings_path, engine = 'fastparquet') 16 | 17 | if not os.path.exists(master_path+"/m_song_dict.pkl"): 18 | song_dict = {} 19 | for idx, row in song_embeddings.iterrows(): 20 | song_dict[row['song_index']] = idx 21 | pickle.dump(song_dict, open("{}/m_song_dict.pkl".format(master_path), "wb")) 22 | else: 23 | song_dict = pickle.load(open("{}/m_song_dict.pkl".format(master_path), "rb")) 24 | 25 | 26 | # user embeddings (target = only for train users) 27 | 28 | print("--- user embeddings - training dataset ---") 29 | 30 | user_embeddings = pd.read_parquet(dataset_path + "/user_embeddings.parquet", engine = 'fastparquet') 31 | list_embeddings = ["embedding_"+str(i) for i in range(len(user_embeddings[embeddings_version + "_embeddings"][0]))] 32 | user_embeddings[list_embeddings] = pd.DataFrame(user_embeddings[embeddings_version + "_embeddings"].tolist(), index= user_embeddings.index) 33 | 34 | # user features train 35 | 36 | print("--- user features - training dataset ---") 37 | 38 | features_train_path = dataset_path + "/user_features_train_" + embeddings_version + ".parquet" 39 | features_train = pd.read_parquet(features_train_path, engine = 'fastparquet').fillna(0) 40 | features_train = features_train.sort_values("user_index") 41 | features_train = features_train.reset_index(drop=True)#to check it is ok for train data 42 | 43 | # training dataset creation 44 | 45 | dataset = "train" 46 | if not os.path.exists(master_path+"/"): 47 | os.mkdir(master_path+"/") 48 | if not os.path.exists(master_path+"/"+embeddings_version+"/"): 49 | os.mkdir(master_path+"/"+embeddings_version+"/") 50 | if not os.path.exists(master_path+"/"+embeddings_version+"/"+dataset+"/"): 51 | os.mkdir(master_path+"/"+embeddings_version+"/"+dataset+"/") 52 | for idx in range(len(features_train)): 53 | x_train = torch.FloatTensor(features_train.iloc[idx,2:]) 54 | y_train = torch.FloatTensor(user_embeddings[list_embeddings].iloc[idx,:]) 55 | pickle.dump(x_train, open("{}/{}/{}/x_train_{}.pkl".format(master_path, embeddings_version, dataset, idx), "wb")) 56 | pickle.dump(y_train, open("{}/{}/{}/y_train_{}.pkl".format(master_path, embeddings_version, dataset, idx), "wb")) 57 | 58 | # user features validation & test 59 | 60 | print("--- user features - evaluation datasets ---") 61 | 62 | for dataset in dataset_eval : 63 | 64 | print("--- "+dataset+" ---") 65 | 66 | features_validation_path = dataset_path + "/user_features_" + dataset + "_" + embeddings_version + ".parquet" 67 | features_validation = pd.read_parquet(features_validation_path, engine = 'fastparquet').fillna(0) 68 | features_validation = features_validation.sort_values("user_index") 69 | features_validation = features_validation.reset_index(drop=True) 70 | 71 | # validation & test dataset creation 72 | 73 | if not os.path.exists(master_path+"/"+embeddings_version+"/"): 74 | os.mkdir(master_path+"/"+embeddings_version+"/") 75 | if not os.path.exists(master_path+"/"+embeddings_version+"/"+dataset+"/"): 76 | os.mkdir(master_path+"/"+embeddings_version+"/"+dataset+"/") 77 | for i in range(len(features_validation)): 78 | x_validation = torch.FloatTensor(features_validation.iloc[i,2:]) 79 | y_validation = [song_dict[song_index] for song_index in features_validation["d1d30_songs"][i]] 80 | groundtruth_validation_list = [1.0 * (song in y_validation) for song in range(len(song_embeddings))] 81 | pickle.dump(x_validation, open("{}/{}/{}/x_{}.pkl".format(master_path, embeddings_version, dataset, i), "wb")) 82 | pickle.dump(y_validation, open("{}/{}/{}/y_listened_songs_{}.pkl".format(master_path, embeddings_version, dataset, i), "wb")) 83 | pickle.dump(groundtruth_validation_list, open("{}/{}/{}/groundtruth_list_{}.pkl".format(master_path, embeddings_version, dataset, i), "wb")) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data_generation import generate 3 | from model_training import training 4 | from model_evaluation import evaluation 5 | from clustering import train_kmeans, train_inputfeatureskmeans 6 | from options import config 7 | import time 8 | import pandas as pd 9 | import torch 10 | 11 | if __name__ == "__main__": 12 | 13 | try: 14 | pd.io.parquet.get_engine('fastparquet') 15 | except ImportError as e: 16 | print("please install fastparquet first") 17 | 18 | if config['use_cuda'] and not torch.cuda.is_available(): 19 | print("please make cuda gpu available, or set use_cuda=False") 20 | assert torch.cuda.is_available() 21 | 22 | master_path= "./deezer" 23 | dataset_path = os.getcwd() + "/data" 24 | embeddings_version = config["embeddings_version"] 25 | model_filename = "regression_model_" + embeddings_version 26 | clustering_path = "clustering_" + embeddings_version 27 | clusters_filename = "clustering_model" + embeddings_version 28 | inputfeaturesclustering_path = "inputfeaturesclustering_" + embeddings_version 29 | inputfeaturesclusters_filename = "inputfeaturesclustering_model" + embeddings_version 30 | print("--- running for embeddings version " + embeddings_version + " ---") 31 | 32 | if not os.path.exists("{}/".format(master_path)): 33 | os.mkdir("{}/".format(master_path)) 34 | if not os.path.exists(master_path + "/" + embeddings_version + "/"): 35 | print("--- the data has not been generated yet for the embeddings version " + embeddings_version + " : generation running ---") 36 | os.mkdir(master_path + "/" + embeddings_version + "/") 37 | # preparing dataset 38 | print("--- data generation ---") 39 | start_time_data_generation = time.time() 40 | generate(dataset_path, master_path, config['embeddings_version']) 41 | print("--- data generation done ---") 42 | print("--- seconds ---" + str(time.time() - start_time_data_generation)) 43 | else: 44 | print("--- the data has already been generated : no need to regenerate it ---") 45 | 46 | # training model. 47 | print("--- training prediction model ---") 48 | start_time_prediction_model = time.time() 49 | training(dataset_path, master_path, embeddings_version = embeddings_version, eval = True, model_save = True, model_filename = model_filename) 50 | print("--- training prediction model done ---") 51 | print("--- seconds ---" + str(time.time() - start_time_prediction_model)) 52 | 53 | 54 | # evaluation of the model - full personalization strategy. 55 | print("--- full personalisation evaluation ---") 56 | start_time_fullperso_eval = time.time() 57 | evaluation(dataset_path, master_path, eval_type = "full_perso", embeddings_version = embeddings_version, model_filename = model_filename) 58 | print("--- full personalisation evaluation done ---") 59 | print("--- seconds ---" + str(time.time() - start_time_fullperso_eval)) 60 | 61 | # evaluation of the model - semi personalization strategy. 62 | if not os.path.exists("{}/{}/".format(master_path, clustering_path)): 63 | os.mkdir("{}/{}/".format(master_path, clustering_path)) 64 | if not os.path.exists(master_path + "/" + clustering_path + "/" + clusters_filename): 65 | print("--- clustering running ---") 66 | start_time_clustering = time.time() 67 | train_kmeans(dataset_path, master_path, clustering_path, config['nb_clusters'], config['max_iter'], config['random_state'], embeddings_version = embeddings_version, clusters_filename = clusters_filename) 68 | print("--- clustering done ---") 69 | print("--- seconds ---" + str(time.time() - start_time_clustering)) 70 | else: 71 | print("--- no need to do the clustering again ---") 72 | 73 | print("--- semi personalisation evaluation ---") 74 | start_time_semiperso_eval = time.time() 75 | evaluation(dataset_path, master_path, eval_type = "semi_perso", embeddings_version = embeddings_version, model_filename = model_filename, clustering_path=clustering_path, clusters_filename = clusters_filename, nb_clusters=config["nb_clusters"]) 76 | print("--- semi personalisation evaluation done ---") 77 | print("--- seconds ---" + str(time.time() - start_time_semiperso_eval)) 78 | 79 | # popularity baseline. 80 | print("--- popularity baseline evaluation ---") 81 | start_time_popbaseline_eval = time.time() 82 | evaluation(dataset_path, master_path, eval_type = "popularity", embeddings_version = embeddings_version, model_filename = model_filename) 83 | print("--- popularity baseline evaluation done ---") 84 | print("--- seconds ---" + str(time.time() - start_time_popbaseline_eval)) 85 | 86 | # avg d0 stream baseline. 87 | print("--- avg d0 stream baseline evaluation ---") 88 | start_time_avgd0streambaseline_eval = time.time() 89 | evaluation(dataset_path, master_path, eval_type = "avgd0stream", embeddings_version = embeddings_version, model_filename = model_filename) 90 | print("--- avg d0 stream baseline evaluation done ---") 91 | print("--- seconds ---" + str(time.time() - start_time_avgd0streambaseline_eval)) 92 | 93 | # input features clustering baseline. 94 | print("--- input features clustering baseline evaluation ---") 95 | start_time_inputfeatures_clustering = time.time() 96 | if not os.path.exists("{}/{}/".format(master_path, inputfeaturesclustering_path)): 97 | os.mkdir("{}/{}/".format(master_path, inputfeaturesclustering_path)) 98 | if not os.path.exists(master_path + "/" + inputfeaturesclustering_path + "/" + inputfeaturesclusters_filename): 99 | print("--- input features clustering running ---") 100 | start_time_inputfeaturesclustering = time.time() 101 | train_inputfeatureskmeans(dataset_path, master_path, inputfeaturesclustering_path, config['nb_clusters_inputfeatures'], config['max_iter'], config['random_state'], config["nb_songs"], embeddings_version = embeddings_version, clusters_filename = inputfeaturesclusters_filename) 102 | print("--- clustering done ---") 103 | print("--- seconds ---" + str(time.time() - start_time_inputfeatures_clustering)) 104 | else: 105 | print("--- no need to do the clustering again ---") 106 | 107 | print("--- input features clustering baseline evaluation ---") 108 | start_time_inputfeaturesbaseline_eval = time.time() 109 | evaluation(dataset_path, master_path, eval_type = "inputfeatures", embeddings_version = embeddings_version, model_filename = model_filename, clustering_path = inputfeaturesclustering_path, clusters_filename = inputfeaturesclusters_filename, nb_clusters = config['nb_clusters_inputfeatures']) 110 | print("--- input features clustering baseline evaluation done ---") 111 | print("--- seconds ---" + str(time.time() - start_time_inputfeaturesbaseline_eval)) 112 | -------------------------------------------------------------------------------- /clustering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from sklearn.cluster import KMeans 5 | import pickle 6 | from options import config 7 | from sklearn.preprocessing import Normalizer 8 | 9 | def train_kmeans(dataset_path, master_path, clustering_path, nb_clusters, max_iter, random_state, embeddings_version="svd", clusters_filename=None): 10 | 11 | # user embeddings (target = only for train users) 12 | 13 | user_embeddings = pd.read_parquet(dataset_path + "/user_embeddings.parquet", engine = 'fastparquet') 14 | list_embeddings = ["embedding_"+str(i) for i in range(len(user_embeddings[embeddings_version + "_embeddings"][0]))] 15 | user_embeddings[list_embeddings] = pd.DataFrame(user_embeddings[embeddings_version + "_embeddings"].tolist(), index= user_embeddings.index) 16 | user_embeddings_values = user_embeddings[list_embeddings].values 17 | 18 | # clustering 19 | 20 | kmeans = KMeans(n_clusters=nb_clusters, random_state=random_state, max_iter = max_iter, n_jobs=None, precompute_distances='auto').fit(user_embeddings_values) 21 | with open(master_path + "/" + clustering_path + "/" + clusters_filename, "wb") as f: 22 | pickle.dump(kmeans, f) 23 | 24 | # top songs by cluster 25 | 26 | user_embeddings["cluster"] = kmeans.labels_ 27 | user_clusters = user_embeddings[["user_index", "cluster"]] 28 | 29 | # load songs listened between D1 and D30 for train users 30 | 31 | features_train_path = dataset_path + "/user_features_train_" + embeddings_version + ".parquet" 32 | features_train = pd.read_parquet(features_train_path, engine = 'fastparquet').fillna(0) 33 | features_train = features_train.sort_values("user_index") 34 | features_train = features_train.reset_index(drop=True)#to check it is ok for train data 35 | 36 | listd1d30 = features_train[["user_index", "d1d30_songs"]] 37 | listd1d30 = pd.merge(listd1d30, user_clusters, left_on = "user_index", right_on = "user_index") 38 | listd1d30_exploded = listd1d30.explode('d1d30_songs') 39 | listd1d30_exploded["count"] = np.ones(len(listd1d30_exploded)) 40 | listd1d30_by_cluster = pd.DataFrame(listd1d30_exploded.groupby(["cluster", "d1d30_songs"])['count'].count()) 41 | 42 | # most popular songs by cluster 43 | 44 | nb_songs = config["nb_songs"] 45 | arrays = (np.repeat(np.arange(nb_clusters), repeats = nb_songs), np.tile(np.arange(nb_songs), nb_clusters)) 46 | tuples = list(zip(*arrays)) 47 | index_perso = pd.MultiIndex.from_tuples(tuples, names=["cluster", "song_index"]) 48 | df = pd.DataFrame(index=["default"], columns=index_perso).T 49 | both = pd.concat([listd1d30_by_cluster, df], axis=1)[["count"]].fillna(0) 50 | both = both.reset_index(drop=False) 51 | both.columns = ["cluster", "song_index", "nb_streams"] 52 | data_by_cluster = pd.DataFrame(both.groupby("cluster")['nb_streams'].sum()) 53 | data_by_cluster.columns = ["nb_streams_by_cluster"] 54 | data_by_cluster_and_song = pd.merge(both, data_by_cluster, left_on = "cluster", right_on = "cluster") 55 | data_by_cluster_and_song["segment_proba"]=data_by_cluster_and_song["nb_streams"]/data_by_cluster_and_song["nb_streams_by_cluster"] 56 | 57 | if not os.path.exists("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)): 58 | os.mkdir("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)) 59 | for cluster_id in range(nb_clusters): 60 | if cluster_id % 100 == 0: 61 | print("probas by cluster and by song computed for cluster : "+ str(cluster_id)) 62 | list_proba = [] 63 | for song_index in range(nb_songs): 64 | list_proba.append(data_by_cluster_and_song.iloc[cluster_id*nb_songs+song_index]["segment_proba"]) 65 | pickle.dump(list_proba, open("{}/{}/list_proba_{}.pkl".format(master_path, clustering_path + "_probas_" + embeddings_version, cluster_id), "wb")) 66 | 67 | 68 | def train_inputfeatureskmeans(dataset_path, master_path, clustering_path, nb_clusters, max_iter, random_state, nb_songs, embeddings_version="svd", clusters_filename=None): 69 | 70 | # train features 71 | 72 | user_features_train = pd.read_parquet(dataset_path+"/user_features_train_"+embeddings_version+".parquet", engine = 'fastparquet') 73 | features_train = user_features_train.fillna(0).sort_values("user_index") 74 | features_train_ = features_train.values[:,2:] 75 | transformer = Normalizer().fit(features_train_) 76 | X_train = transformer.transform(features_train_) 77 | 78 | # clustering 79 | 80 | kmeans = KMeans(n_clusters=nb_clusters, random_state=random_state, max_iter = max_iter, n_jobs=None, precompute_distances='auto').fit(X_train) 81 | with open(master_path + "/" + clustering_path + "/" + clusters_filename, "wb") as f: 82 | pickle.dump(kmeans, f) 83 | 84 | # top songs by cluster 85 | 86 | features_train["cluster"] = kmeans.labels_ 87 | listd1d30 = features_train[["user_index", "d1d30_songs", "cluster"]] 88 | listd1d30_exploded = listd1d30.explode('d1d30_songs') 89 | listd1d30_exploded["count"] = np.ones(len(listd1d30_exploded)) 90 | listd1d30_by_cluster = pd.DataFrame(listd1d30_exploded.groupby(["cluster", "d1d30_songs"])['count'].count()) 91 | 92 | #we need to create empty dataframe with all combination of songs + clusters 93 | arrays = (np.repeat(np.arange(nb_clusters), repeats = nb_songs), 94 | np.tile(np.arange(nb_songs), nb_clusters)) 95 | tuples = list(zip(*arrays)) 96 | index_perso = pd.MultiIndex.from_tuples(tuples, names=["cluster", "song_index"]) 97 | 98 | df = pd.DataFrame(index=["test"], columns=index_perso).T 99 | 100 | both = pd.concat([listd1d30_by_cluster, df], axis=1)[["count"]].fillna(0) 101 | both = both.reset_index(drop=False) 102 | both.columns = ["cluster", "song_index", "nb_streams"] 103 | 104 | data_by_cluster = pd.DataFrame(both.groupby("cluster")['nb_streams'].sum()) 105 | data_by_cluster.columns = ["nb_streams_by_cluster"] 106 | data_by_cluster_and_song = pd.merge(both, data_by_cluster, left_on = "cluster", 107 | right_on = "cluster") 108 | data_by_cluster_and_song["segment_proba"]=data_by_cluster_and_song["nb_streams"]/data_by_cluster_and_song["nb_streams_by_cluster"] 109 | 110 | if not os.path.exists("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)): 111 | os.mkdir("{}/{}/".format(master_path, clustering_path + "_probas_" + embeddings_version)) 112 | for cluster_id in range(nb_clusters): 113 | if cluster_id % 10 == 0: 114 | print("probas by cluster and by song computed for cluster : "+ str(cluster_id)) 115 | list_proba = [] 116 | for song_index in range(nb_songs): 117 | list_proba.append(data_by_cluster_and_song.iloc[cluster_id*nb_songs+song_index]["segment_proba"]) 118 | pickle.dump(list_proba, open("{}/{}/list_proba_{}.pkl".format(master_path, clustering_path + "_probas_" + embeddings_version, cluster_id), "wb")) -------------------------------------------------------------------------------- /model_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | import torch.nn 6 | import time 7 | import pickle 8 | import random 9 | from model import RegressionTripleHidden 10 | from options import config 11 | 12 | 13 | def training(dataset_path, master_path, embeddings_version="svd", eval=True, model_save=True, model_filename=None): 14 | 15 | use_cuda = config['use_cuda'] 16 | cuda_number = config['device_number'] 17 | cuda = torch.device(cuda_number) 18 | target_dim = config['embeddings_dim'] 19 | input_dim = config['input_dim'] 20 | nb_epochs = config['nb_epochs'] 21 | learning_rate = config['learning_rate'] 22 | reg_param = config['reg_param'] 23 | drop_out = config['drop_out'] 24 | batch_size = config['batch_size'] 25 | eval_every = config['eval_every'] 26 | k_val = config['k_val'] 27 | 28 | if not os.path.exists(master_path + "/" + model_filename + ".pt"): 29 | 30 | print("--- no model pre-existing for "+embeddings_version+" : training regression model running ---") 31 | 32 | # Load training dataset. 33 | training_set_size = int(len(os.listdir("{}/{}/train".format(master_path, embeddings_version))) / 2) 34 | train_xs = [] 35 | train_ys = [] 36 | for idx in range(training_set_size): 37 | train_xs.append(pickle.load(open("{}/{}/train/x_train_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 38 | train_ys.append(pickle.load(open("{}/{}/train/y_train_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 39 | total_dataset = list(zip(train_xs, train_ys)) 40 | del(train_xs, train_ys) 41 | 42 | if eval: 43 | 44 | # Load validation dataset. 45 | 46 | validation_set_size = int(len(os.listdir("{}/{}/validation".format(master_path, embeddings_version))) / 3) 47 | validation_xs = [] 48 | listened_songs_validation_ys = [] 49 | for idx in range(validation_set_size): 50 | validation_xs.append(pickle.load(open("{}/{}/validation/x_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 51 | listened_songs_validation_ys.append(pickle.load(open("{}/{}/validation/y_listened_songs_{}.pkl".format(master_path, embeddings_version, idx), "rb"))) 52 | total_validation_dataset = list(zip(validation_xs, listened_songs_validation_ys)) 53 | del(validation_xs, listened_songs_validation_ys) 54 | 55 | # Load song embeddings for evaluation 56 | 57 | song_embeddings_path = dataset_path + "/song_embeddings.parquet" 58 | song_embeddings = pd.read_parquet(song_embeddings_path, engine = 'fastparquet') 59 | list_features = ["feature_" + str(i) for i in range(len(song_embeddings["features_" + embeddings_version][0]))] 60 | song_embeddings[list_features] = pd.DataFrame(song_embeddings["features_" + embeddings_version].tolist(), index= song_embeddings.index) 61 | song_embeddings_values = song_embeddings[list_features].values 62 | song_embeddings_values_ = torch.FloatTensor(song_embeddings_values.astype(np.float32)) 63 | 64 | if use_cuda: 65 | regression_model = RegressionTripleHidden(input_dim = input_dim, output_dim = target_dim, drop_out = drop_out).cuda(device = cuda) 66 | else: 67 | regression_model = RegressionTripleHidden(input_dim = input_dim, output_dim = target_dim, drop_out = drop_out) 68 | criterion = torch.nn.MSELoss() 69 | optimizer = torch.optim.Adam(regression_model.parameters(), lr = learning_rate, weight_decay=reg_param ) 70 | 71 | print("training set size : "+str(training_set_size)) 72 | print("validation set size : "+str(validation_set_size)) 73 | print("input dimension : " + str(input_dim)) 74 | print("regression model : "+ str(regression_model)) 75 | print("training running") 76 | 77 | loss_train = [] 78 | 79 | for nb in range(nb_epochs): 80 | print("nb epoch : "+str(nb)) 81 | start_time_epoch = time.time() 82 | random.Random(nb).shuffle(total_dataset) 83 | a,b = zip(*total_dataset) 84 | num_batch = int(training_set_size / batch_size) 85 | if use_cuda: 86 | regression_model = regression_model.to(device = cuda) 87 | for i in range(num_batch): 88 | optimizer.zero_grad() 89 | if use_cuda: 90 | batch_features_tensor = torch.stack(a[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 91 | batch_target_tensor = torch.stack(b[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 92 | else: 93 | batch_features_tensor = torch.stack(a[batch_size*i:batch_size*(i+1)]) 94 | batch_target_tensor = torch.stack(b[batch_size*i:batch_size*(i+1)]) 95 | output_tensor = regression_model(batch_features_tensor) 96 | loss = criterion(output_tensor, batch_target_tensor) 97 | loss.backward() 98 | optimizer.step() 99 | loss_train.append(loss.item()) 100 | print('epoch ' + str(nb) + " training loss : "+ str(sum(loss_train)/float(len(loss_train)))) 101 | print("--- seconds ---" + str(time.time() - start_time_epoch)) 102 | 103 | if nb != 0 and (nb % eval_every == 0 or nb == nb_epochs - 1): 104 | print('testing model') 105 | start_time_eval = time.time() 106 | reg = regression_model.eval() 107 | if use_cuda: 108 | reg = reg.to(device=cuda) 109 | validation_set_size = len(total_validation_dataset) 110 | a,b = zip(*total_validation_dataset) 111 | num_batch_validation = int(validation_set_size / batch_size) 112 | current_precisions = [] 113 | with torch.set_grad_enabled(False): 114 | for i in range(num_batch_validation): 115 | if use_cuda: 116 | batch_features_tensor_validation = torch.stack(a[batch_size*i:batch_size*(i+1)]).cuda(device = cuda) 117 | else: 118 | batch_features_tensor_validation = torch.stack(a[batch_size*i:batch_size*(i+1)]) 119 | predictions_validation = reg(batch_features_tensor_validation) 120 | groundtruth_validation = list(b[batch_size*i:batch_size*(i+1)]) 121 | predictions_songs_validation = torch.mm(predictions_validation.cpu(), song_embeddings_values_.transpose(0, 1)) 122 | recommendations_validation = (predictions_songs_validation.topk(k= k_val, dim = 1)[1]).tolist() 123 | precisions = list(map(lambda x, y: len(set(x) & set(y))/float(min(len(x), k_val)), groundtruth_validation, recommendations_validation)) 124 | current_precisions.extend(precisions) 125 | print('epoch ' + str(nb) + " precision test : "+ str(sum(current_precisions) / float(len(current_precisions))) ) 126 | print("--- %s seconds ---" + str(time.time() - start_time_eval)) 127 | print("--- training finished ---") 128 | 129 | if model_save: 130 | print("--- saving model ---") 131 | torch.save(regression_model.state_dict(), master_path + "/" + model_filename + ".pt") 132 | print(regression_model) 133 | print("--- model saved ---") 134 | 135 | else: 136 | print("--- there is already a model pre-existing for "+embeddings_version+" : no need to run training again ---") 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Semi-Personalized System for User Cold Start Recommendation on Music Streaming Apps 2 | 3 | This repository provides Python code to reproduce experiments from the article [A Semi-Personalized System for User Cold Start Recommendation on Music Streaming Apps](https://arxiv.org/pdf/2106.03819.pdf) published in the proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery and Data Mining ([KDD 2021](https://virtual.2021.kdd.org/index.html)). 4 | 5 | ## Recommending Musical Content on Deezer 6 | 7 | Music streaming services heavily rely on recommender systems to improve their users’ experience, by helping them navigate through a large musical catalog and discover new songs, playlists, albums or artists. 8 | 9 | To recommend personalized musical content to users, [Deezer](https://www.deezer.com/) leverages **latent models for collaborative filtering**. In a nutshell, these models aim at learning vector space representations of users and items, a.k.a. **embedding representations**, where proximity should reflect user preferences. More specifically, two different methods, referred to as **UT-ALS** and **TT-SVD** embeddings, are consider in this work. We refer to [our paper](https://arxiv.org/pdf/2106.03819.pdf) for technical details. 10 | 11 |
12 |
13 |
70 |
71 |
95 |
96 |