├── .gitignore ├── MLproject ├── README.md ├── artefacts └── .gitignore ├── conda.yaml ├── data ├── articles_df.csv ├── interactions_full_df.csv ├── interactions_test_df.csv ├── interactions_train_df.csv ├── predict.csv └── raw │ └── rating.csv ├── data_preparation.py ├── docs ├── all_metrics.png ├── all_metrics_plot.png ├── background.png ├── deep_autoenc1.png ├── deep_autoenc2.png ├── hist_metrics.png ├── metrics.png ├── model_summary_1.png ├── model_summary_2.png ├── model_summary_3.png └── train_hist.png ├── evaluation ├── __init__.py ├── metrics.py └── model_evaluator.py ├── model ├── AutoEncContentModel.py ├── AutoEncModel.py ├── BaseModel.py ├── CDAEModel.py └── __init__.py ├── notebooks ├── DeepAutoEncoder - Simple Train.ipynb ├── DeepAutoEncoderContent - Simple Train.ipynb ├── Group String Input.ipynb └── Plot Metrics.ipynb ├── popularity_train.py ├── recommender.py ├── report.txt ├── train.py ├── train_all.sh └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | mlruns/* 3 | .ipynb_checkpoints 4 | derby.log 5 | metastore_db -------------------------------------------------------------------------------- /MLproject: -------------------------------------------------------------------------------- 1 | # deep_recsys/ 2 | # |__ MLproject 3 | # |__ conda.yaml 4 | # |__ train.py 5 | # |__ ... 6 | 7 | name: recsys_autoencoders 8 | conda_env: conda.yaml 9 | 10 | entry_points: 11 | main: 12 | parameters: 13 | name: {type: string, default: 'auto_enc'} 14 | factors: {type: int, default: 15} 15 | layers: {type: string, default: '[128,256,128]'} 16 | epochs: {type: int, default: 100} 17 | batch: {type: int, default: 64} 18 | activation: {type: string, default: 'selu'} 19 | dropout: {type: float, default: 0.8} 20 | lr: {type: float, default: 0.0001} 21 | reg: {type: float, default: 0.01} 22 | command: "python train.py 23 | --name {name} 24 | --factors {factors} 25 | --layers {layers} 26 | --epochs {epochs} 27 | --batch {batch} 28 | --activation {activation} 29 | --dropout {dropout} 30 | --lr {lr} 31 | --reg {reg}" 32 | 33 | data_preparation: 34 | parameters: 35 | min_interactions: {type: int, default: 5} 36 | factor_negative_sample: {type: int, default: 0} 37 | test_size: {type: float, default: 0.2} 38 | command: "python data_preparation.py 39 | --min_interactions {min_interactions} 40 | --test_size {test_size} 41 | --factor_negative_sample {factor_negative_sample}" 42 | 43 | popularity_train: 44 | command: "python popularity_train.py" 45 | 46 | recommender: 47 | parameters: 48 | name: {type: string, default: 'selu'} 49 | model_path: {type: string, default: 'selu'} 50 | user_id: {type: int, default: 1} 51 | topn: {type: int, default: 10} 52 | view: {type: int, default: 0} 53 | output: {type: string, default: './data/predict.csv'} 54 | command: "python recommender.py 55 | --name {name} 56 | --model_path {model_path} 57 | --user_id {user_id} 58 | --topn {topn} 59 | --view {view} 60 | --output {output}" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep AutoEncoders for Collaborative Filtering 2 | 3 | Collaborative Filtering is a method used by recommender systems to make predictions about an interest of an specific user by collecting taste or preferences information from many other users. The technique of Collaborative Filtering has the underlying assumption that if a user A has the same taste or opinion on an issue as the person B, A is more likely to have B’s opinion on a different issue. 4 | 5 | **This project implements different Deep Autoencoder for Collaborative Filtering for Recommendation Systems in Keras** based on different articles. The test case uses the Steam Platform interactions dataset to recommend games for users. 6 | 7 | ![Steam](docs/background.png) 8 | 9 | - [Deep AutoEncoders for Collaborative Filtering](#deep-autoencoders-for-collaborative-filtering) 10 | - [Requirements](#requirements) 11 | - [Getting Started](#getting-started) 12 | - [Datasets](#datasets) 13 | - [Data Preparation](#data-preparation) 14 | - [Model Training](#model-training) 15 | - [Implemented Recommender Models](#implemented-recommender-models) 16 | - [1. Popularity Model](#1-popularity-model) 17 | - [2. CDAE - Collaborative Denoising Auto-Encoders for Top-N Recommender Systems](#2-cdae---collaborative-denoising-auto-encoders-for-top-n-recommender-systems) 18 | - [3. Deep AutoEncoder for Collaborative Filtering](#3-deep-autoencoder-for-collaborative-filtering) 19 | - [4. Deep AutoEncoder for Collaborative Filtering With Content Information](#4-deep-autoencoder-for-collaborative-filtering-with-content-information) 20 | - [Training Results](#training-results) 21 | - [Evaluation](#evaluation) 22 | - [Recommender](#recommender) 23 | - [Rerefences](#rerefences) 24 | 25 | ## Requirements 26 | 27 | Create a conda env from ```conda.yaml``` 28 | 29 | * python=3.6 30 | * cloudpickle=0.6.1 31 | * numpy=1.16.4 32 | * pandas=0.24.2 33 | * scikit-learn=0.20.1 34 | * seaborn=0.9 35 | * click=6.7 36 | * tensorflow-gpu==2.0 37 | * scipy=1.2.1 38 | * graphviz 39 | * pydotplus 40 | 41 | ## Getting Started 42 | 43 | This project uses [MLflow](https://www.mlflow.org) for reproducibility, model training and dependency management. Install mlflow before: 44 | 45 | ``` 46 | $ pip install mlflow 47 | ``` 48 | 49 | ``` 50 | $ git clone https://github.com/marlesson/recsys_autoencoders.git 51 | ``` 52 | 53 | ### Datasets 54 | 55 | The dataset used in this project is Steam-Vide-Games obtained from https://www.kaggle.com/tamber/steam-video-games. 56 | 57 | This dataset is a list of user behaviors, with columns:`user_id`, `game`, `type`, `hours`, `none`. The type included are 'purchase' and 'play'. The value indicates the degree to which the behavior was performed - in the case of 'purchase' the value is always 1, and in the case of 'play' the value represents the number of hours the user has played the game. 58 | 59 | `./data/raw/rating.csv` 60 | 61 | | user_id | game | type | hours | none | 62 | | -------- | ---------------- | ---------------- | ----------- | ------------ | 63 | | 151603712 | "The Elder Scrolls V Skyrim" | purchase | 1.0 | 0 | 64 | | 151603712 | "The Elder Scrolls V Skyrim" | play | 273.0 | 0 | 65 | | 151603712 | "Fallout 4" | purchase | 1.0 | 0 | 66 | | ... | ... | ... | ... | ... | 67 | 68 | 69 | ### Data Preparation 70 | 71 | The data preparation process transforms the original dataset, groups the implicit feedbacks and interactions, and creates specific datasets for training and model testing. 72 | 73 | ``` 74 | $ mlflow run . -e data_preparation -P min_interactions=5 -P test_size=0.2 75 | ``` 76 | 77 | Datasets created: 78 | * ./data/articles_df.csv 79 | * ./data/interactions_full_df.csv 80 | * ./data/interactions_train_df.csv (Subset of 'interactions_full_df.csv' for train) 81 | * ./data/interactions_test_df.csv (Subset of 'interactions_full_df.csv' for test) 82 | 83 | `articles_df.csv` contain the data exclusively of the items (games). 84 | 85 | | content_id | game | total_users | total_hours | 86 | | -------- | ---------------- | ---------------- | ----------- | 87 | | 0 | 007 Legends | 1 | 1.7 | 88 | | 1 | 0RBITALIS | 3 | 4.2 | 89 | 90 | `interactions_full_df.csv` contain the data of interactions between user X item, amount of hours played (hours) and played (view) as implicit feedback. 91 | 92 | | user_id | content_id | game | hours | view | 93 | | -------- | ---------------- | ---------------- | ----------- | ----------- | 94 | | 134 | 1680 | Far Cry 3 Blood Dragon | 2.2 | 1 | 95 | | 2219 | 1938 | Gone Home | 1.2 | 1 | 96 | | 3315 | 3711 | Serious Sam 3 BFE | 3.7 | 1 | 97 | 98 | ### Model Training 99 | 100 | Parameter ```--name``` indicates the model to be trained. Depending on the model some parameters have no effect. 101 | 102 | ``` 103 | Usage: mlflow run . [OPTIONS] 104 | 105 | Train Autoencoder Matrix Fatorization Model 106 | 107 | Options: 108 | --name [auto_enc|cdae|auto_enc_content] 109 | --factors INTEGER 110 | --layers TEXT 111 | --epochs INTEGER 112 | --batch INTEGER 113 | --activation [relu|elu|selu|sigmoid] 114 | --dropout FLOAT 115 | --lr FLOAT 116 | --reg FLOAT 117 | ``` 118 | 119 | #### Implemented Recommender Models 120 | 121 | **This is an adapted implementation of the original article, simplifying some features for a better understanding of the models.** 122 | 123 | ##### 1. Popularity Model 124 | 125 | This model makes recommendations using the most popular games, the ones that had the most purchases in a period. This recommendation is not personalized, that is, it is the same for all users 126 | 127 | This is a Base Model that will be used to compare with AutoEncoders Models. 128 | 129 | Run training: 130 | ``` 131 | $ mlflow run . -e popularity_train 132 | ``` 133 | 134 | ##### 2. CDAE - Collaborative Denoising Auto-Encoders for Top-N Recommender Systems 135 | 136 | > Yao Wu, Christopher DuBois, Alice X. Zheng, Martin Ester. 137 | > Collaborative Denoising Auto-Encoders for Top-N Recommender Systems. 138 | > The 9th ACM International Conference on Web Search and Data Mining (WSDM'16), p153--162, 2016. 139 | > http://alicezheng.org/papers/wsdm16-cdae.pdf 140 | 141 | Run training: 142 | ``` 143 | $ mlflow run . \ 144 | -P activation=selu \ 145 | -P batch=64 \ 146 | -P dropout=0.8 \ 147 | -P epochs=50 \ 148 | -P factors=500 \ 149 | -P lr=0.0001 \ 150 | -P name=cdae \ 151 | -P reg=0.0001 152 | ``` 153 | 154 | ![Steam](docs/model_summary_1.png) 155 | 156 | ##### 3. Deep AutoEncoder for Collaborative Filtering 157 | 158 | > KUCHAIEV, Oleksii; GINSBURG, Boris. 159 | > Training deep autoencoders for collaborative filtering. 160 | > arXiv preprint arXiv:1708.01715, 2017. 161 | > https://arxiv.org/pdf/1708.01715.pdf 162 | 163 | Run training: 164 | ``` 165 | $ mlflow run . \ 166 | -P activation=selu \ 167 | -P batch=64 \ 168 | -P dropout=0.8 \ 169 | -P epochs=50 \ 170 | -P layers='[512,256,512]' \ 171 | -P lr=0.0001 \ 172 | -P name=auto_enc_content \ 173 | -P reg=0.01 174 | ``` 175 | 176 | ![Steam](docs/model_summary_2.png) 177 | 178 | ##### 4. Deep AutoEncoder for Collaborative Filtering With Content Information 179 | 180 | This model is an adaptation of the model presented previously, but adding content information. In this way the model is a Hybrid implementation. 181 | 182 | In this model I add the 'game name' of all games that the user has already played as additional information for collaborative filtering. This is a way to add content information to the user level. 183 | 184 | ``` 185 | $ mlflow run . \ 186 | -P activation=selu \ 187 | -P batch=64 \ 188 | -P dropout=0.8 \ 189 | -P epochs=50 \ 190 | -P layers='[512,256,512]' \ 191 | -P lr=0.0001 \ 192 | -P name=auto_enc \ 193 | -P reg=0.01 194 | ``` 195 | 196 | ![Steam](docs/model_summary_3.png) 197 | 198 | #### Training Results 199 | 200 | After the trained model, the artifacts (model, metrics, graphics, logs) will be saved in `./mlruns/0//` 201 | 202 | ![Hist](docs/hist_metrics.png) 203 | 204 | If you want to run the training for all models, run the script `$ ./train_all.sh` 205 | 206 | ### Evaluation 207 | 208 | All models were evaluated with different RecSys metrics. After train use Mlflow to view metrics in UI. 209 | 210 | ``` 211 | $ mlflow ui 212 | ``` 213 | ![Metrics](docs/all_metrics.png) 214 | 215 | ![Metrics](docs/all_metrics_plot.png) 216 | 217 | ### Recommender 218 | 219 | Uses a trained AutoEncoder (`--model_path`) to recommend games for the user (´--user_id´). 220 | 221 | ``` 222 | Usage: mlflow run . -e recommender [OPTIONS] 223 | 224 | Recommender Matrix Fatorization Model 225 | 226 | Options: 227 | --name [auto_enc|cdae|auto_enc_content] 228 | --model_path TEXT 229 | --user_id INTEGER 230 | --topn INTEGER 231 | --view INTEGER (Recommend items already viewed) 232 | --output TEXT 233 | ``` 234 | 235 | ``` 236 | $ mlflow run . -e recommender \ 237 | -P name='auto_enc' \ 238 | -P model_path='mlruns/0//artifacts/auto_enc' \ 239 | -P topn=10 \ 240 | -P view=0 \ 241 | -P user_id=25 242 | 243 | ... 244 | 245 | score content_id game 246 | 0 1.029705 2457 Left 4 Dead 2 247 | 1 1.013235 2326 Just Cause 2 248 | 2 0.979504 975 Counter-Strike Global Offensive 249 | 3 0.979452 4675 Trine 2 250 | 4 0.909918 2355 Killing Floor 251 | 5 0.904026 2690 Metro 2033 252 | 6 0.897048 2063 Half-Life Opposing Force 253 | 7 0.892517 2061 Half-Life Blue Shift 254 | 8 0.857984 4240 Terraria 255 | 9 0.840588 2055 Half-Life 2 256 | ``` 257 | 258 | ## Rerefences 259 | 260 | * https://www.kaggle.com/gspmoreira/recommender-systems-in-python-101 261 | * https://github.com/statisticianinstilettos/recmetrics/blob/master/recmetrics/metrics.py 262 | * https://github.com/benfred/implicit/ 263 | * https://github.com/NVIDIA/DeepRecommender -------------------------------------------------------------------------------- /artefacts/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | 4 | -------------------------------------------------------------------------------- /conda.yaml: -------------------------------------------------------------------------------- 1 | name: recsys_autoencoders 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.6 6 | - cloudpickle=0.6.1 7 | - numpy=1.16.4 8 | - pandas=0.24.2 9 | - scikit-learn=0.20.1 10 | - seaborn=0.9 11 | - click=6.7 12 | - tensorflow-gpu==2.0 13 | - scipy=1.2.1 14 | - graphviz 15 | - pydotplus 16 | - pip: 17 | - mlflow 18 | -------------------------------------------------------------------------------- /data/predict.csv: -------------------------------------------------------------------------------- 1 | score,content_id,game 2 | 0.6950495,445,Bastion 3 | 0.67773813,2429,LIMBO 4 | 0.6305973,2057,Half-Life 2 Episode One 5 | 0.5916952,1146,Dead Space 6 | 0.5745701,2590,Magicka 7 | 0.5396258,1108,Darksiders 8 | 0.53595227,928,Company of Heroes Tales of Valor 9 | 0.53326076,2738,Monaco 10 | 0.5249776,4222,Team Fortress Classic 11 | 0.51937866,1623,FTL Faster Than Light 12 | -------------------------------------------------------------------------------- /data_preparation.py: -------------------------------------------------------------------------------- 1 | # Dataset - Steam Video Games 2 | # https://www.kaggle.com/tamber/steam-video-games 3 | # 4 | 5 | import click 6 | import pandas as pd 7 | from util import * 8 | from sklearn.model_selection import train_test_split 9 | 10 | # event_type_strength = { 11 | # 'purchase': 1, 12 | # 'play': 1 13 | # } 14 | 15 | @click.command(help="Processes the source data to create new data for the model") 16 | @click.option("--min_interactions", type=click.INT, default=5, 17 | help="Minimun number of interactions for user.") 18 | @click.option("--test_size", type=click.FLOAT, default=0.2) 19 | @click.option("--factor_negative_sample", type=click.INT, default=0) 20 | def run(min_interactions, test_size, factor_negative_sample): 21 | base_path = './data/raw/' 22 | 23 | # Contains logs of user interactions on shared articles 24 | interactions_df = pd.read_csv(base_path+'/rating.csv', index_col=None, header=None) 25 | interactions_df.columns=['user_id', 'game', 'type', 'hours', 'none'] 26 | 27 | # Group interations by user_id and game 28 | interactions_full_df = interactions_df.groupby(['user_id', 'game'])\ 29 | .sum()['hours'].reset_index() 30 | 31 | interactions_full_df['view'] = 1 # define 32 | 33 | # Filter interactions 34 | interactions_full_df = filter_interactions(interactions_full_df, min_interactions) 35 | 36 | # Define dummy ID 37 | interactions_full_df['content_id'] = interactions_full_df['game'].astype('category').cat.codes 38 | interactions_full_df['user_id'] = interactions_full_df['user_id'].astype('category').cat.codes 39 | 40 | # Create a DataFrame with Content Information 41 | articles_df = interactions_full_df.groupby(['game', 'content_id'])\ 42 | .agg({'user_id': 'count', 'hours': np.sum})[['user_id','hours']]\ 43 | .reset_index()\ 44 | .rename(columns={'user_id': 'total_users', 'hours': 'total_hours'}) 45 | 46 | print('# of unique user/item interactions: %d' % len(interactions_full_df)) 47 | 48 | # Split dataset in Train/Test 49 | interactions_train_df, interactions_test_df = train_test_split(interactions_full_df, 50 | stratify=interactions_full_df['user_id'], 51 | test_size=test_size, 52 | random_state=42) 53 | 54 | 55 | print("# size of train dataset before negative sample: %d" % len(interactions_train_df)) 56 | 57 | # If use negative sample then create 58 | interactions_train_df = interactions_with_negative_sample(interactions_train_df, 59 | factor_negative_sample=factor_negative_sample) 60 | 61 | print("# size of train dataset after negative sample: %d" % len(interactions_train_df)) 62 | print("# size of test dataset: %d" % len(interactions_test_df)) 63 | 64 | 65 | interactions_full_df[['user_id','content_id','game','hours','view']].to_csv('./data/interactions_full_df.csv', index = False) 66 | interactions_train_df[['user_id','content_id','game','hours','view']].to_csv('./data/interactions_train_df.csv', index = False) 67 | interactions_test_df[['user_id','content_id','game','hours','view']].to_csv('./data/interactions_test_df.csv', index = False) 68 | articles_df[['content_id', 'game','total_users','total_hours']].to_csv('./data/articles_df.csv', index = False) 69 | 70 | def filter_interactions(interactions_df, min_interactions): 71 | ''' 72 | Filter interactions of users with at least {min_interactions} interactions 73 | ''' 74 | users_interactions_count_df = interactions_df.groupby('user_id').size() 75 | print('# users: %d' % len(users_interactions_count_df)) 76 | 77 | users_with_enough_interactions_df = users_interactions_count_df[users_interactions_count_df >= min_interactions]\ 78 | .reset_index()[['user_id']] 79 | 80 | print('# users with at least %d interactions: %d' % (min_interactions, len(users_with_enough_interactions_df))) 81 | 82 | print('# of interactions: %d' % len(interactions_df)) 83 | interactions_from_selected_users_df = interactions_df.merge(users_with_enough_interactions_df, 84 | how = 'right', 85 | left_on = 'user_id', 86 | right_on = 'user_id') 87 | 88 | print('# of interactions from users with at least %d interactions: %d' % (min_interactions, len(interactions_from_selected_users_df))) 89 | 90 | return interactions_from_selected_users_df 91 | 92 | 93 | def interactions_with_negative_sample(interactions_train_df, factor_negative_sample=3): 94 | ''' 95 | Create a negative interactions, on the user not view content 96 | 97 | factor_negative_sample: Kx no-negative interactions 98 | ''' 99 | 100 | if factor_negative_sample == 0: 101 | return interactions_train_df 102 | 103 | # Top content Views 104 | top_content_views = interactions_train_df.groupby('content_id').count()['user_id']\ 105 | .reset_index().sort_values('user_id', ascending=False).head(1000) 106 | 107 | interactions_train_df = interactions_train_df.set_index('user_id') 108 | 109 | all_df = [] 110 | for user_id in np.unique(interactions_train_df.index.values): 111 | content_views = interactions_train_df.loc[user_id].content_id.unique() 112 | content_not_view = top_content_views[~top_content_views.content_id.isin(content_views)]\ 113 | .content_id.values[:int(len(content_views)*factor_negative_sample)] 114 | 115 | df_view = pd.DataFrame(data={'user_id': [user_id]*len(content_views), 116 | 'content_id': content_views, 117 | 'view': [1]*len(content_views)}) 118 | 119 | df_not_view = pd.DataFrame(data={'user_id': [user_id]*len(content_not_view), 120 | 'content_id': content_not_view, 121 | 'view': [-1]*len(content_not_view)}) 122 | 123 | df = pd.concat([df_view,df_not_view]).sample(frac=1) 124 | all_df.append(df) 125 | 126 | return pd.concat(all_df) 127 | 128 | if __name__ == '__main__': 129 | run() 130 | -------------------------------------------------------------------------------- /docs/all_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/all_metrics.png -------------------------------------------------------------------------------- /docs/all_metrics_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/all_metrics_plot.png -------------------------------------------------------------------------------- /docs/background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/background.png -------------------------------------------------------------------------------- /docs/deep_autoenc1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/deep_autoenc1.png -------------------------------------------------------------------------------- /docs/deep_autoenc2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/deep_autoenc2.png -------------------------------------------------------------------------------- /docs/hist_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/hist_metrics.png -------------------------------------------------------------------------------- /docs/metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/metrics.png -------------------------------------------------------------------------------- /docs/model_summary_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/model_summary_1.png -------------------------------------------------------------------------------- /docs/model_summary_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/model_summary_2.png -------------------------------------------------------------------------------- /docs/model_summary_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/model_summary_3.png -------------------------------------------------------------------------------- /docs/train_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/docs/train_hist.png -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/marlesson/recsys_autoencoders/a4f2fce504509bbc755d790d3faeec34addae47b/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | # Metrics 2 | # 3 | # These metrics are derived from the 4 | # 5 | # https://github.com/statisticianinstilettos/recmetrics/blob/master/recmetrics/metrics.py 6 | # http://github.com/benfred/implicit/ 7 | 8 | import pandas as pd 9 | from sklearn.metrics.pairwise import cosine_similarity 10 | from sklearn.metrics import mean_squared_error 11 | import scipy.sparse as sp 12 | from math import sqrt 13 | import numpy as np 14 | import warnings 15 | 16 | __all__ = [ 17 | 'mean_average_precision', 18 | 'ndcg_at', 19 | 'precision_at', 20 | ] 21 | 22 | def _require_positive_k(k): 23 | """Helper function to avoid copy/pasted code for validating K""" 24 | if k <= 0: 25 | raise ValueError("ranking position k should be positive") 26 | 27 | 28 | def _mean_ranking_metric(predictions, labels, metric): 29 | """Helper function for precision_at_k and mean_average_precision""" 30 | # do not zip, as this will require an extra pass of O(N). Just assert 31 | # equal length and index (compute in ONE pass of O(N)). 32 | # if len(predictions) != len(labels): 33 | # raise ValueError("dim mismatch in predictions and labels!") 34 | # return np.mean([ 35 | # metric(np.asarray(predictions[i]), np.asarray(labels[i])) 36 | # for i in xrange(len(predictions)) 37 | # ]) 38 | 39 | # Actually probably want lazy evaluation in case preds is a 40 | # generator, since preds can be very dense and could blow up 41 | # memory... but how to assert lengths equal? FIXME 42 | return np.mean([ 43 | metric(np.asarray(prd), np.asarray(labels[i])) 44 | for i, prd in enumerate(predictions) # lazy eval if generator 45 | ]) 46 | 47 | 48 | def _warn_for_empty_labels(): 49 | """Helper for missing ground truth sets""" 50 | warnings.warn("Empty ground truth set! Check input data") 51 | return 0. 52 | 53 | 54 | def precision_at(predictions, labels, k=10, assume_unique=True): 55 | """Compute the precision at K. 56 | Compute the average precision of all the queries, truncated at 57 | ranking position k. If for a query, the ranking algorithm returns 58 | n (n is less than k) results, the precision value will be computed 59 | as #(relevant items retrieved) / k. This formula also applies when 60 | the size of the ground truth set is less than k. 61 | If a query has an empty ground truth set, zero will be used as 62 | precision together with a warning. 63 | Parameters 64 | ---------- 65 | predictions : array-like, shape=(n_predictions,) 66 | The prediction array. The items that were predicted, in descending 67 | order of relevance. 68 | labels : array-like, shape=(n_ratings,) 69 | The labels (positively-rated items). 70 | k : int, optional (default=10) 71 | The rank at which to measure the precision. 72 | assume_unique : bool, optional (default=True) 73 | Whether to assume the items in the labels and predictions are each 74 | unique. That is, the same item is not predicted multiple times or 75 | rated multiple times. 76 | Examples 77 | -------- 78 | >>> # predictions for 3 users 79 | >>> preds = [[1, 6, 2, 7, 8, 3, 9, 10, 4, 5], 80 | ... [4, 1, 5, 6, 2, 7, 3, 8, 9, 10], 81 | ... [1, 2, 3, 4, 5]] 82 | >>> # labels for the 3 users 83 | >>> labels = [[1, 2, 3, 4, 5], [1, 2, 3], []] 84 | >>> precision_at(preds, labels, 1) 85 | 0.33333333333333331 86 | >>> precision_at(preds, labels, 5) 87 | 0.26666666666666666 88 | >>> precision_at(preds, labels, 15) 89 | 0.17777777777777778 90 | """ 91 | # validate K 92 | _require_positive_k(k) 93 | 94 | def _inner_pk(pred, lab): 95 | # need to compute the count of the number of values in the predictions 96 | # that are present in the labels. We'll use numpy in1d for this (set 97 | # intersection in O(1)) 98 | if lab.shape[0] > 0: 99 | n = min(pred.shape[0], k) 100 | cnt = np.in1d(pred[:n], lab, assume_unique=assume_unique).sum() 101 | return float(cnt) / k 102 | else: 103 | return _warn_for_empty_labels() 104 | 105 | return _mean_ranking_metric(predictions, labels, _inner_pk) 106 | 107 | 108 | def mean_average_precision(predictions, labels, assume_unique=True): 109 | """Compute the mean average precision on predictions and labels. 110 | Returns the mean average precision (MAP) of all the queries. If a query 111 | has an empty ground truth set, the average precision will be zero and a 112 | warning is generated. 113 | Parameters 114 | ---------- 115 | predictions : array-like, shape=(n_predictions,) 116 | The prediction array. The items that were predicted, in descending 117 | order of relevance. 118 | labels : array-like, shape=(n_ratings,) 119 | The labels (positively-rated items). 120 | assume_unique : bool, optional (default=True) 121 | Whether to assume the items in the labels and predictions are each 122 | unique. That is, the same item is not predicted multiple times or 123 | rated multiple times. 124 | Examples 125 | -------- 126 | >>> # predictions for 3 users 127 | >>> preds = [[1, 6, 2, 7, 8, 3, 9, 10, 4, 5], 128 | ... [4, 1, 5, 6, 2, 7, 3, 8, 9, 10], 129 | ... [1, 2, 3, 4, 5]] 130 | >>> # labels for the 3 users 131 | >>> labels = [[1, 2, 3, 4, 5], [1, 2, 3], []] 132 | >>> mean_average_precision(preds, labels) 133 | 0.35502645502645497 134 | """ 135 | def _inner_map(pred, lab): 136 | if lab.shape[0]: 137 | # compute the number of elements within the predictions that are 138 | # present in the actual labels, and get the cumulative sum weighted 139 | # by the index of the ranking 140 | n = pred.shape[0] 141 | 142 | # Scala code from Spark source: 143 | # var i = 0 144 | # var cnt = 0 145 | # var precSum = 0.0 146 | # val n = pred.length 147 | # while (i < n) { 148 | # if (labSet.contains(pred(i))) { 149 | # cnt += 1 150 | # precSum += cnt.toDouble / (i + 1) 151 | # } 152 | # i += 1 153 | # } 154 | # precSum / labSet.size 155 | 156 | arange = np.arange(n, dtype=np.float32) + 1. # this is the denom 157 | present = np.in1d(pred[:n], lab, assume_unique=assume_unique) 158 | prec_sum = np.ones(present.sum()).cumsum() 159 | denom = arange[present] 160 | return (prec_sum / denom).sum() / lab.shape[0] 161 | 162 | else: 163 | return _warn_for_empty_labels() 164 | 165 | return _mean_ranking_metric(predictions, labels, _inner_map) 166 | 167 | 168 | def ndcg_at(predictions, labels, k=10, assume_unique=True): 169 | """Compute the normalized discounted cumulative gain at K. 170 | Compute the average NDCG value of all the queries, truncated at ranking 171 | position k. The discounted cumulative gain at position k is computed as: 172 | sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1) 173 | and the NDCG is obtained by dividing the DCG value on the ground truth set. 174 | In the current implementation, the relevance value is binary. 175 | If a query has an empty ground truth set, zero will be used as 176 | NDCG together with a warning. 177 | Parameters 178 | ---------- 179 | predictions : array-like, shape=(n_predictions,) 180 | The prediction array. The items that were predicted, in descending 181 | order of relevance. 182 | labels : array-like, shape=(n_ratings,) 183 | The labels (positively-rated items). 184 | k : int, optional (default=10) 185 | The rank at which to measure the NDCG. 186 | assume_unique : bool, optional (default=True) 187 | Whether to assume the items in the labels and predictions are each 188 | unique. That is, the same item is not predicted multiple times or 189 | rated multiple times. 190 | Examples 191 | -------- 192 | >>> # predictions for 3 users 193 | >>> preds = [[1, 6, 2, 7, 8, 3, 9, 10, 4, 5], 194 | ... [4, 1, 5, 6, 2, 7, 3, 8, 9, 10], 195 | ... [1, 2, 3, 4, 5]] 196 | >>> # labels for the 3 users 197 | >>> labels = [[1, 2, 3, 4, 5], [1, 2, 3], []] 198 | >>> ndcg_at(preds, labels, 3) 199 | 0.3333333432674408 200 | >>> ndcg_at(preds, labels, 10) 201 | 0.48791273434956867 202 | References 203 | ---------- 204 | .. [1] K. Jarvelin and J. Kekalainen, "IR evaluation methods for 205 | retrieving highly relevant documents." 206 | """ 207 | # validate K 208 | _require_positive_k(k) 209 | 210 | def _inner_ndcg(pred, lab): 211 | if lab.shape[0]: 212 | # if we do NOT assume uniqueness, the set is a bit different here 213 | if not assume_unique: 214 | lab = np.unique(lab) 215 | 216 | n_lab = lab.shape[0] 217 | n_pred = pred.shape[0] 218 | n = min(max(n_pred, n_lab), k) # min(min(p, l), k)? 219 | 220 | # similar to mean_avg_prcsn, we need an arange, but this time +2 221 | # since python is zero-indexed, and the denom typically needs +1. 222 | # Also need the log base2... 223 | arange = np.arange(n, dtype=np.float32) # length n 224 | 225 | # since we are only interested in the arange up to n_pred, truncate 226 | # if necessary 227 | arange = arange[:n_pred] 228 | denom = np.log2(arange + 2.) # length n 229 | gains = 1. / denom # length n 230 | 231 | # compute the gains where the prediction is present in the labels 232 | dcg_mask = np.in1d(pred[:n], lab, assume_unique=assume_unique) 233 | dcg = gains[dcg_mask].sum() 234 | 235 | # the max DCG is sum of gains where the index < the label set size 236 | max_dcg = gains[arange < n_lab].sum() 237 | return dcg / max_dcg 238 | 239 | else: 240 | return _warn_for_empty_labels() 241 | 242 | return _mean_ranking_metric(predictions, labels, _inner_ndcg) 243 | 244 | def coverage(predicted, catalog): 245 | """ 246 | Computes the coverage for a list of recommendations 247 | Parameters 248 | ---------- 249 | predicted : a list of lists 250 | Ordered predictions 251 | example: [['X', 'Y', 'Z'], ['X', 'Y', 'Z']] 252 | catalog: list 253 | A list of all unique items in the training data 254 | example: ['A', 'B', 'C', 'X', 'Y', Z] 255 | Returns 256 | ---------- 257 | coverage: 258 | The coverage of the recommendations as a percent 259 | rounded to 2 decimal places 260 | """ 261 | predicted_flattened = [p for sublist in predicted for p in sublist] 262 | unique_predictions = len(set(predicted_flattened)) 263 | coverage = round(unique_predictions/(len(catalog)* 1.0)*100,2) 264 | return coverage 265 | 266 | def _ark(actual, predicted, k=10): 267 | """ 268 | Computes the average recall at k. 269 | Parameters 270 | ---------- 271 | actual : list 272 | A list of actual items to be predicted 273 | predicted : list 274 | An ordered list of predicted items 275 | k : int, default = 10 276 | Number of predictions to consider 277 | Returns: 278 | ------- 279 | score : int 280 | The average recall at k. 281 | """ 282 | if len(predicted)>k: 283 | predicted = predicted[:k] 284 | 285 | score = 0.0 286 | num_hits = 0.0 287 | 288 | for i,p in enumerate(predicted): 289 | if p in actual and p not in predicted[:i]: 290 | num_hits += 1.0 291 | score += num_hits / (i+1.0) 292 | 293 | if not actual: 294 | return 0.0 295 | 296 | return score / len(actual) 297 | 298 | def mark(actual, predicted, k=10): 299 | """ 300 | Computes the mean average recall at k. 301 | Parameters 302 | ---------- 303 | actual : a list of lists 304 | Actual items to be predicted 305 | example: [['A', 'B', 'X'], ['A', 'B', 'Y']] 306 | predicted : a list of lists 307 | Ordered predictions 308 | example: [['X', 'Y', 'Z'], ['X', 'Y', 'Z']] 309 | Returns: 310 | ------- 311 | mark: int 312 | The mean average recall at k (mar@k) 313 | """ 314 | return np.mean([_ark(a,p,k) for a,p in zip(actual, predicted)]) 315 | 316 | def personalization(predicted): 317 | """ 318 | Personalization measures recommendation similarity across users. 319 | A high score indicates good personalization (user's lists of recommendations are different). 320 | A low score indicates poor personalization (user's lists of recommendations are very similar). 321 | A model is "personalizing" well if the set of recommendations for each user is different. 322 | Parameters: 323 | ---------- 324 | predicted : a list of lists 325 | Ordered predictions 326 | example: [['X', 'Y', 'Z'], ['X', 'Y', 'Z']] 327 | Returns: 328 | ------- 329 | The personalization score for all recommendations. 330 | """ 331 | 332 | def make_rec_matrix(predicted, unique_recs): 333 | rec_matrix = pd.DataFrame(index = range(len(predicted)),columns=unique_recs) 334 | rec_matrix.fillna(0, inplace=True) 335 | for i in rec_matrix.index: 336 | rec_matrix.loc[i, predicted[i]] = 1 337 | return rec_matrix 338 | 339 | #get all unique items recommended 340 | predicted_flattened = [p for sublist in predicted for p in sublist] 341 | unique_recs = list(set(predicted_flattened)) 342 | 343 | #create matrix for recommendations 344 | rec_matrix = make_rec_matrix(predicted, unique_recs) 345 | rec_matrix_sparse = sp.csr_matrix(rec_matrix.values) 346 | 347 | #calculate similarity for every user's recommendation list 348 | similarity = cosine_similarity(X=rec_matrix_sparse, dense_output=False) 349 | 350 | #get indicies for upper right triangle w/o diagonal 351 | upper_right = np.triu_indices(similarity.shape[0], k=1) 352 | 353 | #calculate average similarity 354 | personalization = np.mean(similarity[upper_right]) 355 | return 1-personalization 356 | 357 | def _single_list_similarity(predicted, feature_df): 358 | """ 359 | Computes the intra-list similarity for a single list of recommendations. 360 | Parameters 361 | ---------- 362 | predicted : a list 363 | Ordered predictions 364 | Example: ['X', 'Y', 'Z'] 365 | feature_df: dataframe 366 | A dataframe with one hot encoded or latent features. 367 | The dataframe should be indexed by the id used in the recommendations. 368 | Returns: 369 | ------- 370 | ils_single_user: float 371 | The intra-list similarity for a single list of recommendations. 372 | """ 373 | #get features for all recommended items 374 | recs_content = feature_df.loc[predicted] 375 | recs_content = recs_content.dropna() 376 | recs_content = sp.csr_matrix(recs_content.values) 377 | 378 | #calculate similarity scores for all items in list 379 | similarity = cosine_similarity(X=recs_content, dense_output=False) 380 | 381 | #get indicies for upper right triangle w/o diagonal 382 | upper_right = np.triu_indices(similarity.shape[0], k=1) 383 | 384 | #calculate average similarity score of all recommended items in list 385 | ils_single_user = np.mean(similarity[upper_right]) 386 | return ils_single_user 387 | 388 | def intra_list_similarity(predicted, feature_df): 389 | """ 390 | Computes the average intra-list similarity of all recommendations. 391 | This metric can be used to measure diversity of the list of recommended items. 392 | Parameters 393 | ---------- 394 | predicted : a list of lists 395 | Ordered predictions 396 | Example: [['X', 'Y', 'Z'], ['X', 'Y', 'Z']] 397 | feature_df: dataframe 398 | A dataframe with one hot encoded or latent features. 399 | The dataframe should be indexed by the id used in the recommendations. 400 | Returns: 401 | ------- 402 | The average intra-list similarity for recommendations. 403 | """ 404 | feature_df = feature_df.fillna(0) 405 | Users = range(len(predicted)) 406 | ils = [_single_list_similarity(predicted[u], feature_df) for u in Users] 407 | return np.mean(ils) 408 | 409 | def mse(y, yhat): 410 | """ 411 | Computes the mean square error (MSE) 412 | Parameters 413 | ---------- 414 | yhat : Series or array. Reconstructed (predicted) ratings or values. 415 | y: original true ratings or values. 416 | Returns: 417 | ------- 418 | The mean square error (MSE) 419 | """ 420 | mse = mean_squared_error(y, yhat) 421 | return mse 422 | 423 | def rmse(y, yhat): 424 | """ 425 | Computes the root mean square error (RMSE) 426 | Parameters 427 | ---------- 428 | yhat : Series or array. Reconstructed (predicted) ratings or values 429 | y: original true ratings or values. 430 | Returns: 431 | ------- 432 | The mean square error (MSE) 433 | """ 434 | rmse = sqrt(mean_squared_error(y, yhat)) 435 | return rmse 436 | 437 | -------------------------------------------------------------------------------- /evaluation/model_evaluator.py: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | # 3 | # These evaluation are derived from the 4 | # https://www.kaggle.com/gspmoreira/recommender-systems-in-python-101 5 | 6 | import pandas as pd 7 | import random 8 | import numpy as np 9 | from . import metrics 10 | 11 | class ModelEvaluator: 12 | METRICS = ['ndcg_at.5','ndcg_at.10', 'converge','personalization'] 13 | 14 | def __init__(self, articles_df, interactions_full_df, 15 | interactions_train_df, interactions_test_df): 16 | 17 | #Indexing by user_id to speed up the searches during evaluation 18 | self.articles_df = articles_df 19 | self.interactions_full_indexed_df = interactions_full_df[interactions_full_df.view > 0].set_index('user_id') 20 | self.interactions_train_indexed_df = interactions_train_df[interactions_train_df.view > 0].set_index('user_id') 21 | self.interactions_test_indexed_df = interactions_test_df.set_index('user_id') 22 | 23 | 24 | def get_items_interacted(self, person_id, interactions_df): 25 | # Get the user's data and merge in the movie information. 26 | interacted_items = interactions_df.loc[person_id]['content_id'] 27 | return set(interacted_items if type(interacted_items) == pd.Series else [interacted_items]) 28 | 29 | def get_not_interacted_items_sample(self, person_id, sample_size, seed=42): 30 | interacted_items = self.get_items_interacted(person_id, self.interactions_full_indexed_df) 31 | all_items = set(self.articles_df['content_id']) 32 | non_interacted_items = all_items - interacted_items 33 | 34 | random.seed(seed) 35 | non_interacted_items_sample = random.sample(non_interacted_items, sample_size) 36 | return set(non_interacted_items_sample) 37 | 38 | def _verify_hit_top_n(self, item_id, recommended_items, topn): 39 | try: 40 | index = next(i for i, c in enumerate(recommended_items) if c == item_id) 41 | except: 42 | index = -1 43 | hit = int(index in range(0, topn)) 44 | return hit, index 45 | 46 | def recommender_model_for_user(self, model, person_id): 47 | #Getting the items in test set 48 | interacted_values_testset = self.interactions_test_indexed_df.loc[person_id] 49 | if type(interacted_values_testset['content_id']) == pd.Series: 50 | person_interacted_items_testset = np.array(interacted_values_testset['content_id']) 51 | else: 52 | person_interacted_items_testset = np.array([int(interacted_values_testset['content_id'])]) 53 | interacted_items_count_testset = len(person_interacted_items_testset) 54 | 55 | #Getting a ranked recommendation list from a model for a given user 56 | person_recs_df = model.recommend_items(person_id, 57 | items_to_ignore=self.get_items_interacted(person_id, 58 | self.interactions_train_indexed_df), 59 | topn=20) 60 | # Recommender Content_iD 61 | #recs_content_id = person_recs_df['content_id'].values[:10] 62 | 63 | 64 | person_metrics = {'recommender': person_recs_df['content_id'].values, 65 | 'labels': person_interacted_items_testset} 66 | 67 | return person_metrics 68 | 69 | def evaluate_model(self, model): 70 | print('Running evaluation for users') 71 | people_recs = [] 72 | recs_content_id = [] 73 | users = list(self.interactions_test_indexed_df.sample(frac=1).index.unique().values)[:500] 74 | len_users = len(users) 75 | 76 | for idx, person_id in enumerate(users): 77 | if idx % 500 == 0 and idx > 0: 78 | print('%.2f users processed' % (idx/len_users)) 79 | 80 | person_recs = self.recommender_model_for_user(model, person_id) 81 | person_recs['person_id'] = person_id 82 | people_recs.append(person_recs) 83 | 84 | print('%d users processed' % idx) 85 | print("evaluations...") 86 | 87 | detailed_results_df = pd.DataFrame(people_recs) 88 | predictions = detailed_results_df['recommender'].values 89 | labels = detailed_results_df['labels'].values 90 | 91 | ndcg_at_5 = metrics.ndcg_at(predictions, labels, k=5) 92 | ndcg_at_10 = metrics.ndcg_at(predictions, labels, k=10) 93 | mean_average_precision = metrics.mean_average_precision(predictions, labels) 94 | coverage = metrics.coverage(predictions, self.articles_df['content_id'].values)/100 95 | personalization = metrics.personalization(predictions) 96 | 97 | global_metrics = {'ndcg_at.5': ndcg_at_5, 98 | 'ndcg_at.10': ndcg_at_10, 99 | 'MAP': mean_average_precision, 100 | 'converge': coverage, 101 | 'personalization': personalization } 102 | 103 | 104 | 105 | return global_metrics, detailed_results_df 106 | 107 | 108 | class CFRecommender: 109 | ''' 110 | 111 | ''' 112 | 113 | MODEL_NAME = 'Collaborative Filtering' 114 | 115 | def __init__(self, cf_predictions_df, items_df=None): 116 | self.cf_predictions_df = cf_predictions_df 117 | self.items_df = items_df 118 | 119 | def get_model_name(self): 120 | return self.MODEL_NAME 121 | 122 | def recommend_items(self, user_id, items_to_ignore=[], topn=10, verbose=False): 123 | # Get and sort the user's predictions 124 | sorted_user_predictions = self.cf_predictions_df[user_id].sort_values(ascending=False) \ 125 | .reset_index().rename(columns={user_id: 'score'}) 126 | 127 | # Recommend the highest predicted rating movies that the user hasn't seen yet. 128 | recommendations_df = sorted_user_predictions[~sorted_user_predictions['content_id'].isin(items_to_ignore)] \ 129 | .sort_values('score', ascending = False) \ 130 | .head(topn) 131 | 132 | if verbose: 133 | if self.items_df is None: 134 | raise Exception('"items_df" is required in verbose mode') 135 | 136 | recommendations_df = recommendations_df.merge(self.items_df, how = 'left', 137 | left_on = 'content_id', 138 | right_on = 'content_id')[['score', 'content_id', 'game']] 139 | 140 | 141 | return recommendations_df -------------------------------------------------------------------------------- /model/AutoEncContentModel.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.optimizers import Adam, RMSprop 2 | from tensorflow.keras.layers import Input, Dense, Embedding, Flatten, Dropout, Activation 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras.regularizers import l2 5 | from tensorflow.keras import backend as K 6 | from tensorflow.keras import regularizers 7 | from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping 8 | from tensorflow.keras import initializers 9 | from tensorflow.keras.layers import add 10 | from model.BaseModel import BaseModel 11 | import numpy as np 12 | from tensorflow.keras.preprocessing.text import one_hot 13 | from tensorflow.keras.preprocessing.sequence import pad_sequences 14 | 15 | class AutoEncContentModel(BaseModel): 16 | ''' 17 | Model adapted from Auto Encoder With Content Base Information 18 | 19 | Reference: 20 | KUCHAIEV, Oleksii; GINSBURG, Boris. Training deep autoencoders for collaborative filtering. 21 | arXiv preprint arXiv:1708.01715, 2017. 22 | https://github.com/NVIDIA/DeepRecommender 23 | https://arxiv.org/pdf/1708.01715.pdf 24 | ''' 25 | 26 | def __init__(self, layers = '[]', epochs = None, batch = None, 27 | activation = None, dropout = None, lr = None, reg = None): 28 | self.layers = eval(layers) 29 | self.epochs = epochs 30 | self.batch = batch 31 | self.activation = activation 32 | self.dropout = dropout 33 | self.lr = lr 34 | self.reg = reg 35 | self.model = None 36 | 37 | 38 | def data_preparation(self, interactions, user_item_matrix): 39 | ''' 40 | Create a Input to Model 41 | ''' 42 | 43 | # Params 44 | # integer encode the documents 45 | vocab_size = 100 46 | # pad documents to a max length of 4 words 47 | max_length = 50 48 | 49 | 50 | def split_str(val): 51 | ''' 52 | Split and Join Array(Array(str)) 53 | ''' 54 | tokens = [] 55 | for v in val: 56 | tokens.extend(v.split(' ')) 57 | return ' '.join(tokens) 58 | 59 | # Order users in matrix interactions 60 | users_ids = list(user_item_matrix.index) 61 | 62 | # Dataset with User X Content information 63 | user_games = interactions.groupby('user_id')['game'].apply(list).loc[users_ids].reset_index() 64 | user_games['tokens'] = user_games['game'].apply(split_str) 65 | 66 | # Prepare input layer 67 | encoded_tokens = [one_hot(d, vocab_size) for d in user_games.tokens] 68 | padded_tokens = pad_sequences(encoded_tokens, maxlen=max_length, padding='post') 69 | 70 | # Input 71 | X = [user_item_matrix.values, padded_tokens] 72 | y = user_item_matrix.values 73 | 74 | return X, y 75 | 76 | def fit(self, X, y): 77 | ''' 78 | Train Model 79 | ''' 80 | 81 | # Build model 82 | model = self.build_model(X) 83 | 84 | model.compile(optimizer = Adam(lr=self.lr), 85 | loss='mse')#'mean_absolute_error' 86 | 87 | # train 88 | hist = model.fit(x=X, y=y, 89 | epochs=self.epochs, 90 | batch_size=self.batch, 91 | shuffle=True, 92 | validation_split=0.1, 93 | callbacks=self.callbacks_list()) 94 | 95 | # Melhor peso 96 | model.load_weights(self.WEIGHT_MODEL) 97 | self.model = model 98 | 99 | return model, hist 100 | 101 | def predict(self, X): 102 | 103 | # Predict 104 | pred = self.model.predict(X) 105 | 106 | # remove watched items from predictions 107 | pred = pred * (X[0] == 0) 108 | 109 | return pred 110 | 111 | def build_model(self, X): 112 | ''' 113 | Autoencoder for Collaborative Filter Model 114 | ''' 115 | 116 | # Params 117 | users_items_matrix, content_info = X 118 | 119 | # Input 120 | input_layer = x = Input(shape=(users_items_matrix.shape[1],), name='UserScore') 121 | input_content = Input(shape=(content_info.shape[1],), name='Itemcontent') 122 | 123 | # Encoder 124 | k = int(len(self.layers)/2) 125 | i = 0 126 | for l in self.layers[:k]: 127 | x = Dense(l, activation=self.activation, 128 | name='EncLayer{}'.format(i))(x) 129 | i = i+1 130 | 131 | # Latent Space 132 | x = Dense(self.layers[k], activation=self.activation, 133 | name='UserLatentSpace')(x) 134 | 135 | # Content Information 136 | x_content = Embedding(100, self.layers[k], 137 | input_length=content_info.shape[1])(input_content) 138 | x_content = Flatten()(x_content) 139 | x_content = Dense(self.layers[k], activation=self.activation, 140 | name='ItemLatentSpace')(x_content) 141 | # Concatenate 142 | x = add([x, x_content], name='LatentSpace') 143 | 144 | # Dropout 145 | x = Dropout(self.dropout)(x) 146 | 147 | # Decoder 148 | for l in self.layers[k+1:]: 149 | i = i-1 150 | x = Dense(l, activation=self.activation, 151 | name='DecLayer{}'.format(i))(x) 152 | 153 | # Output 154 | output_layer = Dense(users_items_matrix.shape[1], activation='linear', name='UserScorePred')(x) 155 | 156 | 157 | # this model maps an input to its reconstruction 158 | model = Model([input_layer, input_content], output_layer) 159 | 160 | return model 161 | -------------------------------------------------------------------------------- /model/AutoEncModel.py: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # 4 | # 5 | # 6 | 7 | from tensorflow.keras.optimizers import Adam, RMSprop 8 | from tensorflow.keras.layers import Input, Dense, Embedding, Flatten, Dropout, Activation 9 | from model.BaseModel import BaseModel 10 | from tensorflow.keras.models import Model 11 | 12 | import numpy as np 13 | 14 | class AutoEncModel(BaseModel): 15 | ''' 16 | create model 17 | Reference: 18 | KUCHAIEV, Oleksii; GINSBURG, Boris. Training deep autoencoders for collaborative filtering. 19 | arXiv preprint arXiv:1708.01715, 2017. 20 | https://github.com/NVIDIA/DeepRecommender 21 | https://arxiv.org/pdf/1708.01715.pdf 22 | ''' 23 | 24 | def __init__(self, layers = '[]', epochs = None, batch = None, 25 | activation = None, dropout = None, lr = None, reg = None): 26 | self.layers = eval(layers) 27 | self.epochs = epochs 28 | self.batch = batch 29 | self.activation = activation 30 | self.dropout = dropout 31 | self.lr = lr 32 | self.reg = reg 33 | self.model = None 34 | 35 | def data_preparation(self, interactions, user_item_matrix): 36 | ''' 37 | Create a Input to Model 38 | ''' 39 | 40 | X, y = user_item_matrix.values, user_item_matrix.values 41 | 42 | return X, y 43 | 44 | 45 | def fit(self, X, y): 46 | ''' 47 | Train Model 48 | ''' 49 | 50 | # Build model 51 | model = self.build_model(X) 52 | 53 | model.compile(optimizer = Adam(lr=self.lr), 54 | loss='mse')#'mean_absolute_error' 55 | 56 | # train 57 | hist = model.fit(x=X, y=y, 58 | epochs=self.epochs, 59 | batch_size=self.batch, 60 | shuffle=True, 61 | validation_split=0.1, 62 | callbacks=self.callbacks_list()) 63 | 64 | # Melhor peso 65 | model.load_weights(self.WEIGHT_MODEL) 66 | self.model = model 67 | 68 | return model, hist 69 | 70 | def predict(self, X): 71 | 72 | # Predict 73 | pred = self.model.predict(X) 74 | 75 | # remove watched items from predictions 76 | pred = pred * (X[0] == 0) 77 | 78 | return pred 79 | 80 | def build_model(self, X): 81 | ''' 82 | Autoencoder for Collaborative Filter Model 83 | ''' 84 | 85 | # Input 86 | input_layer = x = Input(shape=(X.shape[1],), name='UserScore') 87 | 88 | # Encoder 89 | # ----------------------------- 90 | k = int(len(self.layers)/2) 91 | i = 0 92 | for l in self.layers[:k]: 93 | x = Dense(l, activation=self.activation, 94 | name='EncLayer{}'.format(i))(x) 95 | i = i+1 96 | 97 | # Latent Space 98 | # ----------------------------- 99 | x = Dense(self.layers[k], activation=self.activation, 100 | name='LatentSpace')(x) 101 | # Dropout 102 | x = Dropout(self.dropout, name='Dropout')(x) 103 | 104 | # Decoder 105 | # ----------------------------- 106 | for l in self.layers[k+1:]: 107 | i = i-1 108 | x = Dense(l, activation=self.activation, 109 | name='DecLayer{}'.format(i))(x) 110 | 111 | # Output 112 | output_layer = Dense(X.shape[1], activation='linear', name='UserScorePred')(x) 113 | 114 | 115 | # this model maps an input to its reconstruction 116 | model = Model(input_layer, output_layer) 117 | 118 | return model 119 | -------------------------------------------------------------------------------- /model/BaseModel.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping 2 | 3 | 4 | class BaseModel(object): 5 | WEIGHT_MODEL = "./artefacts/weights-best-model.hdf5" 6 | 7 | def callbacks_list(self, monitor='val_loss', path_model = WEIGHT_MODEL, patience=15): 8 | ''' 9 | Callbacks of Train model 10 | ''' 11 | checkpoint = ModelCheckpoint(path_model, monitor=monitor, verbose=1, save_best_only=True, mode='min') 12 | early_stop = EarlyStopping(monitor=monitor,min_delta=0, patience=patience, verbose=0, mode='auto') 13 | 14 | callbacks_list = [checkpoint, early_stop] 15 | return callbacks_list -------------------------------------------------------------------------------- /model/CDAEModel.py: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # 4 | # 5 | # 6 | import tensorflow as tf 7 | 8 | from tensorflow.keras.optimizers import Adam, RMSprop 9 | from tensorflow.keras.layers import Input, Dense, Embedding, Flatten, Dropout, Activation 10 | from tensorflow.keras.regularizers import l2 11 | from tensorflow.keras import backend as K 12 | from tensorflow.keras.layers import add 13 | from tensorflow.keras.models import Model 14 | from model.BaseModel import BaseModel 15 | import numpy as np 16 | 17 | class CDAEModel(BaseModel): 18 | ''' 19 | create model 20 | Reference: 21 | Yao Wu, Christopher DuBois, Alice X. Zheng, Martin Ester. 22 | Collaborative Denoising Auto-Encoders for Top-N Recommender Systems. 23 | The 9th ACM International Conference on Web Search and Data Mining (WSDM'16), p153--162, 2016. 24 | ''' 25 | 26 | def __init__(self, factors = None, epochs = None, batch = None, 27 | activation = None, dropout = None, lr = None, reg = None): 28 | self.factors = factors 29 | self.epochs = epochs 30 | self.batch = batch 31 | self.activation = activation 32 | self.dropout = dropout 33 | self.lr = lr 34 | self.reg = reg 35 | self.model = None 36 | 37 | 38 | def data_preparation(self, interactions, user_item_matrix): 39 | ''' 40 | Create a Input to Model 41 | ''' 42 | users_ids = list(user_item_matrix.index) 43 | x_user_ids = np.array(users_ids).reshape(len(users_ids), 1) 44 | 45 | X = [user_item_matrix.values, x_user_ids] 46 | y = user_item_matrix.values 47 | 48 | return X, y 49 | 50 | def fit(self, X, y): 51 | # Build model 52 | model = self.build_model(X) 53 | 54 | model.compile(optimizer = Adam(lr=self.lr), 55 | loss='mse')#'mean_absolute_error' 56 | 57 | # train 58 | hist = model.fit(x=X, y=y, 59 | epochs=self.epochs, 60 | batch_size=self.batch, 61 | shuffle=True, 62 | validation_split=0.1, 63 | callbacks=self.callbacks_list()) 64 | 65 | # Melhor peso 66 | model.load_weights(self.WEIGHT_MODEL) 67 | self.model = model 68 | 69 | return model, hist 70 | 71 | 72 | def predict(self, X): 73 | 74 | # Predict 75 | pred = self.model.predict(X) 76 | 77 | # remove watched items from predictions 78 | pred = pred * (X[0] == 0) 79 | 80 | return pred 81 | 82 | def build_model(self, X): 83 | 84 | # Params 85 | users_items_matrix, x_user_ids = X 86 | 87 | # Model 88 | x_item = Input((users_items_matrix.shape[1],), name='UserScore') 89 | h_item = Dropout(self.dropout)(x_item) 90 | h_item = Dense(self.factors, 91 | kernel_regularizer=l2(self.reg), 92 | bias_regularizer=l2(self.reg), 93 | activation=self.activation)(h_item) 94 | 95 | # dtype should be int to connect to Embedding layer 96 | x_user = Input((1,), dtype='int32', name='UserContent') 97 | h_user = Embedding(len(np.unique(x_user_ids))+1,self.factors, 98 | input_length=1, 99 | embeddings_regularizer=l2(self.reg))(x_user) 100 | h_user = Flatten()(h_user) 101 | 102 | h = add([h_item, h_user], name='LatentSpace') 103 | y = Dense(users_items_matrix.shape[1], activation='linear', name='UserScorePred')(h) 104 | 105 | return Model(inputs=[x_item, x_user], outputs=y) 106 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # from . import CDAEModel 2 | 3 | # __all__ = [ 4 | # 'CDAEModel', 5 | # ] -------------------------------------------------------------------------------- /notebooks/DeepAutoEncoder - Simple Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "import matplotlib.pyplot as plt\n", 11 | "import math\n", 12 | "import numpy as np\n", 13 | "import seaborn as sns\n", 14 | "%matplotlib inline \n", 15 | "\n", 16 | "import warnings\n", 17 | "warnings.filterwarnings('ignore')" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Data Preparation" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 2, 30 | "metadata": {}, 31 | "outputs": [ 32 | { 33 | "data": { 34 | "text/html": [ 35 | "
\n", 36 | "\n", 49 | "\n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | "
151603712The Elder Scrolls V Skyrimpurchase1.00
0151603712The Elder Scrolls V Skyrimplay273.00
1151603712Fallout 4purchase1.00
2151603712Fallout 4play87.00
3151603712Sporepurchase1.00
4151603712Sporeplay14.90
\n", 103 | "
" 104 | ], 105 | "text/plain": [ 106 | " 151603712 The Elder Scrolls V Skyrim purchase 1.0 0\n", 107 | "0 151603712 The Elder Scrolls V Skyrim play 273.0 0\n", 108 | "1 151603712 Fallout 4 purchase 1.0 0\n", 109 | "2 151603712 Fallout 4 play 87.0 0\n", 110 | "3 151603712 Spore purchase 1.0 0\n", 111 | "4 151603712 Spore play 14.9 0" 112 | ] 113 | }, 114 | "execution_count": 2, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "# Raw Data\n", 121 | "df = pd.read_csv('../data/raw/rating.csv')\n", 122 | "df.head(5)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 3, 128 | "metadata": { 129 | "scrolled": true 130 | }, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/html": [ 135 | "
\n", 136 | "\n", 149 | "\n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | "
user_idcontent_idgameview
01341680Far Cry 3 Blood Dragon1
122191938Gone Home1
233153711Serious Sam 3 BFE1
334404784Velvet Sundown1
48704104Super Hexagon1
\n", 197 | "
" 198 | ], 199 | "text/plain": [ 200 | " user_id content_id game view\n", 201 | "0 134 1680 Far Cry 3 Blood Dragon 1\n", 202 | "1 2219 1938 Gone Home 1\n", 203 | "2 3315 3711 Serious Sam 3 BFE 1\n", 204 | "3 3440 4784 Velvet Sundown 1\n", 205 | "4 870 4104 Super Hexagon 1" 206 | ] 207 | }, 208 | "execution_count": 3, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "# Train Data (subset of all interactions)\n", 215 | "df = pd.read_csv('../data/interactions_train_df.csv')\n", 216 | "df = df[['user_id', 'content_id', 'game', 'view']]\n", 217 | "df.head(5)" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 4, 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "data": { 227 | "text/html": [ 228 | "
\n", 229 | "\n", 242 | "\n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | "
content_idgametotal_userstotal_hours
00007 Legends11.7
110RBITALIS34.2
221... 2... 3... KICK IT! (Drop That Beat Like a...727.0
3310 Second Ninja611.9
\n", 283 | "
" 284 | ], 285 | "text/plain": [ 286 | " content_id game total_users \\\n", 287 | "0 0 007 Legends 1 \n", 288 | "1 1 0RBITALIS 3 \n", 289 | "2 2 1... 2... 3... KICK IT! (Drop That Beat Like a... 7 \n", 290 | "3 3 10 Second Ninja 6 \n", 291 | "\n", 292 | " total_hours \n", 293 | "0 1.7 \n", 294 | "1 4.2 \n", 295 | "2 27.0 \n", 296 | "3 11.9 " 297 | ] 298 | }, 299 | "execution_count": 4, 300 | "metadata": {}, 301 | "output_type": "execute_result" 302 | } 303 | ], 304 | "source": [ 305 | "# Content Data of Games\n", 306 | "df_game = pd.read_csv('../data/articles_df.csv')\n", 307 | "df_game.head(4)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "### Create a Matrix of Interactions" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": 5, 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/html": [ 325 | "
\n", 326 | "\n", 339 | "\n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | "
content_id01235678910...5103510451055106510751085109511051115112
user_id
00.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
10.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
20.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
30.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
40.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
50.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
60.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
70.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
80.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
90.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
\n", 633 | "

10 rows × 4862 columns

\n", 634 | "
" 635 | ], 636 | "text/plain": [ 637 | "content_id 0 1 2 3 5 6 7 8 9 10 ... \\\n", 638 | "user_id ... \n", 639 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 640 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 641 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 642 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 643 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 644 | "5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 645 | "6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 646 | "7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 647 | "8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 648 | "9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 ... \n", 649 | "\n", 650 | "content_id 5103 5104 5105 5106 5107 5108 5109 5110 5111 5112 \n", 651 | "user_id \n", 652 | "0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 653 | "1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 654 | "2 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 655 | "3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 656 | "4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 657 | "5 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 658 | "6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 659 | "7 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 660 | "8 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 661 | "9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", 662 | "\n", 663 | "[10 rows x 4862 columns]" 664 | ] 665 | }, 666 | "execution_count": 5, 667 | "metadata": {}, 668 | "output_type": "execute_result" 669 | } 670 | ], 671 | "source": [ 672 | "# Creating a sparse pivot table with users in rows and items in columns\n", 673 | "users_items_matrix_df = df.pivot(index = 'user_id', \n", 674 | " columns = 'content_id', \n", 675 | " values = 'view').fillna(0)\n", 676 | "users_items_matrix_df.head(10)" 677 | ] 678 | }, 679 | { 680 | "cell_type": "code", 681 | "execution_count": 6, 682 | "metadata": {}, 683 | "outputs": [ 684 | { 685 | "data": { 686 | "text/plain": [ 687 | "(3757, 4862)" 688 | ] 689 | }, 690 | "execution_count": 6, 691 | "metadata": {}, 692 | "output_type": "execute_result" 693 | } 694 | ], 695 | "source": [ 696 | "users_items_matrix_df.shape" 697 | ] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": 7, 702 | "metadata": {}, 703 | "outputs": [ 704 | { 705 | "data": { 706 | "text/plain": [ 707 | "0.5042609616033342" 708 | ] 709 | }, 710 | "execution_count": 7, 711 | "metadata": {}, 712 | "output_type": "execute_result" 713 | } 714 | ], 715 | "source": [ 716 | "users_items_matrix_df.values.mean()*100" 717 | ] 718 | }, 719 | { 720 | "cell_type": "markdown", 721 | "metadata": {}, 722 | "source": [ 723 | "## Model" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": null, 729 | "metadata": {}, 730 | "outputs": [], 731 | "source": [] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": 8, 736 | "metadata": {}, 737 | "outputs": [ 738 | { 739 | "name": "stderr", 740 | "output_type": "stream", 741 | "text": [ 742 | "Using TensorFlow backend.\n" 743 | ] 744 | } 745 | ], 746 | "source": [ 747 | "from keras.optimizers import Adam\n", 748 | "from keras.layers import Input, Dense, Dropout\n", 749 | "from keras.models import Model\n", 750 | "\n", 751 | "def autoEncoder(X):\n", 752 | " '''\n", 753 | " Autoencoder for Collaborative Filter Model\n", 754 | " '''\n", 755 | "\n", 756 | " # Input\n", 757 | " input_layer = Input(shape=(X.shape[1],), name='UserScore')\n", 758 | " \n", 759 | " # Encoder\n", 760 | " # -----------------------------\n", 761 | " enc = Dense(512, activation='selu', name='EncLayer1')(input_layer)\n", 762 | "\n", 763 | " # Latent Space\n", 764 | " # -----------------------------\n", 765 | " lat_space = Dense(256, activation='selu', name='LatentSpace')(enc)\n", 766 | " lat_space = Dropout(0.8, name='Dropout')(lat_space) # Dropout\n", 767 | "\n", 768 | " # Decoder\n", 769 | " # -----------------------------\n", 770 | " dec = Dense(512, activation='selu', name='DecLayer1')(lat_space)\n", 771 | "\n", 772 | " # Output\n", 773 | " output_layer = Dense(X.shape[1], activation='linear', name='UserScorePred')(dec)\n", 774 | "\n", 775 | " # this model maps an input to its reconstruction\n", 776 | " model = Model(input_layer, output_layer) \n", 777 | " \n", 778 | " return model" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 9, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "# input\n", 788 | "X = users_items_matrix_df.values\n", 789 | "y = users_items_matrix_df.values" 790 | ] 791 | }, 792 | { 793 | "cell_type": "code", 794 | "execution_count": 10, 795 | "metadata": {}, 796 | "outputs": [ 797 | { 798 | "name": "stdout", 799 | "output_type": "stream", 800 | "text": [ 801 | "_________________________________________________________________\n", 802 | "Layer (type) Output Shape Param # \n", 803 | "=================================================================\n", 804 | "UserScore (InputLayer) (None, 4862) 0 \n", 805 | "_________________________________________________________________\n", 806 | "EncLayer1 (Dense) (None, 512) 2489856 \n", 807 | "_________________________________________________________________\n", 808 | "LatentSpace (Dense) (None, 256) 131328 \n", 809 | "_________________________________________________________________\n", 810 | "Dropout (Dropout) (None, 256) 0 \n", 811 | "_________________________________________________________________\n", 812 | "DecLayer1 (Dense) (None, 512) 131584 \n", 813 | "_________________________________________________________________\n", 814 | "UserScorePred (Dense) (None, 4862) 2494206 \n", 815 | "=================================================================\n", 816 | "Total params: 5,246,974\n", 817 | "Trainable params: 5,246,974\n", 818 | "Non-trainable params: 0\n", 819 | "_________________________________________________________________\n" 820 | ] 821 | } 822 | ], 823 | "source": [ 824 | "# Build model\n", 825 | "model = autoEncoder(X)\n", 826 | "\n", 827 | "model.compile(optimizer = Adam(lr=0.0001), loss='mse')\n", 828 | " \n", 829 | "model.summary()" 830 | ] 831 | }, 832 | { 833 | "cell_type": "markdown", 834 | "metadata": {}, 835 | "source": [ 836 | "### Train Model" 837 | ] 838 | }, 839 | { 840 | "cell_type": "code", 841 | "execution_count": 11, 842 | "metadata": {}, 843 | "outputs": [ 844 | { 845 | "name": "stdout", 846 | "output_type": "stream", 847 | "text": [ 848 | "Train on 3381 samples, validate on 376 samples\n", 849 | "Epoch 1/50\n", 850 | "3381/3381 [==============================] - 6s 2ms/step - loss: 0.0282 - val_loss: 0.0038\n", 851 | "Epoch 2/50\n", 852 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0183 - val_loss: 0.0031\n", 853 | "Epoch 3/50\n", 854 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0137 - val_loss: 0.0027\n", 855 | "Epoch 4/50\n", 856 | "3381/3381 [==============================] - 3s 1ms/step - loss: 0.0111 - val_loss: 0.0025\n", 857 | "Epoch 5/50\n", 858 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0095 - val_loss: 0.0023\n", 859 | "Epoch 6/50\n", 860 | "3381/3381 [==============================] - 3s 1ms/step - loss: 0.0083 - val_loss: 0.0022\n", 861 | "Epoch 7/50\n", 862 | "3381/3381 [==============================] - 3s 1ms/step - loss: 0.0075 - val_loss: 0.0021\n", 863 | "Epoch 8/50\n", 864 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0069 - val_loss: 0.0021\n", 865 | "Epoch 9/50\n", 866 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0065 - val_loss: 0.0020\n", 867 | "Epoch 10/50\n", 868 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0061 - val_loss: 0.0020\n", 869 | "Epoch 11/50\n", 870 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0059 - val_loss: 0.0019\n", 871 | "Epoch 12/50\n", 872 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0056 - val_loss: 0.0019\n", 873 | "Epoch 13/50\n", 874 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0055 - val_loss: 0.0019\n", 875 | "Epoch 14/50\n", 876 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0053 - val_loss: 0.0019\n", 877 | "Epoch 15/50\n", 878 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0052 - val_loss: 0.0018\n", 879 | "Epoch 16/50\n", 880 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0051 - val_loss: 0.0018\n", 881 | "Epoch 17/50\n", 882 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0050 - val_loss: 0.0018\n", 883 | "Epoch 18/50\n", 884 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0049 - val_loss: 0.0018\n", 885 | "Epoch 19/50\n", 886 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0049 - val_loss: 0.0018\n", 887 | "Epoch 20/50\n", 888 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0048 - val_loss: 0.0018\n", 889 | "Epoch 21/50\n", 890 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0047 - val_loss: 0.0017\n", 891 | "Epoch 22/50\n", 892 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0047 - val_loss: 0.0017\n", 893 | "Epoch 23/50\n", 894 | "3381/3381 [==============================] - 3s 1ms/step - loss: 0.0047 - val_loss: 0.0017\n", 895 | "Epoch 24/50\n", 896 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0046 - val_loss: 0.0017\n", 897 | "Epoch 25/50\n", 898 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0046 - val_loss: 0.0017\n", 899 | "Epoch 26/50\n", 900 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0045 - val_loss: 0.0017\n", 901 | "Epoch 27/50\n", 902 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0045 - val_loss: 0.0017\n", 903 | "Epoch 28/50\n", 904 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0045 - val_loss: 0.0017\n", 905 | "Epoch 29/50\n", 906 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0044 - val_loss: 0.0017\n", 907 | "Epoch 30/50\n", 908 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0044 - val_loss: 0.0017\n", 909 | "Epoch 31/50\n", 910 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0044 - val_loss: 0.0017\n", 911 | "Epoch 32/50\n", 912 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0044 - val_loss: 0.0017\n", 913 | "Epoch 33/50\n", 914 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0044 - val_loss: 0.0016\n", 915 | "Epoch 34/50\n", 916 | "3381/3381 [==============================] - 3s 1ms/step - loss: 0.0044 - val_loss: 0.0016\n", 917 | "Epoch 35/50\n", 918 | "3381/3381 [==============================] - 3s 960us/step - loss: 0.0043 - val_loss: 0.0016\n", 919 | "Epoch 36/50\n", 920 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0043 - val_loss: 0.0016\n", 921 | "Epoch 37/50\n", 922 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0043 - val_loss: 0.0016\n", 923 | "Epoch 38/50\n", 924 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0043 - val_loss: 0.0016\n", 925 | "Epoch 39/50\n", 926 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0043 - val_loss: 0.0016\n", 927 | "Epoch 40/50\n", 928 | "3381/3381 [==============================] - 3s 1ms/step - loss: 0.0043 - val_loss: 0.0016\n", 929 | "Epoch 41/50\n", 930 | "3381/3381 [==============================] - 3s 984us/step - loss: 0.0043 - val_loss: 0.0016\n", 931 | "Epoch 42/50\n", 932 | "3381/3381 [==============================] - 3s 963us/step - loss: 0.0043 - val_loss: 0.0016\n", 933 | "Epoch 43/50\n", 934 | "3381/3381 [==============================] - 3s 993us/step - loss: 0.0042 - val_loss: 0.0016\n", 935 | "Epoch 44/50\n", 936 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0042 - val_loss: 0.0016\n", 937 | "Epoch 45/50\n", 938 | "3381/3381 [==============================] - 5s 1ms/step - loss: 0.0042 - val_loss: 0.0016\n", 939 | "Epoch 46/50\n", 940 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0042 - val_loss: 0.0016\n", 941 | "Epoch 47/50\n", 942 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0042 - val_loss: 0.0016\n", 943 | "Epoch 48/50\n", 944 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0042 - val_loss: 0.0016\n", 945 | "Epoch 49/50\n", 946 | "3381/3381 [==============================] - 3s 1ms/step - loss: 0.0042 - val_loss: 0.0016\n", 947 | "Epoch 50/50\n", 948 | "3381/3381 [==============================] - 4s 1ms/step - loss: 0.0042 - val_loss: 0.0016\n" 949 | ] 950 | } 951 | ], 952 | "source": [ 953 | "hist = model.fit(x=X, y=y,\n", 954 | " epochs=50,\n", 955 | " batch_size=64,\n", 956 | " shuffle=True,\n", 957 | " validation_split=0.1)" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": 21, 963 | "metadata": {}, 964 | "outputs": [ 965 | { 966 | "data": { 967 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEWCAYAAABMoxE0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzt3Xt8XOV95/HPb26SRhfrYsn4BjZYOJhLILiQhFwgKSwNJM4WEqCE0JSGpK/wSrpN04ZuSbdsslt2u02bhm25FkK5FkLwNhAC4VZCMNjgEMAYhLGxbOOrbMuydRnNb/84R/JYHkkjj45H0nzfr8xrzjnznKPnicV89ZznnOeYuyMiInKoYqWugIiITG4KEhERKYqCREREiqIgERGRoihIRESkKAoSEREpioJEJEJmdpuZfbfAsmvN7LeLPY7I4aYgERGRoihIRESkKAoSKXvhKaVvmdkrZtZlZreY2Qwze8TMOs3scTNryCn/GTN7zcx2mtlTZnZczmenmNlL4X73ApVDftb5ZrYy3Pc5MzvpEOv8ZTNrM7MdZrbUzGaF283Mvm9mW8xsV9imE8LPPmVmr4d122Bmf3pI/4eJDKEgEQlcAJwNHAt8GngE+AtgOsF/J18HMLNjgbuBPwaagYeB/2dmKTNLAT8B7gAagX8Lj0u47weAW4GvAE3ADcBSM6sYS0XN7BPA/wQ+D8wE1gH3hB+fA3wsbEc9cBGwPfzsFuAr7l4LnAA8MZafKzIcBYlI4B/dfbO7bwD+A1jm7i+7ew/wIHBKWO4i4Kfu/pi79wF/C1QBHwY+CCSBv3f3Pne/H3gx52d8GbjB3Ze5e7+73w70hPuNxaXAre7+Uli/q4EPmdk8oA+oBd4HmLuvcvdN4X59wCIzq3P3Dnd/aYw/VyQvBYlIYHPO8r486zXh8iyCHgAA7p4F1gOzw882+IEzoa7LWT4K+GZ4Wmunme0E5ob7jcXQOuwh6HXMdvcngB8C1wObzexGM6sLi14AfApYZ2ZPm9mHxvhzRfJSkIiMzUaCQACCMQmCMNgAbAJmh9sGHJmzvB74nrvX57zS7n53kXWoJjhVtgHA3X/g7qcCxxOc4vpWuP1Fd18CtBCcgrtvjD9XJC8FicjY3AecZ2afNLMk8E2C01PPAb8CMsDXzSxhZr8LnJaz703AV83s9HBQvNrMzjOz2jHW4S7gS2Z2cji+8j8ITsWtNbPfCo+fBLqAbqA/HMO51MymhafkdgP9Rfz/IDJIQSIyBu6+GvgC8I/ANoKB+U+7e6+79wK/C/w+0EEwnvLjnH2XE4yT/DD8vC0sO9Y6/AK4BniAoBd0DHBx+HEdQWB1EJz+2k4wjgNwGbDWzHYDXw3bIVI004OtRESkGOqRiIhIURQkIiJSFAWJiIgURUEiIiJFSZS6AofD9OnTfd68eaWuhojIpLJixYpt7t48WrmyCJJ58+axfPnyUldDRGRSMbN1o5fSqS0RESmSgkRERIqiIBERkaKUxRhJPn19fbS3t9Pd3X3QZ5WVlcyZM4dkMlmCmomITC5lGyTt7e3U1tYyb948cidrdXe2b99Oe3s78+fPL2ENRUQmh7I9tdXd3U1TUxMHzvgNZkZTU1PenoqIiBysbIMEOChERtsuIiIHK+sgGc3tz61l6a83lroaIiITmoJkBHe/8C5LVypIRERGUtZBMtyzWAa2N6RT7NzbezirJCIy6ZRtkFRWVrJ9+/aDwmTgqq3Kykoaq1PsUJCIiIyobC//nTNnDu3t7WzduvWgzwbuI6lP72Hn3r4S1E5EZPIo2yBJJpOj3ifSWB2c2spmnVhMV3KJiORTtqe2CtGQTpF12N2tXomIyHAUJCNoqA6mSNnRpXESEZHhKEhG0JBOAdChcRIRkWEpSEYwGCTqkYiIDEtBMoLG6iBIdAmwiMjwFCQjqE8HYyS6KVFEZHgKkhHUVCRIxo0dXRojEREZjoJkBGZGvaZJEREZkYJkFI3plC7/FREZgYJkFA3VSU2TIiIyAgXJKBrSmrhRRGQkCpJRNFRrjEREZCQKklE0pJN07O0jm83/7BIRkXKnIBlFQzpFf9bp7M6UuioiIhOSgmQU++fb0uktEZF8FCSj0DQpIiIjU5CMQtOkiIiMTEEyisEeiaZJERHJK9IgMbNzzWy1mbWZ2bfzfF5hZveGny8zs3nh9rPNbIWZ/SZ8/0TOPk+Fx1wZvlqibEN9OEaiHomISH6RPbPdzOLA9cDZQDvwopktdffXc4pdAXS4+wIzuxi4DrgI2AZ82t03mtkJwKPA7Jz9LnX35VHVPVddZYJ4zDRNiojIMKLskZwGtLn7GnfvBe4BlgwpswS4PVy+H/ikmZm7v+zuG8PtrwGVZlYRYV2HZWY0pFN6SqKIyDCiDJLZwPqc9XYO7FUcUMbdM8AuoGlImQuAl929J2fbv4Snta4xM8v3w83sSjNbbmbLt27dWkw7gpsS1SMREckryiDJ9wU/9PbwEcuY2fEEp7u+kvP5pe5+IvDR8HVZvh/u7je6+2J3X9zc3Dymig/VUK35tkREhhNlkLQDc3PW5wAbhytjZglgGrAjXJ8DPAh80d3fHtjB3TeE753AXQSn0CLVkE5qsF1EZBhRBsmLQKuZzTezFHAxsHRImaXA5eHyhcAT7u5mVg/8FLja3X85UNjMEmY2PVxOAucDr0bYBiC4BFiX/4qI5BdZkIRjHlcRXHG1CrjP3V8zs2vN7DNhsVuAJjNrA/4EGLhE+CpgAXDNkMt8K4BHzewVYCWwAbgpqjYMGHhKorsmbhQRGSqyy38B3P1h4OEh276Ts9wNfC7Pft8FvjvMYU8dzzoWojGdIpN1Onsy1FUmD/ePFxGZ0HRnewEGp0nR6S0RkYMoSAqgiRtFRIanIClAQ7WmkhcRGY6CpACDzyTRTYkiIgdRkBSgMT0wA7CCRERkKAVJAWorE8QMdmq+LRGRgyhIChCLBRM3arBdRORgCpIC1WuaFBGRvBQkBQqmSVGQiIgMpSApUDBNisZIRESGUpAUqDGtHomISD4KkgI1VAc9Ek3cKCJyIAVJgRrSSXr7s3T19pe6KiIiE4qCpECD06To9JaIyAEUJAUanCZFlwCLiBxAQVKgxupgKnkNuIuIHEhBUqD6sEeiS4BFRA6kICmQJm4UEclPQVKguqokZmiaFBGRIRQkBYrHjPqqpCZuFBEZQkEyBg3VKTo0RiIicgAFyRg0pFO6j0REZAgFyRg0aL4tEZGDKEjGoCGd1OW/IiJDKEjGoLE6eEqiJm4UEdlPQTIG9ekUvZks+/o0caOIyAAFyRhomhQRkYMpSMZA06SIiBxMQTIGjdWaJkVEZCgFyRhoKnkRkYMpSMagIR2MkeimRBGR/RQkYzAtnLhxh8ZIREQGRRokZnauma02szYz+3aezyvM7N7w82VmNi/cfraZrTCz34Tvn8jZ59Rwe5uZ/cDMLMo25ErEY9RVJjUDsIhIjsiCxMziwPXA7wCLgEvMbNGQYlcAHe6+APg+cF24fRvwaXc/EbgcuCNnn38CrgRaw9e5UbUhn8ZqTZMiIpIryh7JaUCbu69x917gHmDJkDJLgNvD5fuBT5qZufvL7r4x3P4aUBn2XmYCde7+Kw9uL/8R8NkI23CQek2TIiJygCiDZDawPme9PdyWt4y7Z4BdQNOQMhcAL7t7T1i+fZRjRqpREzeKiBwgyiDJN3YxdJKqEcuY2fEEp7u+MoZjDux7pZktN7PlW7duLaC6halPpzRGIiKSI8ogaQfm5qzPATYOV8bMEsA0YEe4Pgd4EPiiu7+dU37OKMcEwN1vdPfF7r64ubm5yKbs11itpySKiOSKMkheBFrNbL6ZpYCLgaVDyiwlGEwHuBB4wt3dzOqBnwJXu/svBwq7+yag08w+GF6t9UXgoQjbcJCG6hTdfVn29WriRhERiDBIwjGPq4BHgVXAfe7+mplda2afCYvdAjSZWRvwJ8DAJcJXAQuAa8xsZfhqCT/7I+BmoA14G3gkqjbko7vbRUQOlIjy4O7+MPDwkG3fyVnuBj6XZ7/vAt8d5pjLgRPGt6aFGwiSHV29zKqvKlU1REQmDN3ZPkYD06ToEmARkYCCZIwGZwDWqS0REUBBMmb7n0miIBERAQXJmNWn9ZREEZFcCpIxSsZj1FYmNEYiIhJSkBwCTdwoIrKfguQQ1KdTuo9ERCSkIDkEjemkgkREJKQgOQSN1RVs61SQiIiAguSQHN1czXu7u9nTkyl1VURESk5BcggWtNQA8PaWPSWuiYhI6SlIDkFrGCRvKUhERBQkh+LIxjSpeIy3tnSWuioiIiWnIDkEiXiMo5uradusHomIiILkEC1oqdGpLRERFCSHrLWllvUde/WkRBEpewqSQ9Q6owZ3eHureiUiUt4UJIdo4MqtNp3eEpEypyA5REc1VZOIma7cEpGypyA5RKlEjHnTq3lLV26JSJkrKEjM7BtmVmeBW8zsJTM7J+rKTXStLTU6tSUiZa/QHskfuPtu4BygGfgS8DeR1WqSaG2pYe32LnoyunJLRMpXoUFi4fungH9x91/nbCtbC2bUknV4Z1tXqasiIlIyhQbJCjP7OUGQPGpmtUA2umpNDoNzbmmcRETKWKLAclcAJwNr3H2vmTUSnN4qa/OnVxMzTd4oIuWt0B7Jh4DV7r7TzL4A/CWwK7pqTQ6VyThHNVXTpkuARaSMFRok/wTsNbP3A38GrAN+FFmtJpEFLTU6tSUiZa3QIMm4uwNLgH9w938AaqOr1uTR2lLDO9u66Osv+yEjESlThQZJp5ldDVwG/NTM4kAyumpNHq0zashknXXbdeWWiJSnQoPkIqCH4H6S94DZwP+OrFaTSGtL0DHT6S0RKVcFBUkYHncC08zsfKDb3TVGAhzTXIPpyi0RKWOFTpHyeeAF4HPA54FlZnZhlBWbLKpSceY0VClIRKRsFXofyX8FfsvdtwCYWTPwOHB/VBWbTFpbanlrsy4BFpHyVOgYSWwgRELbC9nXzM41s9Vm1mZm387zeYWZ3Rt+vszM5oXbm8zsSTPbY2Y/HLLPU+ExV4avlgLbEJnWlhrWbOsioyu3RKQMFdoj+ZmZPQrcHa5fBDw80g7hlV3XA2cD7cCLZrbU3V/PKXYF0OHuC8zsYuC68NjdwDXACeFrqEvdfXmBdY/cgpYaejNZ1nfsY/706lJXR0TksCp0sP1bwI3AScD7gRvd/c9H2e00oM3d17h7L3APwX0ouZYAt4fL9wOfNDNz9y53f5YgUCa81hkDV27p9JaIlJ9CeyS4+wPAA2M49mxgfc56O3D6cGXcPWNmu4AmYNsox/4XM+sP6/Pd8GbJA5jZlcCVAEceeeQYqj12CwYmb9yyh3OOj/RHiYhMOCP2SMys08x253l1mtnuUY6db5r5oV/4hZQZ6lJ3PxH4aPi6LF8hd7/R3Re7++Lm5uZRDlmcmooEs6ZV6iFXIlKWRgwSd69197o8r1p3rxvl2O3A3Jz1OcDG4cqYWQKYBuwYpU4bwvdO4C6CU2glt2BGrZ7fLiJlKcpntr8ItJrZfDNLARcDS4eUWQpcHi5fCDyR7zTVADNLmNn0cDkJnA+8Ou41PwQDj93NZkfrUImITC0Fj5GMVTjmcRXwKBAHbnX318zsWmC5uy8FbgHuMLM2gp7IxQP7m9laoA5ImdlnCR7zu47gwVrJ8JiPAzdF1YaxaG2pobsvy4ad+5jbmC51dUREDpvIggTA3R9myGXC7v6dnOVugrvl8+07b5jDnjpe9RtPrTMGBtw7FSQiUlaiPLVVVhY0a/JGESlPCpJxMi2dpKW2QnNuiUjZUZCMo9YZNQoSESk7CpJx1NpSS9vmTka48ExEZMpRkIyjhUfU0tXbz5ptelqiiJQPBck4+siC6QA8tXpriWsiInL4KEjG0dzGNAtaanhq9ZbRC4uITBEKknF21sJmlq3ZQVdPptRVERE5LBQk4+yshS309md57u3tpa6KiMhhoSAZZ4vnNVKdivOkTm+JSJlQkIyzVCLGR1qn89QbW3QZsIiUBQVJBM5a2MLGXd28qelSRKQMKEgicObCFgCd3hKRsqAgicAR0yo5bmYdT76hIBGRqU9BEpGzFjazfF0Hu7v7Sl0VEZFIKUgictb7WujPOs++ta3UVRERiZSCJCKnzK2nrjKh01siMuUpSCKSiMf42LHNPPXmVj3HXUSmNAVJhM5a2MLWzh5e37S71FUREYmMgiRCH1/YjBk6vSUiU5qCJELTayo4aU49T+h+EhGZwhQkETtrYTMr1+9kR1dvqasiIhIJBUnEzlrYgjs886YediUiU5OCJGInzp5GU3VK06WIyJSlIIlYLGZ8fGEzT7+5lX5dBiwiU5CC5DA4a2ELO/f2sWJdR6mrIiIy7hQkh8FZ72uhtjLBbc+9U+qqiIiMOwXJYVBTkeALHzyKR159j3e2dZW6OiIi40pBcph86Yx5JGMxbnxmTamrIiIyrhQkh0lLbSUXnDqbB15qZ0tnd6mrIyIybhQkh9GXP3o0ff1Zbvvl2lJXRURk3ChIDqOjm2s49/gjuOP5dXTqgVciMkVEGiRmdq6ZrTazNjP7dp7PK8zs3vDzZWY2L9zeZGZPmtkeM/vhkH1ONbPfhPv8wMwsyjaMt69+/Bg6uzPc/cK7pa6KiMi4iCxIzCwOXA/8DrAIuMTMFg0pdgXQ4e4LgO8D14Xbu4FrgD/Nc+h/Aq4EWsPXueNf++i8f249Hzq6iVuefYeeTH+pqyMiUrQoeySnAW3uvsbde4F7gCVDyiwBbg+X7wc+aWbm7l3u/ixBoAwys5lAnbv/yt0d+BHw2QjbEImvnnkMm3f38NDLG0tdFRGRokUZJLOB9Tnr7eG2vGXcPQPsAppGOWb7KMcEwMyuNLPlZrZ869aJNWHix1qns2hmHf/8zNt6eqKITHpRBkm+sYuh35qFlDmk8u5+o7svdvfFzc3NIxzy8DMzvvLxo1mztYvHVm0udXVERIoSZZC0A3Nz1ucAQ8/lDJYxswQwDdgxyjHnjHLMSeG8E2cyp6GKf376bYKzdCIik1OUQfIi0Gpm880sBVwMLB1SZilwebh8IfCEj/Ct6u6bgE4z+2B4tdYXgYfGv+rRS8RjfPmjR/Pyuzt54Z2RslNEZGKLLEjCMY+rgEeBVcB97v6amV1rZp8Ji90CNJlZG/AnwOAlwma2Fvg74PfNrD3niq8/Am4G2oC3gUeiakPUPr94LtNrKvjew6s0xbyITFpWDqdVFi9e7MuXLy91NfJa+uuNfP3ul/nO+Yv4g4/ML3V1REQGmdkKd188Wjnd2V5inz5pJmcubOZvf76aDTv3lbo6IiJjpiApMTPjvy85AXf4q4de1cC7iEw6CpIJYG5jmv9ydiuPr9rCz159r9TVEREZEwXJBPEHZ8xn0cw6/mrpa+zWhI4iMokoSCaIRDzG31xwItv29PC/fvZGqasjIlIwBckEctKcei7/8DzuXPYuK9bp3hIRmRwUJBPMN89ZyMy6Sq7+8W/ozWRLXR0RkVEpSCaYmooE1y45gTc37+H/PtVW6uqIiIxKQTIB/faiGSw5eRZ///hbPLCiffQdRERKKFHqCkh+111wEtv29PBnD7xCbWWCc44/otRVEhHJSz2SCaoyGeeGyxZzwuxpXHX3yzz39rZSV0lEJC8FyQRWU5Hgtt//LY5qTPPl25fzSvvOUldJROQgCpIJrqE6xR1XnE5DdYrLb32Bti2dpa6SiMgBFCSTwBHTKvnXK04nHovxhZtfoL1jb6mrJCIySEEyScybXs0dV5zG3t4Ml9z0PG+8t7vUVRIRARQkk8pxM+v40RWn09OX5bPX/5KHVm4odZVERBQkk83Jc+v5969/hJPm1PONe1by35a+pjvgRaSkFCSTUEttJXf+4en84Ufmc9tza7nkpufZvLu71NUSkTKlIJmkkvEYf3n+Iv7xklNYtWk35/3gWZat2V7qaolIGVKQTHKffv8sfvK1M6irTHDJTc9zzU9epaOrt9TVEpEyoiCZAo6dUctDV53BZR88irteeJcz//Ypbn9uLZl+jZ2ISPQUJFNEbWWSv15yAg9//aOcMDt40uJ5P3iW59o0tYqIREtBMsUsPKKWf73idG647FT29mX4vZuX8ZU7lrP6Pd0RLyLR0Oy/U5CZ8Z+OP4KPH9vMLc++w/VPtvHoa5v5+LHNXPmxo/nwMU2YWamrKSJThLl7qesQucWLF/vy5ctLXY2S6ejq5c5l67jtuXVs29PDopl1fPlj8zn/pFkk4+qUikh+ZrbC3RePWk5BUj66+/p5aOUGbvqPd2jbsoeZ0yr5/OK5nH/STFpn1Ja6eiIywShIcihIDpTNOk+/uZWbn13Dc29vxx0WzqjlvJNmct5JMzmmuabUVRSRCUBBkkNBMrwtu7t55NX3+Okrm3hx3Q7c4X1H1PKpE2dy9qIZvO+IWo2niJQpBUkOBUlh3tvVzSOvbuKnr2xixbsduMPcxirOPu4Izjl+BouPaiChMRWRsqEgyaEgGbstnd38YtUWHnt9M8+2baM3k6U+neTMY5tZPK+RU49q4NgZtcRj6q2ITFUKkhwKkuJ09WR45s2tPPb6Zp55axvb9vQAwaOATzmyng8c2cAHjmrgxNnTaKxOlbi2IjJeFCQ5FCTjx91Zv2MfK97dwYp1HaxYt5PV7+0mG/4aza6v4vhZdZwwe9rge0tthcZZRCahQoMk0hsSzexc4B+AOHCzu//NkM8rgB8BpwLbgYvcfW342dXAFUA/8HV3fzTcvhboDLdnCmmkjB8z48imNEc2pfnPp8wBoLO7j1fad/Haxl28umE3r27cxWOrNjPwN0pDOsmxM2pZeET4mlFL64xaplUlS9gSERkvkQWJmcWB64GzgXbgRTNb6u6v5xS7Auhw9wVmdjFwHXCRmS0CLgaOB2YBj5vZse7eH+53lrtrEqkJorYyyRkLpnPGgumD27p6MqzatJtXN+xi9eY9vLm5kwdf2kBnT2awzPSaCo5qSnNUY5q5jelguSlYbq5RL0ZksoiyR3Ia0ObuawDM7B5gCZAbJEuA/xYu3w/80IJvjyXAPe7eA7xjZm3h8X4VYX1lHFVXJFg8r5HF8xoHt7k7G3d18+Z7naze3Mk7W7tYt6OL59ds58GVG8g9y1qZjDGnIc2chirmhu8D67MbqmiqTiloRCaIKINkNrA+Z70dOH24Mu6eMbNdQFO4/fkh+84Olx34uZk5cIO735jvh5vZlcCVAEceeWRxLZFxYWbMrq9idn0VZ72v5YDPuvv6ae/Yx7s7uli/Yx/rd+ylvWMf6zv28vK7O9m1r++A8pXJGLPCY82ur6KxOkVDOkV9OklDOkVDdZL6dIr6qiTTqpK6bFkkQlEGSb4/F4eO7A9XZqR9z3D3jWbWAjxmZm+4+zMHFQ4C5kYIBtsLr7aUQmUyzoKWGha05L+rfnd3H+079rFh5z42dOwN3nfuo71jH6s2ddKxt5f+7PD/zDUVCaaFoVKfTtJQnWJ6dYrpNRVMr62gqTo1+D6tKkltZVKXNosUKMogaQfm5qzPATYOU6bdzBLANGDHSPu6+8D7FjN7kOCU10FBIlNLXWWSRbOSLJpVl/dzd6ezJ8POrj469vbSsbeXnXv72LWvb//7vl527e1j574+Vm3czdY9PXR2Z/IeD6C2IkFdGD51VQlqKhJUpRJUp+KkUwmqK4L3msoEdZXB57WVyfA9QV1lktrKBDEFkkxxUQbJi0Crmc0HNhAMnv/ekDJLgcsJxj4uBJ5wdzezpcBdZvZ3BIPtrcALZlYNxNy9M1w+B7g2wjbIJGFm1FUmqatMcmRTuuD9ejL9bN/Ty7Y9PWzf08v2rl527wuCZ9e+Pnbv62N3dx+792XYtKubvb397O3NsLenn67eDCN0gsJ6DQmkyiTpVJxE3EjEYyRj4XvcSMVj1FQmqKlIUlOZoLYiCKfqigSVyRgViTgVyRgVif3L6WRcp+2k5CILknDM4yrgUYLLf29199fM7FpgubsvBW4B7ggH03cQhA1hufsIBuYzwNfcvd/MZgAPhoOsCeAud/9ZVG2Qqa8iEWdWfRWz6qvGvK+7092XpbOnjz3dGTq7M+zpydDZ3cfucH0wjHLC6b3dfWT6nb5slky/k+nP0pd1evr66ertH/EUXT6peIx0RZx0Mk66IkE6FScVjxGPWRBYsRiJmBGPGalEjKpknKpUnKpknMpwuTIRI5mIkYzFSCaCfZLxGKmEBe/x4PNUPAiyVCL4fOCzRNzCddNFEGVINySKTCBDw2lPT4Y93Rl6+rP09GXpyfTTk8kGr75+9vX2s7evn709mbC3FPSYevuDkOrPOn1Zpz8Mrd5Mlu6+YJ99vcGxxlsyPhBEQXAlcsIpHjPiZsRiRjwGcQsCLgit2GAYDYQUQNadrAfvhO+pRIx0Kk5VMkFVKkY6laAqGSeZCPYxgt6gYQzkWu628H/EYxb07hIxKpIxKpPxwR5fIm6D9UvEbX/dwxcGMQt6wwPHHqine/BvmfWgzMBxJ1vITogbEkVkbMws6C2k4rQchkfEZLNOdyYIlUzW6evP0hf2knrD5b7+LL2ZYL03s/+VyWbp7Xf6Mtlwv2A9058dPM7A9qAH5mSzQbj1e7jsHnzWn6WrJ7P/54X7GUbMyPniDr60e/uz7O0Ng7SAU4wTRWUyNtgTrEjEBoMnmw1Cp9+d4I97C0OKweCynPUguMJt4fJAgAXB62SzQZg9+a0zqUjEI22XgkSkjMViRjqVIJ2avF8F7k5vf5Z9vf309TtO0HNxGLw3yQl7CWH5ge392WDf7r6wp9e3fzmTzQahF74y4XvuFzbk9EAIelgW9lJi4Zd81qE70093eOzunN5gzIJ/g5gN9NQY7LX4QBgQ9siyPlj/bPjzBnppjg8GzkDw7g+a6HtBk/e3R0SE4Is3OD0V7V/dMjxd7iEiIkVRkIiISFEUJCIiUhQFiYiIFEVBIiIiRVGQiIhIURQkIiJSFAWJiIgUpSzm2jKzrcC6Q9x9OlCOj/VVu8uL2l1eCm33Ue7ePFqhsgiSYpjZ8kImLZtq1O5i0iaIAAAFGElEQVTyonaXl/Fut05tiYhIURQkIiJSFAXJ6G4sdQVKRO0uL2p3eRnXdmuMREREiqIeiYiIFEVBIiIiRVGQDMPMzjWz1WbWZmbfLnV9omRmt5rZFjN7NWdbo5k9ZmZvhe8NpaxjFMxsrpk9aWarzOw1M/tGuH1Kt93MKs3sBTP7ddjuvw63zzezZWG77zWzVKnrGgUzi5vZy2b27+H6lG+3ma01s9+Y2UozWx5uG7ffcwVJHmYWB64HfgdYBFxiZotKW6tI3QacO2Tbt4FfuHsr8ItwfarJAN909+OADwJfC/+dp3rbe4BPuPv7gZOBc83sg8B1wPfDdncAV5SwjlH6BrAqZ71c2n2Wu5+cc//IuP2eK0jyOw1oc/c17t4L3AMsKXGdIuPuzwA7hmxeAtweLt8OfPawVuowcPdN7v5SuNxJ8OUymynedg/sCVeT4cuBTwD3h9unXLsBzGwOcB5wc7hulEG7hzFuv+cKkvxmA+tz1tvDbeVkhrtvguALF2gpcX0iZWbzgFOAZZRB28PTOyuBLcBjwNvATnfPhEWm6u/83wN/BmTD9SbKo90O/NzMVpjZleG2cfs9T4xDBaciy7NN10lPUWZWAzwA/LG77w7+SJ3a3L0fONnM6oEHgePyFTu8tYqWmZ0PbHH3FWZ25sDmPEWnVLtDZ7j7RjNrAR4zszfG8+DqkeTXDszNWZ8DbCxRXUpls5nNBAjft5S4PpEwsyRBiNzp7j8ON5dF2wHcfSfwFMEYUb2ZDfxxORV/588APmNmawlOV3+CoIcy1duNu28M37cQ/OFwGuP4e64gye9FoDW8miMFXAwsLXGdDrelwOXh8uXAQyWsSyTC8+O3AKvc/e9yPprSbTez5rAngplVAb9NMD70JHBhWGzKtdvdr3b3Oe4+j+C/6Sfc/VKmeLvNrNrMageWgXOAVxnH33Pd2T4MM/sUwV8rceBWd/9eiasUGTO7GziTYGrpzcBfAT8B7gOOBN4FPufuQwfkJzUz+wjwH8Bv2H/O/C8IxkmmbNvN7CSCwdU4wR+T97n7tWZ2NMFf6o3Ay8AX3L2ndDWNTnhq60/d/fyp3u6wfQ+GqwngLnf/npk1MU6/5woSEREpik5tiYhIURQkIiJSFAWJiIgURUEiIiJFUZCIiEhRFCQiE5iZnTkwS63IRKUgERGRoihIRMaBmX0hfMbHSjO7IZwUcY+Z/R8ze8nMfmFmzWHZk83seTN7xcweHHgOhJktMLPHw+eEvGRmx4SHrzGz+83sDTO708phMjCZVBQkIkUys+OAiwgmxjsZ6AcuBaqBl9z9A8DTBDMGAPwI+HN3P4ngrvqB7XcC14fPCfkwsCncfgrwxwTPxjmaYM4okQlDs/+KFO+TwKnAi2FnoYpgArwscG9Y5l+BH5vZNKDe3Z8Ot98O/Fs4F9Jsd38QwN27AcLjveDu7eH6SmAe8Gz0zRIpjIJEpHgG3O7uVx+w0eyaIeVGmo9opNNVufM+9aP/bmWC0aktkeL9ArgwfNbDwLOwjyL472tgVtnfA551911Ah5l9NNx+GfC0u+8G2s3ss+ExKswsfVhbIXKI9JeNSJHc/XUz+0uCJ9DFgD7ga0AXcLyZrQB2EYyjQDBl9z+HQbEG+FK4/TLgBjO7NjzG5w5jM0QOmWb/FYmIme1x95pS10Mkajq1JSIiRVGPREREiqIeiYiIFEVBIiIiRVGQiIhIURQkIiJSFAWJiIgU5f8DyChMwJojCvUAAAAASUVORK5CYII=\n", 968 | "text/plain": [ 969 | "" 970 | ] 971 | }, 972 | "metadata": {}, 973 | "output_type": "display_data" 974 | } 975 | ], 976 | "source": [ 977 | "def plot_hist(hist):\n", 978 | " # summarize history for loss\n", 979 | " fig, ax = plt.subplots() # create figure & 1 axis\n", 980 | "\n", 981 | " plt.title('model loss')\n", 982 | " plt.ylabel('loss')\n", 983 | " plt.xlabel('epoch')\n", 984 | " plt.legend(['train', 'test'], loc='upper left')\n", 985 | "\n", 986 | " plt.plot(hist.history['loss'])\n", 987 | " #plt.plot(hist.history['val_loss'])\n", 988 | "\n", 989 | "plot_hist(hist)" 990 | ] 991 | }, 992 | { 993 | "cell_type": "markdown", 994 | "metadata": {}, 995 | "source": [ 996 | "## Recommender" 997 | ] 998 | }, 999 | { 1000 | "cell_type": "code", 1001 | "execution_count": 13, 1002 | "metadata": {}, 1003 | "outputs": [], 1004 | "source": [ 1005 | "# Predict new Matrix Interactions, set score zero on visualized games\n", 1006 | "new_matrix = model.predict(X) * (X == 0)" 1007 | ] 1008 | }, 1009 | { 1010 | "cell_type": "code", 1011 | "execution_count": 14, 1012 | "metadata": { 1013 | "scrolled": true 1014 | }, 1015 | "outputs": [ 1016 | { 1017 | "data": { 1018 | "text/html": [ 1019 | "
\n", 1020 | "\n", 1033 | "\n", 1034 | " \n", 1035 | " \n", 1036 | " \n", 1037 | " \n", 1038 | " \n", 1039 | " \n", 1040 | " \n", 1041 | " \n", 1042 | " \n", 1043 | " \n", 1044 | " \n", 1045 | " \n", 1046 | " \n", 1047 | " \n", 1048 | " \n", 1049 | " \n", 1050 | " \n", 1051 | " \n", 1052 | " \n", 1053 | " \n", 1054 | " \n", 1055 | " \n", 1056 | " \n", 1057 | " \n", 1058 | " \n", 1059 | " \n", 1060 | " \n", 1061 | " \n", 1062 | " \n", 1063 | " \n", 1064 | " \n", 1065 | " \n", 1066 | " \n", 1067 | " \n", 1068 | " \n", 1069 | " \n", 1070 | " \n", 1071 | " \n", 1072 | " \n", 1073 | " \n", 1074 | " \n", 1075 | " \n", 1076 | " \n", 1077 | " \n", 1078 | " \n", 1079 | " \n", 1080 | " \n", 1081 | " \n", 1082 | " \n", 1083 | " \n", 1084 | " \n", 1085 | " \n", 1086 | " \n", 1087 | " \n", 1088 | " \n", 1089 | " \n", 1090 | " \n", 1091 | " \n", 1092 | " \n", 1093 | " \n", 1094 | " \n", 1095 | " \n", 1096 | " \n", 1097 | " \n", 1098 | " \n", 1099 | " \n", 1100 | " \n", 1101 | " \n", 1102 | " \n", 1103 | " \n", 1104 | " \n", 1105 | " \n", 1106 | " \n", 1107 | " \n", 1108 | " \n", 1109 | " \n", 1110 | " \n", 1111 | " \n", 1112 | " \n", 1113 | " \n", 1114 | " \n", 1115 | " \n", 1116 | " \n", 1117 | " \n", 1118 | " \n", 1119 | " \n", 1120 | " \n", 1121 | " \n", 1122 | " \n", 1123 | " \n", 1124 | " \n", 1125 | " \n", 1126 | " \n", 1127 | " \n", 1128 | " \n", 1129 | " \n", 1130 | " \n", 1131 | " \n", 1132 | " \n", 1133 | " \n", 1134 | " \n", 1135 | " \n", 1136 | " \n", 1137 | " \n", 1138 | " \n", 1139 | " \n", 1140 | " \n", 1141 | " \n", 1142 | " \n", 1143 | " \n", 1144 | " \n", 1145 | " \n", 1146 | " \n", 1147 | " \n", 1148 | " \n", 1149 | " \n", 1150 | " \n", 1151 | " \n", 1152 | " \n", 1153 | " \n", 1154 | " \n", 1155 | " \n", 1156 | " \n", 1157 | " \n", 1158 | " \n", 1159 | " \n", 1160 | " \n", 1161 | " \n", 1162 | " \n", 1163 | " \n", 1164 | " \n", 1165 | " \n", 1166 | " \n", 1167 | " \n", 1168 | " \n", 1169 | " \n", 1170 | " \n", 1171 | " \n", 1172 | " \n", 1173 | " \n", 1174 | " \n", 1175 | " \n", 1176 | " \n", 1177 | " \n", 1178 | " \n", 1179 | " \n", 1180 | " \n", 1181 | " \n", 1182 | " \n", 1183 | " \n", 1184 | " \n", 1185 | " \n", 1186 | " \n", 1187 | " \n", 1188 | " \n", 1189 | " \n", 1190 | " \n", 1191 | " \n", 1192 | " \n", 1193 | " \n", 1194 | " \n", 1195 | " \n", 1196 | " \n", 1197 | " \n", 1198 | " \n", 1199 | " \n", 1200 | " \n", 1201 | " \n", 1202 | " \n", 1203 | " \n", 1204 | " \n", 1205 | " \n", 1206 | "
content_id01235678910...5103510451055106510751085109511051115112
user_id
0-0.002197-0.001979-0.0008050.002296-0.000611-0.008840-0.0070280.005476-0.003611-0.001567...-0.007693-0.002160-0.001630-0.0072510.003084-0.0199690.004092-0.004482-0.0156260.000907
10.0006520.0029330.007499-0.008718-0.004078-0.0017390.0003830.018262-0.0137930.018117...-0.009904-0.012572-0.009873-0.005511-0.000534-0.003477-0.008922-0.0036470.029548-0.002062
20.004266-0.0044230.0041110.011979-0.0138280.016206-0.0133070.017889-0.0088610.005887...-0.014347-0.0028230.0067620.0287510.0068270.032240-0.001731-0.0003660.0030560.017550
30.0040460.0011160.002487-0.000940-0.0002100.002233-0.0003040.005891-0.0042610.001885...0.005553-0.0003220.002608-0.003137-0.0025500.0077190.000484-0.007889-0.000531-0.002881
40.005502-0.000663-0.001771-0.001596-0.0001340.0058940.0026740.004908-0.000602-0.003510...0.0032610.0013400.0005160.000465-0.0000460.0067020.002666-0.0044890.011630-0.004099
\n", 1207 | "

5 rows × 4862 columns

\n", 1208 | "
" 1209 | ], 1210 | "text/plain": [ 1211 | "content_id 0 1 2 3 5 6 \\\n", 1212 | "user_id \n", 1213 | "0 -0.002197 -0.001979 -0.000805 0.002296 -0.000611 -0.008840 \n", 1214 | "1 0.000652 0.002933 0.007499 -0.008718 -0.004078 -0.001739 \n", 1215 | "2 0.004266 -0.004423 0.004111 0.011979 -0.013828 0.016206 \n", 1216 | "3 0.004046 0.001116 0.002487 -0.000940 -0.000210 0.002233 \n", 1217 | "4 0.005502 -0.000663 -0.001771 -0.001596 -0.000134 0.005894 \n", 1218 | "\n", 1219 | "content_id 7 8 9 10 ... 5103 \\\n", 1220 | "user_id ... \n", 1221 | "0 -0.007028 0.005476 -0.003611 -0.001567 ... -0.007693 \n", 1222 | "1 0.000383 0.018262 -0.013793 0.018117 ... -0.009904 \n", 1223 | "2 -0.013307 0.017889 -0.008861 0.005887 ... -0.014347 \n", 1224 | "3 -0.000304 0.005891 -0.004261 0.001885 ... 0.005553 \n", 1225 | "4 0.002674 0.004908 -0.000602 -0.003510 ... 0.003261 \n", 1226 | "\n", 1227 | "content_id 5104 5105 5106 5107 5108 5109 \\\n", 1228 | "user_id \n", 1229 | "0 -0.002160 -0.001630 -0.007251 0.003084 -0.019969 0.004092 \n", 1230 | "1 -0.012572 -0.009873 -0.005511 -0.000534 -0.003477 -0.008922 \n", 1231 | "2 -0.002823 0.006762 0.028751 0.006827 0.032240 -0.001731 \n", 1232 | "3 -0.000322 0.002608 -0.003137 -0.002550 0.007719 0.000484 \n", 1233 | "4 0.001340 0.000516 0.000465 -0.000046 0.006702 0.002666 \n", 1234 | "\n", 1235 | "content_id 5110 5111 5112 \n", 1236 | "user_id \n", 1237 | "0 -0.004482 -0.015626 0.000907 \n", 1238 | "1 -0.003647 0.029548 -0.002062 \n", 1239 | "2 -0.000366 0.003056 0.017550 \n", 1240 | "3 -0.007889 -0.000531 -0.002881 \n", 1241 | "4 -0.004489 0.011630 -0.004099 \n", 1242 | "\n", 1243 | "[5 rows x 4862 columns]" 1244 | ] 1245 | }, 1246 | "execution_count": 14, 1247 | "metadata": {}, 1248 | "output_type": "execute_result" 1249 | } 1250 | ], 1251 | "source": [ 1252 | "# converting the reconstructed matrix back to a Pandas dataframe\n", 1253 | "new_users_items_matrix_df = pd.DataFrame(new_matrix, \n", 1254 | " columns = users_items_matrix_df.columns, \n", 1255 | " index = users_items_matrix_df.index)\n", 1256 | "new_users_items_matrix_df.head()" 1257 | ] 1258 | }, 1259 | { 1260 | "cell_type": "code", 1261 | "execution_count": 15, 1262 | "metadata": {}, 1263 | "outputs": [ 1264 | { 1265 | "name": "stdout", 1266 | "output_type": "stream", 1267 | "text": [ 1268 | "-0.49707025 1.4381734\n" 1269 | ] 1270 | } 1271 | ], 1272 | "source": [ 1273 | "print(new_users_items_matrix_df.values.min(), new_users_items_matrix_df.values.max())" 1274 | ] 1275 | }, 1276 | { 1277 | "cell_type": "code", 1278 | "execution_count": 16, 1279 | "metadata": {}, 1280 | "outputs": [], 1281 | "source": [ 1282 | "def recommender_for_user(user_id, interact_matrix, df_content, topn = 10):\n", 1283 | " '''\n", 1284 | " Recommender Games for UserWarning\n", 1285 | " '''\n", 1286 | " pred_scores = interact_matrix.loc[user_id].values\n", 1287 | "\n", 1288 | " df_scores = pd.DataFrame({'content_id': list(users_items_matrix_df.columns), \n", 1289 | " 'score': pred_scores})\n", 1290 | "\n", 1291 | " df_rec = df_scores.set_index('content_id')\\\n", 1292 | " .join(df_content.set_index('content_id'))\\\n", 1293 | " .sort_values('score', ascending=False)\\\n", 1294 | " .head(topn)[['score', 'game']]\n", 1295 | " \n", 1296 | " return df_rec[df_rec.score > 0]" 1297 | ] 1298 | }, 1299 | { 1300 | "cell_type": "markdown", 1301 | "metadata": {}, 1302 | "source": [ 1303 | "Recommender for **user_id = 1011**. This user prefers games in the same half-life line" 1304 | ] 1305 | }, 1306 | { 1307 | "cell_type": "code", 1308 | "execution_count": 17, 1309 | "metadata": { 1310 | "cell_style": "split" 1311 | }, 1312 | "outputs": [ 1313 | { 1314 | "data": { 1315 | "text/html": [ 1316 | "
\n", 1317 | "\n", 1330 | "\n", 1331 | " \n", 1332 | " \n", 1333 | " \n", 1334 | " \n", 1335 | " \n", 1336 | " \n", 1337 | " \n", 1338 | " \n", 1339 | " \n", 1340 | " \n", 1341 | " \n", 1342 | " \n", 1343 | " \n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | "
scoregame
content_id
31961.0Portal
42211.0Team Fortress 2
20551.0Half-Life 2
20571.0Half-Life 2 Episode One
20591.0Half-Life 2 Lost Coast
\n", 1371 | "
" 1372 | ], 1373 | "text/plain": [ 1374 | " score game\n", 1375 | "content_id \n", 1376 | "3196 1.0 Portal\n", 1377 | "4221 1.0 Team Fortress 2\n", 1378 | "2055 1.0 Half-Life 2\n", 1379 | "2057 1.0 Half-Life 2 Episode One\n", 1380 | "2059 1.0 Half-Life 2 Lost Coast" 1381 | ] 1382 | }, 1383 | "execution_count": 17, 1384 | "metadata": {}, 1385 | "output_type": "execute_result" 1386 | } 1387 | ], 1388 | "source": [ 1389 | "# Games previously purchased by the user\n", 1390 | "recommender_for_user(\n", 1391 | " user_id = 1011, \n", 1392 | " interact_matrix = users_items_matrix_df, \n", 1393 | " df_content = df_game)" 1394 | ] 1395 | }, 1396 | { 1397 | "cell_type": "code", 1398 | "execution_count": 18, 1399 | "metadata": { 1400 | "cell_style": "split" 1401 | }, 1402 | "outputs": [ 1403 | { 1404 | "data": { 1405 | "text/html": [ 1406 | "
\n", 1407 | "\n", 1420 | "\n", 1421 | " \n", 1422 | " \n", 1423 | " \n", 1424 | " \n", 1425 | " \n", 1426 | " \n", 1427 | " \n", 1428 | " \n", 1429 | " \n", 1430 | " \n", 1431 | " \n", 1432 | " \n", 1433 | " \n", 1434 | " \n", 1435 | " \n", 1436 | " \n", 1437 | " \n", 1438 | " \n", 1439 | " \n", 1440 | " \n", 1441 | " \n", 1442 | " \n", 1443 | " \n", 1444 | " \n", 1445 | " \n", 1446 | " \n", 1447 | " \n", 1448 | " \n", 1449 | " \n", 1450 | " \n", 1451 | " \n", 1452 | " \n", 1453 | " \n", 1454 | " \n", 1455 | " \n", 1456 | " \n", 1457 | " \n", 1458 | " \n", 1459 | " \n", 1460 | " \n", 1461 | " \n", 1462 | " \n", 1463 | " \n", 1464 | " \n", 1465 | " \n", 1466 | " \n", 1467 | " \n", 1468 | " \n", 1469 | " \n", 1470 | " \n", 1471 | " \n", 1472 | " \n", 1473 | " \n", 1474 | " \n", 1475 | " \n", 1476 | " \n", 1477 | " \n", 1478 | " \n", 1479 | " \n", 1480 | " \n", 1481 | " \n", 1482 | " \n", 1483 | " \n", 1484 | " \n", 1485 | "
scoregame
content_id
20580.267165Half-Life 2 Episode Two
20560.251573Half-Life 2 Deathmatch
9780.175254Counter-Strike Source
20620.161243Half-Life Deathmatch Source
31970.134333Portal 2
11260.098420Day of Defeat Source
20640.095502Half-Life Source
24570.093733Left 4 Dead 2
24560.074682Left 4 Dead
20630.070074Half-Life Opposing Force
\n", 1486 | "
" 1487 | ], 1488 | "text/plain": [ 1489 | " score game\n", 1490 | "content_id \n", 1491 | "2058 0.267165 Half-Life 2 Episode Two\n", 1492 | "2056 0.251573 Half-Life 2 Deathmatch\n", 1493 | "978 0.175254 Counter-Strike Source\n", 1494 | "2062 0.161243 Half-Life Deathmatch Source\n", 1495 | "3197 0.134333 Portal 2\n", 1496 | "1126 0.098420 Day of Defeat Source\n", 1497 | "2064 0.095502 Half-Life Source\n", 1498 | "2457 0.093733 Left 4 Dead 2\n", 1499 | "2456 0.074682 Left 4 Dead\n", 1500 | "2063 0.070074 Half-Life Opposing Force" 1501 | ] 1502 | }, 1503 | "execution_count": 18, 1504 | "metadata": {}, 1505 | "output_type": "execute_result" 1506 | } 1507 | ], 1508 | "source": [ 1509 | "# Recommended User Games\n", 1510 | "recommender_for_user(\n", 1511 | " user_id = 1011, \n", 1512 | " interact_matrix = new_users_items_matrix_df, \n", 1513 | " df_content = df_game)" 1514 | ] 1515 | }, 1516 | { 1517 | "cell_type": "markdown", 1518 | "metadata": {}, 1519 | "source": [ 1520 | "Recommender for **user_id = 1319**. This user prefers games in the same line as the RPG or strategy" 1521 | ] 1522 | }, 1523 | { 1524 | "cell_type": "code", 1525 | "execution_count": 19, 1526 | "metadata": { 1527 | "cell_style": "split" 1528 | }, 1529 | "outputs": [ 1530 | { 1531 | "data": { 1532 | "text/html": [ 1533 | "
\n", 1534 | "\n", 1547 | "\n", 1548 | " \n", 1549 | " \n", 1550 | " \n", 1551 | " \n", 1552 | " \n", 1553 | " \n", 1554 | " \n", 1555 | " \n", 1556 | " \n", 1557 | " \n", 1558 | " \n", 1559 | " \n", 1560 | " \n", 1561 | " \n", 1562 | " \n", 1563 | " \n", 1564 | " \n", 1565 | " \n", 1566 | " \n", 1567 | " \n", 1568 | " \n", 1569 | " \n", 1570 | " \n", 1571 | " \n", 1572 | " \n", 1573 | " \n", 1574 | " \n", 1575 | " \n", 1576 | " \n", 1577 | " \n", 1578 | " \n", 1579 | " \n", 1580 | " \n", 1581 | " \n", 1582 | " \n", 1583 | " \n", 1584 | " \n", 1585 | " \n", 1586 | " \n", 1587 | " \n", 1588 | " \n", 1589 | " \n", 1590 | " \n", 1591 | " \n", 1592 | " \n", 1593 | " \n", 1594 | " \n", 1595 | " \n", 1596 | " \n", 1597 | " \n", 1598 | " \n", 1599 | " \n", 1600 | " \n", 1601 | " \n", 1602 | " \n", 1603 | " \n", 1604 | " \n", 1605 | " \n", 1606 | " \n", 1607 | " \n", 1608 | " \n", 1609 | " \n", 1610 | " \n", 1611 | " \n", 1612 | "
scoregame
content_id
1731.0Age of Empires II HD Edition
26621.0Medieval II Total War Kingdoms
43311.0The Elder Scrolls V Skyrim - Hearthfire
43291.0The Elder Scrolls V Skyrim - Dawnguard
46261.0Total War ATTILA
17521.0Football Manager 2015
42721.0The Bard's Tale
46311.0Total War SHOGUN 2
43281.0The Elder Scrolls V Skyrim
26611.0Medieval II Total War
\n", 1613 | "
" 1614 | ], 1615 | "text/plain": [ 1616 | " score game\n", 1617 | "content_id \n", 1618 | "173 1.0 Age of Empires II HD Edition\n", 1619 | "2662 1.0 Medieval II Total War Kingdoms\n", 1620 | "4331 1.0 The Elder Scrolls V Skyrim - Hearthfire\n", 1621 | "4329 1.0 The Elder Scrolls V Skyrim - Dawnguard\n", 1622 | "4626 1.0 Total War ATTILA\n", 1623 | "1752 1.0 Football Manager 2015\n", 1624 | "4272 1.0 The Bard's Tale\n", 1625 | "4631 1.0 Total War SHOGUN 2\n", 1626 | "4328 1.0 The Elder Scrolls V Skyrim\n", 1627 | "2661 1.0 Medieval II Total War" 1628 | ] 1629 | }, 1630 | "execution_count": 19, 1631 | "metadata": {}, 1632 | "output_type": "execute_result" 1633 | } 1634 | ], 1635 | "source": [ 1636 | "# Games previously purchased by the user\n", 1637 | "recommender_for_user(\n", 1638 | " user_id = 1319, \n", 1639 | " interact_matrix = users_items_matrix_df, \n", 1640 | " df_content = df_game)" 1641 | ] 1642 | }, 1643 | { 1644 | "cell_type": "code", 1645 | "execution_count": 22, 1646 | "metadata": { 1647 | "cell_style": "split" 1648 | }, 1649 | "outputs": [ 1650 | { 1651 | "data": { 1652 | "text/html": [ 1653 | "
\n", 1654 | "\n", 1667 | "\n", 1668 | " \n", 1669 | " \n", 1670 | " \n", 1671 | " \n", 1672 | " \n", 1673 | " \n", 1674 | " \n", 1675 | " \n", 1676 | " \n", 1677 | " \n", 1678 | " \n", 1679 | " \n", 1680 | " \n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | " \n", 1721 | " \n", 1722 | " \n", 1723 | " \n", 1724 | " \n", 1725 | " \n", 1726 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | "
scoregame
content_id
43300.292329The Elder Scrolls V Skyrim - Dragonborn
37920.151435Sid Meier's Civilization V
38370.146657Skyrim High Resolution Texture Pack
46300.108521Total War ROME II - Emperor Edition
14900.105798Empire Total War
16670.093142Fallout New Vegas
28720.091419Napoleon Total War
37930.083783Sid Meier's Civilization V Brave New World
16680.077532Fallout New Vegas Courier's Stash
16700.074456Fallout New Vegas Honest Hearts
\n", 1733 | "
" 1734 | ], 1735 | "text/plain": [ 1736 | " score game\n", 1737 | "content_id \n", 1738 | "4330 0.292329 The Elder Scrolls V Skyrim - Dragonborn\n", 1739 | "3792 0.151435 Sid Meier's Civilization V\n", 1740 | "3837 0.146657 Skyrim High Resolution Texture Pack\n", 1741 | "4630 0.108521 Total War ROME II - Emperor Edition\n", 1742 | "1490 0.105798 Empire Total War\n", 1743 | "1667 0.093142 Fallout New Vegas\n", 1744 | "2872 0.091419 Napoleon Total War\n", 1745 | "3793 0.083783 Sid Meier's Civilization V Brave New World\n", 1746 | "1668 0.077532 Fallout New Vegas Courier's Stash\n", 1747 | "1670 0.074456 Fallout New Vegas Honest Hearts" 1748 | ] 1749 | }, 1750 | "execution_count": 22, 1751 | "metadata": {}, 1752 | "output_type": "execute_result" 1753 | } 1754 | ], 1755 | "source": [ 1756 | "# Recommended User Games\n", 1757 | "recommender_for_user(\n", 1758 | " user_id = 1319, \n", 1759 | " interact_matrix = new_users_items_matrix_df, \n", 1760 | " df_content = df_game)" 1761 | ] 1762 | }, 1763 | { 1764 | "cell_type": "code", 1765 | "execution_count": null, 1766 | "metadata": {}, 1767 | "outputs": [], 1768 | "source": [] 1769 | } 1770 | ], 1771 | "metadata": { 1772 | "kernelspec": { 1773 | "display_name": "Python 3", 1774 | "language": "python", 1775 | "name": "python3" 1776 | }, 1777 | "language_info": { 1778 | "codemirror_mode": { 1779 | "name": "ipython", 1780 | "version": 3 1781 | }, 1782 | "file_extension": ".py", 1783 | "mimetype": "text/x-python", 1784 | "name": "python", 1785 | "nbconvert_exporter": "python", 1786 | "pygments_lexer": "ipython3", 1787 | "version": "3.6.4" 1788 | }, 1789 | "toc": { 1790 | "base_numbering": 1, 1791 | "nav_menu": {}, 1792 | "number_sections": true, 1793 | "sideBar": true, 1794 | "skip_h1_title": false, 1795 | "title_cell": "Table of Contents", 1796 | "title_sidebar": "Contents", 1797 | "toc_cell": false, 1798 | "toc_position": {}, 1799 | "toc_section_display": true, 1800 | "toc_window_display": false 1801 | } 1802 | }, 1803 | "nbformat": 4, 1804 | "nbformat_minor": 2 1805 | } 1806 | -------------------------------------------------------------------------------- /popularity_train.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pandas as pd 3 | import math 4 | from util import * 5 | from sklearn.model_selection import train_test_split 6 | from evaluation.model_evaluator import * 7 | import mlflow 8 | 9 | # Const 10 | METRICS_LOG_PATH = './artefacts/metrics.png' 11 | SCORE_VALUES_PATH = './artefacts/score_plot.png' 12 | 13 | class PopularityRecommender: 14 | MODEL_NAME = 'Popularity' 15 | 16 | def __init__(self, popularity_df, items_df=None): 17 | self.popularity_df = popularity_df 18 | self.items_df = items_df 19 | 20 | def get_model_name(self): 21 | return self.MODEL_NAME 22 | 23 | def recommend_items(self, user_id, items_to_ignore=[], topn=10, verbose=False): 24 | # Recommend the more popular items that the user hasn't seen yet. 25 | recommendations_df = self.popularity_df[~self.popularity_df['content_id'].isin(items_to_ignore)] \ 26 | .sort_values('view', ascending = False) \ 27 | .head(topn) 28 | return recommendations_df 29 | 30 | def run(): 31 | print("run()") 32 | # Load Dataset 33 | articles_df, interactions_full_df, \ 34 | interactions_train_df, interactions_test_df, \ 35 | cf_preds_df = load_dataset() 36 | 37 | #interactions_full_df['eventStrength'] = interactions_full_df['eventStrength'].apply(lambda x: 1 if x > 0 else 0) 38 | print('# interactions on Train set: %d' % len(interactions_train_df)) 39 | print('# interactions on Test set: %d' % len(interactions_test_df)) 40 | 41 | # Train 42 | ## Computes the most popular items 43 | item_popularity_df = interactions_full_df.groupby('content_id')['view'].sum()\ 44 | .sort_values(ascending=False).reset_index() 45 | 46 | popularity_model = PopularityRecommender(item_popularity_df, articles_df) 47 | 48 | 49 | # Evaluation 50 | model_evaluator = ModelEvaluator(articles_df, interactions_full_df, 51 | interactions_train_df, interactions_test_df) 52 | 53 | print('Evaluating Popularity recommendation model...') 54 | pop_global_metrics, pop_detailed_results_df = model_evaluator.evaluate_model(popularity_model) 55 | print('\nGlobal metrics:\n%s' % pop_global_metrics) 56 | 57 | # Plot Metrics 58 | plot_metrics_disc(pop_global_metrics, METRICS_LOG_PATH) 59 | 60 | # Tracking 61 | with mlflow.start_run(): 62 | for metric in ModelEvaluator.METRICS: 63 | mlflow.log_metric(metric, pop_global_metrics[metric]) 64 | mlflow.log_artifact(METRICS_LOG_PATH, "evaluation") 65 | 66 | 67 | if __name__ == '__main__': 68 | run() -------------------------------------------------------------------------------- /recommender.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pandas as pd 3 | import numpy as np 4 | import mlflow 5 | import mlflow.keras 6 | from util import * 7 | from model.CDAEModel import * 8 | from model.AutoEncModel import * 9 | from model.AutoEncContentModel import * 10 | from evaluation.model_evaluator import * 11 | 12 | # main 13 | # ---------------------------------------------- 14 | @click.command(help="Recommender Matrix Fatorization Model") 15 | @click.option("--name", type=click.Choice(['auto_enc', 'cdae', 'auto_enc_content'])) 16 | @click.option("--model_path", type=click.STRING) 17 | @click.option("--user_id", type=click.INT, default=1) 18 | @click.option("--topn", type=click.INT, default=10) 19 | @click.option("--view", type=click.INT, default=0) 20 | @click.option("--output", type=click.STRING, default='./data/predict.csv') 21 | def run(name, model_path, user_id, topn, view, output): 22 | 23 | # Load Dataset 24 | articles_df, _n, interactions_hist, _n2, _n3 = load_dataset() 25 | 26 | 27 | #Creating a sparse pivot table with users in rows and items in columns 28 | users_items_matrix_df = interactions_hist.pivot(index = 'user_id', 29 | columns = 'content_id', 30 | values = 'view').fillna(0) 31 | 32 | 33 | # Data 34 | users_items_matrix = users_items_matrix_df.values 35 | users_ids = list(users_items_matrix_df.index) 36 | 37 | if name == 'cdae': 38 | model = CDAEModel() 39 | elif name == 'auto_enc': 40 | model = AutoEncModel() 41 | elif name == 'auto_enc_content': 42 | model = AutoEncContentModel() 43 | 44 | # Input - Prepare input layer 45 | X, y = model.data_preparation(interactions_hist, users_items_matrix_df) 46 | 47 | # Keras Model 48 | model = mlflow.keras.load_model(model_path+name) 49 | 50 | # Predict 51 | if view == 0: # New Predic 52 | pred_score = model.predict(X) * (X[0] == view) 53 | else: # User Interactive Hist 54 | pred_score = users_items_matrix 55 | 56 | 57 | # converting the reconstructed matrix back to a Pandas dataframe 58 | cf_preds_df = pd.DataFrame(pred_score, 59 | columns = users_items_matrix_df.columns, 60 | index=users_ids).transpose() 61 | 62 | 63 | # Evaluation Model 64 | cf_recommender_model = CFRecommender(cf_preds_df, articles_df) 65 | 66 | # Recommender 67 | rec_list = cf_recommender_model.recommend_items(user_id = user_id, 68 | items_to_ignore=[], 69 | topn=topn, 70 | verbose=True) 71 | rec_list = rec_list[rec_list.score > 0] 72 | 73 | print(rec_list) 74 | rec_list.to_csv(output, index=False) 75 | 76 | 77 | if __name__ == '__main__': 78 | run() -------------------------------------------------------------------------------- /report.txt: -------------------------------------------------------------------------------- 1 | TensorFlow 2.0 Upgrade Script 2 | ----------------------------- 3 | Converted 1 files 4 | Detected 0 issues that require attention 5 | -------------------------------------------------------------------------------- 6 | ================================================================================ 7 | Detailed log follows: 8 | 9 | ================================================================================ 10 | -------------------------------------------------------------------------------- 11 | Processing file 'train.py' 12 | outputting to 'train_v2.py' 13 | -------------------------------------------------------------------------------- 14 | 15 | 16 | -------------------------------------------------------------------------------- 17 | 18 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import click 2 | import pandas as pd 3 | import math 4 | import numpy as np 5 | import json 6 | import mlflow 7 | import mlflow.keras 8 | from util import * 9 | from contextlib import redirect_stdout 10 | 11 | import tensorflow as tf 12 | from tensorflow import keras 13 | 14 | from evaluation.model_evaluator import * 15 | from model.CDAEModel import * 16 | from model.AutoEncModel import * 17 | from model.AutoEncContentModel import * 18 | 19 | # Const 20 | TRAIN_HIST_PATH = './artefacts/train_hist.png' 21 | TRAIN_HIST_LOG_PATH = './artefacts/train_hist.json' 22 | MODEL_SUMMARY_PATH = './artefacts/model_summary.txt' 23 | IMG_MODEL_SUMMARY_PATH = './artefacts/model_summary.png' 24 | METRICS_LOG_PATH = './artefacts/metrics.png' 25 | SCORE_VALUES_PATH = './artefacts/score_plot.png' 26 | 27 | # ---------------------------------------------- 28 | # main 29 | # ---------------------------------------------- 30 | @click.command(help="Autoencoder Matrix Fatorization Model") 31 | @click.option("--name", type=click.Choice(['auto_enc', 'cdae', 'auto_enc_content'])) 32 | @click.option("--factors", type=click.INT, default=10) 33 | @click.option("--layers", type=click.STRING, default='[128,256,128]') 34 | @click.option("--epochs", type=click.INT, default=10) 35 | @click.option("--batch", type=click.INT, default=64) 36 | @click.option("--activation", type=click.Choice(['relu', 'elu', 'selu', 'sigmoid'])) 37 | @click.option("--dropout", type=click.FLOAT, default=0.6) 38 | @click.option("--lr", type=click.FLOAT, default=0.001) 39 | @click.option("--reg", type=click.FLOAT, default=0.001) 40 | def run(name, factors, layers, epochs, batch, activation, dropout, lr, reg): 41 | 42 | # Load Dataset 43 | articles_df, interactions_full_df, \ 44 | interactions_train_df, interactions_test_df, \ 45 | cf_preds_df = load_dataset() 46 | 47 | print('# interactions on Train set: %d' % len(interactions_train_df)) 48 | print('# interactions on Test set: %d' % len(interactions_test_df)) 49 | 50 | #Creating a sparse pivot table with users in rows and items in columns 51 | users_items_matrix_df = interactions_train_df.pivot(index = 'user_id', 52 | columns = 'content_id', 53 | values = 'view').fillna(0) 54 | # Data 55 | users_items_matrix = users_items_matrix_df.values 56 | users_ids = list(users_items_matrix_df.index) 57 | 58 | if name == 'cdae': 59 | model = CDAEModel(factors, epochs, batch, activation, dropout, lr, reg) 60 | elif name == 'auto_enc': 61 | model = AutoEncModel(layers, epochs, batch, activation, dropout, lr, reg) 62 | elif name == 'auto_enc_content': 63 | model = AutoEncContentModel(layers, epochs, batch, activation, dropout, lr, reg) 64 | 65 | # --------------------------------------------- 66 | # Input - Prepare input layer 67 | X, y = model.data_preparation(interactions_train_df, users_items_matrix_df) 68 | 69 | # Train 70 | k_model, hist = model.fit(X, y) 71 | 72 | # Predict 73 | pred_score = model.predict(X) 74 | 75 | # converting the reconstructed matrix back to a Pandas dataframe 76 | cf_preds_df = pd.DataFrame(pred_score, 77 | columns = users_items_matrix_df.columns, 78 | index=users_ids).transpose() 79 | 80 | # Plot Preds Scores 81 | plot_scores_values(cf_preds_df.sample(frac=0.1).values.reshape(-1), SCORE_VALUES_PATH) 82 | 83 | print("Sample Scores") 84 | print(cf_preds_df.iloc[0].values[:10]) 85 | print(cf_preds_df.iloc[1].values[:10]) 86 | 87 | # Evaluation Model 88 | cf_recommender_model = CFRecommender(cf_preds_df, articles_df) 89 | model_evaluator = ModelEvaluator(articles_df, interactions_full_df, 90 | interactions_train_df, interactions_test_df) 91 | 92 | # Plot Summary model 93 | print_model_summary(k_model) 94 | 95 | # Plot History train model 96 | print_hist_log(hist) 97 | 98 | # Plot Evaluation 99 | print('Evaluating Collaborative Filtering model...') 100 | metrics, detailed_metrics = model_evaluator.evaluate_model(cf_recommender_model) 101 | print('\nGlobal metrics:\n%s' % metrics) 102 | 103 | # Plot Metrics 104 | plot_metrics_disc(metrics, METRICS_LOG_PATH) 105 | 106 | # Tracking 107 | with mlflow.start_run(): 108 | # metrics 109 | for metric in ModelEvaluator.METRICS: 110 | mlflow.log_metric(metric, metrics[metric]) 111 | 112 | # artefact 113 | mlflow.log_artifact(TRAIN_HIST_PATH, "history") 114 | mlflow.log_artifact(TRAIN_HIST_LOG_PATH, "history") 115 | mlflow.log_artifact(MODEL_SUMMARY_PATH) 116 | mlflow.log_artifact(IMG_MODEL_SUMMARY_PATH) 117 | mlflow.log_artifact(METRICS_LOG_PATH, "evaluation") 118 | mlflow.log_artifact(SCORE_VALUES_PATH, "evaluation") 119 | 120 | #model 121 | mlflow.keras.log_model(k_model, name) 122 | 123 | def print_model_summary(model): 124 | # Save model summary 125 | print(model.summary()) 126 | keras.utils.plot_model(model, to_file=IMG_MODEL_SUMMARY_PATH, show_shapes=True) 127 | 128 | with open(MODEL_SUMMARY_PATH, 'w') as f: 129 | with redirect_stdout(f): 130 | model.summary() 131 | 132 | def print_hist_log(hist): 133 | # save hist 134 | with open(TRAIN_HIST_LOG_PATH, 'w') as f: 135 | json.dump(hist.history, f) 136 | 137 | # save image hist 138 | plot_hist(hist).savefig(TRAIN_HIST_PATH) 139 | 140 | 141 | if __name__ == '__main__': 142 | run() -------------------------------------------------------------------------------- /train_all.sh: -------------------------------------------------------------------------------- 1 | # /bin/bash 2 | 3 | # 1. Popularity Model 4 | mlflow run . -e popularity_train 5 | 6 | # 2. CDAE - Collaborative Denoising Auto-Encoders for Top-N Recommender Systems 7 | mlflow run . -P activation=selu -P batch=64 -P dropout=0.8 -P epochs=50 -P factors=500 -P lr=0.0001 -P name=cdae -P reg=0.0001 8 | 9 | # 3. Deep AutoEncoder for Collaborative Filtering 10 | mlflow run . -P activation=selu -P batch=64 -P dropout=0.8 -P epochs=50 -P layers='[512,256,512]' -P lr=0.0001 -P name=auto_enc_content -P reg=0.01 11 | 12 | # 4. Deep AutoEncoder for Collaborative Filtering With Content Information 13 | mlflow run . -P activation=selu -P batch=64 -P dropout=0.8 -P epochs=50 -P layers='[512,256,512]' -P lr=0.0001 -P name=auto_enc -P reg=0.01 14 | 15 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | import math 4 | import numpy as np 5 | import seaborn as sns 6 | 7 | def load_dataset(base_path = './data/', with_cartesian = False): 8 | articles_df = pd.read_csv(base_path+'articles_df.csv') 9 | interactions_full_df = pd.read_csv(base_path+'interactions_full_df.csv') 10 | interactions_train_df = pd.read_csv(base_path+'interactions_train_df.csv') 11 | interactions_test_df = pd.read_csv(base_path+'interactions_test_df.csv') 12 | 13 | if with_cartesian: 14 | cartesian_product_df = pd.read_csv(base_path+'cartesian_product_df.csv') 15 | else: 16 | cartesian_product_df = None 17 | 18 | return articles_df, interactions_full_df, interactions_train_df, interactions_test_df, cartesian_product_df 19 | 20 | def export_figure_matplotlib(arr, f_name, dpi=200, resize_fact=1, plt_show=False): 21 | """ 22 | Export array as figure in original resolution 23 | :param arr: array of image to save in original resolution 24 | :param f_name: name of file where to save figure 25 | :param resize_fact: resize facter wrt shape of arr, in (0, np.infty) 26 | :param dpi: dpi of your screen 27 | :param plt_show: show plot or not 28 | """ 29 | fig = plt.figure(frameon=False) 30 | fig.set_size_inches(arr.shape[1]/dpi, arr.shape[0]/dpi) 31 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 32 | ax.set_axis_off() 33 | fig.add_axes(ax) 34 | ax.imshow(arr, cmap='hot') 35 | plt.savefig(f_name, dpi=(dpi * resize_fact)) 36 | if plt_show: 37 | plt.show() 38 | else: 39 | plt.close() 40 | 41 | 42 | def plot_scores_values(values, f_name, plt_show=False): 43 | 44 | fig = plt.figure(figsize=(10,4)) 45 | 46 | #ax = sns.boxplot(x=values) 47 | ax = sns.distplot(values) 48 | #ax = sns.swarmplot(x=values, color=".25") 49 | plt.savefig(f_name, dpi=(200)) 50 | if plt_show: 51 | plt.show() 52 | else: 53 | plt.close() 54 | 55 | def plot_metrics_disc(metrics, f_name, plt_show=False): 56 | ''' 57 | Plot Metrics in Angle 58 | ''' 59 | labels = list(metrics.keys()) 60 | stats = list(metrics.values()) 61 | 62 | angles = np.linspace(0, 2*np.pi, len(labels), endpoint=False) 63 | # close the plot 64 | stats = np.concatenate((stats,[stats[0]])) 65 | angles = np.concatenate((angles,[angles[0]])) 66 | 67 | fig = plt.figure(figsize=(6,6)) 68 | ax = fig.add_subplot(111, polar=True) 69 | ax.plot(angles, stats, linewidth=1, linestyle='solid') 70 | ax.fill(angles, stats, 'b', alpha=0.1) 71 | ax.set_thetagrids(angles * 180/np.pi, labels) 72 | plt.yticks([0.25,0.5, 0.75], ["0,25","0,5", "0,75"], color="grey", size=7) 73 | plt.ylim(0,1) 74 | 75 | plt.savefig(f_name, dpi=(200)) 76 | if plt_show: 77 | plt.show() 78 | else: 79 | plt.close() 80 | 81 | def plot_hist(hist): 82 | # summarize history for loss 83 | fig, ax = plt.subplots() # create figure & 1 axis 84 | 85 | plt.plot(hist.history['loss']) 86 | plt.plot(hist.history['val_loss']) 87 | plt.title('model loss') 88 | plt.ylabel('loss') 89 | plt.xlabel('epoch') 90 | plt.legend(['train', 'test'], loc='upper left') 91 | return fig 92 | 93 | def smooth_user_preference(x): 94 | return math.log(1+x, 2) 95 | --------------------------------------------------------------------------------