├── nbs ├── .gitignore ├── figs │ ├── forecast__ercot.png │ ├── forecast__predict.png │ ├── cross_validation__series.png │ ├── forecast__cross_validation.png │ ├── cross_validation__predictions.png │ ├── forecast__predict_intervals.png │ ├── quick_start_distributed__sample.png │ ├── forecast__cross_validation_intervals.png │ ├── forecast__predict_intervals_window_size_1.png │ ├── electricity_peak_forecasting__predicted_peak.png │ └── quick_start_distributed__sample_prediction.png ├── nbdev.yml ├── styles.css ├── sidebar.yml ├── distributed.models.ray.lgb.ipynb ├── _quarto.yml ├── distributed.models.dask.xgb.ipynb ├── distributed.models.ray.xgb.ipynb ├── distributed.models.dask.lgb.ipynb ├── distributed.models.spark.lgb.ipynb ├── distributed.models.spark.xgb.ipynb ├── docs │ ├── install.ipynb │ ├── cross_validation.ipynb │ └── quick_start_distributed.ipynb └── index.ipynb ├── .github ├── CODEOWNERS ├── workflows │ ├── deploy.yaml │ ├── lint.yaml │ ├── release.yml │ └── ci.yaml ├── pull_request_template.md ├── release-drafter.yml └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── mlforecast ├── distributed │ ├── models │ │ ├── __init__.py │ │ ├── dask │ │ │ ├── __init__.py │ │ │ ├── xgb.py │ │ │ └── lgb.py │ │ ├── ray │ │ │ ├── __init__.py │ │ │ ├── lgb.py │ │ │ └── xgb.py │ │ └── spark │ │ │ ├── __init__.py │ │ │ ├── lgb.py │ │ │ └── xgb.py │ ├── __init__.py │ └── forecast.py ├── __init__.py ├── utils.py ├── lgb_cv.py └── _modidx.py ├── .mypy.ini ├── .gitattributes ├── figs └── index.png ├── .pylintrc ├── action_files ├── lint ├── clean_nbs └── remove_logs_cells ├── MANIFEST.in ├── local_environment.yml ├── environment.yml ├── settings.ini ├── .gitignore ├── setup.py ├── CONTRIBUTING.md ├── LICENSE └── README.md /nbs/.gitignore: -------------------------------------------------------------------------------- 1 | /.quarto/ 2 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @jmoralez 2 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/dask/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/ray/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/spark/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | nbs/** linguist-documentation 2 | *.ipynb merge=nbdev-merge 3 | -------------------------------------------------------------------------------- /figs/index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/figs/index.png -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | ignore=_nbdev.py 3 | 4 | [MESSAGES CONTROL] 5 | disable=all 6 | enable=W0612,W0613 7 | -------------------------------------------------------------------------------- /nbs/figs/forecast__ercot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/forecast__ercot.png -------------------------------------------------------------------------------- /nbs/figs/forecast__predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/forecast__predict.png -------------------------------------------------------------------------------- /mlforecast/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.0" 2 | __all__ = ['MLForecast'] 3 | from mlforecast.forecast import MLForecast 4 | -------------------------------------------------------------------------------- /nbs/figs/cross_validation__series.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/cross_validation__series.png -------------------------------------------------------------------------------- /action_files/lint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | mypy mlforecast || exit -1 3 | flake8 --select=F mlforecast || exit -1 4 | pylint mlforecast 5 | -------------------------------------------------------------------------------- /nbs/figs/forecast__cross_validation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/forecast__cross_validation.png -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /nbs/figs/cross_validation__predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/cross_validation__predictions.png -------------------------------------------------------------------------------- /nbs/figs/forecast__predict_intervals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/forecast__predict_intervals.png -------------------------------------------------------------------------------- /mlforecast/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['DistributedMLForecast'] 2 | from mlforecast.distributed.forecast import DistributedMLForecast 3 | -------------------------------------------------------------------------------- /nbs/figs/quick_start_distributed__sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/quick_start_distributed__sample.png -------------------------------------------------------------------------------- /nbs/figs/forecast__cross_validation_intervals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/forecast__cross_validation_intervals.png -------------------------------------------------------------------------------- /nbs/figs/forecast__predict_intervals_window_size_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/forecast__predict_intervals_window_size_1.png -------------------------------------------------------------------------------- /nbs/figs/electricity_peak_forecasting__predicted_peak.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/electricity_peak_forecasting__predicted_peak.png -------------------------------------------------------------------------------- /nbs/figs/quick_start_distributed__sample_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pietroppeter/mlforecast/main/nbs/figs/quick_start_distributed__sample_prediction.png -------------------------------------------------------------------------------- /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | on: 3 | push: 4 | branches: ["main"] 5 | workflow_dispatch: 6 | jobs: 7 | deploy: 8 | runs-on: ubuntu-latest 9 | steps: [uses: fastai/workflows/quarto-ghp@master] 10 | -------------------------------------------------------------------------------- /nbs/nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "mlforecast" 6 | site-url: "https://Nixtla.github.io/" 7 | description: "Scalable machine learning based time series forecasting" 8 | repo-branch: main 9 | repo-url: "https://github.com/Nixtla/mlforecast" 10 | -------------------------------------------------------------------------------- /action_files/clean_nbs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | nbdev_clean 3 | # the following sets the kernel as python 3 to avoid annoying diffs 4 | for file in $(find nbs/ -type f -name "*.ipynb") 5 | do 6 | sed -i 's/Python 3.*,$/Python 3\",/g' $file 7 | done 8 | # distributed training produces logs with different IPs each time 9 | ./action_files/remove_logs_cells 10 | -------------------------------------------------------------------------------- /local_environment.yml: -------------------------------------------------------------------------------- 1 | name: mlforecast 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - black 6 | - flake8 7 | - jupyterlab 8 | - lightgbm 9 | - matplotlib 10 | - mypy 11 | - numba 12 | - pandas 13 | - pip 14 | - pylint 15 | - scikit-learn 16 | - window-ops 17 | - xgboost 18 | - pip: 19 | - datasetsforecast 20 | - nbdev 21 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mlforecast 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - black 6 | - dask<2023.1.1 7 | - flake8 8 | - jupyterlab 9 | - lightgbm 10 | - matplotlib 11 | - mypy 12 | - numba 13 | - pandas 14 | - pip 15 | - pylint 16 | - pyspark 17 | - scikit-learn 18 | - window-ops 19 | - xgboost 20 | - pip: 21 | - datasetsforecast 22 | - fugue[ray] 23 | - lightgbm_ray 24 | - nbdev 25 | - xgboost_ray 26 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/ray/lgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.ray.lgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['RayLGBMForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.ray.lgb.ipynb 3 7 | import lightgbm as lgb 8 | from lightgbm_ray import RayLGBMRegressor 9 | 10 | # %% ../../../../nbs/distributed.models.ray.lgb.ipynb 4 11 | class RayLGBMForecast(RayLGBMRegressor): 12 | @property 13 | def model_(self): 14 | return self._lgb_ray_to_local(lgb.LGBMRegressor) 15 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | ## Description 8 | 9 | 10 | Checklist: 11 | - [ ] This PR has a meaningful title and a clear description. 12 | - [ ] The tests pass. 13 | - [ ] All linting tasks pass. 14 | - [ ] The notebooks are clean. -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v$NEXT_PATCH_VERSION' 2 | tag-template: 'v$NEXT_PATCH_VERSION' 3 | categories: 4 | - title: 'New Features' 5 | label: 'feature' 6 | - title: 'Breaking' 7 | label: 'breaking' 8 | - title: 'Bug Fixes' 9 | label: 'fix' 10 | - title: 'Documentation' 11 | label: 'documentation' 12 | - title: 'Maintenance' 13 | label: 'maintenance' 14 | - title: 'Enhancement' 15 | label: 'enhancement' 16 | change-template: '- $TITLE @$AUTHOR (#$NUMBER)' 17 | template: | 18 | ## Changes 19 | $CHANGES 20 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Clone repo 14 | uses: actions/checkout@v2 15 | 16 | - name: Set up python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.8 20 | 21 | - name: Install linters 22 | run: pip install mypy flake8 pylint 23 | 24 | - name: Lint 25 | run: ./action_files/lint 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an enhancement to the project 4 | --- 5 | 6 | 9 | 10 | ## Summary 11 | 14 | 15 | ## Motivation 16 | 19 | 20 | ## Description 21 | 24 | 25 | ## References 26 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/dask/xgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.dask.xgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DaskXGBForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.dask.xgb.ipynb 3 7 | import xgboost as xgb 8 | 9 | # %% ../../../../nbs/distributed.models.dask.xgb.ipynb 4 10 | class DaskXGBForecast(xgb.dask.DaskXGBRegressor): 11 | @property 12 | def model_(self): 13 | model_str = self.get_booster().save_raw("ubj") 14 | local_model = xgb.XGBRegressor() 15 | local_model.load_model(model_str) 16 | return local_model 17 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/ray/xgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.ray.xgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['RayXGBForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.ray.xgb.ipynb 3 7 | import xgboost as xgb 8 | from xgboost_ray import RayXGBRegressor 9 | 10 | # %% ../../../../nbs/distributed.models.ray.xgb.ipynb 4 11 | class RayXGBForecast(RayXGBRegressor): 12 | @property 13 | def model_(self): 14 | model_str = self.get_booster().save_raw("ubj") 15 | local_model = xgb.XGBRegressor() 16 | local_model.load_model(model_str) 17 | return local_model 18 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/dask/lgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.dask.lgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DaskLGBMForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.dask.lgb.ipynb 3 7 | import warnings 8 | 9 | import lightgbm as lgb 10 | 11 | # %% ../../../../nbs/distributed.models.dask.lgb.ipynb 4 12 | class DaskLGBMForecast(lgb.dask.DaskLGBMRegressor): 13 | if lgb.__version__ < "3.3.0": 14 | warnings.warn( 15 | "It is recommended to install LightGBM version >= 3.3.0, since " 16 | "the current LightGBM version might be affected by https://github.com/microsoft/LightGBM/issues/4026, " 17 | "which was fixed in 3.3.0" 18 | ) 19 | 20 | @property 21 | def model_(self): 22 | return self.to_local() 23 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | defaults: 9 | run: 10 | shell: bash -l {0} 11 | 12 | jobs: 13 | release: 14 | if: github.repository == 'Nixtla/mlforecast' 15 | runs-on: ubuntu-latest 16 | steps: 17 | - name: Clone repo 18 | uses: actions/checkout@v2 19 | 20 | - name: Set up python 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: '3.10' 24 | 25 | - name: Install build dependencies 26 | run: python -m pip install build wheel 27 | 28 | - name: Build distributions 29 | run: python -m build -sw 30 | 31 | - name: Publish package to PyPI 32 | uses: pypa/gh-action-pypi-publish@master 33 | with: 34 | user: __token__ 35 | password: ${{ secrets.pypi_token }} 36 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Something isn't working as expected. 4 | --- 5 | 6 | 9 | 10 | ## Description 11 | 14 | 15 | ## Reproducible example 16 | 19 | 20 | ```python 21 | # code goes here 22 | ``` 23 | 24 |
25 | Error message 26 | 27 | ```python 28 | # Stacktrace 29 | ``` 30 |
31 | 32 | ## Environment info 33 | Install method (pip, conda, github): 34 | 35 | Package version: 36 | 37 | 38 | ## Additional information 39 | 42 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/spark/lgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.spark.lgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SparkLGBMForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.spark.lgb.ipynb 3 7 | import lightgbm as lgb 8 | 9 | try: 10 | from synapse.ml.lightgbm import LightGBMRegressor 11 | except ModuleNotFoundError: 12 | import os 13 | 14 | if os.getenv("QUARTO_PREVIEW", "0") == "1" or os.getenv("IN_TEST", "0") == "1": 15 | LightGBMRegressor = object 16 | else: 17 | raise 18 | 19 | # %% ../../../../nbs/distributed.models.spark.lgb.ipynb 4 20 | class SparkLGBMForecast(LightGBMRegressor): 21 | def _pre_fit(self, target_col): 22 | return self.setLabelCol(target_col) 23 | 24 | def extract_local_model(self, trained_model): 25 | model_str = trained_model.getNativeModel() 26 | local_model = lgb.Booster(model_str=model_str) 27 | return local_model 28 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/spark/xgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.spark.xgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SparkXGBForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.spark.xgb.ipynb 3 7 | import xgboost as xgb 8 | 9 | try: 10 | from xgboost.spark import SparkXGBRegressor # type: ignore 11 | except ModuleNotFoundError: 12 | import os 13 | 14 | if os.getenv("IN_TEST", "0") == "1": 15 | SparkXGBRegressor = object 16 | else: 17 | raise 18 | 19 | # %% ../../../../nbs/distributed.models.spark.xgb.ipynb 4 20 | class SparkXGBForecast(SparkXGBRegressor): 21 | def _pre_fit(self, target_col): 22 | self.setParams(label_col=target_col) 23 | return self 24 | 25 | def extract_local_model(self, trained_model): 26 | model_str = trained_model.get_booster().save_raw("ubj") 27 | local_model = xgb.XGBRegressor() 28 | local_model.load_model(model_str) 29 | return local_model 30 | -------------------------------------------------------------------------------- /action_files/remove_logs_cells: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import re 3 | import json 4 | from pathlib import Path 5 | from nbdev.clean import process_write 6 | 7 | IP_REGEX = re.compile(r'[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}') 8 | HOURS_REGEX = re.compile(r'\d{2}:\d{2}:\d{2}') 9 | 10 | def cell_contains_ips(cell): 11 | if 'outputs' not in cell: 12 | return False 13 | for output in cell['outputs']: 14 | if 'text' not in output: 15 | return False 16 | for line in output['text']: 17 | if IP_REGEX.search(line) or HOURS_REGEX.search(line): 18 | return True 19 | return False 20 | 21 | 22 | def clean_nb(nb): 23 | for cell in nb['cells']: 24 | if cell_contains_ips(cell): 25 | cell['outputs'] = [] 26 | 27 | 28 | if __name__ == '__main__': 29 | repo_root = Path(__file__).parents[1] 30 | for nb in (repo_root / 'nbs').glob('*distributed*'): 31 | process_write(warn_msg='Failed to clean_nb', proc_nb=clean_nb, f_in=nb) 32 | -------------------------------------------------------------------------------- /nbs/styles.css: -------------------------------------------------------------------------------- 1 | .cell { 2 | margin-bottom: 1rem; 3 | } 4 | 5 | .cell > .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre { 14 | margin-left: 0.8rem; 15 | margin-top: 0; 16 | background: none; 17 | border-left: 2px solid lightsalmon; 18 | border-top-left-radius: 0; 19 | border-top-right-radius: 0; 20 | } 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | 39 | /* show_doc signature */ 40 | blockquote > pre { 41 | font-size: 14px; 42 | } 43 | 44 | .table { 45 | font-size: 16px; 46 | /* disable striped tables */ 47 | --bs-table-striped-bg: var(--bs-table-bg); 48 | } 49 | 50 | .quarto-figure-center > figure > figcaption { 51 | text-align: center; 52 | } 53 | 54 | .figure-caption { 55 | font-size: 75%; 56 | font-style: italic; 57 | } 58 | -------------------------------------------------------------------------------- /nbs/sidebar.yml: -------------------------------------------------------------------------------- 1 | website: 2 | reader-mode: false 3 | sidebar: 4 | collapse-level: 3 5 | contents: 6 | - index.ipynb 7 | - section: Getting started 8 | contents: 9 | - docs/install.ipynb 10 | - docs/quick_start_local.ipynb 11 | - docs/quick_start_distributed.ipynb 12 | - docs/end_to_end_walkthrough.ipynb 13 | - section: Tutorials 14 | contents: 15 | - docs/cross_validation.ipynb 16 | - docs/electricity_peak_forecasting.ipynb 17 | - docs/prediction_intervals.ipynb 18 | - docs/transfer_learning.ipynb 19 | - section: API reference 20 | contents: 21 | - section: Local 22 | contents: 23 | - forecast.ipynb 24 | - lgb_cv.ipynb 25 | - utils.ipynb 26 | - core.ipynb 27 | - section: Distributed 28 | contents: 29 | - distributed.forecast.ipynb 30 | - distributed.models.dask.lgb.ipynb 31 | - distributed.models.dask.xgb.ipynb 32 | - distributed.models.spark.lgb.ipynb 33 | - distributed.models.spark.xgb.ipynb 34 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | host = github 3 | lib_name = mlforecast 4 | user = Nixtla 5 | description = Scalable machine learning based time series forecasting 6 | keywords = python forecast forecasting machine-learning dask 7 | author = José Morales 8 | author_email = jmoralz92@gmail.com 9 | copyright = Nixtla 10 | branch = main 11 | version = 0.6.0 12 | min_python = 3.6 13 | audience = Developers 14 | language = English 15 | custom_sidebar = True 16 | license = apache2 17 | status = 3 18 | requirements = numba pandas scikit-learn window-ops 19 | distributed_requirements = dask[complete] fugue[ray] pyspark lightgbm_ray xgboost_ray 20 | dev_requirements = black datasetsforecast flake8 lightgbm matplotlib mypy nbdev pylint statsforecast xgboost 21 | nbs_path = nbs 22 | doc_path = _docs 23 | recursive = True 24 | doc_host = https://Nixtla.github.io 25 | doc_baseurl = / 26 | git_url = https://github.com/Nixtla/mlforecast 27 | lib_path = mlforecast 28 | title = mlforecast 29 | tst_flags = 30 | black_formatting = True 31 | readme_nb = index.ipynb 32 | allowed_metadata_keys = 33 | allowed_cell_metadata_keys = 34 | jupyter_hooks = True 35 | clean_ids = True 36 | clear_all = False 37 | put_version_in_init = True 38 | -------------------------------------------------------------------------------- /nbs/distributed.models.ray.lgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.ray.lgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "12972535-1a7c-4814-a19c-5e2c48824e85", 16 | "metadata": {}, 17 | "source": [ 18 | "# RayLGBMForecast\n", 19 | "\n", 20 | "> ray LightGBM forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "fd9d0998-ca46-4e7a-9c64-b8378c0c1b85", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `lightgbm.ray.RayLGBMRegressor` that adds a `model_` property that contains the fitted booster and is sent to the workers to in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import lightgbm as lgb\n", 40 | "from lightgbm_ray import RayLGBMRegressor" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "#|export\n", 51 | "class RayLGBMForecast(RayLGBMRegressor):\n", 52 | " @property\n", 53 | " def model_(self):\n", 54 | " return self._lgb_ray_to_local(lgb.LGBMRegressor)" 55 | ] 56 | } 57 | ], 58 | "metadata": { 59 | "kernelspec": { 60 | "display_name": "python3", 61 | "language": "python", 62 | "name": "python3" 63 | } 64 | }, 65 | "nbformat": 4, 66 | "nbformat_minor": 5 67 | } 68 | -------------------------------------------------------------------------------- /nbs/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | fontsize: 1em 8 | linestretch: 1.7 9 | css: styles.css 10 | toc: true 11 | 12 | website: 13 | twitter-card: 14 | image: "https://farm6.staticflickr.com/5510/14338202952_93595258ff_z.jpg" 15 | site: "@Nixtlainc" 16 | open-graph: 17 | image: "https://github.com/Nixtla/styles/blob/2abf51612584169874c90cd7c4d347e3917eaf73/images/Banner%20Github.png" 18 | google-analytics: "G-NXJNCVR18L" 19 | repo-actions: [issue] 20 | navbar: 21 | background: primary 22 | search: true 23 | collapse-below: lg 24 | left: 25 | - text: "Get Started" 26 | href: docs/quick_start_local.ipynb 27 | - text: "NixtlaVerse" 28 | menu: 29 | - text: "StatsForecast ⚡️" 30 | href: https://github.com/nixtla/statsforecast 31 | - text: "NeuralForecast 🧠" 32 | href: https://github.com/nixtla/neuralforecast 33 | - text: "HierarchicalForecast 👑" 34 | href: "https://github.com/nixtla/hierarchicalforecast" 35 | 36 | - text: "Help" 37 | menu: 38 | - text: "Report an Issue" 39 | icon: bug 40 | href: https://github.com/nixtla/mlforecast/issues/new/choose 41 | - text: "Join our Slack" 42 | icon: chat-right-text 43 | href: https://join.slack.com/t/nixtlaworkspace/shared_invite/zt-135dssye9-fWTzMpv2WBthq8NK0Yvu6A 44 | right: 45 | - icon: github 46 | href: "https://github.com/nixtla/mlforecast" 47 | - icon: twitter 48 | href: https://twitter.com/nixtlainc 49 | aria-label: Nixtla Twitter 50 | 51 | sidebar: 52 | style: floating 53 | body-footer: | 54 | Give us a ⭐ on [Github](https://github.com/nixtla/mlforecast) 55 | 56 | metadata-files: [nbdev.yml, sidebar.yml] 57 | -------------------------------------------------------------------------------- /nbs/distributed.models.dask.xgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.dask.xgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "5ee154af-e882-4914-8bf2-f536a8d01b94", 16 | "metadata": {}, 17 | "source": [ 18 | "# DaskXGBForecast\n", 19 | "\n", 20 | "> dask XGBoost forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "4f4c7bc1-9779-4771-8224-f852e6b7987c", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `xgboost.dask.DaskXGBRegressor` that adds a `model_` property that contains the fitted model and is sent to the workers in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import xgboost as xgb" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "#|export\n", 50 | "class DaskXGBForecast(xgb.dask.DaskXGBRegressor):\n", 51 | " @property\n", 52 | " def model_(self):\n", 53 | " model_str = self.get_booster().save_raw('ubj')\n", 54 | " local_model = xgb.XGBRegressor()\n", 55 | " local_model.load_model(model_str)\n", 56 | " return local_model" 57 | ] 58 | } 59 | ], 60 | "metadata": { 61 | "kernelspec": { 62 | "display_name": "python3", 63 | "language": "python", 64 | "name": "python3" 65 | } 66 | }, 67 | "nbformat": 4, 68 | "nbformat_minor": 5 69 | } 70 | -------------------------------------------------------------------------------- /nbs/distributed.models.ray.xgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.ray.xgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "5ee154af-e882-4914-8bf2-f536a8d01b94", 16 | "metadata": {}, 17 | "source": [ 18 | "# RayXGBForecast\n", 19 | "\n", 20 | "> dask XGBoost forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "4f4c7bc1-9779-4771-8224-f852e6b7987c", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `xgboost.ray.RayXGBRegressor` that adds a `model_` property that contains the fitted model and is sent to the workers in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import xgboost as xgb\n", 40 | "from xgboost_ray import RayXGBRegressor" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "#|export\n", 51 | "class RayXGBForecast(RayXGBRegressor):\n", 52 | " @property\n", 53 | " def model_(self):\n", 54 | " model_str = self.get_booster().save_raw(\"ubj\")\n", 55 | " local_model = xgb.XGBRegressor()\n", 56 | " local_model.load_model(model_str)\n", 57 | " return local_model" 58 | ] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "python3", 64 | "language": "python", 65 | "name": "python3" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 5 70 | } 71 | -------------------------------------------------------------------------------- /nbs/distributed.models.dask.lgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.dask.lgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "12972535-1a7c-4814-a19c-5e2c48824e85", 16 | "metadata": {}, 17 | "source": [ 18 | "# DaskLGBMForecast\n", 19 | "\n", 20 | "> dask LightGBM forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "fd9d0998-ca46-4e7a-9c64-b8378c0c1b85", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `lightgbm.dask.DaskLGBMRegressor` that adds a `model_` property that contains the fitted booster and is sent to the workers to in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import warnings\n", 40 | "\n", 41 | "import lightgbm as lgb" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "#|export\n", 52 | "class DaskLGBMForecast(lgb.dask.DaskLGBMRegressor):\n", 53 | " if lgb.__version__ < \"3.3.0\":\n", 54 | " warnings.warn(\n", 55 | " \"It is recommended to install LightGBM version >= 3.3.0, since \"\n", 56 | " \"the current LightGBM version might be affected by https://github.com/microsoft/LightGBM/issues/4026, \"\n", 57 | " \"which was fixed in 3.3.0\"\n", 58 | " )\n", 59 | "\n", 60 | " @property\n", 61 | " def model_(self):\n", 62 | " return self.to_local()" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "python3", 69 | "language": "python", 70 | "name": "python3" 71 | } 72 | }, 73 | "nbformat": 4, 74 | "nbformat_minor": 5 75 | } 76 | -------------------------------------------------------------------------------- /nbs/distributed.models.spark.lgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.spark.lgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "12972535-1a7c-4814-a19c-5e2c48824e85", 16 | "metadata": {}, 17 | "source": [ 18 | "# SparkLGBMForecast\n", 19 | "\n", 20 | "> spark LightGBM forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "fd9d0998-ca46-4e7a-9c64-b8378c0c1b85", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `synapse.ml.lightgbm.LightGBMRegressor` that adds an `extract_local_model` method to get a local version of the trained model and broadcast it to the workers." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import lightgbm as lgb\n", 40 | "try:\n", 41 | " from synapse.ml.lightgbm import LightGBMRegressor\n", 42 | "except ModuleNotFoundError:\n", 43 | " import os\n", 44 | " \n", 45 | " if os.getenv('QUARTO_PREVIEW', '0') == '1' or os.getenv('IN_TEST', '0') == '1':\n", 46 | " LightGBMRegressor = object\n", 47 | " else:\n", 48 | " raise" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "#|export\n", 59 | "class SparkLGBMForecast(LightGBMRegressor):\n", 60 | " def _pre_fit(self, target_col):\n", 61 | " return self.setLabelCol(target_col)\n", 62 | " \n", 63 | " def extract_local_model(self, trained_model):\n", 64 | " model_str = trained_model.getNativeModel()\n", 65 | " local_model = lgb.Booster(model_str=model_str)\n", 66 | " return local_model" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "python3", 73 | "language": "python", 74 | "name": "python3" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 5 79 | } 80 | -------------------------------------------------------------------------------- /nbs/distributed.models.spark.xgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.spark.xgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "5ee154af-e882-4914-8bf2-f536a8d01b94", 16 | "metadata": {}, 17 | "source": [ 18 | "# SparkXGBForecast\n", 19 | "\n", 20 | "> spark XGBoost forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "4f4c7bc1-9779-4771-8224-f852e6b7987c", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `xgboost.spark.SparkXGBRegressor` that adds an `extract_local_model` method to get a local version of the trained model and broadcast it to the workers." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import xgboost as xgb\n", 40 | "try:\n", 41 | " from xgboost.spark import SparkXGBRegressor # type: ignore\n", 42 | "except ModuleNotFoundError:\n", 43 | " import os\n", 44 | " \n", 45 | " if os.getenv('IN_TEST', '0') == '1':\n", 46 | " SparkXGBRegressor = object\n", 47 | " else:\n", 48 | " raise" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "#|export\n", 59 | "class SparkXGBForecast(SparkXGBRegressor): \n", 60 | " def _pre_fit(self, target_col):\n", 61 | " self.setParams(label_col=target_col)\n", 62 | " return self\n", 63 | "\n", 64 | " def extract_local_model(self, trained_model):\n", 65 | " model_str = trained_model.get_booster().save_raw('ubj')\n", 66 | " local_model = xgb.XGBRegressor()\n", 67 | " local_model.load_model(model_str)\n", 68 | " return local_model" 69 | ] 70 | } 71 | ], 72 | "metadata": { 73 | "kernelspec": { 74 | "display_name": "python3", 75 | "language": "python", 76 | "name": "python3" 77 | } 78 | }, 79 | "nbformat": 4, 80 | "nbformat_minor": 5 81 | } 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | .gitattributes 3 | .last_checked 4 | .gitconfig 5 | *.bak 6 | *.log 7 | *~ 8 | ~* 9 | _tmp* 10 | tmp* 11 | tags 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | .venv 99 | venv/ 100 | ENV/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | .vscode 116 | *.swp 117 | 118 | # osx generated files 119 | .DS_Store 120 | .DS_Store? 121 | .Trashes 122 | ehthumbs.db 123 | Thumbs.db 124 | .idea 125 | 126 | # pytest 127 | .pytest_cache 128 | 129 | # tools/trust-doc-nbs 130 | docs_src/.last_checked 131 | 132 | # symlinks to fastai 133 | docs_src/fastai 134 | tools/fastai 135 | 136 | # link checker 137 | checklink/cookies.txt 138 | 139 | # .gitconfig is now autogenerated 140 | .gitconfig 141 | 142 | # dask 143 | dask-worker-space 144 | 145 | # gemfiles 146 | Gemfile* 147 | 148 | # jekyll 149 | .jekyll-cache 150 | 151 | # series files 152 | nbs/data 153 | 154 | # nbdev 155 | nbs/_docs 156 | _proc/ 157 | index_files 158 | _docs 159 | nbs/data 160 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = cfg.get('requirements','').split() 28 | distributed_requirements = cfg.get('distributed_requirements', '').split() 29 | dev_requirements = requirements + distributed_requirements + cfg.get('dev_requirements', '').split() 30 | min_python = cfg['min_python'] 31 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 32 | 33 | setuptools.setup( 34 | name = 'mlforecast', 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require = { 46 | 'distributed': distributed_requirements, 47 | 'dev': dev_requirements, 48 | }, 49 | dependency_links = cfg.get('dep_links','').split(), 50 | python_requires = '>=' + cfg['min_python'], 51 | long_description = open('README.md', encoding='utf-8').read(), 52 | long_description_content_type = 'text/markdown', 53 | zip_safe = False, 54 | entry_points = { 'console_scripts': cfg.get('console_scripts','').split() }, 55 | **setup_cfg) 56 | 57 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: 9 | 10 | defaults: 11 | run: 12 | shell: bash -l {0} 13 | 14 | concurrency: 15 | group: ${{ github.workflow }}-${{ github.ref }} 16 | cancel-in-progress: true 17 | 18 | jobs: 19 | nb-sync: 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Clone repo 23 | uses: actions/checkout@v2 24 | 25 | - name: Set up python 26 | uses: actions/setup-python@v2 27 | 28 | - name: Install nbdev 29 | run: pip install nbdev 30 | 31 | - name: Check if all notebooks are cleaned 32 | run: | 33 | echo "Check we are starting with clean git checkout" 34 | if [ -n "$(git status -uno -s)" ]; then echo "git status is not clean"; false; fi 35 | echo "Trying to strip out notebooks" 36 | nbdev_clean 37 | echo "Check that strip out was unnecessary" 38 | git status -s # display the status to see which nbs need cleaning up 39 | if [ -n "$(git status -uno -s)" ]; then echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_hooks"; false; fi 40 | 41 | run-all-tests: 42 | runs-on: ubuntu-latest 43 | strategy: 44 | fail-fast: false 45 | matrix: 46 | python-version: ['3.7', '3.8', '3.9', '3.10'] 47 | steps: 48 | - name: Clone repo 49 | uses: actions/checkout@v2 50 | 51 | - name: Set up environment 52 | uses: mamba-org/provision-with-micromamba@main 53 | with: 54 | extra-specs: python=${{ matrix.python-version }} 55 | cache-env: true 56 | 57 | - name: Install the library 58 | run: pip install ./ 59 | 60 | - name: Run all tests 61 | run: nbdev_test --n_workers 1 --do_print --timing 62 | 63 | run-local-tests: 64 | runs-on: ${{ matrix.os }} 65 | strategy: 66 | fail-fast: false 67 | matrix: 68 | os: [macos-latest, windows-latest] 69 | python-version: ['3.7', '3.8', '3.9', '3.10'] 70 | steps: 71 | - name: Clone repo 72 | uses: actions/checkout@v2 73 | 74 | - name: Set up environment 75 | uses: mamba-org/provision-with-micromamba@main 76 | with: 77 | environment-file: local_environment.yml 78 | extra-specs: python=${{ matrix.python-version }} 79 | cache-env: true 80 | 81 | - name: Install the library 82 | run: pip install ./ 83 | 84 | - name: Run local tests 85 | run: nbdev_test --n_workers 1 --do_print --timing --skip_file_glob "*distributed*" 86 | 87 | check-deps: 88 | runs-on: ubuntu-latest 89 | steps: 90 | - name: Clone repo 91 | uses: actions/checkout@v2 92 | 93 | - name: Set up python 94 | uses: actions/setup-python@v2 95 | with: 96 | python-version: '3.10' 97 | 98 | - name: Install the library and import it 99 | run: | 100 | pip install . 101 | python -c 'import mlforecast' 102 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | ## Did you find a bug? 4 | 5 | * Ensure the bug was not already reported by searching on GitHub under Issues. 6 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. 7 | * Be sure to add the complete error messages. 8 | 9 | ## Do you have a feature request? 10 | 11 | * Ensure that it hasn't been yet implemented in the `main` branch of the repository and that there's not an Issue requesting it yet. 12 | * Open a new issue and make sure to describe it clearly, mention how it improves the project and why its useful. 13 | 14 | ## Do you want to fix a bug or implement a feature? 15 | 16 | Bug fixes and features are added through pull requests (PRs). 17 | 18 | ## PR submission guidelines 19 | 20 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused. 21 | * Ensure that your PR includes a test that fails without your patch, and passes with it. 22 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. 23 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected. 24 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can. 25 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project. 26 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another. 27 | 28 | ### Local setup for working on a PR 29 | 30 | #### 1. Clone the repository 31 | * HTTPS: `git clone https://github.com/Nixtla/mlforecast.git` 32 | * SSH: `git clone git@github.com:Nixtla/mlforecast.git` 33 | * GitHub CLI: `gh repo clone Nixtla/mlforecast` 34 | 35 | #### 2. Install the required dependencies for development 36 | ##### conda/mamba 37 | The repo comes with an `environment.yml` file which contains the libraries needed to run all the tests (please note that the distributed interface is only available on Linux). In order to set up the environment you must have `conda/mamba` installed, we recommend [mambaforge](https://github.com/conda-forge/miniforge#mambaforge). 38 | 39 | Once you have `conda/mamba` go to the top level directory of the repository and run: 40 | ``` 41 | {conda|mamba} env create -f environment.yml 42 | ``` 43 | 44 | Once you have your environment setup, activate it using `conda activate mlforecast`. 45 | ##### PyPI 46 | From the top level directory of the repository run: `pip install ".[dev]"` 47 | 48 | #### 3. Install the library 49 | From the top level directory of the repository run: `pip install -e .` 50 | 51 | ### Building the library 52 | The library is built using the notebooks contained in the `nbs` folder. If you want to make any changes to the library you have to find the relevant notebook, make your changes and then call `nbdev_export`. 53 | 54 | ### Running tests 55 | 56 | * If you're working on the local interface, use `nbdev_test --skip_file_glob "distributed*" --n_workers 1`. 57 | * If you're modifying the distributed interface run the tests using `nbdev_test --n_workers 1`. 58 | ### Linters 59 | This project uses a couple of linters to validate different aspects of the code. Before opening a PR, please make sure that it passes all the linting tasks by following the next steps. 60 | 61 | #### Run the linting tasks 62 | * `mypy mlforecast/` 63 | * `flake8 --select=F mlforecast/` 64 | 65 | ### Cleaning notebooks 66 | Run `nbdev_clean`. 67 | ## Do you want to contribute to the documentation? 68 | 69 | * Docs are automatically created from the notebooks in the `nbs` folder. 70 | * In order to modify the documentation: 71 | 1. Find the relevant notebook. 72 | 2. Make your changes. 73 | 3. Run all cells. 74 | 4. Run `nbdev_preview` 75 | 5. If you modified the `index.ipynb` notebook, run `nbdev_readme`. 76 | -------------------------------------------------------------------------------- /nbs/docs/install.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e4498c76-c1a4-4bb4-ac55-96f9c0475acc", 6 | "metadata": {}, 7 | "source": [ 8 | "# Install\n", 9 | "\n", 10 | "> Instructions to install the package from different sources." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "a2312c2d-b391-4d19-a1bd-f99339c290d7", 16 | "metadata": {}, 17 | "source": [ 18 | "## Released versions" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "d465c613-a0a6-4538-9497-57984c261dc0", 24 | "metadata": {}, 25 | "source": [ 26 | "### PyPI" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "417f8f55-8000-4595-af03-ab88bcc62488", 32 | "metadata": {}, 33 | "source": [ 34 | "#### Latest release" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "0ea582c6-04e1-4c0f-a4bd-6e551f08f20d", 40 | "metadata": {}, 41 | "source": [ 42 | "To install the latest release of mlforecast from [PyPI](https://pypi.org/project/mlforecast/) you just have to run the following in a terminal:\n", 43 | "\n", 44 | "`pip install mlforecast`" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "342b7c8d-4bb3-43e0-8ee8-73bfd65e2b5f", 50 | "metadata": {}, 51 | "source": [ 52 | "#### Specific version" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "a09a5c90-cce9-4c49-ac38-f8569f760a98", 58 | "metadata": {}, 59 | "source": [ 60 | "If you want a specific version you can include a filter, for example:\n", 61 | "\n", 62 | "* `pip install \"mlforecast==0.3.0\"` to install the 0.3.0 version\n", 63 | "* `pip install \"mlforecast<0.4.0\"` to install any version prior to 0.4.0" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "3d3acf50-ace7-4c00-b935-36a801c79dc7", 69 | "metadata": {}, 70 | "source": [ 71 | "#### Distributed training" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "8912889b-cd26-4b37-a43f-bc672cdec9f0", 77 | "metadata": {}, 78 | "source": [ 79 | "If you want to perform distributed training you have to include the [dask](https://www.dask.org/) extra:\n", 80 | "\n", 81 | "`pip install \"mlforecast[dask]\"`\n", 82 | "\n", 83 | "and also either [LightGBM](https://github.com/microsoft/LightGBM/tree/master/python-package) or [XGBoost](https://xgboost.readthedocs.io/en/latest/install.html#python)." 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "6772b367-7e1b-4612-8bcd-26039b2badf3", 89 | "metadata": {}, 90 | "source": [ 91 | "### Conda" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "id": "3df21509-dcd1-433e-8a3e-a9f5bce7dc51", 97 | "metadata": {}, 98 | "source": [ 99 | "#### Latest release" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "13744f0d-6916-4358-be73-25209033eb74", 105 | "metadata": {}, 106 | "source": [ 107 | "The mlforecast package is also published to [conda-forge](https://anaconda.org/conda-forge/mlforecast), which you can install by running the following in a terminal:\n", 108 | "\n", 109 | "`conda install -c conda-forge mlforecast`\n", 110 | "\n", 111 | "Note that this happens about a day later after it is published to PyPI, so you may have to wait to get the latest release." 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "debb123c-a31e-4c9f-a39c-1c4db07a30a3", 117 | "metadata": {}, 118 | "source": [ 119 | "#### Specific version" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "ed4c9f64-db1f-4f3e-b72d-5076ecf7be3d", 125 | "metadata": {}, 126 | "source": [ 127 | "If you want a specific version you can include a filter, for example:\n", 128 | "\n", 129 | "* `conda install -c conda-forge \"mlforecast==0.3.0\"` to install the 0.3.0 version\n", 130 | "* `conda install -c conda-forge \"mlforecast<0.4.0\"` to install any version prior to 0.4.0" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "40307de9-f0ec-4f84-a324-8a36d66e7fdb", 136 | "metadata": {}, 137 | "source": [ 138 | "#### Distributed training" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "83dab9f7-e614-42ef-95c0-a76ae20f74a7", 144 | "metadata": {}, 145 | "source": [ 146 | "If you want to perform distributed training you also have to install [dask](https://www.dask.org/):\n", 147 | "\n", 148 | "`conda install -c conda-forge dask`\n", 149 | "\n", 150 | "and also either [LightGBM](https://github.com/microsoft/LightGBM/tree/master/python-package) or [XGBoost](https://xgboost.readthedocs.io/en/latest/install.html#python)." 151 | ] 152 | }, 153 | { 154 | "cell_type": "markdown", 155 | "id": "937ed413-2207-43f2-965e-62ebbaf0c8db", 156 | "metadata": {}, 157 | "source": [ 158 | "## Development version" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "id": "ec4a0e27-24c3-4f79-929c-c77f891eb27e", 164 | "metadata": {}, 165 | "source": [ 166 | "If you want to try out a new feature that hasn't made it into a release yet you have the following options:\n", 167 | "\n", 168 | "* Install from github: `pip install git+https://github.com/Nixtla/mlforecast`\n", 169 | "* Clone and install:\n", 170 | " * `git clone https://github.com/Nixtla/mlforecast`\n", 171 | " * `pip install mlforecast`\n", 172 | "\n", 173 | "which will install the version from the current main branch." 174 | ] 175 | } 176 | ], 177 | "metadata": { 178 | "kernelspec": { 179 | "display_name": "python3", 180 | "language": "python", 181 | "name": "python3" 182 | } 183 | }, 184 | "nbformat": 4, 185 | "nbformat_minor": 5 186 | } 187 | -------------------------------------------------------------------------------- /mlforecast/utils.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['generate_daily_series', 'generate_prices_for_series', 'backtest_splits'] 5 | 6 | # %% ../nbs/utils.ipynb 2 7 | import random 8 | import reprlib 9 | from itertools import chain 10 | from math import ceil, log10 11 | from typing import Optional, Union 12 | 13 | import numpy as np 14 | import pandas as pd 15 | 16 | # %% ../nbs/utils.ipynb 4 17 | def generate_daily_series( 18 | n_series: int, 19 | min_length: int = 50, 20 | max_length: int = 500, 21 | n_static_features: int = 0, 22 | equal_ends: bool = False, 23 | static_as_categorical: bool = True, 24 | with_trend: bool = False, 25 | seed: int = 0, 26 | ) -> pd.DataFrame: 27 | """Generates `n_series` of different lengths in the interval [`min_length`, `max_length`]. 28 | 29 | If `n_static_features > 0`, then each serie gets static features with random values. 30 | If `equal_ends == True` then all series end at the same date.""" 31 | rng = np.random.RandomState(seed) 32 | random.seed(seed) 33 | series_lengths = rng.randint(min_length, max_length + 1, n_series) 34 | total_length = series_lengths.sum() 35 | n_digits = ceil(log10(n_series)) 36 | 37 | dates = pd.date_range("2000-01-01", periods=max_length, freq="D").values 38 | uids = [ 39 | [f"id_{i:0{n_digits}}"] * serie_length 40 | for i, serie_length in enumerate(series_lengths) 41 | ] 42 | if equal_ends: 43 | ds = [dates[-serie_length:] for serie_length in series_lengths] 44 | else: 45 | ds = [dates[:serie_length] for serie_length in series_lengths] 46 | y = np.arange(total_length) % 7 + rng.rand(total_length) * 0.5 47 | series = pd.DataFrame( 48 | { 49 | "unique_id": list(chain.from_iterable(uids)), 50 | "ds": list(chain.from_iterable(ds)), 51 | "y": y, 52 | } 53 | ) 54 | for i in range(n_static_features): 55 | static_values = np.repeat(rng.randint(0, 100, n_series), series_lengths) 56 | series[f"static_{i}"] = static_values 57 | if static_as_categorical: 58 | series[f"static_{i}"] = series[f"static_{i}"].astype("category") 59 | if i == 0: 60 | series["y"] = series["y"] * 0.1 * (1 + static_values) 61 | series["unique_id"] = series["unique_id"].astype("category") 62 | series["unique_id"] = series["unique_id"].cat.as_ordered() 63 | if with_trend: 64 | coefs = pd.Series( 65 | rng.rand(n_series), index=[f"id_{i:0{n_digits}}" for i in range(n_series)] 66 | ) 67 | trends = series.groupby("unique_id").cumcount() 68 | trends.index = series["unique_id"] 69 | series["y"] += (coefs * trends).values 70 | return series 71 | 72 | # %% ../nbs/utils.ipynb 15 73 | def generate_prices_for_series( 74 | series: pd.DataFrame, horizon: int = 7, seed: int = 0 75 | ) -> pd.DataFrame: 76 | rng = np.random.RandomState(seed) 77 | unique_last_dates = series.groupby("unique_id")["ds"].max().nunique() 78 | if unique_last_dates > 1: 79 | raise ValueError("series must have equal ends.") 80 | if "product_id" not in series: 81 | raise ValueError("series must have a product_id column.") 82 | day_offset = pd.tseries.frequencies.Day() 83 | starts_ends = series.groupby("product_id")["ds"].agg([min, max]) 84 | dfs = [] 85 | for idx, (start, end) in starts_ends.iterrows(): 86 | product_df = pd.DataFrame( 87 | { 88 | "product_id": idx, 89 | "price": rng.rand((end - start).days + 1 + horizon), 90 | }, 91 | index=pd.date_range(start, end + horizon * day_offset, name="ds"), 92 | ) 93 | dfs.append(product_df) 94 | prices_catalog = pd.concat(dfs).reset_index() 95 | return prices_catalog 96 | 97 | # %% ../nbs/utils.ipynb 18 98 | def single_split( 99 | data: pd.DataFrame, 100 | i_window: int, 101 | n_windows: int, 102 | window_size: int, 103 | id_col: str, 104 | time_col: str, 105 | freq: Union[pd.offsets.BaseOffset, int], 106 | max_dates: pd.Series, 107 | step_size: Optional[int] = None, 108 | input_size: Optional[int] = None, 109 | ): 110 | if step_size is None: 111 | step_size = window_size 112 | test_size = window_size + step_size * (n_windows - 1) 113 | offset = test_size - i_window * step_size 114 | train_ends = max_dates - offset * freq 115 | valid_ends = train_ends + window_size * freq 116 | train_mask = data[time_col].le(train_ends) 117 | if input_size is not None: 118 | train_mask &= data[time_col].gt(train_ends - input_size * freq) 119 | train_sizes = train_mask.groupby(data[id_col], observed=True).sum() 120 | if train_sizes.eq(0).any(): 121 | ids = reprlib.repr(train_sizes[train_sizes.eq(0)].index.tolist()) 122 | raise ValueError(f"The following series are too short for the window: {ids}") 123 | valid_mask = data[time_col].gt(train_ends) & data[time_col].le(valid_ends) 124 | cutoffs = ( 125 | train_ends.set_axis(data[id_col]) 126 | .groupby(id_col, observed=True) 127 | .head(1) 128 | .rename("cutoff") 129 | ) 130 | return cutoffs, train_mask, valid_mask 131 | 132 | # %% ../nbs/utils.ipynb 19 133 | def backtest_splits( 134 | data: pd.DataFrame, 135 | n_windows: int, 136 | window_size: int, 137 | id_col: str, 138 | time_col: str, 139 | freq: Union[pd.offsets.BaseOffset, int], 140 | step_size: Optional[int] = None, 141 | input_size: Optional[int] = None, 142 | ): 143 | max_dates = data.groupby(id_col, observed=True)[time_col].transform("max") 144 | for i in range(n_windows): 145 | cutoffs, train_mask, valid_mask = single_split( 146 | data, 147 | i_window=i, 148 | n_windows=n_windows, 149 | window_size=window_size, 150 | id_col=id_col, 151 | time_col=time_col, 152 | freq=freq, 153 | max_dates=max_dates, 154 | step_size=step_size, 155 | input_size=input_size, 156 | ) 157 | train, valid = data[train_mask], data[valid_mask] 158 | yield cutoffs, train, valid 159 | 160 | # %% ../nbs/utils.ipynb 22 161 | class PredictionIntervals: 162 | """Class for storing prediction intervals metadata information.""" 163 | 164 | def __init__( 165 | self, 166 | n_windows: int = 2, 167 | window_size: int = 1, 168 | method: str = "conformal_distribution", 169 | ): 170 | if n_windows < 2: 171 | raise ValueError( 172 | "You need at least two windows to compute conformal intervals" 173 | ) 174 | allowed_methods = ["conformal_error", "conformal_distribution"] 175 | if method not in allowed_methods: 176 | raise ValueError(f"method must be one of {allowed_methods}") 177 | self.n_windows = n_windows 178 | self.window_size = window_size 179 | self.method = method 180 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 Nixtla 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Nixtla   2 | [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Statistical%20Forecasting%20Algorithms%20by%20Nixtla%20&url=https://github.com/Nixtla/statsforecast&via=nixtlainc&hashtags=StatisticalModels,TimeSeries,Forecasting) 3 |  ![Slack](https://img.shields.io/badge/Slack-4A154B?&logo=slack&logoColor=white.png) 4 | ================ 5 | 6 | 7 | 8 |
9 | 10 |
11 | 12 |
13 |

14 | Machine Learning 🤖 Forecast 15 |

16 |

17 | Scalable machine learning for time series forecasting 18 |

19 | 20 | [![CI](https://github.com/Nixtla/mlforecast/actions/workflows/ci.yaml/badge.svg)](https://github.com/Nixtla/mlforecast/actions/workflows/ci.yaml) 21 | [![Python](https://img.shields.io/pypi/pyversions/mlforecast.png)](https://pypi.org/project/mlforecast/) 22 | [![PyPi](https://img.shields.io/pypi/v/mlforecast?color=blue.png)](https://pypi.org/project/mlforecast/) 23 | [![conda-forge](https://img.shields.io/conda/vn/conda-forge/mlforecast?color=blue.png)](https://anaconda.org/conda-forge/mlforecast) 24 | [![License](https://img.shields.io/github/license/Nixtla/mlforecast.png)](https://github.com/Nixtla/mlforecast/blob/main/LICENSE) 25 | 26 | **mlforecast** is a framework to perform time series forecasting using 27 | machine learning models, with the option to scale to massive amounts of 28 | data using remote clusters. 29 | 30 |
31 | 32 | ## Install 33 | 34 | ### PyPI 35 | 36 | `pip install mlforecast` 37 | 38 | If you want to perform distributed training, you can instead use 39 | `pip install "mlforecast[distributed]"`, which will also install 40 | [dask](https://dask.org/). Note that you’ll also need to install either 41 | [LightGBM](https://github.com/microsoft/LightGBM/tree/master/python-package) 42 | or 43 | [XGBoost](https://xgboost.readthedocs.io/en/latest/install.html#python). 44 | 45 | ### conda-forge 46 | 47 | `conda install -c conda-forge mlforecast` 48 | 49 | Note that this installation comes with the required dependencies for the 50 | local interface. If you want to perform distributed training, you must 51 | install dask (`conda install -c conda-forge dask`) and either 52 | [LightGBM](https://github.com/microsoft/LightGBM/tree/master/python-package) 53 | or 54 | [XGBoost](https://xgboost.readthedocs.io/en/latest/install.html#python). 55 | 56 | ## Quick Start 57 | 58 | **Minimal Example** 59 | 60 | ``` python 61 | import lightgbm as lgb 62 | 63 | from mlforecast import MLForecast 64 | from sklearn.linear_model import LinearRegression 65 | 66 | mlf = MLForecast( 67 | models = [LinearRegression(), lgb.LGBMRegressor()], 68 | lags=[1, 12], 69 | freq = 'M' 70 | ) 71 | mlf.fit(df) 72 | mlf.predict(12) 73 | ``` 74 | 75 | **Get Started with this [quick 76 | guide](https://nixtla.github.io/mlforecast/docs/quick_start_local.html).** 77 | 78 | **Follow this [end-to-end 79 | walkthrough](https://nixtla.github.io/mlforecast/docs/end_to_end_walkthrough.html) 80 | for best practices.** 81 | 82 | ## Why? 83 | 84 | Current Python alternatives for machine learning models are slow, 85 | inaccurate and don’t scale well. So we created a library that can be 86 | used to forecast in production environments. `MLForecast` includes 87 | efficient feature engineering to train any machine learning model (with 88 | `fit` and `predict` methods such as 89 | [`sklearn`](https://scikit-learn.org/stable/)) to fit millions of time 90 | series. 91 | 92 | ## Features 93 | 94 | - Fastest implementations of feature engineering for time series 95 | forecasting in Python. 96 | - Out-of-the-box compatibility with Spark, Dask, and Ray. 97 | - Probabilistic Forecasting with Conformal Prediction. 98 | - Support for exogenous variables and static covariates. 99 | - Familiar `sklearn` syntax: `.fit` and `.predict`. 100 | 101 | Missing something? Please open an issue or write us in 102 | [![Slack](https://img.shields.io/badge/Slack-4A154B?&logo=slack&logoColor=white.png)](https://join.slack.com/t/nixtlaworkspace/shared_invite/zt-135dssye9-fWTzMpv2WBthq8NK0Yvu6A) 103 | 104 | ## Examples and Guides 105 | 106 | 📚 [End to End 107 | Walkthrough](https://nixtla.github.io/mlforecast/docs/end_to_end_walkthrough.html): 108 | model training, evaluation and selection for multiple time series. 109 | 110 | 🔎 [Probabilistic 111 | Forecasting](https://nixtla.github.io/mlforecast/docs/prediction_intervals.html): 112 | use Conformal Prediction to produce prediciton intervals. 113 | 114 | 👩‍🔬 [Cross 115 | Validation](https://nixtla.github.io/mlforecast/docs/cross_validation.html): 116 | robust model’s performance evaluation. 117 | 118 | 🔌 [Predict Demand 119 | Peaks](https://nixtla.github.io/mlforecast/docs/electricity_peak_forecasting.html): 120 | electricity load forecasting for detecting daily peaks and reducing 121 | electric bills. 122 | 123 | 📈 [Transfer 124 | Learning](https://nixtla.github.io/mlforecast/docs/transfer_learning.html): 125 | pretrain a model using a set of time series and then predict another one 126 | using that pretrained model. 127 | 128 | 🌡️ [Distributed 129 | Training](https://nixtla.github.io/mlforecast/docs/quick_start_distributed.html): 130 | use a Dask cluster to train models at scale. 131 | 132 | ## How to use 133 | 134 | The following provides a very basic overview, for a more detailed 135 | description see the 136 | [documentation](https://nixtla.github.io/mlforecast/). 137 | 138 | ### Data setup 139 | 140 | Store your time series in a pandas dataframe in long format, that is, 141 | each row represents an observation for a specific serie and timestamp. 142 | 143 | ``` python 144 | from mlforecast.utils import generate_daily_series 145 | 146 | series = generate_daily_series( 147 | n_series=20, 148 | max_length=100, 149 | n_static_features=1, 150 | static_as_categorical=False, 151 | with_trend=True 152 | ) 153 | series.head() 154 | ``` 155 | 156 |
157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 |
unique_iddsystatic_0
0id_002000-01-011.75191772
1id_002000-01-029.19671572
2id_002000-01-0318.57778872
3id_002000-01-0424.52064672
4id_002000-01-0533.41802872
205 |
206 | 207 | ### Models 208 | 209 | Next define your models. If you want to use the local interface this can 210 | be any regressor that follows the scikit-learn API. For distributed 211 | training there are `LGBMForecast` and `XGBForecast`. 212 | 213 | ``` python 214 | import lightgbm as lgb 215 | import xgboost as xgb 216 | from sklearn.ensemble import RandomForestRegressor 217 | 218 | models = [ 219 | lgb.LGBMRegressor(), 220 | xgb.XGBRegressor(), 221 | RandomForestRegressor(random_state=0), 222 | ] 223 | ``` 224 | 225 | ### Forecast object 226 | 227 | Now instantiate a `MLForecast` object with the models and the features 228 | that you want to use. The features can be lags, transformations on the 229 | lags and date features. The lag transformations are defined as 230 | [numba](http://numba.pydata.org/) *jitted* functions that transform an 231 | array, if they have additional arguments you can either supply a tuple 232 | (`transform_func`, `arg1`, `arg2`, …) or define new functions fixing the 233 | arguments. You can also define differences to apply to the series before 234 | fitting that will be restored when predicting. 235 | 236 | ``` python 237 | from mlforecast import MLForecast 238 | from numba import njit 239 | from window_ops.expanding import expanding_mean 240 | from window_ops.rolling import rolling_mean 241 | 242 | 243 | @njit 244 | def rolling_mean_28(x): 245 | return rolling_mean(x, window_size=28) 246 | 247 | 248 | fcst = MLForecast( 249 | models=models, 250 | freq='D', 251 | lags=[7, 14], 252 | lag_transforms={ 253 | 1: [expanding_mean], 254 | 7: [rolling_mean_28] 255 | }, 256 | date_features=['dayofweek'], 257 | differences=[1], 258 | ) 259 | ``` 260 | 261 | ### Training 262 | 263 | To compute the features and train the models call `fit` on your 264 | `Forecast` object. 265 | 266 | ``` python 267 | fcst.fit(series) 268 | ``` 269 | 270 | MLForecast(models=[LGBMRegressor, XGBRegressor, RandomForestRegressor], freq=, lag_features=['lag7', 'lag14', 'expanding_mean_lag1', 'rolling_mean_28_lag7'], date_features=['dayofweek'], num_threads=1) 271 | 272 | ### Predicting 273 | 274 | To get the forecasts for the next `n` days call `predict(n)` on the 275 | forecast object. This will automatically handle the updates required by 276 | the features using a recursive strategy. 277 | 278 | ``` python 279 | predictions = fcst.predict(14) 280 | predictions 281 | ``` 282 | 283 |
284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 |
unique_iddsLGBMRegressorXGBRegressorRandomForestRegressor
0id_002000-04-0469.08283067.76133768.184016
1id_002000-04-0575.70602474.58869975.470680
2id_002000-04-0682.22247381.05828982.846249
3id_002000-04-0789.57763888.73594790.201271
4id_002000-04-0844.14909544.98138446.096322
..................
275id_192000-03-2330.23601231.94909532.656369
276id_192000-03-2431.30826932.76591933.624488
277id_192000-03-2532.78855033.62886434.581486
278id_192000-03-2634.08697634.50845735.553173
279id_192000-03-2734.28896835.41161336.526505
386 |

280 rows × 5 columns

387 |
388 | 389 | ### Visualize results 390 | 391 | ``` python 392 | import matplotlib.pyplot as plt 393 | import pandas as pd 394 | 395 | fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(12, 6), gridspec_kw=dict(hspace=0.3)) 396 | for i, (uid, axi) in enumerate(zip(series['unique_id'].unique(), ax.flat)): 397 | fltr = lambda df: df['unique_id'].eq(uid) 398 | pd.concat([series.loc[fltr, ['ds', 'y']], predictions.loc[fltr]]).set_index('ds').plot(ax=axi) 399 | axi.set(title=uid, xlabel=None) 400 | if i % 2 == 0: 401 | axi.legend().remove() 402 | else: 403 | axi.legend(bbox_to_anchor=(1.01, 1.0)) 404 | fig.savefig('figs/index.png', bbox_inches='tight') 405 | plt.close() 406 | ``` 407 | 408 | ![](https://raw.githubusercontent.com/Nixtla/mlforecast/main/figs/index.png) 409 | 410 | ## Sample notebooks 411 | 412 | - [m5](https://www.kaggle.com/code/lemuz90/m5-mlforecast-eval) 413 | - [m4](https://www.kaggle.com/code/lemuz90/m4-competition) 414 | - [m4-cv](https://www.kaggle.com/code/lemuz90/m4-competition-cv) 415 | 416 | ## How to contribute 417 | 418 | See 419 | [CONTRIBUTING.md](https://github.com/Nixtla/mlforecast/blob/main/CONTRIBUTING.md). 420 | -------------------------------------------------------------------------------- /nbs/docs/cross_validation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Cross validation\n", 8 | "\n", 9 | "> In this example, we'll implement time series cross-validation to evaluate model's performance. " 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "::: {.callout-warning collapse=\"true\"}\n", 17 | "\n", 18 | "## Prerequesites\n", 19 | "\n", 20 | "This tutorial assumes basic familiarity with `MLForecast`. For a minimal example visit the [Quick Start](https://nixtla.github.io/mlforecast/docs/quick_start_local.html) \n", 21 | ":::" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Introduction \n", 29 | "\n", 30 | "Time series cross-validation is a method for evaluating how a model would have performed in the past. It works by defining a sliding window across the historical data and predicting the period following it. \n", 31 | "\n", 32 | "![](https://raw.githubusercontent.com/Nixtla/statsforecast/main/nbs/imgs/ChainedWindows.gif) " 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "[MLForecast](https://nixtla.github.io/mlforecast/) has an implementation of time series cross-validation that is fast and easy to use. This implementation makes cross-validation a efficient operation, which makes it less time-consuming. In this notebook, we'll use it on a subset of the [M4 Competition](https://www.sciencedirect.com/science/article/pii/S0169207019301128) hourly dataset. " 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "**Outline:**\n", 47 | "\n", 48 | "1. Install libraries \n", 49 | "2. Load and explore data \n", 50 | "3. Train model\n", 51 | "4. Perform time series cross-validation \n", 52 | "5. Evaluate results " 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "::: {.callout-tip}\n", 60 | "You can use Colab to run this Notebook interactively \"Open\n", 61 | "::: " 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## Install libraries" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "We assume that you have `MLForecast` already installed. If not, check this guide for instructions on [how to install MLForecast](https://nixtla.github.io/mlforecast/docs/install.html)." 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "Install the necessary packages with `pip install mlforecast`." 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "%%capture\n", 92 | "pip install mlforecast lightgbm" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "from mlforecast import MLForecast # required to instantiate MLForecast object and use cross-validation method " 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Load and explore the data" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "As stated in the introduction, we'll use the M4 Competition hourly dataset. We'll first import the data from an URL using `pandas`. " 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "data": { 125 | "text/html": [ 126 | "
\n", 127 | "\n", 140 | "\n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | "
unique_iddsy
0H11605.0
1H12586.0
2H13586.0
3H14559.0
4H15511.0
\n", 182 | "
" 183 | ], 184 | "text/plain": [ 185 | " unique_id ds y\n", 186 | "0 H1 1 605.0\n", 187 | "1 H1 2 586.0\n", 188 | "2 H1 3 586.0\n", 189 | "3 H1 4 559.0\n", 190 | "4 H1 5 511.0" 191 | ] 192 | }, 193 | "execution_count": null, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "import pandas as pd \n", 200 | "\n", 201 | "Y_df = pd.read_csv('https://datasets-nixtla.s3.amazonaws.com/m4-hourly.csv') # load the data \n", 202 | "Y_df.head() " 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "The input to `MLForecast` is a data frame in [long format](https://www.theanalysisfactor.com/wide-and-long-data/) with three columns: `unique_id`, `ds` and `y`: \n", 210 | "\n", 211 | "- The `unique_id` (string, int, or category) represents an identifier for the series. \n", 212 | "- The `ds` (datestamp or int) column should be either an integer indexing time or a datestamp in format YYYY-MM-DD or YYYY-MM-DD HH:MM:SS. \n", 213 | "- The `y` (numeric) represents the measurement we wish to forecast. \n", 214 | "\n", 215 | "The data in this example already has this format, so no changes are needed. " 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "We can plot the time series we'll work with using the following function. " 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "import matplotlib.pyplot as plt\n", 232 | "\n", 233 | "def plot(df, fname, last_n=24 * 14):\n", 234 | " fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(14, 6), gridspec_kw=dict(hspace=0.5))\n", 235 | " uids = df['unique_id'].unique()\n", 236 | " for i, (uid, axi) in enumerate(zip(uids, ax.flat)):\n", 237 | " legend = i % 2 == 0\n", 238 | " df[df['unique_id'].eq(uid)].tail(last_n).set_index('ds').plot(ax=axi, title=uid, legend=legend)\n", 239 | " fig.savefig(fname, bbox_inches='tight')\n", 240 | " plt.close()" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "plot(Y_df, '../figs/cross_validation__series.png')" 250 | ] 251 | }, 252 | { 253 | "cell_type": "markdown", 254 | "metadata": {}, 255 | "source": [ 256 | "![](../figs/cross_validation__series.png)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": {}, 262 | "source": [ 263 | "## Train model" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": {}, 269 | "source": [ 270 | "For this example, we'll use LightGBM. We first need to import it and then we need to instantiate a new `MLForecast` object. " 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "The `MLForecast` object has the following parameters: \n", 278 | "\n", 279 | "- `models`: a list of sklearn-like (`fit` and `predict`) models. \n", 280 | "- `freq`: a string indicating the frequency of the data. See [panda’s available frequencies.](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)\n", 281 | "- `differences`: Differences to take of the target before computing the features. These are restored at the forecasting step.\n", 282 | "- `lags`: Lags of the target to use as features.\n", 283 | "\n", 284 | "In this example, we are only using `differences` and `lags` to produce features. See [the full documentation](https://nixtla.github.io/mlforecast/forecast.html) to see all available features.\n", 285 | "\n", 286 | "Any settings are passed into the constructor. Then you call its `fit` method and pass in the historical data frame `df`. " 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "import lightgbm as lgb\n", 296 | "\n", 297 | "models = [lgb.LGBMRegressor()]\n", 298 | "\n", 299 | "mlf = MLForecast(\n", 300 | " models = models, \n", 301 | " freq = 1,# our series have integer timestamps, so we'll just add 1 in every timeste, \n", 302 | " differences=[24],\n", 303 | " lags=range(1, 25, 1)\n", 304 | ")" 305 | ] 306 | }, 307 | { 308 | "cell_type": "markdown", 309 | "metadata": {}, 310 | "source": [ 311 | "## Perform time series cross-validation" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": {}, 317 | "source": [ 318 | "Once the `MLForecast` object has been instantiated, we can use the `cross_validation` method, which takes the following arguments: \n", 319 | "\n", 320 | "- `data`: training data frame with `MLForecast` format \n", 321 | "- `window_size` (int): represents the h steps into the future that will be forecasted \n", 322 | "- `n_windows` (int): number of windows used for cross-validation, meaning the number of forecasting processes in the past you want to evaluate. \n", 323 | "- `id_col`: identifies each time series.\n", 324 | "- `time_col`: indetifies the temporal column of the time series. \n", 325 | "- `target_col`: identifies the column to model." 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": {}, 331 | "source": [ 332 | "For this particular example, we'll use 3 windows of 24 hours." 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "crossvalidation_df = mlf.cross_validation(\n", 342 | " data=Y_df,\n", 343 | " window_size=24,\n", 344 | " n_windows=3,\n", 345 | ")" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": {}, 351 | "source": [ 352 | "The crossvaldation_df object is a new data frame that includes the following columns:\n", 353 | "\n", 354 | "- `unique_id`: identifies each time series.\n", 355 | "- `ds`: datestamp or temporal index.\n", 356 | "- `cutoff`: the last datestamp or temporal index for the `n_windows`. \n", 357 | "- `y`: true value\n", 358 | "- `\"model\"`: columns with the model’s name and fitted value." 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [ 366 | { 367 | "data": { 368 | "text/html": [ 369 | "
\n", 370 | "\n", 383 | "\n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | "
unique_iddscutoffyLGBMRegressor
0H1677676691.0673.703191
1H1678676618.0552.306270
2H1679676563.0541.778027
3H1680676529.0502.778027
4H1681676504.0480.778027
\n", 437 | "
" 438 | ], 439 | "text/plain": [ 440 | " unique_id ds cutoff y LGBMRegressor\n", 441 | "0 H1 677 676 691.0 673.703191\n", 442 | "1 H1 678 676 618.0 552.306270\n", 443 | "2 H1 679 676 563.0 541.778027\n", 444 | "3 H1 680 676 529.0 502.778027\n", 445 | "4 H1 681 676 504.0 480.778027" 446 | ] 447 | }, 448 | "execution_count": null, 449 | "metadata": {}, 450 | "output_type": "execute_result" 451 | } 452 | ], 453 | "source": [ 454 | "crossvalidation_df.head()" 455 | ] 456 | }, 457 | { 458 | "cell_type": "markdown", 459 | "metadata": {}, 460 | "source": [ 461 | "We'll now plot the forecast for each cutoff period." 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "def plot_cv(df, df_cv, uid, fname, last_n=24 * 14):\n", 471 | " cutoffs = df_cv.query('unique_id == @uid')['cutoff'].unique()\n", 472 | " fig, ax = plt.subplots(nrows=len(cutoffs), ncols=1, figsize=(14, 6), gridspec_kw=dict(hspace=0.8))\n", 473 | " for cutoff, axi in zip(cutoffs, ax.flat):\n", 474 | " df.query('unique_id == @uid').tail(last_n).set_index('ds').plot(ax=axi, title=uid, y='y')\n", 475 | " df_cv.query('unique_id == @uid & cutoff == @cutoff').set_index('ds').plot(ax=axi, title=uid, y='LGBMRegressor')\n", 476 | " fig.savefig(fname, bbox_inches='tight')\n", 477 | " plt.close()" 478 | ] 479 | }, 480 | { 481 | "cell_type": "code", 482 | "execution_count": null, 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "plot_cv(Y_df, crossvalidation_df, 'H1', '../figs/cross_validation__predictions.png')" 487 | ] 488 | }, 489 | { 490 | "cell_type": "markdown", 491 | "metadata": {}, 492 | "source": [ 493 | "![](../figs/cross_validation__predictions.png)" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "metadata": {}, 499 | "source": [ 500 | "Notice that in each cutoff period, we generated a forecast for the next 24 hours using only the data `y` before said period. " 501 | ] 502 | }, 503 | { 504 | "cell_type": "markdown", 505 | "metadata": {}, 506 | "source": [ 507 | "## Evaluate results " 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "We can now compute the accuracy of the forecast using an appropiate accuracy metric. Here we'll use the [Root Mean Squared Error (RMSE).](https://en.wikipedia.org/wiki/Root-mean-square_deviation) To do this, we first need to install [datasetsforecast](https://github.com/Nixtla/datasetsforecast/tree/main/), a Python library developed by Nixtla that includes a function to compute the RMSE. " 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "%%capture\n", 524 | "pip install datasetsforecast" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [ 533 | "from datasetsforecast.losses import rmse " 534 | ] 535 | }, 536 | { 537 | "cell_type": "markdown", 538 | "metadata": {}, 539 | "source": [ 540 | "The function to compute the RMSE takes two arguments: \n", 541 | " \n", 542 | "1. The actual values. \n", 543 | "2. The forecasts, in this case, `LGBMRegressor`. \n", 544 | "\n", 545 | "In this case we will compute the `rmse` per time series and cutoff and then we will take the mean of the results." 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "metadata": {}, 552 | "outputs": [ 553 | { 554 | "name": "stdout", 555 | "output_type": "stream", 556 | "text": [ 557 | "RMSE using cross-validation: 249.90517171185527\n" 558 | ] 559 | } 560 | ], 561 | "source": [ 562 | "cv_rmse = crossvalidation_df.groupby(['unique_id', 'cutoff']).apply(lambda df: rmse(df['y'], df['LGBMRegressor'])).mean()\n", 563 | "print(\"RMSE using cross-validation: \", cv_rmse)" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": {}, 569 | "source": [ 570 | "This measure should better reflect the predictive abilities of our model, since it used different time periods to test its accuracy. " 571 | ] 572 | }, 573 | { 574 | "cell_type": "markdown", 575 | "metadata": {}, 576 | "source": [ 577 | "## References \n", 578 | "\n", 579 | "[Rob J. Hyndman and George Athanasopoulos (2018). \"Forecasting principles and practice, Time series cross-validation\"](https://otexts.com/fpp3/tscv.html)." 580 | ] 581 | } 582 | ], 583 | "metadata": { 584 | "kernelspec": { 585 | "display_name": "python3", 586 | "language": "python", 587 | "name": "python3" 588 | } 589 | }, 590 | "nbformat": 4, 591 | "nbformat_minor": 4 592 | } 593 | -------------------------------------------------------------------------------- /nbs/index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#| hide\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "#| hide\n", 21 | "import os\n", 22 | "os.chdir('..')" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "# Nixtla   [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Statistical%20Forecasting%20Algorithms%20by%20Nixtla%20&url=https://github.com/Nixtla/statsforecast&via=nixtlainc&hashtags=StatisticalModels,TimeSeries,Forecasting)  ![Slack](https://img.shields.io/badge/Slack-4A154B?&logo=slack&logoColor=white)" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "
\n", 37 | "
\n", 38 | "\n", 39 | "
\n", 40 | "

Machine Learning 🤖 Forecast

\n", 41 | "

Scalable machine learning for time series forecasting

\n", 42 | " \n", 43 | "[![CI](https://github.com/Nixtla/mlforecast/actions/workflows/ci.yaml/badge.svg)](https://github.com/Nixtla/mlforecast/actions/workflows/ci.yaml)\n", 44 | "[![Python](https://img.shields.io/pypi/pyversions/mlforecast)](https://pypi.org/project/mlforecast/)\n", 45 | "[![PyPi](https://img.shields.io/pypi/v/mlforecast?color=blue)](https://pypi.org/project/mlforecast/)\n", 46 | "[![conda-forge](https://img.shields.io/conda/vn/conda-forge/mlforecast?color=blue)](https://anaconda.org/conda-forge/mlforecast)\n", 47 | "[![License](https://img.shields.io/github/license/Nixtla/mlforecast)](https://github.com/Nixtla/mlforecast/blob/main/LICENSE)\n", 48 | " \n", 49 | "**mlforecast** is a framework to perform time series forecasting using machine learning models, with the option to scale to massive amounts of data using remote clusters.\n", 50 | "\n", 51 | "
" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "## Install" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "### PyPI\n", 66 | "\n", 67 | "`pip install mlforecast`\n", 68 | "\n", 69 | "If you want to perform distributed training, you can instead use `pip install \"mlforecast[distributed]\"`, which will also install [dask](https://dask.org/). Note that you'll also need to install either [LightGBM](https://github.com/microsoft/LightGBM/tree/master/python-package) or [XGBoost](https://xgboost.readthedocs.io/en/latest/install.html#python)." 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "### conda-forge\n", 77 | "`conda install -c conda-forge mlforecast`\n", 78 | "\n", 79 | "Note that this installation comes with the required dependencies for the local interface. If you want to perform distributed training, you must install dask (`conda install -c conda-forge dask`) and either [LightGBM](https://github.com/microsoft/LightGBM/tree/master/python-package) or [XGBoost](https://xgboost.readthedocs.io/en/latest/install.html#python)." 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "## Quick Start\n", 87 | "\n", 88 | "**Minimal Example**\n", 89 | "\n", 90 | "```python\n", 91 | "import lightgbm as lgb\n", 92 | "\n", 93 | "from mlforecast import MLForecast\n", 94 | "from sklearn.linear_model import LinearRegression\n", 95 | "\n", 96 | "mlf = MLForecast(\n", 97 | " models = [LinearRegression(), lgb.LGBMRegressor()],\n", 98 | " lags=[1, 12],\n", 99 | " freq = 'M'\n", 100 | ")\n", 101 | "mlf.fit(df)\n", 102 | "mlf.predict(12)\n", 103 | "```\n", 104 | "\n", 105 | "**Get Started with this [quick guide](https://nixtla.github.io/mlforecast/docs/quick_start_local.html).**\n", 106 | "\n", 107 | "**Follow this [end-to-end walkthrough](https://nixtla.github.io/mlforecast/docs/end_to_end_walkthrough.html) for best practices.**" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "## Why? \n", 115 | "\n", 116 | "Current Python alternatives for machine learning models are slow, inaccurate and don't scale well. So we created a library that can be used to forecast in production environments. `MLForecast` includes efficient feature engineering to train any machine learning model (with `fit` and `predict` methods such as [`sklearn`](https://scikit-learn.org/stable/)) to fit millions of time series.\n", 117 | "\n", 118 | "## Features\n", 119 | "\n", 120 | "* Fastest implementations of feature engineering for time series forecasting in Python. \n", 121 | "* Out-of-the-box compatibility with Spark, Dask, and Ray.\n", 122 | "* Probabilistic Forecasting with Conformal Prediction.\n", 123 | "* Support for exogenous variables and static covariates.\n", 124 | "* Familiar `sklearn` syntax: `.fit` and `.predict`.\n", 125 | "\n", 126 | "\n", 127 | "Missing something? Please open an issue or write us in [![Slack](https://img.shields.io/badge/Slack-4A154B?&logo=slack&logoColor=white)](https://join.slack.com/t/nixtlaworkspace/shared_invite/zt-135dssye9-fWTzMpv2WBthq8NK0Yvu6A)" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Examples and Guides\n", 135 | "\n", 136 | "📚 [End to End Walkthrough](https://nixtla.github.io/mlforecast/docs/end_to_end_walkthrough.html): model training, evaluation and selection for multiple time series.\n", 137 | "\n", 138 | "🔎 [Probabilistic Forecasting](https://nixtla.github.io/mlforecast/docs/prediction_intervals.html): use Conformal Prediction to produce prediciton intervals. \n", 139 | "\n", 140 | "👩‍🔬 [Cross Validation](https://nixtla.github.io/mlforecast/docs/cross_validation.html): robust model’s performance evaluation.\n", 141 | "\n", 142 | "🔌 [Predict Demand Peaks](https://nixtla.github.io/mlforecast/docs/electricity_peak_forecasting.html): electricity load forecasting for detecting daily peaks and reducing electric bills.\n", 143 | "\n", 144 | "📈 [Transfer Learning](https://nixtla.github.io/mlforecast/docs/transfer_learning.html): pretrain a model using a set of time series and then predict another one using that pretrained model. \n", 145 | "\n", 146 | "🌡️ [Distributed Training](https://nixtla.github.io/mlforecast/docs/quick_start_distributed.html): use a Dask cluster to train models at scale.\n" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": {}, 152 | "source": [ 153 | "## How to use\n", 154 | "\n", 155 | "The following provides a very basic overview, for a more detailed description see the [documentation](https://nixtla.github.io/mlforecast/)." 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "### Data setup" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "Store your time series in a pandas dataframe in long format, that is, each row represents an observation for a specific serie and timestamp." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "text/html": [ 180 | "
\n", 181 | "\n", 194 | "\n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | "
unique_iddsystatic_0
0id_002000-01-011.75191772
1id_002000-01-029.19671572
2id_002000-01-0318.57778872
3id_002000-01-0424.52064672
4id_002000-01-0533.41802872
\n", 242 | "
" 243 | ], 244 | "text/plain": [ 245 | " unique_id ds y static_0\n", 246 | "0 id_00 2000-01-01 1.751917 72\n", 247 | "1 id_00 2000-01-02 9.196715 72\n", 248 | "2 id_00 2000-01-03 18.577788 72\n", 249 | "3 id_00 2000-01-04 24.520646 72\n", 250 | "4 id_00 2000-01-05 33.418028 72" 251 | ] 252 | }, 253 | "execution_count": null, 254 | "metadata": {}, 255 | "output_type": "execute_result" 256 | } 257 | ], 258 | "source": [ 259 | "from mlforecast.utils import generate_daily_series\n", 260 | "\n", 261 | "series = generate_daily_series(\n", 262 | " n_series=20,\n", 263 | " max_length=100,\n", 264 | " n_static_features=1,\n", 265 | " static_as_categorical=False,\n", 266 | " with_trend=True\n", 267 | ")\n", 268 | "series.head()" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "### Models" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "Next define your models. If you want to use the local interface this can be any regressor that follows the scikit-learn API. For distributed training there are `LGBMForecast` and `XGBForecast`." 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "metadata": {}, 289 | "outputs": [], 290 | "source": [ 291 | "import lightgbm as lgb\n", 292 | "import xgboost as xgb\n", 293 | "from sklearn.ensemble import RandomForestRegressor\n", 294 | "\n", 295 | "models = [\n", 296 | " lgb.LGBMRegressor(),\n", 297 | " xgb.XGBRegressor(),\n", 298 | " RandomForestRegressor(random_state=0),\n", 299 | "]" 300 | ] 301 | }, 302 | { 303 | "cell_type": "markdown", 304 | "metadata": {}, 305 | "source": [ 306 | "### Forecast object" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "Now instantiate a `MLForecast` object with the models and the features that you want to use. The features can be lags, transformations on the lags and date features. The lag transformations are defined as [numba](http://numba.pydata.org/) *jitted* functions that transform an array, if they have additional arguments you can either supply a tuple (`transform_func`, `arg1`, `arg2`, ...) or define new functions fixing the arguments. You can also define differences to apply to the series before fitting that will be restored when predicting." 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "from mlforecast import MLForecast\n", 323 | "from numba import njit\n", 324 | "from window_ops.expanding import expanding_mean\n", 325 | "from window_ops.rolling import rolling_mean\n", 326 | "\n", 327 | "\n", 328 | "@njit\n", 329 | "def rolling_mean_28(x):\n", 330 | " return rolling_mean(x, window_size=28)\n", 331 | "\n", 332 | "\n", 333 | "fcst = MLForecast(\n", 334 | " models=models,\n", 335 | " freq='D',\n", 336 | " lags=[7, 14],\n", 337 | " lag_transforms={\n", 338 | " 1: [expanding_mean],\n", 339 | " 7: [rolling_mean_28]\n", 340 | " },\n", 341 | " date_features=['dayofweek'],\n", 342 | " differences=[1],\n", 343 | ")" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "metadata": {}, 349 | "source": [ 350 | "### Training" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": {}, 356 | "source": [ 357 | "To compute the features and train the models call `fit` on your `Forecast` object." 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "data": { 367 | "text/plain": [ 368 | "MLForecast(models=[LGBMRegressor, XGBRegressor, RandomForestRegressor], freq=, lag_features=['lag7', 'lag14', 'expanding_mean_lag1', 'rolling_mean_28_lag7'], date_features=['dayofweek'], num_threads=1)" 369 | ] 370 | }, 371 | "execution_count": null, 372 | "metadata": {}, 373 | "output_type": "execute_result" 374 | } 375 | ], 376 | "source": [ 377 | "fcst.fit(series)" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "### Predicting" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "To get the forecasts for the next `n` days call `predict(n)` on the forecast object. This will automatically handle the updates required by the features using a recursive strategy." 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [ 399 | { 400 | "data": { 401 | "text/html": [ 402 | "
\n", 403 | "\n", 416 | "\n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | "
unique_iddsLGBMRegressorXGBRegressorRandomForestRegressor
0id_002000-04-0469.08283067.76133768.184016
1id_002000-04-0575.70602474.58869975.470680
2id_002000-04-0682.22247381.05828982.846249
3id_002000-04-0789.57763888.73594790.201271
4id_002000-04-0844.14909544.98138446.096322
..................
275id_192000-03-2330.23601231.94909532.656369
276id_192000-03-2431.30826932.76591933.624488
277id_192000-03-2532.78855033.62886434.581486
278id_192000-03-2634.08697634.50845735.553173
279id_192000-03-2734.28896835.41161336.526505
\n", 518 | "

280 rows × 5 columns

\n", 519 | "
" 520 | ], 521 | "text/plain": [ 522 | " unique_id ds LGBMRegressor XGBRegressor RandomForestRegressor\n", 523 | "0 id_00 2000-04-04 69.082830 67.761337 68.184016\n", 524 | "1 id_00 2000-04-05 75.706024 74.588699 75.470680\n", 525 | "2 id_00 2000-04-06 82.222473 81.058289 82.846249\n", 526 | "3 id_00 2000-04-07 89.577638 88.735947 90.201271\n", 527 | "4 id_00 2000-04-08 44.149095 44.981384 46.096322\n", 528 | ".. ... ... ... ... ...\n", 529 | "275 id_19 2000-03-23 30.236012 31.949095 32.656369\n", 530 | "276 id_19 2000-03-24 31.308269 32.765919 33.624488\n", 531 | "277 id_19 2000-03-25 32.788550 33.628864 34.581486\n", 532 | "278 id_19 2000-03-26 34.086976 34.508457 35.553173\n", 533 | "279 id_19 2000-03-27 34.288968 35.411613 36.526505\n", 534 | "\n", 535 | "[280 rows x 5 columns]" 536 | ] 537 | }, 538 | "execution_count": null, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | } 542 | ], 543 | "source": [ 544 | "predictions = fcst.predict(14)\n", 545 | "predictions" 546 | ] 547 | }, 548 | { 549 | "cell_type": "markdown", 550 | "metadata": {}, 551 | "source": [ 552 | "### Visualize results" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "metadata": {}, 559 | "outputs": [], 560 | "source": [ 561 | "import matplotlib.pyplot as plt\n", 562 | "import pandas as pd\n", 563 | "\n", 564 | "fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(12, 6), gridspec_kw=dict(hspace=0.3))\n", 565 | "for i, (uid, axi) in enumerate(zip(series['unique_id'].unique(), ax.flat)):\n", 566 | " fltr = lambda df: df['unique_id'].eq(uid)\n", 567 | " pd.concat([series.loc[fltr, ['ds', 'y']], predictions.loc[fltr]]).set_index('ds').plot(ax=axi)\n", 568 | " axi.set(title=uid, xlabel=None)\n", 569 | " if i % 2 == 0:\n", 570 | " axi.legend().remove()\n", 571 | " else:\n", 572 | " axi.legend(bbox_to_anchor=(1.01, 1.0))\n", 573 | "fig.savefig('figs/index.png', bbox_inches='tight')\n", 574 | "plt.close()" 575 | ] 576 | }, 577 | { 578 | "cell_type": "markdown", 579 | "metadata": {}, 580 | "source": [ 581 | "![](https://raw.githubusercontent.com/Nixtla/mlforecast/main/figs/index.png)" 582 | ] 583 | }, 584 | { 585 | "cell_type": "markdown", 586 | "metadata": {}, 587 | "source": [ 588 | "## Sample notebooks" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": {}, 594 | "source": [ 595 | "* [m5](https://www.kaggle.com/code/lemuz90/m5-mlforecast-eval)\n", 596 | "* [m4](https://www.kaggle.com/code/lemuz90/m4-competition)\n", 597 | "* [m4-cv](https://www.kaggle.com/code/lemuz90/m4-competition-cv)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "markdown", 602 | "metadata": {}, 603 | "source": [ 604 | "## How to contribute\n", 605 | "See [CONTRIBUTING.md](https://github.com/Nixtla/mlforecast/blob/main/CONTRIBUTING.md)." 606 | ] 607 | } 608 | ], 609 | "metadata": { 610 | "kernelspec": { 611 | "display_name": "python3", 612 | "language": "python", 613 | "name": "python3" 614 | } 615 | }, 616 | "nbformat": 4, 617 | "nbformat_minor": 4 618 | } 619 | -------------------------------------------------------------------------------- /nbs/docs/quick_start_distributed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "9693434b-0290-4ffd-88f9-a4267c11b548", 6 | "metadata": {}, 7 | "source": [ 8 | "# Quick start (distributed)\n", 9 | "\n", 10 | "> Minimal example of distributed training with MLForecast" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "3a829526-7edd-4c9d-9aba-95441d184b31", 16 | "metadata": {}, 17 | "source": [ 18 | "## Main concepts" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "dcc46367-361e-4e2b-9297-31a98a5c4b09", 24 | "metadata": {}, 25 | "source": [ 26 | "The main component for distributed training with mlforecast is the `DistributedMLForecast` class, which abstracts away:\n", 27 | "\n", 28 | "* Feature engineering and model training through `DistributedMLForecast.fit`\n", 29 | "* Feature updates and multi step ahead predictions through `DistributedMLForecast.predict`" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "b088eaaf-6c5b-4292-b420-da2916233558", 35 | "metadata": {}, 36 | "source": [ 37 | "## Setup" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "a3d06187-a523-46de-af83-17f17d63849d", 43 | "metadata": {}, 44 | "source": [ 45 | "In order to perform distributed training you need a dask cluster. In this example we'll use a local cluster but you can replace it with any other type of remote cluster and the processing will take place there." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "4f9225b4-4eeb-445f-b4bd-fd9bef11cd43", 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "name": "stderr", 56 | "output_type": "stream", 57 | "text": [ 58 | "2023-03-29 22:52:01,537 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-g5huyr4x', purging\n", 59 | "2023-03-29 22:52:01,537 - distributed.diskutils - INFO - Found stale lock file and directory '/tmp/dask-worker-space/worker-la8ynbsi', purging\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "from dask.distributed import Client, LocalCluster\n", 65 | "\n", 66 | "cluster = LocalCluster(n_workers=2, threads_per_worker=1) # change this to use a remote cluster\n", 67 | "client = Client(cluster)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "29aa9351-46ab-4f64-9bbc-86d0069630cd", 73 | "metadata": {}, 74 | "source": [ 75 | "## Data format" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "id": "39b91ff7-2661-46a5-a127-a1fce8c02611", 81 | "metadata": {}, 82 | "source": [ 83 | "The data is expected to be a dask dataframe in long format, that is, each row represents an observation of a single serie at a given time, with at least three columns:\n", 84 | "\n", 85 | "* `id_col`: column that identifies each serie.\n", 86 | "* `target_col`: column that has the series values at each timestamp.\n", 87 | "* `time_col`: column that contains the time the series value was observed. These are usually timestamps, but can also be consecutive integers.\n", 88 | "\n", 89 | "**You need to make sure that each serie is only in a single partition**. You can do so by setting the id_col as the index in dask or with repartitionByRange in spark.\n", 90 | "\n", 91 | "Here we present an example with synthetic data." 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "58e1f392-29e3-4328-af11-5e1935aa0bf8", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "import dask.dataframe as dd\n", 102 | "from mlforecast.utils import generate_daily_series" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "a09142c9-6105-4c39-9503-9b55054f66e8", 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/html": [ 114 | "
\n", 115 | "\n", 128 | "\n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | "
unique_iddsy
0id_002000-01-010.497650
1id_002000-01-021.554489
2id_002000-01-032.734311
3id_002000-01-044.028039
4id_002000-01-055.366009
............
26998id_992000-06-2534.165302
26999id_992000-06-2628.277320
27000id_992000-06-2729.450129
27001id_992000-06-2830.241885
27002id_992000-06-2931.576907
\n", 206 | "

27003 rows × 3 columns

\n", 207 | "
" 208 | ], 209 | "text/plain": [ 210 | " unique_id ds y\n", 211 | "0 id_00 2000-01-01 0.497650\n", 212 | "1 id_00 2000-01-02 1.554489\n", 213 | "2 id_00 2000-01-03 2.734311\n", 214 | "3 id_00 2000-01-04 4.028039\n", 215 | "4 id_00 2000-01-05 5.366009\n", 216 | "... ... ... ...\n", 217 | "26998 id_99 2000-06-25 34.165302\n", 218 | "26999 id_99 2000-06-26 28.277320\n", 219 | "27000 id_99 2000-06-27 29.450129\n", 220 | "27001 id_99 2000-06-28 30.241885\n", 221 | "27002 id_99 2000-06-29 31.576907\n", 222 | "\n", 223 | "[27003 rows x 3 columns]" 224 | ] 225 | }, 226 | "execution_count": null, 227 | "metadata": {}, 228 | "output_type": "execute_result" 229 | } 230 | ], 231 | "source": [ 232 | "series = generate_daily_series(100, with_trend=True)\n", 233 | "series" 234 | ] 235 | }, 236 | { 237 | "cell_type": "markdown", 238 | "id": "b905be3b-30e7-459e-94ed-aa6008201ccc", 239 | "metadata": {}, 240 | "source": [ 241 | "Here we can see that the index goes from `id_00` to `id_99`, which means we have 100 different series stacked together.\n", 242 | "\n", 243 | "We also have the `ds` column that contains the timestamps, in this case with a daily frequency, and the `y` column that contains the series values in each timestamp.\n", 244 | "\n", 245 | "In order to perform distributed processing and training we need to have these in a dask dataframe, this is typically done loading them directly in a distributed way, for example with `dd.read_parquet`." 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "e02ac871-05a8-48c7-b38d-b56174a9ed58", 252 | "metadata": {}, 253 | "outputs": [ 254 | { 255 | "data": { 256 | "text/html": [ 257 | "
Dask DataFrame Structure:
\n", 258 | "
\n", 259 | "\n", 272 | "\n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | "
unique_iddsy
npartitions=2
id_00objectdatetime64[ns]float64
id_49.........
id_99.........
\n", 308 | "
\n", 309 | "
Dask Name: assign, 5 graph layers
" 310 | ], 311 | "text/plain": [ 312 | "Dask DataFrame Structure:\n", 313 | " unique_id ds y\n", 314 | "npartitions=2 \n", 315 | "id_00 object datetime64[ns] float64\n", 316 | "id_49 ... ... ...\n", 317 | "id_99 ... ... ...\n", 318 | "Dask Name: assign, 5 graph layers" 319 | ] 320 | }, 321 | "execution_count": null, 322 | "metadata": {}, 323 | "output_type": "execute_result" 324 | } 325 | ], 326 | "source": [ 327 | "series_ddf = dd.from_pandas(series.set_index('unique_id'), npartitions=2) # make sure we split by id\n", 328 | "series_ddf = series_ddf.map_partitions(lambda part: part.reset_index()) # we can't have an index\n", 329 | "series_ddf['unique_id'] = series_ddf['unique_id'].astype('str') # categoricals aren't supported at the moment\n", 330 | "series_ddf" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "id": "be52c394-8b5c-46b6-8963-b60ebf3736f9", 336 | "metadata": {}, 337 | "source": [ 338 | "We now have a dask dataframe with two partitions which will be processed independently in each machine and their outputs will be combined to perform distributed training." 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "id": "8e8b5f8b-fa7c-4e4f-8fc3-799fc6a0d160", 344 | "metadata": {}, 345 | "source": [ 346 | "## Modeling" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "id": "01f984a5-eae5-4cc8-8636-ad56834c7c61", 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "import random\n", 357 | "import matplotlib.pyplot as plt\n", 358 | "\n", 359 | "def plot_sample(df, ax):\n", 360 | " idxs = df['unique_id'].unique()\n", 361 | " random.seed(0)\n", 362 | " sample_idxs = random.choices(idxs, k=4)\n", 363 | " for uid, axi in zip(sample_idxs, ax.flat):\n", 364 | " df[df['unique_id'].eq(uid)].set_index('ds').plot(ax=axi, title=uid)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "id": "0ccdb56e-595c-4099-b703-9f9baec846d7", 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 6), gridspec_kw=dict(hspace=0.5))\n", 375 | "plot_sample(series, ax)\n", 376 | "fig.savefig('../figs/quick_start_distributed__sample.png', bbox_inches='tight')\n", 377 | "plt.close()" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "id": "2422d6e2-12d5-431d-bd5a-0973d2f49fd0", 383 | "metadata": {}, 384 | "source": [ 385 | "![](../figs/quick_start_distributed__sample.png)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "id": "b1efdf73-8533-4315-a0f4-404cca531a94", 391 | "metadata": {}, 392 | "source": [ 393 | "We can see that the series have a clear trend, so we can take the first difference, i.e. take each value and subtract the value at the previous month. This can be achieved by setting `differences=[1]`.\n", 394 | "\n", 395 | "We can then train a LightGBM model using the value from the same day of the week at the previous week (lag 7) as a feature, this is done by passing `lags=[7]`." 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "id": "6a03d9d2-8cd6-4c14-aa1c-e092781a083c", 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "from mlforecast.distributed import DistributedMLForecast\n", 406 | "from mlforecast.distributed.models.dask.lgb import DaskLGBMForecast" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "8cdec3a1-f615-4681-a8ee-87cc6e34e22f", 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "name": "stderr", 417 | "output_type": "stream", 418 | "text": [ 419 | "/home/jose/mambaforge/envs/mlforecast/lib/python3.10/site-packages/lightgbm/dask.py:525: UserWarning: Parameter n_jobs will be ignored.\n", 420 | " _log_warning(f\"Parameter {param_alias} will be ignored.\")\n" 421 | ] 422 | }, 423 | { 424 | "name": "stdout", 425 | "output_type": "stream", 426 | "text": [ 427 | "Finding random open ports for workers\n", 428 | "[LightGBM] [Info] Trying to bind port 36331...\n", 429 | "[LightGBM] [Info] Binding port 36331 succeeded\n", 430 | "[LightGBM] [Info] Listening...\n", 431 | "[LightGBM] [Warning] Connecting to rank 1 failed, waiting for 200 milliseconds\n", 432 | "[LightGBM] [Info] Trying to bind port 43259...\n", 433 | "[LightGBM] [Info] Binding port 43259 succeeded\n", 434 | "[LightGBM] [Info] Listening...\n", 435 | "[LightGBM] [Info] Connected to rank 1\n", 436 | "[LightGBM] [Info] Connected to rank 0\n", 437 | "[LightGBM] [Info] Local rank: 0, total number of machines: 2\n", 438 | "[LightGBM] [Info] Local rank: 1, total number of machines: 2\n", 439 | "[LightGBM] [Warning] num_threads is set=1, n_jobs=-1 will be ignored. Current value: num_threads=1\n", 440 | "[LightGBM] [Warning] num_threads is set=1, n_jobs=-1 will be ignored. Current value: num_threads=1\n" 441 | ] 442 | }, 443 | { 444 | "data": { 445 | "text/plain": [ 446 | "DistributedMLForecast(models=[DaskLGBMForecast], freq=, lag_features=['lag7'], date_features=[], num_threads=1, engine=None)" 447 | ] 448 | }, 449 | "execution_count": null, 450 | "metadata": {}, 451 | "output_type": "execute_result" 452 | } 453 | ], 454 | "source": [ 455 | "fcst = DistributedMLForecast(\n", 456 | " models=DaskLGBMForecast(verbosity=-1),\n", 457 | " freq='D',\n", 458 | " lags=[7],\n", 459 | " differences=[1],\n", 460 | ")\n", 461 | "fcst.fit(series_ddf)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "id": "3d61860c-8e00-416a-8cc5-ee6b54f8578a", 467 | "metadata": {}, 468 | "source": [ 469 | "The previous line computed the features and trained the model, so now we're ready to compute our forecasts." 470 | ] 471 | }, 472 | { 473 | "cell_type": "markdown", 474 | "id": "f7ba3510-84fc-4b09-8f0b-e4abe2b1b9e6", 475 | "metadata": {}, 476 | "source": [ 477 | "## Forecasting" 478 | ] 479 | }, 480 | { 481 | "cell_type": "markdown", 482 | "id": "ba67c696-c3d1-4270-9b17-9644eca5e32e", 483 | "metadata": {}, 484 | "source": [ 485 | "Compute the forecast for the next 14 days." 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": null, 491 | "id": "aa63c1a6-a709-4fed-a6d2-35d5648b79e1", 492 | "metadata": {}, 493 | "outputs": [ 494 | { 495 | "data": { 496 | "text/html": [ 497 | "
Dask DataFrame Structure:
\n", 498 | "
\n", 499 | "\n", 512 | "\n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | "
unique_iddsDaskLGBMForecast
npartitions=2
id_00objectdatetime64[ns]float64
id_49.........
id_99.........
\n", 548 | "
\n", 549 | "
Dask Name: map, 17 graph layers
" 550 | ], 551 | "text/plain": [ 552 | "Dask DataFrame Structure:\n", 553 | " unique_id ds DaskLGBMForecast\n", 554 | "npartitions=2 \n", 555 | "id_00 object datetime64[ns] float64\n", 556 | "id_49 ... ... ...\n", 557 | "id_99 ... ... ...\n", 558 | "Dask Name: map, 17 graph layers" 559 | ] 560 | }, 561 | "execution_count": null, 562 | "metadata": {}, 563 | "output_type": "execute_result" 564 | } 565 | ], 566 | "source": [ 567 | "preds = fcst.predict(14)\n", 568 | "preds" 569 | ] 570 | }, 571 | { 572 | "cell_type": "markdown", 573 | "id": "b1d3b53b-e6ed-4939-9e76-f43051203718", 574 | "metadata": {}, 575 | "source": [ 576 | "These are returned as a dask dataframe as well. If it's safe (memory-wise) we can bring them to the main process." 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": null, 582 | "id": "9b45104e-11d2-4a5f-85fa-c2573f545f10", 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [ 586 | "local_preds = preds.compute()" 587 | ] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "id": "0783cf79-5d52-43e3-9e3c-290530504d4e", 592 | "metadata": {}, 593 | "source": [ 594 | "## Visualize results" 595 | ] 596 | }, 597 | { 598 | "cell_type": "markdown", 599 | "id": "b6d0ea21-b713-4c46-bd43-e162b21dbfb8", 600 | "metadata": {}, 601 | "source": [ 602 | "We can visualize what our prediction looks like." 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "id": "9f6e208d-6eb8-4e41-8dfb-0095c3e70f34", 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [ 612 | "import pandas as pd" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": null, 618 | "id": "6b060f4e-e037-428b-b38c-195981be66c6", 619 | "metadata": {}, 620 | "outputs": [], 621 | "source": [ 622 | "fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(10, 6), gridspec_kw=dict(hspace=0.5))\n", 623 | "plot_sample(pd.concat([series, local_preds.set_index('unique_id')]), ax)\n", 624 | "fig.savefig('../figs/quick_start_distributed__sample_prediction.png', bbox_inches='tight')\n", 625 | "plt.close()" 626 | ] 627 | }, 628 | { 629 | "cell_type": "markdown", 630 | "id": "9c65a60b-246b-42ee-a888-b082822c568e", 631 | "metadata": {}, 632 | "source": [ 633 | "![](../figs/quick_start_distributed__sample_prediction.png)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "markdown", 638 | "id": "4c7076e1-0388-4fbf-a27c-bc99ed291259", 639 | "metadata": {}, 640 | "source": [ 641 | "And that's it! You've trained a distributed LightGBM model and computed predictions for the next 14 days." 642 | ] 643 | } 644 | ], 645 | "metadata": { 646 | "kernelspec": { 647 | "display_name": "python3", 648 | "language": "python", 649 | "name": "python3" 650 | } 651 | }, 652 | "nbformat": 4, 653 | "nbformat_minor": 5 654 | } 655 | -------------------------------------------------------------------------------- /mlforecast/lgb_cv.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/lgb_cv.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['LightGBMCV'] 5 | 6 | # %% ../nbs/lgb_cv.ipynb 3 7 | import copy 8 | import os 9 | from concurrent.futures import ThreadPoolExecutor 10 | from functools import partial 11 | from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union 12 | 13 | import lightgbm as lgb 14 | import numpy as np 15 | import pandas as pd 16 | 17 | from mlforecast.core import ( 18 | DateFeature, 19 | Differences, 20 | Freq, 21 | LagTransforms, 22 | Lags, 23 | TimeSeries, 24 | ) 25 | from .utils import backtest_splits 26 | 27 | # %% ../nbs/lgb_cv.ipynb 5 28 | def _mape(y_true, y_pred): 29 | abs_pct_err = abs(y_true - y_pred) / y_true 30 | return ( 31 | abs_pct_err.groupby(y_true.index.get_level_values(0), observed=True) 32 | .mean() 33 | .mean() 34 | ) 35 | 36 | 37 | def _rmse(y_true, y_pred): 38 | sq_err = (y_true - y_pred) ** 2 39 | return ( 40 | sq_err.groupby(y_true.index.get_level_values(0), observed=True) 41 | .mean() 42 | .pow(0.5) 43 | .mean() 44 | ) 45 | 46 | 47 | _metric2fn = {"mape": _mape, "rmse": _rmse} 48 | 49 | 50 | def _update(bst, n): 51 | for _ in range(n): 52 | bst.update() 53 | 54 | 55 | def _predict( 56 | ts, 57 | bst, 58 | valid, 59 | h, 60 | id_col, 61 | time_col, 62 | dynamic_dfs, 63 | before_predict_callback, 64 | after_predict_callback, 65 | ): 66 | preds = ts.predict( 67 | {"Booster": bst}, 68 | h, 69 | dynamic_dfs, 70 | before_predict_callback, 71 | after_predict_callback, 72 | ) 73 | return valid.merge(preds, on=[id_col, time_col], how="left") 74 | 75 | 76 | def _update_and_predict( 77 | ts, 78 | bst, 79 | valid, 80 | n, 81 | h, 82 | id_col, 83 | time_col, 84 | dynamic_dfs, 85 | before_predict_callback, 86 | after_predict_callback, 87 | ): 88 | _update(bst, n) 89 | return _predict( 90 | ts, 91 | bst, 92 | valid, 93 | h, 94 | id_col, 95 | time_col, 96 | dynamic_dfs, 97 | before_predict_callback, 98 | after_predict_callback, 99 | ) 100 | 101 | # %% ../nbs/lgb_cv.ipynb 6 102 | CVResult = Tuple[int, float] 103 | 104 | # %% ../nbs/lgb_cv.ipynb 7 105 | class LightGBMCV: 106 | def __init__( 107 | self, 108 | freq: Optional[Freq] = None, 109 | lags: Optional[Lags] = None, 110 | lag_transforms: Optional[LagTransforms] = None, 111 | date_features: Optional[Iterable[DateFeature]] = None, 112 | differences: Optional[Differences] = None, 113 | num_threads: int = 1, 114 | ): 115 | """Create LightGBM CV object. 116 | 117 | Parameters 118 | ---------- 119 | freq : str or int, optional (default=None) 120 | Pandas offset alias, e.g. 'D', 'W-THU' or integer denoting the frequency of the series. 121 | lags : list of int, optional (default=None) 122 | Lags of the target to use as features. 123 | lag_transforms : dict of int to list of functions, optional (default=None) 124 | Mapping of target lags to their transformations. 125 | date_features : list of str or callable, optional (default=None) 126 | Features computed from the dates. Can be pandas date attributes or functions that will take the dates as input. 127 | differences : list of int, optional (default=None) 128 | Differences to take of the target before computing the features. These are restored at the forecasting step. 129 | num_threads : int (default=1) 130 | Number of threads to use when computing the features. 131 | """ 132 | self.num_threads = num_threads 133 | cpu_count = os.cpu_count() 134 | if cpu_count is None: 135 | num_cpus = 1 136 | else: 137 | num_cpus = cpu_count 138 | self.bst_threads = max(num_cpus // num_threads, 1) 139 | self.ts = TimeSeries( 140 | freq, lags, lag_transforms, date_features, differences, self.bst_threads 141 | ) 142 | 143 | def __repr__(self): 144 | return ( 145 | f"{self.__class__.__name__}(" 146 | f"freq={self.ts.freq}, " 147 | f"lag_features={list(self.ts.transforms.keys())}, " 148 | f"date_features={self.ts.date_features}, " 149 | f"num_threads={self.num_threads}, " 150 | f"bst_threads={self.bst_threads})" 151 | ) 152 | 153 | def setup( 154 | self, 155 | data: pd.DataFrame, 156 | n_windows: int, 157 | window_size: int, 158 | id_col: str = "unique_id", 159 | time_col: str = "ds", 160 | target_col: str = "y", 161 | step_size: Optional[int] = None, 162 | params: Optional[Dict[str, Any]] = None, 163 | static_features: Optional[List[str]] = None, 164 | dropna: bool = True, 165 | keep_last_n: Optional[int] = None, 166 | weights: Optional[Sequence[float]] = None, 167 | metric: Union[str, Callable] = "mape", 168 | input_size: Optional[int] = None, 169 | ): 170 | """Initialize internal data structures to iteratively train the boosters. Use this before calling partial_fit. 171 | 172 | Parameters 173 | ---------- 174 | data : pandas DataFrame 175 | Series data in long format. 176 | n_windows : int 177 | Number of windows to evaluate. 178 | window_size : int 179 | Number of test periods in each window. 180 | id_col : str (default='unique_id') 181 | Column that identifies each serie. 182 | time_col : str (default='ds') 183 | Column that identifies each timestep, its values can be timestamps or integers. 184 | target_col : str (default='y') 185 | Column that contains the target. 186 | step_size : int, optional (default=None) 187 | Step size between each cross validation window. If None it will be equal to `window_size`. 188 | params : dict, optional(default=None) 189 | Parameters to be passed to the LightGBM Boosters. 190 | static_features : list of str, optional (default=None) 191 | Names of the features that are static and will be repeated when forecasting. 192 | dropna : bool (default=True) 193 | Drop rows with missing values produced by the transformations. 194 | keep_last_n : int, optional (default=None) 195 | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. 196 | weights : sequence of float, optional (default=None) 197 | Weights to multiply the metric of each window. If None, all windows have the same weight. 198 | metric : str or callable, default='mape' 199 | Metric used to assess the performance of the models and perform early stopping. 200 | input_size : int, optional (default=None) 201 | Maximum training samples per serie in each window. If None, will use an expanding window. 202 | 203 | Returns 204 | ------- 205 | self : LightGBMCV 206 | CV object with internal data structures for partial_fit. 207 | """ 208 | if weights is None: 209 | self.weights = np.full(n_windows, 1 / n_windows) 210 | elif len(weights) != n_windows: 211 | raise ValueError("Must specify as many weights as the number of windows") 212 | else: 213 | self.weights = np.asarray(weights) 214 | if callable(metric): 215 | self.metric_fn = metric 216 | self.metric_name = "custom_metric" 217 | else: 218 | if metric not in _metric2fn: 219 | raise ValueError( 220 | f'{metric} is not one of the implemented metrics: ({", ".join(_metric2fn.keys())})' 221 | ) 222 | self.metric_fn = _metric2fn[metric] 223 | self.metric_name = metric 224 | if np.issubdtype(data[time_col].dtype.type, np.integer): 225 | freq = 1 226 | else: 227 | freq = self.ts.freq 228 | self.items = [] 229 | self.window_size = window_size 230 | self.id_col = id_col 231 | self.time_col = time_col 232 | self.target_col = target_col 233 | self.params = {} if params is None else params 234 | splits = backtest_splits( 235 | data, 236 | n_windows=n_windows, 237 | window_size=window_size, 238 | id_col=id_col, 239 | time_col=time_col, 240 | freq=freq, 241 | step_size=step_size, 242 | input_size=input_size, 243 | ) 244 | for _, train, valid in splits: 245 | ts = copy.deepcopy(self.ts) 246 | prep = ts.fit_transform( 247 | train, 248 | id_col, 249 | time_col, 250 | target_col, 251 | static_features, 252 | dropna, 253 | keep_last_n, 254 | ) 255 | ds = lgb.Dataset( 256 | prep.drop(columns=[id_col, time_col, target_col]), prep[target_col] 257 | ).construct() 258 | bst = lgb.Booster({**self.params, "num_threads": self.bst_threads}, ds) 259 | bst.predict = partial(bst.predict, num_threads=self.bst_threads) 260 | valid = valid.set_index(time_col, append=True) 261 | self.items.append((ts, bst, valid)) 262 | return self 263 | 264 | def _single_threaded_partial_fit( 265 | self, 266 | metric_values, 267 | num_iterations, 268 | dynamic_dfs, 269 | before_predict_callback: Optional[Callable] = None, 270 | after_predict_callback: Optional[Callable] = None, 271 | ): 272 | for j, (ts, bst, valid) in enumerate(self.items): 273 | preds = _update_and_predict( 274 | ts=ts, 275 | bst=bst, 276 | valid=valid, 277 | n=num_iterations, 278 | h=self.window_size, 279 | id_col=self.id_col, 280 | time_col=self.time_col, 281 | dynamic_dfs=dynamic_dfs, 282 | before_predict_callback=before_predict_callback, 283 | after_predict_callback=after_predict_callback, 284 | ) 285 | metric_values[j] = self.metric_fn(preds[self.target_col], preds["Booster"]) 286 | 287 | def _multithreaded_partial_fit( 288 | self, 289 | metric_values, 290 | num_iterations, 291 | dynamic_dfs, 292 | before_predict_callback: Optional[Callable] = None, 293 | after_predict_callback: Optional[Callable] = None, 294 | ): 295 | with ThreadPoolExecutor(self.num_threads) as executor: 296 | futures = [] 297 | for ts, bst, valid in self.items: 298 | _update(bst, num_iterations) 299 | future = executor.submit( 300 | _predict, 301 | ts=ts, 302 | bst=bst, 303 | valid=valid, 304 | h=self.window_size, 305 | id_col=self.id_col, 306 | time_col=self.time_col, 307 | dynamic_dfs=dynamic_dfs, 308 | before_predict_callback=before_predict_callback, 309 | after_predict_callback=after_predict_callback, 310 | ) 311 | futures.append(future) 312 | cv_preds = [f.result() for f in futures] 313 | metric_values[:] = [ 314 | self.metric_fn(preds[self.target_col], preds["Booster"]) 315 | for preds in cv_preds 316 | ] 317 | 318 | def partial_fit( 319 | self, 320 | num_iterations: int, 321 | dynamic_dfs: Optional[List[pd.DataFrame]] = None, 322 | before_predict_callback: Optional[Callable] = None, 323 | after_predict_callback: Optional[Callable] = None, 324 | ) -> float: 325 | """Train the boosters for some iterations. 326 | 327 | Parameters 328 | ---------- 329 | num_iterations : int 330 | Number of boosting iterations to run 331 | dynamic_dfs : list of pandas DataFrame, optional (default=None) 332 | Future values of the dynamic features, e.g. prices. 333 | before_predict_callback : callable, optional (default=None) 334 | Function to call on the features before computing the predictions. 335 | This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. 336 | The series identifier is on the index. 337 | after_predict_callback : callable, optional (default=None) 338 | Function to call on the predictions before updating the targets. 339 | This function will take a pandas Series with the predictions and should return another one with the same structure. 340 | The series identifier is on the index. 341 | 342 | Returns 343 | ------- 344 | metric_value : float 345 | Weighted metric after training for num_iterations. 346 | """ 347 | metric_values = np.empty(len(self.items)) 348 | if self.num_threads == 1: 349 | self._single_threaded_partial_fit( 350 | metric_values, 351 | num_iterations, 352 | dynamic_dfs, 353 | before_predict_callback, 354 | after_predict_callback, 355 | ) 356 | else: 357 | self._multithreaded_partial_fit( 358 | metric_values, 359 | num_iterations, 360 | dynamic_dfs, 361 | before_predict_callback, 362 | after_predict_callback, 363 | ) 364 | return metric_values @ self.weights 365 | 366 | def should_stop(self, hist, early_stopping_evals, early_stopping_pct) -> bool: 367 | if len(hist) < early_stopping_evals + 1: 368 | return False 369 | improvement_pct = 1 - hist[-1][1] / hist[-(early_stopping_evals + 1)][1] 370 | return improvement_pct < early_stopping_pct 371 | 372 | def find_best_iter(self, hist, early_stopping_evals) -> int: 373 | best_iter, best_score = hist[-1] 374 | for r, m in hist[-(early_stopping_evals + 1) : -1]: 375 | if m < best_score: 376 | best_score = m 377 | best_iter = r 378 | return best_iter 379 | 380 | def fit( 381 | self, 382 | data: pd.DataFrame, 383 | n_windows: int, 384 | window_size: int, 385 | id_col: str = "unique_id", 386 | time_col: str = "ds", 387 | target_col: str = "y", 388 | step_size: Optional[int] = None, 389 | num_iterations: int = 100, 390 | params: Optional[Dict[str, Any]] = None, 391 | static_features: Optional[List[str]] = None, 392 | dropna: bool = True, 393 | keep_last_n: Optional[int] = None, 394 | dynamic_dfs: Optional[List[pd.DataFrame]] = None, 395 | eval_every: int = 10, 396 | weights: Optional[Sequence[float]] = None, 397 | metric: Union[str, Callable] = "mape", 398 | verbose_eval: bool = True, 399 | early_stopping_evals: int = 2, 400 | early_stopping_pct: float = 0.01, 401 | compute_cv_preds: bool = False, 402 | before_predict_callback: Optional[Callable] = None, 403 | after_predict_callback: Optional[Callable] = None, 404 | input_size: Optional[int] = None, 405 | ) -> List[CVResult]: 406 | """Train boosters simultaneously and assess their performance on the complete forecasting window. 407 | 408 | Parameters 409 | ---------- 410 | data : pandas DataFrame 411 | Series data in long format. 412 | n_windows : int 413 | Number of windows to evaluate. 414 | window_size : int 415 | Number of test periods in each window. 416 | id_col : str (default='unique_id') 417 | Column that identifies each serie. 418 | time_col : str (default='ds') 419 | Column that identifies each timestep, its values can be timestamps or integers. 420 | target_col : str (default='y') 421 | Column that contains the target. 422 | step_size : int, optional (default=None) 423 | Step size between each cross validation window. If None it will be equal to `window_size`. 424 | num_iterations : int (default=100) 425 | Maximum number of boosting iterations to run. 426 | params : dict, optional(default=None) 427 | Parameters to be passed to the LightGBM Boosters. 428 | static_features : list of str, optional (default=None) 429 | Names of the features that are static and will be repeated when forecasting. 430 | dropna : bool (default=True) 431 | Drop rows with missing values produced by the transformations. 432 | keep_last_n : int, optional (default=None) 433 | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. 434 | dynamic_dfs : list of pandas DataFrame, optional (default=None) 435 | Future values of the dynamic features, e.g. prices. 436 | eval_every : int (default=10) 437 | Number of boosting iterations to train before evaluating on the whole forecast window. 438 | weights : sequence of float, optional (default=None) 439 | Weights to multiply the metric of each window. If None, all windows have the same weight. 440 | metric : str or callable, default='mape' 441 | Metric used to assess the performance of the models and perform early stopping. 442 | verbose_eval : bool 443 | Print the metrics of each evaluation. 444 | early_stopping_evals : int (default=2) 445 | Maximum number of evaluations to run without improvement. 446 | early_stopping_pct : float (default=0.01) 447 | Minimum percentage improvement in metric value in `early_stopping_evals` evaluations. 448 | compute_cv_preds : bool (default=True) 449 | Compute predictions for each window after finding the best iteration. 450 | before_predict_callback : callable, optional (default=None) 451 | Function to call on the features before computing the predictions. 452 | This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. 453 | The series identifier is on the index. 454 | after_predict_callback : callable, optional (default=None) 455 | Function to call on the predictions before updating the targets. 456 | This function will take a pandas Series with the predictions and should return another one with the same structure. 457 | The series identifier is on the index. 458 | input_size : int, optional (default=None) 459 | Maximum training samples per serie in each window. If None, will use an expanding window. 460 | 461 | Returns 462 | ------- 463 | cv_result : list of tuple. 464 | List of (boosting rounds, metric value) tuples. 465 | """ 466 | self.setup( 467 | data=data, 468 | n_windows=n_windows, 469 | window_size=window_size, 470 | params=params, 471 | id_col=id_col, 472 | time_col=time_col, 473 | target_col=target_col, 474 | input_size=input_size, 475 | step_size=step_size, 476 | static_features=static_features, 477 | dropna=dropna, 478 | keep_last_n=keep_last_n, 479 | weights=weights, 480 | metric=metric, 481 | ) 482 | hist = [] 483 | for i in range(0, num_iterations, eval_every): 484 | metric_value = self.partial_fit( 485 | eval_every, dynamic_dfs, before_predict_callback, after_predict_callback 486 | ) 487 | rounds = eval_every + i 488 | hist.append((rounds, metric_value)) 489 | if verbose_eval: 490 | print(f"[{rounds:,d}] {self.metric_name}: {metric_value:,f}") 491 | if self.should_stop(hist, early_stopping_evals, early_stopping_pct): 492 | print(f"Early stopping at round {rounds:,}") 493 | break 494 | self.best_iteration_ = self.find_best_iter(hist, early_stopping_evals) 495 | print(f"Using best iteration: {self.best_iteration_:,}") 496 | hist = hist[: self.best_iteration_ // eval_every] 497 | for _, bst, _ in self.items: 498 | bst.best_iteration = self.best_iteration_ 499 | 500 | self.cv_models_ = {f"Booster{i}": item[1] for i, item in enumerate(self.items)} 501 | if compute_cv_preds: 502 | with ThreadPoolExecutor(self.num_threads) as executor: 503 | futures = [] 504 | for ts, bst, valid in self.items: 505 | future = executor.submit( 506 | _predict, 507 | ts=ts, 508 | bst=bst, 509 | valid=valid, 510 | h=self.window_size, 511 | id_col=self.id_col, 512 | time_col=self.time_col, 513 | dynamic_dfs=dynamic_dfs, 514 | before_predict_callback=before_predict_callback, 515 | after_predict_callback=after_predict_callback, 516 | ) 517 | futures.append(future) 518 | self.cv_preds_ = pd.concat( 519 | [f.result().assign(window=i) for i, f in enumerate(futures)] 520 | ) 521 | self.ts._fit(data, id_col, time_col, target_col, static_features, keep_last_n) 522 | return hist 523 | 524 | def predict( 525 | self, 526 | horizon: int, 527 | dynamic_dfs: Optional[List[pd.DataFrame]] = None, 528 | before_predict_callback: Optional[Callable] = None, 529 | after_predict_callback: Optional[Callable] = None, 530 | ) -> pd.DataFrame: 531 | """Compute predictions with each of the trained boosters. 532 | 533 | Parameters 534 | ---------- 535 | horizon : int 536 | Number of periods to predict. 537 | dynamic_dfs : list of pandas DataFrame, optional (default=None) 538 | Future values of the dynamic features, e.g. prices. 539 | before_predict_callback : callable, optional (default=None) 540 | Function to call on the features before computing the predictions. 541 | This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. 542 | The series identifier is on the index. 543 | after_predict_callback : callable, optional (default=None) 544 | Function to call on the predictions before updating the targets. 545 | This function will take a pandas Series with the predictions and should return another one with the same structure. 546 | The series identifier is on the index. 547 | 548 | Returns 549 | ------- 550 | result : pandas DataFrame 551 | Predictions for each serie and timestep, with one column per window. 552 | """ 553 | return self.ts.predict( 554 | self.cv_models_, 555 | horizon, 556 | dynamic_dfs, 557 | before_predict_callback, 558 | after_predict_callback, 559 | ) 560 | -------------------------------------------------------------------------------- /mlforecast/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'main', 4 | 'doc_baseurl': '/', 5 | 'doc_host': 'https://Nixtla.github.io', 6 | 'git_url': 'https://github.com/Nixtla/mlforecast', 7 | 'lib_path': 'mlforecast'}, 8 | 'syms': { 'mlforecast.core': { 'mlforecast.core.GroupedArray': ('core.html#groupedarray', 'mlforecast/core.py'), 9 | 'mlforecast.core.GroupedArray.__getitem__': ('core.html#groupedarray.__getitem__', 'mlforecast/core.py'), 10 | 'mlforecast.core.GroupedArray.__init__': ('core.html#groupedarray.__init__', 'mlforecast/core.py'), 11 | 'mlforecast.core.GroupedArray.__len__': ('core.html#groupedarray.__len__', 'mlforecast/core.py'), 12 | 'mlforecast.core.GroupedArray.__repr__': ('core.html#groupedarray.__repr__', 'mlforecast/core.py'), 13 | 'mlforecast.core.GroupedArray.__setitem__': ('core.html#groupedarray.__setitem__', 'mlforecast/core.py'), 14 | 'mlforecast.core.GroupedArray.append': ('core.html#groupedarray.append', 'mlforecast/core.py'), 15 | 'mlforecast.core.GroupedArray.expand_target': ( 'core.html#groupedarray.expand_target', 16 | 'mlforecast/core.py'), 17 | 'mlforecast.core.GroupedArray.from_sorted_df': ( 'core.html#groupedarray.from_sorted_df', 18 | 'mlforecast/core.py'), 19 | 'mlforecast.core.GroupedArray.restore_difference': ( 'core.html#groupedarray.restore_difference', 20 | 'mlforecast/core.py'), 21 | 'mlforecast.core.GroupedArray.take_from_groups': ( 'core.html#groupedarray.take_from_groups', 22 | 'mlforecast/core.py'), 23 | 'mlforecast.core.GroupedArray.transform_series': ( 'core.html#groupedarray.transform_series', 24 | 'mlforecast/core.py'), 25 | 'mlforecast.core.TimeSeries': ('core.html#timeseries', 'mlforecast/core.py'), 26 | 'mlforecast.core.TimeSeries.__init__': ('core.html#timeseries.__init__', 'mlforecast/core.py'), 27 | 'mlforecast.core.TimeSeries.__repr__': ('core.html#timeseries.__repr__', 'mlforecast/core.py'), 28 | 'mlforecast.core.TimeSeries._apply_multithreaded_transforms': ( 'core.html#timeseries._apply_multithreaded_transforms', 29 | 'mlforecast/core.py'), 30 | 'mlforecast.core.TimeSeries._apply_transforms': ( 'core.html#timeseries._apply_transforms', 31 | 'mlforecast/core.py'), 32 | 'mlforecast.core.TimeSeries._compute_date_feature': ( 'core.html#timeseries._compute_date_feature', 33 | 'mlforecast/core.py'), 34 | 'mlforecast.core.TimeSeries._compute_transforms': ( 'core.html#timeseries._compute_transforms', 35 | 'mlforecast/core.py'), 36 | 'mlforecast.core.TimeSeries._date_feature_names': ( 'core.html#timeseries._date_feature_names', 37 | 'mlforecast/core.py'), 38 | 'mlforecast.core.TimeSeries._fit': ('core.html#timeseries._fit', 'mlforecast/core.py'), 39 | 'mlforecast.core.TimeSeries._get_features_for_next_step': ( 'core.html#timeseries._get_features_for_next_step', 40 | 'mlforecast/core.py'), 41 | 'mlforecast.core.TimeSeries._get_predictions': ( 'core.html#timeseries._get_predictions', 42 | 'mlforecast/core.py'), 43 | 'mlforecast.core.TimeSeries._get_raw_predictions': ( 'core.html#timeseries._get_raw_predictions', 44 | 'mlforecast/core.py'), 45 | 'mlforecast.core.TimeSeries._predict_multi': ('core.html#timeseries._predict_multi', 'mlforecast/core.py'), 46 | 'mlforecast.core.TimeSeries._predict_recursive': ( 'core.html#timeseries._predict_recursive', 47 | 'mlforecast/core.py'), 48 | 'mlforecast.core.TimeSeries._predict_setup': ('core.html#timeseries._predict_setup', 'mlforecast/core.py'), 49 | 'mlforecast.core.TimeSeries._restore_differences': ( 'core.html#timeseries._restore_differences', 50 | 'mlforecast/core.py'), 51 | 'mlforecast.core.TimeSeries._transform': ('core.html#timeseries._transform', 'mlforecast/core.py'), 52 | 'mlforecast.core.TimeSeries._update_features': ( 'core.html#timeseries._update_features', 53 | 'mlforecast/core.py'), 54 | 'mlforecast.core.TimeSeries._update_y': ('core.html#timeseries._update_y', 'mlforecast/core.py'), 55 | 'mlforecast.core.TimeSeries.features': ('core.html#timeseries.features', 'mlforecast/core.py'), 56 | 'mlforecast.core.TimeSeries.fit_transform': ('core.html#timeseries.fit_transform', 'mlforecast/core.py'), 57 | 'mlforecast.core.TimeSeries.predict': ('core.html#timeseries.predict', 'mlforecast/core.py'), 58 | 'mlforecast.core._append_new': ('core.html#_append_new', 'mlforecast/core.py'), 59 | 'mlforecast.core._apply_difference': ('core.html#_apply_difference', 'mlforecast/core.py'), 60 | 'mlforecast.core._as_tuple': ('core.html#_as_tuple', 'mlforecast/core.py'), 61 | 'mlforecast.core._build_transform_name': ('core.html#_build_transform_name', 'mlforecast/core.py'), 62 | 'mlforecast.core._diff': ('core.html#_diff', 'mlforecast/core.py'), 63 | 'mlforecast.core._expand_target': ('core.html#_expand_target', 'mlforecast/core.py'), 64 | 'mlforecast.core._identity': ('core.html#_identity', 'mlforecast/core.py'), 65 | 'mlforecast.core._name_models': ('core.html#_name_models', 'mlforecast/core.py'), 66 | 'mlforecast.core._restore_difference': ('core.html#_restore_difference', 'mlforecast/core.py'), 67 | 'mlforecast.core._transform_series': ('core.html#_transform_series', 'mlforecast/core.py')}, 68 | 'mlforecast.distributed.forecast': { 'mlforecast.distributed.forecast.DistributedMLForecast': ( 'distributed.forecast.html#distributedmlforecast', 69 | 'mlforecast/distributed/forecast.py'), 70 | 'mlforecast.distributed.forecast.DistributedMLForecast.__init__': ( 'distributed.forecast.html#distributedmlforecast.__init__', 71 | 'mlforecast/distributed/forecast.py'), 72 | 'mlforecast.distributed.forecast.DistributedMLForecast.__repr__': ( 'distributed.forecast.html#distributedmlforecast.__repr__', 73 | 'mlforecast/distributed/forecast.py'), 74 | 'mlforecast.distributed.forecast.DistributedMLForecast._fit': ( 'distributed.forecast.html#distributedmlforecast._fit', 75 | 'mlforecast/distributed/forecast.py'), 76 | 'mlforecast.distributed.forecast.DistributedMLForecast._get_predict_schema': ( 'distributed.forecast.html#distributedmlforecast._get_predict_schema', 77 | 'mlforecast/distributed/forecast.py'), 78 | 'mlforecast.distributed.forecast.DistributedMLForecast._predict': ( 'distributed.forecast.html#distributedmlforecast._predict', 79 | 'mlforecast/distributed/forecast.py'), 80 | 'mlforecast.distributed.forecast.DistributedMLForecast._preprocess': ( 'distributed.forecast.html#distributedmlforecast._preprocess', 81 | 'mlforecast/distributed/forecast.py'), 82 | 'mlforecast.distributed.forecast.DistributedMLForecast._preprocess_partition': ( 'distributed.forecast.html#distributedmlforecast._preprocess_partition', 83 | 'mlforecast/distributed/forecast.py'), 84 | 'mlforecast.distributed.forecast.DistributedMLForecast._preprocess_partitions': ( 'distributed.forecast.html#distributedmlforecast._preprocess_partitions', 85 | 'mlforecast/distributed/forecast.py'), 86 | 'mlforecast.distributed.forecast.DistributedMLForecast._retrieve_df': ( 'distributed.forecast.html#distributedmlforecast._retrieve_df', 87 | 'mlforecast/distributed/forecast.py'), 88 | 'mlforecast.distributed.forecast.DistributedMLForecast.cross_validation': ( 'distributed.forecast.html#distributedmlforecast.cross_validation', 89 | 'mlforecast/distributed/forecast.py'), 90 | 'mlforecast.distributed.forecast.DistributedMLForecast.fit': ( 'distributed.forecast.html#distributedmlforecast.fit', 91 | 'mlforecast/distributed/forecast.py'), 92 | 'mlforecast.distributed.forecast.DistributedMLForecast.predict': ( 'distributed.forecast.html#distributedmlforecast.predict', 93 | 'mlforecast/distributed/forecast.py'), 94 | 'mlforecast.distributed.forecast.DistributedMLForecast.preprocess': ( 'distributed.forecast.html#distributedmlforecast.preprocess', 95 | 'mlforecast/distributed/forecast.py')}, 96 | 'mlforecast.distributed.models.dask.lgb': { 'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast': ( 'distributed.models.dask.lgb.html#dasklgbmforecast', 97 | 'mlforecast/distributed/models/dask/lgb.py'), 98 | 'mlforecast.distributed.models.dask.lgb.DaskLGBMForecast.model_': ( 'distributed.models.dask.lgb.html#dasklgbmforecast.model_', 99 | 'mlforecast/distributed/models/dask/lgb.py')}, 100 | 'mlforecast.distributed.models.dask.xgb': { 'mlforecast.distributed.models.dask.xgb.DaskXGBForecast': ( 'distributed.models.dask.xgb.html#daskxgbforecast', 101 | 'mlforecast/distributed/models/dask/xgb.py'), 102 | 'mlforecast.distributed.models.dask.xgb.DaskXGBForecast.model_': ( 'distributed.models.dask.xgb.html#daskxgbforecast.model_', 103 | 'mlforecast/distributed/models/dask/xgb.py')}, 104 | 'mlforecast.distributed.models.ray.lgb': { 'mlforecast.distributed.models.ray.lgb.RayLGBMForecast': ( 'distributed.models.ray.lgb.html#raylgbmforecast', 105 | 'mlforecast/distributed/models/ray/lgb.py'), 106 | 'mlforecast.distributed.models.ray.lgb.RayLGBMForecast.model_': ( 'distributed.models.ray.lgb.html#raylgbmforecast.model_', 107 | 'mlforecast/distributed/models/ray/lgb.py')}, 108 | 'mlforecast.distributed.models.ray.xgb': { 'mlforecast.distributed.models.ray.xgb.RayXGBForecast': ( 'distributed.models.ray.xgb.html#rayxgbforecast', 109 | 'mlforecast/distributed/models/ray/xgb.py'), 110 | 'mlforecast.distributed.models.ray.xgb.RayXGBForecast.model_': ( 'distributed.models.ray.xgb.html#rayxgbforecast.model_', 111 | 'mlforecast/distributed/models/ray/xgb.py')}, 112 | 'mlforecast.distributed.models.spark.lgb': { 'mlforecast.distributed.models.spark.lgb.SparkLGBMForecast': ( 'distributed.models.spark.lgb.html#sparklgbmforecast', 113 | 'mlforecast/distributed/models/spark/lgb.py'), 114 | 'mlforecast.distributed.models.spark.lgb.SparkLGBMForecast._pre_fit': ( 'distributed.models.spark.lgb.html#sparklgbmforecast._pre_fit', 115 | 'mlforecast/distributed/models/spark/lgb.py'), 116 | 'mlforecast.distributed.models.spark.lgb.SparkLGBMForecast.extract_local_model': ( 'distributed.models.spark.lgb.html#sparklgbmforecast.extract_local_model', 117 | 'mlforecast/distributed/models/spark/lgb.py')}, 118 | 'mlforecast.distributed.models.spark.xgb': { 'mlforecast.distributed.models.spark.xgb.SparkXGBForecast': ( 'distributed.models.spark.xgb.html#sparkxgbforecast', 119 | 'mlforecast/distributed/models/spark/xgb.py'), 120 | 'mlforecast.distributed.models.spark.xgb.SparkXGBForecast._pre_fit': ( 'distributed.models.spark.xgb.html#sparkxgbforecast._pre_fit', 121 | 'mlforecast/distributed/models/spark/xgb.py'), 122 | 'mlforecast.distributed.models.spark.xgb.SparkXGBForecast.extract_local_model': ( 'distributed.models.spark.xgb.html#sparkxgbforecast.extract_local_model', 123 | 'mlforecast/distributed/models/spark/xgb.py')}, 124 | 'mlforecast.forecast': { 'mlforecast.forecast.MLForecast': ('forecast.html#mlforecast', 'mlforecast/forecast.py'), 125 | 'mlforecast.forecast.MLForecast.__init__': ( 'forecast.html#mlforecast.__init__', 126 | 'mlforecast/forecast.py'), 127 | 'mlforecast.forecast.MLForecast.__repr__': ( 'forecast.html#mlforecast.__repr__', 128 | 'mlforecast/forecast.py'), 129 | 'mlforecast.forecast.MLForecast._conformity_scores': ( 'forecast.html#mlforecast._conformity_scores', 130 | 'mlforecast/forecast.py'), 131 | 'mlforecast.forecast.MLForecast.cross_validation': ( 'forecast.html#mlforecast.cross_validation', 132 | 'mlforecast/forecast.py'), 133 | 'mlforecast.forecast.MLForecast.fit': ('forecast.html#mlforecast.fit', 'mlforecast/forecast.py'), 134 | 'mlforecast.forecast.MLForecast.fit_models': ( 'forecast.html#mlforecast.fit_models', 135 | 'mlforecast/forecast.py'), 136 | 'mlforecast.forecast.MLForecast.freq': ('forecast.html#mlforecast.freq', 'mlforecast/forecast.py'), 137 | 'mlforecast.forecast.MLForecast.from_cv': ( 'forecast.html#mlforecast.from_cv', 138 | 'mlforecast/forecast.py'), 139 | 'mlforecast.forecast.MLForecast.predict': ( 'forecast.html#mlforecast.predict', 140 | 'mlforecast/forecast.py'), 141 | 'mlforecast.forecast.MLForecast.preprocess': ( 'forecast.html#mlforecast.preprocess', 142 | 'mlforecast/forecast.py'), 143 | 'mlforecast.forecast._add_conformal_distribution_intervals': ( 'forecast.html#_add_conformal_distribution_intervals', 144 | 'mlforecast/forecast.py'), 145 | 'mlforecast.forecast._add_conformal_error_intervals': ( 'forecast.html#_add_conformal_error_intervals', 146 | 'mlforecast/forecast.py'), 147 | 'mlforecast.forecast._get_conformal_method': ( 'forecast.html#_get_conformal_method', 148 | 'mlforecast/forecast.py')}, 149 | 'mlforecast.lgb_cv': { 'mlforecast.lgb_cv.LightGBMCV': ('lgb_cv.html#lightgbmcv', 'mlforecast/lgb_cv.py'), 150 | 'mlforecast.lgb_cv.LightGBMCV.__init__': ('lgb_cv.html#lightgbmcv.__init__', 'mlforecast/lgb_cv.py'), 151 | 'mlforecast.lgb_cv.LightGBMCV.__repr__': ('lgb_cv.html#lightgbmcv.__repr__', 'mlforecast/lgb_cv.py'), 152 | 'mlforecast.lgb_cv.LightGBMCV._multithreaded_partial_fit': ( 'lgb_cv.html#lightgbmcv._multithreaded_partial_fit', 153 | 'mlforecast/lgb_cv.py'), 154 | 'mlforecast.lgb_cv.LightGBMCV._single_threaded_partial_fit': ( 'lgb_cv.html#lightgbmcv._single_threaded_partial_fit', 155 | 'mlforecast/lgb_cv.py'), 156 | 'mlforecast.lgb_cv.LightGBMCV.find_best_iter': ( 'lgb_cv.html#lightgbmcv.find_best_iter', 157 | 'mlforecast/lgb_cv.py'), 158 | 'mlforecast.lgb_cv.LightGBMCV.fit': ('lgb_cv.html#lightgbmcv.fit', 'mlforecast/lgb_cv.py'), 159 | 'mlforecast.lgb_cv.LightGBMCV.partial_fit': ( 'lgb_cv.html#lightgbmcv.partial_fit', 160 | 'mlforecast/lgb_cv.py'), 161 | 'mlforecast.lgb_cv.LightGBMCV.predict': ('lgb_cv.html#lightgbmcv.predict', 'mlforecast/lgb_cv.py'), 162 | 'mlforecast.lgb_cv.LightGBMCV.setup': ('lgb_cv.html#lightgbmcv.setup', 'mlforecast/lgb_cv.py'), 163 | 'mlforecast.lgb_cv.LightGBMCV.should_stop': ( 'lgb_cv.html#lightgbmcv.should_stop', 164 | 'mlforecast/lgb_cv.py'), 165 | 'mlforecast.lgb_cv._mape': ('lgb_cv.html#_mape', 'mlforecast/lgb_cv.py'), 166 | 'mlforecast.lgb_cv._predict': ('lgb_cv.html#_predict', 'mlforecast/lgb_cv.py'), 167 | 'mlforecast.lgb_cv._rmse': ('lgb_cv.html#_rmse', 'mlforecast/lgb_cv.py'), 168 | 'mlforecast.lgb_cv._update': ('lgb_cv.html#_update', 'mlforecast/lgb_cv.py'), 169 | 'mlforecast.lgb_cv._update_and_predict': ('lgb_cv.html#_update_and_predict', 'mlforecast/lgb_cv.py')}, 170 | 'mlforecast.utils': { 'mlforecast.utils.PredictionIntervals': ('utils.html#predictionintervals', 'mlforecast/utils.py'), 171 | 'mlforecast.utils.PredictionIntervals.__init__': ( 'utils.html#predictionintervals.__init__', 172 | 'mlforecast/utils.py'), 173 | 'mlforecast.utils.backtest_splits': ('utils.html#backtest_splits', 'mlforecast/utils.py'), 174 | 'mlforecast.utils.generate_daily_series': ('utils.html#generate_daily_series', 'mlforecast/utils.py'), 175 | 'mlforecast.utils.generate_prices_for_series': ( 'utils.html#generate_prices_for_series', 176 | 'mlforecast/utils.py'), 177 | 'mlforecast.utils.single_split': ('utils.html#single_split', 'mlforecast/utils.py')}}} 178 | -------------------------------------------------------------------------------- /mlforecast/distributed/forecast.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/distributed.forecast.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DistributedMLForecast'] 5 | 6 | # %% ../../nbs/distributed.forecast.ipynb 5 7 | import copy 8 | from collections import namedtuple 9 | from typing import Any, Callable, Iterable, List, Optional 10 | 11 | import cloudpickle 12 | 13 | try: 14 | import dask.dataframe as dd 15 | 16 | DASK_INSTALLED = True 17 | except ModuleNotFoundError: 18 | DASK_INSTALLED = False 19 | import fugue 20 | import fugue.api as fa 21 | import pandas as pd 22 | 23 | try: 24 | from pyspark.ml.feature import VectorAssembler 25 | from pyspark.sql import DataFrame as SparkDataFrame 26 | 27 | SPARK_INSTALLED = True 28 | except ModuleNotFoundError: 29 | SPARK_INSTALLED = False 30 | try: 31 | from lightgbm_ray import RayDMatrix 32 | from ray.data import Dataset as RayDataset 33 | 34 | RAY_INSTALLED = True 35 | except ModuleNotFoundError: 36 | RAY_INSTALLED = False 37 | from sklearn.base import clone 38 | 39 | from mlforecast.core import ( 40 | DateFeature, 41 | Differences, 42 | Freq, 43 | LagTransforms, 44 | Lags, 45 | TimeSeries, 46 | _name_models, 47 | ) 48 | from ..utils import single_split 49 | 50 | # %% ../../nbs/distributed.forecast.ipynb 6 51 | WindowInfo = namedtuple( 52 | "WindowInfo", ["n_windows", "window_size", "step_size", "i_window", "input_size"] 53 | ) 54 | 55 | # %% ../../nbs/distributed.forecast.ipynb 7 56 | class DistributedMLForecast: 57 | """Multi backend distributed pipeline""" 58 | 59 | def __init__( 60 | self, 61 | models, 62 | freq: Optional[Freq] = None, 63 | lags: Optional[Lags] = None, 64 | lag_transforms: Optional[LagTransforms] = None, 65 | date_features: Optional[Iterable[DateFeature]] = None, 66 | differences: Optional[Differences] = None, 67 | num_threads: int = 1, 68 | engine=None, 69 | num_partitions: Optional[int] = None, 70 | ): 71 | """Create distributed forecast object 72 | 73 | Parameters 74 | ---------- 75 | models : regressor or list of regressors 76 | Models that will be trained and used to compute the forecasts. 77 | freq : str or int, optional (default=None) 78 | Pandas offset alias, e.g. 'D', 'W-THU' or integer denoting the frequency of the series. 79 | lags : list of int, optional (default=None) 80 | Lags of the target to use as features. 81 | lag_transforms : dict of int to list of functions, optional (default=None) 82 | Mapping of target lags to their transformations. 83 | date_features : list of str or callable, optional (default=None) 84 | Features computed from the dates. Can be pandas date attributes or functions that will take the dates as input. 85 | differences : list of int, optional (default=None) 86 | Differences to take of the target before computing the features. These are restored at the forecasting step. 87 | num_threads : int (default=1) 88 | Number of threads to use when computing the features. 89 | engine : fugue execution engine, optional (default=None) 90 | Dask Client, Spark Session, etc to use for the distributed computation. 91 | If None will infer depending on the input type. 92 | num_partitions: number of data partitions to use, optional (default=None) 93 | If None, the default partitions provided by the AnyDataFrame used 94 | by the `fit` and `cross_validation` methods will be used. If a Ray 95 | Dataset is provided and `num_partitions` is None, the partitioning 96 | will be done by the `id_col`. 97 | """ 98 | if not isinstance(models, dict) and not isinstance(models, list): 99 | models = [models] 100 | if isinstance(models, list): 101 | model_names = _name_models([m.__class__.__name__ for m in models]) 102 | models_with_names = dict(zip(model_names, models)) 103 | else: 104 | models_with_names = models 105 | self.models = models_with_names 106 | self._base_ts = TimeSeries( 107 | freq, lags, lag_transforms, date_features, differences, num_threads 108 | ) 109 | self.engine = engine 110 | self.num_partitions = num_partitions 111 | 112 | def __repr__(self) -> str: 113 | return ( 114 | f'{self.__class__.__name__}(models=[{", ".join(self.models.keys())}], ' 115 | f"freq={self._base_ts.freq}, " 116 | f"lag_features={list(self._base_ts.transforms.keys())}, " 117 | f"date_features={self._base_ts.date_features}, " 118 | f"num_threads={self._base_ts.num_threads}, " 119 | f"engine={self.engine})" 120 | ) 121 | 122 | @staticmethod 123 | def _preprocess_partition( 124 | part: pd.DataFrame, 125 | base_ts: TimeSeries, 126 | id_col: str, 127 | time_col: str, 128 | target_col: str, 129 | static_features: Optional[List[str]] = None, 130 | dropna: bool = True, 131 | keep_last_n: Optional[int] = None, 132 | window_info: Optional[WindowInfo] = None, 133 | fit_ts_only: bool = False, 134 | ) -> List[List[Any]]: 135 | ts = copy.deepcopy(base_ts) 136 | if fit_ts_only: 137 | ts._fit( 138 | part, 139 | id_col=id_col, 140 | time_col=time_col, 141 | target_col=target_col, 142 | static_features=static_features, 143 | keep_last_n=keep_last_n, 144 | ) 145 | return [ 146 | [ 147 | cloudpickle.dumps(ts), 148 | cloudpickle.dumps(None), 149 | cloudpickle.dumps(None), 150 | ] 151 | ] 152 | if window_info is None: 153 | train = part 154 | valid = None 155 | else: 156 | max_dates = part.groupby(id_col, observed=True)[time_col].transform("max") 157 | cutoffs, train_mask, valid_mask = single_split( 158 | part, 159 | i_window=window_info.i_window, 160 | n_windows=window_info.n_windows, 161 | window_size=window_info.window_size, 162 | id_col=id_col, 163 | time_col=time_col, 164 | freq=base_ts.freq, 165 | max_dates=max_dates, 166 | step_size=window_info.step_size, 167 | input_size=window_info.input_size, 168 | ) 169 | train = part[train_mask] 170 | valid_keep_cols = part.columns 171 | if static_features is not None: 172 | valid_keep_cols.drop(static_features) 173 | valid = part.loc[valid_mask, valid_keep_cols].merge(cutoffs, on=id_col) 174 | transformed = ts.fit_transform( 175 | train, 176 | id_col=id_col, 177 | time_col=time_col, 178 | target_col=target_col, 179 | static_features=static_features, 180 | dropna=dropna, 181 | keep_last_n=keep_last_n, 182 | ) 183 | return [ 184 | [ 185 | cloudpickle.dumps(ts), 186 | cloudpickle.dumps(transformed), 187 | cloudpickle.dumps(valid), 188 | ] 189 | ] 190 | 191 | @staticmethod 192 | def _retrieve_df(items: List[List[Any]]) -> Iterable[pd.DataFrame]: 193 | for _, serialized_train, _ in items: 194 | yield cloudpickle.loads(serialized_train) 195 | 196 | def _preprocess_partitions( 197 | self, 198 | data: fugue.AnyDataFrame, 199 | id_col: str, 200 | time_col: str, 201 | target_col: str, 202 | static_features: Optional[List[str]] = None, 203 | dropna: bool = True, 204 | keep_last_n: Optional[int] = None, 205 | window_info: Optional[WindowInfo] = None, 206 | fit_ts_only: bool = False, 207 | ) -> List[Any]: 208 | if self.num_partitions: 209 | partition = dict(by=id_col, num=self.num_partitions, algo="coarse") 210 | elif isinstance( 211 | data, RayDataset 212 | ): # num partitions is None but data is a RayDataset 213 | # We need to add this because 214 | # currently ray doesnt support partitioning a Dataset 215 | # based on a column. 216 | # If a Dataset is partitioned using `.repartition(num_partitions)` 217 | # we will have akward results. 218 | partition = dict(by=id_col) 219 | else: 220 | partition = None 221 | return fa.transform( 222 | data, 223 | DistributedMLForecast._preprocess_partition, 224 | params={ 225 | "base_ts": self._base_ts, 226 | "id_col": id_col, 227 | "time_col": time_col, 228 | "target_col": target_col, 229 | "static_features": static_features, 230 | "dropna": dropna, 231 | "keep_last_n": keep_last_n, 232 | "window_info": window_info, 233 | "fit_ts_only": fit_ts_only, 234 | }, 235 | schema="ts:binary,train:binary,valid:binary", 236 | engine=self.engine, 237 | as_fugue=True, 238 | partition=partition, 239 | ) 240 | 241 | def _preprocess( 242 | self, 243 | data: fugue.AnyDataFrame, 244 | id_col: str, 245 | time_col: str, 246 | target_col: str, 247 | static_features: Optional[List[str]] = None, 248 | dropna: bool = True, 249 | keep_last_n: Optional[int] = None, 250 | window_info: Optional[WindowInfo] = None, 251 | ) -> fugue.AnyDataFrame: 252 | self.id_col = id_col 253 | self.time_col = time_col 254 | self.target_col = target_col 255 | self.static_features = static_features 256 | self.dropna = dropna 257 | self.keep_last_n = keep_last_n 258 | self.partition_results = self._preprocess_partitions( 259 | data=data, 260 | id_col=id_col, 261 | time_col=time_col, 262 | target_col=target_col, 263 | static_features=static_features, 264 | dropna=dropna, 265 | keep_last_n=keep_last_n, 266 | window_info=window_info, 267 | ) 268 | base_schema = str(fa.get_schema(data)) 269 | features_schema = ",".join(f"{feat}:double" for feat in self._base_ts.features) 270 | res = fa.transform( 271 | self.partition_results, 272 | DistributedMLForecast._retrieve_df, 273 | schema=f"{base_schema},{features_schema}", 274 | engine=self.engine, 275 | ) 276 | return fa.get_native_as_df(res) 277 | 278 | def preprocess( 279 | self, 280 | data: fugue.AnyDataFrame, 281 | id_col: str = "unique_id", 282 | time_col: str = "ds", 283 | target_col: str = "y", 284 | static_features: Optional[List[str]] = None, 285 | dropna: bool = True, 286 | keep_last_n: Optional[int] = None, 287 | ) -> fugue.AnyDataFrame: 288 | """Add the features to `data`. 289 | 290 | Parameters 291 | ---------- 292 | data : dask or spark DataFrame. 293 | Series data in long format. 294 | id_col : str (default='unique_id') 295 | Column that identifies each serie. 296 | time_col : str (default='ds') 297 | Column that identifies each timestep, its values can be timestamps or integers. 298 | target_col : str (default='y') 299 | Column that contains the target. 300 | static_features : list of str, optional (default=None) 301 | Names of the features that are static and will be repeated when forecasting. 302 | dropna : bool (default=True) 303 | Drop rows with missing values produced by the transformations. 304 | keep_last_n : int, optional (default=None) 305 | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. 306 | 307 | Returns 308 | ------- 309 | result : same type as input 310 | data with added features. 311 | """ 312 | return self._preprocess( 313 | data, 314 | id_col=id_col, 315 | time_col=time_col, 316 | target_col=target_col, 317 | static_features=static_features, 318 | dropna=dropna, 319 | keep_last_n=keep_last_n, 320 | ) 321 | 322 | def _fit( 323 | self, 324 | data: fugue.AnyDataFrame, 325 | id_col: str, 326 | time_col: str, 327 | target_col: str, 328 | static_features: Optional[List[str]] = None, 329 | dropna: bool = True, 330 | keep_last_n: Optional[int] = None, 331 | window_info: Optional[WindowInfo] = None, 332 | ) -> "DistributedMLForecast": 333 | prep = self._preprocess( 334 | data, 335 | id_col=id_col, 336 | time_col=time_col, 337 | target_col=target_col, 338 | static_features=static_features, 339 | dropna=dropna, 340 | keep_last_n=keep_last_n, 341 | window_info=window_info, 342 | ) 343 | features = [ 344 | x 345 | for x in fa.get_column_names(prep) 346 | if x not in {id_col, time_col, target_col} 347 | ] 348 | self.models_ = {} 349 | if SPARK_INSTALLED and isinstance(data, SparkDataFrame): 350 | featurizer = VectorAssembler(inputCols=features, outputCol="features") 351 | train_data = featurizer.transform(prep)[target_col, "features"] 352 | for name, model in self.models.items(): 353 | trained_model = model._pre_fit(target_col).fit(train_data) 354 | self.models_[name] = model.extract_local_model(trained_model) 355 | elif DASK_INSTALLED and isinstance(data, dd.DataFrame): 356 | X, y = prep[features], prep[target_col] 357 | for name, model in self.models.items(): 358 | trained_model = clone(model).fit(X, y) 359 | self.models_[name] = trained_model.model_ 360 | elif RAY_INSTALLED and isinstance(data, RayDataset): 361 | X = RayDMatrix( 362 | prep.select_columns(cols=features + [target_col]), 363 | label=target_col, 364 | ) 365 | for name, model in self.models.items(): 366 | trained_model = clone(model).fit(X, y=None) 367 | self.models_[name] = trained_model.model_ 368 | else: 369 | raise NotImplementedError( 370 | "Only spark, dask, and ray engines are supported." 371 | ) 372 | return self 373 | 374 | def fit( 375 | self, 376 | data: fugue.AnyDataFrame, 377 | id_col: str = "unique_id", 378 | time_col: str = "ds", 379 | target_col: str = "y", 380 | static_features: Optional[List[str]] = None, 381 | dropna: bool = True, 382 | keep_last_n: Optional[int] = None, 383 | ) -> "DistributedMLForecast": 384 | """Apply the feature engineering and train the models. 385 | 386 | Parameters 387 | ---------- 388 | data : dask or spark DataFrame 389 | Series data in long format. 390 | id_col : str (default='unique_id') 391 | Column that identifies each serie. 392 | time_col : str (default='ds') 393 | Column that identifies each timestep, its values can be timestamps or integers. 394 | target_col : str (default='y') 395 | Column that contains the target. 396 | static_features : list of str, optional (default=None) 397 | Names of the features that are static and will be repeated when forecasting. 398 | dropna : bool (default=True) 399 | Drop rows with missing values produced by the transformations. 400 | keep_last_n : int, optional (default=None) 401 | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. 402 | 403 | Returns 404 | ------- 405 | self : DistributedMLForecast 406 | Forecast object with series values and trained models. 407 | """ 408 | return self._fit( 409 | data, 410 | id_col=id_col, 411 | time_col=time_col, 412 | target_col=target_col, 413 | static_features=static_features, 414 | dropna=dropna, 415 | keep_last_n=keep_last_n, 416 | ) 417 | 418 | @staticmethod 419 | def _predict( 420 | items: List[List[Any]], 421 | models, 422 | horizon, 423 | dynamic_dfs=None, 424 | before_predict_callback=None, 425 | after_predict_callback=None, 426 | ) -> Iterable[pd.DataFrame]: 427 | for serialized_ts, _, serialized_valid in items: 428 | valid = cloudpickle.loads(serialized_valid) 429 | ts = cloudpickle.loads(serialized_ts) 430 | if valid is not None: 431 | dynamic_features = valid.columns.drop( 432 | [ts.id_col, ts.time_col, ts.target_col] 433 | ) 434 | if not dynamic_features.empty: 435 | dynamic_dfs = [valid.drop(columns=ts.target_col)] 436 | res = ts.predict( 437 | models=models, 438 | horizon=horizon, 439 | dynamic_dfs=dynamic_dfs, 440 | before_predict_callback=before_predict_callback, 441 | after_predict_callback=after_predict_callback, 442 | ) 443 | if valid is not None: 444 | res = res.merge(valid, how="left") 445 | yield res 446 | 447 | def _get_predict_schema(self) -> str: 448 | model_names = self.models.keys() 449 | models_schema = ",".join(f"{model_name}:double" for model_name in model_names) 450 | schema = f"{self.id_col}:string,{self.time_col}:datetime," + models_schema 451 | return schema 452 | 453 | def predict( 454 | self, 455 | horizon: int, 456 | dynamic_dfs: Optional[List[pd.DataFrame]] = None, 457 | before_predict_callback: Optional[Callable] = None, 458 | after_predict_callback: Optional[Callable] = None, 459 | new_data: Optional[fugue.AnyDataFrame] = None, 460 | ) -> fugue.AnyDataFrame: 461 | """Compute the predictions for the next `horizon` steps. 462 | 463 | Parameters 464 | ---------- 465 | horizon : int 466 | Number of periods to predict. 467 | dynamic_dfs : list of pandas DataFrame, optional (default=None) 468 | Future values of the dynamic features, e.g. prices. 469 | before_predict_callback : callable, optional (default=None) 470 | Function to call on the features before computing the predictions. 471 | This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. 472 | The series identifier is on the index. 473 | after_predict_callback : callable, optional (default=None) 474 | Function to call on the predictions before updating the targets. 475 | This function will take a pandas Series with the predictions and should return another one with the same structure. 476 | The series identifier is on the index. 477 | new_data : dask or spark DataFrame, optional (default=None) 478 | Series data of new observations for which forecasts are to be generated. 479 | This dataframe should have the same structure as the one used to fit the model, including any features and time series data. 480 | If `new_data` is not None, the method will generate forecasts for the new observations. 481 | 482 | Returns 483 | ------- 484 | result : dask, spark or ray DataFrame 485 | Predictions for each serie and timestep, with one column per model. 486 | """ 487 | if new_data is not None: 488 | partition_results = self._preprocess_partitions( 489 | data=new_data, 490 | id_col=self.id_col, 491 | time_col=self.time_col, 492 | target_col=self.target_col, 493 | static_features=self.static_features, 494 | dropna=self.dropna, 495 | keep_last_n=self.keep_last_n, 496 | fit_ts_only=True, 497 | ) 498 | else: 499 | partition_results = self.partition_results 500 | schema = self._get_predict_schema() 501 | res = fa.transform( 502 | partition_results, 503 | DistributedMLForecast._predict, 504 | params={ 505 | "models": self.models_, 506 | "horizon": horizon, 507 | "dynamic_dfs": dynamic_dfs, 508 | "before_predict_callback": before_predict_callback, 509 | "after_predict_callback": after_predict_callback, 510 | }, 511 | schema=schema, 512 | engine=self.engine, 513 | ) 514 | return fa.get_native_as_df(res) 515 | 516 | def cross_validation( 517 | self, 518 | data: fugue.AnyDataFrame, 519 | n_windows: int, 520 | window_size: int, 521 | id_col: str = "unique_id", 522 | time_col: str = "ds", 523 | target_col: str = "y", 524 | step_size: Optional[int] = None, 525 | static_features: Optional[List[str]] = None, 526 | dropna: bool = True, 527 | keep_last_n: Optional[int] = None, 528 | refit: bool = True, 529 | before_predict_callback: Optional[Callable] = None, 530 | after_predict_callback: Optional[Callable] = None, 531 | input_size: Optional[int] = None, 532 | ) -> fugue.AnyDataFrame: 533 | """Perform time series cross validation. 534 | Creates `n_windows` splits where each window has `window_size` test periods, 535 | trains the models, computes the predictions and merges the actuals. 536 | 537 | Parameters 538 | ---------- 539 | data : dask, spark or ray DataFrame 540 | Series data in long format. 541 | n_windows : int 542 | Number of windows to evaluate. 543 | window_size : int 544 | Number of test periods in each window. 545 | id_col : str (default='unique_id') 546 | Column that identifies each serie. 547 | time_col : str (default='ds') 548 | Column that identifies each timestep, its values can be timestamps or integers. 549 | target_col : str (default='y') 550 | Column that contains the target. 551 | step_size : int, optional (default=None) 552 | Step size between each cross validation window. If None it will be equal to `window_size`. 553 | static_features : list of str, optional (default=None) 554 | Names of the features that are static and will be repeated when forecasting. 555 | dropna : bool (default=True) 556 | Drop rows with missing values produced by the transformations. 557 | keep_last_n : int, optional (default=None) 558 | Keep only these many records from each serie for the forecasting step. Can save time and memory if your features allow it. 559 | refit : bool (default=True) 560 | Retrain model for each cross validation window. 561 | If False, the models are trained at the beginning and then used to predict each window. 562 | before_predict_callback : callable, optional (default=None) 563 | Function to call on the features before computing the predictions. 564 | This function will take the input dataframe that will be passed to the model for predicting and should return a dataframe with the same structure. 565 | The series identifier is on the index. 566 | after_predict_callback : callable, optional (default=None) 567 | Function to call on the predictions before updating the targets. 568 | This function will take a pandas Series with the predictions and should return another one with the same structure. 569 | The series identifier is on the index. 570 | input_size : int, optional (default=None) 571 | Maximum training samples per serie in each window. If None, will use an expanding window. 572 | 573 | Returns 574 | ------- 575 | result : dask, spark or ray DataFrame 576 | Predictions for each window with the series id, timestamp, target value and predictions from each model. 577 | """ 578 | self.cv_models_ = [] 579 | results = [] 580 | for i in range(n_windows): 581 | window_info = WindowInfo(n_windows, window_size, step_size, i, input_size) 582 | if refit or i == 0: 583 | self._fit( 584 | data, 585 | id_col=id_col, 586 | time_col=time_col, 587 | target_col=target_col, 588 | static_features=static_features, 589 | dropna=dropna, 590 | keep_last_n=keep_last_n, 591 | window_info=window_info, 592 | ) 593 | self.cv_models_.append(self.models_) 594 | partition_results = self.partition_results 595 | elif not refit: 596 | partition_results = self._preprocess_partitions( 597 | data=data, 598 | id_col=id_col, 599 | time_col=time_col, 600 | target_col=target_col, 601 | static_features=static_features, 602 | dropna=dropna, 603 | keep_last_n=keep_last_n, 604 | window_info=window_info, 605 | ) 606 | schema = ( 607 | self._get_predict_schema() 608 | + f",cutoff:datetime,{self.target_col}:double" 609 | ) 610 | preds = fa.transform( 611 | partition_results, 612 | DistributedMLForecast._predict, 613 | params={ 614 | "models": self.models_, 615 | "horizon": window_size, 616 | "before_predict_callback": before_predict_callback, 617 | "after_predict_callback": after_predict_callback, 618 | }, 619 | schema=schema, 620 | engine=self.engine, 621 | ) 622 | results.append(fa.get_native_as_df(preds)) 623 | if len(results) == 1: 624 | return results[0] 625 | if len(results) == 2: 626 | return fa.union(results[0], results[1]) 627 | return fa.union(results[0], results[1], results[2:]) 628 | --------------------------------------------------------------------------------