├── .gitattributes ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ ├── documentation-issue.yml │ └── feature-request.yml ├── dependabot.yml ├── pull_request_template.md ├── release-drafter.yml └── workflows │ ├── build-docs.yaml │ ├── ci.yaml │ ├── lint.yaml │ ├── no-response.yaml │ ├── release-drafter.yml │ └── release.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── action_files ├── clean_nbs ├── lint └── remove_logs_cells ├── mlforecast ├── __init__.py ├── _modidx.py ├── auto.py ├── callbacks.py ├── compat.py ├── core.py ├── distributed │ ├── __init__.py │ ├── forecast.py │ └── models │ │ ├── __init__.py │ │ ├── dask │ │ ├── __init__.py │ │ ├── lgb.py │ │ └── xgb.py │ │ ├── ray │ │ ├── __init__.py │ │ ├── lgb.py │ │ └── xgb.py │ │ └── spark │ │ ├── __init__.py │ │ ├── lgb.py │ │ └── xgb.py ├── feature_engineering.py ├── flavor.py ├── forecast.py ├── grouped_array.py ├── lag_transforms.py ├── lgb_cv.py ├── optimization.py ├── py.typed ├── target_transforms.py └── utils.py ├── nbs ├── .gitignore ├── _quarto.yml ├── auto.ipynb ├── callbacks.ipynb ├── compat.ipynb ├── core.ipynb ├── distributed.forecast.ipynb ├── distributed.models.dask.lgb.ipynb ├── distributed.models.dask.xgb.ipynb ├── distributed.models.ray.lgb.ipynb ├── distributed.models.ray.xgb.ipynb ├── distributed.models.spark.lgb.ipynb ├── distributed.models.spark.xgb.ipynb ├── docs │ ├── getting-started │ │ ├── end_to_end_walkthrough.ipynb │ │ ├── install.ipynb │ │ ├── quick_start_distributed.ipynb │ │ └── quick_start_local.ipynb │ ├── how-to-guides │ │ ├── analyzing_models.ipynb │ │ ├── cross_validation.ipynb │ │ ├── custom_date_features.ipynb │ │ ├── custom_training.ipynb │ │ ├── exogenous_features.ipynb │ │ ├── hyperparameter_optimization.ipynb │ │ ├── lag_transforms_guide.ipynb │ │ ├── mlflow.ipynb │ │ ├── one_model_per_horizon.ipynb │ │ ├── predict_callbacks.ipynb │ │ ├── predict_subset.ipynb │ │ ├── prediction_intervals.ipynb │ │ ├── sample_weights.ipynb │ │ ├── sklearn_pipelines.ipynb │ │ ├── target_transforms_guide.ipynb │ │ ├── training_with_numpy.ipynb │ │ ├── transfer_learning.ipynb │ │ └── transforming_exog.ipynb │ └── tutorials │ │ ├── electricity_load_forecasting.ipynb │ │ ├── electricity_peak_forecasting.ipynb │ │ └── prediction_intervals_in_forecasting_models.ipynb ├── favicon.png ├── feature_engineering.ipynb ├── figs │ ├── cross_validation__predictions.png │ ├── cross_validation__series.png │ ├── electricity_peak_forecasting__eda.png │ ├── electricity_peak_forecasting__predicted_peak.png │ ├── end_to_end_walkthrough__cv.png │ ├── end_to_end_walkthrough__differences.png │ ├── end_to_end_walkthrough__eda.png │ ├── end_to_end_walkthrough__final_forecast.png │ ├── end_to_end_walkthrough__lgbcv.png │ ├── end_to_end_walkthrough__predictions.png │ ├── forecast__cross_validation.png │ ├── forecast__cross_validation_intervals.png │ ├── forecast__ercot.png │ ├── forecast__predict.png │ ├── forecast__predict_intervals.png │ ├── forecast__predict_intervals_window_size_1.png │ ├── index.png │ ├── load_forecasting__differences.png │ ├── load_forecasting__prediction_intervals.png │ ├── load_forecasting__predictions.png │ ├── load_forecasting__raw.png │ ├── load_forecasting__transformed.png │ ├── logo.png │ ├── prediction_intervals__eda.png │ ├── prediction_intervals__knn.png │ ├── prediction_intervals__lasso.png │ ├── prediction_intervals__lr.png │ ├── prediction_intervals__mlp.png │ ├── prediction_intervals__ridge.png │ ├── prediction_intervals_in_forecasting_models__autocorrelation.png │ ├── prediction_intervals_in_forecasting_models__eda.png │ ├── prediction_intervals_in_forecasting_models__plot_forecasting_intervals.png │ ├── prediction_intervals_in_forecasting_models__plot_residual_model.png │ ├── prediction_intervals_in_forecasting_models__plot_values.png │ ├── prediction_intervals_in_forecasting_models__seasonal_decompose_aditive.png │ ├── prediction_intervals_in_forecasting_models__seasonal_decompose_multiplicative.png │ ├── prediction_intervals_in_forecasting_models__train_test.png │ ├── quick_start_local__eda.png │ ├── quick_start_local__predictions.png │ ├── target_transforms__diff1.png │ ├── target_transforms__diff2.png │ ├── target_transforms__eda.png │ ├── target_transforms__log.png │ ├── target_transforms__log_diffs.png │ ├── target_transforms__minmax.png │ ├── target_transforms__standardized.png │ ├── target_transforms__zeros.png │ ├── transfer_learning__eda.png │ └── transfer_learning__forecast.png ├── forecast.ipynb ├── grouped_array.ipynb ├── index.ipynb ├── lag_transforms.ipynb ├── lgb_cv.ipynb ├── mint.json ├── nbdev.yml ├── optimization.ipynb ├── sidebar.yml ├── styles.css ├── target_transforms.ipynb └── utils.ipynb ├── pyproject.toml ├── settings.ini ├── setup.py └── tests ├── test_m4.py └── test_pipeline.py /.gitattributes: -------------------------------------------------------------------------------- 1 | nbs/** linguist-documentation 2 | *.ipynb merge=nbdev-merge 3 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @jmoralez 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | title: "[] " 3 | description: Problems and issues with code of the library 4 | labels: [bug] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for reporting the problem! 10 | Please make sure what you are reporting is a bug with reproducible steps. To ask questions 11 | or share ideas, please post on our [Slack community](https://join.slack.com/t/nixtlacommunity/shared_invite/zt-1h77esh5y-iL1m8N0F7qV1HmH~0KYeAQ) instead. 12 | 13 | - type: textarea 14 | attributes: 15 | label: What happened + What you expected to happen 16 | description: Describe 1. the bug 2. expected behavior 3. useful information (e.g., logs) 17 | placeholder: > 18 | Please provide the context in which the problem occurred and explain what happened. Further, 19 | please also explain why you think the behaviour is erroneous. It is extremely helpful if you can 20 | copy and paste the fragment of logs showing the exact error messages or wrong behaviour here. 21 | 22 | **NOTE**: please copy and paste texts instead of taking screenshots of them for easy future search. 23 | validations: 24 | required: true 25 | 26 | - type: textarea 27 | attributes: 28 | label: Versions / Dependencies 29 | description: Please specify the versions of the library, Python, OS, and other libraries that are used. 30 | placeholder: > 31 | Please specify the versions of dependencies. 32 | validations: 33 | required: true 34 | 35 | - type: textarea 36 | attributes: 37 | label: Reproduction script 38 | description: > 39 | Please provide a reproducible script. Providing a narrow reproduction (minimal / no external dependencies) will 40 | help us triage and address issues in the timely manner! 41 | placeholder: > 42 | Please provide a short code snippet (less than 50 lines if possible) that can be copy-pasted to 43 | reproduce the issue. The snippet should have **no external library dependencies** 44 | (i.e., use fake or mock data / environments). 45 | 46 | **NOTE**: If the code snippet cannot be run by itself, the issue will be marked as "needs-repro-script" 47 | until the repro instruction is updated. 48 | validations: 49 | required: true 50 | 51 | - type: dropdown 52 | attributes: 53 | label: Issue Severity 54 | description: | 55 | How does this issue affect your experience as user? 56 | multiple: false 57 | options: 58 | - "Low: It annoys or frustrates me." 59 | - "Medium: It is a significant difficulty but I can work around it." 60 | - "High: It blocks me from completing my task." 61 | validations: 62 | required: false 63 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Ask a question or get support 4 | url: https://join.slack.com/t/nixtlacommunity/shared_invite/zt-1h77esh5y-iL1m8N0F7qV1HmH~0KYeAQ 5 | about: Ask a question or request support for using a library of the nixtlaverse 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-issue.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | title: "[] " 3 | description: Report an issue with the library documentation 4 | labels: [documentation] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: Thank you for helping us improve the library documentation! 9 | 10 | - type: textarea 11 | attributes: 12 | label: Description 13 | description: | 14 | Tell us about the change you'd like to see. For example, "I'd like to 15 | see more examples of how to use `cross_validation`." 16 | validations: 17 | required: true 18 | 19 | - type: textarea 20 | attributes: 21 | label: Link 22 | description: | 23 | If the problem is related to an existing section, please add a link to 24 | the section. 25 | validations: 26 | required: false 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: Library feature request 2 | description: Suggest an idea for a project 3 | title: "[] " 4 | labels: [enhancement, feature] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for finding the time to propose a new feature! 10 | We really appreciate the community efforts to improve the nixtlaverse. 11 | 12 | - type: textarea 13 | attributes: 14 | label: Description 15 | description: A short description of your feature 16 | 17 | - type: textarea 18 | attributes: 19 | label: Use case 20 | description: > 21 | Describe the use case of your feature request. It will help us understand and 22 | prioritize the feature request. 23 | placeholder: > 24 | Rather than telling us how you might implement this feature, try to take a 25 | step back and describe what you are trying to achieve. 26 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: / 5 | schedule: 6 | interval: weekly 7 | groups: 8 | ci-dependencies: 9 | patterns: ["*"] 10 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | ## Description 8 | 9 | 10 | Checklist: 11 | - [ ] This PR has a meaningful title and a clear description. 12 | - [ ] The tests pass. 13 | - [ ] All linting tasks pass. 14 | - [ ] The notebooks are clean. -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v$NEXT_PATCH_VERSION' 2 | tag-template: 'v$NEXT_PATCH_VERSION' 3 | categories: 4 | - title: 'New Features' 5 | label: 'feature' 6 | - title: 'Breaking Change' 7 | label: 'breaking change' 8 | - title: 'Bug Fixes' 9 | label: 'fix' 10 | - title: 'Documentation' 11 | label: 'documentation' 12 | - title: 'Dependencies' 13 | label: 'dependencies' 14 | - title: 'Enhancement' 15 | label: 'enhancement' 16 | change-template: '- $TITLE @$AUTHOR (#$NUMBER)' 17 | template: | 18 | ## Changes 19 | $CHANGES 20 | -------------------------------------------------------------------------------- /.github/workflows/build-docs.yaml: -------------------------------------------------------------------------------- 1 | name: "build-docs" 2 | on: 3 | push: 4 | branches: ["main"] 5 | pull_request: 6 | branches: ["main"] 7 | workflow_dispatch: 8 | 9 | jobs: 10 | build-docs: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Clone repo 14 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 15 | - name: Clone docs repo 16 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 17 | with: 18 | repository: Nixtla/docs 19 | ref: scripts 20 | path: docs-scripts 21 | - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 22 | with: 23 | python-version: "3.10" 24 | - name: Build docs 25 | run: | 26 | pip install uv && uv pip install --system ".[all]" 27 | mkdir nbs/_extensions 28 | cp -r docs-scripts/mintlify/ nbs/_extensions/ 29 | python docs-scripts/update-quarto.py 30 | nbdev_docs 31 | - name: Apply final formats 32 | run: bash ./docs-scripts/docs-final-formatting.bash 33 | - name: Copy over necessary assets 34 | run: | 35 | cp nbs/mint.json _docs/mint.json 36 | cp docs-scripts/imgs/* _docs/ 37 | - name: Configure redirects for gh-pages 38 | run: python docs-scripts/configure-redirects.py mlforecast 39 | - name: Deploy to Mintlify Docs 40 | if: github.event_name == 'push' 41 | uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 42 | with: 43 | github_token: ${{ secrets.GITHUB_TOKEN }} 44 | publish_branch: docs 45 | publish_dir: ./_docs 46 | user_name: github-actions[bot] 47 | user_email: 41898282+github-actions[bot]@users.noreply.github.com 48 | - name: Trigger mintlify workflow 49 | if: github.event_name == 'push' 50 | uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 51 | with: 52 | github-token: ${{ secrets.DOCS_WORKFLOW_TOKEN }} 53 | script: | 54 | await github.rest.actions.createWorkflowDispatch({ 55 | owner: 'nixtla', 56 | repo: 'docs', 57 | workflow_id: 'mintlify-action.yml', 58 | ref: 'main', 59 | }); 60 | - name: Deploy to Github Pages 61 | if: github.event_name == 'push' 62 | uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 63 | with: 64 | github_token: ${{ secrets.GITHUB_TOKEN }} 65 | publish_branch: gh-pages 66 | publish_dir: ./gh-pages 67 | user_name: github-actions[bot] 68 | user_email: 41898282+github-actions[bot]@users.noreply.github.com 69 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | run-all-tests: 16 | runs-on: ubuntu-latest 17 | timeout-minutes: 30 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: ["3.9", "3.10", "3.11", "3.12"] 22 | env: 23 | AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID_NIXTLA_TMP }} 24 | AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY_NIXTLA_TMP }} 25 | steps: 26 | - name: Clone repo 27 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 28 | 29 | - name: Set up environment 30 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | 34 | - name: Install the library 35 | run: pip install uv && uv pip install --system ".[all]" "numba>=0.60" "scikit-learn<1.6" shap window-ops 36 | 37 | - name: Run all tests 38 | run: nbdev_test --n_workers 0 --do_print --timing --skip_file_re 'electricity' --flags 'polars shap window_ops' 39 | 40 | run-local-tests: 41 | runs-on: ${{ matrix.os }} 42 | timeout-minutes: 30 43 | strategy: 44 | fail-fast: false 45 | matrix: 46 | os: [macos-13, macos-14, windows-latest] 47 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 48 | steps: 49 | - name: Clone repo 50 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 51 | 52 | - name: Set up environment 53 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 54 | with: 55 | python-version: ${{ matrix.python-version }} 56 | 57 | - name: Install the library 58 | run: pip install uv && uv pip install --system ".[dev]" 59 | 60 | - name: Install OpenMP 61 | if: startsWith(matrix.os, 'macos') 62 | run: brew install libomp 63 | 64 | - name: Run local tests 65 | run: nbdev_test --n_workers 0 --do_print --timing --skip_file_re "(distributed|electricity)" --flags 'polars' 66 | 67 | check-deps: 68 | runs-on: ubuntu-latest 69 | steps: 70 | - name: Clone repo 71 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 72 | 73 | - name: Set up python 74 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 75 | with: 76 | python-version: "3.10" 77 | 78 | - name: Install forecast notebook dependencies 79 | run: pip install uv && uv pip install --system . lightgbm matplotlib nbdev pyarrow xgboost 80 | 81 | - name: Run forecast notebook 82 | run: nbdev_test --path nbs/forecast.ipynb 83 | 84 | efficiency-tests: 85 | runs-on: ubuntu-latest 86 | steps: 87 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 88 | 89 | - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 90 | with: 91 | python-version: "3.12" 92 | 93 | - name: Install dependencies 94 | run: pip install uv && uv pip install --system . pytest-codspeed pytest-xdist 95 | 96 | - name: Run benchmarks 97 | uses: CodSpeedHQ/action@63ae6025a0ffee97d7736a37c9192dbd6ed4e75f # 3.4.0 98 | with: 99 | token: ${{ secrets.CODESPEED_TOKEN }} 100 | run: pytest tests/test_pipeline.py --codspeed -n 2 101 | 102 | performance-tests: 103 | runs-on: ubuntu-latest 104 | steps: 105 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 106 | 107 | - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 108 | with: 109 | python-version: "3.10" 110 | 111 | - name: Install dependencies 112 | run: pip install uv && uv pip install --system . datasetsforecast lightgbm pytest 113 | 114 | - name: Run m4 performance tests 115 | run: pytest tests/test_m4.py 116 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Clone repo 14 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 15 | 16 | - name: Set up python 17 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 18 | with: 19 | python-version: '3.10' 20 | cache: 'pip' 21 | 22 | - name: Install dependencies 23 | run: pip install black 'nbdev<2.3.26' pre-commit 24 | 25 | - name: Run pre-commit 26 | run: pre-commit run --files mlforecast/* 27 | -------------------------------------------------------------------------------- /.github/workflows/no-response.yaml: -------------------------------------------------------------------------------- 1 | name: No Response Bot 2 | 3 | on: 4 | issue_comment: 5 | types: [created] 6 | schedule: 7 | - cron: '0 4 * * *' 8 | 9 | jobs: 10 | noResponse: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: lee-dohm/no-response@9bb0a4b5e6a45046f00353d5de7d90fb8bd773bb # v0.5.0 14 | with: 15 | closeComment: > 16 | This issue has been automatically closed because it has been awaiting a response for too long. 17 | When you have time to to work with the maintainers to resolve this issue, please post a new comment and it will be re-opened. 18 | If the issue has been locked for editing by the time you return to it, please open a new issue and reference this one. 19 | daysUntilClose: 30 20 | responseRequiredLabel: awaiting response 21 | token: ${{ github.token }} 22 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | update_release_draft: 13 | permissions: 14 | contents: write 15 | pull-requests: read 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: release-drafter/release-drafter@b1476f6e6eb133afa41ed8589daba6dc69b4d3f5 # v6.1.0 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | defaults: 9 | run: 10 | shell: bash -l {0} 11 | 12 | jobs: 13 | release: 14 | if: github.repository == 'Nixtla/mlforecast' 15 | runs-on: ubuntu-latest 16 | permissions: 17 | id-token: write 18 | steps: 19 | - name: Clone repo 20 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 21 | 22 | - name: Set up python 23 | uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # 5.4.0 24 | with: 25 | python-version: '3.10' 26 | 27 | - name: Install build dependencies 28 | run: python -m pip install build wheel 29 | 30 | - name: Build distributions 31 | run: python -m build -sw 32 | 33 | - name: Publish package to PyPI 34 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | .gitattributes 3 | .last_checked 4 | .gitconfig 5 | *.bak 6 | *.log 7 | *~ 8 | ~* 9 | _tmp* 10 | tmp* 11 | tags 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | .venv 99 | venv/ 100 | ENV/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | .vscode 116 | *.swp 117 | 118 | # osx generated files 119 | .DS_Store 120 | .DS_Store? 121 | .Trashes 122 | ehthumbs.db 123 | Thumbs.db 124 | .idea 125 | 126 | # pytest 127 | .pytest_cache 128 | 129 | # tools/trust-doc-nbs 130 | docs_src/.last_checked 131 | 132 | # symlinks to fastai 133 | docs_src/fastai 134 | tools/fastai 135 | 136 | # link checker 137 | checklink/cookies.txt 138 | 139 | # .gitconfig is now autogenerated 140 | .gitconfig 141 | 142 | # dask 143 | dask-worker-space 144 | 145 | # gemfiles 146 | Gemfile* 147 | 148 | # jekyll 149 | .jekyll-cache 150 | 151 | # series files 152 | /**/data/ 153 | 154 | # nbdev 155 | nbs/_docs 156 | _proc/ 157 | index_files 158 | _docs 159 | nbs/docs/**/*.xls* 160 | nbs/_extensions 161 | catboost_info 162 | *.pkl 163 | mlruns 164 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/.gitmodules -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: true 2 | 3 | repos: 4 | - repo: local 5 | hooks: 6 | - id: clean_nbs 7 | name: Clean notebooks 8 | entry: sh action_files/clean_nbs 9 | language: system 10 | - repo: https://github.com/fastai/nbdev 11 | rev: 2.2.10 12 | hooks: 13 | - id: nbdev_export 14 | - repo: https://github.com/astral-sh/ruff-pre-commit 15 | rev: v0.2.1 16 | hooks: 17 | - id: ruff 18 | - repo: https://github.com/pre-commit/mirrors-mypy 19 | rev: v1.8.0 20 | hooks: 21 | - id: mypy 22 | args: [--ignore-missing-imports] 23 | exclude: 'setup.py' 24 | additional_dependencies: ['types-PyYAML'] 25 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | ops@nixtla.io. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | ## Did you find a bug? 4 | 5 | * Ensure the bug was not already reported by searching on GitHub under Issues. 6 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. 7 | * Be sure to add the complete error messages. 8 | 9 | ## Do you have a feature request? 10 | 11 | * Ensure that it hasn't been yet implemented in the `main` branch of the repository and that there's not an Issue requesting it yet. 12 | * Open a new issue and make sure to describe it clearly, mention how it improves the project and why its useful. 13 | 14 | ## Do you want to fix a bug or implement a feature? 15 | 16 | Bug fixes and features are added through pull requests (PRs). 17 | 18 | ## PR submission guidelines 19 | 20 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused. 21 | * Ensure that your PR includes a test that fails without your patch, and passes with it. 22 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. 23 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected. 24 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can. 25 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project. 26 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another. 27 | 28 | ### Local setup for working on a PR 29 | 30 | #### 1. Clone the repository 31 | * HTTPS: `git clone https://github.com/Nixtla/mlforecast.git` 32 | * SSH: `git clone git@github.com:Nixtla/mlforecast.git` 33 | * GitHub CLI: `gh repo clone Nixtla/mlforecast` 34 | 35 | #### 2. Install the required dependencies for development 36 | ##### conda/mamba 37 | The repo comes with an `environment.yml` file which contains the libraries needed to run all the tests (please note that the distributed interface is only available on Linux). In order to set up the environment you must have `conda/mamba` installed, we recommend [mambaforge](https://github.com/conda-forge/miniforge#mambaforge). 38 | 39 | Once you have `conda/mamba` go to the top level directory of the repository and run: 40 | ``` 41 | {conda|mamba} env create -f environment.yml 42 | ``` 43 | 44 | Once you have your environment setup, activate it using `conda activate mlforecast`. 45 | 46 | ##### PyPI 47 | From the top level directory of the repository run: `pip install ".[dev]"` 48 | 49 | #### 3. Install the library 50 | From the top level directory of the repository run: `pip install -e .[dev]` 51 | 52 | ##### Setting up pre-commit 53 | Run `pre-commit install` 54 | 55 | ### Building the library 56 | The library is built using the notebooks contained in the `nbs` folder. If you want to make any changes to the library you have to find the relevant notebook, make your changes and then call `nbdev_export`. 57 | 58 | ### Running tests 59 | 60 | * If you're working on the local interface, use `nbdev_test --skip_file_glob "distributed*" --n_workers 1`. 61 | * If you're modifying the distributed interface run the tests using `nbdev_test --n_workers 1`. 62 | 63 | ### Run the linting tasks 64 | Run `./action_files/lint` 65 | 66 | ### Cleaning notebooks 67 | Run `./action_files/clean_nbs` 68 | 69 | 70 | ## Do you want to contribute to the documentation? 71 | 72 | * Docs are automatically created from the notebooks in the `nbs` folder. 73 | * In order to modify the documentation: 74 | 1. Find the relevant notebook. 75 | 2. Make your changes. 76 | 3. Run all cells. 77 | 4. Run `nbdev_preview` 78 | 5. If you modified the `index.ipynb` notebook, run `nbdev_readme`. 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 Nixtla 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mlforecast 2 | [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Statistical%20Forecasting%20Algorithms%20by%20Nixtla%20&url=https://github.com/Nixtla/statsforecast&via=nixtlainc&hashtags=StatisticalModels,TimeSeries,Forecasting) 3 | [![Slack](https://img.shields.io/badge/Slack-4A154B?&logo=slack&logoColor=white.png)](https://join.slack.com/t/nixtlacommunity/shared_invite/zt-1pmhan9j5-F54XR20edHk0UtYAPcW4KQ) 4 | 5 | 6 | 7 |
8 | 9 |
10 | 11 |
12 |

13 | Machine Learning 🤖 Forecast 14 |

15 |

16 | Scalable machine learning for time series forecasting 17 |

18 | 19 | [![CI](https://github.com/Nixtla/mlforecast/actions/workflows/ci.yaml/badge.svg)](https://github.com/Nixtla/mlforecast/actions/workflows/ci.yaml) 20 | [![Python](https://img.shields.io/pypi/pyversions/mlforecast.png)](https://pypi.org/project/mlforecast/) 21 | [![PyPi](https://img.shields.io/pypi/v/mlforecast?color=blue.png)](https://pypi.org/project/mlforecast/) 22 | [![conda-forge](https://img.shields.io/conda/vn/conda-forge/mlforecast?color=blue.png)](https://anaconda.org/conda-forge/mlforecast) 23 | [![License](https://img.shields.io/github/license/Nixtla/mlforecast.png)](https://github.com/Nixtla/mlforecast/blob/main/LICENSE) 24 | 25 | **mlforecast** is a framework to perform time series forecasting using 26 | machine learning models, with the option to scale to massive amounts of 27 | data using remote clusters. 28 | 29 |
30 | 31 | ## Install 32 | 33 | ### PyPI 34 | 35 | `pip install mlforecast` 36 | 37 | ### conda-forge 38 | 39 | `conda install -c conda-forge mlforecast` 40 | 41 | For more detailed instructions you can refer to the [installation 42 | page](https://nixtla.github.io/mlforecast/docs/getting-started/install.html). 43 | 44 | ## Quick Start 45 | 46 | **Get Started with this [quick 47 | guide](https://nixtla.github.io/mlforecast/docs/getting-started/quick_start_local.html).** 48 | 49 | **Follow this [end-to-end 50 | walkthrough](https://nixtla.github.io/mlforecast/docs/getting-started/end_to_end_walkthrough.html) 51 | for best practices.** 52 | 53 | ### Videos 54 | 55 | - [Overview](https://www.youtube.com/live/EnhyJx8l2LE) 56 | 57 | ### Sample notebooks 58 | 59 | - [m5](https://www.kaggle.com/code/lemuz90/m5-mlforecast-eval) 60 | - [m5-polars](https://www.kaggle.com/code/lemuz90/m5-mlforecast-eval-polars) 61 | - [m4](https://www.kaggle.com/code/lemuz90/m4-competition) 62 | - [m4-cv](https://www.kaggle.com/code/lemuz90/m4-competition-cv) 63 | - [favorita](https://www.kaggle.com/code/lemuz90/mlforecast-favorita) 64 | - [VN1](https://colab.research.google.com/drive/1UdhCAk49k6HgMezG-U_1ETnAB5pYvZk9) 65 | 66 | ## Why? 67 | 68 | Current Python alternatives for machine learning models are slow, 69 | inaccurate and don’t scale well. So we created a library that can be 70 | used to forecast in production environments. 71 | [`MLForecast`](https://Nixtla.github.io/mlforecast/forecast.html#mlforecast) 72 | includes efficient feature engineering to train any machine learning 73 | model (with `fit` and `predict` methods such as 74 | [`sklearn`](https://scikit-learn.org/stable/)) to fit millions of time 75 | series. 76 | 77 | ## Features 78 | 79 | - Fastest implementations of feature engineering for time series 80 | forecasting in Python. 81 | - Out-of-the-box compatibility with pandas, polars, spark, dask, and 82 | ray. 83 | - Probabilistic Forecasting with Conformal Prediction. 84 | - Support for exogenous variables and static covariates. 85 | - Familiar `sklearn` syntax: `.fit` and `.predict`. 86 | 87 | Missing something? Please open an issue or write us in 88 | [![Slack](https://img.shields.io/badge/Slack-4A154B?&logo=slack&logoColor=white.png)](https://join.slack.com/t/nixtlaworkspace/shared_invite/zt-135dssye9-fWTzMpv2WBthq8NK0Yvu6A) 89 | 90 | ## Examples and Guides 91 | 92 | 📚 [End to End 93 | Walkthrough](https://nixtla.github.io/mlforecast/docs/getting-started/end_to_end_walkthrough.html): 94 | model training, evaluation and selection for multiple time series. 95 | 96 | 🔎 [Probabilistic 97 | Forecasting](https://nixtla.github.io/mlforecast/docs/how-to-guides/prediction_intervals.html): 98 | use Conformal Prediction to produce prediciton intervals. 99 | 100 | 👩‍🔬 [Cross 101 | Validation](https://nixtla.github.io/mlforecast/docs/how-to-guides/cross_validation.html): 102 | robust model’s performance evaluation. 103 | 104 | 🔌 [Predict Demand 105 | Peaks](https://nixtla.github.io/mlforecast/docs/tutorials/electricity_peak_forecasting.html): 106 | electricity load forecasting for detecting daily peaks and reducing 107 | electric bills. 108 | 109 | 📈 [Transfer 110 | Learning](https://nixtla.github.io/mlforecast/docs/how-to-guides/transfer_learning.html): 111 | pretrain a model using a set of time series and then predict another one 112 | using that pretrained model. 113 | 114 | 🌡️ [Distributed 115 | Training](https://nixtla.github.io/mlforecast/docs/getting-started/quick_start_distributed.html): 116 | use a Dask, Ray or Spark cluster to train models at scale. 117 | 118 | ## How to use 119 | 120 | The following provides a very basic overview, for a more detailed 121 | description see the 122 | [documentation](https://nixtla.github.io/mlforecast/). 123 | 124 | ### Data setup 125 | 126 | Store your time series in a pandas dataframe in long format, that is, 127 | each row represents an observation for a specific serie and timestamp. 128 | 129 | ``` python 130 | from mlforecast.utils import generate_daily_series 131 | 132 | series = generate_daily_series( 133 | n_series=20, 134 | max_length=100, 135 | n_static_features=1, 136 | static_as_categorical=False, 137 | with_trend=True 138 | ) 139 | series.head() 140 | ``` 141 | 142 |
143 | 144 | | | unique_id | ds | y | static_0 | 145 | |-----|-----------|------------|------------|----------| 146 | | 0 | id_00 | 2000-01-01 | 17.519167 | 72 | 147 | | 1 | id_00 | 2000-01-02 | 87.799695 | 72 | 148 | | 2 | id_00 | 2000-01-03 | 177.442975 | 72 | 149 | | 3 | id_00 | 2000-01-04 | 232.704110 | 72 | 150 | | 4 | id_00 | 2000-01-05 | 317.510474 | 72 | 151 | 152 |
153 | 154 | > Note: The unique_id serves as an identifier for each distinct time 155 | > series in your dataset. If you are using only single time series from 156 | > your dataset, set this column to a constant value. 157 | 158 | ### Models 159 | 160 | Next define your models, each one will be trained on all series. These 161 | can be any regressor that follows the scikit-learn API. 162 | 163 | ``` python 164 | import lightgbm as lgb 165 | from sklearn.linear_model import LinearRegression 166 | ``` 167 | 168 | ``` python 169 | models = [ 170 | lgb.LGBMRegressor(random_state=0, verbosity=-1), 171 | LinearRegression(), 172 | ] 173 | ``` 174 | 175 | ### Forecast object 176 | 177 | Now instantiate an 178 | [`MLForecast`](https://Nixtla.github.io/mlforecast/forecast.html#mlforecast) 179 | object with the models and the features that you want to use. The 180 | features can be lags, transformations on the lags and date features. You 181 | can also define transformations to apply to the target before fitting, 182 | which will be restored when predicting. 183 | 184 | ``` python 185 | from mlforecast import MLForecast 186 | from mlforecast.lag_transforms import ExpandingMean, RollingMean 187 | from mlforecast.target_transforms import Differences 188 | ``` 189 | 190 | ``` python 191 | fcst = MLForecast( 192 | models=models, 193 | freq='D', 194 | lags=[7, 14], 195 | lag_transforms={ 196 | 1: [ExpandingMean()], 197 | 7: [RollingMean(window_size=28)] 198 | }, 199 | date_features=['dayofweek'], 200 | target_transforms=[Differences([1])], 201 | ) 202 | ``` 203 | 204 | ### Training 205 | 206 | To compute the features and train the models call `fit` on your 207 | `Forecast` object. 208 | 209 | ``` python 210 | fcst.fit(series) 211 | ``` 212 | 213 | MLForecast(models=[LGBMRegressor, LinearRegression], freq=D, lag_features=['lag7', 'lag14', 'expanding_mean_lag1', 'rolling_mean_lag7_window_size28'], date_features=['dayofweek'], num_threads=1) 214 | 215 | ### Predicting 216 | 217 | To get the forecasts for the next `n` days call `predict(n)` on the 218 | forecast object. This will automatically handle the updates required by 219 | the features using a recursive strategy. 220 | 221 | ``` python 222 | predictions = fcst.predict(14) 223 | predictions 224 | ``` 225 | 226 |
227 | 228 | | | unique_id | ds | LGBMRegressor | LinearRegression | 229 | |-----|-----------|------------|---------------|------------------| 230 | | 0 | id_00 | 2000-04-04 | 299.923771 | 311.432371 | 231 | | 1 | id_00 | 2000-04-05 | 365.424147 | 379.466214 | 232 | | 2 | id_00 | 2000-04-06 | 432.562441 | 460.234028 | 233 | | 3 | id_00 | 2000-04-07 | 495.628000 | 524.278924 | 234 | | 4 | id_00 | 2000-04-08 | 60.786223 | 79.828767 | 235 | | ... | ... | ... | ... | ... | 236 | | 275 | id_19 | 2000-03-23 | 36.266780 | 28.333215 | 237 | | 276 | id_19 | 2000-03-24 | 44.370984 | 33.368228 | 238 | | 277 | id_19 | 2000-03-25 | 50.746222 | 38.613001 | 239 | | 278 | id_19 | 2000-03-26 | 58.906524 | 43.447398 | 240 | | 279 | id_19 | 2000-03-27 | 63.073949 | 48.666783 | 241 | 242 |

280 rows × 4 columns

243 |
244 | 245 | ### Visualize results 246 | 247 | ``` python 248 | from utilsforecast.plotting import plot_series 249 | ``` 250 | 251 | ``` python 252 | fig = plot_series(series, predictions, max_ids=4, plot_random=False) 253 | ``` 254 | 255 | ![](https://raw.githubusercontent.com/Nixtla/mlforecast/main/nbs/figs/index.png) 256 | 257 | ## How to contribute 258 | 259 | See 260 | [CONTRIBUTING.md](https://github.com/Nixtla/mlforecast/blob/main/CONTRIBUTING.md). 261 | -------------------------------------------------------------------------------- /action_files/clean_nbs: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | nbdev_clean 3 | # the following sets the kernel as python 3 to avoid annoying diffs 4 | for file in $(find nbs/ -type f -name "*.ipynb") 5 | do 6 | sed -i 's/Python 3.*,$/Python 3\",/g' $file 7 | done 8 | # distributed training produces logs with different IPs each time 9 | ./action_files/remove_logs_cells 10 | -------------------------------------------------------------------------------- /action_files/lint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ruff check mlforecast || exit -1 3 | mypy mlforecast || exit -1 4 | -------------------------------------------------------------------------------- /action_files/remove_logs_cells: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import re 3 | import json 4 | from pathlib import Path 5 | from nbdev.clean import process_write 6 | 7 | IP_REGEX = re.compile(r'[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}') 8 | HOURS_REGEX = re.compile(r'\d{2}:\d{2}:\d{2}') 9 | 10 | def cell_contains_ips(cell): 11 | if 'outputs' not in cell: 12 | return False 13 | for output in cell['outputs']: 14 | if 'text' not in output: 15 | return False 16 | for line in output['text']: 17 | if IP_REGEX.search(line) or HOURS_REGEX.search(line) or '[LightGBM]' in line: 18 | return True 19 | return False 20 | 21 | 22 | def clean_nb(nb): 23 | for cell in nb['cells']: 24 | if cell_contains_ips(cell): 25 | cell['outputs'] = [] 26 | 27 | 28 | if __name__ == '__main__': 29 | repo_root = Path(__file__).parents[1] 30 | for nb in (repo_root / 'nbs').rglob('*distributed*.ipynb'): 31 | process_write(warn_msg='Failed to clean_nb', proc_nb=clean_nb, f_in=nb) 32 | -------------------------------------------------------------------------------- /mlforecast/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.2" 2 | __all__ = ['MLForecast'] 3 | from mlforecast.forecast import MLForecast 4 | -------------------------------------------------------------------------------- /mlforecast/callbacks.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/callbacks.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SaveFeatures'] 5 | 6 | # %% ../nbs/callbacks.ipynb 3 7 | from utilsforecast.compat import DataFrame 8 | from utilsforecast.processing import ( 9 | assign_columns, 10 | drop_index_if_pandas, 11 | vertical_concat, 12 | ) 13 | 14 | # %% ../nbs/callbacks.ipynb 4 15 | class SaveFeatures: 16 | """Saves the features in every timestamp.""" 17 | 18 | def __init__(self): 19 | self._inputs = [] 20 | 21 | def __call__(self, new_x): 22 | self._inputs.append(new_x) 23 | return new_x 24 | 25 | def get_features(self, with_step: bool = False) -> DataFrame: 26 | """Retrieves the input features for every timestep 27 | 28 | Parameters 29 | ---------- 30 | with_step : bool 31 | Add a column indicating the step 32 | 33 | Returns 34 | ------- 35 | pandas or polars DataFrame 36 | DataFrame with input features 37 | """ 38 | if not self._inputs: 39 | raise ValueError( 40 | "Inputs list is empty. " 41 | "Call `predict` using this callback as before_predict_callback" 42 | ) 43 | if with_step: 44 | dfs = [assign_columns(df, "step", i) for i, df in enumerate(self._inputs)] 45 | else: 46 | dfs = self._inputs 47 | res = vertical_concat(dfs, match_categories=False) 48 | res = drop_index_if_pandas(res) 49 | return res 50 | -------------------------------------------------------------------------------- /mlforecast/compat.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/compat.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = [] 5 | 6 | # %% ../nbs/compat.ipynb 1 7 | try: 8 | from catboost import CatBoostRegressor 9 | except ImportError: 10 | 11 | class CatBoostRegressor: 12 | def __init__(self, *args, **kwargs): # noqa: ARG002 13 | raise ImportError("Please install catboost to use this model.") 14 | 15 | 16 | try: 17 | from lightgbm import LGBMRegressor 18 | except ImportError: 19 | 20 | class LGBMRegressor: 21 | def __init__(self, *args, **kwargs): # noqa: ARG002 22 | raise ImportError("Please install lightgbm to use this model.") 23 | 24 | 25 | try: 26 | from xgboost import XGBRegressor 27 | except ImportError: 28 | 29 | class XGBRegressor: 30 | def __init__(self, *args, **kwargs): # noqa: ARG002 31 | raise ImportError("Please install xgboost to use this model.") 32 | 33 | 34 | try: 35 | from window_ops.shift import shift_array 36 | except ImportError: 37 | import numpy as np 38 | from utilsforecast.compat import njit 39 | 40 | @njit 41 | def shift_array(x, offset): 42 | if offset >= x.size or offset < 0: 43 | return np.full_like(x, np.nan) 44 | if offset == 0: 45 | return x.copy() 46 | out = np.empty_like(x) 47 | out[:offset] = np.nan 48 | out[offset:] = x[:-offset] 49 | return out 50 | -------------------------------------------------------------------------------- /mlforecast/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['DistributedMLForecast'] 2 | from mlforecast.distributed.forecast import DistributedMLForecast 3 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/mlforecast/distributed/models/__init__.py -------------------------------------------------------------------------------- /mlforecast/distributed/models/dask/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/mlforecast/distributed/models/dask/__init__.py -------------------------------------------------------------------------------- /mlforecast/distributed/models/dask/lgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.dask.lgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DaskLGBMForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.dask.lgb.ipynb 3 7 | import warnings 8 | 9 | import lightgbm as lgb 10 | 11 | # %% ../../../../nbs/distributed.models.dask.lgb.ipynb 4 12 | class DaskLGBMForecast(lgb.dask.DaskLGBMRegressor): 13 | if lgb.__version__ < "3.3.0": 14 | warnings.warn( 15 | "It is recommended to install LightGBM version >= 3.3.0, since " 16 | "the current LightGBM version might be affected by https://github.com/microsoft/LightGBM/issues/4026, " 17 | "which was fixed in 3.3.0" 18 | ) 19 | 20 | @property 21 | def model_(self): 22 | return self.to_local() 23 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/dask/xgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.dask.xgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DaskXGBForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.dask.xgb.ipynb 3 7 | import xgboost as xgb 8 | 9 | # %% ../../../../nbs/distributed.models.dask.xgb.ipynb 4 10 | class DaskXGBForecast(xgb.dask.DaskXGBRegressor): 11 | @property 12 | def model_(self): 13 | model_str = self.get_booster().save_raw("ubj") 14 | local_model = xgb.XGBRegressor() 15 | local_model.load_model(model_str) 16 | return local_model 17 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/ray/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/mlforecast/distributed/models/ray/__init__.py -------------------------------------------------------------------------------- /mlforecast/distributed/models/ray/lgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.ray.lgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['RayLGBMForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.ray.lgb.ipynb 3 7 | import lightgbm as lgb 8 | from lightgbm_ray import RayLGBMRegressor 9 | 10 | # %% ../../../../nbs/distributed.models.ray.lgb.ipynb 4 11 | class RayLGBMForecast(RayLGBMRegressor): 12 | @property 13 | def model_(self): 14 | return self._lgb_ray_to_local(lgb.LGBMRegressor) 15 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/ray/xgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.ray.xgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['RayXGBForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.ray.xgb.ipynb 3 7 | import xgboost as xgb 8 | from xgboost_ray import RayXGBRegressor 9 | 10 | # %% ../../../../nbs/distributed.models.ray.xgb.ipynb 4 11 | class RayXGBForecast(RayXGBRegressor): 12 | @property 13 | def model_(self): 14 | model_str = self.get_booster().save_raw("ubj") 15 | local_model = xgb.XGBRegressor() 16 | local_model.load_model(model_str) 17 | return local_model 18 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/spark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/mlforecast/distributed/models/spark/__init__.py -------------------------------------------------------------------------------- /mlforecast/distributed/models/spark/lgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.spark.lgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SparkLGBMForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.spark.lgb.ipynb 3 7 | import lightgbm as lgb 8 | 9 | try: 10 | from synapse.ml.lightgbm import LightGBMRegressor 11 | except ModuleNotFoundError: 12 | import os 13 | 14 | if os.getenv("QUARTO_PREVIEW", "0") == "1" or os.getenv("IN_TEST", "0") == "1": 15 | LightGBMRegressor = object 16 | else: 17 | raise 18 | 19 | # %% ../../../../nbs/distributed.models.spark.lgb.ipynb 4 20 | class SparkLGBMForecast(LightGBMRegressor): 21 | def _pre_fit(self, target_col): 22 | return self.setLabelCol(target_col) 23 | 24 | def extract_local_model(self, trained_model): 25 | model_str = trained_model.getNativeModel() 26 | local_model = lgb.Booster(model_str=model_str) 27 | return local_model 28 | -------------------------------------------------------------------------------- /mlforecast/distributed/models/spark/xgb.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../../../../nbs/distributed.models.spark.xgb.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['SparkXGBForecast'] 5 | 6 | # %% ../../../../nbs/distributed.models.spark.xgb.ipynb 3 7 | import xgboost as xgb 8 | 9 | try: 10 | from xgboost.spark import SparkXGBRegressor # type: ignore 11 | except ModuleNotFoundError: 12 | import os 13 | 14 | if os.getenv("IN_TEST", "0") == "1": 15 | SparkXGBRegressor = object 16 | else: 17 | raise 18 | 19 | # %% ../../../../nbs/distributed.models.spark.xgb.ipynb 4 20 | class SparkXGBForecast(SparkXGBRegressor): 21 | def _pre_fit(self, target_col): 22 | self.setParams(label_col=target_col) 23 | return self 24 | 25 | def extract_local_model(self, trained_model): 26 | model_str = trained_model.get_booster().save_raw("ubj") 27 | local_model = xgb.XGBRegressor() 28 | local_model.load_model(model_str) 29 | return local_model 30 | -------------------------------------------------------------------------------- /mlforecast/feature_engineering.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/feature_engineering.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['transform_exog'] 5 | 6 | # %% ../nbs/feature_engineering.ipynb 3 7 | from typing import Optional 8 | 9 | import utilsforecast.processing as ufp 10 | from utilsforecast.compat import DFType 11 | from utilsforecast.validation import validate_format 12 | 13 | from .core import _parse_transforms, Lags, LagTransforms 14 | from .grouped_array import GroupedArray 15 | 16 | # %% ../nbs/feature_engineering.ipynb 4 17 | def transform_exog( 18 | df: DFType, 19 | lags: Optional[Lags] = None, 20 | lag_transforms: Optional[LagTransforms] = None, 21 | id_col: str = "unique_id", 22 | time_col: str = "ds", 23 | num_threads: int = 1, 24 | ) -> DFType: 25 | """Compute lag features for dynamic exogenous regressors. 26 | 27 | Parameters 28 | ---------- 29 | df : pandas or polars DataFrame 30 | Dataframe with ids, times and values for the exogenous regressors. 31 | lags : list of int, optional (default=None) 32 | Lags of the target to use as features. 33 | lag_transforms : dict of int to list of functions, optional (default=None) 34 | Mapping of target lags to their transformations. 35 | id_col : str (default='unique_id') 36 | Column that identifies each serie. 37 | time_col : str (default='ds') 38 | Column that identifies each timestep, its values can be timestamps or integers. 39 | num_threads : int (default=1) 40 | Number of threads to use when computing the features. 41 | 42 | Returns 43 | ------- 44 | pandas or polars DataFrame 45 | Original DataFrame with the computed features 46 | """ 47 | if lags is None and lag_transforms is None: 48 | raise ValueError("At least one of `lags` or `lag_transforms` is required.") 49 | if lags is None: 50 | lags = [] 51 | if lag_transforms is None: 52 | lag_transforms = {} 53 | tfms = _parse_transforms(lags, lag_transforms) 54 | targets = [c for c in df.columns if c not in (id_col, time_col)] 55 | # this is just a dummy target because process_df requires one 56 | target_col = targets[0] 57 | validate_format(df, id_col, time_col, target_col) 58 | processed = ufp.process_df(df, id_col, time_col, target_col) 59 | results = {} 60 | cols = [] 61 | for j, target in enumerate(targets): 62 | ga = GroupedArray(processed.data[:, j], processed.indptr) 63 | named_tfms = {f"{target}_{k}": v for k, v in tfms.items()} 64 | if num_threads == 1 or len(named_tfms) == 1: 65 | computed_tfms = ga.apply_transforms( 66 | transforms=named_tfms, updates_only=False 67 | ) 68 | else: 69 | computed_tfms = ga.apply_multithreaded_transforms( 70 | transforms=named_tfms, num_threads=num_threads, updates_only=False 71 | ) 72 | results.update(computed_tfms) 73 | cols.extend(list(named_tfms.keys())) 74 | if processed.sort_idxs is not None: 75 | base_df = ufp.take_rows(df, processed.sort_idxs) 76 | else: 77 | base_df = df 78 | base_df = ufp.drop_index_if_pandas(base_df) 79 | return ufp.horizontal_concat([base_df, type(df)(results)[cols]]) 80 | -------------------------------------------------------------------------------- /mlforecast/grouped_array.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/grouped_array.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['GroupedArray'] 5 | 6 | # %% ../nbs/grouped_array.ipynb 2 7 | import concurrent.futures 8 | from typing import Any, Dict, Mapping, Tuple, Union 9 | 10 | import numpy as np 11 | from coreforecast.grouped_array import GroupedArray as CoreGroupedArray 12 | from utilsforecast.compat import njit 13 | 14 | from .compat import shift_array 15 | from .lag_transforms import _BaseLagTransform 16 | 17 | # %% ../nbs/grouped_array.ipynb 3 18 | @njit(nogil=True) 19 | def _transform_series(data, indptr, updates_only, lag, func, *args) -> np.ndarray: 20 | """Shifts every group in `data` by `lag` and computes `func(shifted, *args)`. 21 | 22 | If `updates_only=True` only last value of the transformation for each group is returned, 23 | otherwise the full transformation is returned""" 24 | n_series = len(indptr) - 1 25 | if updates_only: 26 | out = np.empty_like(data[:n_series]) 27 | for i in range(n_series): 28 | lagged = shift_array(data[indptr[i] : indptr[i + 1]], lag) 29 | out[i] = func(lagged, *args)[-1] 30 | else: 31 | out = np.empty_like(data) 32 | for i in range(n_series): 33 | lagged = shift_array(data[indptr[i] : indptr[i + 1]], lag) 34 | out[indptr[i] : indptr[i + 1]] = func(lagged, *args) 35 | return out 36 | 37 | # %% ../nbs/grouped_array.ipynb 4 38 | class GroupedArray: 39 | """Array made up of different groups. Can be thought of (and iterated) as a list of arrays. 40 | 41 | All the data is stored in a single 1d array `data`. 42 | The indices for the group boundaries are stored in another 1d array `indptr`.""" 43 | 44 | def __init__(self, data: np.ndarray, indptr: np.ndarray): 45 | self.data = data 46 | self.indptr = indptr 47 | self.n_groups = len(indptr) - 1 48 | 49 | def __len__(self) -> int: 50 | return self.n_groups 51 | 52 | def __getitem__(self, idx: int) -> np.ndarray: 53 | return self.data[self.indptr[idx] : self.indptr[idx + 1]] 54 | 55 | def __setitem__(self, idx: int, vals: np.ndarray): 56 | if self[idx].size != vals.size: 57 | raise ValueError(f"vals must be of size {self[idx].size}") 58 | self[idx][:] = vals 59 | 60 | def __copy__(self): 61 | return GroupedArray(self.data.copy(), self.indptr) 62 | 63 | def take(self, idxs: np.ndarray) -> "GroupedArray": 64 | idxs = np.asarray(idxs) 65 | ranges = [range(self.indptr[i], self.indptr[i + 1]) for i in idxs] 66 | items = [self.data[rng] for rng in ranges] 67 | sizes = np.array([item.size for item in items]) 68 | data = np.hstack(items) 69 | indptr = np.append(0, sizes.cumsum()) 70 | return GroupedArray(data, indptr) 71 | 72 | def apply_transforms( 73 | self, 74 | transforms: Mapping[str, Union[Tuple[Any, ...], _BaseLagTransform]], 75 | updates_only: bool = False, 76 | ) -> Dict[str, np.ndarray]: 77 | """Apply the transformations using the main process. 78 | 79 | If `updates_only` then only the updates are returned. 80 | """ 81 | results = {} 82 | offset = 1 if updates_only else 0 83 | if any(isinstance(tfm, _BaseLagTransform) for tfm in transforms.values()): 84 | core_ga = CoreGroupedArray(self.data, self.indptr) 85 | for tfm_name, tfm in transforms.items(): 86 | if isinstance(tfm, _BaseLagTransform): 87 | if updates_only: 88 | results[tfm_name] = tfm.update(core_ga) 89 | else: 90 | results[tfm_name] = tfm.transform(core_ga) 91 | else: 92 | lag, tfm, *args = tfm 93 | results[tfm_name] = _transform_series( 94 | self.data, self.indptr, updates_only, lag - offset, tfm, *args 95 | ) 96 | return results 97 | 98 | def apply_multithreaded_transforms( 99 | self, 100 | transforms: Mapping[str, Union[Tuple[Any, ...], _BaseLagTransform]], 101 | num_threads: int, 102 | updates_only: bool = False, 103 | ) -> Dict[str, np.ndarray]: 104 | """Apply the transformations using multithreading. 105 | 106 | If `updates_only` then only the updates are returned. 107 | """ 108 | future_to_result = {} 109 | results = {} 110 | offset = 1 if updates_only else 0 111 | numba_tfms = {} 112 | core_tfms = {} 113 | for name, tfm in transforms.items(): 114 | if isinstance(tfm, _BaseLagTransform): 115 | core_tfms[name] = tfm 116 | else: 117 | numba_tfms[name] = tfm 118 | if numba_tfms: 119 | with concurrent.futures.ThreadPoolExecutor(num_threads) as executor: 120 | for tfm_name, (lag, tfm, *args) in numba_tfms.items(): 121 | future = executor.submit( 122 | _transform_series, 123 | self.data, 124 | self.indptr, 125 | updates_only, 126 | lag - offset, 127 | tfm, 128 | *args, 129 | ) 130 | future_to_result[future] = tfm_name 131 | for future in concurrent.futures.as_completed(future_to_result): 132 | tfm_name = future_to_result[future] 133 | results[tfm_name] = future.result() 134 | if core_tfms: 135 | core_ga = CoreGroupedArray(self.data, self.indptr, num_threads) 136 | for name, tfm in core_tfms.items(): 137 | if updates_only: 138 | results[name] = tfm.update(core_ga) 139 | else: 140 | results[name] = tfm.transform(core_ga) 141 | return results 142 | 143 | def expand_target(self, max_horizon: int) -> np.ndarray: 144 | out = np.full_like( 145 | self.data, np.nan, shape=(self.data.size, max_horizon), order="F" 146 | ) 147 | for j in range(max_horizon): 148 | for i in range(self.n_groups): 149 | if self.indptr[i + 1] - self.indptr[i] > j: 150 | out[self.indptr[i] : self.indptr[i + 1] - j, j] = self.data[ 151 | self.indptr[i] + j : self.indptr[i + 1] 152 | ] 153 | return out 154 | 155 | def take_from_groups(self, idx: Union[int, slice]) -> "GroupedArray": 156 | """Takes `idx` from each group in the array.""" 157 | ranges = [ 158 | range(self.indptr[i], self.indptr[i + 1])[idx] for i in range(self.n_groups) 159 | ] 160 | items = [self.data[rng] for rng in ranges] 161 | sizes = np.array([item.size for item in items]) 162 | data = np.hstack(items) 163 | indptr = np.append(0, sizes.cumsum()) 164 | return GroupedArray(data, indptr) 165 | 166 | def append(self, new_data: np.ndarray) -> "GroupedArray": 167 | """Appends each element of `new_data` to each existing group. Returns a copy.""" 168 | if new_data.size != self.n_groups: 169 | raise ValueError(f"`new_data` must be of size {self.n_groups:,}") 170 | core_ga = CoreGroupedArray(self.data, self.indptr) 171 | new_data = new_data.astype(self.data.dtype, copy=False) 172 | new_indptr = np.arange(self.n_groups + 1, dtype=np.int32) 173 | new_ga = CoreGroupedArray(new_data, new_indptr) 174 | combined = core_ga._append(new_ga) 175 | return GroupedArray(combined.data, combined.indptr) 176 | 177 | def append_several( 178 | self, new_sizes: np.ndarray, new_values: np.ndarray, new_groups: np.ndarray 179 | ) -> "GroupedArray": 180 | new_data = np.empty(self.data.size + new_values.size, dtype=self.data.dtype) 181 | new_indptr = np.empty(new_sizes.size + 1, dtype=self.indptr.dtype) 182 | new_indptr[0] = 0 183 | old_indptr_idx = 0 184 | new_vals_idx = 0 185 | for i, is_new in enumerate(new_groups): 186 | new_size = new_sizes[i] 187 | if is_new: 188 | old_size = 0 189 | else: 190 | prev_slice = slice( 191 | self.indptr[old_indptr_idx], self.indptr[old_indptr_idx + 1] 192 | ) 193 | old_indptr_idx += 1 194 | old_size = prev_slice.stop - prev_slice.start 195 | new_size += old_size 196 | new_data[new_indptr[i] : new_indptr[i] + old_size] = self.data[ 197 | prev_slice 198 | ] 199 | new_indptr[i + 1] = new_indptr[i] + new_size 200 | new_data[new_indptr[i] + old_size : new_indptr[i + 1]] = new_values[ 201 | new_vals_idx : new_vals_idx + new_sizes[i] 202 | ] 203 | new_vals_idx += new_sizes[i] 204 | return GroupedArray(new_data, new_indptr) 205 | 206 | def __repr__(self) -> str: 207 | return f"{self.__class__.__name__}(ndata={self.data.size}, n_groups={self.n_groups})" 208 | -------------------------------------------------------------------------------- /mlforecast/lag_transforms.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/lag_transforms.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['RollingMean', 'RollingStd', 'RollingMin', 'RollingMax', 'RollingQuantile', 'SeasonalRollingMean', 5 | 'SeasonalRollingStd', 'SeasonalRollingMin', 'SeasonalRollingMax', 'SeasonalRollingQuantile', 'ExpandingMean', 6 | 'ExpandingStd', 'ExpandingMin', 'ExpandingMax', 'ExpandingQuantile', 'ExponentiallyWeightedMean', 'Offset', 7 | 'Combine'] 8 | 9 | # %% ../nbs/lag_transforms.ipynb 3 10 | import copy 11 | import inspect 12 | import re 13 | from typing import Callable, Optional, Sequence 14 | 15 | import numpy as np 16 | import coreforecast.lag_transforms as core_tfms 17 | from coreforecast.grouped_array import GroupedArray as CoreGroupedArray 18 | from sklearn.base import BaseEstimator, clone 19 | 20 | # %% ../nbs/lag_transforms.ipynb 4 21 | def _pascal2camel(pascal_str: str) -> str: 22 | return re.sub(r"(? "_BaseLagTransform": 34 | init_args = {k: getattr(self, k) for k in self._get_init_signature()} 35 | self._core_tfm = getattr(core_tfms, self.__class__.__name__)( 36 | lag=lag, **init_args 37 | ) 38 | return self 39 | 40 | def _get_name(self, lag: int) -> str: 41 | init_params = self._get_init_signature() 42 | result = f"{_pascal2camel(self.__class__.__name__)}_lag{lag}" 43 | changed_params = [ 44 | f"{name}{getattr(self, name)}" 45 | for name, arg in init_params.items() 46 | if arg.default != getattr(self, name) 47 | ] 48 | if changed_params: 49 | result += "_" + "_".join(changed_params) 50 | return result 51 | 52 | def transform(self, ga: CoreGroupedArray) -> np.ndarray: 53 | return self._core_tfm.transform(ga) 54 | 55 | def update(self, ga: CoreGroupedArray) -> np.ndarray: 56 | return self._core_tfm.update(ga) 57 | 58 | def take(self, idxs: np.ndarray) -> "_BaseLagTransform": 59 | out = copy.deepcopy(self) 60 | out._core_tfm = self._core_tfm.take(idxs) 61 | return out 62 | 63 | @staticmethod 64 | def stack(transforms: Sequence["_BaseLagTransform"]) -> "_BaseLagTransform": 65 | out = copy.deepcopy(transforms[0]) 66 | out._core_tfm = transforms[0]._core_tfm.stack( 67 | [tfm._core_tfm for tfm in transforms] 68 | ) 69 | return out 70 | 71 | @property 72 | def _lag(self): 73 | return self._core_tfm.lag - 1 74 | 75 | @property 76 | def update_samples(self) -> int: 77 | return -1 78 | 79 | # %% ../nbs/lag_transforms.ipynb 6 80 | class Lag(_BaseLagTransform): 81 | 82 | def __init__(self, lag: int): 83 | self.lag = lag 84 | self._core_tfm = core_tfms.Lag(lag=lag) 85 | 86 | def _set_core_tfm(self, _lag: int) -> "Lag": 87 | return self 88 | 89 | def _get_name(self, lag: int) -> str: 90 | return f"lag{lag}" 91 | 92 | def __eq__(self, other): 93 | return isinstance(other, Lag) and self.lag == other.lag 94 | 95 | @property 96 | def update_samples(self) -> int: 97 | return self.lag 98 | 99 | # %% ../nbs/lag_transforms.ipynb 7 100 | class _RollingBase(_BaseLagTransform): 101 | "Rolling statistic" 102 | 103 | def __init__(self, window_size: int, min_samples: Optional[int] = None): 104 | """ 105 | Parameters 106 | ---------- 107 | window_size : int 108 | Number of samples in the window. 109 | min_samples: int 110 | Minimum samples required to output the statistic. 111 | If `None`, will be set to `window_size`. 112 | """ 113 | self.window_size = window_size 114 | self.min_samples = min_samples 115 | 116 | @property 117 | def update_samples(self) -> int: 118 | return self._lag + self.window_size 119 | 120 | # %% ../nbs/lag_transforms.ipynb 8 121 | class RollingMean(_RollingBase): ... 122 | 123 | 124 | class RollingStd(_RollingBase): ... 125 | 126 | 127 | class RollingMin(_RollingBase): ... 128 | 129 | 130 | class RollingMax(_RollingBase): ... 131 | 132 | 133 | class RollingQuantile(_RollingBase): 134 | def __init__(self, p: float, window_size: int, min_samples: Optional[int] = None): 135 | super().__init__(window_size=window_size, min_samples=min_samples) 136 | self.p = p 137 | 138 | def _set_core_tfm(self, lag: int): 139 | self._core_tfm = core_tfms.RollingQuantile( 140 | lag=lag, 141 | p=self.p, 142 | window_size=self.window_size, 143 | min_samples=self.min_samples, 144 | ) 145 | return self 146 | 147 | # %% ../nbs/lag_transforms.ipynb 10 148 | class _Seasonal_RollingBase(_BaseLagTransform): 149 | """Rolling statistic over seasonal periods""" 150 | 151 | def __init__( 152 | self, season_length: int, window_size: int, min_samples: Optional[int] = None 153 | ): 154 | """ 155 | Parameters 156 | ---------- 157 | season_length : int 158 | Periodicity of the seasonal period. 159 | window_size : int 160 | Number of samples in the window. 161 | min_samples: int 162 | Minimum samples required to output the statistic. 163 | If `None`, will be set to `window_size`. 164 | """ 165 | self.season_length = season_length 166 | self.window_size = window_size 167 | self.min_samples = min_samples 168 | 169 | @property 170 | def update_samples(self) -> int: 171 | return self._lag + self.season_length * self.window_size 172 | 173 | # %% ../nbs/lag_transforms.ipynb 11 174 | class SeasonalRollingMean(_Seasonal_RollingBase): ... 175 | 176 | 177 | class SeasonalRollingStd(_Seasonal_RollingBase): ... 178 | 179 | 180 | class SeasonalRollingMin(_Seasonal_RollingBase): ... 181 | 182 | 183 | class SeasonalRollingMax(_Seasonal_RollingBase): ... 184 | 185 | 186 | class SeasonalRollingQuantile(_Seasonal_RollingBase): 187 | def __init__( 188 | self, 189 | p: float, 190 | season_length: int, 191 | window_size: int, 192 | min_samples: Optional[int] = None, 193 | ): 194 | super().__init__( 195 | season_length=season_length, 196 | window_size=window_size, 197 | min_samples=min_samples, 198 | ) 199 | self.p = p 200 | 201 | # %% ../nbs/lag_transforms.ipynb 13 202 | class _ExpandingBase(_BaseLagTransform): 203 | """Expanding statistic""" 204 | 205 | def __init__(self): ... 206 | 207 | @property 208 | def update_samples(self) -> int: 209 | return 1 210 | 211 | # %% ../nbs/lag_transforms.ipynb 14 212 | class ExpandingMean(_ExpandingBase): ... 213 | 214 | 215 | class ExpandingStd(_ExpandingBase): ... 216 | 217 | 218 | class ExpandingMin(_ExpandingBase): ... 219 | 220 | 221 | class ExpandingMax(_ExpandingBase): ... 222 | 223 | 224 | class ExpandingQuantile(_ExpandingBase): 225 | def __init__(self, p: float): 226 | self.p = p 227 | 228 | @property 229 | def update_samples(self) -> int: 230 | return -1 231 | 232 | # %% ../nbs/lag_transforms.ipynb 16 233 | class ExponentiallyWeightedMean(_BaseLagTransform): 234 | """Exponentially weighted average 235 | 236 | Parameters 237 | ---------- 238 | alpha : float 239 | Smoothing factor.""" 240 | 241 | def __init__(self, alpha: float): 242 | self.alpha = alpha 243 | 244 | @property 245 | def update_samples(self) -> int: 246 | return 1 247 | 248 | # %% ../nbs/lag_transforms.ipynb 18 249 | class Offset(_BaseLagTransform): 250 | """Shift series before computing transformation 251 | 252 | Parameters 253 | ---------- 254 | tfm : LagTransform 255 | Transformation to be applied 256 | n : int 257 | Number of positions to shift (lag) series before applying the transformation""" 258 | 259 | def __init__(self, tfm: _BaseLagTransform, n: int): 260 | self.tfm = tfm 261 | self.n = n 262 | 263 | def _get_name(self, lag: int) -> str: 264 | return self.tfm._get_name(lag + self.n) 265 | 266 | def _set_core_tfm(self, lag: int) -> "Offset": 267 | self.tfm = clone(self.tfm)._set_core_tfm(lag + self.n) 268 | self._core_tfm = self.tfm._core_tfm 269 | return self 270 | 271 | @property 272 | def update_samples(self) -> int: 273 | return self.tfm.update_samples + self.n 274 | 275 | # %% ../nbs/lag_transforms.ipynb 20 276 | class Combine(_BaseLagTransform): 277 | """Combine two lag transformations using an operator 278 | 279 | Parameters 280 | ---------- 281 | tfm1 : LagTransform 282 | First transformation. 283 | tfm2 : LagTransform 284 | Second transformation. 285 | operator : callable 286 | Binary operator that defines how to combine the two transformations.""" 287 | 288 | def __init__( 289 | self, tfm1: _BaseLagTransform, tfm2: _BaseLagTransform, operator: Callable 290 | ): 291 | self.tfm1 = tfm1 292 | self.tfm2 = tfm2 293 | self.operator = operator 294 | 295 | def _set_core_tfm(self, lag: int) -> "Combine": 296 | self.tfm1 = clone(self.tfm1)._set_core_tfm(lag) 297 | self.tfm2 = clone(self.tfm2)._set_core_tfm(lag) 298 | return self 299 | 300 | def _get_name(self, lag: int) -> str: 301 | lag1 = getattr(self.tfm1, "lag", lag) 302 | lag2 = getattr(self.tfm2, "lag", lag) 303 | return f"{self.tfm1._get_name(lag1)}_{self.operator.__name__}_{self.tfm2._get_name(lag2)}" 304 | 305 | def transform(self, ga: CoreGroupedArray) -> np.ndarray: 306 | return self.operator(self.tfm1.transform(ga), self.tfm2.transform(ga)) 307 | 308 | def update(self, ga: CoreGroupedArray) -> np.ndarray: 309 | return self.operator(self.tfm1.update(ga), self.tfm2.update(ga)) 310 | 311 | @property 312 | def update_samples(self): 313 | return max(self.tfm1.update_samples, self.tfm2.update_samples) 314 | -------------------------------------------------------------------------------- /mlforecast/optimization.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/optimization.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['mlforecast_objective'] 5 | 6 | # %% ../nbs/optimization.ipynb 3 7 | import copy 8 | from typing import Any, Callable, Dict, Optional, Union 9 | 10 | import numpy as np 11 | import optuna 12 | import utilsforecast.processing as ufp 13 | from sklearn.base import BaseEstimator, clone 14 | from utilsforecast.compat import DataFrame 15 | 16 | from . import MLForecast 17 | from .compat import CatBoostRegressor 18 | from .core import Freq 19 | 20 | # %% ../nbs/optimization.ipynb 4 21 | _TrialToConfig = Callable[[optuna.Trial], Dict[str, Any]] 22 | 23 | # %% ../nbs/optimization.ipynb 5 24 | def mlforecast_objective( 25 | df: DataFrame, 26 | config_fn: _TrialToConfig, 27 | loss: Callable, 28 | model: BaseEstimator, 29 | freq: Freq, 30 | n_windows: int, 31 | h: int, 32 | step_size: Optional[int] = None, 33 | input_size: Optional[int] = None, 34 | refit: Union[bool, int] = False, 35 | id_col: str = "unique_id", 36 | time_col: str = "ds", 37 | target_col: str = "y", 38 | ) -> Callable[[optuna.Trial], float]: 39 | """optuna objective function for the MLForecast class 40 | 41 | Parameters 42 | ---------- 43 | config_fn : callable 44 | Function that takes an optuna trial and produces a configuration with the following keys: 45 | - model_params 46 | - mlf_init_params 47 | - mlf_fit_params 48 | loss : callable 49 | Function that takes the validation and train dataframes and produces a float. 50 | model : BaseEstimator 51 | scikit-learn compatible model to be trained 52 | freq : str or int 53 | pandas' or polars' offset alias or integer denoting the frequency of the series. 54 | n_windows : int 55 | Number of windows to evaluate. 56 | h : int 57 | Forecast horizon. 58 | step_size : int, optional (default=None) 59 | Step size between each cross validation window. If None it will be equal to `h`. 60 | input_size : int, optional (default=None) 61 | Maximum training samples per serie in each window. If None, will use an expanding window. 62 | refit : bool or int (default=False) 63 | Retrain model for each cross validation window. 64 | If False, the models are trained at the beginning and then used to predict each window. 65 | If positive int, the models are retrained every `refit` windows. 66 | id_col : str (default='unique_id') 67 | Column that identifies each serie. 68 | time_col : str (default='ds') 69 | Column that identifies each timestep, its values can be timestamps or integers. 70 | target_col : str (default='y') 71 | Column that contains the target. 72 | study_kwargs : dict, optional (default=None) 73 | """ 74 | 75 | def objective(trial: optuna.Trial) -> float: 76 | config = config_fn(trial) 77 | trial.set_user_attr("config", copy.deepcopy(config)) 78 | if all( 79 | config["mlf_init_params"].get(k, None) is None 80 | for k in ["lags", "lag_transforms", "date_features"] 81 | ): 82 | # no features 83 | return np.inf 84 | splits = ufp.backtest_splits( 85 | df, 86 | n_windows=n_windows, 87 | h=h, 88 | id_col=id_col, 89 | time_col=time_col, 90 | freq=freq, 91 | step_size=step_size, 92 | input_size=input_size, 93 | ) 94 | model_copy = clone(model) 95 | model_params = config["model_params"] 96 | if config["mlf_fit_params"].get("static_features", []) and isinstance( 97 | model, CatBoostRegressor 98 | ): 99 | # catboost needs the categorical features in the init signature 100 | # we assume all statics are categoricals 101 | model_params["cat_features"] = config["mlf_fit_params"]["static_features"] 102 | model_copy.set_params(**config["model_params"]) 103 | metrics = [] 104 | mlf = MLForecast( 105 | models={"model": model_copy}, 106 | freq=freq, 107 | **config["mlf_init_params"], 108 | ) 109 | for i, (_, train, valid) in enumerate(splits): 110 | should_fit = i == 0 or (refit > 0 and i % refit == 0) 111 | if should_fit: 112 | mlf.fit( 113 | train, 114 | id_col=id_col, 115 | time_col=time_col, 116 | target_col=target_col, 117 | **config["mlf_fit_params"], 118 | ) 119 | static = [c for c in mlf.ts.static_features_.columns if c != id_col] 120 | dynamic = [ 121 | c 122 | for c in valid.columns 123 | if c not in static + [id_col, time_col, target_col] 124 | ] 125 | if dynamic: 126 | X_df: Optional[DataFrame] = ufp.drop_columns( 127 | valid, static + [target_col] 128 | ) 129 | else: 130 | X_df = None 131 | preds = mlf.predict( 132 | h=h, 133 | X_df=X_df, 134 | new_df=None if should_fit else train, 135 | ) 136 | result = ufp.join( 137 | valid[[id_col, time_col, target_col]], 138 | preds, 139 | on=[id_col, time_col], 140 | ) 141 | if result.shape[0] < valid.shape[0]: 142 | raise ValueError( 143 | "Cross validation result produced less results than expected. " 144 | "Please verify that the passed frequency (freq) matches your series' " 145 | "and that there aren't any missing periods." 146 | ) 147 | metric = loss(result, train_df=train) 148 | metrics.append(metric) 149 | trial.report(metric, step=i) 150 | if trial.should_prune(): 151 | raise optuna.TrialPruned() 152 | return np.mean(metrics).item() 153 | 154 | return objective 155 | -------------------------------------------------------------------------------- /mlforecast/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/mlforecast/py.typed -------------------------------------------------------------------------------- /mlforecast/utils.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/utils.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['generate_daily_series', 'generate_prices_for_series', 'PredictionIntervals'] 5 | 6 | # %% ../nbs/utils.ipynb 3 7 | from math import ceil, log10 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from utilsforecast.compat import DataFrame, pl 13 | from utilsforecast.data import generate_series 14 | 15 | # %% ../nbs/utils.ipynb 5 16 | def generate_daily_series( 17 | n_series: int, 18 | min_length: int = 50, 19 | max_length: int = 500, 20 | n_static_features: int = 0, 21 | equal_ends: bool = False, 22 | static_as_categorical: bool = True, 23 | with_trend: bool = False, 24 | seed: int = 0, 25 | engine: str = "pandas", 26 | ) -> DataFrame: 27 | """Generate Synthetic Panel Series. 28 | 29 | Parameters 30 | ---------- 31 | n_series : int 32 | Number of series for synthetic panel. 33 | min_length : int (default=50) 34 | Minimum length of synthetic panel's series. 35 | max_length : int (default=500) 36 | Maximum length of synthetic panel's series. 37 | n_static_features : int (default=0) 38 | Number of static exogenous variables for synthetic panel's series. 39 | equal_ends : bool (default=False) 40 | Series should end in the same date stamp `ds`. 41 | static_as_categorical : bool (default=True) 42 | Static features should have a categorical data type. 43 | with_trend : bool (default=False) 44 | Series should have a (positive) trend. 45 | seed : int (default=0) 46 | Random seed used for generating the data. 47 | engine : str (default='pandas') 48 | Output Dataframe type. 49 | 50 | Returns 51 | ------- 52 | pandas or polars DataFrame 53 | Synthetic panel with columns [`unique_id`, `ds`, `y`] and exogenous features. 54 | """ 55 | series = generate_series( 56 | n_series=n_series, 57 | freq="D", 58 | min_length=min_length, 59 | max_length=max_length, 60 | n_static_features=n_static_features, 61 | equal_ends=equal_ends, 62 | static_as_categorical=static_as_categorical, 63 | with_trend=with_trend, 64 | seed=seed, 65 | engine=engine, 66 | ) 67 | n_digits = ceil(log10(n_series)) 68 | 69 | if engine == "pandas": 70 | series["unique_id"] = ( 71 | "id_" + series["unique_id"].astype(str).str.rjust(n_digits, "0") 72 | ).astype("category") 73 | else: 74 | try: 75 | series = series.with_columns( 76 | ("id_" + pl.col("unique_id").cast(pl.Utf8).str.pad_start(n_digits, "0")) 77 | .alias("unique_id") 78 | .cast(pl.Categorical) 79 | ) 80 | except AttributeError: 81 | series = series.with_columns( 82 | ("id_" + pl.col("unique_id").cast(pl.Utf8).str.rjust(n_digits, "0")) 83 | .alias("unique_id") 84 | .cast(pl.Categorical) 85 | ) 86 | return series 87 | 88 | # %% ../nbs/utils.ipynb 16 89 | def generate_prices_for_series( 90 | series: pd.DataFrame, horizon: int = 7, seed: int = 0 91 | ) -> pd.DataFrame: 92 | rng = np.random.RandomState(seed) 93 | unique_last_dates = series.groupby("unique_id", observed=True)["ds"].max().nunique() 94 | if unique_last_dates > 1: 95 | raise ValueError("series must have equal ends.") 96 | day_offset = pd.tseries.frequencies.Day() 97 | starts_ends = series.groupby("unique_id", observed=True)["ds"].agg(["min", "max"]) 98 | dfs = [] 99 | for idx, (start, end) in starts_ends.iterrows(): 100 | product_df = pd.DataFrame( 101 | { 102 | "unique_id": idx, 103 | "price": rng.rand((end - start).days + 1 + horizon), 104 | }, 105 | index=pd.date_range(start, end + horizon * day_offset, name="ds"), 106 | ) 107 | dfs.append(product_df) 108 | prices_catalog = pd.concat(dfs).reset_index() 109 | return prices_catalog 110 | 111 | # %% ../nbs/utils.ipynb 19 112 | class PredictionIntervals: 113 | """Class for storing prediction intervals metadata information.""" 114 | 115 | def __init__( 116 | self, 117 | n_windows: int = 2, 118 | h: int = 1, 119 | method: str = "conformal_distribution", 120 | ): 121 | if n_windows < 2: 122 | raise ValueError( 123 | "You need at least two windows to compute conformal intervals" 124 | ) 125 | allowed_methods = ["conformal_error", "conformal_distribution"] 126 | if method not in allowed_methods: 127 | raise ValueError(f"method must be one of {allowed_methods}") 128 | self.n_windows = n_windows 129 | self.h = h 130 | self.method = method 131 | 132 | def __repr__(self): 133 | return f"PredictionIntervals(n_windows={self.n_windows}, h={self.h}, method='{self.method}')" 134 | 135 | # %% ../nbs/utils.ipynb 20 136 | class _ShortSeriesException(Exception): 137 | def __init__(self, idxs): 138 | self.idxs = idxs 139 | -------------------------------------------------------------------------------- /nbs/.gitignore: -------------------------------------------------------------------------------- 1 | /.quarto/ 2 | -------------------------------------------------------------------------------- /nbs/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | fontsize: 1em 8 | linestretch: 1.7 9 | css: styles.css 10 | toc: true 11 | 12 | website: 13 | twitter-card: 14 | image: "https://farm6.staticflickr.com/5510/14338202952_93595258ff_z.jpg" 15 | site: "@Nixtlainc" 16 | open-graph: 17 | image: "https://github.com/Nixtla/styles/blob/2abf51612584169874c90cd7c4d347e3917eaf73/images/Banner%20Github.png" 18 | google-analytics: "G-NXJNCVR18L" 19 | repo-actions: [issue] 20 | favicon: favicon.png 21 | navbar: 22 | background: primary 23 | search: true 24 | collapse-below: lg 25 | left: 26 | - text: "Get Started" 27 | href: docs/getting-started/quick_start_local.ipynb 28 | - text: "NixtlaVerse" 29 | menu: 30 | - text: "StatsForecast ⚡️" 31 | href: https://github.com/nixtla/statsforecast 32 | - text: "NeuralForecast 🧠" 33 | href: https://github.com/nixtla/neuralforecast 34 | - text: "HierarchicalForecast 👑" 35 | href: "https://github.com/nixtla/hierarchicalforecast" 36 | 37 | - text: "Help" 38 | menu: 39 | - text: "Report an Issue" 40 | icon: bug 41 | href: https://github.com/nixtla/mlforecast/issues/new/choose 42 | - text: "Join our Slack" 43 | icon: chat-right-text 44 | href: https://join.slack.com/t/nixtlaworkspace/shared_invite/zt-135dssye9-fWTzMpv2WBthq8NK0Yvu6A 45 | right: 46 | - icon: github 47 | href: "https://github.com/nixtla/mlforecast" 48 | - icon: twitter 49 | href: https://twitter.com/nixtlainc 50 | aria-label: Nixtla Twitter 51 | 52 | sidebar: 53 | style: floating 54 | body-footer: | 55 | Give us a ⭐ on [Github](https://github.com/nixtla/mlforecast) 56 | 57 | metadata-files: [nbdev.yml, sidebar.yml] 58 | -------------------------------------------------------------------------------- /nbs/callbacks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f4644ef9-9ae6-40e8-91c8-0c7e042af123", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| default_exp callbacks" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "21b4ed66-5188-4325-b6bb-0f1e4b5429d9", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "#| hide\n", 21 | "%load_ext autoreload\n", 22 | "%autoreload 2" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "afe10c28-fec7-40b7-9f73-ef2094ade7fc", 28 | "metadata": {}, 29 | "source": [ 30 | "# Callbacks\n", 31 | "Utility functions use in the predict step." 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "ae0253fb-eb18-4188-985f-5842f45f2db7", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "#| export\n", 42 | "from utilsforecast.compat import DataFrame\n", 43 | "from utilsforecast.processing import (\n", 44 | " assign_columns,\n", 45 | " drop_index_if_pandas,\n", 46 | " vertical_concat\n", 47 | ")" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "da0e5c57-2eac-452a-8435-edb0e919bc65", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "#| export\n", 58 | "class SaveFeatures:\n", 59 | " \"\"\"Saves the features in every timestamp.\"\"\"\n", 60 | " def __init__(self):\n", 61 | " self._inputs = []\n", 62 | "\n", 63 | " def __call__(self, new_x):\n", 64 | " self._inputs.append(new_x)\n", 65 | " return new_x\n", 66 | "\n", 67 | " def get_features(self, with_step: bool = False) -> DataFrame:\n", 68 | " \"\"\"Retrieves the input features for every timestep\n", 69 | " \n", 70 | " Parameters\n", 71 | " ----------\n", 72 | " with_step : bool\n", 73 | " Add a column indicating the step\n", 74 | " \n", 75 | " Returns\n", 76 | " -------\n", 77 | " pandas or polars DataFrame\n", 78 | " DataFrame with input features\n", 79 | " \"\"\"\n", 80 | " if not self._inputs:\n", 81 | " raise ValueError(\n", 82 | " 'Inputs list is empty. '\n", 83 | " 'Call `predict` using this callback as before_predict_callback'\n", 84 | " )\n", 85 | " if with_step:\n", 86 | " dfs = [assign_columns(df, 'step', i) for i, df in enumerate(self._inputs)]\n", 87 | " else:\n", 88 | " dfs = self._inputs\n", 89 | " res = vertical_concat(dfs, match_categories=False)\n", 90 | " res = drop_index_if_pandas(res)\n", 91 | " return res" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "dac9a0ab-140e-43be-abe0-70ea1c71505e", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "#| hide\n", 102 | "from nbdev import show_doc" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "14b534fa-3247-4883-84c5-0ddf2d684397", 109 | "metadata": {}, 110 | "outputs": [ 111 | { 112 | "data": { 113 | "text/markdown": [ 114 | "---\n", 115 | "\n", 116 | "### SaveFeatures\n", 117 | "\n", 118 | "> SaveFeatures ()\n", 119 | "\n", 120 | "Saves the features in every timestamp." 121 | ], 122 | "text/plain": [ 123 | "---\n", 124 | "\n", 125 | "### SaveFeatures\n", 126 | "\n", 127 | "> SaveFeatures ()\n", 128 | "\n", 129 | "Saves the features in every timestamp." 130 | ] 131 | }, 132 | "execution_count": null, 133 | "metadata": {}, 134 | "output_type": "execute_result" 135 | } 136 | ], 137 | "source": [ 138 | "show_doc(SaveFeatures)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "557189a2-24fb-4f12-b4ad-b9bb080d6b51", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "data": { 149 | "text/markdown": [ 150 | "---\n", 151 | "\n", 152 | "### SaveFeatures.get_features\n", 153 | "\n", 154 | "> SaveFeatures.get_features (with_step:bool=False)\n", 155 | "\n", 156 | "Retrieves the input features for every timestep\n", 157 | "\n", 158 | "| | **Type** | **Default** | **Details** |\n", 159 | "| -- | -------- | ----------- | ----------- |\n", 160 | "| with_step | bool | False | Add a column indicating the step |\n", 161 | "| **Returns** | **Union** | | **DataFrame with input features** |" 162 | ], 163 | "text/plain": [ 164 | "---\n", 165 | "\n", 166 | "### SaveFeatures.get_features\n", 167 | "\n", 168 | "> SaveFeatures.get_features (with_step:bool=False)\n", 169 | "\n", 170 | "Retrieves the input features for every timestep\n", 171 | "\n", 172 | "| | **Type** | **Default** | **Details** |\n", 173 | "| -- | -------- | ----------- | ----------- |\n", 174 | "| with_step | bool | False | Add a column indicating the step |\n", 175 | "| **Returns** | **Union** | | **DataFrame with input features** |" 176 | ] 177 | }, 178 | "execution_count": null, 179 | "metadata": {}, 180 | "output_type": "execute_result" 181 | } 182 | ], 183 | "source": [ 184 | "show_doc(SaveFeatures.get_features)" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "python3", 191 | "language": "python", 192 | "name": "python3" 193 | } 194 | }, 195 | "nbformat": 4, 196 | "nbformat_minor": 5 197 | } 198 | -------------------------------------------------------------------------------- /nbs/compat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "955a6b2e-26a9-4fe2-b971-8379ff23fc3f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| default_exp compat" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "b3af4b54-1713-4171-b15a-a911ed696933", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "#| export\n", 21 | "try:\n", 22 | " from catboost import CatBoostRegressor\n", 23 | "except ImportError:\n", 24 | " class CatBoostRegressor:\n", 25 | " def __init__(self, *args, **kwargs): # noqa: ARG002\n", 26 | " raise ImportError(\n", 27 | " \"Please install catboost to use this model.\"\n", 28 | " )\n", 29 | "\n", 30 | "try:\n", 31 | " from lightgbm import LGBMRegressor\n", 32 | "except ImportError:\n", 33 | " class LGBMRegressor:\n", 34 | " def __init__(self, *args, **kwargs): # noqa: ARG002\n", 35 | " raise ImportError(\n", 36 | " \"Please install lightgbm to use this model.\"\n", 37 | " )\n", 38 | "\n", 39 | "try:\n", 40 | " from xgboost import XGBRegressor\n", 41 | "except ImportError:\n", 42 | " class XGBRegressor:\n", 43 | " def __init__(self, *args, **kwargs): # noqa: ARG002\n", 44 | " raise ImportError(\n", 45 | " \"Please install xgboost to use this model.\"\n", 46 | " )\n", 47 | "\n", 48 | "try:\n", 49 | " from window_ops.shift import shift_array\n", 50 | "except ImportError:\n", 51 | " import numpy as np\n", 52 | " from utilsforecast.compat import njit\n", 53 | "\n", 54 | " @njit\n", 55 | " def shift_array(x, offset):\n", 56 | " if offset >= x.size or offset < 0:\n", 57 | " return np.full_like(x, np.nan)\n", 58 | " if offset == 0:\n", 59 | " return x.copy()\n", 60 | " out = np.empty_like(x)\n", 61 | " out[:offset] = np.nan\n", 62 | " out[offset:] = x[:-offset]\n", 63 | " return out" 64 | ] 65 | } 66 | ], 67 | "metadata": { 68 | "kernelspec": { 69 | "display_name": "python3", 70 | "language": "python", 71 | "name": "python3" 72 | } 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 5 76 | } 77 | -------------------------------------------------------------------------------- /nbs/distributed.models.dask.lgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.dask.lgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "12972535-1a7c-4814-a19c-5e2c48824e85", 16 | "metadata": {}, 17 | "source": [ 18 | "# DaskLGBMForecast\n", 19 | "\n", 20 | "> dask LightGBM forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "fd9d0998-ca46-4e7a-9c64-b8378c0c1b85", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `lightgbm.dask.DaskLGBMRegressor` that adds a `model_` property that contains the fitted booster and is sent to the workers to in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import warnings\n", 40 | "\n", 41 | "import lightgbm as lgb" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "#|export\n", 52 | "class DaskLGBMForecast(lgb.dask.DaskLGBMRegressor):\n", 53 | " if lgb.__version__ < \"3.3.0\":\n", 54 | " warnings.warn(\n", 55 | " \"It is recommended to install LightGBM version >= 3.3.0, since \"\n", 56 | " \"the current LightGBM version might be affected by https://github.com/microsoft/LightGBM/issues/4026, \"\n", 57 | " \"which was fixed in 3.3.0\"\n", 58 | " )\n", 59 | "\n", 60 | " @property\n", 61 | " def model_(self):\n", 62 | " return self.to_local()" 63 | ] 64 | } 65 | ], 66 | "metadata": { 67 | "kernelspec": { 68 | "display_name": "python3", 69 | "language": "python", 70 | "name": "python3" 71 | } 72 | }, 73 | "nbformat": 4, 74 | "nbformat_minor": 5 75 | } 76 | -------------------------------------------------------------------------------- /nbs/distributed.models.dask.xgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.dask.xgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "5ee154af-e882-4914-8bf2-f536a8d01b94", 16 | "metadata": {}, 17 | "source": [ 18 | "# DaskXGBForecast\n", 19 | "\n", 20 | "> dask XGBoost forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "4f4c7bc1-9779-4771-8224-f852e6b7987c", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `xgboost.dask.DaskXGBRegressor` that adds a `model_` property that contains the fitted model and is sent to the workers in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import xgboost as xgb" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "#|export\n", 50 | "class DaskXGBForecast(xgb.dask.DaskXGBRegressor):\n", 51 | " @property\n", 52 | " def model_(self):\n", 53 | " model_str = self.get_booster().save_raw('ubj')\n", 54 | " local_model = xgb.XGBRegressor()\n", 55 | " local_model.load_model(model_str)\n", 56 | " return local_model" 57 | ] 58 | } 59 | ], 60 | "metadata": { 61 | "kernelspec": { 62 | "display_name": "python3", 63 | "language": "python", 64 | "name": "python3" 65 | } 66 | }, 67 | "nbformat": 4, 68 | "nbformat_minor": 5 69 | } 70 | -------------------------------------------------------------------------------- /nbs/distributed.models.ray.lgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.ray.lgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "12972535-1a7c-4814-a19c-5e2c48824e85", 16 | "metadata": {}, 17 | "source": [ 18 | "# RayLGBMForecast\n", 19 | "\n", 20 | "> ray LightGBM forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "fd9d0998-ca46-4e7a-9c64-b8378c0c1b85", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `lightgbm.ray.RayLGBMRegressor` that adds a `model_` property that contains the fitted booster and is sent to the workers to in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import lightgbm as lgb\n", 40 | "from lightgbm_ray import RayLGBMRegressor" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "#|export\n", 51 | "class RayLGBMForecast(RayLGBMRegressor):\n", 52 | " @property\n", 53 | " def model_(self):\n", 54 | " return self._lgb_ray_to_local(lgb.LGBMRegressor)" 55 | ] 56 | } 57 | ], 58 | "metadata": { 59 | "kernelspec": { 60 | "display_name": "python3", 61 | "language": "python", 62 | "name": "python3" 63 | } 64 | }, 65 | "nbformat": 4, 66 | "nbformat_minor": 5 67 | } 68 | -------------------------------------------------------------------------------- /nbs/distributed.models.ray.xgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.ray.xgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "5ee154af-e882-4914-8bf2-f536a8d01b94", 16 | "metadata": {}, 17 | "source": [ 18 | "# RayXGBForecast\n", 19 | "\n", 20 | "> ray XGBoost forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "4f4c7bc1-9779-4771-8224-f852e6b7987c", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `xgboost.ray.RayXGBRegressor` that adds a `model_` property that contains the fitted model and is sent to the workers in the forecasting step." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import xgboost as xgb\n", 40 | "from xgboost_ray import RayXGBRegressor" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "#|export\n", 51 | "class RayXGBForecast(RayXGBRegressor):\n", 52 | " @property\n", 53 | " def model_(self):\n", 54 | " model_str = self.get_booster().save_raw(\"ubj\")\n", 55 | " local_model = xgb.XGBRegressor()\n", 56 | " local_model.load_model(model_str)\n", 57 | " return local_model" 58 | ] 59 | } 60 | ], 61 | "metadata": { 62 | "kernelspec": { 63 | "display_name": "python3", 64 | "language": "python", 65 | "name": "python3" 66 | } 67 | }, 68 | "nbformat": 4, 69 | "nbformat_minor": 5 70 | } 71 | -------------------------------------------------------------------------------- /nbs/distributed.models.spark.lgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.spark.lgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "12972535-1a7c-4814-a19c-5e2c48824e85", 16 | "metadata": {}, 17 | "source": [ 18 | "# SparkLGBMForecast\n", 19 | "\n", 20 | "> spark LightGBM forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "fd9d0998-ca46-4e7a-9c64-b8378c0c1b85", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `synapse.ml.lightgbm.LightGBMRegressor` that adds an `extract_local_model` method to get a local version of the trained model and broadcast it to the workers." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import lightgbm as lgb\n", 40 | "try:\n", 41 | " from synapse.ml.lightgbm import LightGBMRegressor\n", 42 | "except ModuleNotFoundError:\n", 43 | " import os\n", 44 | " \n", 45 | " if os.getenv('QUARTO_PREVIEW', '0') == '1' or os.getenv('IN_TEST', '0') == '1':\n", 46 | " LightGBMRegressor = object\n", 47 | " else:\n", 48 | " raise" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "#|export\n", 59 | "class SparkLGBMForecast(LightGBMRegressor):\n", 60 | " def _pre_fit(self, target_col):\n", 61 | " return self.setLabelCol(target_col)\n", 62 | " \n", 63 | " def extract_local_model(self, trained_model):\n", 64 | " model_str = trained_model.getNativeModel()\n", 65 | " local_model = lgb.Booster(model_str=model_str)\n", 66 | " return local_model" 67 | ] 68 | } 69 | ], 70 | "metadata": { 71 | "kernelspec": { 72 | "display_name": "python3", 73 | "language": "python", 74 | "name": "python3" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 5 79 | } 80 | -------------------------------------------------------------------------------- /nbs/distributed.models.spark.xgb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "533f8f53-cfa2-4560-a28f-1ce032a0949d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp distributed.models.spark.xgb" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "5ee154af-e882-4914-8bf2-f536a8d01b94", 16 | "metadata": {}, 17 | "source": [ 18 | "# SparkXGBForecast\n", 19 | "\n", 20 | "> spark XGBoost forecaster" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "4f4c7bc1-9779-4771-8224-f852e6b7987c", 26 | "metadata": {}, 27 | "source": [ 28 | "Wrapper of `xgboost.spark.SparkXGBRegressor` that adds an `extract_local_model` method to get a local version of the trained model and broadcast it to the workers." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "dbae0b4a-545c-472f-8ead-549830fb071c", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#|export\n", 39 | "import xgboost as xgb\n", 40 | "try:\n", 41 | " from xgboost.spark import SparkXGBRegressor # type: ignore\n", 42 | "except ModuleNotFoundError:\n", 43 | " import os\n", 44 | " \n", 45 | " if os.getenv('IN_TEST', '0') == '1':\n", 46 | " SparkXGBRegressor = object\n", 47 | " else:\n", 48 | " raise" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "ef31c6d5-7fb6-4a08-8d72-bfcdc1ae8540", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "#|export\n", 59 | "class SparkXGBForecast(SparkXGBRegressor): \n", 60 | " def _pre_fit(self, target_col):\n", 61 | " self.setParams(label_col=target_col)\n", 62 | " return self\n", 63 | "\n", 64 | " def extract_local_model(self, trained_model):\n", 65 | " model_str = trained_model.get_booster().save_raw('ubj')\n", 66 | " local_model = xgb.XGBRegressor()\n", 67 | " local_model.load_model(model_str)\n", 68 | " return local_model" 69 | ] 70 | } 71 | ], 72 | "metadata": { 73 | "kernelspec": { 74 | "display_name": "python3", 75 | "language": "python", 76 | "name": "python3" 77 | } 78 | }, 79 | "nbformat": 4, 80 | "nbformat_minor": 5 81 | } 82 | -------------------------------------------------------------------------------- /nbs/docs/getting-started/install.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e4498c76-c1a4-4bb4-ac55-96f9c0475acc", 6 | "metadata": {}, 7 | "source": [ 8 | "# Install\n", 9 | "\n", 10 | "> Instructions to install the package from different sources." 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "a2312c2d-b391-4d19-a1bd-f99339c290d7", 16 | "metadata": {}, 17 | "source": [ 18 | "## Released versions" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "d465c613-a0a6-4538-9497-57984c261dc0", 24 | "metadata": {}, 25 | "source": [ 26 | "### PyPI" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "417f8f55-8000-4595-af03-ab88bcc62488", 32 | "metadata": {}, 33 | "source": [ 34 | "#### Latest release" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "id": "0ea582c6-04e1-4c0f-a4bd-6e551f08f20d", 40 | "metadata": {}, 41 | "source": [ 42 | "To install the latest release of mlforecast from [PyPI](https://pypi.org/project/mlforecast/) you just have to run the following in a terminal:\n", 43 | "\n", 44 | "`pip install mlforecast`" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "342b7c8d-4bb3-43e0-8ee8-73bfd65e2b5f", 50 | "metadata": {}, 51 | "source": [ 52 | "#### Specific version" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "a09a5c90-cce9-4c49-ac38-f8569f760a98", 58 | "metadata": {}, 59 | "source": [ 60 | "If you want a specific version you can include a filter, for example:\n", 61 | "\n", 62 | "* `pip install \"mlforecast==0.3.0\"` to install the 0.3.0 version\n", 63 | "* `pip install \"mlforecast<0.4.0\"` to install any version prior to 0.4.0" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "108f8d92-506f-4dd0-ab3f-7fda491e616c", 69 | "metadata": {}, 70 | "source": [ 71 | "#### Extras" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "7b575d6e-8ffa-42b1-9b1a-f3a6a28085e0", 77 | "metadata": {}, 78 | "source": [ 79 | "##### polars\n", 80 | "\n", 81 | "Using polars dataframes: `pip install \"mlforecast[polars]\"`" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "id": "738f3c23-7969-423b-8177-e08849c9c16f", 87 | "metadata": {}, 88 | "source": [ 89 | "##### Saving to remote storages\n", 90 | "\n", 91 | "If you want to save your forecast artifacts to a remote storage like S3 or GCS you can use the following extras:\n", 92 | "\n", 93 | "* Saving to S3: `pip install \"mlforecast[aws]\"`\n", 94 | "* Saving to Google Cloud Storage: `pip install \"mlforecast[gcp]\"`\n", 95 | "* Saving to Azure Data Lake: `pip install \"mlforecast[azure]\"` " 96 | ] 97 | }, 98 | { 99 | "attachments": {}, 100 | "cell_type": "markdown", 101 | "id": "6e1d3669-bc0f-4458-98a6-cecc5ca4c58a", 102 | "metadata": {}, 103 | "source": [ 104 | "##### Distributed training\n", 105 | "\n", 106 | "If you want to perform distributed training you can use either dask, ray or spark. Once you know which framework you want to use you can include its extra:\n", 107 | "\n", 108 | "* dask: `pip install \"mlforecast[dask]\"`\n", 109 | "* ray: `pip install \"mlforecast[ray]\"`\n", 110 | "* spark: `pip install \"mlforecast[spark]\"`" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "6772b367-7e1b-4612-8bcd-26039b2badf3", 116 | "metadata": {}, 117 | "source": [ 118 | "### Conda" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "id": "3df21509-dcd1-433e-8a3e-a9f5bce7dc51", 124 | "metadata": {}, 125 | "source": [ 126 | "#### Latest release" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "id": "13744f0d-6916-4358-be73-25209033eb74", 132 | "metadata": {}, 133 | "source": [ 134 | "The mlforecast package is also published to [conda-forge](https://anaconda.org/conda-forge/mlforecast), which you can install by running the following in a terminal:\n", 135 | "\n", 136 | "`conda install -c conda-forge mlforecast`\n", 137 | "\n", 138 | "Note that this happens about a day later after it is published to PyPI, so you may have to wait to get the latest release." 139 | ] 140 | }, 141 | { 142 | "attachments": {}, 143 | "cell_type": "markdown", 144 | "id": "b0bd307a-1e96-4e3f-9a48-fed6eb0dc38d", 145 | "metadata": {}, 146 | "source": [ 147 | "#### Specific version\n", 148 | "\n", 149 | "If you want a specific version you can include a filter, for example:\n", 150 | "\n", 151 | "* `conda install -c conda-forge \"mlforecast==0.3.0\"` to install the 0.3.0 version\n", 152 | "* `conda install -c conda-forge \"mlforecast<0.4.0\"` to install any version prior to 0.4.0" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "937ed413-2207-43f2-965e-62ebbaf0c8db", 158 | "metadata": {}, 159 | "source": [ 160 | "## Development version" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "id": "ec4a0e27-24c3-4f79-929c-c77f891eb27e", 166 | "metadata": {}, 167 | "source": [ 168 | "If you want to try out a new feature that hasn't made it into a release yet you have the following options:\n", 169 | "\n", 170 | "* Install from github: `pip install git+https://github.com/Nixtla/mlforecast`\n", 171 | "* Clone and install: `git clone https://github.com/Nixtla/mlforecast mlforecast-dev && pip install mlforecast-dev/`, which will install the version from the current main branch." 172 | ] 173 | } 174 | ], 175 | "metadata": { 176 | "kernelspec": { 177 | "display_name": "python3", 178 | "language": "python", 179 | "name": "python3" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 5 184 | } 185 | -------------------------------------------------------------------------------- /nbs/docs/how-to-guides/custom_date_features.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "84f2bb63-84d6-4f04-a5d7-3ef146e4bc45", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| hide\n", 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "4ea65c79-0e5d-4b83-b6e7-95c76ad9c472", 18 | "metadata": {}, 19 | "source": [ 20 | "# Custom date features\n", 21 | "> Define your own functions to be used as date features" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "99e98cae-11e0-4823-8603-ffe1f5f86734", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from mlforecast import MLForecast\n", 32 | "from mlforecast.utils import generate_daily_series" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "5c156e10-15fa-4f85-830d-906c611e9432", 38 | "metadata": {}, 39 | "source": [ 40 | "The `date_features` argument of MLForecast can take pandas date attributes as well as functions that take a [pandas DatetimeIndex](https://pandas.pydata.org/docs/reference/api/pandas.DatetimeIndex.html) and return a numeric value. The name of the function is used as the name of the feature, so please use unique and descriptive names." 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "e02a3fe7-ae9e-41fe-835b-22678bb28448", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "series = generate_daily_series(1, min_length=6, max_length=6)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "0647652d-b1a7-4506-8fa4-c7e7e94db915", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "def even_day(dates):\n", 61 | " \"\"\"Day of month is even\"\"\"\n", 62 | " return dates.day % 2 == 0\n", 63 | "\n", 64 | "def month_start_or_end(dates):\n", 65 | " \"\"\"Date is month start or month end\"\"\"\n", 66 | " return dates.is_month_start | dates.is_month_end\n", 67 | "\n", 68 | "def is_monday(dates):\n", 69 | " \"\"\"Date is monday\"\"\"\n", 70 | " return dates.dayofweek == 0" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "4bcc2ce3-0e5a-4e3d-886b-5e3dd1c7d45a", 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "data": { 81 | "text/html": [ 82 | "
\n", 83 | "\n", 96 | "\n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | "
unique_iddsydayofweekdayofyeareven_daymonth_start_or_endis_monday
0id_02000-01-010.27440751FalseTrueFalse
1id_02000-01-021.35759562TrueFalseFalse
2id_02000-01-032.30138203FalseFalseTrue
3id_02000-01-043.27244214TrueFalseFalse
4id_02000-01-054.21182725FalseFalseFalse
5id_02000-01-065.32294736TrueFalseFalse
\n", 179 | "
" 180 | ], 181 | "text/plain": [ 182 | " unique_id ds y dayofweek dayofyear even_day \\\n", 183 | "0 id_0 2000-01-01 0.274407 5 1 False \n", 184 | "1 id_0 2000-01-02 1.357595 6 2 True \n", 185 | "2 id_0 2000-01-03 2.301382 0 3 False \n", 186 | "3 id_0 2000-01-04 3.272442 1 4 True \n", 187 | "4 id_0 2000-01-05 4.211827 2 5 False \n", 188 | "5 id_0 2000-01-06 5.322947 3 6 True \n", 189 | "\n", 190 | " month_start_or_end is_monday \n", 191 | "0 True False \n", 192 | "1 False False \n", 193 | "2 False True \n", 194 | "3 False False \n", 195 | "4 False False \n", 196 | "5 False False " 197 | ] 198 | }, 199 | "execution_count": null, 200 | "metadata": {}, 201 | "output_type": "execute_result" 202 | } 203 | ], 204 | "source": [ 205 | "fcst = MLForecast(\n", 206 | " [],\n", 207 | " freq='D',\n", 208 | " date_features=['dayofweek', 'dayofyear', even_day, month_start_or_end, is_monday]\n", 209 | ")\n", 210 | "fcst.preprocess(series)" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "python3", 217 | "language": "python", 218 | "name": "python3" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 5 223 | } 224 | -------------------------------------------------------------------------------- /nbs/docs/how-to-guides/custom_training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "36511bd5-1402-4ee7-b28f-36a0022a018b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| hide\n", 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "f448deb8-57d3-4ecc-8b71-74968b3b5ae8", 18 | "metadata": {}, 19 | "source": [ 20 | "# Custom training\n", 21 | "> Customize the training procedure for your models" 22 | ] 23 | }, 24 | { 25 | "attachments": {}, 26 | "cell_type": "markdown", 27 | "id": "09e62587-193a-4df4-be3a-7e552ac3f805", 28 | "metadata": {}, 29 | "source": [ 30 | "mlforecast abstracts away most of the training details, which is useful for iterating quickly. However, sometimes you want more control over the fit parameters, the data that goes into the model, etc. This guide shows how you can train a model in a specific way and then giving it back to mlforecast to produce forecasts with it." 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "7170807e-ec08-46f5-be8a-1337aa7fa28b", 36 | "metadata": {}, 37 | "source": [ 38 | "## Data setup" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "id": "7db156b4-ac87-4c78-9197-beb427cedecb", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "from mlforecast.utils import generate_daily_series" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "23d5b34d-e873-4349-9c0a-5d2674229c00", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "series = generate_daily_series(5)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "88860248-cc4d-4d74-aba5-400eca036111", 64 | "metadata": {}, 65 | "source": [ 66 | "## Creating forecast object" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "9f8a9e22-1634-4545-975e-f90c95ea7c5f", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "import lightgbm as lgb\n", 77 | "import numpy as np\n", 78 | "from sklearn.linear_model import LinearRegression\n", 79 | "\n", 80 | "from mlforecast import MLForecast" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "bc23f054-2bd9-4194-9b1d-4b0e57e2923d", 86 | "metadata": {}, 87 | "source": [ 88 | "Suppose we want to train a linear regression with the default settings." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "0967fcd1-faa2-436f-b7fd-5c5397e8bba2", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "fcst = MLForecast(\n", 99 | " models={'lr': LinearRegression()},\n", 100 | " freq='D',\n", 101 | " lags=[1],\n", 102 | " date_features=['dayofweek'],\n", 103 | ")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "ab9bcca5-03a1-4452-baed-e0577a1df3a1", 109 | "metadata": {}, 110 | "source": [ 111 | "## Generate training set" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "5d663091-d1db-4ff2-a0f0-749f71dcc5e5", 117 | "metadata": {}, 118 | "source": [ 119 | "Use `MLForecast.preprocess` to generate the training data." 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "5ffd0ca0-f6b7-487f-92cc-188a5f10c674", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/html": [ 131 | "
\n", 132 | "\n", 145 | "\n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | "
unique_iddsylag1dayofweek
1id_02000-01-021.4236260.4289736
2id_02000-01-032.3117821.4236260
3id_02000-01-043.1921912.3117821
4id_02000-01-054.1487673.1921912
5id_02000-01-065.0283564.1487673
\n", 199 | "
" 200 | ], 201 | "text/plain": [ 202 | " unique_id ds y lag1 dayofweek\n", 203 | "1 id_0 2000-01-02 1.423626 0.428973 6\n", 204 | "2 id_0 2000-01-03 2.311782 1.423626 0\n", 205 | "3 id_0 2000-01-04 3.192191 2.311782 1\n", 206 | "4 id_0 2000-01-05 4.148767 3.192191 2\n", 207 | "5 id_0 2000-01-06 5.028356 4.148767 3" 208 | ] 209 | }, 210 | "execution_count": null, 211 | "metadata": {}, 212 | "output_type": "execute_result" 213 | } 214 | ], 215 | "source": [ 216 | "prep = fcst.preprocess(series)\n", 217 | "prep.head()" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "id": "f5f82a08-743b-489d-8426-55747dce1ba1", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "X = prep.drop(columns=['unique_id', 'ds', 'y'])\n", 228 | "y = prep['y']" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "id": "7408d5e2-7153-4451-9ea9-e765544e4a1b", 234 | "metadata": {}, 235 | "source": [ 236 | "## Regular training" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "id": "e56b7678-41e3-4501-a91a-2b07de6967dc", 242 | "metadata": {}, 243 | "source": [ 244 | "Since we don't want to do anything special in our training process for the linear regression, we can just call `MLForecast.fit_models`" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "b94ad115-0606-40d7-87d5-762b1c716418", 251 | "metadata": {}, 252 | "outputs": [ 253 | { 254 | "data": { 255 | "text/plain": [ 256 | "MLForecast(models=[lr], freq=D, lag_features=['lag1'], date_features=['dayofweek'], num_threads=1)" 257 | ] 258 | }, 259 | "execution_count": null, 260 | "metadata": {}, 261 | "output_type": "execute_result" 262 | } 263 | ], 264 | "source": [ 265 | "fcst.fit_models(X, y)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "id": "e3701c65-e28f-4ae9-8c72-1db64f411963", 271 | "metadata": {}, 272 | "source": [ 273 | "This has trained the linear regression model and is now available in the `MLForecast.models_` attribute." 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "id": "bdd2aa2e-0c2e-4843-86c3-4588d69b28ce", 280 | "metadata": {}, 281 | "outputs": [ 282 | { 283 | "data": { 284 | "text/plain": [ 285 | "{'lr': LinearRegression()}" 286 | ] 287 | }, 288 | "execution_count": null, 289 | "metadata": {}, 290 | "output_type": "execute_result" 291 | } 292 | ], 293 | "source": [ 294 | "fcst.models_" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "id": "112bde47-6af9-4336-8529-a24435d35f2b", 300 | "metadata": {}, 301 | "source": [ 302 | "## Custom training" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "id": "8add7f5e-b4ec-4115-9cab-f2e96e0b138a", 308 | "metadata": {}, 309 | "source": [ 310 | "Now suppose you also want to train a LightGBM model on the same data, but treating the day of the week as a categorical feature and logging the train loss." 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "id": "b4de2390-37da-46ce-9ba4-723db3b0e95e", 317 | "metadata": {}, 318 | "outputs": [ 319 | { 320 | "name": "stdout", 321 | "output_type": "stream", 322 | "text": [ 323 | "[20]\ttraining's l2: 0.0823528\n", 324 | "[40]\ttraining's l2: 0.0230292\n", 325 | "[60]\ttraining's l2: 0.0207829\n", 326 | "[80]\ttraining's l2: 0.019675\n", 327 | "[100]\ttraining's l2: 0.018778\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "model = lgb.LGBMRegressor(n_estimators=100, verbosity=-1)\n", 333 | "model.fit(\n", 334 | " X,\n", 335 | " y,\n", 336 | " eval_set=[(X, y)],\n", 337 | " categorical_feature=['dayofweek'],\n", 338 | " callbacks=[lgb.log_evaluation(20)],\n", 339 | ");" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "id": "23bab8e2-514f-4d42-ba76-e5792a80f3e9", 345 | "metadata": {}, 346 | "source": [ 347 | "## Computing forecasts" 348 | ] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "id": "9b82ab00-c292-434f-9ad8-0a0c0859a35a", 353 | "metadata": {}, 354 | "source": [ 355 | "Now we just assign this model to the `MLForecast.models_` dictionary. Note that you can assign as many models as you want." 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "id": "8c908b51-fa59-406d-bb89-ea8f3ac7c9df", 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "data": { 366 | "text/plain": [ 367 | "{'lr': LinearRegression(), 'lgbm': LGBMRegressor(verbosity=-1)}" 368 | ] 369 | }, 370 | "execution_count": null, 371 | "metadata": {}, 372 | "output_type": "execute_result" 373 | } 374 | ], 375 | "source": [ 376 | "fcst.models_['lgbm'] = model\n", 377 | "fcst.models_" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "id": "8a4cea7b-9cc9-48e8-b7af-08407d4cb101", 383 | "metadata": {}, 384 | "source": [ 385 | "And now when calling `MLForecast.predict`, mlforecast will use those models to compute the forecasts." 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "id": "83791662-2031-4dc2-85cd-88e554c6ee7a", 392 | "metadata": {}, 393 | "outputs": [ 394 | { 395 | "data": { 396 | "text/html": [ 397 | "
\n", 398 | "\n", 411 | "\n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | "
unique_iddslrlgbm
0id_02000-08-103.5491245.166797
1id_12000-04-073.1542854.252490
2id_22000-06-162.8809333.224506
3id_32000-08-304.0618010.245443
4id_42001-01-082.9048722.225106
\n", 459 | "
" 460 | ], 461 | "text/plain": [ 462 | " unique_id ds lr lgbm\n", 463 | "0 id_0 2000-08-10 3.549124 5.166797\n", 464 | "1 id_1 2000-04-07 3.154285 4.252490\n", 465 | "2 id_2 2000-06-16 2.880933 3.224506\n", 466 | "3 id_3 2000-08-30 4.061801 0.245443\n", 467 | "4 id_4 2001-01-08 2.904872 2.225106" 468 | ] 469 | }, 470 | "execution_count": null, 471 | "metadata": {}, 472 | "output_type": "execute_result" 473 | } 474 | ], 475 | "source": [ 476 | "fcst.predict(1)" 477 | ] 478 | } 479 | ], 480 | "metadata": { 481 | "kernelspec": { 482 | "display_name": "python3", 483 | "language": "python", 484 | "name": "python3" 485 | } 486 | }, 487 | "nbformat": 4, 488 | "nbformat_minor": 5 489 | } 490 | -------------------------------------------------------------------------------- /nbs/docs/how-to-guides/one_model_per_horizon.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "02fdd2f9-c0c1-4d3c-96eb-63f600245d75", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| hide\n", 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "9fa8fdaa-85d7-4bcb-bc4f-8ef94811d664", 18 | "metadata": {}, 19 | "source": [ 20 | "# One model per step\n", 21 | "> Train one model to predict each step of the forecasting horizon" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "750a79d9-cf24-46a8-aba0-3b8a34916ef6", 27 | "metadata": {}, 28 | "source": [ 29 | "By default mlforecast uses the recursive strategy, i.e. a model is trained to predict the next value and if we're predicting several values we do it one at a time and then use the model's predictions as the new target, recompute the features and predict the next step.\n", 30 | "\n", 31 | "There's another approach where if we want to predict 10 steps ahead we train 10 different models, where each model is trained to predict the value at each specific step, i.e. one model predicts the next value, another one predicts the value two steps ahead and so on. This can be very time consuming but can also provide better results. If you want to use this approach you can specify `max_horizon` in `MLForecast.fit`, which will train that many models and each model will predict its corresponding horizon when you call `MLForecast.predict`." 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "id": "8481d0c4-035d-45f5-b64a-c99e0bd49922", 37 | "metadata": {}, 38 | "source": [ 39 | "## Setup" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "5d48bcd4-6fdc-41b3-adc9-3d5f83197028", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "import random\n", 50 | "import lightgbm as lgb\n", 51 | "import pandas as pd\n", 52 | "from datasetsforecast.m4 import M4, M4Info\n", 53 | "from utilsforecast.evaluation import evaluate\n", 54 | "from utilsforecast.losses import smape\n", 55 | "\n", 56 | "from mlforecast import MLForecast\n", 57 | "from mlforecast.lag_transforms import ExponentiallyWeightedMean, RollingMean\n", 58 | "from mlforecast.target_transforms import Differences" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "id": "92b3efda-abf4-4eaf-b1a9-48056321bc9f", 64 | "metadata": {}, 65 | "source": [ 66 | "### Data" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "5e9d3d66-3807-4549-8bde-4e1d9eda2f05", 72 | "metadata": {}, 73 | "source": [ 74 | "We will use four random series from the M4 dataset" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "90fc0561-04fc-4bfe-b823-1fe5d09bd90a", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "group = 'Hourly'\n", 85 | "await M4.async_download('data', group=group)\n", 86 | "df, *_ = M4.load(directory='data', group=group)\n", 87 | "df['ds'] = df['ds'].astype('int')\n", 88 | "ids = df['unique_id'].unique()\n", 89 | "random.seed(0)\n", 90 | "sample_ids = random.choices(ids, k=4)\n", 91 | "sample_df = df[df['unique_id'].isin(sample_ids)]\n", 92 | "info = M4Info[group]\n", 93 | "horizon = info.horizon\n", 94 | "valid = sample_df.groupby('unique_id').tail(horizon)\n", 95 | "train = sample_df.drop(valid.index)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "ad1761c0-60af-4e13-97fc-e991b5361c5e", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "def avg_smape(df):\n", 106 | " \"\"\"Computes the SMAPE by serie and then averages it across all series.\"\"\"\n", 107 | " full = df.merge(valid)\n", 108 | " return (\n", 109 | " evaluate(full, metrics=[smape])\n", 110 | " .drop(columns='metric')\n", 111 | " .set_index('unique_id')\n", 112 | " .squeeze()\n", 113 | " )" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "id": "ed8bd708-d042-485b-8a00-d699680cf2db", 119 | "metadata": {}, 120 | "source": [ 121 | "## Model" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "bf2837f5-e579-4f02-88b5-e3866e9c049d", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "fcst = MLForecast(\n", 132 | " models=lgb.LGBMRegressor(random_state=0, verbosity=-1),\n", 133 | " freq=1,\n", 134 | " lags=[24 * (i+1) for i in range(7)],\n", 135 | " lag_transforms={\n", 136 | " 1: [RollingMean(window_size=24)],\n", 137 | " 24: [RollingMean(window_size=24)],\n", 138 | " 48: [ExponentiallyWeightedMean(alpha=0.3)],\n", 139 | " },\n", 140 | " num_threads=1,\n", 141 | " target_transforms=[Differences([24])],\n", 142 | ")" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "6ec0b126-1367-49a8-a3f4-235fee969908", 149 | "metadata": {}, 150 | "outputs": [ 151 | { 152 | "name": "stdout", 153 | "output_type": "stream", 154 | "text": [ 155 | "Average SMAPE per method and serie\n" 156 | ] 157 | }, 158 | { 159 | "data": { 160 | "text/html": [ 161 | "
\n", 162 | "\n", 175 | "\n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | "
individualrecursive
unique_id
H1960.3%0.3%
H2560.4%0.3%
H38120.9%9.5%
H41311.9%13.6%
\n", 211 | "
" 212 | ], 213 | "text/plain": [ 214 | " individual recursive\n", 215 | "unique_id \n", 216 | "H196 0.3% 0.3%\n", 217 | "H256 0.4% 0.3%\n", 218 | "H381 20.9% 9.5%\n", 219 | "H413 11.9% 13.6%" 220 | ] 221 | }, 222 | "execution_count": null, 223 | "metadata": {}, 224 | "output_type": "execute_result" 225 | } 226 | ], 227 | "source": [ 228 | "horizon = 24\n", 229 | "# the following will train 24 models, one for each horizon\n", 230 | "individual_fcst = fcst.fit(train, max_horizon=horizon)\n", 231 | "individual_preds = individual_fcst.predict(horizon)\n", 232 | "avg_smape_individual = avg_smape(individual_preds).rename('individual')\n", 233 | "# the following will train a single model and use the recursive strategy\n", 234 | "recursive_fcst = fcst.fit(train)\n", 235 | "recursive_preds = recursive_fcst.predict(horizon)\n", 236 | "avg_smape_recursive = avg_smape(recursive_preds).rename('recursive')\n", 237 | "# results\n", 238 | "print('Average SMAPE per method and serie')\n", 239 | "avg_smape_individual.to_frame().join(avg_smape_recursive).applymap('{:.1%}'.format)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "id": "430eff81-dd27-42bf-b852-093ded7fc2e2", 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "#| hide\n", 250 | "# we get the same prediction for the first timestep\n", 251 | "pd.testing.assert_frame_equal(\n", 252 | " individual_preds.groupby('unique_id').head(1).astype({'ds': 'int64'}),\n", 253 | " recursive_preds.groupby('unique_id').head(1).astype({'ds': 'int64'}), \n", 254 | ")" 255 | ] 256 | } 257 | ], 258 | "metadata": { 259 | "kernelspec": { 260 | "display_name": "python3", 261 | "language": "python", 262 | "name": "python3" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 5 267 | } 268 | -------------------------------------------------------------------------------- /nbs/docs/how-to-guides/predict_subset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "caabe6e0-3eb2-4b58-85b4-20aad5de34e5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| hide\n", 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "c252c6e4-15f4-4789-a07c-ad326daba639", 18 | "metadata": {}, 19 | "source": [ 20 | "# Predicting a subset of ids\n", 21 | "> Compute predictions for only a subset of the training ids " 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "e9c959e7-6819-477d-b345-c2734785397a", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from lightgbm import LGBMRegressor\n", 32 | "from fastcore.test import test_fail\n", 33 | "\n", 34 | "from mlforecast import MLForecast\n", 35 | "from mlforecast.utils import generate_daily_series" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "b3e4e904-161a-4346-8165-ad352f4c8934", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "data": { 46 | "text/html": [ 47 | "
\n", 48 | "\n", 61 | "\n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | "
unique_iddslgb
0id_02000-08-103.728396
1id_12000-04-074.749133
2id_22000-06-164.749133
3id_32000-08-302.758949
4id_42001-01-083.331394
\n", 103 | "
" 104 | ], 105 | "text/plain": [ 106 | " unique_id ds lgb\n", 107 | "0 id_0 2000-08-10 3.728396\n", 108 | "1 id_1 2000-04-07 4.749133\n", 109 | "2 id_2 2000-06-16 4.749133\n", 110 | "3 id_3 2000-08-30 2.758949\n", 111 | "4 id_4 2001-01-08 3.331394" 112 | ] 113 | }, 114 | "execution_count": null, 115 | "metadata": {}, 116 | "output_type": "execute_result" 117 | } 118 | ], 119 | "source": [ 120 | "series = generate_daily_series(5)\n", 121 | "fcst = MLForecast({'lgb': LGBMRegressor(verbosity=-1)}, freq='D', date_features=['dayofweek'])\n", 122 | "fcst.fit(series)\n", 123 | "all_preds = fcst.predict(1)\n", 124 | "all_preds" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "7e7b3367-aa2a-44ce-8e59-e69237dd08b5", 130 | "metadata": {}, 131 | "source": [ 132 | "By default all series seen during training will be forecasted with the predict method. If you're only interested in predicting a couple of them you can use the `ids` argument." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "551cebed-50c3-4d39-8034-296bc874aab4", 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "data": { 143 | "text/html": [ 144 | "
\n", 145 | "\n", 158 | "\n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | "
unique_iddslgb
0id_02000-08-103.728396
1id_42001-01-083.331394
\n", 182 | "
" 183 | ], 184 | "text/plain": [ 185 | " unique_id ds lgb\n", 186 | "0 id_0 2000-08-10 3.728396\n", 187 | "1 id_4 2001-01-08 3.331394" 188 | ] 189 | }, 190 | "execution_count": null, 191 | "metadata": {}, 192 | "output_type": "execute_result" 193 | } 194 | ], 195 | "source": [ 196 | "fcst.predict(1, ids=['id_0', 'id_4'])" 197 | ] 198 | }, 199 | { 200 | "cell_type": "markdown", 201 | "id": "a7c59a20-e374-4ebe-b2ce-9a262580f6f8", 202 | "metadata": {}, 203 | "source": [ 204 | "Note that the ids must've been seen during training, if you try to predict an id that wasn't there you'll get an error." 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "id": "63b5d5ac-f9cc-4af7-bf44-1d22f1c2e627", 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "test_fail(lambda: fcst.predict(1, ids=['fake_id']), contains='fake_id')" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "python3", 221 | "language": "python", 222 | "name": "python3" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 5 227 | } 228 | -------------------------------------------------------------------------------- /nbs/docs/how-to-guides/sample_weights.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "da17c174-1220-44e7-a746-e6c7a1b175bf", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#| hide\n", 11 | "%load_ext autoreload\n", 12 | "%autoreload 2" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "11cbdd0d-b153-4367-aedd-31320d6e70e6", 18 | "metadata": {}, 19 | "source": [ 20 | "# Sample weights\n", 21 | "> Provide a column to pass through to the underlying models as sample weights" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "241ee142-6279-49db-a609-d26ec23f0825", 27 | "metadata": {}, 28 | "source": [ 29 | "## Data setup" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "e6c7432c-95ca-4c34-bc66-d6f317ef8b9b", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "import numpy as np\n", 40 | "from mlforecast.utils import generate_daily_series" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "8f942426-aa79-4d60-8d86-2cf3dc1c9cfb", 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "text/html": [ 52 | "
\n", 53 | "\n", 66 | "\n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | "
unique_iddsyweight
0id_02000-01-010.3575950.636962
1id_02000-01-021.3013820.269787
\n", 93 | "
" 94 | ], 95 | "text/plain": [ 96 | " unique_id ds y weight\n", 97 | "0 id_0 2000-01-01 0.357595 0.636962\n", 98 | "1 id_0 2000-01-02 1.301382 0.269787" 99 | ] 100 | }, 101 | "execution_count": null, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "series = generate_daily_series(2)\n", 108 | "series['weight'] = np.random.default_rng(seed=0).random(series.shape[0])\n", 109 | "series.head(2)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "id": "c9a92b29-db7f-47e0-90b4-28502fb90c4a", 115 | "metadata": {}, 116 | "source": [ 117 | "## Creating forecast object" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "f30ca0d0-bd08-4131-86f8-d8f8685b9553", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "import lightgbm as lgb\n", 128 | "from sklearn.linear_model import LinearRegression\n", 129 | "\n", 130 | "from mlforecast import MLForecast" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "be951ea9-f722-4e33-b460-845b6d672015", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "fcst = MLForecast(\n", 141 | " models={\n", 142 | " 'lr': LinearRegression(),\n", 143 | " 'lgbm': lgb.LGBMRegressor(verbosity=-1),\n", 144 | " },\n", 145 | " freq='D',\n", 146 | " lags=[1],\n", 147 | " date_features=['dayofweek'],\n", 148 | ")" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "id": "5f4ce8fb-9281-4230-92e8-ead0d4288d48", 154 | "metadata": {}, 155 | "source": [ 156 | "## Forecasting\n", 157 | "You can provide the `weight_col` argument to `MLForecast.fit` to indicate which column should be used as the sample weights." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "e30151a7-1d16-4852-800e-ddc88375020d", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/html": [ 169 | "
\n", 170 | "\n", 183 | "\n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | "
unique_iddslrlgbm
0id_02000-08-103.3360195.283677
1id_12000-04-073.3007864.230655
\n", 210 | "
" 211 | ], 212 | "text/plain": [ 213 | " unique_id ds lr lgbm\n", 214 | "0 id_0 2000-08-10 3.336019 5.283677\n", 215 | "1 id_1 2000-04-07 3.300786 4.230655" 216 | ] 217 | }, 218 | "execution_count": null, 219 | "metadata": {}, 220 | "output_type": "execute_result" 221 | } 222 | ], 223 | "source": [ 224 | "fcst.fit(series, weight_col='weight').predict(1)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "2bd56946-761b-4a12-9011-16b8588014c0", 230 | "metadata": {}, 231 | "source": [ 232 | "## Cross validation\n", 233 | "You can provide the `weight_col` argument to `MLForecast.cross_validation` to indicate which column should be used as the sample weights." 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "id": "45ca8c72-8440-42e4-a84e-6f68fd3b33a3", 240 | "metadata": {}, 241 | "outputs": [ 242 | { 243 | "data": { 244 | "text/html": [ 245 | "
\n", 246 | "\n", 259 | "\n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | "
unique_iddscutoffylrlgbm
0id_02000-08-082000-08-073.4363252.7707173.242790
1id_12000-04-052000-04-042.4302762.6879322.075247
2id_02000-08-092000-08-084.1367713.0951404.239010
3id_12000-04-062000-04-053.3635223.0166613.436962
\n", 310 | "
" 311 | ], 312 | "text/plain": [ 313 | " unique_id ds cutoff y lr lgbm\n", 314 | "0 id_0 2000-08-08 2000-08-07 3.436325 2.770717 3.242790\n", 315 | "1 id_1 2000-04-05 2000-04-04 2.430276 2.687932 2.075247\n", 316 | "2 id_0 2000-08-09 2000-08-08 4.136771 3.095140 4.239010\n", 317 | "3 id_1 2000-04-06 2000-04-05 3.363522 3.016661 3.436962" 318 | ] 319 | }, 320 | "execution_count": null, 321 | "metadata": {}, 322 | "output_type": "execute_result" 323 | } 324 | ], 325 | "source": [ 326 | "fcst.cross_validation(series, n_windows=2, h=1, weight_col='weight')" 327 | ] 328 | } 329 | ], 330 | "metadata": { 331 | "kernelspec": { 332 | "display_name": "python3", 333 | "language": "python", 334 | "name": "python3" 335 | } 336 | }, 337 | "nbformat": 4, 338 | "nbformat_minor": 5 339 | } 340 | -------------------------------------------------------------------------------- /nbs/docs/how-to-guides/transfer_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Transfer Learning" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "Transfer learning refers to the process of pre-training a flexible model on a large dataset and using it later on other data with little to no training. It is one of the most outstanding 🚀 achievements in Machine Learning and has many practical applications.\n", 15 | "\n", 16 | "For time series forecasting, the technique allows you to get lightning-fast predictions ⚡ bypassing the tradeoff between accuracy and speed (more than 30 times faster than our already fast [AutoARIMA](https://github.com/Nixtla/statsforecast) for a similar accuracy).\n", 17 | "\n", 18 | "This notebook shows how to generate a pre-trained model to forecast new time series never seen by the model. \n", 19 | "\n", 20 | "Table of Contents\n", 21 | "\n", 22 | "- Installing MLForecast\n", 23 | "- Load M3 Monthly Data\n", 24 | "- Instantiate NeuralForecast core, Fit, and save\n", 25 | "- Use the pre-trained model to predict on AirPassengers\n", 26 | "- Evaluate Results" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "You can run these experiments with Google Colab.\n", 34 | "\n", 35 | "\"Open" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Installing Libraries" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "%%capture\n", 52 | "# !pip install mlforecast datasetsforecast utilsforecast s3fs" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import lightgbm as lgb\n", 62 | "import numpy as np\n", 63 | "import pandas as pd\n", 64 | "from datasetsforecast.m3 import M3\n", 65 | "from sklearn.metrics import mean_absolute_error\n", 66 | "from utilsforecast.plotting import plot_series\n", 67 | "\n", 68 | "from mlforecast import MLForecast\n", 69 | "from mlforecast.target_transforms import Differences" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "## Load M3 Data\n", 77 | "\n", 78 | "The `M3` class will automatically download the complete M3 dataset and process it.\n", 79 | "\n", 80 | "It return three Dataframes: `Y_df` contains the values for the target variables, `X_df` contains exogenous calendar features and `S_df` contains static features for each time-series. For this example we will only use `Y_df`.\n", 81 | "\n", 82 | "If you want to use your own data just replace `Y_df`. Be sure to use a long format and have a simmilar structure than our data set." 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "Y_df_M3, _, _ = M3.load(directory='./', group='Monthly')" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "In this tutorial we are only using `1_000` series to speed up computations. Remove the filter to use the whole dataset." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "fig = plot_series(Y_df_M3)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "#| hide\n", 117 | "fig.savefig('../../figs/transfer_learning__eda.png', bbox_inches='tight')" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "![](../../figs/transfer_learning__eda.png)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "## Model Training\n", 132 | "\n", 133 | "Using the `MLForecast.fit` method you can train a set of models to your dataset. You can modify the hyperparameters of the model to get a better accuracy, in this case we will use the default hyperparameters of `lgb.LGBMRegressor`. " 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "models = [lgb.LGBMRegressor(verbosity=-1)]" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "The `MLForecast` object has the following parameters: \n", 150 | "\n", 151 | "- `models`: a list of sklearn-like (`fit` and `predict`) models. \n", 152 | "- `freq`: a string indicating the frequency of the data. See [panda’s available frequencies.](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases)\n", 153 | "- `differences`: Differences to take of the target before computing the features. These are restored at the forecasting step.\n", 154 | "- `lags`: Lags of the target to use as features.\n", 155 | "\n", 156 | "In this example, we are only using `differences` and `lags` to produce features. See [the full documentation](https://nixtla.github.io/mlforecast/forecast.html) to see all available features.\n", 157 | "\n", 158 | "Any settings are passed into the constructor. Then you call its `fit` method and pass in the historical data frame `Y_df_M3`. " 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "fcst = MLForecast(\n", 168 | " models=models, \n", 169 | " lags=range(1, 13),\n", 170 | " freq='MS',\n", 171 | " target_transforms=[Differences([1, 12])],\n", 172 | ")\n", 173 | "fcst.fit(Y_df_M3);" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "## Transfer M3 to AirPassengers\n", 181 | "\n", 182 | "Now we can transfer the trained model to forecast `AirPassengers` with the `MLForecast.predict` method, we just have to pass the new dataframe to the `new_data` argument." 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "Y_df = pd.read_csv('https://datasets-nixtla.s3.amazonaws.com/air-passengers.csv', parse_dates=['ds'])\n", 192 | "\n", 193 | "# We define the train df. \n", 194 | "Y_train_df = Y_df[Y_df.ds<='1959-12-31'] # 132 train\n", 195 | "Y_test_df = Y_df[Y_df.ds>'1959-12-31'] # 12 test" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "data": { 205 | "text/html": [ 206 | "
\n", 207 | "\n", 220 | "\n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | "
unique_iddsLGBMRegressor
0AirPassengers1960-01-01422.740096
1AirPassengers1960-02-01399.480193
2AirPassengers1960-03-01458.220289
3AirPassengers1960-04-01442.960385
4AirPassengers1960-05-01461.700482
\n", 262 | "
" 263 | ], 264 | "text/plain": [ 265 | " unique_id ds LGBMRegressor\n", 266 | "0 AirPassengers 1960-01-01 422.740096\n", 267 | "1 AirPassengers 1960-02-01 399.480193\n", 268 | "2 AirPassengers 1960-03-01 458.220289\n", 269 | "3 AirPassengers 1960-04-01 442.960385\n", 270 | "4 AirPassengers 1960-05-01 461.700482" 271 | ] 272 | }, 273 | "execution_count": null, 274 | "metadata": {}, 275 | "output_type": "execute_result" 276 | } 277 | ], 278 | "source": [ 279 | "Y_hat_df = fcst.predict(h=12, new_df=Y_train_df)\n", 280 | "Y_hat_df.head()" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "Y_hat_df = Y_test_df.merge(Y_hat_df, how='left', on=['unique_id', 'ds'])" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [ 298 | "fig = plot_series(Y_train_df, Y_hat_df)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "#| hide\n", 308 | "fig.savefig('../../figs/transfer_learning__forecast.png', bbox_inches='tight')" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "![](../../figs/transfer_learning__forecast.png)" 316 | ] 317 | }, 318 | { 319 | "cell_type": "markdown", 320 | "metadata": {}, 321 | "source": [ 322 | "## Evaluate Results\n", 323 | "\n", 324 | "\n", 325 | "We evaluate the forecasts of the pre-trained model with the Mean Absolute Error (`mae`).\n", 326 | "\n", 327 | "$$\n", 328 | "\\qquad MAE = \\frac{1}{Horizon} \\sum_{\\tau} |y_{\\tau} - \\hat{y}_{\\tau}|\\qquad\n", 329 | "$$" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "y_true = Y_test_df.y.values\n", 339 | "y_hat = Y_hat_df['LGBMRegressor'].values" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": {}, 346 | "outputs": [ 347 | { 348 | "name": "stdout", 349 | "output_type": "stream", 350 | "text": [ 351 | "LGBMRegressor MAE: 13.560\n", 352 | "ETS MAE: 16.222\n", 353 | "AutoARIMA MAE: 18.551\n" 354 | ] 355 | } 356 | ], 357 | "source": [ 358 | "print(f'LGBMRegressor MAE: {mean_absolute_error(y_hat, y_true):.3f}')\n", 359 | "print('ETS MAE: 16.222')\n", 360 | "print('AutoARIMA MAE: 18.551')" 361 | ] 362 | } 363 | ], 364 | "metadata": { 365 | "kernelspec": { 366 | "display_name": "python3", 367 | "language": "python", 368 | "name": "python3" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 4 373 | } 374 | -------------------------------------------------------------------------------- /nbs/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/favicon.png -------------------------------------------------------------------------------- /nbs/figs/cross_validation__predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/cross_validation__predictions.png -------------------------------------------------------------------------------- /nbs/figs/cross_validation__series.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/cross_validation__series.png -------------------------------------------------------------------------------- /nbs/figs/electricity_peak_forecasting__eda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/electricity_peak_forecasting__eda.png -------------------------------------------------------------------------------- /nbs/figs/electricity_peak_forecasting__predicted_peak.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/electricity_peak_forecasting__predicted_peak.png -------------------------------------------------------------------------------- /nbs/figs/end_to_end_walkthrough__cv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/end_to_end_walkthrough__cv.png -------------------------------------------------------------------------------- /nbs/figs/end_to_end_walkthrough__differences.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/end_to_end_walkthrough__differences.png -------------------------------------------------------------------------------- /nbs/figs/end_to_end_walkthrough__eda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/end_to_end_walkthrough__eda.png -------------------------------------------------------------------------------- /nbs/figs/end_to_end_walkthrough__final_forecast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/end_to_end_walkthrough__final_forecast.png -------------------------------------------------------------------------------- /nbs/figs/end_to_end_walkthrough__lgbcv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/end_to_end_walkthrough__lgbcv.png -------------------------------------------------------------------------------- /nbs/figs/end_to_end_walkthrough__predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/end_to_end_walkthrough__predictions.png -------------------------------------------------------------------------------- /nbs/figs/forecast__cross_validation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/forecast__cross_validation.png -------------------------------------------------------------------------------- /nbs/figs/forecast__cross_validation_intervals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/forecast__cross_validation_intervals.png -------------------------------------------------------------------------------- /nbs/figs/forecast__ercot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/forecast__ercot.png -------------------------------------------------------------------------------- /nbs/figs/forecast__predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/forecast__predict.png -------------------------------------------------------------------------------- /nbs/figs/forecast__predict_intervals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/forecast__predict_intervals.png -------------------------------------------------------------------------------- /nbs/figs/forecast__predict_intervals_window_size_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/forecast__predict_intervals_window_size_1.png -------------------------------------------------------------------------------- /nbs/figs/index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/index.png -------------------------------------------------------------------------------- /nbs/figs/load_forecasting__differences.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/load_forecasting__differences.png -------------------------------------------------------------------------------- /nbs/figs/load_forecasting__prediction_intervals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/load_forecasting__prediction_intervals.png -------------------------------------------------------------------------------- /nbs/figs/load_forecasting__predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/load_forecasting__predictions.png -------------------------------------------------------------------------------- /nbs/figs/load_forecasting__raw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/load_forecasting__raw.png -------------------------------------------------------------------------------- /nbs/figs/load_forecasting__transformed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/load_forecasting__transformed.png -------------------------------------------------------------------------------- /nbs/figs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/logo.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals__eda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals__eda.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals__knn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals__knn.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals__lasso.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals__lasso.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals__lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals__lr.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals__mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals__mlp.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals__ridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals__ridge.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__autocorrelation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__autocorrelation.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__eda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__eda.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__plot_forecasting_intervals.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__plot_forecasting_intervals.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__plot_residual_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__plot_residual_model.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__plot_values.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__plot_values.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__seasonal_decompose_aditive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__seasonal_decompose_aditive.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__seasonal_decompose_multiplicative.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__seasonal_decompose_multiplicative.png -------------------------------------------------------------------------------- /nbs/figs/prediction_intervals_in_forecasting_models__train_test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/prediction_intervals_in_forecasting_models__train_test.png -------------------------------------------------------------------------------- /nbs/figs/quick_start_local__eda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/quick_start_local__eda.png -------------------------------------------------------------------------------- /nbs/figs/quick_start_local__predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/quick_start_local__predictions.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__diff1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__diff1.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__diff2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__diff2.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__eda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__eda.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__log.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__log_diffs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__log_diffs.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__minmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__minmax.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__standardized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__standardized.png -------------------------------------------------------------------------------- /nbs/figs/target_transforms__zeros.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/target_transforms__zeros.png -------------------------------------------------------------------------------- /nbs/figs/transfer_learning__eda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/transfer_learning__eda.png -------------------------------------------------------------------------------- /nbs/figs/transfer_learning__forecast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/mlforecast/f674b450b19c6cde3e8d3498544d60f8740917aa/nbs/figs/transfer_learning__forecast.png -------------------------------------------------------------------------------- /nbs/mint.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://mintlify.com/schema.json", 3 | "name": "Nixtla", 4 | "logo": { 5 | "light": "/light.png", 6 | "dark": "/dark.png" 7 | }, 8 | "favicon": "/favicon.svg", 9 | "colors": { 10 | "primary": "#0E0E0E", 11 | "light": "#FAFAFA", 12 | "dark": "#0E0E0E", 13 | "anchors": { 14 | "from": "#2AD0CA", 15 | "to": "#0E00F8" 16 | } 17 | }, 18 | "topbarCtaButton": { 19 | "type": "github", 20 | "url": "https://github.com/Nixtla/mlforecast" 21 | }, 22 | "topAnchor": { 23 | "name": "MLForecast", 24 | "icon": "robot" 25 | }, 26 | "navigation": [ 27 | { 28 | "group": "", 29 | "pages": ["index.html"] 30 | }, 31 | { 32 | "group": "Getting Started", 33 | "pages": [ 34 | "docs/getting-started/install.html", 35 | "docs/getting-started/quick_start_local.html", 36 | "docs/getting-started/quick_start_distributed.html", 37 | "docs/getting-started/end_to_end_walkthrough.html" 38 | ] 39 | }, 40 | { 41 | "group": "How-to guides", 42 | "pages": [ 43 | "docs/how-to-guides/exogenous_features.html", 44 | "docs/how-to-guides/lag_transforms_guide.html", 45 | "docs/how-to-guides/hyperparameter_optimization.html", 46 | "docs/how-to-guides/sklearn_pipelines.html", 47 | "docs/how-to-guides/sample_weights.html", 48 | "docs/how-to-guides/cross_validation.html", 49 | "docs/how-to-guides/prediction_intervals.html", 50 | "docs/how-to-guides/target_transforms_guide.html", 51 | "docs/how-to-guides/analyzing_models.html", 52 | "docs/how-to-guides/mlflow.html", 53 | "docs/how-to-guides/transforming_exog.html", 54 | "docs/how-to-guides/custom_training.html", 55 | "docs/how-to-guides/training_with_numpy.html", 56 | "docs/how-to-guides/one_model_per_horizon.html", 57 | "docs/how-to-guides/custom_date_features.html", 58 | "docs/how-to-guides/predict_callbacks.html", 59 | "docs/how-to-guides/predict_subset.html", 60 | "docs/how-to-guides/transfer_learning.html" 61 | ] 62 | }, 63 | { 64 | "group": "Tutorials", 65 | "pages": [ 66 | "docs/tutorials/electricity_load_forecasting.html", 67 | "docs/tutorials/electricity_peak_forecasting.html", 68 | "docs/tutorials/prediction_intervals_in_forecasting_models.html" 69 | ] 70 | }, 71 | { 72 | "group": "API Reference", 73 | "pages": [ 74 | { 75 | "group": "Local", 76 | "pages": [ 77 | "forecast.html", 78 | "auto.html", 79 | "lgb_cv.html", 80 | "optimization.html", 81 | "utils.html", 82 | "core.html", 83 | "target_transforms.html", 84 | "lag_transforms.html", 85 | "feature_engineering.html", 86 | "callbacks.html" 87 | ] 88 | }, 89 | { 90 | "group": "Distributed", 91 | "pages": [ 92 | "distributed.forecast.html", 93 | { 94 | "group": "Models", 95 | "pages": [ 96 | "distributed.models.dask.lgb.html", 97 | "distributed.models.dask.xgb.html", 98 | "distributed.models.ray.lgb.html", 99 | "distributed.models.ray.xgb.html", 100 | "distributed.models.spark.lgb.html", 101 | "distributed.models.spark.xgb.html" 102 | ] 103 | } 104 | ] 105 | } 106 | ] 107 | } 108 | ] 109 | } 110 | -------------------------------------------------------------------------------- /nbs/nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "mlforecast" 6 | site-url: "https://Nixtla.github.io/mlforecast/" 7 | description: "Scalable machine learning based time series forecasting" 8 | repo-branch: main 9 | repo-url: "https://github.com/Nixtla/mlforecast" 10 | -------------------------------------------------------------------------------- /nbs/sidebar.yml: -------------------------------------------------------------------------------- 1 | website: 2 | reader-mode: false 3 | sidebar: 4 | collapse-level: 1 5 | contents: 6 | - index.ipynb 7 | - section: Getting started 8 | contents: 9 | - docs/getting-started/install.ipynb 10 | - docs/getting-started/quick_start_local.ipynb 11 | - docs/getting-started/quick_start_distributed.ipynb 12 | - docs/getting-started/end_to_end_walkthrough.ipynb 13 | - section: How-to guides 14 | contents: 15 | - docs/how-to-guides/exogenous_features.ipynb 16 | - docs/how-to-guides/lag_transforms_guide.ipynb 17 | - docs/how-to-guides/transforming_exog.ipynb 18 | - docs/how-to-guides/cross_validation.ipynb 19 | - docs/how-to-guides/prediction_intervals.ipynb 20 | - docs/how-to-guides/target_transforms_guide.ipynb 21 | - docs/how-to-guides/analyzing_models.ipynb 22 | - docs/how-to-guides/custom_training.ipynb 23 | - docs/how-to-guides/training_with_numpy.ipynb 24 | - docs/how-to-guides/one_model_per_horizon.ipynb 25 | - docs/how-to-guides/custom_date_features.ipynb 26 | - docs/how-to-guides/predict_callbacks.ipynb 27 | - docs/how-to-guides/predict_subset.ipynb 28 | - docs/how-to-guides/transfer_learning.ipynb 29 | - section: Tutorials 30 | contents: 31 | - docs/tutorials/electricity_load_forecasting.ipynb 32 | - docs/tutorials/electricity_peak_forecasting.ipynb 33 | - docs/tutorials/prediction_intervals_in_forecasting_models.ipynb 34 | - section: API reference 35 | contents: 36 | - section: Local 37 | contents: 38 | - forecast.ipynb 39 | - lgb_cv.ipynb 40 | - utils.ipynb 41 | - core.ipynb 42 | - target_transforms.ipynb 43 | - lag_transforms.ipynb 44 | - feature_engineering.ipynb 45 | - callbacks.ipynb 46 | - section: Distributed 47 | contents: 48 | - distributed.forecast.ipynb 49 | - section: Models 50 | contents: 51 | - distributed.models.dask.lgb.ipynb 52 | - distributed.models.dask.xgb.ipynb 53 | - distributed.models.ray.lgb.ipynb 54 | - distributed.models.ray.xgb.ipynb 55 | - distributed.models.spark.lgb.ipynb 56 | - distributed.models.spark.xgb.ipynb 57 | -------------------------------------------------------------------------------- /nbs/styles.css: -------------------------------------------------------------------------------- 1 | .cell { 2 | margin-bottom: 1rem; 3 | } 4 | 5 | .cell > .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre { 14 | margin-left: 0.8rem; 15 | margin-top: 0; 16 | background: none; 17 | border-left: 2px solid lightsalmon; 18 | border-top-left-radius: 0; 19 | border-top-right-radius: 0; 20 | } 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | 39 | /* show_doc signature */ 40 | blockquote > pre { 41 | font-size: 14px; 42 | } 43 | 44 | .table { 45 | font-size: 16px; 46 | /* disable striped tables */ 47 | --bs-table-striped-bg: var(--bs-table-bg); 48 | } 49 | 50 | .quarto-figure-center > figure > figcaption { 51 | text-align: center; 52 | } 53 | 54 | .figure-caption { 55 | font-size: 75%; 56 | font-style: italic; 57 | } 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff.lint] 2 | select = ["F", "ARG"] 3 | 4 | exclude = ["_nbdev.py"] 5 | 6 | [tool.mypy] 7 | ignore_missing_imports = true 8 | [[tool.mypy.overrides]] 9 | module = 'mlforecast.compat' 10 | ignore_errors = true 11 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | host = github 3 | lib_name = mlforecast 4 | user = Nixtla 5 | description = Scalable machine learning based time series forecasting 6 | keywords = python forecast forecasting machine-learning dask 7 | author = José Morales 8 | author_email = jmoralz92@gmail.com 9 | copyright = Nixtla 10 | branch = main 11 | version = 1.0.2 12 | min_python = 3.9 13 | audience = Developers 14 | language = English 15 | custom_sidebar = True 16 | license = apache2 17 | status = 4 18 | requirements = cloudpickle coreforecast>=0.0.15 fsspec optuna pandas scikit-learn utilsforecast>=0.2.9 19 | dask_requirements = fugue dask[complete]<=2024.12.1 lightgbm xgboost 20 | ray_requirements = fugue[ray] lightgbm_ray xgboost_ray 21 | spark_requirements = fugue pyspark>=3.3 lightgbm xgboost 22 | aws_requirements = fsspec[s3] 23 | gcp_requirements = fsspec[gcs] 24 | azure_requirements = fsspec[adl] 25 | polars_requirements = polars[numpy] 26 | dev_requirements = black>=24 datasetsforecast>=1 gitpython holidays<0.21 lightgbm<4.6 matplotlib mlflow>=2.10.0 mypy nbdev<2.3.26 numpy>=2 pandas>=2.2.2 pre-commit polars[numpy] pyarrow ruff setuptools statsmodels xgboost 27 | nbs_path = nbs 28 | doc_path = _docs 29 | recursive = True 30 | doc_host = https://Nixtla.github.io 31 | doc_baseurl = /mlforecast/ 32 | git_url = https://github.com/Nixtla/mlforecast 33 | lib_path = mlforecast 34 | title = mlforecast 35 | tst_flags = polars ray shap window_ops 36 | black_formatting = True 37 | readme_nb = index.ipynb 38 | allowed_metadata_keys = 39 | allowed_cell_metadata_keys = 40 | jupyter_hooks = True 41 | clean_ids = True 42 | clear_all = False 43 | put_version_in_init = True 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.9 3.10 3.11 3.12 3.13'.split() 26 | 27 | requirements = cfg['requirements'].split() 28 | dask_requirements = cfg['dask_requirements'].split() 29 | ray_requirements = cfg['ray_requirements'].split() 30 | spark_requirements = cfg['spark_requirements'].split() 31 | aws_requirements = cfg['aws_requirements'].split() 32 | azure_requirements = cfg['azure_requirements'].split() 33 | gcp_requirements = cfg['gcp_requirements'].split() 34 | polars_requirements = cfg['polars_requirements'].split() 35 | dev_requirements = cfg['dev_requirements'].split() 36 | all_requirements = { 37 | *dask_requirements, 38 | *ray_requirements, 39 | *spark_requirements, 40 | *aws_requirements, 41 | *azure_requirements, 42 | *gcp_requirements, 43 | *polars_requirements, 44 | *dev_requirements, 45 | } 46 | min_python = cfg['min_python'] 47 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 48 | 49 | setuptools.setup( 50 | name = 'mlforecast', 51 | license = lic[0], 52 | classifiers = [ 53 | 'Development Status :: ' + statuses[int(cfg['status'])], 54 | 'Intended Audience :: ' + cfg['audience'].title(), 55 | 'Natural Language :: ' + cfg['language'].title(), 56 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 57 | url = cfg['git_url'], 58 | packages = setuptools.find_packages(), 59 | include_package_data = True, 60 | install_requires = requirements, 61 | extras_require = { 62 | 'dask': dask_requirements, 63 | 'ray': ray_requirements, 64 | 'spark': spark_requirements, 65 | 'aws': aws_requirements, 66 | 'azure': azure_requirements, 67 | 'gcp': gcp_requirements, 68 | 'polars': polars_requirements, 69 | 'dev': dev_requirements, 70 | 'all': all_requirements, 71 | }, 72 | dependency_links = cfg.get('dep_links','').split(), 73 | python_requires = '>=' + cfg['min_python'], 74 | long_description = open('README.md', encoding='utf-8').read(), 75 | long_description_content_type = 'text/markdown', 76 | zip_safe = False, 77 | entry_points = { 78 | 'console_scripts': cfg.get('console_scripts','').split(), 79 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 80 | }, 81 | **setup_cfg) 82 | -------------------------------------------------------------------------------- /tests/test_m4.py: -------------------------------------------------------------------------------- 1 | import lightgbm as lgb 2 | import pandas as pd 3 | import pytest 4 | from datasetsforecast.m4 import M4, M4Info, M4Evaluation 5 | from sklearn.linear_model import ElasticNet 6 | 7 | from mlforecast import MLForecast 8 | from mlforecast.lag_transforms import ( 9 | ExpandingMean, 10 | ExponentiallyWeightedMean, 11 | RollingMean, 12 | ) 13 | from mlforecast.target_transforms import Differences 14 | 15 | configs = { 16 | "Hourly": { 17 | "lgb_params": { 18 | "n_estimators": 200, 19 | "bagging_freq": 1, 20 | "learning_rate": 0.05, 21 | "num_leaves": 2500, 22 | "lambda_l1": 0.03, 23 | "lambda_l2": 0.5, 24 | "bagging_fraction": 0.9, 25 | "feature_fraction": 0.8, 26 | }, 27 | "mlf_params": { 28 | "target_transforms": [Differences([24])], 29 | "lags": [24 * i for i in range(1, 15)], 30 | "lag_transforms": { 31 | 24: [ 32 | ExponentiallyWeightedMean(alpha=0.3), 33 | RollingMean(7 * 24), 34 | RollingMean(7 * 48), 35 | ], 36 | 48: [ 37 | ExponentiallyWeightedMean(alpha=0.3), 38 | RollingMean(7 * 24), 39 | RollingMean(7 * 48), 40 | ], 41 | }, 42 | }, 43 | "metrics": { 44 | "lgb": { 45 | "SMAPE": 10.206856, 46 | "MASE": 0.861700, 47 | "OWA": 0.457511, 48 | }, 49 | "enet": { 50 | "SMAPE": 26.721835, 51 | "MASE": 22.954763, 52 | "OWA": 5.518959, 53 | }, 54 | }, 55 | }, 56 | "Daily": { 57 | "lgb_params": { 58 | "n_estimators": 30, 59 | "num_leaves": 128, 60 | }, 61 | "mlf_params": { 62 | "target_transforms": [Differences([1])], 63 | "lags": [i + 1 for i in range(14)], 64 | "lag_transforms": { 65 | 7: [RollingMean(7)], 66 | 14: [RollingMean(7)], 67 | }, 68 | }, 69 | "metrics": { 70 | "lgb": { 71 | "SMAPE": 2.984652, 72 | "MASE": 3.205519, 73 | "OWA": 0.978931, 74 | }, 75 | "enet": { 76 | "SMAPE": 2.989489, 77 | "MASE": 3.221004, 78 | "OWA": 0.982087, 79 | }, 80 | }, 81 | }, 82 | "Weekly": { 83 | "lgb_params": { 84 | "n_estimators": 100, 85 | "objective": "l1", 86 | "num_leaves": 256, 87 | }, 88 | "mlf_params": { 89 | "target_transforms": [Differences([1])], 90 | "lags": [i + 1 for i in range(32)], 91 | "lag_transforms": { 92 | 4: [ExpandingMean(), RollingMean(4)], 93 | 8: [ExpandingMean(), RollingMean(4)], 94 | }, 95 | }, 96 | "metrics": { 97 | "lgb": { 98 | "SMAPE": 8.238175, 99 | "MASE": 2.222099, 100 | "OWA": 0.849666, 101 | }, 102 | "enet": { 103 | "SMAPE": 9.794393, 104 | "MASE": 3.270274, 105 | "OWA": 1.123305, 106 | }, 107 | }, 108 | }, 109 | "Yearly": { 110 | "lgb_params": { 111 | "n_estimators": 100, 112 | "objective": "l1", 113 | "num_leaves": 256, 114 | }, 115 | "mlf_params": { 116 | "target_transforms": [Differences([1])], 117 | "lags": [i + 1 for i in range(6)], 118 | "lag_transforms": { 119 | 1: [ExpandingMean()], 120 | 6: [ExpandingMean()], 121 | }, 122 | }, 123 | "metrics": { 124 | "lgb": { 125 | "SMAPE": 13.281131, 126 | "MASE": 3.018999, 127 | "OWA": 0.786155, 128 | }, 129 | "enet": { 130 | "SMAPE": 15.363430, 131 | "MASE": 3.953421, 132 | "OWA": 0.967420, 133 | }, 134 | }, 135 | }, 136 | } 137 | 138 | 139 | def train_valid_split(group): 140 | df, *_ = M4.load(directory="data", group=group) 141 | df["ds"] = df["ds"].astype("int") 142 | horizon = M4Info[group].horizon 143 | valid = df.groupby("unique_id").tail(horizon) 144 | train = df.drop(valid.index) 145 | return train, valid, horizon 146 | 147 | 148 | @pytest.mark.parametrize("group", configs.keys()) 149 | def test_performance(group): 150 | cfg = configs[group] 151 | train, _, horizon = train_valid_split(group) 152 | fcst = MLForecast( 153 | models={ 154 | "lgb": lgb.LGBMRegressor( 155 | random_state=0, n_jobs=1, verbosity=-1, **cfg["lgb_params"] 156 | ), 157 | "enet": ElasticNet(), 158 | }, 159 | freq=1, 160 | **cfg["mlf_params"], 161 | num_threads=2, 162 | ) 163 | fcst.fit(train) 164 | preds = fcst.predict(horizon) 165 | for model, expected in cfg["metrics"].items(): 166 | model_preds = preds[model].values.reshape(-1, horizon) 167 | model_eval = M4Evaluation.evaluate("data", group, model_preds).loc[group] 168 | pd.testing.assert_series_equal(model_eval, pd.Series(expected, name=group)) 169 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.base import BaseEstimator 4 | from sklearn.linear_model import LinearRegression 5 | from utilsforecast.losses import smape 6 | 7 | from mlforecast import MLForecast 8 | from mlforecast.lag_transforms import RollingMean, RollingMax, RollingMin 9 | from mlforecast.target_transforms import Differences, LocalStandardScaler 10 | from mlforecast.utils import generate_daily_series 11 | 12 | 13 | class SeasonalNaive(BaseEstimator): 14 | def fit(self, X, y=None): 15 | return self 16 | 17 | def predict(self, X, y=None): 18 | return X["lag7"] 19 | 20 | 21 | @pytest.fixture(scope="module") 22 | def series(): 23 | n_series = 1_000 24 | n_static = 10 25 | return generate_daily_series( 26 | n_series=n_series, 27 | min_length=500, 28 | max_length=2_000, 29 | n_static_features=n_static, 30 | static_as_categorical=False, 31 | equal_ends=True, 32 | ) 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def series_with_exog(series): 37 | series = series.copy() 38 | n_exog = 10 39 | exog_names = [f"exog_{i}" for i in range(n_exog)] 40 | series[exog_names] = np.random.random((series.shape[0], n_exog)) 41 | return series 42 | 43 | 44 | @pytest.fixture 45 | def fcst(): 46 | return MLForecast( 47 | models={ 48 | "lr": LinearRegression(), 49 | "seas_naive": SeasonalNaive(), 50 | }, 51 | freq="D", 52 | lags=[1, 7, 14, 28], 53 | lag_transforms={ 54 | 1 : [RollingMean(7)], 55 | 7 : [RollingMean(7), RollingMin(7), RollingMax(7)], 56 | 14: [RollingMean(7), RollingMin(7), RollingMax(7)], 57 | 28: [RollingMean(7), RollingMin(7), RollingMax(7)], 58 | }, 59 | date_features=["dayofweek", "month", "year", "day"], 60 | target_transforms=[Differences([1, 7]), LocalStandardScaler()], 61 | ) 62 | 63 | 64 | @pytest.fixture 65 | def statics(series): 66 | return series.columns.drop(["unique_id", "ds", "y"]).tolist() 67 | 68 | 69 | @pytest.fixture 70 | def exogs(series_with_exog, statics): 71 | return series_with_exog.columns.drop(["unique_id", "ds", "y"] + statics).tolist() 72 | 73 | 74 | @pytest.mark.parametrize("use_exog", [True, False]) 75 | @pytest.mark.parametrize("num_threads", [1, 2]) 76 | def test_preprocess(benchmark, fcst: MLForecast, series, use_exog, series_with_exog, statics, num_threads): 77 | if use_exog: 78 | series = series_with_exog 79 | fcst.ts.num_threads = num_threads 80 | benchmark(fcst.preprocess, series, static_features=statics) 81 | 82 | 83 | @pytest.mark.parametrize("use_exog", [True, False]) 84 | @pytest.mark.parametrize("num_threads", [1, 2]) 85 | @pytest.mark.parametrize("keep_last_n", [None, 50]) 86 | def test_predict(benchmark, fcst: MLForecast, series, use_exog, series_with_exog, exogs, statics, keep_last_n, num_threads): 87 | horizon = 14 88 | if use_exog: 89 | series = series_with_exog 90 | valid = series.groupby("unique_id").tail(horizon) 91 | train = series.drop(valid.index) 92 | pred_kwargs = {} 93 | if use_exog: 94 | pred_kwargs["X_df"] = valid[["unique_id", "ds"] + exogs] 95 | fcst.ts.num_threads = num_threads 96 | fcst.fit(train, static_features=statics, keep_last_n=keep_last_n) 97 | preds = benchmark(fcst.predict, horizon, **pred_kwargs) 98 | full_preds = preds.merge(valid[["unique_id", "ds", "y"]], on=["unique_id", "ds"]) 99 | models = fcst.models.keys() 100 | evaluation = smape(full_preds, models=models) 101 | summary = evaluation[models].mean(axis=0) 102 | assert summary["lr"] < summary["seas_naive"] 103 | --------------------------------------------------------------------------------