├── .dvc ├── .gitignore └── config ├── .flake8 ├── .github └── workflows │ └── github-actions-demo.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Makefile ├── README.md ├── bentofile.yaml ├── config ├── main.yaml ├── process │ ├── process_1.yaml │ ├── process_2.yaml │ ├── process_3.yaml │ └── process_4.yaml └── segment │ ├── AffinityPropagation.yaml │ ├── AgglomerativeClustering.yaml │ ├── Birch.yaml │ ├── DBSCAN.yaml │ ├── KMeans.yaml │ ├── MeanShift.yaml │ ├── OPTICS.yaml │ └── SpectralClustering.yaml ├── data ├── .gitignore └── raw.dvc ├── dvc.lock ├── dvc.yaml ├── model └── cluster.pkl ├── notebook └── analyze_data.ipynb ├── poetry.lock ├── process_data.log ├── pyproject.toml ├── src ├── __init__.py ├── helper.py ├── main.py ├── process_data.py ├── segment.py └── streamlit_app.py └── tests ├── __init__.py └── test_process_data.py /.dvc/.gitignore: -------------------------------------------------------------------------------- 1 | /config.local 2 | /tmp 3 | /cache 4 | -------------------------------------------------------------------------------- /.dvc/config: -------------------------------------------------------------------------------- 1 | [core] 2 | remote = ocelot 3 | ['remote "ocelot"'] 4 | url = https://dagshub.com/khuyentran1401/customer_segmentation.dvc 5 | ['remote "origin"'] 6 | url = https://dagshub.com/khuyentran1401/customer_segmentation_demo.dvc 7 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401 3 | max-line-length = 79 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 -------------------------------------------------------------------------------- /.github/workflows/github-actions-demo.yml: -------------------------------------------------------------------------------- 1 | name: GitHub Actions Demo 2 | on: [push] 3 | jobs: 4 | Explore-GitHub-Actions: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event." 8 | - run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by GitHub!" 9 | - run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}." 10 | - name: Check out repository code 11 | uses: actions/checkout@v3 12 | - run: echo "💡 The ${{ github.repository }} repository has been cloned to the runner." 13 | - run: echo "🖥️ The workflow is now ready to test your code on the runner." 14 | - name: List files in the repository 15 | run: | 16 | ls ${{ github.workspace }} 17 | - run: echo "🍏 This job's status is ${{ job.status }}." 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | .mypy_cache 3 | .vscode 4 | __pycache__ 5 | outputs 6 | processors 7 | .DS_Store 8 | .ipynb_checkpoints 9 | customer-segmentation-a-AULLE--py3.8 10 | *-workspace 11 | .tox 12 | image/*.png 13 | wandb 14 | multirun 15 | main.log 16 | .venv 17 | /image 18 | segment.log 19 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: 20.8b1 4 | hooks: 5 | - id: black 6 | - repo: https://gitlab.com/pycqa/flake8 7 | rev: 3.8.4 8 | hooks: 9 | - id: flake8 10 | - repo: https://github.com/timothycrosley/isort 11 | rev: 5.7.0 12 | hooks: 13 | - id: isort 14 | - repo: https://github.com/kynan/nbstripout 15 | rev: 0.5.0 16 | hooks: 17 | - id: nbstripout 18 | # - repo: https://github.com/pre-commit/mirrors-mypy 19 | # rev: v0.782 20 | # hooks: 21 | # - id: mypy 22 | # args: [--ignore-missing-imports] 23 | 24 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: notebook 2 | .EXPORT_ALL_VARIABLES: 3 | 4 | PREFECT__FLOWS__CHECKPOINTING = true 5 | 6 | install: 7 | @echo "Installing..." 8 | poetry install 9 | 10 | activate: 11 | @echo "Activating virtual environment" 12 | poetry shell 13 | 14 | env: 15 | @echo "Please set the environment variable 'PREFECT__FLOWS__CHECKPOINTING=true' to persist the output of Prefect's flow" 16 | 17 | pull_data: 18 | @echo "Pulling data..." 19 | poetry run dvc pull 20 | 21 | setup: activate 22 | install_all: install pull_data env 23 | 24 | test: 25 | pytest 26 | 27 | clean: 28 | @echo "Deleting log files..." 29 | find . -name "*.log" -type f -not -path "./wandb/*" -delete -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-end Customer Segmentation Project 2 | 3 | ## Tools Used in This Project 4 | * [Prefect](https://www.prefect.io/): Orchestrate workflows - [article](https://towardsdatascience.com/orchestrate-a-data-science-project-in-python-with-prefect-e69c61a49074) 5 | * [hydra](https://hydra.cc/): Manage configuration files - [article](https://towardsdatascience.com/introduction-to-hydra-cc-a-powerful-framework-to-configure-your-data-science-projects-ed65713a53c6) 6 | * [Weights & Biases](https://wandb.ai/): Track and monitor experiments - [article](https://towardsdatascience.com/introduction-to-weight-biases-track-and-visualize-your-machine-learning-experiments-in-3-lines-9c9553b0f99d) 7 | * [pre-commit plugins](https://pre-commit.com/): Automate code reviewing formatting - [article](https://towardsdatascience.com/4-pre-commit-plugins-to-automate-code-reviewing-and-formatting-in-python-c80c6d2e9f5?sk=2388804fb174d667ee5b680be22b8b1f) 8 | * [poetry](https://python-poetry.org/): Python dependency management - [article](https://towardsdatascience.com/how-to-effortlessly-publish-your-python-package-to-pypi-using-poetry-44b305362f9f) 9 | * [DVC](https://dvc.org/): Data version control. - [article](https://towardsdatascience.com/introduction-to-dvc-data-version-control-tool-for-machine-learning-projects-7cb49c229fe0) 10 | * [BentoML](https://docs.bentoml.org/en/latest/): Deploy and serve machine learning models - [article](https://towardsdatascience.com/bentoml-create-an-ml-powered-prediction-service-in-minutes-23d135d6ca76) 11 | 12 | ## Variations of This Project 13 | - [workshop branch](https://github.com/khuyentran1401/customer_segmentation/tree/workshop) focuses on Hydra, Prefect, and Weight & Biases along with explanations. 14 | - [bentoml_demo branch](https://github.com/khuyentran1401/customer_segmentation/tree/bentoml_demo) focuses on BentoML along with explanations. 15 | 16 | ## Project Structure 17 | * `src`: consists of Python scripts 18 | * `config`: consists of configuration files 19 | * `data`: consists of data 20 | * `notebook`: consists of Jupyter Notebooks 21 | * `tests`: consists of test files 22 | 23 | ## Set Up the Project 24 | 1. Install [Poetry](https://python-poetry.org/docs/#installation) 25 | 2. Set up the environment: 26 | ```bash 27 | make setup 28 | make install_all 29 | ``` 30 | 3. To persist the output of Prefect's flow, run 31 | ```bash 32 | export PREFECT__FLOWS__CHECKPOINTING=true 33 | ``` 34 | 35 | ## Run the Project 36 | To run all flows, type: 37 | ```bash 38 | python src/main.py 39 | ``` 40 | 41 | To run the `process` flow, type: 42 | ```bash 43 | python src/main.py flow=process 44 | ``` 45 | 46 | To run the `segment` flow, type: 47 | ```bash 48 | python src/main.py flow=segment 49 | ``` 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /bentofile.yaml: -------------------------------------------------------------------------------- 1 | service: "src/bentoml_app.py:service" 2 | include: 3 | - "src/bentoml_app.py" 4 | - "src/streamlit_app.py" 5 | - "processors/*" 6 | python: 7 | packages: 8 | - numpy==1.19.5 9 | - pandas==1.3.3 10 | - scikit-learn==1.0.2 11 | - pydantic==1.9.0 12 | - feature-engine==1.2.0 13 | - streamlit==1.5.1 -------------------------------------------------------------------------------- /config/main.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | output_subdir: null 3 | run: 4 | dir: . 5 | 6 | defaults: 7 | - process: process_2 8 | - segment: KMeans 9 | - _self_ 10 | 11 | raw_data: 12 | path: data/raw/marketing_campaign.csv 13 | 14 | intermediate: 15 | dir: data/intermediate 16 | name: scale_features.csv 17 | path: ${intermediate.dir}/${intermediate.name} 18 | 19 | segmented: 20 | dir: data/final 21 | name: segmented.csv 22 | path: ${segmented.dir}/${segmented.name} 23 | 24 | flow: all 25 | 26 | pca: 27 | n_components: 3 28 | columns: 29 | - col1 30 | - col2 31 | - col3 32 | 33 | elbow_metric: silhouette 34 | 35 | image: 36 | elbow: image/elbow.png 37 | clusters: image/cluster.png 38 | 39 | wandb_mode: -------------------------------------------------------------------------------- /config/process/process_1.yaml: -------------------------------------------------------------------------------- 1 | name: process_1 2 | keep_columns: 3 | - Income 4 | - Recency 5 | - NumWebVisitsMonth 6 | - AcceptedCmp3 7 | - AcceptedCmp4 8 | - AcceptedCmp5 9 | - AcceptedCmp1 10 | - AcceptedCmp2 11 | - Complain 12 | - Response 13 | - age 14 | - total_purchases 15 | - enrollment_years 16 | - family_size 17 | 18 | remove_outliers_threshold: 19 | age: 90 20 | Income: 600000 21 | 22 | encode: 23 | family_size: 24 | Married: 2 25 | Together: 2 26 | Absurd: 1 27 | Widow: 1 28 | YOLO: 1 29 | Divorced: 1 30 | Single: 1 31 | Alone: 1 32 | 33 | -------------------------------------------------------------------------------- /config/process/process_2.yaml: -------------------------------------------------------------------------------- 1 | name: process_2 2 | keep_columns: 3 | - Income 4 | - Recency 5 | - NumWebVisitsMonth 6 | - Complain 7 | - age 8 | - total_purchases 9 | - enrollment_years 10 | - family_size 11 | 12 | remove_outliers_threshold: 13 | age: 90 14 | Income: 600000 15 | 16 | encode: 17 | family_size: 18 | Married: 2 19 | Together: 2 20 | Absurd: 1 21 | Widow: 1 22 | YOLO: 1 23 | Divorced: 1 24 | Single: 1 25 | Alone: 1 26 | 27 | -------------------------------------------------------------------------------- /config/process/process_3.yaml: -------------------------------------------------------------------------------- 1 | name: process_3 2 | keep_columns: 3 | - Income 4 | - Recency 5 | - NumWebVisitsMonth 6 | - NumDealsPurchases 7 | - NumWebPurchases 8 | - NumCatalogPurchases 9 | - NumStorePurchases 10 | - Complain 11 | - Response 12 | - age 13 | - enrollment_years 14 | - family_size 15 | 16 | remove_outliers_threshold: 17 | age: 90 18 | Income: 600000 19 | 20 | encode: 21 | family_size: 22 | Married: 2 23 | Together: 2 24 | Absurd: 1 25 | Widow: 1 26 | YOLO: 1 27 | Divorced: 1 28 | Single: 1 29 | Alone: 1 30 | 31 | -------------------------------------------------------------------------------- /config/process/process_4.yaml: -------------------------------------------------------------------------------- 1 | name: process_4 2 | keep_columns: 3 | - Income 4 | - Recency 5 | - MntWines 6 | - MntFruits 7 | - MntMeatProducts 8 | - MntFishProducts 9 | - MntSweetProducts 10 | - MntGoldProds 11 | - Complain 12 | - Response 13 | - age 14 | - enrollment_years 15 | - family_size 16 | 17 | remove_outliers_threshold: 18 | age: 90 19 | Income: 600000 20 | 21 | encode: 22 | family_size: 23 | Married: 2 24 | Together: 2 25 | Absurd: 1 26 | Widow: 1 27 | YOLO: 1 28 | Divorced: 1 29 | Single: 1 30 | Alone: 1 31 | 32 | -------------------------------------------------------------------------------- /config/segment/AffinityPropagation.yaml: -------------------------------------------------------------------------------- 1 | algorithm: AffinityPropagation 2 | args: -------------------------------------------------------------------------------- /config/segment/AgglomerativeClustering.yaml: -------------------------------------------------------------------------------- 1 | algorithm: AgglomerativeClustering 2 | args: -------------------------------------------------------------------------------- /config/segment/Birch.yaml: -------------------------------------------------------------------------------- 1 | algorithm: Birch 2 | args: 3 | n_clusters: 3 -------------------------------------------------------------------------------- /config/segment/DBSCAN.yaml: -------------------------------------------------------------------------------- 1 | algorithm: DBSCAN 2 | args: -------------------------------------------------------------------------------- /config/segment/KMeans.yaml: -------------------------------------------------------------------------------- 1 | algorithm: KMeans 2 | args: 3 | n_clusters: 8 -------------------------------------------------------------------------------- /config/segment/MeanShift.yaml: -------------------------------------------------------------------------------- 1 | algorithm: MeanShift 2 | args: -------------------------------------------------------------------------------- /config/segment/OPTICS.yaml: -------------------------------------------------------------------------------- 1 | algorithm: OPTICS 2 | args: -------------------------------------------------------------------------------- /config/segment/SpectralClustering.yaml: -------------------------------------------------------------------------------- 1 | algorithm: SpectralClustering 2 | args: 3 | n_clusters: 8 -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | /raw 2 | /intermediate 3 | /final 4 | -------------------------------------------------------------------------------- /data/raw.dvc: -------------------------------------------------------------------------------- 1 | outs: 2 | - md5: 10c3f643286f509fa7f6b4675d9efbad.dir 3 | size: 222379 4 | nfiles: 1 5 | path: raw 6 | -------------------------------------------------------------------------------- /dvc.lock: -------------------------------------------------------------------------------- 1 | schema: '2.0' 2 | stages: 3 | process_data: 4 | cmd: python src/process_data.py 5 | deps: 6 | - path: config/main.yaml 7 | md5: 413f47291554e937fffcac2027faada6 8 | size: 522 9 | - path: config/process 10 | md5: 3b82130fcc2aa6161ffda042584cf844.dir 11 | size: 1683 12 | nfiles: 4 13 | - path: data/raw 14 | md5: 10c3f643286f509fa7f6b4675d9efbad.dir 15 | size: 222379 16 | nfiles: 1 17 | - path: src/process_data.py 18 | md5: 2f1ddcf21d1514f18f5944d379ac1250 19 | size: 3170 20 | outs: 21 | - path: data/intermediate 22 | md5: e2082717fdfbbe71b178d7205dbf9ac9.dir 23 | size: 404048 24 | nfiles: 2 25 | segment: 26 | cmd: python src/segment.py 27 | deps: 28 | - path: config/main.yaml 29 | md5: 413f47291554e937fffcac2027faada6 30 | size: 522 31 | - path: config/segment 32 | md5: c704b1c0b8900693e098cef212f095bf.dir 33 | size: 276 34 | nfiles: 8 35 | - path: data/intermediate 36 | md5: e2082717fdfbbe71b178d7205dbf9ac9.dir 37 | size: 404048 38 | nfiles: 2 39 | - path: src/segment.py 40 | md5: 892a53397b3fc800bb5faec71a4df76d 41 | size: 5089 42 | outs: 43 | - path: data/final 44 | md5: 895bd4c5211f1f6e32ff37683e1a7210.dir 45 | size: 868954 46 | nfiles: 2 47 | - path: image 48 | md5: 920873df71bb734d51076b33e4fa0f1b.dir 49 | size: 217931 50 | nfiles: 3 51 | -------------------------------------------------------------------------------- /dvc.yaml: -------------------------------------------------------------------------------- 1 | stages: 2 | process_data: 3 | cmd: python src/process_data.py 4 | deps: 5 | - config/main.yaml 6 | - config/process 7 | - data/raw 8 | - src/process_data.py 9 | outs: 10 | - data/intermediate: 11 | persist: true 12 | segment: 13 | cmd: python src/segment.py 14 | deps: 15 | - config/main.yaml 16 | - config/segment 17 | - data/intermediate 18 | - src/segment.py 19 | outs: 20 | - data/final: 21 | persist: true 22 | - image: 23 | persist: true 24 | -------------------------------------------------------------------------------- /model/cluster.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khuyentran1401/customer_segmentation/5c5890c3e0416128a834d690a771701e0d046ce2/model/cluster.pkl -------------------------------------------------------------------------------- /notebook/analyze_data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f894664a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext nb_black" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "6440ec12", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "%cd ~/customer_segmentation/" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "cf83bd4d", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import pandas as pd\n", 31 | "import seaborn as sns\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "from IPython.core.pylabtools import figsize\n", 34 | "from pandasgui import show\n", 35 | "import plotly.express as px\n", 36 | "from yellowbrick.cluster import KElbowVisualizer\n", 37 | "from sklearn.cluster import KMeans\n", 38 | "import matplotlib.pyplot as plt, numpy as np\n", 39 | "from mpl_toolkits.mplot3d import Axes3D\n", 40 | "from sklearn.cluster import AgglomerativeClustering\n", 41 | "from matplotlib.colors import ListedColormap\n", 42 | "from sklearn import metrics" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "f38a06b7", 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "df = pd.read_csv(\"data/final/segmented.csv\", index_col=0)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "192c4429", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "df.columns" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "8dd843bf", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "sns.countplot(x=df[\"clusters\"])\n", 73 | "plt.plot()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "29f69846", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "def create_box_plot(df: pd.DataFrame, y: str):\n", 84 | " sorted_clusters = (\n", 85 | " df.groupby([\"clusters\"]).agg({y: \"median\"}).sort_values(by=y).index\n", 86 | " )\n", 87 | " sns.boxplot(data=df, x=\"clusters\", y=y, order=sorted_clusters)\n", 88 | " plt.plot()" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "406eec83", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "create_box_plot(df, \"Income\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "e67c7ac7", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "create_box_plot(df, \"NumWebVisitsMonth\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "1e968a1a", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "create_box_plot(df, \"age\")" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "8dc713d2", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "create_box_plot(df, \"total_purchases\")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "e42bc7d2", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "create_box_plot(df, \"family_size\")" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "8c34b63b", 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [] 148 | } 149 | ], 150 | "metadata": { 151 | "hide_input": false, 152 | "interpreter": { 153 | "hash": "4d434d5ef086bedd16efcd2361c0add02e37807d20c32e890f8acb9574c06275" 154 | }, 155 | "kernelspec": { 156 | "display_name": "Python 3.8.10 64-bit ('customer-segmentation-a-AULLE--py3.8': poetry)", 157 | "name": "python3" 158 | }, 159 | "language_info": { 160 | "codemirror_mode": { 161 | "name": "ipython", 162 | "version": 3 163 | }, 164 | "file_extension": ".py", 165 | "mimetype": "text/x-python", 166 | "name": "python", 167 | "nbconvert_exporter": "python", 168 | "pygments_lexer": "ipython3", 169 | "version": "3.8.10" 170 | }, 171 | "toc": { 172 | "base_numbering": 1, 173 | "nav_menu": {}, 174 | "number_sections": true, 175 | "sideBar": true, 176 | "skip_h1_title": false, 177 | "title_cell": "Table of Contents", 178 | "title_sidebar": "Contents", 179 | "toc_cell": false, 180 | "toc_position": {}, 181 | "toc_section_display": true, 182 | "toc_window_display": false 183 | } 184 | }, 185 | "nbformat": 4, 186 | "nbformat_minor": 5 187 | } 188 | -------------------------------------------------------------------------------- /process_data.log: -------------------------------------------------------------------------------- 1 | [2022-03-10 14:18:17,222][prefect.FlowRunner][INFO] - Beginning Flow run for 'process_data' 2 | [2022-03-10 14:18:17,322][prefect.TaskRunner][INFO] - Task 'load_data': Starting task run... 3 | [2022-03-10 14:18:17,538][prefect.TaskRunner][INFO] - Task 'load_data': Finished task run for task with final state: 'Success' 4 | [2022-03-10 14:18:17,602][prefect.TaskRunner][INFO] - Task 'drop_na': Starting task run... 5 | [2022-03-10 14:18:17,645][prefect.TaskRunner][INFO] - Task 'drop_na': Finished task run for task with final state: 'Success' 6 | [2022-03-10 14:18:17,708][prefect.TaskRunner][INFO] - Task 'get_age': Starting task run... 7 | [2022-03-10 14:18:17,747][prefect.TaskRunner][INFO] - Task 'get_age': Finished task run for task with final state: 'Success' 8 | [2022-03-10 14:18:17,812][prefect.TaskRunner][INFO] - Task 'get_total_children': Starting task run... 9 | [2022-03-10 14:18:17,859][prefect.TaskRunner][INFO] - Task 'get_total_children': Finished task run for task with final state: 'Success' 10 | [2022-03-10 14:18:17,928][prefect.TaskRunner][INFO] - Task 'get_total_purchases': Starting task run... 11 | [2022-03-10 14:18:17,975][prefect.TaskRunner][INFO] - Task 'get_total_purchases': Finished task run for task with final state: 'Success' 12 | [2022-03-10 14:18:18,040][prefect.TaskRunner][INFO] - Task 'get_enrollment_years': Starting task run... 13 | [2022-03-10 14:18:18,086][prefect.TaskRunner][INFO] - Task 'get_enrollment_years': Finished task run for task with final state: 'Success' 14 | [2022-03-10 14:18:18,149][prefect.TaskRunner][INFO] - Task 'get_family_size': Starting task run... 15 | [2022-03-10 14:18:18,190][prefect.TaskRunner][INFO] - Task 'get_family_size': Finished task run for task with final state: 'Success' 16 | [2022-03-10 14:18:18,263][prefect.TaskRunner][INFO] - Task 'drop_columns_and_rows': Starting task run... 17 | [2022-03-10 14:18:18,345][prefect.TaskRunner][INFO] - Task 'drop_columns_and_rows': Finished task run for task with final state: 'Success' 18 | [2022-03-10 14:18:18,420][prefect.TaskRunner][INFO] - Task 'get_scaler': Starting task run... 19 | [2022-03-10 14:18:18,457][prefect.TaskRunner][INFO] - Task 'get_scaler': Finished task run for task with final state: 'Success' 20 | [2022-03-10 14:18:18,517][prefect.TaskRunner][INFO] - Task 'scale_features': Starting task run... 21 | [2022-03-10 14:18:18,643][prefect.TaskRunner][INFO] - Task 'scale_features': Finished task run for task with final state: 'Success' 22 | [2022-03-10 14:18:18,646][prefect.FlowRunner][INFO] - Flow run SUCCESS: all reference tasks succeeded 23 | [2022-03-10 15:04:21,286][prefect.FlowRunner][INFO] - Beginning Flow run for 'process_data' 24 | [2022-03-10 15:04:21,382][prefect.TaskRunner][INFO] - Task 'load_data': Starting task run... 25 | [2022-03-10 15:04:21,572][prefect.TaskRunner][INFO] - Task 'load_data': Finished task run for task with final state: 'Success' 26 | [2022-03-10 15:04:21,639][prefect.TaskRunner][INFO] - Task 'drop_na': Starting task run... 27 | [2022-03-10 15:04:21,679][prefect.TaskRunner][INFO] - Task 'drop_na': Finished task run for task with final state: 'Success' 28 | [2022-03-10 15:04:21,741][prefect.TaskRunner][INFO] - Task 'get_age': Starting task run... 29 | [2022-03-10 15:04:21,779][prefect.TaskRunner][INFO] - Task 'get_age': Finished task run for task with final state: 'Success' 30 | [2022-03-10 15:04:21,841][prefect.TaskRunner][INFO] - Task 'get_total_children': Starting task run... 31 | [2022-03-10 15:04:21,879][prefect.TaskRunner][INFO] - Task 'get_total_children': Finished task run for task with final state: 'Success' 32 | [2022-03-10 15:04:21,943][prefect.TaskRunner][INFO] - Task 'get_total_purchases': Starting task run... 33 | [2022-03-10 15:04:21,981][prefect.TaskRunner][INFO] - Task 'get_total_purchases': Finished task run for task with final state: 'Success' 34 | [2022-03-10 15:04:22,046][prefect.TaskRunner][INFO] - Task 'get_enrollment_years': Starting task run... 35 | [2022-03-10 15:04:22,092][prefect.TaskRunner][INFO] - Task 'get_enrollment_years': Finished task run for task with final state: 'Success' 36 | [2022-03-10 15:04:22,154][prefect.TaskRunner][INFO] - Task 'get_family_size': Starting task run... 37 | [2022-03-10 15:04:22,196][prefect.TaskRunner][INFO] - Task 'get_family_size': Finished task run for task with final state: 'Success' 38 | [2022-03-10 15:04:22,269][prefect.TaskRunner][INFO] - Task 'drop_columns_and_rows': Starting task run... 39 | [2022-03-10 15:04:22,337][prefect.TaskRunner][INFO] - Task 'drop_columns_and_rows': Finished task run for task with final state: 'Success' 40 | [2022-03-10 15:04:22,400][prefect.TaskRunner][INFO] - Task 'get_scaler': Starting task run... 41 | [2022-03-10 15:04:22,443][prefect.TaskRunner][INFO] - Task 'get_scaler': Finished task run for task with final state: 'Success' 42 | [2022-03-10 15:04:22,504][prefect.TaskRunner][INFO] - Task 'scale_features': Starting task run... 43 | [2022-03-10 15:04:22,631][prefect.TaskRunner][INFO] - Task 'scale_features': Finished task run for task with final state: 'Success' 44 | [2022-03-10 15:04:22,634][prefect.FlowRunner][INFO] - Flow run SUCCESS: all reference tasks succeeded 45 | [2022-03-10 16:14:44,449][prefect.FlowRunner][INFO] - Beginning Flow run for 'process_data' 46 | [2022-03-10 16:14:44,543][prefect.TaskRunner][INFO] - Task 'load_data': Starting task run... 47 | [2022-03-10 16:14:44,719][prefect.TaskRunner][INFO] - Task 'load_data': Finished task run for task with final state: 'Success' 48 | [2022-03-10 16:14:44,780][prefect.TaskRunner][INFO] - Task 'drop_na': Starting task run... 49 | [2022-03-10 16:14:44,821][prefect.TaskRunner][INFO] - Task 'drop_na': Finished task run for task with final state: 'Success' 50 | [2022-03-10 16:14:44,881][prefect.TaskRunner][INFO] - Task 'get_age': Starting task run... 51 | [2022-03-10 16:14:44,916][prefect.TaskRunner][INFO] - Task 'get_age': Finished task run for task with final state: 'Success' 52 | [2022-03-10 16:14:44,976][prefect.TaskRunner][INFO] - Task 'get_total_children': Starting task run... 53 | [2022-03-10 16:14:45,010][prefect.TaskRunner][INFO] - Task 'get_total_children': Finished task run for task with final state: 'Success' 54 | [2022-03-10 16:14:45,068][prefect.TaskRunner][INFO] - Task 'get_total_purchases': Starting task run... 55 | [2022-03-10 16:14:45,109][prefect.TaskRunner][INFO] - Task 'get_total_purchases': Finished task run for task with final state: 'Success' 56 | [2022-03-10 16:14:45,174][prefect.TaskRunner][INFO] - Task 'get_enrollment_years': Starting task run... 57 | [2022-03-10 16:14:45,221][prefect.TaskRunner][INFO] - Task 'get_enrollment_years': Finished task run for task with final state: 'Success' 58 | [2022-03-10 16:14:45,285][prefect.TaskRunner][INFO] - Task 'get_family_size': Starting task run... 59 | [2022-03-10 16:14:45,327][prefect.TaskRunner][INFO] - Task 'get_family_size': Finished task run for task with final state: 'Success' 60 | [2022-03-10 16:14:45,391][prefect.TaskRunner][INFO] - Task 'drop_columns_and_rows': Starting task run... 61 | [2022-03-10 16:14:45,457][prefect.TaskRunner][INFO] - Task 'drop_columns_and_rows': Finished task run for task with final state: 'Success' 62 | [2022-03-10 16:14:45,522][prefect.TaskRunner][INFO] - Task 'get_scaler': Starting task run... 63 | [2022-03-10 16:14:45,562][prefect.TaskRunner][INFO] - Task 'get_scaler': Finished task run for task with final state: 'Success' 64 | [2022-03-10 16:14:45,633][prefect.TaskRunner][INFO] - Task 'scale_features': Starting task run... 65 | [2022-03-10 16:14:45,763][prefect.TaskRunner][INFO] - Task 'scale_features': Finished task run for task with final state: 'Success' 66 | [2022-03-10 16:14:45,766][prefect.FlowRunner][INFO] - Flow run SUCCESS: all reference tasks succeeded 67 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "customer_segmentation" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["khuyentran1401 "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.8" 9 | numpy = ">=1.14.3,<1.21.0" 10 | matplotlib = "^3.5.1" 11 | seaborn = "^0.11.2" 12 | scikit-learn = "^1.0.2" 13 | hydra-core = "^1.1.1" 14 | dvc = "^2.9.3" 15 | yellowbrick = "^1.3.post1" 16 | pandas = "<1.3.5" 17 | wandb = "^0.12.9" 18 | bentoml = "1.0.0a3" 19 | prefect = "^0.15.13" 20 | pydantic = "^1.9.0" 21 | streamlit = "^1.5.1" 22 | pygit2 = "1.8.0" 23 | 24 | [tool.poetry.dev-dependencies] 25 | pytest = "^6.2.5" 26 | pre-commit = "^2.17.0" 27 | 28 | [virtualenvs] 29 | create = true 30 | in-project = true 31 | 32 | [build-system] 33 | requires = ["poetry-core>=1.0.0"] 34 | build-backend = "poetry.core.masonry.api" 35 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khuyentran1401/customer_segmentation/5c5890c3e0416128a834d690a771701e0d046ce2/src/__init__.py -------------------------------------------------------------------------------- /src/helper.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pandas as pd 3 | import wandb 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | 7 | @hydra.main( 8 | config_path="../config", 9 | config_name="main", 10 | ) 11 | def initialize_wandb(config: DictConfig): 12 | wandb.init( 13 | project="customer_segmentation", 14 | config=OmegaConf.to_object(config), 15 | reinit=True, 16 | mode="disabled", 17 | ) 18 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import wandb 3 | from omegaconf import DictConfig, OmegaConf 4 | 5 | from process_data import process_data 6 | from segment import segment 7 | 8 | 9 | @hydra.main( 10 | config_path="../config", 11 | config_name="main", 12 | ) 13 | def main(config: DictConfig): 14 | 15 | process_data(config) 16 | segment(config) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /src/process_data.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | import hydra 4 | import pandas as pd 5 | import wandb 6 | from omegaconf import DictConfig 7 | from prefect import Flow, Parameter, task 8 | from prefect.engine.results import LocalResult 9 | from prefect.engine.serializers import PandasSerializer 10 | from sklearn.preprocessing import StandardScaler 11 | 12 | INTERMEDIATE_OUTPUT = LocalResult( 13 | "data/intermediate/", 14 | location="{task_name}.csv", 15 | serializer=PandasSerializer("csv", serialize_kwargs={"index": False}), 16 | ) 17 | 18 | 19 | @task 20 | def load_data(data_name: str) -> pd.DataFrame: 21 | data = pd.read_csv(data_name) 22 | return data 23 | 24 | 25 | @task 26 | def drop_na(df: pd.DataFrame) -> pd.DataFrame: 27 | return df.dropna() 28 | 29 | 30 | @task 31 | def get_age(df: pd.DataFrame) -> pd.DataFrame: 32 | return df.assign(age=df["Year_Birth"].apply(lambda row: 2021 - row)) 33 | 34 | 35 | @task 36 | def get_total_children(df: pd.DataFrame) -> pd.DataFrame: 37 | return df.assign(total_children=df["Kidhome"] + df["Teenhome"]) 38 | 39 | 40 | @task 41 | def get_total_purchases(df: pd.DataFrame) -> pd.DataFrame: 42 | purchases_columns = df.filter(like="Purchases", axis=1).columns 43 | return df.assign(total_purchases=df[purchases_columns].sum(axis=1)) 44 | 45 | 46 | @task 47 | def get_enrollment_years(df: pd.DataFrame) -> pd.DataFrame: 48 | df["Dt_Customer"] = pd.to_datetime(df["Dt_Customer"]) 49 | return df.assign(enrollment_years=2022 - df["Dt_Customer"].dt.year) 50 | 51 | 52 | @task 53 | def get_family_size(df: pd.DataFrame, size_map: dict) -> pd.DataFrame: 54 | return df.assign( 55 | family_size=df["Marital_Status"].map(size_map) + df["total_children"] 56 | ) 57 | 58 | 59 | def drop_features(df: pd.DataFrame, keep_columns: list): 60 | df = df[keep_columns] 61 | return df 62 | 63 | 64 | def drop_outliers(df: pd.DataFrame, column_threshold: dict): 65 | for col, threshold in column_threshold.items(): 66 | df = df[df[col] < threshold] 67 | return df.reset_index(drop=True) 68 | 69 | 70 | @task(result=INTERMEDIATE_OUTPUT) 71 | def drop_columns_and_rows( 72 | df: pd.DataFrame, keep_columns: DictConfig, remove_outliers_threshold: DictConfig 73 | ) -> pd.DataFrame: 74 | df = df.pipe(drop_features, keep_columns=keep_columns).pipe( 75 | drop_outliers, column_threshold=remove_outliers_threshold 76 | ) 77 | 78 | return df 79 | 80 | 81 | @task(result=LocalResult("processors", location="scaler.pkl")) 82 | def get_scaler(df: pd.DataFrame): 83 | scaler = StandardScaler() 84 | scaler.fit(df) 85 | 86 | return scaler 87 | 88 | 89 | @task(result=INTERMEDIATE_OUTPUT) 90 | def scale_features(df: pd.DataFrame, scaler: StandardScaler): 91 | return pd.DataFrame(scaler.transform(df), columns=df.columns) 92 | 93 | 94 | @hydra.main( 95 | config_path="../config", 96 | config_name="main", 97 | ) 98 | def process_data(config: DictConfig): 99 | 100 | with Flow("process_data") as flow: 101 | df = load_data(config.raw_data.path) 102 | df = drop_na(df) 103 | df = get_age(df) 104 | df = get_total_children(df) 105 | df = get_total_purchases(df) 106 | df = get_enrollment_years(df) 107 | df = get_family_size(df, config.process.encode.family_size) 108 | df = drop_columns_and_rows( 109 | df, config.process.keep_columns, config.process.remove_outliers_threshold 110 | ) 111 | scaler = get_scaler(df) 112 | df = scale_features(df, scaler) 113 | 114 | flow.run() 115 | # flow.register(project_name="customer_segmentation") 116 | 117 | 118 | if __name__ == "__main__": 119 | process_data() 120 | -------------------------------------------------------------------------------- /src/segment.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Tuple 3 | 4 | import hydra 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | from omegaconf import DictConfig, OmegaConf 9 | from prefect import Flow, case, task 10 | from prefect.engine.results import LocalResult 11 | from prefect.engine.serializers import PandasSerializer 12 | from prefect.tasks.control_flow import merge 13 | from sklearn.cluster import (DBSCAN, OPTICS, AffinityPropagation, 14 | AgglomerativeClustering, Birch, KMeans, MeanShift, 15 | SpectralClustering) 16 | from sklearn.decomposition import PCA 17 | from sklearn.metrics import silhouette_score 18 | from yellowbrick.cluster import KElbowVisualizer 19 | 20 | import wandb 21 | 22 | FINAL_OUTPUT = LocalResult( 23 | "data/final/", 24 | location="{task_name}.csv", 25 | serializer=PandasSerializer("csv", serialize_kwargs={"index": False}), 26 | ) 27 | 28 | 29 | @task 30 | def initialize_wandb(config: DictConfig): 31 | wandb.init( 32 | project="customer_segmentation", 33 | config=OmegaConf.to_object(config), 34 | reinit=True, 35 | mode=config.wandb_mode, 36 | ) 37 | 38 | 39 | @task 40 | def get_pca_model(data: pd.DataFrame) -> PCA: 41 | 42 | pca = PCA(n_components=3) 43 | pca.fit(data) 44 | return pca 45 | 46 | 47 | @task 48 | def reduce_dimension(df: pd.DataFrame, pca: PCA) -> pd.DataFrame: 49 | return pd.DataFrame(pca.transform(df), columns=["col1", "col2", "col3"]) 50 | 51 | 52 | @task 53 | def get_3d_projection(pca_df: pd.DataFrame) -> dict: 54 | """A 3D Projection Of Data In The Reduced Dimensionality Space""" 55 | return {"x": pca_df["col1"], "y": pca_df["col2"], "z": pca_df["col3"]} 56 | 57 | 58 | @task 59 | def check_has_nclusters(config: DictConfig): 60 | args = config.segment.args 61 | return args is not None and "n_clusters" in args 62 | 63 | 64 | @task 65 | def get_best_k_cluster( 66 | pca_df: pd.DataFrame, image_path: str, elbow_metric: str 67 | ) -> pd.DataFrame: 68 | 69 | fig = plt.figure(figsize=(10, 8)) 70 | fig.add_subplot(111) 71 | 72 | elbow = KElbowVisualizer(KMeans(), metric=elbow_metric) 73 | elbow.fit(pca_df) 74 | elbow.fig.savefig(image_path) 75 | 76 | k_best = elbow.elbow_value_ 77 | 78 | # Log 79 | wandb.log( 80 | { 81 | "elbow": wandb.Image(image_path), 82 | "k_best": k_best, 83 | "score_best": elbow.elbow_score_, 84 | } 85 | ) 86 | return k_best 87 | 88 | 89 | @task 90 | def predict_with_predefined_clusters( 91 | pca_df: pd.DataFrame, k: int, model: dict 92 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: 93 | """Get model with the parameter `n_clusters`""" 94 | 95 | model_args = dict(model.args) 96 | model_args["n_clusters"] = k 97 | 98 | model = eval(model.algorithm)(**model_args) 99 | 100 | # Predict 101 | return model.fit_predict(pca_df) 102 | 103 | 104 | @task 105 | def predict_without_predefined_clusters( 106 | pca_df: pd.DataFrame, model: dict 107 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: 108 | """Get model without the parameter `n_clusters`""" 109 | if model.args is None: 110 | model_args = {} 111 | else: 112 | model_args = dict(model.args) 113 | model = eval(model.algorithm)(**model_args) 114 | 115 | # Predict 116 | return model.fit_predict(pca_df) 117 | 118 | 119 | @task 120 | def get_silhouette_score(pca_df: pd.DataFrame, labels: pd.DataFrame) -> float: 121 | sil_score = silhouette_score(pca_df, labels) 122 | wandb.log({"silhouette_score": sil_score}) 123 | return sil_score 124 | 125 | 126 | @task 127 | def plot_silhouette_score( 128 | pca_df: pd.DataFrame, silhouette_score: float, image_path: str 129 | ): 130 | fig = plt.figure(figsize=(10, 8)) 131 | ax = fig.add_subplot(111) 132 | 133 | ax.set_xlim([-1, 1]) 134 | ax.set_ylim([0, len()]) 135 | plt.plot(silhouette_score) 136 | plt.savefig(image_path) 137 | wandb.log({"silhouette_score_plot": wandb.Image(image_path)}) 138 | 139 | 140 | @task(result=FINAL_OUTPUT) 141 | def insert_clusters_to_df(df: pd.DataFrame, clusters: np.ndarray) -> pd.DataFrame: 142 | return df.assign(clusters=clusters) 143 | 144 | 145 | @task 146 | def plot_clusters( 147 | pca_df: pd.DataFrame, preds: np.ndarray, projections: dict, image_path: str 148 | ) -> None: 149 | pca_df["clusters"] = preds 150 | 151 | plt.figure(figsize=(10, 8)) 152 | ax = plt.subplot(111, projection="3d") 153 | ax.scatter( 154 | projections["x"], 155 | projections["y"], 156 | projections["z"], 157 | s=40, 158 | c=pca_df["clusters"], 159 | marker="o", 160 | cmap="Accent", 161 | ) 162 | ax.set_title("The Plot Of The Clusters") 163 | 164 | plt.savefig(image_path) 165 | 166 | # Log plot 167 | wandb.log({"clusters": wandb.Image(image_path)}) 168 | 169 | 170 | @task 171 | def wandb_log(config: DictConfig): 172 | 173 | # log data 174 | wandb.log_artifact(config.raw_data.path, name="raw_data", type="data") 175 | wandb.log_artifact(config.intermediate.path, name="intermediate_data", type="data") 176 | wandb.log_artifact(config.segmented.path, name="segmented_data", type="data") 177 | 178 | # log number of columns 179 | wandb.log({"num_cols": len(config.process.keep_columns)}) 180 | 181 | 182 | @hydra.main(config_path="../config", config_name="main") 183 | def segment(config: DictConfig) -> None: 184 | 185 | with Flow("segmentation") as flow: 186 | 187 | initialize_wandb(config) 188 | 189 | data = pd.read_csv(config.intermediate.path) 190 | pca = get_pca_model(data) 191 | pca_df = reduce_dimension(data, pca) 192 | 193 | projections = get_3d_projection(pca_df) 194 | 195 | has_nclusters = check_has_nclusters(config) 196 | 197 | with case(has_nclusters, True): 198 | k_best = get_best_k_cluster(pca_df, config.image.elbow, config.elbow_metric) 199 | prediction1 = predict_with_predefined_clusters( 200 | pca_df, k_best, config.segment 201 | ) 202 | 203 | with case(has_nclusters, False): 204 | prediction2 = predict_without_predefined_clusters(pca_df, config.segment) 205 | 206 | prediction = merge(prediction1, prediction2) 207 | 208 | score = get_silhouette_score(pca_df, prediction) 209 | 210 | data = insert_clusters_to_df(data, prediction) 211 | 212 | plot_clusters(pca_df, prediction, projections, config.image.clusters) 213 | 214 | wandb_log(config) 215 | 216 | flow.run() 217 | # flow.visualize() 218 | # flow.register(project_name="customer_segmentation") 219 | 220 | 221 | if __name__ == "__main__": 222 | segment() 223 | -------------------------------------------------------------------------------- /src/streamlit_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | 4 | import requests 5 | import streamlit as st 6 | 7 | st.title("Customer Segmentation Web App") 8 | 9 | data = {} 10 | 11 | data["Income"] = st.number_input( 12 | "Income", 13 | min_value=0, 14 | step=500, 15 | value=58138, 16 | help="Customer's yearly household income", 17 | ) 18 | data["Recency"] = st.number_input( 19 | "Recency", 20 | min_value=0, 21 | value=58, 22 | help="Number of days since customer's last purchase", 23 | ) 24 | data["NumWebVisitsMonth"] = st.number_input( 25 | "NumWebVisitsMonth", 26 | min_value=0, 27 | value=7, 28 | help="Number of visits to company’s website in the last month", 29 | ) 30 | data["Complain"] = st.number_input( 31 | "Complain", 32 | min_value=0, 33 | value=7, 34 | help="1 if the customer complained in the last 2 years, 0 otherwise", 35 | ) 36 | data["age"] = st.number_input( 37 | "age", 38 | min_value=0, 39 | value=64, 40 | help="Customer's age", 41 | ) 42 | data["total_purchases"] = st.number_input( 43 | "total_purchases", 44 | min_value=0, 45 | value=25, 46 | help="Total number of purchases through website, catalogue, or store", 47 | ) 48 | data["enrollment_years"] = st.number_input( 49 | "enrollment_years", 50 | min_value=0, 51 | value=10, 52 | help="Number of years a client has enrolled with a company", 53 | ) 54 | data["family_size"] = st.number_input( 55 | "family_size", 56 | min_value=0, 57 | value=1, 58 | help="Total number of members in a customer's family", 59 | ) 60 | 61 | 62 | if st.button("Get the cluster of this customer"): 63 | if not any(math.isnan(v) for v in data.values()): 64 | data_json = json.dumps(data) 65 | 66 | prediction = requests.post( 67 | "http://127.0.0.1:5000/predict", 68 | headers={"content-type": "application/json"}, 69 | data=data_json, 70 | ).text 71 | st.write(f"This customer belongs to the cluster {prediction}") 72 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khuyentran1401/customer_segmentation/5c5890c3e0416128a834d690a771701e0d046ce2/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_process_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas.testing import assert_frame_equal 3 | 4 | from src.process_data import (get_age, get_enrollment_years, get_family_size, 5 | get_total_purchases, scale_features) 6 | 7 | 8 | def test_get_age(): 9 | df = pd.DataFrame({"Year_Birth": [1999, 2000]}) 10 | assert_frame_equal( 11 | get_age(df), 12 | pd.DataFrame({"Year_Birth": [1999, 2000], "age": [22, 21]}), 13 | ) 14 | 15 | 16 | def test_get_total_purchases(): 17 | df = pd.DataFrame({"FirstPurchases": [1, 2], "SecondPurchases": [3, 4]}) 18 | out = get_total_purchases(df) 19 | assert out["total_purchases"].tolist() == [4, 6] 20 | 21 | 22 | def test_get_enrollment_years(): 23 | df = pd.DataFrame({"Dt_Customer": ["04-09-2012"]}) 24 | assert_frame_equal( 25 | get_enrollment_years(df), 26 | pd.DataFrame( 27 | { 28 | "Dt_Customer": [pd.Timestamp("2012-04-09 00:00:00")], 29 | "enrollment_years": [10], 30 | } 31 | ), 32 | ) 33 | 34 | 35 | def test_scale_features(): 36 | df = pd.DataFrame( 37 | {"FirstPurchases": [1, 2, 5], "SecondPurchases": [3, 4, 7]} 38 | ) 39 | out = scale_features.run(df) 40 | assert_frame_equal( 41 | out, 42 | pd.DataFrame( 43 | { 44 | "FirstPurchases": [-0.980, -0.392, 1.373], 45 | "SecondPurchases": [-0.980, -0.392, 1.373], 46 | } 47 | ), 48 | atol=2, 49 | ) 50 | 51 | 52 | def test_get_family_size(): 53 | df = pd.DataFrame( 54 | { 55 | "Marital_Status": ["Married", "Absurd", "Single"], 56 | "total_children": [1, 2, 3], 57 | } 58 | ) 59 | assert_frame_equal( 60 | get_family_size(df, {"Married": 2, "Absurd": 1, "Single": 1}), 61 | pd.DataFrame( 62 | { 63 | "Marital_Status": ["Married", "Absurd", "Single"], 64 | "total_children": [1, 2, 3], 65 | "family_size": [3, 3, 4], 66 | } 67 | ), 68 | ) 69 | --------------------------------------------------------------------------------