├── .gitattributes ├── MANIFEST.in ├── cbm ├── _version.py ├── __init__.py ├── sklearn.py ├── CBM.py └── CBMExplainer.py ├── images └── cbm_kaggle.png ├── tests ├── test_common.py ├── test_sklearn.py └── test_cbm.py ├── pyproject.toml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── src ├── pycbm.h ├── cbm.cpp ├── pycbm.cpp └── cbm.h ├── .vscode └── settings.json ├── .github └── workflows │ ├── publish-to-test-pypi.yml │ └── build.yml ├── .gitignore ├── setup.py ├── SECURITY.md ├── README.md ├── data └── nyc_bb_bicyclist_counts.csv └── kaggle └── favorita-grocery-sales-forecasting └── kaggle.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | images/cbm_kaggle.png filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include setup.py 3 | recursive-include cbm *.py 4 | recursive-include src *.c *.h -------------------------------------------------------------------------------- /cbm/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | __version__ = "0.0.4" -------------------------------------------------------------------------------- /images/cbm_kaggle.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:203399280cbc76634867f1f6731f8f29030afdf8cd194cda0ba3d7435e9f8cff 3 | size 63445 4 | -------------------------------------------------------------------------------- /tests/test_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from sklearn.utils.estimator_checks import check_estimator 5 | 6 | from cbm import CBM 7 | 8 | def test_all_estimators(): 9 | return check_estimator(CBM()) -------------------------------------------------------------------------------- /cbm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .CBM import CBM 5 | from .sklearn import DateEncoder, TemporalSplit 6 | from .CBMExplainer import CBMExplainer 7 | from ._version import __version__ 8 | 9 | __all__ = ['CBM', '__version__', 'DateEncoder', 'TemporalSplit'] -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # These are the assumed default build requirements from pip: 3 | # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support 4 | requires = ["setuptools>=40.8.0", "wheel", "pybind11", "numpy", "pandas", "scikit-learn", "lightgbm"] # , "interpret"] 5 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /src/pycbm.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | Licensed under the MIT License. */ 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cbm.h" 11 | 12 | namespace cbm 13 | { 14 | 15 | namespace py = pybind11; 16 | 17 | class PyCBM 18 | { 19 | 20 | CBM _cbm; 21 | 22 | public: 23 | PyCBM(); 24 | PyCBM(const std::vector> &f, double y_mean); 25 | 26 | void fit( 27 | py::buffer y_b, 28 | py::buffer x_b, 29 | double y_mean, 30 | py::buffer x_max_b, 31 | double learning_rate_step_size, 32 | size_t max_iterations, 33 | size_t min_iterations_early_stopping, 34 | double epsilon_early_stopping, 35 | bool single_update_per_iteration, 36 | std::string metric, 37 | bool enable_bin_count); 38 | 39 | py::array_t predict(py::buffer x_b, bool explain); 40 | 41 | const std::vector> &get_weights() const; 42 | 43 | void set_weights(std::vector> &); 44 | 45 | float get_y_mean() const; 46 | 47 | void set_y_mean(float mean); 48 | 49 | size_t get_iterations() const; 50 | 51 | const std::vector> &get_bin_count() const; 52 | }; 53 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "array": "cpp", 4 | "atomic": "cpp", 5 | "bit": "cpp", 6 | "*.tcc": "cpp", 7 | "cctype": "cpp", 8 | "clocale": "cpp", 9 | "cmath": "cpp", 10 | "cstdarg": "cpp", 11 | "cstddef": "cpp", 12 | "cstdint": "cpp", 13 | "cstdio": "cpp", 14 | "cstdlib": "cpp", 15 | "cwchar": "cpp", 16 | "cwctype": "cpp", 17 | "deque": "cpp", 18 | "unordered_map": "cpp", 19 | "unordered_set": "cpp", 20 | "vector": "cpp", 21 | "exception": "cpp", 22 | "algorithm": "cpp", 23 | "functional": "cpp", 24 | "iterator": "cpp", 25 | "memory": "cpp", 26 | "memory_resource": "cpp", 27 | "numeric": "cpp", 28 | "optional": "cpp", 29 | "random": "cpp", 30 | "string": "cpp", 31 | "string_view": "cpp", 32 | "system_error": "cpp", 33 | "tuple": "cpp", 34 | "type_traits": "cpp", 35 | "utility": "cpp", 36 | "fstream": "cpp", 37 | "initializer_list": "cpp", 38 | "iosfwd": "cpp", 39 | "iostream": "cpp", 40 | "istream": "cpp", 41 | "limits": "cpp", 42 | "new": "cpp", 43 | "ostream": "cpp", 44 | "sstream": "cpp", 45 | "stdexcept": "cpp", 46 | "streambuf": "cpp", 47 | "typeinfo": "cpp" 48 | }, 49 | "jupyter.jupyterServerType": "local", 50 | "C_Cpp.errorSquiggles": "Disabled", 51 | "yaml.schemas": { 52 | "https://json.schemastore.org/github-workflow": "/.github/workflows/**/*.yml" 53 | } 54 | } -------------------------------------------------------------------------------- /tests/test_sklearn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import pytest 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn import linear_model 8 | from sklearn.metrics import make_scorer, mean_squared_error 9 | from sklearn.preprocessing import OrdinalEncoder, KBinsDiscretizer 10 | from sklearn.pipeline import Pipeline, make_pipeline 11 | from sklearn.compose import ColumnTransformer, make_column_transformer 12 | from sklearn.model_selection import train_test_split, GridSearchCV 13 | 14 | import lightgbm as lgb 15 | import timeit 16 | import cbm 17 | 18 | def test_nyc_bicycle_sklearn(): 19 | # read data 20 | bic = pd.read_csv( 21 | 'data/nyc_bb_bicyclist_counts.csv', 22 | parse_dates=['Date']) 23 | 24 | X_train = bic.drop('BB_COUNT', axis=1) 25 | y_train = bic['BB_COUNT'] 26 | 27 | cats = make_column_transformer( 28 | # https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OrdinalEncoder.html 29 | # (OrdinalEncoder(dtype='int', handle_unknown='use_encoded_value', unknown_value=-1), # +1 in CBM code 30 | # ['store_nbr', 'item_nbr', 'onpromotion', 'family', 'class', 'perishable']), 31 | 32 | (cbm.DateEncoder('weekday'), ['Date', 'Date']), 33 | (cbm.DateEncoder('month'), ['Date']), 34 | (KBinsDiscretizer(n_bins=2, encode='ordinal'), ['HIGH_T', 'LOW_T']), 35 | (KBinsDiscretizer(n_bins=5, encode='ordinal'), ['PRECIP']), 36 | ) 37 | 38 | cbm_model = cbm.CBM() 39 | pipeline = make_pipeline(cats, cbm_model) 40 | 41 | cv = GridSearchCV( 42 | pipeline, 43 | param_grid={'columntransformer__kbinsdiscretizer-1__n_bins': np.arange(2, 15)}, 44 | scoring=make_scorer(mean_squared_error, squared=False), 45 | cv=3 46 | ) 47 | 48 | cv.fit(X_train, y_train) 49 | 50 | print(cv.cv_results_['mean_test_score']) 51 | print(cv.best_params_) 52 | 53 | cbm.CBMExplainer(cv.best_estimator_).plot_importance() -------------------------------------------------------------------------------- /.github/workflows/publish-to-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI and TestPyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-windows: 7 | name: Build on Windows 8 | runs-on: windows-latest 9 | steps: 10 | - uses: actions/checkout@master 11 | - name: Set up Python 3.8 12 | uses: actions/setup-python@v1 13 | with: 14 | python-version: 3.8 15 | 16 | - name: Install pypa/build 17 | run: python -m pip install build --user 18 | 19 | - name: Build a binary wheel and a source tarball (Windows) 20 | run: python -m build --sdist --wheel --outdir dist/ . 21 | 22 | - name: Store the binary wheel 23 | uses: actions/upload-artifact@v2 24 | with: 25 | name: python-package-distributions 26 | path: dist 27 | 28 | build-linux: 29 | name: Build on Linux 30 | runs-on: ubuntu-20.04 31 | steps: 32 | - uses: actions/checkout@master 33 | - name: Set up Python 3.8 34 | uses: actions/setup-python@v1 35 | with: 36 | python-version: 3.8 37 | 38 | - name: Install pypa/build 39 | run: python -m pip install build --user 40 | 41 | - name: Build a source tarball 42 | run: python -m build --sdist --outdir dist/ . 43 | 44 | - name: Install from source tarball 45 | run: python -m pip install dist/*.tar.gz --user 46 | 47 | - name: Store the binary wheel 48 | uses: actions/upload-artifact@v2 49 | with: 50 | name: python-package-distributions 51 | path: dist 52 | 53 | publish: 54 | name: Publish to PyPI and TestPyPI 55 | runs-on: ubuntu-20.04 56 | needs: 57 | - build-windows 58 | - build-linux 59 | steps: 60 | - name: Download all the dists 61 | uses: actions/download-artifact@v2 62 | with: 63 | name: python-package-distributions 64 | path: dist/ 65 | 66 | # - name: Publish distribution 📦 to Test PyPI 67 | # uses: pypa/gh-action-pypi-publish@master 68 | # with: 69 | # password: ${{ secrets.TEST_PYPI_API_TOKEN }} 70 | # repository_url: https://test.pypi.org/legacy/ 71 | 72 | - name: Publish distribution 📦 to PyPI 73 | if: startsWith(github.ref, 'refs/tags') 74 | uses: pypa/gh-action-pypi-publish@master 75 | with: 76 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: [push, workflow_dispatch] 3 | jobs: 4 | build: 5 | name: Build 6 | runs-on: ${{ matrix.os }} 7 | strategy: 8 | matrix: 9 | os: [ubuntu-latest, macos-latest, windows-latest] 10 | python: [3.6, 3.7, 3.8] 11 | env: 12 | OS: ${{ matrix.os }} 13 | steps: 14 | - uses: actions/checkout@master 15 | - name: Setup Python 16 | uses: actions/setup-python@master 17 | with: 18 | python-version: ${{ matrix.python }} 19 | 20 | - name: Install dependencies 21 | run: python -m pip install pytest pytest-cov pybind11 numpy pandas scikit-learn lightgbm # interpret 22 | 23 | - name: Install dependency for lightgbm 24 | if: matrix.os == 'macos-latest' 25 | run: brew install libomp 26 | 27 | - name: Install package 28 | run: python -m pip install -e . 29 | 30 | - name: Generate coverage report 31 | run: pytest --cov=./ --cov-report=xml tests/ 32 | - name: Upload coverage to Codecov 33 | uses: codecov/codecov-action@v2 34 | with: 35 | token: ${{ secrets.CODECOV_TOKEN }} 36 | directory: ./coverage/reports/ 37 | env_vars: OS,PYTHON 38 | fail_ci_if_error: true 39 | files: ./coverage.xml 40 | flags: unittests 41 | name: codecov-umbrella 42 | # path_to_write_report: ./coverage/codecov_report.txt 43 | verbose: true 44 | 45 | quality-control: 46 | name: Quality Control 47 | runs-on: ubuntu-latest 48 | steps: 49 | - name: Checkout Code 50 | uses: actions/checkout@v2 51 | with: 52 | # Full git history is needed to get a proper list of changed files within `super-linter` 53 | fetch-depth: 0 54 | 55 | # linting 56 | - name: Lint Code Base 57 | uses: github/super-linter@v4 58 | continue-on-error: true 59 | env: 60 | VALIDATE_ALL_CODEBASE: false 61 | DEFAULT_BRANCH: main 62 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 63 | 64 | - name: CodeQL Init 65 | uses: github/codeql-action/init@v1 66 | with: 67 | languages: python 68 | 69 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 70 | # If this step fails, then you should remove it and run the build manually (see below) 71 | - name: CodeQL Autobuild 72 | uses: github/codeql-action/autobuild@v1 73 | 74 | - name: CodeQL Analysis 75 | uses: github/codeql-action/analyze@v1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sysconfig 2 | 3 | from setuptools import setup 4 | from setuptools.extension import Extension 5 | import platform 6 | 7 | class get_pybind_include(object): 8 | def __init__(self, user=False): 9 | self.user = user 10 | 11 | def __str__(self): 12 | import pybind11 13 | 14 | return pybind11.get_include(self.user) 15 | 16 | def get_extra_compile_args(): 17 | if platform.system() == "Windows": 18 | return "" 19 | 20 | cflags = sysconfig.get_config_var("CFLAGS") 21 | if cflags is None: 22 | cflags = "" 23 | 24 | cflags = cflags.split() \ 25 | + ["-std=c++11", "-Wall", "-Wextra", "-march=native", "-msse2", "-ffast-math", "-mfpmath=sse"] 26 | 27 | if platform.system() == "Linux": 28 | cflags += ["-fopenmp", "-lgomp"] 29 | 30 | return cflags 31 | 32 | def get_libraries(): 33 | if platform.system() == "Windows": 34 | return [] 35 | 36 | return ["stdc++"] 37 | 38 | from pathlib import Path 39 | this_directory = Path(__file__).parent 40 | long_description = (this_directory / "README.md").read_text() 41 | 42 | setup( 43 | name="cyclicbm", 44 | version="0.0.9", 45 | description="Cyclic Boosting Machines", 46 | long_description=long_description, 47 | long_description_content_type='text/markdown', 48 | url="https://github.com/Microsoft/CBM", 49 | author="Markus Cozowicz", 50 | author_email="marcozo@microsoft.com", 51 | license="MIT", 52 | classifiers=[ 53 | "Development Status :: 4 - Beta", 54 | "License :: OSI Approved :: MIT License", 55 | "Programming Language :: Python :: 3.6", 56 | "Programming Language :: Python :: 3.7", 57 | "Programming Language :: Python :: 3.8", 58 | "Intended Audience :: Developers", 59 | "Intended Audience :: Science/Research", 60 | "Topic :: Scientific/Engineering :: Mathematics", 61 | ], 62 | setup_requires=["pytest-runner"], 63 | install_requires=["pybind11>=2.2", "numpy", "scikit-learn", "pandas"], 64 | tests_require=["pytest", "lightgbm"], #, "interpret"], 65 | extras_require={ 66 | 'interactive': ['matplotlib>=2.2.0'], 67 | }, 68 | packages=["cbm"], 69 | ext_modules=[ 70 | Extension( 71 | "cbm_cpp", 72 | ["src/pycbm.cpp", "src/cbm.cpp" ], 73 | include_dirs=[get_pybind_include(), get_pybind_include(user=True), "src"], 74 | extra_compile_args=get_extra_compile_args(), 75 | libraries=get_libraries(), 76 | language="c++11", 77 | extra_link_args=['-fopenmp'] if platform.system() == "Linux" else [] 78 | ) 79 | ], 80 | headers=["src/pycbm.h", "src/cbm.h"], 81 | zip_safe=False, 82 | ) 83 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cyclic Boosting Machines 2 | 3 | ![Build](https://github.com/Microsoft/cbm/workflows/Build/badge.svg) 4 | ![Python](https://img.shields.io/pypi/pyversions/cyclicbm.svg) 5 | [![codecov](https://codecov.io/gh/microsoft/CBM/branch/main/graph/badge.svg?token=VRppFx2o8v)](https://codecov.io/gh/microsoft/CBM) 6 | [![PyPI version](https://badge.fury.io/py/cyclicbm.svg)](https://badge.fury.io/py/cyclicbm) 7 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 8 | [![Academic Paper](https://img.shields.io/badge/academic-paper-7fdcf7)](https://arxiv.org/abs/2002.03425) 9 | 10 | This is an efficient and Scikit-learn compatible implementation of the machine learning algorithm [Cyclic Boosting -- an explainable supervised machine learning algorithm](https://arxiv.org/abs/2002.03425), specifically for predicting count-data, such as sales and demand. 11 | 12 | ## Features 13 | 14 | * Optimized for categorical features 15 | * Continuous features are discretized using [pandas.qcut](https://pandas.pydata.org/docs/reference/api/pandas.qcut.html). 16 | * Date auto-expansion (weekday + month). 17 | * Feature importance plots: categorical, continuous and interactions. 18 | * Metrics to stop training: RMSE, L1, SMAPE. 19 | 20 | ## Usage 21 | 22 | The CBM model predicts by multiplying the global mean with each weight estimate for each bin and feature. Thus the weights can be interpreted as % increase or decrease from the global mean. e.g. a weight of 1.2 for the bin _Monday_ of the feature _Day-of-Week_ can be interpreted as a 20% increase of the target. 23 | 24 | with 25 | 26 | ```bash 27 | pip install cyclicbm 28 | ``` 29 | 30 | ```python 31 | import cbm 32 | from sklearn.metrics import mean_squared_error 33 | 34 | # load data using https://www.kaggle.com/c/demand-forecasting-kernels-only 35 | train = pd.read_csv('data/train.csv', parse_dates=['date']) 36 | test = pd.read_csv('data/test.csv', parse_dates=['date']) 37 | 38 | # feature engineering 39 | min_date = train['date'].min() 40 | 41 | def featurize(df): 42 | out = pd.DataFrame({ 43 | # TODO: for prediction such features need separate modelling 44 | 'seasonal' : (df['date'] - min_date).dt.days // 60, 45 | 'store' : df['store'], 46 | 'item' : df['item'], 47 | 'date' : df['date'], 48 | # _X_ to mark interaction features 49 | 'item_X_month': df['item'].astype(str) + '_' + df['date'].dt.month.astype(str) 50 | }) 51 | 52 | return out 53 | 54 | x_train_df = featurize(train) 55 | x_test_df = featurize(test) 56 | y_train = train['sales'] 57 | 58 | # model training 59 | model = cbm.CBM() 60 | model.fit(x_train_df, y_train) 61 | 62 | # test on train error 63 | y_pred_train = model.predict(x_train_df).flatten() 64 | print('RMSE', mean_squared_error(y_pred_train, y_train, squared=False)) 65 | 66 | # plotting 67 | model.plot_importance(figsize=(20, 20), continuous_features=['seasonal']) 68 | ``` 69 | 70 | ![Feature Importance Plot](images/cbm_kaggle.png) 71 | 72 | ## Contributing 73 | 74 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 75 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 76 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 77 | 78 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 79 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 80 | provided by the bot. You will only need to do this once across all repos using our CLA. 81 | 82 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 83 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 84 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 85 | 86 | ## Trademarks 87 | 88 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 89 | trademarks or logos is subject to and must follow 90 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 91 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 92 | Any use of third-party trademarks or logos are subject to those third-party's policies. 93 | -------------------------------------------------------------------------------- /cbm/sklearn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import calendar 5 | import numpy as np 6 | 7 | from sklearn.model_selection import TimeSeriesSplit 8 | from sklearn.utils import indexable 9 | from sklearn.base import BaseEstimator, TransformerMixin 10 | from sklearn.utils.validation import _check_feature_names_in 11 | 12 | from datetime import timedelta 13 | 14 | # TODO 15 | class TemporalSplit(TimeSeriesSplit): 16 | def __init__(self, step=timedelta(days=1), n_splits=5, *, max_train_size=None, test_size=None, gap=0): 17 | super().__init__(n_splits) 18 | self.step = step 19 | self.max_train_size = max_train_size 20 | self.test_size = test_size 21 | self.gap = gap 22 | 23 | def _create_date_ranges(self, start, end, step): 24 | start_ = start 25 | while start_ < end: 26 | end_ = start_ + step 27 | yield start_ 28 | start_ = end_ 29 | 30 | def split(self, X, y=None, groups=None): 31 | """Generate indices to split data into training and test set. 32 | Parameters 33 | ---------- 34 | X : array-like of shape (n_samples, n_features) 35 | Training data, where `n_samples` is the number of samples 36 | and `n_features` is the number of features. 37 | y : array-like of shape (n_samples,) 38 | Always ignored, exists for compatibility. 39 | groups : array-like of shape (n_samples,) 40 | Always ignored, exists for compatibility. 41 | Yields 42 | ------ 43 | train : ndarray 44 | The training set indices for that split. 45 | test : ndarray 46 | The testing set indices for that split. 47 | """ 48 | X, y, groups = indexable(X, y, groups) 49 | 50 | date_range = list(self._create_date_ranges(X.index.min(), X.index.max(), self.step)) 51 | n_samples = len(date_range) 52 | n_splits = self.n_splits 53 | n_folds = n_splits + 1 54 | gap = self.gap 55 | test_size = ( 56 | self.test_size if self.test_size is not None else n_samples // n_folds 57 | ) 58 | 59 | # Make sure we have enough samples for the given split parameters 60 | if n_folds > n_samples: 61 | raise ValueError( 62 | f"Cannot have number of folds={n_folds} greater" 63 | f" than the number of samples={n_samples}." 64 | ) 65 | if n_samples - gap - (test_size * n_splits) <= 0: 66 | raise ValueError( 67 | f"Too many splits={n_splits} for number of samples" 68 | f"={n_samples} with test_size={test_size} and gap={gap}." 69 | ) 70 | 71 | test_starts = range(n_samples - n_splits * test_size, n_samples, test_size) 72 | 73 | for test_start in test_starts: 74 | train_end = test_start - gap 75 | if self.max_train_size and self.max_train_size < train_end: 76 | yield ( 77 | np.where(np.logical_and(X.index >= date_range[train_end - self.max_train_size], X.index <= date_range[train_end - 1]))[0], 78 | np.where(np.logical_and(X.index >= date_range[test_start], X.index <= date_range[test_start + test_size - 1]))[0] 79 | ) 80 | else: 81 | yield ( 82 | np.where(X.index < date_range[train_end])[0], 83 | np.where(np.logical_and(X.index >= date_range[test_start], X.index <= date_range[test_start + test_size - 1]))[0] 84 | ) 85 | 86 | 87 | # TODO: add unit test 88 | class DateEncoder(BaseEstimator, TransformerMixin): 89 | def __init__(self, component = 'month' ): 90 | if component == 'weekday': 91 | self.categories_ = list(calendar.day_abbr) 92 | self.column_to_ordinal_ = lambda col: col.dt.weekday.values 93 | elif component == 'dayofyear': 94 | self.categories_ = list(range(1, 366)) 95 | self.column_to_ordinal_ = lambda col: col.dt.dayofyear.values 96 | elif component == 'month': 97 | self.categories_ = list(calendar.month_abbr) 98 | self.column_to_ordinal_ = lambda col: col.dt.month.values 99 | else: 100 | raise ValueError('component must be either day or month') 101 | 102 | self.component = component 103 | 104 | def fit(self, X, y = None): 105 | self._validate_data(X, dtype="datetime64") 106 | 107 | return self 108 | 109 | def transform(self, X, y = None): 110 | return X.apply(self.column_to_ordinal_, axis=0) 111 | -------------------------------------------------------------------------------- /src/cbm.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | Licensed under the MIT License. */ 3 | 4 | #include "cbm.h" 5 | 6 | namespace cbm 7 | { 8 | CBM::CBM() : _iterations(0) 9 | { 10 | } 11 | 12 | CBM::CBM(const std::vector> &f, double y_mean) : 13 | _f(f), _y_mean(y_mean), _iterations(0) 14 | { 15 | } 16 | 17 | const std::vector> &CBM::get_weights() const 18 | { 19 | return _f; 20 | } 21 | 22 | void CBM::set_weights(std::vector> &w) 23 | { 24 | _f = w; 25 | } 26 | 27 | float CBM::get_y_mean() const 28 | { 29 | return _y_mean; 30 | } 31 | 32 | void CBM::set_y_mean(float y_mean) 33 | { 34 | _y_mean = y_mean; 35 | } 36 | 37 | size_t CBM::get_iterations() const 38 | { 39 | return _iterations; 40 | } 41 | 42 | const std::vector> & CBM::get_bin_count() const { 43 | return _bin_count; 44 | } 45 | 46 | void CBM::fit( 47 | const uint32_t *y, 48 | const char *x_data, 49 | size_t x_stride0, 50 | size_t x_stride1, 51 | size_t n_examples, 52 | size_t n_features, 53 | double y_mean, 54 | const uint32_t *x_max, 55 | double learning_rate_step_size, 56 | size_t max_iterations, 57 | size_t min_iterations_early_stopping, 58 | double epsilon_early_stopping, 59 | bool single_update_per_iteration, 60 | uint8_t x_bytes_per_feature, 61 | float (*metric)(const uint32_t*, const double*, size_t n_examples), 62 | bool enable_bin_count) 63 | { 64 | switch (x_bytes_per_feature) 65 | { 66 | case 1: 67 | if (enable_bin_count) 68 | fit_internal(y, x_data, x_stride0, x_stride1, n_examples, n_features, y_mean, x_max, learning_rate_step_size, max_iterations, min_iterations_early_stopping, epsilon_early_stopping, single_update_per_iteration, metric); 69 | else 70 | fit_internal(y, x_data, x_stride0, x_stride1, n_examples, n_features, y_mean, x_max, learning_rate_step_size, max_iterations, min_iterations_early_stopping, epsilon_early_stopping, single_update_per_iteration, metric); 71 | break; 72 | case 2: 73 | if (enable_bin_count) 74 | fit_internal(y, x_data, x_stride0, x_stride1, n_examples, n_features, y_mean, x_max, learning_rate_step_size, max_iterations, min_iterations_early_stopping, epsilon_early_stopping, single_update_per_iteration, metric); 75 | else 76 | fit_internal(y, x_data, x_stride0, x_stride1, n_examples, n_features, y_mean, x_max, learning_rate_step_size, max_iterations, min_iterations_early_stopping, epsilon_early_stopping, single_update_per_iteration, metric); 77 | break; 78 | case 4: 79 | if (enable_bin_count) 80 | fit_internal(y, x_data, x_stride0, x_stride1, n_examples, n_features, y_mean, x_max, learning_rate_step_size, max_iterations, min_iterations_early_stopping, epsilon_early_stopping, single_update_per_iteration, metric); 81 | else 82 | fit_internal(y, x_data, x_stride0, x_stride1, n_examples, n_features, y_mean, x_max, learning_rate_step_size, max_iterations, min_iterations_early_stopping, epsilon_early_stopping, single_update_per_iteration, metric); 83 | break; 84 | } 85 | } 86 | 87 | float metric_RMSE(const uint32_t* y, const double* y_hat, size_t n_examples) 88 | { 89 | double rmse = 0; 90 | #pragma omp parallel for schedule(static, 10000) reduction(+: rmse) 91 | for (size_t i = 0; i < n_examples; i++) 92 | rmse += (y_hat[i] - y[i]) * (y_hat[i] - y[i]); 93 | 94 | return std::sqrt(rmse); 95 | } 96 | 97 | float metric_SMAPE(const uint32_t* y, const double* y_hat, size_t n_examples) 98 | { 99 | double smape = 0; 100 | #pragma omp parallel for schedule(static, 10000) reduction(+: smape) 101 | for (size_t i = 0; i < n_examples; i++) { 102 | if (y[i] == 0 && y_hat[i] == 0) 103 | continue; 104 | smape += std::abs(y[i] - y_hat[i]) / (y[i] + y_hat[i]); 105 | } 106 | 107 | return (200 * smape) / n_examples; 108 | } 109 | 110 | 111 | float metric_L1(const uint32_t* y, const double* y_hat, size_t n_examples) 112 | { 113 | double l1 = 0; 114 | #pragma omp parallel for schedule(static, 10000) reduction(+: l1) 115 | for (size_t i = 0; i < n_examples; i++) 116 | l1 += std::abs(y_hat[i] - y[i]); 117 | 118 | return l1 / n_examples; 119 | } 120 | } -------------------------------------------------------------------------------- /cbm/CBM.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import cbm_cpp 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from sklearn.base import BaseEstimator 9 | from sklearn.utils.validation import check_X_y, check_array, check_is_fitted 10 | from sklearn.base import BaseEstimator 11 | from sklearn.utils.validation import check_X_y, check_array, check_is_fitted 12 | from typing import List, Tuple, Union 13 | from pandas.api.types import CategoricalDtype 14 | 15 | 16 | class CBM(BaseEstimator): 17 | cpp: cbm_cpp.PyCBM 18 | 19 | def __init__(self, 20 | learning_rate_step_size:float = 1/100, 21 | max_iterations:int = 100, 22 | min_iterations_early_stopping:int = 20, 23 | epsilon_early_stopping:float = 1e-3, 24 | single_update_per_iteration:bool = True, 25 | metric: str = 'rmse', 26 | enable_bin_count: bool = False 27 | ) -> None: 28 | """Initialize the CBM model. 29 | 30 | Args: 31 | learning_rate_step_size (float, optional): [description]. Defaults to 1/100. 32 | max_iterations (int, optional): [description]. Defaults to 100. 33 | min_iterations_early_stopping (int, optional): [description]. Defaults to 20. 34 | epsilon_early_stopping (float, optional): [description]. Defaults to 1e-3. 35 | single_update_per_iteration (bool, optional): [description]. Defaults to True. 36 | date_features (List[str], optional): [description]. Defaults to ['day', 'month']. 37 | binning (Union[int, lambda x, optional): [description]. Defaults to 10. 38 | The number of bins to create for continuous features. Supply lambda for flexible binning. 39 | metric (str): [description]. Used to determine when to stop. Defaults to 'rmse'. Options are rmse, smape, l1. 40 | """ 41 | 42 | self.learning_rate_step_size = learning_rate_step_size 43 | self.max_iterations = max_iterations 44 | self.min_iterations_early_stopping = min_iterations_early_stopping 45 | self.epsilon_early_stopping = epsilon_early_stopping 46 | self.single_update_per_iteration = single_update_per_iteration 47 | self.enable_bin_count = enable_bin_count 48 | 49 | self.metric = metric 50 | 51 | def fit(self, 52 | X: Union[np.ndarray, pd.DataFrame], 53 | y: np.ndarray 54 | ) -> "CBM": 55 | 56 | X, y = check_X_y(X, y, y_numeric=True) 57 | 58 | # pre-processing 59 | y_mean = np.average(y) 60 | 61 | # determine max bin per categorical 62 | x_max = X.max(axis=0) 63 | x_max_max = x_max.max() 64 | 65 | if x_max_max <= 255: 66 | self._x_type = "uint8" 67 | elif x_max_max <= 65535: 68 | self._x_type = "uint16" 69 | elif x_max_max <= 4294967295: 70 | self._x_type = "uint32" 71 | else: 72 | raise ValueError("Maximum of 255 categories per features") 73 | 74 | X = X.astype(self._x_type) 75 | 76 | self._cpp = cbm_cpp.PyCBM() 77 | self._cpp.fit( 78 | y.astype("uint32"), 79 | X, 80 | y_mean, 81 | x_max.astype("uint32"), 82 | self.learning_rate_step_size, 83 | self.max_iterations, 84 | self.min_iterations_early_stopping, 85 | self.epsilon_early_stopping, 86 | self.single_update_per_iteration, 87 | self.metric, 88 | self.enable_bin_count 89 | ) 90 | 91 | self.is_fitted_ = True 92 | 93 | return self 94 | 95 | def predict(self, X: np.ndarray, explain: bool = False): 96 | X = check_array(X) 97 | check_is_fitted(self, "is_fitted_") 98 | 99 | return self._cpp.predict(X.astype(self._x_type), explain) 100 | 101 | def update(self, weights: list, y_mean: float): 102 | if "_cpp" not in self.__dict__: 103 | self._cpp = cbm_cpp.PyCBM() 104 | 105 | x_max_max = max(map(len, weights)) 106 | if x_max_max <= 255: 107 | self._x_type = "uint8" 108 | elif x_max_max <= 65535: 109 | self._x_type = "uint16" 110 | elif x_max_max <= 4294967295: 111 | self._x_type = "uint32" 112 | else: 113 | raise ValueError("Maximum of 255 categories per features") 114 | 115 | self._cpp.weights = weights 116 | self._cpp.y_mean = y_mean 117 | 118 | self.is_fitted_ = True 119 | 120 | @property 121 | def weights(self): 122 | return self._cpp.weights 123 | 124 | @property 125 | def y_mean(self): 126 | return self._cpp.y_mean 127 | 128 | @property 129 | def iterations(self): 130 | return self._cpp.iterations 131 | 132 | @property 133 | def bin_count(self): 134 | return self._cpp.bin_count -------------------------------------------------------------------------------- /tests/test_cbm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import pytest 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn import linear_model 8 | from sklearn.metrics import mean_squared_error 9 | from sklearn.preprocessing import OneHotEncoder 10 | # from interpret.glassbox import ExplainableBoostingRegressor 11 | 12 | import lightgbm as lgb 13 | import timeit 14 | import cbm 15 | 16 | def test_poisson_random(): 17 | np.random.seed(42) 18 | 19 | # n = 1000 # test w/ 100, 1000, 10000, 100000 20 | # features = 2 21 | # bins = 2 22 | 23 | # y_base = np.random.poisson([[1, 3], [7, 20]], (n, features, bins)) 24 | 25 | # x = np.random.randint(0, bins, (n, features), dtype='uint8') 26 | 27 | # y = np.zeros(n, dtype='uint32') 28 | # # TODO: figure out proper take, take_along_axis, ... 29 | # for i, idx in enumerate(x): 30 | # y[i] = y_base[i, idx[0], idx[1]] 31 | 32 | def test_nyc_bicycle_validate(): 33 | np.random.seed(42) 34 | 35 | # read data 36 | bic = pd.read_csv('data/nyc_bb_bicyclist_counts.csv') 37 | bic['Date'] = pd.to_datetime(bic['Date']) 38 | bic['Weekday'] = bic['Date'].dt.weekday 39 | 40 | y = bic['BB_COUNT'].values.astype('uint32') 41 | 42 | # train/test split 43 | split = int(len(y) * 0.8) 44 | train_idx = np.arange(0, split) 45 | test_idx = np.arange(split + 1, len(y)) 46 | 47 | y_train = y[train_idx] 48 | y_test = y[test_idx] 49 | 50 | test_err_expected = {2: 449.848, 3: 533.465, 4: 503.399, 5: 534.738, 6: 527.854, 7: 529.942, 8: 597.041, 9: 615.646, 10: 560.182} 51 | train_err_expected = {2: 632.521, 3: 578.816, 4: 588.342, 5: 563.843, 6: 552.219, 7: 547.073, 8: 518.893, 9: 525.629, 10: 523.194} 52 | 53 | for bins in [2, 3, 4, 5, 6, 7, 8, 9, 10]: 54 | x = np.stack([ 55 | bic['Weekday'].values, 56 | pd.qcut(bic['HIGH_T'], bins).cat.codes, 57 | pd.qcut(bic['LOW_T'], bins).cat.codes, 58 | pd.qcut(bic['PRECIP'], 5, duplicates='drop').cat.codes 59 | ], 60 | axis=1)\ 61 | .astype('uint8') 62 | 63 | x_train = x[train_idx, ] 64 | x_test = x[test_idx, ] 65 | 66 | # fit CBM model 67 | model = cbm.CBM(single_update_per_iteration=False) 68 | model.fit(x_train, y_train) 69 | 70 | y_pred = model.predict(x_test) 71 | y_pred_train = model.predict(x_train) 72 | 73 | test_err = mean_squared_error(y_test, y_pred, squared=False) 74 | train_err = mean_squared_error(y_train, y_pred_train, squared=False) 75 | 76 | assert test_err_expected[bins] == pytest.approx(test_err, abs=1e-2) 77 | assert train_err_expected[bins] == pytest.approx(train_err, abs=1e-2) 78 | 79 | def test_nyc_bicycle(): 80 | np.random.seed(42) 81 | 82 | # read data 83 | bic = pd.read_csv('data/nyc_bb_bicyclist_counts.csv') 84 | bic['Date'] = pd.to_datetime(bic['Date']) 85 | bic['Weekday'] = bic['Date'].dt.weekday 86 | 87 | y = bic['BB_COUNT'].values.astype('uint32') 88 | 89 | # train/test split 90 | split = int(len(y) * 0.8) 91 | train_idx = np.arange(0, split) 92 | test_idx = np.arange(split + 1, len(y)) 93 | 94 | y_train = y[train_idx] 95 | y_test = y[test_idx] 96 | 97 | #### CBM 98 | 99 | # TODO: move to CBM.py and support pandas interface? 100 | # CBM can only handle categorical information 101 | # def histedges_equalN(x, nbin): 102 | # npt = len(x) 103 | # return np.interp(np.linspace(0, npt, nbin + 1), 104 | # np.arange(npt), 105 | # np.sort(x)) 106 | 107 | # def histedges_equalN(x, nbin): 108 | # return pd.qcut(x, nbin) 109 | 110 | print() 111 | # some hyper-parameter che.. ehm tuning 112 | for bins in [2, 3, 4, 5, 6, 7, 8, 9, 10]: 113 | x = np.stack([ 114 | bic['Weekday'].values, 115 | pd.qcut(bic['HIGH_T'], bins).cat.codes, 116 | pd.qcut(bic['LOW_T'], bins).cat.codes, 117 | pd.qcut(bic['PRECIP'], 5, duplicates='drop').cat.codes 118 | ], 119 | axis=1)\ 120 | .astype('uint8') 121 | 122 | x_train = x[train_idx, ] 123 | x_test = x[test_idx, ] 124 | 125 | start = timeit.timeit() 126 | 127 | # fit CBM model 128 | model = cbm.CBM(single_update_per_iteration=False) 129 | model.fit(x_train, y_train) 130 | 131 | y_pred = model.predict(x_test) 132 | y_pred_train = model.predict(x_train) 133 | 134 | # y_pred_explain[:, 0] --> predictions 135 | # y_pred_explain[:, 1:] --> explainations in-terms of multiplicative deviation from global mean 136 | y_pred_explain = model.predict(x_test, explain=True) 137 | 138 | # print("x", x_test[:3]) 139 | # print("y", y_pred_explain[:3]) 140 | # print("f", model.weights) 141 | 142 | # validate data predictions line up 143 | # print(np.all(y_pred[:, 0] == y_pred_explain[:,0])) 144 | 145 | print(f"CMB: {mean_squared_error(y_test, y_pred, squared=False):1.4f} (train {mean_squared_error(y_train, y_pred_train, squared=False):1.4f}) bins={bins} {timeit.timeit() - start}sec") 146 | print("weights", model.weights) 147 | print(f"y_mean: {model.y_mean}") 148 | # print(np.stack((y, y_pred))[:5,].transpose()) 149 | 150 | model2 = cbm.CBM() 151 | model2.update(model.weights, model.y_mean) 152 | 153 | y_pred_train2 = model2.predict(x_train) 154 | 155 | # print(y_pred_train[:10]) 156 | # print(y_pred_train2[:10]) 157 | print('Must match: ', np.allclose(y_pred_train, y_pred_train2)) 158 | 159 | #### Poisson Regression 160 | 161 | # one-hot encode categorical 162 | start = timeit.timeit() 163 | 164 | x = bic['Weekday'].values.reshape((-1,1)).astype('uint8') 165 | 166 | enc = OneHotEncoder() 167 | enc.fit(x) 168 | x = enc.transform(x) 169 | 170 | x = np.hstack([x.todense(), bic[['HIGH_T', 'LOW_T', 'PRECIP']].values]) 171 | 172 | clf = linear_model.PoissonRegressor() 173 | clf.fit(x[train_idx, ], y_train) 174 | 175 | y_pred = clf.predict(x[test_idx, ]) 176 | print(f"Poisson Reg: {mean_squared_error(y_test, y_pred, squared=False):1.4f} {timeit.timeit() - start}sec") 177 | # print(np.stack((y, y_pred))[:5,].transpose()) 178 | 179 | #### LightGBM 180 | 181 | start = timeit.timeit() 182 | 183 | # train_data = lgb.Dataset(x, label=y, categorical_feature=[0, 1]) 184 | x = bic[['Weekday', 'HIGH_T', 'LOW_T', 'PRECIP']].values 185 | 186 | train_data = lgb.Dataset(x[train_idx, ], label=y_train, categorical_feature=[0]) 187 | model = lgb.train({ 188 | 'objective': 'poisson', 189 | 'metric': ['poisson', 'rmse'], 190 | 'verbose': -1, 191 | }, train_data) 192 | 193 | y_pred = model.predict(x[test_idx, ]) 194 | print(f"LightGBM Reg: {mean_squared_error(y_test, y_pred, squared=False):1.4f} {timeit.timeit() - start}sec") 195 | # print(np.stack((y, y_pred))[:5,].transpose()) 196 | 197 | 198 | #### EBM 199 | # start = timeit.timeit() 200 | 201 | # ebm = ExplainableBoostingRegressor(random_state=23, max_bins=8) #, outer_bags=25, inner_bags=25) 202 | # ebm.fit(x[train_idx], y_train) 203 | 204 | # y_pred = ebm.predict(x[test_idx,]) 205 | # print(f"EBM: {mean_squared_error(y_test, y_pred, squared=False):1.4f} {timeit.timeit() - start}sec") -------------------------------------------------------------------------------- /cbm/CBMExplainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from multiprocessing.sharedctypes import Value 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from argparse import ArgumentTypeError 9 | from sklearn.pipeline import Pipeline 10 | from sklearn.compose import ColumnTransformer 11 | from sklearn.preprocessing import OrdinalEncoder, KBinsDiscretizer 12 | from typing import List, Tuple, Union 13 | 14 | from .sklearn import DateEncoder 15 | from .CBM import CBM 16 | 17 | import matplotlib.pyplot as plt 18 | 19 | from abc import ABC, abstractmethod 20 | 21 | # convenience registry to extract x-axis category labels from fitted transformers 22 | TRANSFORMER_INVERTER_REGISTRY = {} 23 | 24 | def transformer_inverter(transformer_class): 25 | def decorator(inverter_class): 26 | TRANSFORMER_INVERTER_REGISTRY[transformer_class] = inverter_class() 27 | return inverter_class 28 | 29 | return decorator 30 | 31 | # return categories for each feature_names_in_ 32 | class TransformerInverter(ABC): 33 | @abstractmethod 34 | def get_category_names(self, transformer): 35 | pass 36 | 37 | @transformer_inverter(DateEncoder) 38 | class DateEncoderInverter(TransformerInverter): 39 | def get_category_names(self, transformer): 40 | # for each feature we return the set of category labels 41 | return list(map(lambda _: transformer.categories_, transformer.feature_names_in_)) 42 | 43 | @transformer_inverter(KBinsDiscretizer) 44 | class KBinsDiscretizerInverter(TransformerInverter): 45 | def get_category_names(self, transformer): 46 | if transformer.encode != "ordinal": 47 | raise ValueError("Only ordinal encoding supported") 48 | 49 | # bin_edges is feature x bins 50 | def bin_edges_to_str(bin_edges: np.ndarray): 51 | return pd.IntervalIndex(pd.arrays.IntervalArray.from_breaks(np.concatenate([[-np.inf], bin_edges, [np.inf]]))) 52 | 53 | return list(map(bin_edges_to_str, transformer.bin_edges_)) 54 | 55 | @transformer_inverter(OrdinalEncoder) 56 | class OrdinalEncoderInverter(TransformerInverter): 57 | def get_category_names(self, transformer): 58 | return transformer.categories_ 59 | 60 | class CBMExplainerPlot: 61 | feature_index_: int 62 | feature_plots: List[dict] 63 | 64 | def __init__(self): 65 | self.feature_index_ = 0 66 | self.feature_plots_ = [] 67 | 68 | def add_feature_plot(self, col_name: str, x_axis: List): 69 | self.feature_plots_.append({ 70 | "col_name": col_name, 71 | "x_axis": x_axis, 72 | "feature_index": self.feature_index_, 73 | }) 74 | 75 | # increment feature index (assume they are added in order) 76 | self.feature_index_ += 1 77 | 78 | def _plot_categorical(self, ax: plt.Axes, vmin: float, vmax: float, weights: np.ndarray, col_name: str, x_axis, **kwargs): 79 | cmap = plt.get_cmap("RdYlGn") 80 | 81 | is_continuous = isinstance(x_axis, pd.IntervalIndex) 82 | 83 | # plot positive/negative impact (so 1.x to 0.x) 84 | weights -= 1 85 | 86 | alpha = 1 87 | if is_continuous: 88 | ax.plot(range(len(weights)), weights) 89 | alpha = 0.3 90 | 91 | # normalize for color map 92 | weights_normalized = (weights - vmin) / (vmax - vmin) 93 | 94 | # draw bars 95 | ax.bar(range(len(weights)), weights, color=cmap(weights_normalized), edgecolor='black', alpha=alpha) 96 | 97 | ax.set_ylim(vmin-0.1, vmax+0.1) 98 | 99 | ax.set_ylabel('% change') 100 | 101 | ax.set_xlabel(col_name) 102 | 103 | if not is_continuous: 104 | ax.set_xticks(range(len(x_axis))) 105 | ax.set_xticklabels(x_axis, rotation=45) 106 | 107 | # TODO: support 2D interaction plots 108 | # def _plot_importance_interaction(self, ax, feature_idx: int, vmin: float, vmax: float): 109 | # import matplotlib.pyplot as plt 110 | 111 | # weights = np.array(self.weights[feature_idx]) - 0 112 | 113 | # cat_df = pd.DataFrame( 114 | # [(int(c.split('_')[-1]), int(c.split('_')[1]), i) for i, c in enumerate(self._feature_categories[feature_idx])], 115 | # columns=['f-1', 'f1', 'idx']) 116 | 117 | # cat_df.sort_values(['f-1', 'f1'], inplace=True) 118 | 119 | # cat_df_1d = cat_df.pivot(index='f0', columns='f1', values='idx') 120 | 121 | # # resort index by mean weight value 122 | # zi = np.array(weights)[cat_df_1d.to_numpy()] 123 | 124 | # sort_order = np.argsort(np.max(zi, axis=0)) 125 | # cat_df_1d = cat_df_2d.reindex(cat_df_2d.index[sort_order]) 126 | 127 | # # construct data matrices 128 | # xi = cat_df_1d.columns 129 | # yi = cat_df_1d.index 130 | # zi = np.array(weights)[cat_df_1d.to_numpy()] 131 | 132 | # im = ax.imshow(zi, cmap=plt.get_cmap("RdYlGn"), aspect='auto', vmin=vmin, vmax=vmax) 133 | 134 | # cbar = ax.figure.colorbar(im, ax=ax) 135 | # cbar.ax.set_ylabel('% change', rotation=-91, va="bottom") 136 | 137 | # if self._feature_names is not None: 138 | # names = self._feature_names[feature_idx].split('_X_') 139 | # ax.set_ylabel(names[-1]) 140 | # ax.set_xlabel(names[0]) 141 | 142 | # # Show all ticks and label them with the respective list entries 143 | # ax.set_xticks(np.arange(len(xi)), labels=xi) 144 | # ax.set_yticks(np.arange(len(yi)), labels=yi) 145 | 146 | def plot(self, model: CBM, **kwargs) -> Tuple[plt.Figure, plt.Axes]: 147 | num_plots = max(self.feature_plots_, key=lambda d: d["feature_index"])["feature_index"] + 1 148 | n_features = len(model.weights) 149 | 150 | if num_plots != n_features: 151 | raise ValueError(f"Missing plots for some features ({num_plots} vs {n_features})") 152 | 153 | # setup plot 154 | n_rows = num_plots 155 | n_cols = 1 156 | 157 | fig, ax = plt.subplots(n_rows, n_cols, **kwargs) 158 | 159 | for i in range(num_plots): 160 | ax[i].set_axis_off() 161 | 162 | fig.suptitle(f'Response mean: {model.y_mean:0.2f} | Iterations {model.iterations}') 163 | 164 | # extract weights from model 165 | weights = model.weights 166 | 167 | # find global min/max 168 | vmin = np.min([np.min(w) for w in weights]) - 1 169 | vmax = np.max([np.max(w) for w in weights]) - 1 170 | 171 | for feature_idx in range(n_features): 172 | ax_sub = ax[feature_idx] 173 | ax_sub.set_axis_on() 174 | 175 | feature_weights = np.array(weights[feature_idx]) 176 | 177 | self._plot_categorical(ax_sub, vmin, vmax, feature_weights, **self.feature_plots_[feature_idx]) 178 | 179 | plt.tight_layout() 180 | 181 | return fig, ax 182 | 183 | class CBMExplainer: 184 | def __init__(self, pipeline: Pipeline): 185 | if not isinstance(pipeline, Pipeline): 186 | raise ArgumentTypeError("pipeline must be of type sklearn.pipeline.Pipeline") 187 | 188 | self.pipeline_ = pipeline 189 | 190 | def _plot_column_transformer(self, transformer: ColumnTransformer, plot: CBMExplainerPlot): 191 | # need to access transformers_ (vs transformers) to get the fitted transformer instance 192 | for (name, transformer, cols) in transformer.transformers_: 193 | # extension methods ;) 194 | transformer_inverter = TRANSFORMER_INVERTER_REGISTRY[type(transformer)] 195 | category_names = transformer_inverter.get_category_names(transformer) 196 | 197 | for (col_name, cat) in zip(cols, category_names): 198 | plot.add_feature_plot(col_name, cat) 199 | 200 | def plot_importance(self, **kwargs) -> Tuple[plt.Figure, plt.Axes]: 201 | plot = CBMExplainerPlot() 202 | 203 | # iterate through pipeline 204 | for (name, component) in self.pipeline_.steps[0:-1]: 205 | if isinstance(component, ColumnTransformer): 206 | self._plot_column_transformer(component, plot) 207 | 208 | model = self.pipeline_.steps[-1][1] 209 | return plot.plot(model, **kwargs) 210 | -------------------------------------------------------------------------------- /data/nyc_bb_bicyclist_counts.csv: -------------------------------------------------------------------------------- 1 | Date,HIGH_T,LOW_T,PRECIP,BB_COUNT 2 | 1-Apr-17,46.00,37.00,0.00,606 3 | 2-Apr-17,62.10,41.00,0.00,2021 4 | 3-Apr-17,63.00,50.00,0.03,2470 5 | 4-Apr-17,51.10,46.00,1.18,723 6 | 5-Apr-17,63.00,46.00,0.00,2807 7 | 6-Apr-17,48.90,41.00,0.73,461 8 | 7-Apr-17,48.00,43.00,0.01,1222 9 | 8-Apr-17,55.90,39.90,0.00,1674 10 | 9-Apr-17,66.00,45.00,0.00,2375 11 | 10-Apr-17,73.90,55.00,0.00,3324 12 | 11-Apr-17,80.10,62.10,0.00,3887 13 | 12-Apr-17,73.90,57.90,0.02,2565 14 | 13-Apr-17,64.00,48.90,0.00,3353 15 | 14-Apr-17,64.90,48.90,0.00,2942 16 | 15-Apr-17,64.90,52.00,0.00,2253 17 | 16-Apr-17,84.90,62.10,0.01,2877 18 | 17-Apr-17,73.90,64.00,0.01,3152 19 | 18-Apr-17,66.00,50.00,0.00,3415 20 | 19-Apr-17,52.00,45.00,0.01,1965 21 | 20-Apr-17,64.90,50.00,0.17,1567 22 | 21-Apr-17,53.10,48.00,0.29,1426 23 | 22-Apr-17,55.90,52.00,0.11,1318 24 | 23-Apr-17,64.90,46.90,0.00,2520 25 | 24-Apr-17,60.10,50.00,0.01,2544 26 | 25-Apr-17,54.00,50.00,0.91,611 27 | 26-Apr-17,59.00,54.00,0.34,1247 28 | 27-Apr-17,68.00,59.00,0.00,2959 29 | 28-Apr-17,82.90,57.90,0.00,3679 30 | 29-Apr-17,84.00,64.00,0.06,3315 31 | 30-Apr-17,64.00,54.00,0.00,2225 32 | 1-May-17,72.00,50.00,0.00,3084 33 | 2-May-17,73.90,66.90,0.00,3423 34 | 3-May-17,64.90,57.90,0.00,3342 35 | 4-May-17,63.00,50.00,0.00,3019 36 | 5-May-17,59.00,52.00,3.02,513 37 | 6-May-17,64.90,57.00,0.18,1892 38 | 7-May-17,54.00,48.90,0.01,3539 39 | 8-May-17,57.00,45.00,0.00,2886 40 | 9-May-17,61.00,48.00,0.00,2718 41 | 10-May-17,70.00,51.10,0.00,2810 42 | 11-May-17,61.00,51.80,0.00,2657 43 | 12-May-17,62.10,51.10,0.00,2640 44 | 13-May-17,51.10,45.00,1.31,151 45 | 14-May-17,64.90,46.00,0.02,1452 46 | 15-May-17,66.90,55.90,0.00,2685 47 | 16-May-17,78.10,57.90,0.00,3666 48 | 17-May-17,90.00,66.00,0.00,3535 49 | 18-May-17,91.90,75.00,0.00,3190 50 | 19-May-17,90.00,75.90,0.00,2952 51 | 20-May-17,64.00,55.90,0.01,2161 52 | 21-May-17,66.90,55.00,0.00,2612 53 | 22-May-17,61.00,54.00,0.59,768 54 | 23-May-17,68.00,57.90,0.00,3174 55 | 24-May-17,66.90,57.00,0.04,2969 56 | 25-May-17,57.90,55.90,0.58,488 57 | 26-May-17,73.00,55.90,0.10,2590 58 | 27-May-17,71.10,61.00,0.00,2609 59 | 28-May-17,71.10,59.00,0.00,2640 60 | 29-May-17,57.90,55.90,0.13,836 61 | 30-May-17,59.00,55.90,0.06,2301 62 | 31-May-17,75.00,57.90,0.03,2689 63 | 1-Jun-17,78.10,62.10,0.00,3468 64 | 2-Jun-17,73.90,60.10,0.01,3271 65 | 3-Jun-17,72.00,55.00,0.01,2589 66 | 4-Jun-17,68.00,60.10,0.09,1805 67 | 5-Jun-17,66.90,60.10,0.02,2171 68 | 6-Jun-17,55.90,53.10,0.06,1193 69 | 7-Jun-17,66.90,54.00,0.00,3211 70 | 8-Jun-17,68.00,59.00,0.00,3253 71 | 9-Jun-17,80.10,59.00,0.00,3401 72 | 10-Jun-17,84.00,68.00,0.00,3066 73 | 11-Jun-17,90.00,73.00,0.00,2465 74 | 12-Jun-17,91.90,77.00,0.00,2854 75 | 13-Jun-17,93.90,78.10,0.01,2882 76 | 14-Jun-17,84.00,69.10,0.29,2596 77 | 15-Jun-17,75.00,66.00,0.00,3510 78 | 16-Jun-17,68.00,66.00,0.00,2054 79 | 17-Jun-17,73.00,66.90,1.39,1399 80 | 18-Jun-17,84.00,72.00,0.01,2199 81 | 19-Jun-17,87.10,70.00,1.35,1648 82 | 20-Jun-17,82.00,72.00,0.03,3407 83 | 21-Jun-17,82.00,72.00,0.00,3304 84 | 22-Jun-17,82.00,70.00,0.00,3368 85 | 23-Jun-17,82.90,75.90,0.04,2283 86 | 24-Jun-17,82.90,71.10,1.29,2307 87 | 25-Jun-17,82.00,69.10,0.00,2625 88 | 26-Jun-17,78.10,66.00,0.00,3386 89 | 27-Jun-17,75.90,61.00,0.18,3182 90 | 28-Jun-17,78.10,62.10,0.00,3766 91 | 29-Jun-17,81.00,68.00,0.00,3356 92 | 30-Jun-17,88.00,73.90,0.01,2687 93 | 1-Jul-17,84.90,72.00,0.23,1848 94 | 2-Jul-17,87.10,73.00,0.00,2467 95 | 3-Jul-17,87.10,71.10,0.45,2714 96 | 4-Jul-17,82.90,70.00,0.00,2296 97 | 5-Jul-17,84.90,71.10,0.00,3170 98 | 6-Jul-17,75.00,71.10,0.01,3065 99 | 7-Jul-17,79.00,68.00,1.78,1513 100 | 8-Jul-17,82.90,70.00,0.00,2718 101 | 9-Jul-17,81.00,69.10,0.00,3048 102 | 10-Jul-17,82.90,71.10,0.00,3506 103 | 11-Jul-17,84.00,75.00,0.00,2929 104 | 12-Jul-17,87.10,77.00,0.00,2860 105 | 13-Jul-17,89.10,77.00,0.00,2563 106 | 14-Jul-17,69.10,64.90,0.35,907 107 | 15-Jul-17,82.90,68.00,0.00,2853 108 | 16-Jul-17,84.90,70.00,0.00,2917 109 | 17-Jul-17,84.90,73.90,0.00,3264 110 | 18-Jul-17,87.10,75.90,0.00,3507 111 | 19-Jul-17,91.00,77.00,0.00,3114 112 | 20-Jul-17,93.00,78.10,0.01,2840 113 | 21-Jul-17,91.00,77.00,0.00,2751 114 | 22-Jul-17,91.00,78.10,0.57,2301 115 | 23-Jul-17,78.10,73.00,0.06,2321 116 | 24-Jul-17,69.10,63.00,0.74,1576 117 | 25-Jul-17,71.10,64.00,0.00,3191 118 | 26-Jul-17,75.90,66.00,0.00,3821 119 | 27-Jul-17,77.00,66.90,0.01,3287 120 | 28-Jul-17,84.90,73.00,0.00,3123 121 | 29-Jul-17,75.90,68.00,0.00,2074 122 | 30-Jul-17,81.00,64.90,0.00,3331 123 | 31-Jul-17,88.00,66.90,0.00,3560 124 | 1-Aug-17,91.00,72.00,0.00,3492 125 | 2-Aug-17,86.00,69.10,0.09,2637 126 | 3-Aug-17,86.00,70.00,0.00,3346 127 | 4-Aug-17,82.90,70.00,0.15,2400 128 | 5-Aug-17,77.00,70.00,0.30,3409 129 | 6-Aug-17,75.90,64.00,0.00,3130 130 | 7-Aug-17,71.10,64.90,0.76,804 131 | 8-Aug-17,77.00,66.00,0.00,3598 132 | 9-Aug-17,82.90,66.00,0.00,3893 133 | 10-Aug-17,82.90,69.10,0.00,3423 134 | 11-Aug-17,81.00,70.00,0.01,3148 135 | 12-Aug-17,75.90,64.90,0.11,4146 136 | 13-Aug-17,82.00,71.10,0.00,3274 137 | 14-Aug-17,80.10,70.00,0.00,3291 138 | 15-Aug-17,73.00,69.10,0.45,2149 139 | 16-Aug-17,84.90,70.00,0.00,3685 140 | 17-Aug-17,82.00,71.10,0.00,3637 141 | 18-Aug-17,81.00,73.00,0.88,1064 142 | 19-Aug-17,84.90,73.00,0.00,4693 143 | 20-Aug-17,81.00,70.00,0.00,2822 144 | 21-Aug-17,84.90,73.00,0.00,3088 145 | 22-Aug-17,88.00,75.00,0.30,2983 146 | 23-Aug-17,80.10,71.10,0.01,2994 147 | 24-Aug-17,79.00,66.00,0.00,3688 148 | 25-Aug-17,78.10,64.00,0.00,3144 149 | 26-Aug-17,77.00,62.10,0.00,2710 150 | 27-Aug-17,77.00,63.00,0.00,2676 151 | 28-Aug-17,75.00,63.00,0.00,3332 152 | 29-Aug-17,68.00,62.10,0.10,1472 153 | 30-Aug-17,75.90,61.00,0.01,3468 154 | 31-Aug-17,81.00,64.00,0.00,3279 155 | 1-Sep-17,70.00,55.00,0.00,2945 156 | 2-Sep-17,66.90,54.00,0.53,1876 157 | 3-Sep-17,69.10,60.10,0.74,1004 158 | 4-Sep-17,79.00,62.10,0.00,2866 159 | 5-Sep-17,84.00,70.00,0.01,3244 160 | 6-Sep-17,70.00,62.10,0.42,1232 161 | 7-Sep-17,71.10,59.00,0.01,3249 162 | 8-Sep-17,70.00,59.00,0.00,3234 163 | 9-Sep-17,69.10,55.00,0.00,2609 164 | 10-Sep-17,72.00,57.00,0.00,4960 165 | 11-Sep-17,75.90,55.00,0.00,3657 166 | 12-Sep-17,78.10,61.00,0.00,3497 167 | 13-Sep-17,82.00,64.90,0.06,2994 168 | 14-Sep-17,81.00,70.00,0.02,3013 169 | 15-Sep-17,81.00,66.90,0.00,3344 170 | 16-Sep-17,82.00,70.00,0.00,2560 171 | 17-Sep-17,80.10,70.00,0.00,2676 172 | 18-Sep-17,73.00,69.10,0.00,2673 173 | 19-Sep-17,78.10,69.10,0.22,2012 174 | 20-Sep-17,78.10,71.10,0.00,3296 175 | 21-Sep-17,80.10,71.10,0.00,3317 176 | 22-Sep-17,82.00,66.00,0.00,3297 177 | 23-Sep-17,86.00,68.00,0.00,2810 178 | 24-Sep-17,90.00,69.10,0.00,2543 179 | 25-Sep-17,87.10,72.00,0.00,3276 180 | 26-Sep-17,82.00,69.10,0.00,3157 181 | 27-Sep-17,84.90,71.10,0.00,3216 182 | 28-Sep-17,78.10,66.00,0.00,3421 183 | 29-Sep-17,66.90,55.00,0.00,2988 184 | 30-Sep-17,64.00,55.90,0.00,1903 185 | 1-Oct-17,66.90,50.00,0.00,2297 186 | 2-Oct-17,72.00,52.00,0.00,3387 187 | 3-Oct-17,70.00,57.00,0.00,3386 188 | 4-Oct-17,75.00,55.90,0.00,3412 189 | 5-Oct-17,82.00,64.90,0.00,3312 190 | 6-Oct-17,81.00,69.10,0.00,2982 191 | 7-Oct-17,80.10,66.00,0.00,2750 192 | 8-Oct-17,77.00,72.00,0.22,1235 193 | 9-Oct-17,75.90,72.00,0.26,898 194 | 10-Oct-17,80.10,66.00,0.00,3922 195 | 11-Oct-17,75.00,64.90,0.06,2721 196 | 12-Oct-17,63.00,55.90,0.07,2411 197 | 13-Oct-17,64.90,52.00,0.00,2839 198 | 14-Oct-17,71.10,62.10,0.08,2021 199 | 15-Oct-17,72.00,66.00,0.01,2169 200 | 16-Oct-17,60.10,52.00,0.01,2751 201 | 17-Oct-17,57.90,43.00,0.00,2869 202 | 18-Oct-17,71.10,50.00,0.00,3264 203 | 19-Oct-17,70.00,55.90,0.00,3265 204 | 20-Oct-17,73.00,57.90,0.00,3169 205 | 21-Oct-17,78.10,57.00,0.00,2538 206 | 22-Oct-17,75.90,57.00,0.00,2744 207 | 23-Oct-17,73.90,64.00,0.00,3189 208 | 24-Oct-17,73.00,66.90,0.20,954 209 | 25-Oct-17,64.90,57.90,0.00,3367 210 | 26-Oct-17,57.00,53.10,0.00,2565 211 | 27-Oct-17,62.10,48.00,0.00,3150 212 | 28-Oct-17,68.00,55.90,0.00,2245 213 | 29-Oct-17,64.90,61.00,3.03,183 214 | 30-Oct-17,55.00,46.00,0.25,1428 215 | 31-Oct-17,54.00,44.00,0.00,2727 -------------------------------------------------------------------------------- /src/pycbm.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | Licensed under the MIT License. */ 3 | 4 | #include "pycbm.h" 5 | 6 | #include 7 | 8 | namespace py = pybind11; 9 | 10 | namespace cbm 11 | { 12 | PyCBM::PyCBM() 13 | { 14 | } 15 | 16 | PyCBM::PyCBM(const std::vector> &f, double y_mean) : _cbm(f, y_mean) 17 | { 18 | } 19 | 20 | void PyCBM::fit( 21 | py::buffer y_b, 22 | py::buffer x_b, 23 | double y_mean, 24 | py::buffer x_max_b, 25 | double learning_rate_step_size, 26 | size_t max_iterations, 27 | size_t min_iterations_early_stopping, 28 | double epsilon_early_stopping, 29 | bool single_update_per_iteration, 30 | std::string metric, 31 | bool enable_bin_count) 32 | { 33 | 34 | // can't check compare just the format as linux returns I, windows returns L when using astype('uint32') 35 | // https://docs.python.org/3/library/struct.html#format-characters 36 | py::buffer_info y_info = y_b.request(); 37 | if (!(y_info.itemsize == 4 && (y_info.format == "I" || 38 | y_info.format == "H" || 39 | y_info.format == "N" || 40 | y_info.format == "B" || 41 | y_info.format == "L"))) 42 | { 43 | std::ostringstream oss; 44 | oss << "y must be of type unsigned integer/long with 4 bytes! Must use y.astype('uint32'). " 45 | << "Format: " << y_info.format << " Size: " << y_info.itemsize; 46 | throw std::runtime_error(""); 47 | } 48 | 49 | if (y_info.ndim != 1) 50 | throw std::runtime_error("y must be 1-dimensional!"); 51 | 52 | py::buffer_info x_info = x_b.request(); 53 | if (!(x_info.itemsize >= 1 && x_info.itemsize <= 4 && (x_info.format == "I" || 54 | x_info.format == "H" || 55 | x_info.format == "N" || 56 | x_info.format == "B" || 57 | x_info.format == "L"))) 58 | { 59 | std::ostringstream oss; 60 | oss << "x must be of type unsigned integer/long with 1, 2, 4 bytes! Must use x.astype('uint8') or uint16/uint32." 61 | << "Format: " << x_info.format << " Size: " << x_info.itemsize; 62 | 63 | throw std::runtime_error(oss.str().c_str()); 64 | } 65 | 66 | if (x_info.ndim != 2) 67 | throw std::runtime_error("x must be 2-dimensional!"); 68 | 69 | py::buffer_info x_max_info = x_max_b.request(); 70 | if (!(x_max_info.itemsize == 4 && (x_max_info.format == "I" || 71 | x_max_info.format == "H" || 72 | x_max_info.format == "N" || 73 | x_max_info.format == "B" || 74 | x_max_info.format == "L"))) 75 | throw std::runtime_error("Incompatible format: expected a uint32_t array for x_max!"); 76 | 77 | if (x_max_info.ndim != 1) 78 | throw std::runtime_error("Incompatible buffer dimension!"); 79 | 80 | if (y_info.shape[0] != x_info.shape[0]) 81 | throw std::runtime_error("len(y) != len(x)"); 82 | 83 | // data 84 | uint32_t *y = static_cast(y_info.ptr); 85 | uint32_t *x_max = static_cast(x_max_info.ptr); 86 | char *x_data = static_cast(x_info.ptr); 87 | 88 | // dimensions 89 | ssize_t n_examples = y_info.shape[0]; 90 | ssize_t n_features = x_info.shape[1]; 91 | 92 | float (*metric_func)(const uint32_t*, const double*, size_t n_examples) = nullptr; 93 | 94 | if (metric == "rmse") 95 | metric_func = metric_RMSE; 96 | else if (metric == "smape") 97 | metric_func = metric_SMAPE; 98 | else if (metric == "l1") 99 | metric_func = metric_L1; 100 | else 101 | throw std::runtime_error("Unknown metric!"); 102 | 103 | _cbm.fit( 104 | y, 105 | x_data, 106 | x_info.strides[0], 107 | x_info.strides[1], 108 | n_examples, 109 | n_features, 110 | y_mean, 111 | x_max, 112 | learning_rate_step_size, 113 | max_iterations, 114 | min_iterations_early_stopping, 115 | epsilon_early_stopping, 116 | single_update_per_iteration, 117 | (uint8_t)x_info.itemsize, 118 | metric_func, 119 | enable_bin_count); 120 | } 121 | 122 | py::array_t PyCBM::predict(py::buffer x_b, bool explain) 123 | { 124 | // TODO: fix error messages 125 | py::buffer_info x_info = x_b.request(); 126 | if (!(x_info.itemsize >= 1 && x_info.itemsize <= 4 && (x_info.format == "I" || 127 | x_info.format == "H" || 128 | x_info.format == "N" || 129 | x_info.format == "B" || 130 | x_info.format == "L"))) 131 | { 132 | std::ostringstream oss; 133 | oss << "x must be of type unsigned integer/long with 1, 2, 4 bytes! Must use x.astype('uint8') or uint16/uint32." 134 | << "Format: " << x_info.format << " Size: " << x_info.itemsize; 135 | 136 | throw std::runtime_error(oss.str().c_str()); 137 | } 138 | 139 | if (x_info.ndim != 2) 140 | throw std::runtime_error("Incompatible buffer dimension!"); 141 | 142 | char *x_data = static_cast(x_info.ptr); 143 | 144 | // TODO: handle ssize_t vs size_t 145 | ssize_t n_examples = x_info.shape[0]; 146 | ssize_t n_features = x_info.shape[1]; 147 | 148 | py::array_t out_data( 149 | {(int)n_examples, explain ? (int)(1 + n_features) : 1}); 150 | 151 | switch (x_info.itemsize) 152 | { 153 | case 1: 154 | if (explain) 155 | _cbm.predict(x_data, x_info.strides[0], x_info.strides[1], n_examples, n_features, out_data.mutable_data()); 156 | else 157 | _cbm.predict(x_data, x_info.strides[0], x_info.strides[1], n_examples, n_features, out_data.mutable_data()); 158 | break; 159 | 160 | case 2: 161 | if (explain) 162 | _cbm.predict(x_data, x_info.strides[0], x_info.strides[1], n_examples, n_features, out_data.mutable_data()); 163 | else 164 | _cbm.predict(x_data, x_info.strides[0], x_info.strides[1], n_examples, n_features, out_data.mutable_data()); 165 | break; 166 | 167 | case 4: 168 | if (explain) 169 | _cbm.predict(x_data, x_info.strides[0], x_info.strides[1], n_examples, n_features, out_data.mutable_data()); 170 | else 171 | _cbm.predict(x_data, x_info.strides[0], x_info.strides[1], n_examples, n_features, out_data.mutable_data()); 172 | break; 173 | } 174 | 175 | return out_data; 176 | } 177 | 178 | const std::vector> &PyCBM::get_weights() const 179 | { 180 | return _cbm.get_weights(); 181 | } 182 | 183 | void PyCBM::set_weights(std::vector> &w) 184 | { 185 | _cbm.set_weights(w); 186 | } 187 | 188 | float PyCBM::get_y_mean() const 189 | { 190 | return _cbm.get_y_mean(); 191 | } 192 | 193 | void PyCBM::set_y_mean(float y_mean) 194 | { 195 | _cbm.set_y_mean(y_mean); 196 | } 197 | 198 | size_t PyCBM::get_iterations() const 199 | { 200 | return _cbm.get_iterations(); 201 | } 202 | 203 | const std::vector> &PyCBM::get_bin_count() const { 204 | return _cbm.get_bin_count(); 205 | } 206 | }; 207 | 208 | PYBIND11_MODULE(cbm_cpp, m) 209 | { 210 | py::class_ estimator(m, "PyCBM"); 211 | 212 | estimator.def(py::init([]() 213 | { return new cbm::PyCBM(); })) 214 | .def("fit", &cbm::PyCBM::fit) 215 | .def("predict", &cbm::PyCBM::predict) 216 | .def_property("y_mean", &cbm::PyCBM::get_y_mean, &cbm::PyCBM::set_y_mean) 217 | .def_property("weights", &cbm::PyCBM::get_weights, &cbm::PyCBM::set_weights) 218 | .def_property_readonly("iterations", &cbm::PyCBM::get_iterations) 219 | .def_property_readonly("bin_count", &cbm::PyCBM::get_bin_count) 220 | .def(py::pickle( 221 | [](const cbm::PyCBM &p) { // __getstate__ 222 | /* TODO: this does not include the feature pre-processing */ 223 | /* Return a tuple that fully encodes the state of the object */ 224 | return py::make_tuple(p.get_weights(), p.get_y_mean()); 225 | }, 226 | [](py::tuple t) { // __setstate__ 227 | if (t.size() != 2) 228 | throw std::runtime_error("Invalid state!"); 229 | 230 | /* Create a new C++ instance */ 231 | cbm::PyCBM p(t[0].cast>>(), 232 | t[1].cast()); 233 | 234 | return p; 235 | })); 236 | } -------------------------------------------------------------------------------- /src/cbm.h: -------------------------------------------------------------------------------- 1 | /* Copyright (c) Microsoft Corporation. 2 | Licensed under the MIT License. */ 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | // #include 14 | // #include 15 | // using namespace std::chrono; 16 | 17 | namespace cbm 18 | { 19 | float metric_RMSE(const uint32_t* y, const double* y_hat, size_t n_examples); 20 | float metric_SMAPE(const uint32_t* y, const double* y_hat, size_t n_examples); 21 | float metric_L1(const uint32_t* y, const double* y_hat, size_t n_examples); 22 | 23 | class CBM 24 | { 25 | // n_features x max_bin[j] (jagged) 26 | std::vector> _f; 27 | double _y_mean; 28 | 29 | size_t _iterations; 30 | std::vector> _bin_count; 31 | 32 | template 33 | void update_y_hat( 34 | std::vector& y_hat, 35 | std::vector> &x, 36 | size_t n_examples, 37 | size_t n_features) 38 | { 39 | // predict 40 | y_hat.assign(n_examples, _y_mean); 41 | 42 | #pragma omp parallel for schedule(static, 10000) 43 | for (size_t i = 0; i < n_examples; i++) 44 | for (size_t j = 0; j < n_features; j++) 45 | y_hat[i] *= _f[j][x[j][i]]; 46 | } 47 | 48 | template 49 | void update_y_hat_sum( 50 | std::vector& y_hat, 51 | std::vector> &y_hat_sum, 52 | std::vector> &x, 53 | size_t n_examples, 54 | size_t n_features) 55 | { 56 | update_y_hat(y_hat, x, n_examples, n_features); 57 | 58 | // reset y_hat_sum 59 | #pragma omp parallel for 60 | for (size_t j = 0; j < n_features; j++) 61 | std::fill(y_hat_sum[j].begin(), y_hat_sum[j].end(), 0); 62 | 63 | // compute y_hat and y_hat_sum 64 | #pragma omp parallel for 65 | for (size_t j = 0; j < n_features; j++) 66 | for (size_t i = 0; i < n_examples; i++) 67 | // TODO: use log to stabilize? 68 | y_hat_sum[j][x[j][i]] += y_hat[i]; 69 | } 70 | 71 | template 72 | void fit_internal( 73 | const uint32_t *y, 74 | const char *x_data, 75 | size_t x_stride0, 76 | size_t x_stride1, 77 | size_t n_examples, 78 | size_t n_features, 79 | double y_mean, 80 | const uint32_t *x_max, 81 | double learning_rate_step_size, 82 | size_t max_iterations, 83 | size_t min_iterations_early_stopping, 84 | double epsilon_early_stopping, 85 | bool single_update_per_iteration, 86 | float (*metric)(const uint32_t*, const double*, size_t n_examples)) 87 | { 88 | _y_mean = y_mean; 89 | 90 | // allocation 91 | std::vector> x(n_features); // n_features x n_examples 92 | std::vector> g(n_features); // n_features x max_bin[j] (jagged) 93 | std::vector> y_sum(n_features); // n_features x max_bin[j] (jagged) 94 | std::vector> y_hat_sum(n_features); // n_features x max_bin[j] (jagged) 95 | std::vector y_hat(n_examples); 96 | 97 | _f.resize(n_features); 98 | if (enableBinCount) 99 | _bin_count.resize(n_features); 100 | 101 | #pragma omp parallel for 102 | for (size_t j = 0; j < n_features; j++) 103 | { 104 | uint32_t max_bin = x_max[j]; 105 | 106 | g[j].resize(max_bin + 1); 107 | _f[j].resize(max_bin + 1, 1); 108 | y_sum[j].resize(max_bin + 1); 109 | y_hat_sum[j].resize(max_bin + 1); 110 | 111 | if (enableBinCount) 112 | _bin_count[j].resize(max_bin + 1, 0); 113 | 114 | // alloc and store columnar 115 | x[j].reserve(n_examples); 116 | for (size_t i = 0; i < n_examples; i++) 117 | { 118 | // https://docs.python.org/3/c-api/buffer.html#complex-arrays 119 | // strides are expressed in char not target type 120 | T x_ij = *reinterpret_cast(x_data + i * x_stride0 + j * x_stride1); 121 | x[j].push_back(x_ij); 122 | 123 | y_sum[j][x_ij] += y[i]; 124 | 125 | y_sum[j][x_ij] += y[i]; 126 | 127 | if (enableBinCount) 128 | _bin_count[j][x_ij]++; 129 | } 130 | } 131 | 132 | // iterations 133 | double learning_rate = learning_rate_step_size; 134 | double rmse0 = std::numeric_limits::infinity(); 135 | 136 | for (_iterations = 0; _iterations < max_iterations; _iterations++, learning_rate += learning_rate_step_size) 137 | { 138 | // cap at 1 139 | if (learning_rate > 1) 140 | learning_rate = 1; 141 | 142 | update_y_hat_sum(y_hat, y_hat_sum, x, n_examples, n_features); 143 | 144 | // compute g 145 | for (size_t j = 0; j < n_features; j++) 146 | { 147 | for (size_t k = 0; k <= x_max[j]; k++) 148 | { 149 | // TODO: check if a bin is empty. might be better to remap/exclude the bins? 150 | if (y_sum[j][k]) 151 | { 152 | // improve stability 153 | double g = (double)y_sum[j][k] / y_hat_sum[j][k]; // eqn. 2 (a) 154 | 155 | // magic numbers found in Regularization section (worsen it quite a bit) 156 | // double g = (2.0 * y_sum[j][k]) / (1.67834 * y_hat_sum[j][k]); // eqn. 2 (a) 157 | 158 | if (learning_rate == 1) 159 | _f[j][k] *= g; 160 | else 161 | _f[j][k] *= std::exp(learning_rate * std::log(g)); // eqn 2 (b) + eqn 4 162 | 163 | if (!single_update_per_iteration) { 164 | update_y_hat_sum(y_hat, y_hat_sum, x, n_examples, n_features); 165 | } 166 | } 167 | } 168 | 169 | // update_y_hat_sum after every feature 170 | update_y_hat_sum(y_hat, y_hat_sum, x, n_examples, n_features); 171 | } 172 | 173 | // prediction 174 | update_y_hat(y_hat, x, n_examples, n_features); 175 | 176 | double rmse = metric(y, y_hat.data(), n_examples); 177 | 178 | // check for early stopping 179 | // TODO: expose minimum number of rounds 180 | if (_iterations > min_iterations_early_stopping && 181 | (rmse > rmse0 || (rmse0 - rmse) < epsilon_early_stopping)) 182 | { 183 | // TODO: record diagnostics? 184 | // printf("early stopping %1.4f vs %1.4f after t=%d\n", rmse, rmse0, (int)t); 185 | break; 186 | } 187 | rmse0 = rmse; 188 | } 189 | } 190 | 191 | public: 192 | CBM(); 193 | CBM(const std::vector> &f, double y_mean); 194 | 195 | void fit( 196 | const uint32_t *y, 197 | const char *x_data, 198 | size_t x_stride0, 199 | size_t x_stride1, 200 | size_t n_examples, 201 | size_t n_features, 202 | double y_mean, 203 | const uint32_t *x_max, 204 | double learning_rate_step_size, 205 | size_t max_iterations, 206 | size_t min_iterations_early_stopping, 207 | double epsilon_early_stopping, 208 | bool single_update_per_iteration, 209 | uint8_t x_bytes_per_feature, 210 | float (*metric)(const uint32_t*, const double*, size_t n_examples), 211 | bool enable_bin_count); 212 | 213 | template 214 | void predict( 215 | const char *x_data, 216 | size_t x_stride0, 217 | size_t x_stride1, 218 | size_t n_examples, 219 | size_t n_features, 220 | double *out_data) 221 | { 222 | 223 | if (n_features != _f.size()) 224 | throw std::runtime_error("Features need to match!"); 225 | 226 | // column-wise oriented output data 227 | double *out_y_hat = out_data; 228 | std::fill(out_y_hat, out_y_hat + n_examples, _y_mean); 229 | 230 | #pragma omp parallel for schedule(static, 10000) 231 | for (size_t i = 0; i < n_examples; i++) 232 | { 233 | double &y_hat_i = *(out_y_hat + i); 234 | 235 | for (size_t j = 0; j < n_features; j++) 236 | { 237 | // TODO: simd gather? 238 | T x_ij = *reinterpret_cast(x_data + i * x_stride0 + j * x_stride1); 239 | y_hat_i *= _f[j][x_ij]; 240 | 241 | if (explain) 242 | { 243 | *(out_data + (j + 1) * n_examples + i) = _f[j][x_ij]; 244 | } 245 | } 246 | } 247 | } 248 | 249 | const std::vector> &get_weights() const; 250 | void set_weights(std::vector> &); 251 | 252 | float get_y_mean() const; 253 | void set_y_mean(float mean); 254 | 255 | size_t get_iterations() const; 256 | 257 | const std::vector> &get_bin_count() const; 258 | }; 259 | } -------------------------------------------------------------------------------- /kaggle/favorita-grocery-sales-forecasting/kaggle.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2a36103a", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/html": [ 12 | "" 13 | ], 14 | "text/plain": [ 15 | "" 16 | ] 17 | }, 18 | "metadata": {}, 19 | "output_type": "display_data" 20 | }, 21 | { 22 | "data": { 23 | "text/html": [ 24 | "" 25 | ], 26 | "text/plain": [ 27 | "" 28 | ] 29 | }, 30 | "metadata": {}, 31 | "output_type": "display_data" 32 | }, 33 | { 34 | "data": { 35 | "text/html": [ 36 | "" 37 | ], 38 | "text/plain": [ 39 | "" 40 | ] 41 | }, 42 | "metadata": {}, 43 | "output_type": "display_data" 44 | } 45 | ], 46 | "source": [ 47 | "%load_ext autoreload\n", 48 | "%autoreload 2\n", 49 | "\n", 50 | "import numpy as np\n", 51 | "import pandas as pd\n", 52 | "import matplotlib.pyplot as plt\n", 53 | "\n", 54 | "from IPython.core.display import display, HTML\n", 55 | "display(HTML(\"\"))\n", 56 | "display(HTML(\"\"))\n", 57 | "display(HTML(\"\"))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "id": "04b0763f", 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# \n", 68 | "# train = pd.read_csv('data/train.csv', \n", 69 | "# parse_dates=['date'], \n", 70 | "# index_col='id', \n", 71 | "# dtype={\n", 72 | "# # 'date': np.datetime64, \n", 73 | "# 'store_nbr': np.short,\n", 74 | "# 'item_nbr': np.int64,\n", 75 | "# 'unit_sales': np.float64\n", 76 | "# },\n", 77 | "# converters={'onpromotion': lambda x: 'T' if x == 'True' else ('F' if x == 'False' else 'U')}\n", 78 | "# )\n", 79 | "#\n", 80 | "# train.merge(items).to_parquet('data/train_items.parquet')\n", 81 | "# train['class'] = train['class'].astype('str')\n", 82 | "# train['item_nbr'] = train['item_nbr'].astype('str')\n", 83 | "# train.to_parquet('data/train.parquet')" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 14, 89 | "id": "7412ac39", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "data": { 94 | "text/html": [ 95 | "
\n", 96 | "\n", 109 | "\n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | "
datestore_nbritem_nbrunit_salesonpromotionfamilyclassperishable
02013-01-01251036657.0UBREAD/BAKERY27121
12013-01-0211036652.0UBREAD/BAKERY27121
22013-01-0221036655.0UBREAD/BAKERY27121
32013-01-0231036656.0UBREAD/BAKERY27121
42013-01-0241036652.0UBREAD/BAKERY27121
\n", 181 | "
" 182 | ], 183 | "text/plain": [ 184 | " date store_nbr item_nbr unit_sales onpromotion family \\\n", 185 | "0 2013-01-01 25 103665 7.0 U BREAD/BAKERY \n", 186 | "1 2013-01-02 1 103665 2.0 U BREAD/BAKERY \n", 187 | "2 2013-01-02 2 103665 5.0 U BREAD/BAKERY \n", 188 | "3 2013-01-02 3 103665 6.0 U BREAD/BAKERY \n", 189 | "4 2013-01-02 4 103665 2.0 U BREAD/BAKERY \n", 190 | "\n", 191 | " class perishable \n", 192 | "0 2712 1 \n", 193 | "1 2712 1 \n", 194 | "2 2712 1 \n", 195 | "3 2712 1 \n", 196 | "4 2712 1 " 197 | ] 198 | }, 199 | "execution_count": 14, 200 | "metadata": {}, 201 | "output_type": "execute_result" 202 | } 203 | ], 204 | "source": [ 205 | "# scikit-based\n", 206 | "\n", 207 | "train_all = pd.read_parquet('data/train_items.parquet')\n", 208 | "train_all.head()" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 15, 214 | "id": "9400687b", 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "(125497040, 8)" 221 | ] 222 | }, 223 | "execution_count": 15, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "train_all.shape" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 3, 235 | "id": "c62e0a47", 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "data": { 240 | "text/plain": [ 241 | "0 1\n", 242 | "1 2\n", 243 | "2 2\n", 244 | "3 2\n", 245 | "4 2\n", 246 | "Name: date, dtype: int64" 247 | ] 248 | }, 249 | "execution_count": 3, 250 | "metadata": {}, 251 | "output_type": "execute_result" 252 | } 253 | ], 254 | "source": [ 255 | "# train['Weekday'] = train['date'].dt.dayofweek_str\n", 256 | "train.head()['date'].dt.dayofweek" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 9, 262 | "id": "dad0788e", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "from sklearn.model_selection import TimeSeriesSplit\n", 267 | "from sklearn.utils import indexable\n", 268 | "\n", 269 | "from datetime import datetime, timedelta\n", 270 | "\n", 271 | "# TODO\n", 272 | "class TemporalSplit(TimeSeriesSplit):\n", 273 | " def __init__(self, step=timedelta(days=1), n_splits=5, *, max_train_size=None, test_size=None, gap=0):\n", 274 | " super().__init__(n_splits)\n", 275 | " self.step = step\n", 276 | " self.max_train_size = max_train_size\n", 277 | " self.test_size = test_size\n", 278 | " self.gap = gap\n", 279 | "\n", 280 | " def _create_date_ranges(self, start, end, step):\n", 281 | " start_ = start\n", 282 | " while start_ < end:\n", 283 | " end_ = start_ + step\n", 284 | " yield start_\n", 285 | " start_ = end_\n", 286 | " \n", 287 | " def split(self, X, y=None, groups=None):\n", 288 | " \"\"\"Generate indices to split data into training and test set.\n", 289 | " Parameters\n", 290 | " ----------\n", 291 | " X : array-like of shape (n_samples, n_features)\n", 292 | " Training data, where `n_samples` is the number of samples\n", 293 | " and `n_features` is the number of features.\n", 294 | " y : array-like of shape (n_samples,)\n", 295 | " Always ignored, exists for compatibility.\n", 296 | " groups : array-like of shape (n_samples,)\n", 297 | " Always ignored, exists for compatibility.\n", 298 | " Yields\n", 299 | " ------\n", 300 | " train : ndarray\n", 301 | " The training set indices for that split.\n", 302 | " test : ndarray\n", 303 | " The testing set indices for that split.\n", 304 | " \"\"\"\n", 305 | " X, y, groups = indexable(X, y, groups)\n", 306 | " \n", 307 | " date_range = list(self._create_date_ranges(X.index.min(), X.index.max(), self.step))\n", 308 | " n_samples = len(date_range)\n", 309 | " n_splits = self.n_splits\n", 310 | " n_folds = n_splits + 1\n", 311 | " gap = self.gap\n", 312 | " test_size = (\n", 313 | " self.test_size if self.test_size is not None else n_samples // n_folds\n", 314 | " )\n", 315 | "\n", 316 | " # Make sure we have enough samples for the given split parameters\n", 317 | " if n_folds > n_samples:\n", 318 | " raise ValueError(\n", 319 | " f\"Cannot have number of folds={n_folds} greater\"\n", 320 | " f\" than the number of samples={n_samples}.\"\n", 321 | " )\n", 322 | " if n_samples - gap - (test_size * n_splits) <= 0:\n", 323 | " raise ValueError(\n", 324 | " f\"Too many splits={n_splits} for number of samples\"\n", 325 | " f\"={n_samples} with test_size={test_size} and gap={gap}.\"\n", 326 | " )\n", 327 | "\n", 328 | " # indices = np.arange(n_samples)\n", 329 | " test_starts = range(n_samples - n_splits * test_size, n_samples, test_size)\n", 330 | "\n", 331 | " for test_start in test_starts:\n", 332 | " train_end = test_start - gap\n", 333 | " if self.max_train_size and self.max_train_size < train_end:\n", 334 | " yield (\n", 335 | " # TODO: unit test\n", 336 | " # TODO: not sure why np.where returns a tuple.\n", 337 | " np.where(np.logical_and(X.index >= date_range[train_end - self.max_train_size], X.index <= date_range[train_end - 1]))[0],\n", 338 | " np.where(np.logical_and(X.index >= date_range[test_start], X.index <= date_range[test_start + test_size - 1]))[0]\n", 339 | " # indices[train_end - self.max_train_size : train_end],\n", 340 | " # indices[test_start : test_start + test_size],\n", 341 | " )\n", 342 | " else:\n", 343 | " yield (\n", 344 | " np.where(X.index < date_range[train_end])[0],\n", 345 | " np.where(np.logical_and(X.index >= date_range[test_start], X.index <= date_range[test_start + test_size - 1]))[0]\n", 346 | " # indices[:train_end],\n", 347 | " # indices[test_start : test_start + test_size],\n", 348 | " )\n", 349 | "\n", 350 | "# cv = list(TemporalSplit(n_splits=3, test_size=100).split(train_idx))\n", 351 | "# for s in cv:\n", 352 | "# print(s[0])\n", 353 | "# print(s[0][0])\n", 354 | "# print(f'Train: {train_idx.iloc[s[0]].index.min()} - {train_idx.iloc[s[0]].index.max()}')\n", 355 | "# print(f'Test: {train_idx.iloc[s[1]].index.min()} - {train_idx.iloc[s[1]].index.max()}')\n", 356 | "# print()\n", 357 | " \n", 358 | "# cv = list(TimeSeriesSplit(n_splits=3, test_size=100).split(train_idx))\n", 359 | "# for s in cv:\n", 360 | "# print(s[0])\n", 361 | "# print(s[0][0])\n", 362 | "# print(f'Train: {train_idx.iloc[s[0]].index.min()} - {train_idx.iloc[s[0]].index.max()}')\n", 363 | "# print(f'Test: {train_idx.iloc[s[1]].index.min()} - {train_idx.iloc[s[1]].index.max()}')\n", 364 | "# print()" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 101, 370 | "id": "c807d974", 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "data": { 375 | "text/html": [ 376 | "
\n", 377 | "\n", 390 | "\n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | "
datestore_nbritem_nbrunit_salesonpromotionfamilyclassperishable
date
2013-01-012013-01-01256383272.0UGROCERY I10840
2013-01-012013-01-01254643857.0UBREAD/BAKERY27181
2013-01-012013-01-01253603145.0UBREAD/BAKERY27021
2013-01-012013-01-012555725610.0UEGGS25021
2013-01-012013-01-01253585152.0UGROCERY I10240
\n", 473 | "
" 474 | ], 475 | "text/plain": [ 476 | " date store_nbr item_nbr unit_sales onpromotion \\\n", 477 | "date \n", 478 | "2013-01-01 2013-01-01 25 638327 2.0 U \n", 479 | "2013-01-01 2013-01-01 25 464385 7.0 U \n", 480 | "2013-01-01 2013-01-01 25 360314 5.0 U \n", 481 | "2013-01-01 2013-01-01 25 557256 10.0 U \n", 482 | "2013-01-01 2013-01-01 25 358515 2.0 U \n", 483 | "\n", 484 | " family class perishable \n", 485 | "date \n", 486 | "2013-01-01 GROCERY I 1084 0 \n", 487 | "2013-01-01 BREAD/BAKERY 2718 1 \n", 488 | "2013-01-01 BREAD/BAKERY 2702 1 \n", 489 | "2013-01-01 EGGS 2502 1 \n", 490 | "2013-01-01 GROCERY I 1024 0 " 491 | ] 492 | }, 493 | "execution_count": 101, 494 | "metadata": {}, 495 | "output_type": "execute_result" 496 | } 497 | ], 498 | "source": [ 499 | "train = train_all.sample(int(10e6))\n", 500 | "train = train.set_index('date', drop=False).sort_index()\n", 501 | "train.loc[train['unit_sales'] < 0, 'unit_sales'] = 0\n", 502 | "train.head()" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 130, 508 | "id": "890b23a1", 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "(10000000, 8)" 515 | ] 516 | }, 517 | "execution_count": 130, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "# train = train_all.copy()\n", 524 | "# train = train.set_index('date', drop=False).sort_index()\n", 525 | "# train.loc[train['unit_sales'] < 0, 'unit_sales'] = 0\n", 526 | "\n", 527 | "train.shape" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": 134, 533 | "id": "a6d28ce3", 534 | "metadata": {}, 535 | "outputs": [ 536 | { 537 | "name": "stdout", 538 | "output_type": "stream", 539 | "text": [ 540 | "cross-val 53.315948724746704\n", 541 | "[0.76249221 0.74114169 0.73643648]\n" 542 | ] 543 | } 544 | ], 545 | "source": [ 546 | "from sklearn.preprocessing import OrdinalEncoder, KBinsDiscretizer\n", 547 | "from sklearn.pipeline import Pipeline, make_pipeline\n", 548 | "from sklearn.compose import ColumnTransformer, make_column_transformer\n", 549 | "import time\n", 550 | "import cbm\n", 551 | "from sklearn.metrics import mean_squared_error, make_scorer, mean_squared_log_error\n", 552 | "from sklearn.model_selection import cross_val_score\n", 553 | "from sklearn.base import BaseEstimator, TransformerMixin\n", 554 | "import calendar\n", 555 | "\n", 556 | "class DateEncoder(BaseEstimator, TransformerMixin):\n", 557 | " def __init__(self, feature_name, component = 'month' ):\n", 558 | " self.feature_name = feature_name\n", 559 | " \n", 560 | " if component == 'day':\n", 561 | " self.categories = calendar.day_abbr\n", 562 | " self.column_to_ordinal = lambda col: col.dayofweek.values\n", 563 | " elif component == 'month':\n", 564 | " self.categories = calendar.month_abbr\n", 565 | " self.column_to_ordinal = lambda col: col.month.values\n", 566 | " else:\n", 567 | " raise ValueError('component must be either day or month')\n", 568 | " \n", 569 | " self.component = component\n", 570 | " \n", 571 | " def fit(self, X, y = None):\n", 572 | " return self\n", 573 | " \n", 574 | " def transform(self, X, y = None):\n", 575 | " return self.column_to_ordinal(X.iloc[:,0].dt)[:,np.newaxis]\n", 576 | "\n", 577 | "# Talk to Ilya about this use-case\n", 578 | "cats = make_column_transformer(\n", 579 | " # TODO: pass pipeline to CBM model + inspect pipeline to correlate for plotting\n", 580 | " \n", 581 | " # https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OrdinalEncoder.html\n", 582 | " (OrdinalEncoder(dtype='int', handle_unknown='use_encoded_value', unknown_value=-1), # +1 in CBM code\n", 583 | " ['store_nbr', 'item_nbr', 'onpromotion', 'family', 'class', 'perishable']),\n", 584 | " \n", 585 | " (DateEncoder('month', 'month'), ['date']),\n", 586 | " (DateEncoder('day', 'day'), ['date'])\n", 587 | " # https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.KBinsDiscretizer.html\n", 588 | " # (KBinsDiscretizer(n_bins=10, encode='ordinal', dtype='int'), [''])\n", 589 | ")\n", 590 | "\n", 591 | "cbm = cbm.CBM(learning_rate_step_size=1/500, min_iterations_early_stopping=10)\n", 592 | "\n", 593 | "pipeline = make_pipeline(\n", 594 | " cats,\n", 595 | " cbm\n", 596 | " )\n", 597 | "\n", 598 | "# model.fit(x_train, train['unit_sales'])\n", 599 | "# pipeline.fit(train.head(10), train['unit_sales'].head(10))\n", 600 | "# pipeline.fit(train.head(100000), train['unit_sales'].head(100000))\n", 601 | "\n", 602 | "\n", 603 | "# pipeline.fit(train, train['unit_sales'])\n", 604 | "# \n", 605 | "\n", 606 | "# from sklearn.model_selection import cross_val_score\n", 607 | "\n", 608 | "start = time.time()\n", 609 | "scores = cross_val_score(pipeline, train, train['unit_sales'], \n", 610 | " scoring=make_scorer(mean_squared_log_error, squared=False), \n", 611 | " cv=TemporalSplit(n_splits=3, test_size=90),\n", 612 | " n_jobs=-1\n", 613 | " )\n", 614 | "\n", 615 | "print(f'cross-val { time.time() - start}')\n", 616 | "print(scores)" 617 | ] 618 | }, 619 | { 620 | "cell_type": "code", 621 | "execution_count": 31, 622 | "id": "38c5c7be", 623 | "metadata": {}, 624 | "outputs": [ 625 | { 626 | "data": { 627 | "text/plain": [ 628 | "0" 629 | ] 630 | }, 631 | "execution_count": 31, 632 | "metadata": {}, 633 | "output_type": "execute_result" 634 | } 635 | ], 636 | "source": [ 637 | "(train['unit_sales'] < 0).sum()\n", 638 | "\n", 639 | "cross-val 53.586721658706665\n", 640 | "[0.76333752 0.74206952 0.73735652]" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": 9, 646 | "id": "4d762437", 647 | "metadata": {}, 648 | "outputs": [ 649 | { 650 | "data": { 651 | "text/plain": [ 652 | "1" 653 | ] 654 | }, 655 | "execution_count": 9, 656 | "metadata": {}, 657 | "output_type": "execute_result" 658 | } 659 | ], 660 | "source": [ 661 | "pipeline.fit(train, train['unit_sales'])\n", 662 | "1" 663 | ] 664 | }, 665 | { 666 | "cell_type": "code", 667 | "execution_count": 42, 668 | "id": "82ab90f1", 669 | "metadata": {}, 670 | "outputs": [ 671 | { 672 | "data": { 673 | "text/plain": [ 674 | "Pipeline(steps=[('columntransformer',\n", 675 | " ColumnTransformer(transformers=[('ordinalencoder',\n", 676 | " OrdinalEncoder(dtype='int',\n", 677 | " handle_unknown='use_encoded_value',\n", 678 | " unknown_value=-1),\n", 679 | " ['store_nbr', 'item_nbr',\n", 680 | " 'onpromotion', 'family',\n", 681 | " 'class', 'perishable']),\n", 682 | " ('dateencoder-1',\n", 683 | " DateEncoder(feature_name='month'),\n", 684 | " ['date']),\n", 685 | " ('dateencoder-2',\n", 686 | " DateEncoder(component='day',\n", 687 | " feature_name='day'),\n", 688 | " ['date'])]))])" 689 | ] 690 | }, 691 | "execution_count": 42, 692 | "metadata": {}, 693 | "output_type": "execute_result" 694 | } 695 | ], 696 | "source": [ 697 | "train = train_all.copy()\n", 698 | "train.loc[train['unit_sales'] < 0, 'unit_sales'] = 0\n", 699 | "\n", 700 | "pipeline_feat = make_pipeline(cats)\n", 701 | "pipeline_feat.fit(train)" 702 | ] 703 | }, 704 | { 705 | "cell_type": "code", 706 | "execution_count": 40, 707 | "id": "ccdf3418", 708 | "metadata": {}, 709 | "outputs": [ 710 | { 711 | "data": { 712 | "text/html": [ 713 | "
\n", 714 | "\n", 727 | "\n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | "
item_nbrfamilyclassperishable
096995GROCERY I10930
199197GROCERY I10670
2103501CLEANING30080
3103520GROCERY I10280
4103665BREAD/BAKERY27121
...............
40952132318GROCERY I10020
40962132945GROCERY I10260
40972132957GROCERY I10680
40982134058BEVERAGES11240
40992134244LIQUOR,WINE,BEER13640
\n", 817 | "

4100 rows × 4 columns

\n", 818 | "
" 819 | ], 820 | "text/plain": [ 821 | " item_nbr family class perishable\n", 822 | "0 96995 GROCERY I 1093 0\n", 823 | "1 99197 GROCERY I 1067 0\n", 824 | "2 103501 CLEANING 3008 0\n", 825 | "3 103520 GROCERY I 1028 0\n", 826 | "4 103665 BREAD/BAKERY 2712 1\n", 827 | "... ... ... ... ...\n", 828 | "4095 2132318 GROCERY I 1002 0\n", 829 | "4096 2132945 GROCERY I 1026 0\n", 830 | "4097 2132957 GROCERY I 1068 0\n", 831 | "4098 2134058 BEVERAGES 1124 0\n", 832 | "4099 2134244 LIQUOR,WINE,BEER 1364 0\n", 833 | "\n", 834 | "[4100 rows x 4 columns]" 835 | ] 836 | }, 837 | "execution_count": 40, 838 | "metadata": {}, 839 | "output_type": "execute_result" 840 | } 841 | ], 842 | "source": [ 843 | "items" 844 | ] 845 | }, 846 | { 847 | "cell_type": "code", 848 | "execution_count": 123, 849 | "id": "f2856ba4", 850 | "metadata": {}, 851 | "outputs": [ 852 | { 853 | "name": "stdout", 854 | "output_type": "stream", 855 | "text": [ 856 | "(3370464, 8)\n", 857 | "id int64\n", 858 | "date datetime64[ns]\n", 859 | "store_nbr int16\n", 860 | "item_nbr int64\n", 861 | "onpromotion object\n", 862 | "family object\n", 863 | "class int64\n", 864 | "perishable int64\n", 865 | "dtype: object\n" 866 | ] 867 | }, 868 | { 869 | "data": { 870 | "text/html": [ 871 | "
\n", 872 | "\n", 885 | "\n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | "
iddatestore_nbritem_nbronpromotionfamilyclassperishable
01254970402017-08-16196995FGROCERY I10930
11254970412017-08-16199197FGROCERY I10670
21254970422017-08-161103501FCLEANING30080
31254970432017-08-161103520FGROCERY I10280
41254970442017-08-161103665FBREAD/BAKERY27121
\n", 957 | "
" 958 | ], 959 | "text/plain": [ 960 | " id date store_nbr item_nbr onpromotion family class \\\n", 961 | "0 125497040 2017-08-16 1 96995 F GROCERY I 1093 \n", 962 | "1 125497041 2017-08-16 1 99197 F GROCERY I 1067 \n", 963 | "2 125497042 2017-08-16 1 103501 F CLEANING 3008 \n", 964 | "3 125497043 2017-08-16 1 103520 F GROCERY I 1028 \n", 965 | "4 125497044 2017-08-16 1 103665 F BREAD/BAKERY 2712 \n", 966 | "\n", 967 | " perishable \n", 968 | "0 0 \n", 969 | "1 0 \n", 970 | "2 0 \n", 971 | "3 0 \n", 972 | "4 1 " 973 | ] 974 | }, 975 | "execution_count": 123, 976 | "metadata": {}, 977 | "output_type": "execute_result" 978 | } 979 | ], 980 | "source": [ 981 | "items = pd.read_csv('data/items.csv')\n", 982 | "\n", 983 | "test = (pd.read_csv('data/test.csv', \n", 984 | " parse_dates=['date'], \n", 985 | " # index_col='id', \n", 986 | " dtype={\n", 987 | " 'id': np.int64,\n", 988 | " 'store_nbr': np.short,\n", 989 | " 'item_nbr': np.int64,\n", 990 | " 'unit_sales': np.float64\n", 991 | " },\n", 992 | " converters={'onpromotion': lambda x: 'T' if x == 'True' else ('F' if x == 'False' else 'U')}\n", 993 | " )\n", 994 | " .merge(items, how='left'))\n", 995 | "\n", 996 | "# test['class'] = test['class'].astype('str')\n", 997 | "# test['item_nbr'] = test['item_nbr'].astype('str')\n", 998 | "\n", 999 | "print(test.shape)\n", 1000 | "print(test.dtypes)\n", 1001 | "test.head()" 1002 | ] 1003 | }, 1004 | { 1005 | "cell_type": "code", 1006 | "execution_count": 45, 1007 | "id": "33bfade4", 1008 | "metadata": {}, 1009 | "outputs": [ 1010 | { 1011 | "data": { 1012 | "text/plain": [ 1013 | "0" 1014 | ] 1015 | }, 1016 | "execution_count": 45, 1017 | "metadata": {}, 1018 | "output_type": "execute_result" 1019 | } 1020 | ], 1021 | "source": [ 1022 | "# pipeline_feat.transform(test)\n", 1023 | "test.isna().sum(axis=1).sum()" 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "code", 1028 | "execution_count": 124, 1029 | "id": "5019d5e0", 1030 | "metadata": {}, 1031 | "outputs": [ 1032 | { 1033 | "data": { 1034 | "text/plain": [ 1035 | "array([[ 0],\n", 1036 | " [ 1],\n", 1037 | " [ 2],\n", 1038 | " ...,\n", 1039 | " [-1],\n", 1040 | " [-1],\n", 1041 | " [-1]])" 1042 | ] 1043 | }, 1044 | "execution_count": 124, 1045 | "metadata": {}, 1046 | "output_type": "execute_result" 1047 | } 1048 | ], 1049 | "source": [ 1050 | "pipeline_feat = make_pipeline(cats)\n", 1051 | "features = pipeline_feat.transform(test)\n", 1052 | "features" 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "code", 1057 | "execution_count": 135, 1058 | "id": "d473f84e", 1059 | "metadata": {}, 1060 | "outputs": [ 1061 | { 1062 | "data": { 1063 | "text/plain": [ 1064 | "Pipeline(steps=[('columntransformer',\n", 1065 | " ColumnTransformer(transformers=[('ordinalencoder',\n", 1066 | " OrdinalEncoder(dtype='int',\n", 1067 | " handle_unknown='use_encoded_value',\n", 1068 | " unknown_value=-1),\n", 1069 | " ['store_nbr', 'item_nbr',\n", 1070 | " 'onpromotion', 'family',\n", 1071 | " 'class', 'perishable']),\n", 1072 | " ('dateencoder-1',\n", 1073 | " DateEncoder(feature_name='month'),\n", 1074 | " ['date']),\n", 1075 | " ('dateencoder-2',\n", 1076 | " DateEncoder(component='day',\n", 1077 | " feature_name='day'),\n", 1078 | " ['date'])])),\n", 1079 | " ('cbm',\n", 1080 | " CBM(learning_rate_step_size=0.002,\n", 1081 | " min_iterations_early_stopping=10))])" 1082 | ] 1083 | }, 1084 | "execution_count": 135, 1085 | "metadata": {}, 1086 | "output_type": "execute_result" 1087 | } 1088 | ], 1089 | "source": [ 1090 | "# training\n", 1091 | "\n", 1092 | "train_all_all = train_all.copy()\n", 1093 | "train_all_all.loc[train_all_all['unit_sales'] < 0, 'unit_sales'] = 0\n", 1094 | "\n", 1095 | "pipeline_all = make_pipeline(cats, cbm)\n", 1096 | "pipeline_all.fit(train_all_all, train_all_all['unit_sales'])" 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "code", 1101 | "execution_count": 136, 1102 | "id": "2fb0eda4", 1103 | "metadata": {}, 1104 | "outputs": [ 1105 | { 1106 | "data": { 1107 | "text/plain": [ 1108 | "count 3.370464e+06\n", 1109 | "mean 7.313644e+00\n", 1110 | "std 1.613676e+01\n", 1111 | "min 3.352459e-01\n", 1112 | "25% 2.642382e+00\n", 1113 | "50% 4.420249e+00\n", 1114 | "75% 7.781694e+00\n", 1115 | "max 2.224308e+03\n", 1116 | "dtype: float64" 1117 | ] 1118 | }, 1119 | "execution_count": 136, 1120 | "metadata": {}, 1121 | "output_type": "execute_result" 1122 | } 1123 | ], 1124 | "source": [ 1125 | "y_pred_test = pipeline_all.predict(test)\n", 1126 | "\n", 1127 | "pd.Series(y_pred_test.flatten()).describe()" 1128 | ] 1129 | }, 1130 | { 1131 | "cell_type": "code", 1132 | "execution_count": 137, 1133 | "id": "234be585", 1134 | "metadata": {}, 1135 | "outputs": [ 1136 | { 1137 | "data": { 1138 | "text/plain": [ 1139 | "7.535582 21\n", 1140 | "7.772725 21\n", 1141 | "5.474212 21\n", 1142 | "5.866997 21\n", 1143 | "7.266502 21\n", 1144 | " ..\n", 1145 | "2.044499 1\n", 1146 | "9.149783 1\n", 1147 | "6.819987 1\n", 1148 | "2.328584 1\n", 1149 | "8.180235 1\n", 1150 | "Length: 1539115, dtype: int64" 1151 | ] 1152 | }, 1153 | "execution_count": 137, 1154 | "metadata": {}, 1155 | "output_type": "execute_result" 1156 | } 1157 | ], 1158 | "source": [ 1159 | "pd.Series(y_pred_test.flatten()).value_counts()" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "execution_count": 138, 1165 | "id": "6c151939", 1166 | "metadata": {}, 1167 | "outputs": [], 1168 | "source": [ 1169 | "submission = test.copy()\n", 1170 | "submission.loc[:, 'unit_sales'] = y_pred_test.flatten()\n", 1171 | "submission[['id', 'unit_sales']].to_csv('submission.csv', index=False)" 1172 | ] 1173 | }, 1174 | { 1175 | "cell_type": "code", 1176 | "execution_count": 128, 1177 | "id": "abba959f", 1178 | "metadata": {}, 1179 | "outputs": [ 1180 | { 1181 | "name": "stdout", 1182 | "output_type": "stream", 1183 | "text": [ 1184 | "id,unit_sales\r\n", 1185 | "125497040,2.5024013377093564\r\n", 1186 | "125497041,4.222599293880319\r\n", 1187 | "125497042,4.873269254461008\r\n", 1188 | "125497043,4.187582054240339\r\n", 1189 | "125497044,4.932258731082357\r\n", 1190 | "125497045,13.245315426347876\r\n", 1191 | "125497046,21.230761691507237\r\n", 1192 | "125497047,14.97772186798399\r\n", 1193 | "125497048,6.772408878599083\r\n" 1194 | ] 1195 | } 1196 | ], 1197 | "source": [ 1198 | "!head submission.csv" 1199 | ] 1200 | }, 1201 | { 1202 | "cell_type": "code", 1203 | "execution_count": 139, 1204 | "id": "90c762f1", 1205 | "metadata": {}, 1206 | "outputs": [ 1207 | { 1208 | "name": "stdout", 1209 | "output_type": "stream", 1210 | "text": [ 1211 | "100%|██████████████████████████████████████| 91.2M/91.2M [00:17<00:00, 5.61MB/s]\n", 1212 | "Successfully submitted to Corporación Favorita Grocery Sales Forecasting" 1213 | ] 1214 | } 1215 | ], 1216 | "source": [ 1217 | "!kaggle competitions submit -c favorita-grocery-sales-forecasting -f submission.csv -m sklearn1" 1218 | ] 1219 | }, 1220 | { 1221 | "cell_type": "code", 1222 | "execution_count": null, 1223 | "id": "523c44df", 1224 | "metadata": {}, 1225 | "outputs": [], 1226 | "source": [] 1227 | }, 1228 | { 1229 | "cell_type": "code", 1230 | "execution_count": null, 1231 | "id": "3a0ddf3d", 1232 | "metadata": {}, 1233 | "outputs": [], 1234 | "source": [] 1235 | }, 1236 | { 1237 | "cell_type": "code", 1238 | "execution_count": 32, 1239 | "id": "5b8d18ac", 1240 | "metadata": {}, 1241 | "outputs": [ 1242 | { 1243 | "data": { 1244 | "text/plain": [ 1245 | "array([[9.46390226],\n", 1246 | " [9.46412469],\n", 1247 | " [9.48171461],\n", 1248 | " [9.48268984],\n", 1249 | " [9.46475154]])" 1250 | ] 1251 | }, 1252 | "execution_count": 32, 1253 | "metadata": {}, 1254 | "output_type": "execute_result" 1255 | } 1256 | ], 1257 | "source": [ 1258 | "pipeline.predict(train.head())" 1259 | ] 1260 | }, 1261 | { 1262 | "cell_type": "code", 1263 | "execution_count": 18, 1264 | "id": "c6c389af", 1265 | "metadata": {}, 1266 | "outputs": [ 1267 | { 1268 | "data": { 1269 | "text/plain": [ 1270 | "(125497040, 8)" 1271 | ] 1272 | }, 1273 | "execution_count": 18, 1274 | "metadata": {}, 1275 | "output_type": "execute_result" 1276 | } 1277 | ], 1278 | "source": [ 1279 | "train.shape" 1280 | ] 1281 | }, 1282 | { 1283 | "cell_type": "code", 1284 | "execution_count": 54, 1285 | "id": "61fca780", 1286 | "metadata": {}, 1287 | "outputs": [ 1288 | { 1289 | "data": { 1290 | "text/plain": [ 1291 | "172197" 1292 | ] 1293 | }, 1294 | "execution_count": 54, 1295 | "metadata": {}, 1296 | "output_type": "execute_result" 1297 | } 1298 | ], 1299 | "source": [ 1300 | "from collections import defaultdict\n", 1301 | "\n", 1302 | "item_store_map = defaultdict(int)\n", 1303 | "\n", 1304 | "# have the first item as back-off\n", 1305 | "item_store_map.update({(row['item_nbr'], row['store_nbr']): idx + 1 for idx, row in train[['item_nbr','store_nbr']].value_counts(ascending=True).reset_index(name='count').query('count > 5').iterrows()})\n", 1306 | "\n", 1307 | "len(item_store_map)" 1308 | ] 1309 | }, 1310 | { 1311 | "cell_type": "code", 1312 | "execution_count": 55, 1313 | "id": "3fe9b08f", 1314 | "metadata": {}, 1315 | "outputs": [], 1316 | "source": [ 1317 | "train['item_store'] = train[['item_nbr','store_nbr']].apply(lambda x: item_store_map[tuple(x)], axis=1)" 1318 | ] 1319 | }, 1320 | { 1321 | "cell_type": "code", 1322 | "execution_count": 3, 1323 | "id": "c7a3237f", 1324 | "metadata": {}, 1325 | "outputs": [ 1326 | { 1327 | "data": { 1328 | "text/html": [ 1329 | "
\n", 1330 | "\n", 1343 | "\n", 1344 | " \n", 1345 | " \n", 1346 | " \n", 1347 | " \n", 1348 | " \n", 1349 | " \n", 1350 | " \n", 1351 | " \n", 1352 | " \n", 1353 | " \n", 1354 | " \n", 1355 | " \n", 1356 | " \n", 1357 | " \n", 1358 | " \n", 1359 | " \n", 1360 | " \n", 1361 | " \n", 1362 | " \n", 1363 | " \n", 1364 | " \n", 1365 | " \n", 1366 | " \n", 1367 | " \n", 1368 | " \n", 1369 | " \n", 1370 | " \n", 1371 | " \n", 1372 | " \n", 1373 | " \n", 1374 | " \n", 1375 | " \n", 1376 | " \n", 1377 | " \n", 1378 | " \n", 1379 | " \n", 1380 | " \n", 1381 | " \n", 1382 | " \n", 1383 | " \n", 1384 | " \n", 1385 | " \n", 1386 | " \n", 1387 | " \n", 1388 | " \n", 1389 | " \n", 1390 | " \n", 1391 | " \n", 1392 | " \n", 1393 | " \n", 1394 | " \n", 1395 | " \n", 1396 | " \n", 1397 | " \n", 1398 | " \n", 1399 | " \n", 1400 | " \n", 1401 | " \n", 1402 | " \n", 1403 | " \n", 1404 | " \n", 1405 | " \n", 1406 | " \n", 1407 | " \n", 1408 | " \n", 1409 | " \n", 1410 | " \n", 1411 | " \n", 1412 | " \n", 1413 | " \n", 1414 | "
datestore_nbritem_nbrunit_salesonpromotionfamilyclassperishable
02013-01-012547.0UBREAD/BAKERY1841
12013-01-02142.0UBREAD/BAKERY1841
22013-01-02245.0UBREAD/BAKERY1841
32013-01-02346.0UBREAD/BAKERY1841
42013-01-02442.0UBREAD/BAKERY1841
\n", 1415 | "
" 1416 | ], 1417 | "text/plain": [ 1418 | " date store_nbr item_nbr unit_sales onpromotion family \\\n", 1419 | "0 2013-01-01 25 4 7.0 U BREAD/BAKERY \n", 1420 | "1 2013-01-02 1 4 2.0 U BREAD/BAKERY \n", 1421 | "2 2013-01-02 2 4 5.0 U BREAD/BAKERY \n", 1422 | "3 2013-01-02 3 4 6.0 U BREAD/BAKERY \n", 1423 | "4 2013-01-02 4 4 2.0 U BREAD/BAKERY \n", 1424 | "\n", 1425 | " class perishable \n", 1426 | "0 184 1 \n", 1427 | "1 184 1 \n", 1428 | "2 184 1 \n", 1429 | "3 184 1 \n", 1430 | "4 184 1 " 1431 | ] 1432 | }, 1433 | "execution_count": 3, 1434 | "metadata": {}, 1435 | "output_type": "execute_result" 1436 | } 1437 | ], 1438 | "source": [ 1439 | "train = pd.read_parquet('data/train_items.parquet')\n", 1440 | "\n", 1441 | "class_map = {x: i for i, x in enumerate(np.sort(train['class'].unique()))}\n", 1442 | "item_nbr_map = {x: i for i, x in enumerate(np.sort(train['item_nbr'].unique()))}\n", 1443 | "\n", 1444 | "train['class'] = train['class'].map(class_map)\n", 1445 | "train['item_nbr'] = train['item_nbr'].map(item_nbr_map)\n", 1446 | "\n", 1447 | "train.head()" 1448 | ] 1449 | }, 1450 | { 1451 | "cell_type": "code", 1452 | "execution_count": 4, 1453 | "id": "8b7b478c", 1454 | "metadata": {}, 1455 | "outputs": [], 1456 | "source": [ 1457 | "\n", 1458 | "# # train['unit_sales'] = train['unit_sales'].astype(np.int32)\n", 1459 | "# train[['unit_sales']].to_parquet('data/train_unit_sales.parquet')\n", 1460 | "\n", 1461 | "# x_train = pd.read_parquet('data/train_items_featurized.parquet')\n", 1462 | "# x_train_unit_sales = pd.read_parquet('data/train_unit_sales.parquet')" 1463 | ] 1464 | }, 1465 | { 1466 | "cell_type": "code", 1467 | "execution_count": null, 1468 | "id": "7c852782", 1469 | "metadata": {}, 1470 | "outputs": [], 1471 | "source": [] 1472 | }, 1473 | { 1474 | "cell_type": "code", 1475 | "execution_count": null, 1476 | "id": "785a97c4", 1477 | "metadata": {}, 1478 | "outputs": [], 1479 | "source": [] 1480 | }, 1481 | { 1482 | "cell_type": "code", 1483 | "execution_count": 5, 1484 | "id": "5e7fc02a", 1485 | "metadata": {}, 1486 | "outputs": [ 1487 | { 1488 | "data": { 1489 | "text/plain": [ 1490 | "8191095" 1491 | ] 1492 | }, 1493 | "execution_count": 5, 1494 | "metadata": {}, 1495 | "output_type": "execute_result" 1496 | } 1497 | ], 1498 | "source": [ 1499 | "(((train['unit_sales'] * 10) % 10 ) > 0).sum()" 1500 | ] 1501 | }, 1502 | { 1503 | "cell_type": "code", 1504 | "execution_count": null, 1505 | "id": "a8534d07", 1506 | "metadata": {}, 1507 | "outputs": [], 1508 | "source": [] 1509 | }, 1510 | { 1511 | "cell_type": "code", 1512 | "execution_count": 6, 1513 | "id": "e8f9d08e", 1514 | "metadata": {}, 1515 | "outputs": [ 1516 | { 1517 | "data": { 1518 | "text/plain": [ 1519 | "0.06526922866069192" 1520 | ] 1521 | }, 1522 | "execution_count": 6, 1523 | "metadata": {}, 1524 | "output_type": "execute_result" 1525 | } 1526 | ], 1527 | "source": [ 1528 | "8191095 / len(train)" 1529 | ] 1530 | }, 1531 | { 1532 | "cell_type": "code", 1533 | "execution_count": 7, 1534 | "id": "adfe5485", 1535 | "metadata": {}, 1536 | "outputs": [ 1537 | { 1538 | "name": "stdout", 1539 | "output_type": "stream", 1540 | "text": [ 1541 | "featurize 1.988840103149414\n", 1542 | "train 59.32568883895874\n", 1543 | "iterations 3\n", 1544 | "23.605482539271893\n" 1545 | ] 1546 | } 1547 | ], 1548 | "source": [ 1549 | "import cbm\n", 1550 | "import time\n", 1551 | "from sklearn.model_selection import TimeSeriesSplit\n", 1552 | "from sklearn.metrics import mean_squared_error\n", 1553 | "\n", 1554 | "def featurize(df):\n", 1555 | " return pd.DataFrame({\n", 1556 | " 'store_nbr' : df['store_nbr'],\n", 1557 | " 'item_nbr' : df['item_nbr'],\n", 1558 | " 'onpromotion' : df['onpromotion'],\n", 1559 | " 'family' : df['family'],\n", 1560 | " 'class' : df['class'],\n", 1561 | " 'perishable' : df['perishable'],\n", 1562 | " 'date' : df['date'],\n", 1563 | " })\n", 1564 | "\n", 1565 | "start = time.time()\n", 1566 | "\n", 1567 | "x_train = featurize(train)\n", 1568 | "\n", 1569 | "print(f'featurize { time.time() - start}')\n", 1570 | "\n", 1571 | "# enable_bin_count=True) # \n", 1572 | "model = cbm.CBM(learning_rate_step_size=1/64000, min_iterations_early_stopping=2)\n", 1573 | "model.fit(x_train, train['unit_sales'])\n", 1574 | "\n", 1575 | "print(f'train {time.time() - start}')\n", 1576 | "print(f'iterations {model.iterations}')\n", 1577 | "\n", 1578 | "y_pred_train = model.predict(x_train).flatten()\n", 1579 | "\n", 1580 | "rmsle = mean_squared_error(train['unit_sales'], y_pred_train, squared=False)\n", 1581 | "print(rmsle)\n", 1582 | "\n", 1583 | "# 29sec - for store/item\n", 1584 | "# 76sec - for store/onprom/family/class/perishable - 612k\n", 1585 | "# 589k w/ 1/2000 learning rate\n", 1586 | "# 612k w/ 1/200\n", 1587 | "# 132 w/ 1/4000\n", 1588 | "# 37 w/ 1/8000\n", 1589 | "# 25 w/ 1/16000 \n", 1590 | "# 23 w 1/32000 it>=15\n", 1591 | "# 23 w 1/32000 it>=5\n", 1592 | "# 23 w 1/64000 it>=2\n", 1593 | "# 23.60 w 1/64000 it>=2 + item_nbr\n", 1594 | "# 23.60 w 1/64000 it>=2 + item_nbr + date" 1595 | ] 1596 | }, 1597 | { 1598 | "cell_type": "code", 1599 | "execution_count": 9, 1600 | "id": "b672d0c7", 1601 | "metadata": {}, 1602 | "outputs": [ 1603 | { 1604 | "data": { 1605 | "text/plain": [ 1606 | "333" 1607 | ] 1608 | }, 1609 | "execution_count": 9, 1610 | "metadata": {}, 1611 | "output_type": "execute_result" 1612 | } 1613 | ], 1614 | "source": [ 1615 | "x_train['class'].max()" 1616 | ] 1617 | }, 1618 | { 1619 | "cell_type": "code", 1620 | "execution_count": 10, 1621 | "id": "33d38059", 1622 | "metadata": {}, 1623 | "outputs": [ 1624 | { 1625 | "data": { 1626 | "text/plain": [ 1627 | "4035" 1628 | ] 1629 | }, 1630 | "execution_count": 10, 1631 | "metadata": {}, 1632 | "output_type": "execute_result" 1633 | } 1634 | ], 1635 | "source": [ 1636 | "x_train['item_nbr'].max()" 1637 | ] 1638 | }, 1639 | { 1640 | "cell_type": "code", 1641 | "execution_count": 8, 1642 | "id": "09310633", 1643 | "metadata": {}, 1644 | "outputs": [], 1645 | "source": [ 1646 | "# model.plot_importance(figsize=(20, 20))" 1647 | ] 1648 | }, 1649 | { 1650 | "cell_type": "code", 1651 | "execution_count": 45, 1652 | "id": "052d4c10", 1653 | "metadata": {}, 1654 | "outputs": [ 1655 | { 1656 | "name": "stderr", 1657 | "output_type": "stream", 1658 | "text": [ 1659 | "/home/marcozo/miniconda3/envs/cbm/lib/python3.7/site-packages/numpy/lib/arraysetops.py:583: FutureWarning: elementwise comparison failed; returning scalar instead, but in the future will perform elementwise comparison\n", 1660 | " mask |= (ar1 == a)\n" 1661 | ] 1662 | }, 1663 | { 1664 | "data": { 1665 | "text/html": [ 1666 | "
\n", 1667 | "\n", 1680 | "\n", 1681 | " \n", 1682 | " \n", 1683 | " \n", 1684 | " \n", 1685 | " \n", 1686 | " \n", 1687 | " \n", 1688 | " \n", 1689 | " \n", 1690 | " \n", 1691 | " \n", 1692 | " \n", 1693 | " \n", 1694 | " \n", 1695 | " \n", 1696 | " \n", 1697 | " \n", 1698 | " \n", 1699 | " \n", 1700 | " \n", 1701 | " \n", 1702 | " \n", 1703 | " \n", 1704 | " \n", 1705 | " \n", 1706 | " \n", 1707 | " \n", 1708 | " \n", 1709 | " \n", 1710 | " \n", 1711 | " \n", 1712 | " \n", 1713 | " \n", 1714 | " \n", 1715 | " \n", 1716 | " \n", 1717 | " \n", 1718 | " \n", 1719 | " \n", 1720 | " \n", 1721 | " \n", 1722 | " \n", 1723 | " \n", 1724 | " \n", 1725 | " \n", 1726 | " \n", 1727 | " \n", 1728 | " \n", 1729 | " \n", 1730 | " \n", 1731 | " \n", 1732 | " \n", 1733 | " \n", 1734 | " \n", 1735 | " \n", 1736 | " \n", 1737 | " \n", 1738 | " \n", 1739 | " \n", 1740 | " \n", 1741 | " \n", 1742 | " \n", 1743 | " \n", 1744 | " \n", 1745 | "
datestore_nbritem_nbronpromotionfamilyclassperishable
02017-08-1610.0FGROCERY I63.00
12017-08-1620.0FGROCERY I63.00
22017-08-1630.0FGROCERY I63.00
32017-08-1640.0FGROCERY I63.00
42017-08-1650.0FGROCERY I63.00
\n", 1746 | "
" 1747 | ], 1748 | "text/plain": [ 1749 | " date store_nbr item_nbr onpromotion family class perishable\n", 1750 | "0 2017-08-16 1 0.0 F GROCERY I 63.0 0\n", 1751 | "1 2017-08-16 2 0.0 F GROCERY I 63.0 0\n", 1752 | "2 2017-08-16 3 0.0 F GROCERY I 63.0 0\n", 1753 | "3 2017-08-16 4 0.0 F GROCERY I 63.0 0\n", 1754 | "4 2017-08-16 5 0.0 F GROCERY I 63.0 0" 1755 | ] 1756 | }, 1757 | "execution_count": 45, 1758 | "metadata": {}, 1759 | "output_type": "execute_result" 1760 | } 1761 | ], 1762 | "source": [ 1763 | "items = pd.read_csv('data/items.csv')\n", 1764 | "\n", 1765 | "test = pd.read_csv('data/test.csv',\n", 1766 | " parse_dates=['date'], \n", 1767 | " index_col='id', \n", 1768 | " dtype={\n", 1769 | " # 'date': np.datetime64, \n", 1770 | " 'store_nbr': np.short,\n", 1771 | " 'item_nbr': np.int64,\n", 1772 | " 'unit_sales': np.float64\n", 1773 | " },\n", 1774 | ").merge(items)\n", 1775 | "\n", 1776 | "test['onpromotion'] = test['onpromotion'].map({True: 'T', False: 'F'})\n", 1777 | "test['class'] = test['class'].map(class_map)\n", 1778 | "test['item_nbr'] = test['item_nbr'].map(item_nbr_map)\n", 1779 | "\n", 1780 | "test.head()" 1781 | ] 1782 | }, 1783 | { 1784 | "cell_type": "code", 1785 | "execution_count": 53, 1786 | "id": "c4a482dc", 1787 | "metadata": {}, 1788 | "outputs": [ 1789 | { 1790 | "data": { 1791 | "text/html": [ 1792 | "
\n", 1793 | "\n", 1806 | "\n", 1807 | " \n", 1808 | " \n", 1809 | " \n", 1810 | " \n", 1811 | " \n", 1812 | " \n", 1813 | " \n", 1814 | " \n", 1815 | " \n", 1816 | " \n", 1817 | " \n", 1818 | " \n", 1819 | " \n", 1820 | " \n", 1821 | "
datestore_nbritem_nbronpromotionfamilyclassperishable
\n", 1822 | "
" 1823 | ], 1824 | "text/plain": [ 1825 | "Empty DataFrame\n", 1826 | "Columns: [date, store_nbr, item_nbr, onpromotion, family, class, perishable]\n", 1827 | "Index: []" 1828 | ] 1829 | }, 1830 | "execution_count": 53, 1831 | "metadata": {}, 1832 | "output_type": "execute_result" 1833 | } 1834 | ], 1835 | "source": [ 1836 | "# TODO: handle NA by multiplying by 1\n", 1837 | "test['item_nbr'] = test['item_nbr'].fillna(0).astype(int)\n", 1838 | "test['class'] = test['class'] .fillna(0).astype(int)\n", 1839 | "test[test.isna().any(axis=1)]" 1840 | ] 1841 | }, 1842 | { 1843 | "cell_type": "code", 1844 | "execution_count": 12, 1845 | "id": "78377989", 1846 | "metadata": {}, 1847 | "outputs": [], 1848 | "source": [ 1849 | "# class_cats = train_raw['class'].astype('category').cat.categories.tolist()\n", 1850 | "\n", 1851 | "# test['class'] = test['class'].astype(pd.CategoricalDtype(categories=class_cats, ordered=True)).cat.codes\n", 1852 | "# test.head()" 1853 | ] 1854 | }, 1855 | { 1856 | "cell_type": "code", 1857 | "execution_count": 54, 1858 | "id": "288bb676", 1859 | "metadata": {}, 1860 | "outputs": [ 1861 | { 1862 | "data": { 1863 | "text/html": [ 1864 | "
\n", 1865 | "\n", 1878 | "\n", 1879 | " \n", 1880 | " \n", 1881 | " \n", 1882 | " \n", 1883 | " \n", 1884 | " \n", 1885 | " \n", 1886 | " \n", 1887 | " \n", 1888 | " \n", 1889 | " \n", 1890 | " \n", 1891 | " \n", 1892 | " \n", 1893 | " \n", 1894 | " \n", 1895 | " \n", 1896 | " \n", 1897 | " \n", 1898 | " \n", 1899 | " \n", 1900 | " \n", 1901 | " \n", 1902 | " \n", 1903 | " \n", 1904 | " \n", 1905 | " \n", 1906 | " \n", 1907 | " \n", 1908 | " \n", 1909 | " \n", 1910 | " \n", 1911 | " \n", 1912 | " \n", 1913 | " \n", 1914 | " \n", 1915 | " \n", 1916 | " \n", 1917 | " \n", 1918 | " \n", 1919 | " \n", 1920 | " \n", 1921 | " \n", 1922 | " \n", 1923 | " \n", 1924 | " \n", 1925 | " \n", 1926 | " \n", 1927 | " \n", 1928 | " \n", 1929 | " \n", 1930 | " \n", 1931 | " \n", 1932 | " \n", 1933 | " \n", 1934 | " \n", 1935 | " \n", 1936 | " \n", 1937 | " \n", 1938 | " \n", 1939 | " \n", 1940 | " \n", 1941 | " \n", 1942 | " \n", 1943 | "
store_nbritem_nbronpromotionfamilyclassperishabledate
010FGROCERY I6302017-08-16
120FGROCERY I6302017-08-16
230FGROCERY I6302017-08-16
340FGROCERY I6302017-08-16
450FGROCERY I6302017-08-16
\n", 1944 | "
" 1945 | ], 1946 | "text/plain": [ 1947 | " store_nbr item_nbr onpromotion family class perishable date\n", 1948 | "0 1 0 F GROCERY I 63 0 2017-08-16\n", 1949 | "1 2 0 F GROCERY I 63 0 2017-08-16\n", 1950 | "2 3 0 F GROCERY I 63 0 2017-08-16\n", 1951 | "3 4 0 F GROCERY I 63 0 2017-08-16\n", 1952 | "4 5 0 F GROCERY I 63 0 2017-08-16" 1953 | ] 1954 | }, 1955 | "execution_count": 54, 1956 | "metadata": {}, 1957 | "output_type": "execute_result" 1958 | } 1959 | ], 1960 | "source": [ 1961 | "x_test = featurize(test)\n", 1962 | "x_test.head()" 1963 | ] 1964 | }, 1965 | { 1966 | "cell_type": "code", 1967 | "execution_count": 55, 1968 | "id": "e2f2f2e7", 1969 | "metadata": {}, 1970 | "outputs": [ 1971 | { 1972 | "data": { 1973 | "text/plain": [ 1974 | "array([[8.67074847, 1.00170424, 1.00180464, ..., 1.00164436, 1.00164009,\n", 1975 | " 1.00161252],\n", 1976 | " [8.67054548, 1.00168079, 1.00180464, ..., 1.00164436, 1.00164009,\n", 1977 | " 1.00161252],\n", 1978 | " [8.67063081, 1.00169065, 1.00180464, ..., 1.00164436, 1.00164009,\n", 1979 | " 1.00161252],\n", 1980 | " ...,\n", 1981 | " [8.67042574, 1.00148041, 1.00180464, ..., 1.00164436, 1.0016697 ,\n", 1982 | " 1.00161252],\n", 1983 | " [8.67126881, 1.00157779, 1.00180464, ..., 1.00164436, 1.0016697 ,\n", 1984 | " 1.00161252],\n", 1985 | " [8.67073749, 1.00151642, 1.00180464, ..., 1.00164436, 1.0016697 ,\n", 1986 | " 1.00161252]])" 1987 | ] 1988 | }, 1989 | "execution_count": 55, 1990 | "metadata": {}, 1991 | "output_type": "execute_result" 1992 | } 1993 | ], 1994 | "source": [ 1995 | "y_pred_test = model.predict(x_test, explain=True) #.flatten()\n", 1996 | "y_pred_test" 1997 | ] 1998 | }, 1999 | { 2000 | "cell_type": "code", 2001 | "execution_count": 57, 2002 | "id": "6483a859", 2003 | "metadata": {}, 2004 | "outputs": [ 2005 | { 2006 | "data": { 2007 | "text/plain": [ 2008 | "8.666806 21\n", 2009 | "8.666954 21\n", 2010 | "8.666864 21\n", 2011 | "8.667392 21\n", 2012 | "8.666751 21\n", 2013 | " ..\n", 2014 | "8.656977 1\n", 2015 | "8.659257 1\n", 2016 | "8.656693 1\n", 2017 | "8.656806 1\n", 2018 | "8.641494 1\n", 2019 | "Length: 1538737, dtype: int64" 2020 | ] 2021 | }, 2022 | "execution_count": 57, 2023 | "metadata": {}, 2024 | "output_type": "execute_result" 2025 | } 2026 | ], 2027 | "source": [ 2028 | "pd.Series(y_pred_test[:,0].flatten()).value_counts()" 2029 | ] 2030 | }, 2031 | { 2032 | "cell_type": "code", 2033 | "execution_count": 58, 2034 | "id": "2798c44f", 2035 | "metadata": {}, 2036 | "outputs": [], 2037 | "source": [ 2038 | "y_pred_test = model.predict(x_test)" 2039 | ] 2040 | }, 2041 | { 2042 | "cell_type": "code", 2043 | "execution_count": 64, 2044 | "id": "d3d53829", 2045 | "metadata": {}, 2046 | "outputs": [], 2047 | "source": [ 2048 | "test[['unit_sales']].index.rename('id', inplace=True)\n", 2049 | "test['unit_sales'] = y_pred_test\n", 2050 | "test[['unit_sales']].to_csv('submission.csv', index=True)" 2051 | ] 2052 | }, 2053 | { 2054 | "cell_type": "code", 2055 | "execution_count": 65, 2056 | "id": "5ae1ac1b", 2057 | "metadata": {}, 2058 | "outputs": [], 2059 | "source": [ 2060 | "test['unit_sales'] = y_pred_test\n", 2061 | "test[['unit_sales']].to_csv('submission.csv', index=True)" 2062 | ] 2063 | }, 2064 | { 2065 | "cell_type": "code", 2066 | "execution_count": 66, 2067 | "id": "1f3c2e53", 2068 | "metadata": {}, 2069 | "outputs": [ 2070 | { 2071 | "name": "stdout", 2072 | "output_type": "stream", 2073 | "text": [ 2074 | "id,unit_sales\r\n", 2075 | "0,8.670748466911578\r\n", 2076 | "1,8.670545482030695\r\n", 2077 | "2,8.670630814956288\r\n", 2078 | "3,8.67052676822758\r\n", 2079 | "4,8.66955171527891\r\n", 2080 | "5,8.670202952780572\r\n", 2081 | "6,8.670841233962294\r\n", 2082 | "7,8.670437318412505\r\n", 2083 | "8,8.669850978368242\r\n" 2084 | ] 2085 | } 2086 | ], 2087 | "source": [ 2088 | "!head submission.csv" 2089 | ] 2090 | }, 2091 | { 2092 | "cell_type": "code", 2093 | "execution_count": 20, 2094 | "id": "db2247b6", 2095 | "metadata": {}, 2096 | "outputs": [ 2097 | { 2098 | "data": { 2099 | "text/plain": [ 2100 | "8.665899 28\n", 2101 | "8.649263 26\n", 2102 | "8.649260 26\n", 2103 | "8.649381 24\n", 2104 | "8.649342 22\n", 2105 | " ..\n", 2106 | "8.652211 1\n", 2107 | "8.659009 1\n", 2108 | "8.664719 1\n", 2109 | "8.652154 1\n", 2110 | "8.634132 1\n", 2111 | "Length: 20537444, dtype: int64" 2112 | ] 2113 | }, 2114 | "execution_count": 20, 2115 | "metadata": {}, 2116 | "output_type": "execute_result" 2117 | } 2118 | ], 2119 | "source": [ 2120 | "pd.Series(y_pred_train).value_counts()" 2121 | ] 2122 | }, 2123 | { 2124 | "cell_type": "code", 2125 | "execution_count": 67, 2126 | "id": "6c9e45a5", 2127 | "metadata": {}, 2128 | "outputs": [ 2129 | { 2130 | "name": "stdout", 2131 | "output_type": "stream", 2132 | "text": [ 2133 | "100%|██████████████████████████████████████| 81.9M/81.9M [00:21<00:00, 3.94MB/s]\n", 2134 | "Successfully submitted to Corporación Favorita Grocery Sales Forecasting" 2135 | ] 2136 | } 2137 | ], 2138 | "source": [ 2139 | "!kaggle competitions submit -c favorita-grocery-sales-forecasting -f submission.csv -m v1" 2140 | ] 2141 | } 2142 | ], 2143 | "metadata": { 2144 | "kernelspec": { 2145 | "display_name": "Python 3 (ipykernel)", 2146 | "language": "python", 2147 | "name": "python3" 2148 | }, 2149 | "language_info": { 2150 | "codemirror_mode": { 2151 | "name": "ipython", 2152 | "version": 3 2153 | }, 2154 | "file_extension": ".py", 2155 | "mimetype": "text/x-python", 2156 | "name": "python", 2157 | "nbconvert_exporter": "python", 2158 | "pygments_lexer": "ipython3", 2159 | "version": "3.7.11" 2160 | } 2161 | }, 2162 | "nbformat": 4, 2163 | "nbformat_minor": 5 2164 | } 2165 | --------------------------------------------------------------------------------