├── .gitignore ├── LICENSE ├── README.md ├── data └── ecomm │ └── OnlineRetail_sessions.pkl ├── notebooks ├── Analyze_HPO_results.ipynb └── Explore_Online_Retail_Dataset.ipynb ├── recsys ├── __init__.py ├── data.py ├── metrics.py ├── models.py └── utils.py ├── requirements.txt ├── requirements3.6.txt ├── scripts ├── baseline_analysis.py ├── setup_ray_cluster.py ├── train_w2v_with_logging.py └── tune_w2v_with_ray.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # data directory 141 | #data/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Session-based Recommender Systems 2 | 3 | This repo accompanies the Cloudera Fast Forward report [Session-based Recommender Systems](https://session-based-recommenders.fastforwardlabs.com/). It provides small library to train Word2Vec as a means of learning product or item representations in the context of user sessions (browsing histories, transaction histories, music playlists, etc.). These dense representations can then be used for item recommendation. We formulate this under the Next Event Prediction task, that is, given a user's recent interaction, predict the next item they interact with (click on, purchase, listen to, etc.). 4 | 5 | Instructions are given both for general use (on a laptop, say), and for Cloudera CML and CDSW. We'll first describe what's here, then go through how to run everything. 6 | 7 | ## Structure 8 | ``` 9 | . 10 | ├── data # This folder contains starter data. 11 | ├── scripts # This contains scripts for *doing* things -- training models, analysing results. 12 | ├── notebooks # This contains Jupyter notebooks that accompany the report and demonstrate basic usage. 13 | └── recsys # A small library of useful functions. 14 | ``` 15 | Let's examine each of the important folders in turn. 16 | 17 | 18 | ### `recsys` 19 | ``` 20 | ├── data.py # Contains functions for loading and processing data into sessions 21 | ├── metrics.py # Contains metrics for evaluation 22 | ├── models.py # Contains wrappers for training Word2Vec both alone and with Ray Tune 23 | └── utils.py # Helper functions for serialization and I/O 24 | ``` 25 | 26 | 27 | ### `scripts` 28 | ``` 29 | ├── baseline_analysis.py 30 | ├── setup_ray_cluster.py 31 | ├── train_w2v_with_logging.py 32 | └── tune_w2v_with_ray.py 33 | ``` 34 | An overview of what each of these scripts does is discussed below. 35 | 36 | ### `notebooks` 37 | ``` 38 | ├── Analyze_HPO_results.ipynb 39 | └── Explore_Online_Retail_Dataset.ipynb 40 | ``` 41 | These notebooks provide additional exploration and analysis. Please note that `Analyze_HPO_results.ipynb` is expressly for demonstration purposes as HPO output results explored within are not included in this repo. 42 | 43 | ## Learning representations for session-based recommendations 44 | To go from a fresh clone of the repo to the final state, follow these instructions in order. 45 | 46 | ### Installation 47 | The code and applications within were developed against Python 3.8.8, and are likely also to function with more recent versions of Python. 48 | 49 | To install dependencies, first create and activate a new virtual environment through your preferred means, then pip install from the requirements file. I recommend: 50 | 51 | ``` 52 | python3 -m venv .venv 53 | source .venv/bin/activate 54 | pip install -r requirements.txt 55 | ``` 56 | 57 | In CML or CDSW, no virtual env is necessary. Instead, inside a Python 3 session (with at least 2 vCPU / 4 GiB Memory), simply run 58 | 59 | ``` 60 | !pip3 install -r requirements.txt # notice `pip3`, not `pip` 61 | ``` 62 | 63 | Note: if your session has an older Python image (3.6) use the alternative `requirements3.6.txt`: 64 | ``` 65 | !pip3 install -r requirements3.6.txt 66 | ``` 67 | 68 | ### Data 69 | 70 | While we explored several datasets (and code exists in `recsys/data.py` to interact with those datasets), the analysis in this repo is focused on the [Online Retail](https://www.kaggle.com/vijayuv/onlineretail) dataset. This dataset is open source though you will need to create an account on Kaggle before downloading the data. In this repo we include a version of this dataset post-processed into customer sessions. These sessions represent all customer transactions from a UK-based online boutique selling specialty gifts collected between 12/01/2010 and 12/09/2011. In total there are purchase histories for 4,372 customers and 3,684 unique products. 71 | 72 | ### Model training and analysis 73 | 74 | The `scripts` directory contains scripts to train models in various formats and analyze results. Here we provide a high-level overview: 75 | 76 | * `scripts/baseline_analysis.py`: a common baseline for recommendation systems is to simply recommend the most popular items. This script computes the "Association Rules" baseline which considers how frequently each item co-occurrs with all other items in a session for each session in the training set. 77 | * `scripts/train_w2v_with_logging.py`: This script trains Gensim's implementation of the Word2Vec algorithm to learn representations for each item in a session. Identifying "similar" items then serves as the method for generating recommendations. Includes callbacks for monitoring metrics (Recall@K, training loss) as a function of training time (epochs). 78 | * `scripts/tune_w2v_with_ray.py`: The Word2Vec algorithm has a large hyperparameter space and the default values are subpar for the task of generating good item representations for recommendation systems. This scripts performs hyperparameter optimization (HPO) with [Ray Tune](https://docs.ray.io/en/master/tune/index.html). 79 | * [CDSW/CML only] `setup_ray_cluster.py`: Hyperparameter optimization can be computationally expensive but this expense can be mitigated, in part, through distribution. This script initializes (and tears down) a Ray Cluster for distributed hyperparameter optimization. If using, follow the instructions in this script to setup the cluster, then run `tune_w2v_with_ray.py` with the appropriate arguments, and finally shutdown the cluster after HPO is complete. 80 | 81 | 82 | These scripts are not intended to be run in any particular order (with the exception noted above). Instead, they provide functionality for different use cases. To run scripts, follow this procedure in the terminal or a Session with at least 1vCPUs and 2GiBs of memory: 83 | 84 | ``` 85 | !python3 scripts/baseline_analysis.py 86 | !python3 scripts/train_w2v_with_logging.py # see optional arguments 87 | !python3 scriptstune_w2v_with_ray.py # see optional arguments for distributed HPO 88 | ``` 89 | 90 | -------------------------------------------------------------------------------- /data/ecomm/OnlineRetail_sessions.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fastforwardlabs/session_based_recommenders/c438dd1334fcefc6bedea69b0cd67f779a5de5d3/data/ecomm/OnlineRetail_sessions.pkl -------------------------------------------------------------------------------- /notebooks/Explore_Online_Retail_Dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# The \"Online Retail\" Dataset \n", 8 | "\n", 9 | "This is a transnational data set which contains all purchase transactions occurring between 01/12/2010 and 09/12/2011 for a UK-based and registered non-store online retail. The company mainly sells unique all-occasion gifts. Many customers of the company are wholesalers. \n", 10 | "\n", 11 | "The dataset is composed of the following columns:\n", 12 | "\n", 13 | "\n", 14 | "* **InvoiceNo**: Invoice number. Nominal, a 6-digit integral number uniquely assigned to each transaction. If this code starts with letter 'c', it indicates a cancellation.\n", 15 | "* **StockCode**: Product (item) code. Nominal, a 5-digit integral number uniquely assigned to each distinct product.\n", 16 | "* **Description**: Product (item) name. Nominal.\n", 17 | "* **Quantity**: The quantities of each product (item) per transaction. Numeric.\n", 18 | "* **InvoiceDate**: Invice Date and time. Numeric, the day and time when each transaction was generated.\n", 19 | "* **UnitPrice**: Unit price. Numeric, Product price per unit in sterling.\n", 20 | "* **CustomerID**: Customer number. Nominal, a 5-digit integral number uniquely assigned to each customer.\n", 21 | "* **Country**: Country name. Nominal, the name of the country where each customer resides." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 22, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import pandas as pd\n", 31 | "import numpy as np\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "import seaborn as sns\n", 34 | "\n", 35 | "from recsys.data import load_original_ecomm, preprocess_ecomm, construct_session_sequences " 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 23, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import seaborn as sns\n", 45 | "plt.style.use(\"seaborn-white\")\n", 46 | "cldr_colors = ['#00b6b5', '#f7955b','#6c8cc7', '#828282']#\n", 47 | "cldr_green = '#a4d65d'\n", 48 | "color_palette = \"viridis\"" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "## Getting the dataset\n", 56 | "\n", 57 | "We obtained the Online Retail dataset from the Kaggle website found [here](https://www.kaggle.com/vijayuv/onlineretail). \n", 58 | "The data is open source but you will need to register on Kaggle's website before downloading. \n", 59 | "\n", 60 | "Once downloaded, we created some helper functions for opening and processing this dataset. " 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 39, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "df = load_original_ecomm()" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 25, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "data": { 79 | "text/html": [ 80 | "
\n", 81 | "\n", 94 | "\n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | "
InvoiceNoStockCodeDescriptionQuantityInvoiceDateUnitPriceCustomerIDCountry
053636585123AWHITE HANGING HEART T-LIGHT HOLDER62010-12-01 08:26:002.5517850.0United Kingdom
153636571053WHITE METAL LANTERN62010-12-01 08:26:003.3917850.0United Kingdom
253636584406BCREAM CUPID HEARTS COAT HANGER82010-12-01 08:26:002.7517850.0United Kingdom
353636584029GKNITTED UNION FLAG HOT WATER BOTTLE62010-12-01 08:26:003.3917850.0United Kingdom
453636584029ERED WOOLLY HOTTIE WHITE HEART.62010-12-01 08:26:003.3917850.0United Kingdom
553636522752SET 7 BABUSHKA NESTING BOXES22010-12-01 08:26:007.6517850.0United Kingdom
653636521730GLASS STAR FROSTED T-LIGHT HOLDER62010-12-01 08:26:004.2517850.0United Kingdom
753636622633HAND WARMER UNION JACK62010-12-01 08:28:001.8517850.0United Kingdom
853636622632HAND WARMER RED POLKA DOT62010-12-01 08:28:001.8517850.0United Kingdom
953636784879ASSORTED COLOUR BIRD ORNAMENT322010-12-01 08:34:001.6913047.0United Kingdom
1053636722745POPPY'S PLAYHOUSE BEDROOM62010-12-01 08:34:002.1013047.0United Kingdom
1153636722748POPPY'S PLAYHOUSE KITCHEN62010-12-01 08:34:002.1013047.0United Kingdom
1253636722749FELTCRAFT PRINCESS CHARLOTTE DOLL82010-12-01 08:34:003.7513047.0United Kingdom
1353636722310IVORY KNITTED MUG COSY62010-12-01 08:34:001.6513047.0United Kingdom
1453636784969BOX OF 6 ASSORTED COLOUR TEASPOONS62010-12-01 08:34:004.2513047.0United Kingdom
1553636722623BOX OF VINTAGE JIGSAW BLOCKS32010-12-01 08:34:004.9513047.0United Kingdom
1653636722622BOX OF VINTAGE ALPHABET BLOCKS22010-12-01 08:34:009.9513047.0United Kingdom
1753636721754HOME BUILDING BLOCK WORD32010-12-01 08:34:005.9513047.0United Kingdom
1853636721755LOVE BUILDING BLOCK WORD32010-12-01 08:34:005.9513047.0United Kingdom
1953636721777RECIPE BOX WITH METAL HEART42010-12-01 08:34:007.9513047.0United Kingdom
\n", 331 | "
" 332 | ], 333 | "text/plain": [ 334 | " InvoiceNo StockCode Description Quantity \\\n", 335 | "0 536365 85123A WHITE HANGING HEART T-LIGHT HOLDER 6 \n", 336 | "1 536365 71053 WHITE METAL LANTERN 6 \n", 337 | "2 536365 84406B CREAM CUPID HEARTS COAT HANGER 8 \n", 338 | "3 536365 84029G KNITTED UNION FLAG HOT WATER BOTTLE 6 \n", 339 | "4 536365 84029E RED WOOLLY HOTTIE WHITE HEART. 6 \n", 340 | "5 536365 22752 SET 7 BABUSHKA NESTING BOXES 2 \n", 341 | "6 536365 21730 GLASS STAR FROSTED T-LIGHT HOLDER 6 \n", 342 | "7 536366 22633 HAND WARMER UNION JACK 6 \n", 343 | "8 536366 22632 HAND WARMER RED POLKA DOT 6 \n", 344 | "9 536367 84879 ASSORTED COLOUR BIRD ORNAMENT 32 \n", 345 | "10 536367 22745 POPPY'S PLAYHOUSE BEDROOM 6 \n", 346 | "11 536367 22748 POPPY'S PLAYHOUSE KITCHEN 6 \n", 347 | "12 536367 22749 FELTCRAFT PRINCESS CHARLOTTE DOLL 8 \n", 348 | "13 536367 22310 IVORY KNITTED MUG COSY 6 \n", 349 | "14 536367 84969 BOX OF 6 ASSORTED COLOUR TEASPOONS 6 \n", 350 | "15 536367 22623 BOX OF VINTAGE JIGSAW BLOCKS 3 \n", 351 | "16 536367 22622 BOX OF VINTAGE ALPHABET BLOCKS 2 \n", 352 | "17 536367 21754 HOME BUILDING BLOCK WORD 3 \n", 353 | "18 536367 21755 LOVE BUILDING BLOCK WORD 3 \n", 354 | "19 536367 21777 RECIPE BOX WITH METAL HEART 4 \n", 355 | "\n", 356 | " InvoiceDate UnitPrice CustomerID Country \n", 357 | "0 2010-12-01 08:26:00 2.55 17850.0 United Kingdom \n", 358 | "1 2010-12-01 08:26:00 3.39 17850.0 United Kingdom \n", 359 | "2 2010-12-01 08:26:00 2.75 17850.0 United Kingdom \n", 360 | "3 2010-12-01 08:26:00 3.39 17850.0 United Kingdom \n", 361 | "4 2010-12-01 08:26:00 3.39 17850.0 United Kingdom \n", 362 | "5 2010-12-01 08:26:00 7.65 17850.0 United Kingdom \n", 363 | "6 2010-12-01 08:26:00 4.25 17850.0 United Kingdom \n", 364 | "7 2010-12-01 08:28:00 1.85 17850.0 United Kingdom \n", 365 | "8 2010-12-01 08:28:00 1.85 17850.0 United Kingdom \n", 366 | "9 2010-12-01 08:34:00 1.69 13047.0 United Kingdom \n", 367 | "10 2010-12-01 08:34:00 2.10 13047.0 United Kingdom \n", 368 | "11 2010-12-01 08:34:00 2.10 13047.0 United Kingdom \n", 369 | "12 2010-12-01 08:34:00 3.75 13047.0 United Kingdom \n", 370 | "13 2010-12-01 08:34:00 1.65 13047.0 United Kingdom \n", 371 | "14 2010-12-01 08:34:00 4.25 13047.0 United Kingdom \n", 372 | "15 2010-12-01 08:34:00 4.95 13047.0 United Kingdom \n", 373 | "16 2010-12-01 08:34:00 9.95 13047.0 United Kingdom \n", 374 | "17 2010-12-01 08:34:00 5.95 13047.0 United Kingdom \n", 375 | "18 2010-12-01 08:34:00 5.95 13047.0 United Kingdom \n", 376 | "19 2010-12-01 08:34:00 7.95 13047.0 United Kingdom " 377 | ] 378 | }, 379 | "execution_count": 25, 380 | "metadata": {}, 381 | "output_type": "execute_result" 382 | } 383 | ], 384 | "source": [ 385 | "df.head(20)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": {}, 391 | "source": [ 392 | "These purchase histories record transactions for each customer and detail the items that were purchased in each transaction. Each transaction has a unique InvoiceNo, a time stamp (InvoiceDate) and the CustomerID of the purchaser. \n", 393 | "\n", 394 | "## Preprocessing\n", 395 | "\n", 396 | "There are some rows with missing information, so we'll filter those out. Since we want to define customer sessions, we'll use group by CustomerID field and filter out any customer entries that have fewer than three purchased items." 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": 47, 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "data": { 406 | "text/plain": [ 407 | "InvoiceNo 0\n", 408 | "StockCode 0\n", 409 | "Description 1454\n", 410 | "Quantity 0\n", 411 | "InvoiceDate 0\n", 412 | "UnitPrice 0\n", 413 | "CustomerID 135080\n", 414 | "Country 0\n", 415 | "dtype: int64" 416 | ] 417 | }, 418 | "execution_count": 47, 419 | "metadata": {}, 420 | "output_type": "execute_result" 421 | } 422 | ], 423 | "source": [ 424 | "df.isnull().sum()" 425 | ] 426 | }, 427 | { 428 | "cell_type": "code", 429 | "execution_count": 48, 430 | "metadata": {}, 431 | "outputs": [], 432 | "source": [ 433 | "df.dropna(inplace=True)" 434 | ] 435 | }, 436 | { 437 | "cell_type": "code", 438 | "execution_count": 49, 439 | "metadata": {}, 440 | "outputs": [], 441 | "source": [ 442 | "# filter out sessions that have fewer than 3 items\n", 443 | "item_counts = df.groupby([\"CustomerID\"]).count()[\"StockCode\"]\n", 444 | "df = df[\n", 445 | " df[\"CustomerID\"].isin(item_counts[item_counts >= 3].index)\n", 446 | "].reset_index(drop=True)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 50, 452 | "metadata": {}, 453 | "outputs": [ 454 | { 455 | "data": { 456 | "text/plain": [ 457 | "CustomerID\n", 458 | "12346.0 2\n", 459 | "12347.0 182\n", 460 | "12348.0 31\n", 461 | "12349.0 73\n", 462 | "12350.0 17\n", 463 | " ... \n", 464 | "18280.0 10\n", 465 | "18281.0 7\n", 466 | "18282.0 13\n", 467 | "18283.0 756\n", 468 | "18287.0 70\n", 469 | "Name: StockCode, Length: 4372, dtype: int64" 470 | ] 471 | }, 472 | "execution_count": 50, 473 | "metadata": {}, 474 | "output_type": "execute_result" 475 | } 476 | ], 477 | "source": [ 478 | "item_counts" 479 | ] 480 | }, 481 | { 482 | "cell_type": "markdown", 483 | "metadata": {}, 484 | "source": [ 485 | "## Dataset Statistics" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 29, 491 | "metadata": {}, 492 | "outputs": [ 493 | { 494 | "data": { 495 | "text/plain": [ 496 | "4234" 497 | ] 498 | }, 499 | "execution_count": 29, 500 | "metadata": {}, 501 | "output_type": "execute_result" 502 | } 503 | ], 504 | "source": [ 505 | "# Number of unique customers after preprocessing\n", 506 | "df.CustomerID.nunique()" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 30, 512 | "metadata": {}, 513 | "outputs": [ 514 | { 515 | "data": { 516 | "text/plain": [ 517 | "3684" 518 | ] 519 | }, 520 | "execution_count": 30, 521 | "metadata": {}, 522 | "output_type": "execute_result" 523 | } 524 | ], 525 | "source": [ 526 | "# Number of unique stock codes (products)\n", 527 | "df.StockCode.nunique()" 528 | ] 529 | }, 530 | { 531 | "cell_type": "markdown", 532 | "metadata": {}, 533 | "source": [ 534 | "### Product popularity\n", 535 | "\n", 536 | "Here we plot the frequency by which each product is purchased (occurs in a transaction). " 537 | ] 538 | }, 539 | { 540 | "cell_type": "code", 541 | "execution_count": 31, 542 | "metadata": {}, 543 | "outputs": [ 544 | { 545 | "data": { 546 | "image/png": "\n", 547 | "text/plain": [ 548 | "
" 549 | ] 550 | }, 551 | "metadata": {}, 552 | "output_type": "display_data" 553 | } 554 | ], 555 | "source": [ 556 | "plt.style.use(\"seaborn-white\")\n", 557 | "\n", 558 | "# Number of unique customer IDs\n", 559 | "product_counts = df.groupby(['StockCode']).count()['InvoiceNo'].values\n", 560 | "\n", 561 | "fig = plt.figure(figsize=(8,6))\n", 562 | "plt.yticks(fontsize=14)\n", 563 | "plt.xticks(fontsize=14)\n", 564 | "\n", 565 | "plt.semilogy(sorted(product_counts))\n", 566 | "plt.ylabel(\"Product counts\", fontsize=16);\n", 567 | "plt.xlabel(\"Product index\", fontsize=16);\n", 568 | "\n", 569 | "plt.tight_layout()" 570 | ] 571 | }, 572 | { 573 | "cell_type": "markdown", 574 | "metadata": {}, 575 | "source": [ 576 | "The left side of the figure corresponds to products that are not very popular (because they aren't purchased very often), while the far right side indicates that some products are *extremely* popular and have been purchased hundreds of times. \n", 577 | "\n", 578 | "### Customer session lengths \n", 579 | "\n", 580 | "We define a customer's \"session\" as all the products they purchased in each transaction, in the order in which they were purchased (ordered InvoiceDate). We can then examine statistics regarding the length of these sessions. Below is a boxplot of all customer session lengths. " 581 | ] 582 | }, 583 | { 584 | "cell_type": "code", 585 | "execution_count": 78, 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "data": { 590 | "image/png": "\n", 591 | "text/plain": [ 592 | "
" 593 | ] 594 | }, 595 | "metadata": {}, 596 | "output_type": "display_data" 597 | } 598 | ], 599 | "source": [ 600 | "session_lengths = df.groupby(\"CustomerID\").count()['InvoiceNo'].values\n", 601 | "\n", 602 | "fig = plt.figure(figsize=(8,6))\n", 603 | "plt.xticks(fontsize=14)\n", 604 | "\n", 605 | "ax = sns.boxplot(x=session_lengths, color=cldr_colors[2])\n", 606 | "\n", 607 | "for patch in ax.artists:\n", 608 | " r, g, b, a = patch.get_facecolor()\n", 609 | " patch.set_facecolor((r, g, b, .7))\n", 610 | " \n", 611 | "plt.xlim(0,600)\n", 612 | "plt.xlabel(\"Session length (# of products purchased)\", fontsize=16);\n", 613 | "\n", 614 | "plt.tight_layout()\n", 615 | "plt.savefig(\"../../recommendations/docs/images/session_lengths.png\", transparent=True, dpi=150)" 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": 57, 621 | "metadata": {}, 622 | "outputs": [ 623 | { 624 | "name": "stdout", 625 | "output_type": "stream", 626 | "text": [ 627 | "Minimum session length: \t 3\n", 628 | "Maximum session length: \t 7983\n", 629 | "Mean session length: \t \t 96.03967879074162\n", 630 | "Median session length: \t \t 44.0\n", 631 | "Total number of purchases: \t 406632\n" 632 | ] 633 | } 634 | ], 635 | "source": [ 636 | "print(\"Minimum session length: \\t\", min(session_lengths))\n", 637 | "print(\"Maximum session length: \\t\", max(session_lengths))\n", 638 | "print(\"Mean session length: \\t \\t\", np.mean(session_lengths))\n", 639 | "print(\"Median session length: \\t \\t\", np.median(session_lengths))\n", 640 | "print(\"Total number of purchases: \\t\", np.sum(session_lengths))" 641 | ] 642 | }, 643 | { 644 | "cell_type": "markdown", 645 | "metadata": {}, 646 | "source": [ 647 | "## Misc " 648 | ] 649 | }, 650 | { 651 | "cell_type": "code", 652 | "execution_count": 63, 653 | "metadata": {}, 654 | "outputs": [ 655 | { 656 | "data": { 657 | "text/plain": [ 658 | "Timedelta('69 days 23:11:00')" 659 | ] 660 | }, 661 | "execution_count": 63, 662 | "metadata": {}, 663 | "output_type": "execute_result" 664 | } 665 | ], 666 | "source": [ 667 | "customer_grps.get_group(custIDs[0])['InvoiceDate'].diff().max()" 668 | ] 669 | }, 670 | { 671 | "cell_type": "code", 672 | "execution_count": 66, 673 | "metadata": {}, 674 | "outputs": [ 675 | { 676 | "data": { 677 | "text/plain": [ 678 | "312" 679 | ] 680 | }, 681 | "execution_count": 66, 682 | "metadata": {}, 683 | "output_type": "execute_result" 684 | } 685 | ], 686 | "source": [ 687 | "len(customer_grps.get_group(custIDs[0])['StockCode'])" 688 | ] 689 | }, 690 | { 691 | "cell_type": "code", 692 | "execution_count": 66, 693 | "metadata": {}, 694 | "outputs": [], 695 | "source": [ 696 | "customer_grps = df.groupby(\"CustomerID\")\n", 697 | "max_time_diff = []\n", 698 | "\n", 699 | "for custID, history in customer_grps:\n", 700 | " max_time_diff.append(history['InvoiceDate'].diff().max().days)" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 72, 706 | "metadata": {}, 707 | "outputs": [ 708 | { 709 | "data": { 710 | "text/plain": [ 711 | "" 712 | ] 713 | }, 714 | "execution_count": 72, 715 | "metadata": {}, 716 | "output_type": "execute_result" 717 | }, 718 | { 719 | "data": { 720 | "image/png": "\n", 721 | "text/plain": [ 722 | "
" 723 | ] 724 | }, 725 | "metadata": {}, 726 | "output_type": "display_data" 727 | } 728 | ], 729 | "source": [ 730 | "max_time_diff\n", 731 | "\n", 732 | "plt.scatter(np.arange(len(max_time_diff)), sorted(max_time_diff))" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 74, 738 | "metadata": {}, 739 | "outputs": [ 740 | { 741 | "data": { 742 | "text/html": [ 743 | "
\n", 744 | "\n", 757 | "\n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | "
InvoiceNoDescriptionInvoiceDate
10652537626BLACK CANDELABRA T-LIGHT HOLDER2010-12-07 14:57:00
10653537626AIRLINE BAG VINTAGE JET SET BROWN2010-12-07 14:57:00
10654537626COLOUR GLASS. STAR T-LIGHT HOLDER2010-12-07 14:57:00
10655537626MINI PAINT SET VINTAGE2010-12-07 14:57:00
10656537626CLEAR DRAWER KNOB ACRYLIC EDWARDIAN2010-12-07 14:57:00
............
403109581180WOODLAND CHARLOTTE BAG2011-12-07 15:52:00
403110581180PINK GOOSE FEATHER TREE 60CM2011-12-07 15:52:00
403111581180CHRISTMAS TABLE SILVER CANDLE SPIKE2011-12-07 15:52:00
403112581180MINI PLAYING CARDS SPACEBOY2011-12-07 15:52:00
403113581180MINI PLAYING CARDS DOLLY GIRL2011-12-07 15:52:00
\n", 835 | "

182 rows × 3 columns

\n", 836 | "
" 837 | ], 838 | "text/plain": [ 839 | " InvoiceNo Description InvoiceDate\n", 840 | "10652 537626 BLACK CANDELABRA T-LIGHT HOLDER 2010-12-07 14:57:00\n", 841 | "10653 537626 AIRLINE BAG VINTAGE JET SET BROWN 2010-12-07 14:57:00\n", 842 | "10654 537626 COLOUR GLASS. STAR T-LIGHT HOLDER 2010-12-07 14:57:00\n", 843 | "10655 537626 MINI PAINT SET VINTAGE 2010-12-07 14:57:00\n", 844 | "10656 537626 CLEAR DRAWER KNOB ACRYLIC EDWARDIAN 2010-12-07 14:57:00\n", 845 | "... ... ... ...\n", 846 | "403109 581180 WOODLAND CHARLOTTE BAG 2011-12-07 15:52:00\n", 847 | "403110 581180 PINK GOOSE FEATHER TREE 60CM 2011-12-07 15:52:00\n", 848 | "403111 581180 CHRISTMAS TABLE SILVER CANDLE SPIKE 2011-12-07 15:52:00\n", 849 | "403112 581180 MINI PLAYING CARDS SPACEBOY 2011-12-07 15:52:00\n", 850 | "403113 581180 MINI PLAYING CARDS DOLLY GIRL 2011-12-07 15:52:00\n", 851 | "\n", 852 | "[182 rows x 3 columns]" 853 | ] 854 | }, 855 | "execution_count": 74, 856 | "metadata": {}, 857 | "output_type": "execute_result" 858 | } 859 | ], 860 | "source": [ 861 | "test_grp = customer_grps.get_group(12347.0)\n", 862 | "\n", 863 | "test_grp[['InvoiceNo','Description', 'InvoiceDate']]" 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": 77, 869 | "metadata": { 870 | "scrolled": false 871 | }, 872 | "outputs": [ 873 | { 874 | "name": "stdout", 875 | "output_type": "stream", 876 | "text": [ 877 | "2010-12-07 14:57:00\n", 878 | "10652 BLACK CANDELABRA T-LIGHT HOLDER\n", 879 | "10653 AIRLINE BAG VINTAGE JET SET BROWN\n", 880 | "10654 COLOUR GLASS. STAR T-LIGHT HOLDER\n", 881 | "10655 MINI PAINT SET VINTAGE \n", 882 | "10656 CLEAR DRAWER KNOB ACRYLIC EDWARDIAN\n", 883 | "10657 PINK DRAWER KNOB ACRYLIC EDWARDIAN\n", 884 | "10658 GREEN DRAWER KNOB ACRYLIC EDWARDIAN\n", 885 | "10659 RED DRAWER KNOB ACRYLIC EDWARDIAN\n", 886 | "10660 PURPLE DRAWERKNOB ACRYLIC EDWARDIAN\n", 887 | "10661 BLUE DRAWER KNOB ACRYLIC EDWARDIAN\n", 888 | "10662 ALARM CLOCK BAKELIKE CHOCOLATE\n", 889 | "10663 ALARM CLOCK BAKELIKE GREEN\n", 890 | "10664 ALARM CLOCK BAKELIKE RED \n", 891 | "10665 ALARM CLOCK BAKELIKE PINK\n", 892 | "10666 ALARM CLOCK BAKELIKE ORANGE\n", 893 | "10667 FOUR HOOK WHITE LOVEBIRDS\n", 894 | "10668 BLACK GRAND BAROQUE PHOTO FRAME\n", 895 | "10669 BATHROOM METAL SIGN \n", 896 | "10670 LARGE HEART MEASURING SPOONS\n", 897 | "10671 BOX OF 6 ASSORTED COLOUR TEASPOONS\n", 898 | "10672 BLUE 3 PIECE POLKADOT CUTLERY SET\n", 899 | "10673 RED 3 PIECE RETROSPOT CUTLERY SET\n", 900 | "10674 PINK 3 PIECE POLKADOT CUTLERY SET\n", 901 | "10675 EMERGENCY FIRST AID TIN \n", 902 | "10676 SET OF 2 TINS VINTAGE BATHROOM \n", 903 | "10677 SET/3 DECOUPAGE STACKING TINS\n", 904 | "10678 BOOM BOX SPEAKER BOYS\n", 905 | "10679 RED TOADSTOOL LED NIGHT LIGHT\n", 906 | "10680 3D DOG PICTURE PLAYING CARDS\n", 907 | "10681 BLACK EAR MUFF HEADPHONES\n", 908 | "10682 CAMOUFLAGE EAR MUFF HEADPHONES\n", 909 | "Name: Description, dtype: object\n", 910 | "2011-01-26 14:30:00\n", 911 | "44582 PINK NEW BAROQUECANDLESTICK CANDLE\n", 912 | "44583 BLUE NEW BAROQUE CANDLESTICK CANDLE\n", 913 | "44584 BLACK CANDELABRA T-LIGHT HOLDER\n", 914 | "44585 WOODLAND CHARLOTTE BAG\n", 915 | "44586 AIRLINE BAG VINTAGE JET SET BROWN\n", 916 | "44587 AIRLINE BAG VINTAGE JET SET WHITE\n", 917 | "44588 SANDWICH BATH SPONGE\n", 918 | "44589 ALARM CLOCK BAKELIKE CHOCOLATE\n", 919 | "44590 ALARM CLOCK BAKELIKE GREEN\n", 920 | "44591 ALARM CLOCK BAKELIKE RED \n", 921 | "44592 ALARM CLOCK BAKELIKE PINK\n", 922 | "44593 ALARM CLOCK BAKELIKE ORANGE\n", 923 | "44594 SMALL HEART MEASURING SPOONS\n", 924 | "44595 72 SWEETHEART FAIRY CAKE CASES\n", 925 | "44596 60 TEATIME FAIRY CAKE CASES\n", 926 | "44597 PACK OF 60 MUSHROOM CAKE CASES\n", 927 | "44598 PACK OF 60 SPACEBOY CAKE CASES\n", 928 | "44599 TEA TIME OVEN GLOVE\n", 929 | "44600 RED RETROSPOT OVEN GLOVE \n", 930 | "44601 RED RETROSPOT OVEN GLOVE DOUBLE\n", 931 | "44602 SET/2 RED RETROSPOT TEA TOWELS \n", 932 | "44603 REGENCY CAKESTAND 3 TIER\n", 933 | "44604 BOX OF 6 ASSORTED COLOUR TEASPOONS\n", 934 | "44605 MINI LADLE LOVE HEART RED \n", 935 | "44606 CHOCOLATE CALCULATOR\n", 936 | "44607 TOOTHPASTE TUBE PEN\n", 937 | "44608 SET OF 2 TINS VINTAGE BATHROOM \n", 938 | "44609 RED TOADSTOOL LED NIGHT LIGHT\n", 939 | "44610 3D DOG PICTURE PLAYING CARDS\n", 940 | "Name: Description, dtype: object\n", 941 | "2011-04-07 10:43:00\n", 942 | "101930 AIRLINE BAG VINTAGE JET SET WHITE\n", 943 | "101931 AIRLINE BAG VINTAGE JET SET RED\n", 944 | "101932 AIRLINE BAG VINTAGE TOKYO 78\n", 945 | "101933 AIRLINE BAG VINTAGE JET SET BROWN\n", 946 | "101934 RED RETROSPOT PURSE \n", 947 | "101935 ICE CREAM SUNDAE LIP GLOSS\n", 948 | "101936 VINTAGE HEADS AND TAILS CARD GAME \n", 949 | "101937 HOLIDAY FUN LUDO\n", 950 | "101938 TREASURE ISLAND BOOK BOX\n", 951 | "101939 WATERING CAN PINK BUNNY\n", 952 | "101940 RED DRAWER KNOB ACRYLIC EDWARDIAN\n", 953 | "101941 LARGE HEART MEASURING SPOONS\n", 954 | "101942 SMALL HEART MEASURING SPOONS\n", 955 | "101943 PACK OF 60 DINOSAUR CAKE CASES\n", 956 | "101944 RED RETROSPOT OVEN GLOVE DOUBLE\n", 957 | "101945 REGENCY CAKESTAND 3 TIER\n", 958 | "101946 ROSES REGENCY TEACUP AND SAUCER \n", 959 | "101947 RED TOADSTOOL LED NIGHT LIGHT\n", 960 | "101948 MINI PAINT SET VINTAGE \n", 961 | "101949 3D SHEET OF DOG STICKERS\n", 962 | "101950 3D SHEET OF CAT STICKERS\n", 963 | "101951 SMALL FOLDING SCISSOR(POINTED EDGE)\n", 964 | "101952 GIFT BAG PSYCHEDELIC APPLES\n", 965 | "101953 SET OF 2 TINS VINTAGE BATHROOM \n", 966 | "Name: Description, dtype: object\n", 967 | "2011-06-09 13:01:00\n", 968 | "157619 RABBIT NIGHT LIGHT\n", 969 | "157620 REGENCY TEA STRAINER\n", 970 | "157621 REGENCY TEA PLATE GREEN \n", 971 | "157622 REGENCY TEA PLATE PINK\n", 972 | "157623 REGENCY TEA PLATE ROSES \n", 973 | "157624 REGENCY TEAPOT ROSES \n", 974 | "157625 REGENCY SUGAR BOWL GREEN\n", 975 | "157626 REGENCY MILK JUG PINK \n", 976 | "157627 AIRLINE BAG VINTAGE TOKYO 78\n", 977 | "157628 AIRLINE BAG VINTAGE JET SET BROWN\n", 978 | "157629 VICTORIAN SEWING KIT\n", 979 | "157630 NAMASTE SWAGAT INCENSE\n", 980 | "157631 TRIPLE HOOK ANTIQUE IVORY ROSE\n", 981 | "157632 SMALL HEART MEASURING SPOONS\n", 982 | "157633 3D DOG PICTURE PLAYING CARDS\n", 983 | "157634 FEATHER PEN,COAL BLACK\n", 984 | "157635 ALARM CLOCK BAKELIKE RED \n", 985 | "157636 ALARM CLOCK BAKELIKE CHOCOLATE\n", 986 | "Name: Description, dtype: object\n", 987 | "2011-08-02 08:48:00\n", 988 | "205271 SET OF 60 VINTAGE LEAF CAKE CASES \n", 989 | "205272 SET 40 HEART SHAPE PETIT FOUR CASES\n", 990 | "205273 AIRLINE BAG VINTAGE JET SET BROWN\n", 991 | "205274 AIRLINE BAG VINTAGE JET SET RED\n", 992 | "205275 AIRLINE BAG VINTAGE JET SET WHITE\n", 993 | "205276 AIRLINE BAG VINTAGE TOKYO 78\n", 994 | "205277 AIRLINE BAG VINTAGE WORLD CHAMPION \n", 995 | "205278 WOODLAND DESIGN COTTON TOTE BAG\n", 996 | "205279 WOODLAND CHARLOTTE BAG\n", 997 | "205280 ALARM CLOCK BAKELIKE RED \n", 998 | "205281 TRIPLE HOOK ANTIQUE IVORY ROSE\n", 999 | "205282 SINGLE ANTIQUE ROSE HOOK IVORY\n", 1000 | "205283 TEA TIME OVEN GLOVE\n", 1001 | "205284 72 SWEETHEART FAIRY CAKE CASES\n", 1002 | "205285 60 TEATIME FAIRY CAKE CASES\n", 1003 | "205286 PACK OF 60 DINOSAUR CAKE CASES\n", 1004 | "205287 REGENCY CAKESTAND 3 TIER\n", 1005 | "205288 REGENCY MILK JUG PINK \n", 1006 | "205289 3D DOG PICTURE PLAYING CARDS\n", 1007 | "205290 REVOLVER WOODEN RULER \n", 1008 | "205291 VINTAGE HEADS AND TAILS CARD GAME \n", 1009 | "205292 RED REFECTORY CLOCK \n", 1010 | "Name: Description, dtype: object\n", 1011 | "2011-10-31 12:25:00\n", 1012 | "322193 MINI LIGHTS WOODLAND MUSHROOMS\n", 1013 | "322194 PINK GOOSE FEATHER TREE 60CM\n", 1014 | "322195 MADRAS NOTEBOOK MEDIUM\n", 1015 | "322196 AIRLINE BAG VINTAGE WORLD CHAMPION \n", 1016 | "322197 AIRLINE BAG VINTAGE JET SET BROWN\n", 1017 | "322198 AIRLINE BAG VINTAGE TOKYO 78\n", 1018 | "322199 AIRLINE BAG VINTAGE JET SET RED\n", 1019 | "322200 BIRDCAGE DECORATION TEALIGHT HOLDER\n", 1020 | "322201 CHRISTMAS METAL TAGS ASSORTED \n", 1021 | "322202 REGENCY CAKESTAND 3 TIER\n", 1022 | "322203 REGENCY TEAPOT ROSES \n", 1023 | "322204 TEA TIME DES TEA COSY\n", 1024 | "322205 TEA TIME KITCHEN APRON\n", 1025 | "322206 TEA TIME OVEN GLOVE\n", 1026 | "322207 PINK REGENCY TEACUP AND SAUCER\n", 1027 | "322208 GREEN REGENCY TEACUP AND SAUCER\n", 1028 | "322209 3D DOG PICTURE PLAYING CARDS\n", 1029 | "322210 RABBIT NIGHT LIGHT\n", 1030 | "322211 RED TOADSTOOL LED NIGHT LIGHT\n", 1031 | "322212 TREASURE ISLAND BOOK BOX\n", 1032 | "322213 VINTAGE HEADS AND TAILS CARD GAME \n", 1033 | "322214 MINI PLAYING CARDS DOLLY GIRL \n", 1034 | "322215 MINI PLAYING CARDS SPACEBOY \n", 1035 | "322216 PLAYING CARDS KEEP CALM & CARRY ON\n", 1036 | "322217 REVOLVER WOODEN RULER \n", 1037 | "322218 WOODEN SCHOOL COLOURING SET\n", 1038 | "322219 MINI PAINT SET VINTAGE \n", 1039 | "322220 TRADITIONAL KNITTING NANCY\n", 1040 | "322221 TRIPLE HOOK ANTIQUE IVORY ROSE\n", 1041 | "322222 PANTRY HOOK SPATULA\n", 1042 | "322223 PANTRY HOOK BALLOON WHISK \n", 1043 | "322224 PANTRY HOOK TEA STRAINER \n", 1044 | "322225 ROSES REGENCY TEACUP AND SAUCER \n", 1045 | "322226 ALARM CLOCK BAKELIKE CHOCOLATE\n", 1046 | "322227 ALARM CLOCK BAKELIKE PINK\n", 1047 | "322228 ALARM CLOCK BAKELIKE GREEN\n", 1048 | "322229 ALARM CLOCK BAKELIKE RED \n", 1049 | "322230 PACK OF 60 MUSHROOM CAKE CASES\n", 1050 | "322231 PACK OF 60 SPACEBOY CAKE CASES\n", 1051 | "322232 SET OF 60 VINTAGE LEAF CAKE CASES \n", 1052 | "322233 60 TEATIME FAIRY CAKE CASES\n", 1053 | "322234 72 SWEETHEART FAIRY CAKE CASES\n", 1054 | "322235 SMALL HEART MEASURING SPOONS\n", 1055 | "322236 LARGE HEART MEASURING SPOONS\n", 1056 | "322237 WOODLAND CHARLOTTE BAG\n", 1057 | "322238 REGENCY TEA STRAINER\n", 1058 | "322239 FOOD CONTAINER SET 3 LOVE HEART \n", 1059 | "Name: Description, dtype: object\n", 1060 | "2011-12-07 15:52:00\n", 1061 | "403103 CLASSIC CHROME BICYCLE BELL \n", 1062 | "403104 BICYCLE PUNCTURE REPAIR KIT \n", 1063 | "403105 BOOM BOX SPEAKER BOYS\n", 1064 | "403106 PINK NEW BAROQUECANDLESTICK CANDLE\n", 1065 | "403107 RED TOADSTOOL LED NIGHT LIGHT\n", 1066 | "403108 RABBIT NIGHT LIGHT\n", 1067 | "403109 WOODLAND CHARLOTTE BAG\n", 1068 | "403110 PINK GOOSE FEATHER TREE 60CM\n", 1069 | "403111 CHRISTMAS TABLE SILVER CANDLE SPIKE\n", 1070 | "403112 MINI PLAYING CARDS SPACEBOY \n", 1071 | "403113 MINI PLAYING CARDS DOLLY GIRL \n", 1072 | "Name: Description, dtype: object\n" 1073 | ] 1074 | } 1075 | ], 1076 | "source": [ 1077 | "for invoice_date, grp in test_grp.groupby('InvoiceDate'):\n", 1078 | " print(invoice_date)\n", 1079 | " print(grp['Description'])" 1080 | ] 1081 | }, 1082 | { 1083 | "cell_type": "code", 1084 | "execution_count": null, 1085 | "metadata": {}, 1086 | "outputs": [], 1087 | "source": [] 1088 | } 1089 | ], 1090 | "metadata": { 1091 | "kernelspec": { 1092 | "display_name": "Python 3", 1093 | "language": "python", 1094 | "name": "python3" 1095 | }, 1096 | "language_info": { 1097 | "codemirror_mode": { 1098 | "name": "ipython", 1099 | "version": 3 1100 | }, 1101 | "file_extension": ".py", 1102 | "mimetype": "text/x-python", 1103 | "name": "python", 1104 | "nbconvert_exporter": "python", 1105 | "pygments_lexer": "ipython3", 1106 | "version": "3.8.8" 1107 | } 1108 | }, 1109 | "nbformat": 4, 1110 | "nbformat_minor": 4 1111 | } 1112 | -------------------------------------------------------------------------------- /recsys/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fastforwardlabs/session_based_recommenders/c438dd1334fcefc6bedea69b0cd67f779a5de5d3/recsys/__init__.py -------------------------------------------------------------------------------- /recsys/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | from numpy.random import default_rng 5 | 6 | from recsys.utils import pickle_load, pickle_save, absolute_filename, create_path 7 | 8 | rng = default_rng(123) 9 | 10 | RECSYS15_PATH = "data/recsys15/" 11 | RECSYS15_FILENAME = "yoochoose-clicks.dat" 12 | 13 | ECOMM_PATH = "data/ecomm/" 14 | ECOMM_FILENAME = "OnlineRetail.csv" 15 | 16 | AOTM_PATH = "data/aotm/" 17 | AOTM_FILENAME = "aotm_list_ids.txt" 18 | AOTM_NUMPYFILENAME = "aotm_sessions.pkl" 19 | 20 | 21 | def load_recsys15(filename=None): 22 | """ 23 | Checks to see if the processed recsys15 session sequence file exists 24 | If True: loads and returns the session sequences 25 | If False: creates and returns the session sequences constructed from the original data file 26 | """ 27 | original_filename = absolute_filename(RECSYS15_PATH, RECSYS15_FILENAME) 28 | if filename is None: 29 | processed_filename = original_filename.replace(".dat", "_sessions.pkl") 30 | if os.path.exists(processed_filename): 31 | return pickle_load(processed_filename) 32 | else: 33 | if os.path.exists(absolute_filename(filename)): 34 | return pickle_load(absolute_filename(filename)) 35 | 36 | df = load_original_recsys15(original_filename) 37 | session_sequences = preprocess_recsys15(df) 38 | return session_sequences 39 | 40 | 41 | def load_original_recsys15(pathname=RECSYS15_PATH): 42 | """ 43 | Reads in the original RecSys15 Challenge "clicks" data file and returns as a Pandas DF 44 | """ 45 | df = pd.read_csv( 46 | absolute_filename(pathname, RECSYS15_FILENAME), 47 | names=["sessionID", "timestamp", "itemID", "category"], 48 | date_parser=["timestamp"], 49 | dtype={"category": str, "itemID": str, "sessionID": str}, 50 | ) 51 | return df 52 | 53 | 54 | def preprocess_recsys15(df, min_session_count=3): 55 | """ 56 | Given the recsys15 data in pandas df format, clean and sample to only those 57 | sessions that contain at least min_session_count number of interactions 58 | """ 59 | session_counts = df.groupby(["sessionID"]).count() 60 | df = df[ 61 | df["sessionID"].isin( 62 | session_counts[session_counts["itemID"] >= min_session_count].index 63 | ) 64 | ].reset_index(drop=True) 65 | 66 | # TODO: track preprocessed version by appending the filename with min_session_count 67 | filename = absolute_filename( 68 | RECSYS15_PATH, RECSYS15_FILENAME.replace(".dat", f"_sessions.pkl") 69 | ) 70 | sessions = construct_session_sequences( 71 | df, "sessionID", "itemID", save_filename=filename 72 | ) 73 | return sessions 74 | 75 | 76 | def load_ecomm(filename=None): 77 | """ 78 | Checks to see if the processed Online Retail ecommerce session sequence file exists 79 | If True: loads and returns the session sequences 80 | If False: creates and returns the session sequences constructed from the original data file 81 | """ 82 | original_filename = absolute_filename(ECOMM_PATH, ECOMM_FILENAME) 83 | if filename is None: 84 | processed_filename = original_filename.replace(".csv", "_sessions.pkl") 85 | if os.path.exists(processed_filename): 86 | return pickle_load(processed_filename) 87 | else: 88 | if os.path.exists(absolute_filename(filename)): 89 | return pickle_load(absolute_filename(filename)) 90 | 91 | df = load_original_ecomm(original_filename) 92 | session_sequences = preprocess_ecomm(df) 93 | return session_sequences 94 | 95 | 96 | def load_original_ecomm(pathname=ECOMM_PATH): 97 | df = pd.read_csv( 98 | absolute_filename(pathname, ECOMM_FILENAME), 99 | encoding="ISO-8859-1", 100 | parse_dates=["InvoiceDate"], 101 | ) 102 | return df 103 | 104 | 105 | def preprocess_ecomm(df, min_session_count=3): 106 | df.dropna(inplace=True) 107 | item_counts = df.groupby(["CustomerID"]).count()["StockCode"] 108 | df = df[ 109 | df["CustomerID"].isin(item_counts[item_counts >= min_session_count].index) 110 | ].reset_index(drop=True) 111 | 112 | # TODO: track preprocessed version by appending the filename with min_session_count 113 | filename = absolute_filename( 114 | ECOMM_PATH, ECOMM_FILENAME.replace(".csv", "_sessions.pkl") 115 | ) 116 | sessions = construct_session_sequences( 117 | df, "CustomerID", "StockCode", save_filename=filename 118 | ) 119 | return sessions 120 | 121 | 122 | def load_aotm(filename=None): 123 | """ 124 | Checks to see if the processed aotm session sequence file exists 125 | If True: loads and returns the session sequences 126 | If False: creates and returns the session sequences constructed from the original data file 127 | """ 128 | processed_filename = absolute_filename(AOTM_PATH, AOTM_NUMPYFILENAME) 129 | 130 | if os.path.exists(processed_filename): 131 | return pickle_load(processed_filename) 132 | 133 | original_filename = absolute_filename(AOTM_PATH, AOTM_FILENAME) 134 | df = load_original_aotm(original_filename) 135 | session_sequences = preprocess_aotm(df, save_path=AOTM_PATH) 136 | # session_sequences = construct_session_sequences(df, save_path=processed_filename) 137 | return session_sequences 138 | 139 | 140 | def load_original_aotm(pathname=AOTM_PATH): 141 | """ 142 | Reads in the original AOTM file with all 29,164 playlists in numerical format. 143 | Each line defines a playlist in the form #num# artnum: songnum artnum: songnum ... where num is the playlist index 144 | Returns a Pandas DF 145 | """ 146 | df = pd.read_csv( 147 | absolute_filename(pathname, AOTM_FILENAME), 148 | delimiter="# ", 149 | header=None, 150 | names=["list_id", "artists_tracks"], 151 | dtype={"category": int, "artists_tracks": str}, 152 | engine="python", 153 | ) 154 | # some sessions have missing entries... 155 | df.dropna(inplace=True) 156 | return df 157 | 158 | 159 | def preprocess_aotm(df, min_session_count=3, save_path=None): 160 | """ 161 | Given the aotm data in pandas df format, clean and sample to only those 162 | sessions that contain at least min_session_count number of interactions 163 | """ 164 | 165 | # separate out the artists and tracks within the sessions and use only the tracks for word2vec modeling 166 | artists_tracks = df["artists_tracks"].tolist() 167 | artists_tracks_tokens = [a.split() for a in artists_tracks] 168 | track_tokens = [ 169 | [x for x in token if not x.endswith(":")] for token in artists_tracks_tokens 170 | ] 171 | 172 | # exclude tracks with only one entry, remove if there are tokens with len < min_session_count 173 | track_tokens = [token for token in track_tokens if len(token) >= min_session_count] 174 | 175 | if save_path: 176 | create_path(save_path) 177 | pickle_save( 178 | track_tokens, filename=absolute_filename(save_path, AOTM_NUMPYFILENAME) 179 | ) 180 | return track_tokens 181 | 182 | 183 | def construct_session_sequences(df, sessionID, itemID, save_filename): 184 | """ 185 | Given a dataset in pandas df format, construct a list of lists where each sublist 186 | represents the interactions relevant to a specific session, for each sessionID. 187 | These sublists are composed of a series of itemIDs (str) and are the core training 188 | data used in the Word2Vec algorithm. 189 | 190 | This is performed by first grouping over the SessionID column, then casting to list 191 | each group's series of values in the ItemID column. 192 | 193 | INPUTS 194 | ------------ 195 | df: pandas dataframe 196 | sessionID: str column name in the df that represents invididual sessions 197 | itemID: str column name in the df that represents the items within a session 198 | save_filename: str output filename 199 | 200 | Example: 201 | Given a df that looks like 202 | 203 | SessionID | ItemID 204 | ---------------------- 205 | 1 | 111 206 | 1 | 123 207 | 1 | 345 208 | 2 | 045 209 | 2 | 334 210 | 2 | 342 211 | 2 | 8970 212 | 2 | 345 213 | 214 | Retrun a list of lists like this: 215 | 216 | sessions = [ 217 | ['111', '123', '345'], 218 | ['045', '334', '342', '8970', '345'], 219 | ] 220 | """ 221 | grp_by_session = df.groupby([sessionID]) 222 | 223 | session_sequences = [] 224 | for name, group in grp_by_session: 225 | session_sequences.append(list(group[itemID].values)) 226 | 227 | filename = absolute_filename(save_filename) 228 | create_path(filename) 229 | pickle_save(session_sequences, filename=save_filename) 230 | return session_sequences 231 | 232 | 233 | def train_test_split(session_sequences, test_size: int = 10000, rng=rng): 234 | """ 235 | Next Event Prediction (NEP) does not necessarily follow the traditional train/test split. 236 | 237 | Instead training is perform on the first n-1 items in a session sequence of n items. 238 | The test set is constructed of (n-1, n) "query" pairs where the n-1 item is used to generate 239 | recommendation predictions and it is checked whether the nth item is included in those recommendations. 240 | 241 | Example: 242 | Given a session sequence ['045', '334', '342', '8970', '128'] 243 | Training is done on ['045', '334', '342', '8970'] 244 | Testing (and validation) is done on ['8970', '128'] 245 | 246 | Test and Validation sets are constructed to be disjoint. 247 | """ 248 | #np.random.seed(123) 249 | #rng = np.random.default_rng(123) 250 | 251 | ### Construct training set 252 | # use (1 st, ..., n-1 th) items from each session sequence to form the train set (drop last item) 253 | train = [sess[:-1] for sess in session_sequences] 254 | 255 | if test_size > len(train): 256 | print( 257 | f"Test set cannot be larger than train set. Train set contains {len(train)} sessions." 258 | ) 259 | return 260 | 261 | ### Construct test and validation sets 262 | # sub-sample 10k sessions, and use (n-1 th, n th) pairs of items from session_squences to form the 263 | # disjoint validaton and test sets 264 | test_validation = [sess[-2:] for sess in session_sequences] 265 | # TODO: set numpy random seed! NM: added it at the top 266 | index = rng.choice(range(len(test_validation)), test_size * 2, replace=False) 267 | test = np.array(test_validation)[index[:test_size]].tolist() 268 | validation = np.array(test_validation)[index[test_size:]].tolist() 269 | 270 | return train, test, validation 271 | 272 | 273 | #""" 274 | 275 | if __name__ == "__main__": 276 | # load data 277 | sessions = load_ecomm() 278 | 279 | #df = load_original_ecomm() 280 | #sessions = preprocess_ecomm(df) 281 | #print(sessions[0]) 282 | 283 | print(len(sessions)) 284 | #train, test, valid = train_test_split(sessions) 285 | 286 | train, test, valid = train_test_split(sessions, test_size=1000) 287 | #print(train[0]) 288 | print("validation set:", valid[0]) 289 | print("test set", test[0]) 290 | #""" 291 | 292 | -------------------------------------------------------------------------------- /recsys/metrics.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | 4 | 5 | def recall_at_k(test, embeddings, k: int = 10) -> float: 6 | """ 7 | test must be a list of (query, ground truth) pairs 8 | embeddings must be a gensim.word2vec.wv thingy 9 | """ 10 | ratk_score = 0 11 | for query_item, ground_truth in test: 12 | # get the k most similar items to the query item (computes cosine similarity) 13 | neighbors = embeddings.similar_by_vector(query_item, topn=k) 14 | # clean up the list 15 | recommendations = [item for item, score in neighbors] 16 | # check if ground truth is in the recommedations 17 | if ground_truth in recommendations: 18 | ratk_score += 1 19 | ratk_score /= len(test) 20 | return ratk_score 21 | 22 | 23 | def recall_at_k_baseline(test, comatrix, k: int = 10) -> float: 24 | """ 25 | test must be a list of (query, ground truth) pairs 26 | embeddings must be a gensim.word2vec.wv thingy 27 | """ 28 | ratk_score = 0 29 | for query_item, ground_truth in test: 30 | # get the k most similar items to the query item (computes cosine similarity) 31 | try: 32 | co_occ = collections.Counter(comatrix[query_item]) 33 | items_and_counts = co_occ.most_common(k) 34 | recommendations = [item for (item, counts) in items_and_counts] 35 | if ground_truth in recommendations: 36 | ratk_score +=1 37 | except: 38 | pass 39 | ratk_score /= len(test) 40 | return ratk_score 41 | 42 | 43 | def hitratio_at_k(test, embeddings, k: int = 10) -> float: 44 | """ 45 | Implemented EXACTLY as was done in the Hyperparameters Matter paper. 46 | In the paper this metric is described as 47 | • Hit ratio at K (HR@K). It is equal to 1 if the test item appears 48 | in the list of k predicted items and 0 otherwise [13]. 49 | 50 | But this is not what they implement, where they instead divide by k. 51 | What they have actually implemented is more like Precision@k. 52 | However, Precision@k doesn't make a lot of sense in this context because 53 | there is only ONE possible correct answer in the list of generated 54 | recommendations. I don't think this is the best metric to use but 55 | I'll keep it here for posterity. 56 | 57 | test must be a list of (query, ground truth) pairs 58 | embeddings must be a gensim.word2vec.wv thingy 59 | """ 60 | hratk_score = 0 61 | for query_item, ground_truth in test: 62 | # If the query item and next item are the same, prediction is automatically correct 63 | if query_item == ground_truth: 64 | hratk_score += 1 / k 65 | else: 66 | # get the k most similar items to the query item (computes cosine similarity) 67 | neighbors = embeddings.similar_by_vector(query_item, topn=k) 68 | # clean up the list 69 | recommendations = [item for item, score in neighbors] 70 | # check if ground truth is in the recommedations 71 | if ground_truth in recommendations: 72 | hratk_score += 1 / k 73 | hratk_score /= len(test) 74 | return hratk_score*1000 75 | 76 | 77 | def mrr_at_k(test, embeddings, k: int) -> float: 78 | """ 79 | Mean Reciprocal Rank. 80 | 81 | test must be a list of (query, ground truth) pairs 82 | embeddings must be a gensim.word2vec.wv thingy 83 | """ 84 | mrratk_score = 0 85 | for query_item, ground_truth in test: 86 | # get the k most similar items to the query item (computes cosine similarity) 87 | neighbors = embeddings.similar_by_vector(query_item, topn=k) 88 | # clean up the list 89 | recommendations = [item for item, score in neighbors] 90 | # check if ground truth is in the recommedations 91 | if ground_truth in recommendations: 92 | # identify where the item is in the list 93 | rank_idx = ( 94 | np.argwhere(np.array(recommendations) == ground_truth)[0][0] + 1 95 | ) 96 | # score higher-ranked ground truth higher than lower-ranked ground truth 97 | mrratk_score += 1 / rank_idx 98 | mrratk_score /= len(test) 99 | return mrratk_score 100 | 101 | 102 | def mrr_at_k_baseline(test, comatrix, k: int = 10) -> float: 103 | """ 104 | Mean Reciprocal Rank. 105 | 106 | test must be a list of (query, ground truth) pairs 107 | embeddings must be a gensim.word2vec.wv thingy 108 | """ 109 | mrratk_score = 0 110 | for query_item, ground_truth in test: 111 | # get the k most similar items to the query item (computes cosine similarity) 112 | try: 113 | co_occ = collections.Counter(comatrix[query_item]) 114 | items_and_counts = co_occ.most_common(k) 115 | recommendations = [item for (item, counts) in items_and_counts] 116 | if ground_truth in recommendations: 117 | rank_idx = ( 118 | np.argwhere(np.array(recommendations) == ground_truth)[0][0] + 1 119 | ) 120 | mrratk_score += 1 / rank_idx 121 | except: 122 | pass 123 | mrratk_score /= len(test) 124 | return mrratk_score -------------------------------------------------------------------------------- /recsys/models.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import itertools 3 | from copy import deepcopy 4 | 5 | from gensim.models.word2vec import Word2Vec 6 | from gensim.models.callbacks import CallbackAny2Vec 7 | from ray import tune 8 | 9 | from recsys.data import ( 10 | load_recsys15, 11 | load_aotm, 12 | load_ecomm, 13 | train_test_split 14 | ) 15 | from recsys.metrics import recall_at_k, mrr_at_k 16 | from recsys.utils import absolute_filename 17 | 18 | MODEL_DIR = "output/models/" 19 | 20 | def train_w2v(train_data, params:dict, callbacks=None, model_name=None): 21 | if model_name: 22 | # Load a model for additional training. 23 | model = Word2Vec.load(model_name) 24 | else: 25 | # train model 26 | if callbacks: 27 | model = Word2Vec(callbacks=callbacks, **params) 28 | else: 29 | model = Word2Vec(**params) 30 | model.build_vocab(train_data) 31 | 32 | model.train(train_data, total_examples=model.corpus_count, epochs=model.epochs, compute_loss=True) 33 | vectors = model.wv 34 | return vectors 35 | 36 | 37 | def tune_w2v(config): 38 | # load data 39 | if config['dataset'] == 'recsys15': 40 | sessions = load_recsys15() 41 | elif config['dataset'] == 'aotm': 42 | sessions = load_aotm() 43 | elif config['dataset'] == 'ecomm': 44 | sessions = load_ecomm() 45 | else: 46 | print(f"{config['dataset']} is not a valid dataset name. Please choose from recsys15, aotm or ecomm") 47 | return 48 | 49 | train, test, valid = train_test_split(sessions, test_size=1000) 50 | ratk_logger = RecallAtKLogger(valid, k=config['k'], ray_tune=True) 51 | 52 | # remove keys from config that aren't hyperparameters of word2vec 53 | config.pop('dataset') 54 | config.pop('k') 55 | train_w2v(train, params=config, callbacks=[ratk_logger]) 56 | 57 | 58 | class RecallAtKLogger(CallbackAny2Vec): 59 | '''Report Recall@K at each epoch''' 60 | def __init__(self, validation_set, k, ray_tune=False, save_model=False): 61 | self.epoch = 0 62 | self.recall_scores = [] 63 | self.validation = validation_set 64 | self.k = k 65 | self.tune = ray_tune 66 | self.save = save_model 67 | 68 | def on_epoch_begin(self, model): 69 | if not self.tune: 70 | print(f'Epoch: {self.epoch}', end='\t') 71 | 72 | def on_epoch_end(self, model): 73 | # method 1: deepcopy the model and set the model copy's wv to None 74 | mod = deepcopy(model) 75 | mod.wv.norms = None # will cause it recalculate norms? 76 | 77 | # Every 10 epochs, save the model 78 | if self.epoch%10 == 0 and self.save: 79 | # method 2: save and reload the. model 80 | model.save(absolute_filename(f"{MODEL_DIR}w2v_{self.epoch}.model")) 81 | #mod = Word2Vec.load(f"w2v_{self.epoch}.model") 82 | 83 | ratk_score = recall_at_k(self.validation, mod.wv, self.k) 84 | 85 | if self.tune: 86 | tune.report(recall_at_k = ratk_score) 87 | else: 88 | self.recall_scores.append(ratk_score) 89 | print(f' Recall@10: {ratk_score}') 90 | self.epoch += 1 91 | 92 | 93 | class LossLogger(CallbackAny2Vec): 94 | '''Report training loss at each epoch''' 95 | def __init__(self): 96 | self.epoch = 0 97 | self.previous_loss = 0 98 | self.training_loss = [] 99 | 100 | def on_epoch_end(self, model): 101 | # the loss output by Word2Vec is more akin to a cumulative loss and increases each epoch 102 | # to get a value closer to loss per epoch, we subtract 103 | cumulative_loss = model.get_latest_training_loss() 104 | loss = cumulative_loss - self.previous_loss 105 | self.previous_loss = cumulative_loss 106 | self.training_loss.append(loss) 107 | print(f' Loss: {loss}') 108 | self.epoch += 1 109 | 110 | 111 | def association_rules_baseline(train_sessions): 112 | """ 113 | Constructs a co-occurence matrix that counts how frequently each item 114 | co-occurs with any other item in a given session. This matrix can 115 | then be used to generate a list of recommendations according to the most 116 | frequently co-occurring items for the item in question. 117 | 118 | These recommendations must be evaluated using the "_baseline" recall/mrr functions in metrics.py 119 | """ 120 | comatrix = collections.defaultdict(list) 121 | for session in train_sessions: 122 | for (x, y) in itertools.permutations(session, 2): 123 | comatrix[x].append(y) 124 | return comatrix -------------------------------------------------------------------------------- /recsys/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pathlib 4 | 5 | 6 | def pickle_save(vector, filename, overwrite=False): 7 | if os.path.exists(filename) and not overwrite: 8 | print(f"{filename} already exists! Please use overwrite flag.") 9 | else: 10 | create_path(filename) 11 | pickle.dump(vector, open(filename, "wb")) 12 | 13 | 14 | def pickle_load(filename): 15 | if os.path.exists(filename): 16 | with open(filename, "rb") as f: 17 | return pickle.load(f) 18 | else: 19 | print(f"{filename} does not exist!") 20 | 21 | 22 | def create_path(pathname: str) -> None: 23 | """Creates the directory for the given path if it doesn't already exist.""" 24 | dir = str(pathlib.Path(pathname).parent) 25 | if not os.path.exists(dir): 26 | os.makedirs(dir) 27 | 28 | 29 | def absolute_filename(*paths) -> str: 30 | """Given a path relative to this project's top-level directory, returns the 31 | full path in the OS. 32 | Args: 33 | paths: A list of folders/files. These will be joined in order with "/" 34 | or "\" depending on platform. 35 | Returns: 36 | The full absolute path in the OS. 37 | """ 38 | # First parent gets the scripts directory, and the second gets the top-level. 39 | result_path = pathlib.Path(__file__).resolve().parent.parent 40 | for path in paths: 41 | result_path /= path 42 | return str(result_path) 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.1.5 2 | numpy==1.19.2 3 | scikit-learn==0.24.1 4 | gensim==3.8.3 5 | matplotlib==3.3.4 6 | black==19.10b0 7 | dataclasses==0.8 8 | ray==1.2.0 9 | aioredis<2 10 | ray[tune] 11 | -e . # installs local recsys module in "edit" mode -------------------------------------------------------------------------------- /requirements3.6.txt: -------------------------------------------------------------------------------- 1 | pandas==1.1.5 2 | numpy==1.19.2 3 | scikit-learn==0.24.1 4 | gensim==3.8.3 5 | matplotlib==3.3.4 6 | black==19.10b0 7 | -e . # installs local recsys module in "edit" mode 8 | 9 | # The following two lines will install ray for Python 3.6 10 | https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp36-cp36m-manylinux2014_x86_64.whl 11 | ray[tune] -------------------------------------------------------------------------------- /scripts/baseline_analysis.py: -------------------------------------------------------------------------------- 1 | from recsys.data import load_ecomm, train_test_split 2 | from recsys.models import association_rules_baseline 3 | from recsys.metrics import recall_at_k_baseline, mrr_at_k_baseline 4 | 5 | # load data 6 | sessions = load_ecomm() 7 | train, test, valid = train_test_split(sessions, test_size=1000) 8 | 9 | # Construct a co-occurrence matrix containing how frequently 10 | # each item is found in the same session as any other item 11 | comatrix = association_rules_baseline(train) 12 | 13 | # Recommendations are generated as the top K most frequently co-occurring items 14 | # Compute metrics on these recommendations for each (query item, ground truth item) 15 | # pair in the test set 16 | recall_at_10 = recall_at_k_baseline(test, comatrix, k=10) 17 | mrr_at_10 = mrr_at_k_baseline(test, comatrix, k=10) 18 | 19 | print("Recall@10:", recall_at_10) 20 | print("MRR@10:", mrr_at_10) -------------------------------------------------------------------------------- /scripts/setup_ray_cluster.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import cdsw 4 | import ray 5 | import time 6 | 7 | 8 | RAY_DASHBOARD_PORT = int(os.getenv("CDSW_READONLY_PORT")) 9 | 10 | WORKERS = 5 11 | CPUS = 1 12 | MEMORY = 2 13 | 14 | ### RUN THE FOLLOWING LINES TO INITIALIZE A RAY CLUSTER IN CDSW/CML SESSION 15 | ray_head = ray.init(dashboard_port=RAY_DASHBOARD_PORT) 16 | ray_nodes = cdsw.launch_workers( 17 | n=WORKERS, 18 | cpu=CPUS, 19 | memory=MEMORY, 20 | kernel="python3", 21 | code=f"!ray start --num-cpus={CPUS} --address={ray_head['redis_address']}; while true; do sleep 10; done", 22 | ) 23 | print( 24 | f"""http://read-only-{os.getenv('CDSW_MASTER_ID')}.{os.getenv("CDSW_DOMAIN")}""" 25 | ) 26 | # Set environment variable so other scripts can access the head address 27 | os.environ["RAY_CLUSTER_ADDRESS"] = ray_head["redis_address"] 28 | 29 | 30 | ### RUN THESE LINES TO TEAR DOWN RAY CLUSTER WHEN FINISHED 31 | ray.shutdown() 32 | cdsw.stop_workers(*[worker["id"] for worker in ray_nodes]) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /scripts/train_w2v_with_logging.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | from ray.tune import Analysis 6 | 7 | from recsys.data import load_ecomm, train_test_split 8 | from recsys.models import train_w2v, RecallAtKLogger, LossLogger 9 | from recsys.metrics import recall_at_k, mrr_at_k 10 | from recsys.utils import absolute_filename, pickle_save 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--name", 15 | help="Directory for HPO experiment results -- providing this will result in a W2V model \ 16 | trained with the best hyperparameters from that experiment. (Note: if not provided, \ 17 | default W2V hyperparameters are used instead. These can be modified directly in this script." 18 | ) 19 | parser.add_argument( 20 | "-k", default=10, 21 | help="Number of recommendations to generate for model evaluation. Default is 10." 22 | ) 23 | parser.add_argument( 24 | "--outdir", 25 | help="Directory in which to save trained model embeddings and training metrics. Default is `output/`", 26 | default=absolute_filename("output/") 27 | ) 28 | args = parser.parse_known_args() 29 | 30 | 31 | # load data 32 | sessions = load_ecomm() 33 | train, test, valid = train_test_split(sessions, test_size=1000) 34 | 35 | # determine word2vec parameters to train with 36 | if args.name: 37 | analysis = Analysis(absolute_filename("ray_results", args.name), 38 | default_metric="recall_at_k", 39 | default_mode="max") 40 | 41 | w2v_params = analysis.get_best_config() 42 | else: 43 | # These the few required parameters for training Word2Vec for this use case. 44 | # All other parameters will rely on Gensim defaults. 45 | w2v_params = { 46 | "min_count": 1, 47 | "iter": 5, 48 | "workers": 10, 49 | "sg": 1, 50 | } 51 | 52 | # Instantiate callback to measurs Recall@K on the validation set after each epoch of training 53 | ratk_logger = RecallAtKLogger(valid, k=args.k, save_model=True) 54 | # Instantiate callback to compute Word2Vec's training loss on the training set after each epoch of training 55 | loss_logger = LossLogger() 56 | # Train Word2Vec model and retrieve trained embeddings 57 | embeddings = train_w2v(train, w2v_params, [ratk_logger, loss_logger]) 58 | 59 | # Save results 60 | pickle_save(ratk_logger.recall_scores, absolute_filename(args.outdir, f"recall@k_per_epoch.pkl")) 61 | pickle_save(loss_logger.training_loss, absolute_filename(args.outdir, f"trainloss_per_epoch.pkl")) 62 | 63 | # Save trained embeddings 64 | embeddings.save(absolute_filename(args.outdir, f"embeddings.wv")) 65 | 66 | # Visualize metrics as a function of epoch 67 | plt.plot(np.array(ratk_logger.recall_scores)/np.max(ratk_logger.recall_scores)) 68 | plt.plot(np.array(loss_logger.training_loss)/np.max(loss_logger.training_loss)) 69 | plt.show() 70 | 71 | # Print results on the test set 72 | print(recall_at_k(test, embeddings, k=args.k)) 73 | print(mrr_at_k(test, embeddings, k=args.k)) 74 | 75 | -------------------------------------------------------------------------------- /scripts/tune_w2v_with_ray.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | import ray 6 | from ray import tune 7 | from ray.tune.schedulers import ASHAScheduler 8 | 9 | from recsys.models import tune_w2v 10 | from recsys.utils import pickle_save, absolute_filename 11 | 12 | RAY_CLUSTER_ADDRESS = os.getenv("RAY_CLUSTER_ADDRESS") # For use in CDSW/CML 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--name", 17 | help="Directory name for HPO experiment results.", 18 | required=False 19 | ) 20 | parser.add_argument( 21 | "--smoke-test", action="store_true", help="Finish quickly for testing" 22 | ) 23 | parser.add_argument( 24 | "--ray-address", 25 | help="Address of Ray cluster for seamless distributed execution.", 26 | required=False, 27 | ) 28 | parser.add_argument( 29 | "-cml", 30 | help="Set this flag if using CDSW or CML for seamless distributed execution", 31 | action="store_true" 32 | ) 33 | parser.add_argument( 34 | "--asha", 35 | help="Enable an ASHA Scheduler to stop underperforming trials early during hyperparameter sweep", 36 | action="store_true" 37 | ) 38 | args, _ = parser.parse_known_args() 39 | 40 | 41 | # If necessary, connect to an existing Ray Cluster for distributed execution 42 | if args.ray_address: 43 | ray.init(address=args.ray_address) 44 | if args.cml: 45 | ray.init(address=RAY_CLUSTER_ADDRESS) 46 | 47 | # Define the hyperparameter search space for Word2Vec algorithm 48 | search_space = { 49 | "dataset": "ecomm", 50 | "k": 10, 51 | #"size": tune.grid_search(list(np.arange(10,106, 6))), 52 | #"window": tune.grid_search(list(np.arange(1,22, 3))), 53 | #"ns_exponent": tune.grid_search(list(np.arange(-1, 1.2, .2))), 54 | #"alpha": tune.grid_search([0.001, 0.01, 0.1]), 55 | "negative": tune.grid_search(list(np.arange(1,22, 3))), 56 | "iter": 10, 57 | "min_count": 1, 58 | "workers": 6, 59 | "sg": 1, 60 | } 61 | 62 | # The ASHA Scheduler will stop underperforming trials in a principled fashion 63 | asha_scheduler = ASHAScheduler(max_t=100, grace_period=10) if args.asha else None 64 | 65 | # Set the stopping critera -- use the smoke-test arg to test the system 66 | stopping_criteria = {"training_iteration": 1 if args.smoke_test else 9999} 67 | 68 | # Perform hyperparamter sweep with Ray Tune 69 | analysis = tune.run( 70 | tune_w2v, 71 | name=args.name, 72 | local_dir=absolute_filename("ray_results"), 73 | metric="recall_at_k", 74 | mode="max", 75 | scheduler=asha_scheduler, 76 | stop=stopping_criteria, 77 | num_samples=1, 78 | verbose=1, 79 | resources_per_trial={ 80 | "cpu": 1, 81 | "gpu": 0 82 | }, 83 | config=search_space, 84 | ) 85 | print("Best hyperparameters found were: ", analysis.best_config) 86 | 87 | """ 88 | # Plot all trials as a function of epochs 89 | dfs = analysis.trial_dataframes 90 | ax = None 91 | for d in dfs.values(): 92 | ax = d.recall_at_k.plot(ax=ax, legend=False) 93 | ax.set_xlabel("Epochs"); 94 | ax.set_ylabel("Recall@10"); 95 | """ 96 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="recsys", 5 | version="0.0.1", 6 | description=""" 7 | Utilities for a session-based recommendation system using Word2Vec. 8 | """, 9 | author="Melanie Beck & Nisha Muktewar", 10 | ) --------------------------------------------------------------------------------