├── .gitmodules ├── utilsforecast ├── py.typed ├── __init__.py ├── compat.py ├── grouped_array.py ├── validation.py ├── data.py ├── preprocessing.py ├── evaluation.py ├── feature_engineering.py └── plotting.py ├── nbs ├── .gitignore └── _quarto.yml ├── .gitattributes ├── action_files ├── clean_nbs ├── lint ├── remove_logs_cells └── nbdev_test ├── docs ├── mintlify │ ├── dark.png │ ├── light.png │ ├── imgs │ │ ├── index.png │ │ ├── plotting.png │ │ └── losses │ │ │ ├── q_loss.png │ │ │ ├── mae_loss.png │ │ │ ├── mq_loss.png │ │ │ ├── mse_loss.png │ │ │ ├── mape_loss.png │ │ │ ├── mase_loss.png │ │ │ ├── rmae_loss.png │ │ │ └── rmse_loss.png │ ├── docs.json │ └── favicon.svg ├── evaluation.html.md ├── data.html.md ├── preprocessing.html.md ├── plotting.html.md ├── feature_engineering.html.md ├── to_mdx.py ├── convert_to_mkdocstrings.py └── losses.html.md ├── scripts ├── cvt.py ├── filter_licenses.py ├── extract_test.sh ├── cli.py └── cli.ipynb ├── MANIFEST.in ├── .github ├── dependabot.yml ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── documentation-issue.yml │ ├── feature-request.yml │ └── bug-report.yml ├── workflows │ ├── release-drafter.yml │ ├── lint.yaml │ ├── no-response.yaml │ ├── release.yml │ ├── pytest.yml │ └── build-docs.yaml └── release-drafter.yml ├── tests ├── conftest.py ├── test_data.py ├── test_validation.py ├── test_plotting.py ├── test_grouped_array.py ├── test_feature_engineering.py └── test_preprocessing.py ├── .pre-commit-config.yaml ├── THIRD_PARTY_LICENSES.md ├── Makefile ├── .gitignore ├── pyproject.toml ├── README.md ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md └── LICENSE /.gitmodules: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utilsforecast/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nbs/.gitignore: -------------------------------------------------------------------------------- 1 | /.quarto/ 2 | -------------------------------------------------------------------------------- /utilsforecast/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.15" 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | nbs/** linguist-documentation 2 | *.ipynb merge=nbdev-merge 3 | -------------------------------------------------------------------------------- /action_files/clean_nbs: -------------------------------------------------------------------------------- 1 | nbdev_clean 2 | ./action_files/remove_logs_cells 3 | -------------------------------------------------------------------------------- /docs/mintlify/dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/dark.png -------------------------------------------------------------------------------- /docs/mintlify/light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/light.png -------------------------------------------------------------------------------- /scripts/cvt.py: -------------------------------------------------------------------------------- 1 | from nbdev.export import nb_export 2 | nb_export('cli.ipynb', lib_path='.', name='cli') -------------------------------------------------------------------------------- /docs/mintlify/imgs/index.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/index.png -------------------------------------------------------------------------------- /action_files/lint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | ruff check utilsforecast || exit -1 3 | mypy utilsforecast || exit -1 4 | -------------------------------------------------------------------------------- /docs/mintlify/imgs/plotting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/plotting.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/q_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/q_loss.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/mae_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/mae_loss.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/mq_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/mq_loss.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/mse_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/mse_loss.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/mape_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/mape_loss.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/mase_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/mase_loss.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/rmae_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/rmae_loss.png -------------------------------------------------------------------------------- /docs/mintlify/imgs/losses/rmse_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nixtla/utilsforecast/HEAD/docs/mintlify/imgs/losses/rmse_loss.png -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: github-actions 4 | directory: / 5 | schedule: 6 | interval: weekly 7 | groups: 8 | ci-dependencies: 9 | patterns: ["*"] 10 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def assert_raises_with_message(func, expected_msg, *args, **kwargs): 5 | with pytest.raises((AssertionError, ValueError, Exception)) as exc_info: 6 | func(*args, **kwargs) 7 | assert expected_msg in str(exc_info.value) 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Ask a question or get support 4 | url: https://join.slack.com/t/nixtlacommunity/shared_invite/zt-1h77esh5y-iL1m8N0F7qV1HmH~0KYeAQ 5 | about: Ask a question or request support for using a library of the nixtlaverse 6 | -------------------------------------------------------------------------------- /docs/evaluation.html.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Evaluation 3 | description: Model performance evaluation 4 | --- 5 | 6 | ::: utilsforecast.evaluation.evaluate 7 | handler: python 8 | options: 9 | docstring_style: google 10 | heading_level: 3 11 | show_root_heading: true 12 | show_source: true 13 | -------------------------------------------------------------------------------- /docs/data.html.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Data 3 | description: Utilies for generating time series datasets 4 | --- 5 | 6 | 7 | ::: utilsforecast.data.generate_series 8 | handler: python 9 | options: 10 | docstring_style: google 11 | heading_level: 3 12 | show_root_heading: true 13 | show_source: true 14 | 15 | 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: true 2 | 3 | repos: 4 | - repo: https://github.com/astral-sh/ruff-pre-commit 5 | rev: v0.2.1 6 | hooks: 7 | - id: ruff 8 | files: utilsforecast 9 | - repo: https://github.com/pre-commit/mirrors-mypy 10 | rev: v1.8.0 11 | hooks: 12 | - id: mypy 13 | args: [--ignore-missing-imports] 14 | files: utilsforecast 15 | -------------------------------------------------------------------------------- /scripts/filter_licenses.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | df = pd.read_csv('third_party_licenses.csv') 4 | df = df[df['License'].str.contains('GPL|AGPL|LGPL|MPL', na=False)] 5 | 6 | # if the license has a long agreement, only capture the title and skip the rest 7 | df['License'] = df['License'].apply(lambda x: x.split('\n')[0]) 8 | 9 | df = df[~df['Name'].str.contains('quadprog')] # ignore quadprog 10 | df.to_markdown('THIRD_PARTY_LICENSES.md', index=False) 11 | -------------------------------------------------------------------------------- /.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | update_release_draft: 13 | permissions: 14 | contents: write 15 | pull-requests: read 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: release-drafter/release-drafter@b1476f6e6eb133afa41ed8589daba6dc69b4d3f5 # v6.1.0 19 | env: 20 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 21 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: 'v$NEXT_PATCH_VERSION' 2 | tag-template: 'v$NEXT_PATCH_VERSION' 3 | categories: 4 | - title: 'New Features' 5 | label: 'feature' 6 | - title: 'Breaking Change' 7 | label: 'breaking change' 8 | - title: 'Bug Fixes' 9 | label: 'fix' 10 | - title: 'Documentation' 11 | label: 'documentation' 12 | - title: 'Dependencies' 13 | label: 'dependencies' 14 | - title: 'Enhancement' 15 | label: 'enhancement' 16 | change-template: '- $TITLE @$AUTHOR (#$NUMBER)' 17 | template: | 18 | $CHANGES 19 | -------------------------------------------------------------------------------- /docs/preprocessing.html.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Preprocessing 3 | description: Utilities for processing data before training/analysis 4 | --- 5 | 6 | ::: utilsforecast.preprocessing.id_time_grid 7 | handler: python 8 | options: 9 | docstring_style: google 10 | heading_level: 3 11 | show_root_heading: true 12 | show_source: true 13 | 14 | ::: utilsforecast.preprocessing.fill_gaps 15 | handler: python 16 | options: 17 | docstring_style: google 18 | heading_level: 3 19 | show_root_heading: true 20 | show_source: true 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Clone repo 14 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 15 | 16 | - name: Set up python 17 | uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Install dependencies 22 | run: pip install black nbdev pre-commit 23 | 24 | - name: Run pre-commit 25 | run: pre-commit run --files utilsforecast/* -------------------------------------------------------------------------------- /THIRD_PARTY_LICENSES.md: -------------------------------------------------------------------------------- 1 | | Name | Version | License | Author | URL | 2 | |:---------|:-----------|:--------------------------------------------------|:-------------------------------------|:------------------------------------------| 3 | | certifi | 2025.11.12 | Mozilla Public License 2.0 (MPL 2.0) | Kenneth Reitz | https://github.com/certifi/python-certifi | 4 | | pathspec | 0.12.1 | Mozilla Public License 2.0 (MPL 2.0) | "Caleb P. Burns" | UNKNOWN | 5 | | tqdm | 4.67.1 | MIT License; Mozilla Public License 2.0 (MPL 2.0) | UNKNOWN | https://tqdm.github.io | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | from utilsforecast.data import generate_series 2 | 3 | 4 | def test_data(): 5 | synthetic_panel = generate_series(n_series=2) 6 | synthetic_panel.groupby("unique_id", observed=True).head(4) 7 | level = [40, 80, 95] 8 | series = generate_series(100, n_models=2, level=level) 9 | for model in ["model0", "model1"]: 10 | for lv in level: 11 | assert ( 12 | series[model] 13 | .between(series[f"{model}-lo-{lv}"], series[f"{model}-hi-{lv}"]) 14 | .all() 15 | ) 16 | for lv_lo, lv_hi in zip(level[:-1], level[1:]): 17 | assert series[f"{model}-lo-{lv_lo}"].ge(series[f"{model}-lo-{lv_hi}"]).all() 18 | assert series[f"{model}-hi-{lv_lo}"].le(series[f"{model}-hi-{lv_hi}"]).all() 19 | -------------------------------------------------------------------------------- /.github/workflows/no-response.yaml: -------------------------------------------------------------------------------- 1 | name: No Response Bot 2 | 3 | on: 4 | issue_comment: 5 | types: [created] 6 | schedule: 7 | - cron: '0 4 * * *' 8 | 9 | jobs: 10 | noResponse: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: lee-dohm/no-response@9bb0a4b5e6a45046f00353d5de7d90fb8bd773bb # v0.5.0 14 | with: 15 | closeComment: > 16 | This issue has been automatically closed because it has been awaiting a response for too long. 17 | When you have time to to work with the maintainers to resolve this issue, please post a new comment and it will be re-opened. 18 | If the issue has been locked for editing by the time you return to it, please open a new issue and reference this one. 19 | daysUntilClose: 30 20 | responseRequiredLabel: awaiting response 21 | token: ${{ github.token }} 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-issue.yml: -------------------------------------------------------------------------------- 1 | name: Documentation 2 | title: "[] " 3 | description: Report an issue with the library documentation 4 | labels: [documentation] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: Thank you for helping us improve the library documentation! 9 | 10 | - type: textarea 11 | attributes: 12 | label: Description 13 | description: | 14 | Tell us about the change you'd like to see. For example, "I'd like to 15 | see more examples of how to use `cross_validation`." 16 | validations: 17 | required: true 18 | 19 | - type: textarea 20 | attributes: 21 | label: Link 22 | description: | 23 | If the problem is related to an existing section, please add a link to 24 | the section. 25 | validations: 26 | required: false 27 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | defaults: 9 | run: 10 | shell: bash -l {0} 11 | 12 | jobs: 13 | release: 14 | if: github.repository == 'Nixtla/utilsforecast' 15 | runs-on: ubuntu-latest 16 | permissions: 17 | id-token: write 18 | steps: 19 | - name: Clone repo 20 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 21 | 22 | - name: Set up python 23 | uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 24 | with: 25 | python-version: '3.10' 26 | 27 | - name: Install build dependencies 28 | run: python -m pip install build wheel 29 | 30 | - name: Build distributions 31 | run: python -m build -sw 32 | 33 | - name: Publish package to PyPI 34 | uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: Library feature request 2 | description: Suggest an idea for a project 3 | title: "[] " 4 | labels: [enhancement, feature] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for finding the time to propose a new feature! 10 | We really appreciate the community efforts to improve the nixtlaverse. 11 | 12 | - type: textarea 13 | attributes: 14 | label: Description 15 | description: A short description of your feature 16 | 17 | - type: textarea 18 | attributes: 19 | label: Use case 20 | description: > 21 | Describe the use case of your feature request. It will help us understand and 22 | prioritize the feature request. 23 | placeholder: > 24 | Rather than telling us how you might implement this feature, try to take a 25 | step back and describe what you are trying to achieve. 26 | -------------------------------------------------------------------------------- /docs/mintlify/docs.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://mintlify.com/docs.json", 3 | "theme": "mint", 4 | "name": "Nixtla", 5 | "colors": { 6 | "primary": "#0E0E0E", 7 | "light": "#FAFAFA", 8 | "dark": "#0E0E0E" 9 | }, 10 | "favicon": "/favicon.svg", 11 | "navigation": { 12 | "groups": [ 13 | { 14 | "group": " ", 15 | "pages": [ 16 | "index.html" 17 | ] 18 | }, 19 | { 20 | "group": "API Reference", 21 | "pages": [ 22 | "preprocessing.html", 23 | "feature_engineering.html", 24 | "evaluation.html", 25 | "losses.html", 26 | "plotting.html", 27 | "data.html" 28 | ] 29 | } 30 | ] 31 | }, 32 | "logo": { 33 | "light": "/light.png", 34 | "dark": "/dark.png" 35 | }, 36 | "navbar": { 37 | "primary": { 38 | "type": "github", 39 | "href": "https://github.com/Nixtla/utilsforecast" 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /action_files/remove_logs_cells: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import re 3 | from pathlib import Path 4 | from nbdev.clean import process_write 5 | 6 | IP_REGEX = re.compile(r'[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}') 7 | HOURS_REGEX = re.compile(r'\d{2}:\d{2}:\d{2}') 8 | 9 | def cell_contains_ips(cell): 10 | if 'outputs' not in cell: 11 | return False 12 | for output in cell['outputs']: 13 | if 'text' not in output: 14 | return False 15 | for line in output['text']: 16 | if IP_REGEX.search(line) or HOURS_REGEX.search(line) or '[LightGBM]' in line: 17 | return True 18 | return False 19 | 20 | 21 | def clean_nb(nb): 22 | for cell in nb['cells']: 23 | if cell_contains_ips(cell): 24 | cell['outputs'] = [] 25 | 26 | 27 | if __name__ == '__main__': 28 | repo_root = Path(__file__).parents[1] 29 | for nb in (repo_root / 'nbs').glob('*.ipynb'): 30 | process_write(warn_msg='Failed to clean_nb', proc_nb=clean_nb, f_in=nb) 31 | -------------------------------------------------------------------------------- /docs/plotting.html.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Plotting 3 | description: Time series visualizations 4 | --- 5 | 6 | 7 | ::: utilsforecast.plotting.plot_series 8 | handler: python 9 | options: 10 | docstring_style: google 11 | heading_level: 3 12 | show_root_heading: true 13 | show_source: true 14 | 15 | 16 | ```python 17 | from utilsforecast.data import generate_series 18 | ``` 19 | 20 | 21 | ```python 22 | level = [80, 95] 23 | series = generate_series(4, freq='D', equal_ends=True, with_trend=True, n_models=2, level=level) 24 | test_pd = series.groupby('unique_id', observed=True).tail(10).copy() 25 | train_pd = series.drop(test_pd.index) 26 | ``` 27 | 28 | 29 | ```python 30 | plt.style.use('ggplot') 31 | fig = plot_series( 32 | train_pd, 33 | forecasts_df=test_pd, 34 | ids=[0, 3], 35 | plot_random=False, 36 | level=level, 37 | max_insample_length=50, 38 | engine='matplotlib', 39 | plot_anomalies=True, 40 | ) 41 | fig.savefig('imgs/plotting.png', bbox_inches='tight') 42 | ``` 43 | 44 | ![](imgs/plotting.png) 45 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | load_docs_scripts: 2 | # load processing scripts 3 | if [ ! -d "docs-scripts" ] ; then \ 4 | git clone -b scripts https://github.com/Nixtla/docs.git docs-scripts --single-branch; \ 5 | fi 6 | 7 | api_docs: 8 | python docs/to_mdx.py docs 9 | 10 | examples_docs: 11 | mkdir -p nbs/_extensions 12 | cp -r docs-scripts/mintlify/ nbs/_extensions/mintlify 13 | python docs-scripts/update-quarto.py 14 | quarto render nbs --output-dir ../docs/mintlify/ 15 | 16 | format_docs: 17 | # replace _docs with docs 18 | sed -i -e 's/_docs/docs/g' ./docs-scripts/docs-final-formatting.bash 19 | bash ./docs-scripts/docs-final-formatting.bash 20 | 21 | 22 | preview_docs: 23 | cd docs/mintlify && mintlify dev 24 | 25 | clean: 26 | find docs/mintlify -name "*.mdx" -exec rm -f {} + 27 | 28 | 29 | all_docs: load_docs_scripts api_docs examples_docs format_docs 30 | 31 | licenses: 32 | pip-licenses --format=csv --with-authors --with-urls > third_party_licenses.csv 33 | python scripts/filter_licenses.py 34 | rm -f third_party_licenses.csv 35 | @echo "✓ THIRD_PARTY_LICENSES.md updated" -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | workflow_dispatch: 9 | 10 | concurrency: 11 | group: ${{ github.workflow }}-${{ github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | all-tests: 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | os: [ubuntu-latest, macos-latest, windows-latest] 21 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] 22 | steps: 23 | - name: Clone repo 24 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 25 | 26 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | 30 | - name: Set matplotlib backend (headless) 31 | run: echo "MPLBACKEND=Agg" >> $GITHUB_ENV 32 | shell: bash 33 | 34 | - name: Install dependencies 35 | run: | 36 | pip install -e ".[dev]" 37 | 38 | - name: Run pytest 39 | run: pytest 40 | 41 | -------------------------------------------------------------------------------- /docs/feature_engineering.html.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Feature Engineering 3 | description: Create exogenous regressors for your models 4 | --- 5 | 6 | ::: utilsforecast.feature_engineering.fourier 7 | handler: python 8 | options: 9 | docstring_style: google 10 | heading_level: 3 11 | show_root_heading: true 12 | show_source: true 13 | 14 | ::: utilsforecast.feature_engineering.trend 15 | handler: python 16 | options: 17 | docstring_style: google 18 | heading_level: 3 19 | show_root_heading: true 20 | show_source: true 21 | 22 | ::: utilsforecast.feature_engineering.time_features 23 | handler: python 24 | options: 25 | docstring_style: google 26 | heading_level: 3 27 | show_root_heading: true 28 | show_source: true 29 | 30 | ::: utilsforecast.feature_engineering.future_exog_to_historic 31 | handler: python 32 | options: 33 | docstring_style: google 34 | heading_level: 3 35 | show_root_heading: true 36 | show_source: true 37 | 38 | ::: utilsforecast.feature_engineering.pipeline 39 | handler: python 40 | options: 41 | docstring_style: google 42 | heading_level: 3 43 | show_root_heading: true 44 | show_source: true 45 | 46 | -------------------------------------------------------------------------------- /scripts/extract_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo Python: $(which python) 4 | source /Users/deven367/miniforge3/bin/activate 5 | conda activate nixtla 6 | 7 | echo Activated env: $(which python) 8 | 9 | 10 | mkdir -p ../tests 11 | 12 | tst_flags='datasets distributed matplotlib polars pyarrow scipy' 13 | 14 | nbs=$(ls ../nbs/*.ipynb) 15 | # echo "Available notebooks: $nbs" 16 | 17 | 18 | # this approach3 that was discussed on Slack 19 | # mkdir -p ../tests3 20 | # for flag in $tst_flags; do 21 | # # echo "Extracting $flag" 22 | 23 | # for nb in $nbs; do 24 | # # get name of notebook without extension 25 | # nb_name=$(basename "$nb" .ipynb) 26 | 27 | # # echo "Processing notebook: $nb" 28 | # # print_dir_in_nb "$nb" --dir_name no_dir_and_dir --dir "$flag" >> "../tests/test_$flag_$nb_name.py" 29 | # print_dir_in_nb "$nb" --dir_name no_dir_and_dir --dir "$flag" >> "../tests3/test_${flag}_$nb_name.py" 30 | # done 31 | # done 32 | 33 | for nb in $nbs; do 34 | # get name of notebook without extension 35 | nb_name=$(basename "$nb" .ipynb) 36 | 37 | # echo "Processing notebook: $nb" 38 | # print_dir_in_nb "$nb" --dir_name no_dir_and_dir --dir "$flag" >> "../tests/test_$flag_$nb_name.py" 39 | python cli.py "$nb" --dir_name get_all_tests >> "../tests/test_$nb_name.py" 40 | done -------------------------------------------------------------------------------- /docs/mintlify/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /nbs/_quarto.yml: -------------------------------------------------------------------------------- 1 | format: 2 | mintlify-md: 3 | code-fold: true 4 | metadata-files: 5 | - nbdev.yml 6 | - sidebar.yml 7 | project: 8 | type: mintlify 9 | website: 10 | body-footer: 'Give us a ⭐ on [Github](https://github.com/nixtla/utilsforecast) 11 | 12 | ' 13 | google-analytics: G-NXJNCVR18L 14 | navbar: 15 | background: primary 16 | collapse-below: lg 17 | left: 18 | - menu: 19 | - href: https://github.com/nixtla/statsforecast 20 | text: StatsForecast ⚡️ 21 | - href: https://github.com/nixtla/neuralforecast 22 | text: NeuralForecast 🧠 23 | - href: https://github.com/nixtla/mlforecast 24 | text: MLForecast 🤖 25 | - href: https://github.com/nixtla/hierarchicalforecast 26 | text: HierarchicalForecast 👑 27 | text: NixtlaVerse 28 | - menu: 29 | - href: https://github.com/nixtla/utilsforecast/issues/new/choose 30 | icon: bug 31 | text: Report an Issue 32 | - href: https://join.slack.com/t/nixtlaworkspace/shared_invite/zt-135dssye9-fWTzMpv2WBthq8NK0Yvu6A 33 | icon: chat-right-text 34 | text: Join our Slack 35 | text: Help 36 | right: 37 | - href: https://github.com/nixtla/utilsforecast 38 | icon: github 39 | - aria-label: Nixtla Twitter 40 | href: https://twitter.com/nixtlainc 41 | icon: twitter 42 | search: true 43 | open-graph: 44 | image: https://github.com/Nixtla/styles/blob/2abf51612584169874c90cd7c4d347e3917eaf73/images/Banner%20Github.png 45 | repo-actions: 46 | - issue 47 | sidebar: 48 | style: floating 49 | twitter-card: 50 | image: https://farm6.staticflickr.com/5510/14338202952_93595258ff_z.jpg 51 | site: '@Nixtlainc' 52 | -------------------------------------------------------------------------------- /docs/to_mdx.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | from pathlib import Path 4 | 5 | from convert_to_mkdocstrings import MkDocstringsParser 6 | 7 | comment_pat = re.compile(r"", re.DOTALL) 8 | anchor_pat = re.compile(r"(.*?)") 9 | output_path = Path("docs/mintlify") 10 | 11 | 12 | def process_files(input_dir): 13 | """Process files with MkDocstrings parser, then clean with regex""" 14 | # Step 1: Use MkDocstrings parser to generate initial MDX files 15 | parser = MkDocstringsParser() 16 | for file in Path(input_dir).glob("*.md"): 17 | output_file = file.with_suffix(".mdx").name 18 | print(f"Processing {file} -> {output_file}") 19 | parser.process_file(str(file), str(Path(input_dir) / "mintlify" / output_file)) 20 | 21 | # Step 2: Clean up the generated MDX files with regex patterns 22 | for mdx_file in (Path(input_dir) / "mintlify").glob("*.mdx"): 23 | if mdx_file.name == "index.mdx": # Skip index.mdx as it's handled separately 24 | continue 25 | print(f"Cleaning up {mdx_file}") 26 | text = mdx_file.read_text() 27 | text = comment_pat.sub("", text) 28 | text = anchor_pat.sub("", text) 29 | mdx_file.write_text(text) 30 | 31 | 32 | def copy_readme(): 33 | """Copy README.md to index.mdx with proper header""" 34 | header = """--- 35 | description: Forecasting utilities 36 | title: "utilsforecast" 37 | --- 38 | """ 39 | readme_text = Path("README.md").read_text() 40 | readme_text = header + readme_text 41 | (output_path / "index.html.mdx").write_text(readme_text) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser(description="Process markdown files to MDX format") 46 | parser.add_argument( 47 | "input_dir", nargs="?", default="docs", help="Input directory (default: docs)" 48 | ) 49 | args = parser.parse_args() 50 | 51 | # Step 1: Process files with MkDocstrings parser, then clean with regex 52 | process_files(args.input_dir) 53 | 54 | # Step 2: Always copy the README 55 | copy_readme() -------------------------------------------------------------------------------- /utilsforecast/compat.py: -------------------------------------------------------------------------------- 1 | __all__ = ["DataFrame", "Series", "DistributedDFType", "AnyDFType"] 2 | 3 | 4 | import warnings 5 | from functools import wraps 6 | from typing import TypeVar, Union 7 | 8 | import pandas as pd 9 | 10 | try: 11 | import polars 12 | import polars as pl 13 | from polars import DataFrame as pl_DataFrame 14 | from polars import Expr as pl_Expr 15 | from polars import Series as pl_Series 16 | 17 | DFType = TypeVar("DFType", pd.DataFrame, polars.DataFrame) 18 | POLARS_INSTALLED = True 19 | except ImportError: 20 | pl = None 21 | 22 | class pl_DataFrame: ... 23 | 24 | class pl_Expr: ... 25 | 26 | class pl_Series: ... 27 | 28 | DFType = pd.DataFrame 29 | POLARS_INSTALLED = False 30 | 31 | try: 32 | from numba import njit # noqa: F04 33 | except ImportError: 34 | 35 | def _doublewrap(f): 36 | @wraps(f) 37 | def new_dec(*args, **kwargs): 38 | if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): 39 | return f(args[0]) 40 | else: 41 | return lambda realf: f(realf, *args, **kwargs) 42 | 43 | return new_dec 44 | 45 | @_doublewrap 46 | def njit(f, *_args, **_kwargs): 47 | @wraps(f) 48 | def wrapper(*args, **kwargs): 49 | warnings.warn( 50 | "numba is not installed, some operations may be very slow. " 51 | "You can find install instructions at " 52 | "https://numba.pydata.org/numba-doc/latest/user/installing.html" 53 | ) 54 | return f(*args, **kwargs) 55 | 56 | return wrapper 57 | 58 | 59 | try: 60 | from dask.dataframe import DataFrame as DaskDataFrame 61 | except ModuleNotFoundError: 62 | pass 63 | 64 | try: 65 | from pyspark.sql import DataFrame as SparkDataFrame 66 | except ModuleNotFoundError: 67 | pass 68 | 69 | DataFrame = Union[pd.DataFrame, pl_DataFrame] 70 | Series = Union[pd.Series, pl_Series] 71 | DistributedDFType = TypeVar( 72 | "DistributedDFType", 73 | "DaskDataFrame", 74 | "SparkDataFrame", 75 | ) 76 | AnyDFType = TypeVar( 77 | "AnyDFType", 78 | "DaskDataFrame", 79 | pd.DataFrame, 80 | "pl_DataFrame", 81 | "SparkDataFrame", 82 | ) 83 | -------------------------------------------------------------------------------- /action_files/nbdev_test: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time,os,sys,traceback,contextlib, inspect 3 | from fastcore.basics import * 4 | from fastcore.imports import * 5 | from fastcore.foundation import * 6 | from fastcore.parallel import * 7 | from fastcore.script import * 8 | from fastcore.meta import delegates 9 | 10 | from nbdev.config import * 11 | from nbdev.doclinks import * 12 | from nbdev.process import NBProcessor, nb_lang 13 | from nbdev.frontmatter import FrontmatterProc 14 | from nbdev.test import _keep_file, test_nb 15 | 16 | from execnb.nbio import * 17 | from execnb.shell import * 18 | 19 | 20 | @call_parse 21 | @delegates(nbglob_cli) 22 | def nbdev_test( 23 | path:str=None, # A notebook name or glob to test 24 | flags:str='', # Space separated list of test flags to run that are normally ignored 25 | n_workers:int=None, # Number of workers 26 | timing:bool=False, # Time each notebook to see which are slow 27 | do_print:bool=False, # Print start and end of each notebook 28 | pause:float=0.01, # Pause time (in seconds) between notebooks to avoid race conditions 29 | ignore_fname:str='.notest', # Filename that will result in siblings being ignored 30 | **kwargs): 31 | "Test in parallel notebooks matching `path`, passing along `flags`" 32 | skip_flags = get_config().tst_flags.split() 33 | force_flags = flags.split() 34 | files = nbglob(path, as_path=True, **kwargs) 35 | files = [f.absolute() for f in sorted(files) if _keep_file(f, ignore_fname)] 36 | if len(files)==0: return print('No files were eligible for testing') 37 | 38 | if n_workers is None: n_workers = 0 if len(files)==1 else min(num_cpus(), 8) 39 | kw = {'method': 'spawn'} 40 | wd_pth = get_config().nbs_path 41 | with working_directory(wd_pth if (wd_pth and wd_pth.exists()) else os.getcwd()): 42 | results = parallel(test_nb, files, skip_flags=skip_flags, force_flags=force_flags, n_workers=n_workers, 43 | basepath=get_config().config_path, pause=pause, do_print=do_print, **kw) 44 | passed,times = zip(*results) 45 | if all(passed): print("Success.") 46 | else: 47 | _fence = '='*50 48 | failed = '\n\t'.join(f.name for p,f in zip(passed,files) if not p) 49 | sys.stderr.write(f"\nnbdev Tests Failed On The Following Notebooks:\n{_fence}\n\t{failed}\n") 50 | sys.exit(1) 51 | if timing: 52 | for i,t in sorted(enumerate(times), key=lambda o:o[1], reverse=True): print(f"{files[i].name}: {int(t)} secs") 53 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | title: "[] " 3 | description: Problems and issues with code of the library 4 | labels: [bug] 5 | body: 6 | - type: markdown 7 | attributes: 8 | value: | 9 | Thank you for reporting the problem.. 10 | Please make sure what you are reporting is a bug with reproducible steps. To ask questions 11 | or share ideas, please post on our [Slack community](https://join.slack.com/t/nixtlacommunity/shared_invite/zt-1h77esh5y-iL1m8N0F7qV1HmH~0KYeAQ) instead. 12 | 13 | - type: textarea 14 | attributes: 15 | label: What happened + What you expected to happen 16 | description: Describe 1. the bug 2. expected behavior 3. useful information (e.g., logs) 17 | placeholder: > 18 | Please provide the context in which the problem occurred and explain what happened. Further, 19 | please also explain why you think the behaviour is erroneous. It is extremely helpful if you can 20 | copy and paste the fragment of logs showing the exact error messages or wrong behaviour here. 21 | 22 | **NOTE**: please copy and paste texts instead of taking screenshots of them for easy future search. 23 | validations: 24 | required: true 25 | 26 | - type: textarea 27 | attributes: 28 | label: Versions / Dependencies 29 | description: Please specify the versions of the library, Python, OS, and other libraries that are used. 30 | value: | 31 |
Click to expand 32 | Dependencies: 33 | 34 |
35 | validations: 36 | required: true 37 | 38 | - type: textarea 39 | attributes: 40 | label: Reproducible example 41 | description: > 42 | Please provide a reproducible script. Providing a simple way to reproduce the issue 43 | (minimal / no external dependencies) will help us triage and address issues in the timely manner! 44 | value: | 45 | ```python 46 | # paste your code here 47 | ``` 48 | validations: 49 | required: true 50 | 51 | - type: dropdown 52 | attributes: 53 | label: Issue Severity 54 | description: | 55 | How does this issue affect your experience as user? 56 | multiple: false 57 | options: 58 | - "Low: It annoys or frustrates me." 59 | - "Medium: It is a significant difficulty but I can work around it." 60 | - "High: It blocks me from completing my task." 61 | validations: 62 | required: false 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _docs/ 2 | _proc/ 3 | 4 | *.bak 5 | .gitattributes 6 | .last_checked 7 | .gitconfig 8 | *.bak 9 | *.log 10 | *~ 11 | ~* 12 | _tmp* 13 | tmp* 14 | tags 15 | *.pkg 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # dotenv 99 | .env 100 | 101 | # virtualenv 102 | .venv 103 | venv/ 104 | ENV/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | .vscode 120 | *.swp 121 | 122 | # osx generated files 123 | .DS_Store 124 | .DS_Store? 125 | .Trashes 126 | ehthumbs.db 127 | Thumbs.db 128 | .idea 129 | 130 | # pytest 131 | .pytest_cache 132 | 133 | # tools/trust-doc-nbs 134 | docs_src/.last_checked 135 | 136 | # symlinks to fastai 137 | docs_src/fastai 138 | tools/fastai 139 | 140 | # link checker 141 | checklink/cookies.txt 142 | 143 | # .gitconfig is now autogenerated 144 | .gitconfig 145 | 146 | # Quarto installer 147 | .deb 148 | .pkg 149 | 150 | # Quarto 151 | .quarto 152 | 153 | *.csv 154 | *.parquet 155 | nbs/_extensions 156 | 157 | # VSCode 158 | *.code-workspace 159 | 160 | docs/**/*.mdx 161 | 162 | docs-scripts 163 | docs/mintlify/examples 164 | 165 | .quarto 166 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name="utilsforecast" 7 | dynamic = ["version"] 8 | description = "Forecasting utilities" 9 | authors = [{name = "Nixtla", email = "business@nixtla.io"}] 10 | license = {text = "Apache Software License 2.0"} 11 | readme = "README.md" 12 | requires-python=">=3.9" 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "Intended Audience :: Developers", 16 | "Natural Language :: English", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Programming Language :: Python :: 3.9", 19 | "Programming Language :: Python :: 3.10", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Programming Language :: Python :: 3.13", 23 | ] 24 | keywords = ["time-series", "analysis", "forecasting"] 25 | dependencies = [ 26 | "numpy", 27 | "packaging", 28 | "pandas>=1.1.1", 29 | "narwhals>=2.0", 30 | ] 31 | 32 | [tool.setuptools.dynamic] 33 | version = {attr = "utilsforecast.__version__"} 34 | 35 | [project.optional-dependencies] 36 | plotting = [ 37 | "pandas[plot]", 38 | "plotly", 39 | "plotly-resampler", 40 | ] 41 | polars = [ 42 | "polars[numpy]<=1.31", 43 | ] 44 | dev = [ 45 | "black", 46 | "datasetsforecast==0.0.8", 47 | "nbformat", 48 | "numba>=0.58.0", 49 | "pip", 50 | "pre-commit", 51 | "pyarrow", 52 | "scipy", 53 | "pandas[plot]", 54 | "plotly", 55 | "plotly-resampler", 56 | "polars[numpy]", 57 | "pytest", 58 | "pytest-cov", 59 | "fugue[dask,spark]>=0.8.1", 60 | "dask<=2024.12.1", 61 | "pip-licenses", 62 | "mkdocstrings-parser@git+https://github.com/Nixtla/mkdocstrings-parser.git", 63 | ] 64 | 65 | [project.urls] 66 | Homepage = "https://github.com/Nixtla/utilsforecast" 67 | Documentation = "https://nixtlaverse.verse.io/utilsforecast" 68 | Repository = "https://github.com/Nixtla/utilsforecast" 69 | 70 | 71 | 72 | [tool.setuptools] 73 | include-package-data = true 74 | 75 | [tool.setuptools.packages.find] 76 | include = ["utilsforecast*"] 77 | 78 | [tool.setuptools.package-data] 79 | utilsforecast = ["py.typed"] 80 | 81 | [tool.mypy] 82 | ignore_missing_imports = true 83 | 84 | [[tool.mypy.overrides]] 85 | module = 'utilsforecast.compat' 86 | ignore_errors = true 87 | 88 | [tool.ruff.lint] 89 | select = ["F", "ARG", "I"] 90 | 91 | [tool.ruff.format] 92 | quote-style = "double" 93 | 94 | [tool.coverage] 95 | branch = true 96 | source = ["utilsforecast"] 97 | 98 | [tool.coverage.run] 99 | omit = ["tests/*"] 100 | 101 | [tool.coverage.report] 102 | fail_under = 80 103 | show_missing = true 104 | 105 | [tool.pytest.ini_options] 106 | testpaths = ["tests"] 107 | #addopts = "--cov utilsforecast -vv" 108 | -------------------------------------------------------------------------------- /scripts/cli.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: cli.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['tst_flags', 'to_skip', 'mapper', 'print_execs', 'print_hide', 'other_tests', 'get_markdown', 'extract_dir', 5 | 'no_dir_and_dir', 'get_all_tests', 'print_dir_in_nb'] 6 | 7 | # %% cli.ipynb 1 8 | from execnb.nbio import read_nb 9 | from nbdev.processors import NBProcessor 10 | from nbdev.export import ExportModuleProc, nb_export 11 | from nbdev.maker import ModuleMaker 12 | from fastcore.xtras import globtastic, Path 13 | from functools import partial 14 | from fastcore.script import call_parse 15 | from nbdev import nbdev_export 16 | 17 | # %% cli.ipynb 3 18 | tst_flags = 'datasets distributed matplotlib polars pyarrow scipy'.split() 19 | to_skip = [ 20 | 'showdoc', 21 | 'load_ext', 22 | 'from nbdev' 23 | ] 24 | 25 | 26 | def print_execs(cell): 27 | if 'exec' in cell.source: print(cell.source) 28 | 29 | def print_hide(cell): 30 | if 'hide' in cell.directives_: print(cell.source) 31 | 32 | def other_tests(cell): 33 | if len(cell.directives_) == 0: 34 | print(cell.source) 35 | 36 | def get_markdown(cell): 37 | if cell.cell_type == "markdown": 38 | print(cell.source) 39 | 40 | def extract_dir(cell, dir): 41 | if dir in cell.directives_: 42 | print(cell.source) 43 | 44 | def no_dir_and_dir(cell, dir): 45 | if len(cell.directives_) == 0: 46 | print(cell.source) 47 | 48 | if dir in cell.directives_: 49 | print(cell.source) 50 | 51 | def get_all_tests2(cell): 52 | if cell.cell_type == "code": 53 | 54 | if len(cell.directives_) == 0: 55 | print(cell.source) 56 | 57 | 58 | elif any(x in tst_flags + ['hide'] for x in cell.directives_): 59 | if not (x in cell.source for x in to_skip): 60 | print(cell.source) 61 | 62 | def get_all_tests(cell): 63 | if len(cell.directives_) == 0: 64 | print(cell.source) 65 | 66 | if any(x in tst_flags + ["hide"] for x in cell.directives_): 67 | print(cell.source) 68 | 69 | 70 | 71 | # %% cli.ipynb 7 72 | mapper = { 73 | 'print_execs': print_execs, 74 | 'print_hide': print_hide, 75 | 'other_tests': other_tests, 76 | 'get_markdown': get_markdown, 77 | 'extract_dir': extract_dir, 78 | 'no_dir_and_dir': no_dir_and_dir, 79 | 'get_all_tests':get_all_tests 80 | } 81 | 82 | # %% cli.ipynb 8 83 | @call_parse 84 | def print_dir_in_nb(nb_path:str, 85 | dir:str=None, 86 | dir_name:str=None, 87 | ): 88 | if dir_name not in mapper.keys(): 89 | raise ValueError(f'Choose processor from the the following: {mapper.keys()}') 90 | 91 | if dir_name == 'extract_dir': 92 | processor = NBProcessor(nb_path, partial(extract_dir, dir=dir)) 93 | processor.process() 94 | return 95 | elif dir_name == 'no_dir_and_dir': 96 | processor = NBProcessor(nb_path, partial(no_dir_and_dir, dir=dir)) 97 | processor.process() 98 | return 99 | 100 | processor = NBProcessor(nb_path, mapper[dir_name]) 101 | processor.process() 102 | 103 | -------------------------------------------------------------------------------- /tests/test_validation.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import pandas as pd 4 | import polars as pl 5 | from conftest import assert_raises_with_message 6 | 7 | from utilsforecast.compat import POLARS_INSTALLED 8 | from utilsforecast.validation import ( 9 | _is_dt_dtype, 10 | _is_int_dtype, 11 | ensure_time_dtype, 12 | validate_format, 13 | validate_freq, 14 | ) 15 | 16 | 17 | def test_dtypes(): 18 | assert _is_int_dtype(pd.Series([1, 2])) 19 | assert _is_int_dtype(pd.Index([1, 2], dtype='uint8')) 20 | assert not _is_int_dtype(pd.Series([1.0])) 21 | assert _is_dt_dtype(pd.to_datetime(['2000-01-01'])) 22 | assert _is_dt_dtype(pd.to_datetime(['2000-01-01'], utc=True)) 23 | 24 | 25 | def test_dtypes_arrow(): 26 | assert _is_dt_dtype(pd.to_datetime(['2000-01-01']).astype('datetime64[s]')) 27 | assert _is_int_dtype(pd.Series([1, 2], dtype='int32[pyarrow]')) 28 | assert _is_dt_dtype(pd.to_datetime(['2000-01-01']).astype('timestamp[ns][pyarrow]')) 29 | assert _is_int_dtype(pl.Series([1, 2])) 30 | assert _is_int_dtype(pl.Series([1, 2], dtype=pl.UInt8)) 31 | 32 | 33 | def test_dtypes_polars(): 34 | assert not _is_int_dtype(pl.Series([1.0])) 35 | assert _is_dt_dtype(pl.Series([datetime.date(2000, 1, 1)])) 36 | assert _is_dt_dtype(pl.Series([datetime.datetime(2000, 1, 1)])) 37 | assert _is_dt_dtype( 38 | pl.Series([datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc)]) 39 | ) 40 | 41 | 42 | def test_ensure_time_dtype(): 43 | pd.testing.assert_frame_equal( 44 | ensure_time_dtype(pd.DataFrame({'ds': ['2000-01-01']})), 45 | pd.DataFrame({'ds': pd.to_datetime(['2000-01-01'])}), 46 | ) 47 | df = pd.DataFrame({'ds': [1, 2]}) 48 | assert df is ensure_time_dtype(df) 49 | assert_raises_with_message(ensure_time_dtype, 'Please make sure that it contains valid timestamps', pd.DataFrame({'ds': ['2000-14-14']})) 50 | pl.testing.assert_frame_equal( 51 | ensure_time_dtype(pl.DataFrame({'ds': ['2000-01-01']})), 52 | pl.DataFrame().with_columns(ds=pl.datetime(2000, 1, 1)), 53 | ) 54 | df = pl.DataFrame({'ds': [1, 2]}) 55 | assert df is ensure_time_dtype(df) 56 | assert_raises_with_message(ensure_time_dtype, 'Please make sure that it contains valid timestamps', pl.DataFrame({'ds': ['hello']})) 57 | 58 | 59 | def test_validate_format(): 60 | assert_raises_with_message(validate_format, "got ", 1) 61 | constructors = [pd.DataFrame] 62 | if POLARS_INSTALLED: 63 | constructors.append(pl.DataFrame) 64 | for constructor in constructors: 65 | df = constructor({'unique_id': [1]}) 66 | assert_raises_with_message(validate_format, "missing: ['ds', 'y']", df) 67 | df = constructor({'unique_id': [1], 'time': ['x'], 'y': [1]}) 68 | assert_raises_with_message(validate_format,"('time') should have either timestamps or integers", df, time_col='time'), 69 | for time in [1, datetime.datetime(2000, 1, 1)]: 70 | df = constructor({'unique_id': [1], 'ds': [time], 'sales': ['x']}) 71 | assert_raises_with_message(validate_format, "('sales') should have a numeric data type", df, target_col='sales') 72 | 73 | 74 | def test_validate_freq(): 75 | assert_raises_with_message(validate_freq, 'provide a valid integer',pd.Series([1, 2]), 'D') 76 | assert_raises_with_message(validate_freq, 'provide a valid pandas or polars offset', pd.to_datetime(['2000-01-01']).to_series(), 1), 77 | assert_raises_with_message(validate_freq, 'provide a valid integer', pl.Series([1, 2]), '1d') 78 | assert_raises_with_message(validate_freq, 'provide a valid pandas or polars offset', pl.Series([datetime.datetime(2000, 1, 1)]), 1), 79 | assert_raises_with_message(validate_freq, 'valid polars offset', pl.Series([datetime.datetime(2000, 1, 1)]), 'D'), 80 | 81 | -------------------------------------------------------------------------------- /.github/workflows/build-docs.yaml: -------------------------------------------------------------------------------- 1 | name: build-docs 2 | on: 3 | push: 4 | branches: [main] 5 | pull_request: 6 | branches: [main] 7 | release: 8 | types: [published] 9 | workflow_dispatch: 10 | inputs: 11 | environment: 12 | description: "The environment to deploy to" 13 | required: True 14 | type: choice 15 | default: "staging" 16 | options: 17 | - staging 18 | - production 19 | 20 | jobs: 21 | build-docs: 22 | runs-on: ubuntu-latest 23 | steps: 24 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 25 | with: 26 | submodules: "true" 27 | 28 | - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 29 | with: 30 | python-version: '3.10' 31 | 32 | - name: Install dependencies 33 | run: pip install uv && uv pip install --system '.[dev]' 34 | 35 | # setup quarto for rendering example/tutorial nbs 36 | - uses: quarto-dev/quarto-actions/setup@v2 37 | with: 38 | version: 1.4.515 39 | 40 | - name: Build Docs 41 | run: make all_docs 42 | 43 | - name: Deploy (Push to main or Pull Request) 44 | if: (github.event_name == 'push' && github.ref == 'refs/heads/main') || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'workflow_dispatch' && github.event.inputs.environment == 'staging') 45 | uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 46 | with: 47 | github_token: ${{ secrets.GITHUB_TOKEN }} 48 | publish_branch: docs-preview 49 | publish_dir: docs/mintlify 50 | user_name: github-actions[bot] 51 | user_email: 41898282+github-actions[bot]@users.noreply.github.com 52 | 53 | - name: Deploy (Release) 54 | if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && github.event.inputs.environment == 'production') 55 | uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 56 | with: 57 | github_token: ${{ secrets.GITHUB_TOKEN }} 58 | publish_branch: docs 59 | publish_dir: docs/mintlify 60 | user_name: github-actions[bot] 61 | user_email: 41898282+github-actions[bot]@users.noreply.github.com 62 | tag_name: ${{ github.event.release.tag_name }} 63 | tag_message: 'Documentation for release ${{ github.event.release.tag_name }}' 64 | 65 | - name: Trigger mintlify workflow (Push to main or Pull Request) 66 | if: (github.event_name == 'push' && github.ref == 'refs/heads/main') || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'workflow_dispatch' && github.event.inputs.environment == 'staging') 67 | uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 68 | with: 69 | github-token: ${{ secrets.DOCS_WORKFLOW_TOKEN }} 70 | script: | 71 | await github.rest.actions.createWorkflowDispatch({ 72 | owner: 'nixtla', 73 | repo: 'docs', 74 | workflow_id: 'preview.yml', 75 | ref: 'main', 76 | }); 77 | 78 | - name: Trigger mintlify workflow (Release) 79 | if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && github.event.inputs.environment == 'production') 80 | uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 81 | with: 82 | github-token: ${{ secrets.DOCS_WORKFLOW_TOKEN }} 83 | script: | 84 | await github.rest.actions.createWorkflowDispatch({ 85 | owner: 'nixtla', 86 | repo: 'docs', 87 | workflow_id: 'production.yml', 88 | ref: 'main', 89 | }); 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Install 2 | 3 | ### PyPI 4 | 5 | ```sh 6 | pip install utilsforecast 7 | ``` 8 | 9 | ### Conda 10 | 11 | ```sh 12 | conda install -c conda-forge utilsforecast 13 | ``` 14 | 15 | --- 16 | 17 | ## How to use 18 | 19 | ### Generate synthetic data 20 | 21 | ```python 22 | from utilsforecast.data import generate_series 23 | 24 | series = generate_series(3, with_trend=True, static_as_categorical=False) 25 | series 26 | ``` 27 | 28 | ``` 29 | | | unique_id | ds | y | 30 | |-----|-----------|------------|------------| 31 | | 0 | 0 | 2000-01-01 | 0.422133 | 32 | | 1 | 0 | 2000-01-02 | 1.501407 | 33 | | 2 | 0 | 2000-01-03 | 2.568495 | 34 | | 3 | 0 | 2000-01-04 | 3.529085 | 35 | | 4 | 0 | 2000-01-05 | 4.481929 | 36 | | ... | ... | ... | ... | 37 | | 481 | 2 | 2000-06-11 | 163.914625 | 38 | | 482 | 2 | 2000-06-12 | 166.018479 | 39 | | 483 | 2 | 2000-06-13 | 160.839176 | 40 | | 484 | 2 | 2000-06-14 | 162.679603 | 41 | | 485 | 2 | 2000-06-15 | 165.089288 | 42 | ``` 43 | 44 | --- 45 | 46 | ### Plotting 47 | 48 | ```python 49 | from utilsforecast.plotting import plot_series 50 | 51 | fig = plot_series(series, plot_random=False, max_insample_length=50, engine='matplotlib') 52 | fig.savefig('imgs/index.png', bbox_inches='tight') 53 | ``` 54 | 55 | ![](./docs/mintlify/imgs/index.png) 56 | ![](./imgs/index.png) 57 | 58 | --- 59 | 60 | ### Preprocessing 61 | 62 | ```python 63 | from utilsforecast.preprocessing import fill_gaps 64 | 65 | serie = series[series['unique_id'].eq(0)].tail(10) 66 | # drop some points 67 | with_gaps = serie.sample(frac=0.5, random_state=0).sort_values('ds') 68 | with_gaps 69 | ``` 70 | 71 | Example output with missing dates: 72 | 73 | ``` 74 | | | unique_id | ds | y | 75 | |-----|-----------|------------|-----------| 76 | | 213 | 0 | 2000-08-01 | 18.543147 | 77 | | 214 | 0 | 2000-08-02 | 19.941764 | 78 | | 216 | 0 | 2000-08-04 | 21.968733 | 79 | | 220 | 0 | 2000-08-08 | 19.091509 | 80 | | 221 | 0 | 2000-08-09 | 20.220739 | 81 | ``` 82 | 83 | ```python 84 | fill_gaps(with_gaps, freq='D') 85 | ``` 86 | 87 | Returns: 88 | 89 | ``` 90 | | | unique_id | ds | y | 91 | |-----|-----------|------------|-----------| 92 | | 0 | 0 | 2000-08-01 | 18.543147 | 93 | | 1 | 0 | 2000-08-02 | 19.941764 | 94 | | 2 | 0 | 2000-08-03 | NaN | 95 | | 3 | 0 | 2000-08-04 | 21.968733 | 96 | | 4 | 0 | 2000-08-05 | NaN | 97 | | 5 | 0 | 2000-08-06 | NaN | 98 | | 6 | 0 | 2000-08-07 | NaN | 99 | | 7 | 0 | 2000-08-08 | 19.091509 | 100 | | 8 | 0 | 2000-08-09 | 20.220739 | 101 | ``` 102 | 103 | --- 104 | 105 | ### Evaluating 106 | 107 | ```python 108 | from functools import partial 109 | import numpy as np 110 | 111 | from utilsforecast.evaluation import evaluate 112 | from utilsforecast.losses import mape, mase 113 | ``` 114 | 115 | ```python 116 | valid = series.groupby('unique_id').tail(7).copy() 117 | train = series.drop(valid.index) 118 | 119 | rng = np.random.RandomState(0) 120 | valid['seas_naive'] = train.groupby('unique_id')['y'].tail(7).values 121 | valid['rand_model'] = valid['y'] * rng.rand(valid['y'].shape[0]) 122 | 123 | daily_mase = partial(mase, seasonality=7) 124 | 125 | evaluate(valid, metrics=[mape, daily_mase], train_df=train) 126 | ``` 127 | 128 | 129 | ``` 130 | | | unique_id | metric | seas_naive | rand_model | 131 | |-----|-----------|--------|------------|------------| 132 | | 0 | 0 | mape | 0.024139 | 0.440173 | 133 | | 1 | 1 | mape | 0.054259 | 0.278123 | 134 | | 2 | 2 | mape | 0.042642 | 0.480316 | 135 | | 3 | 0 | mase | 0.907149 | 16.418014 | 136 | | 4 | 1 | mase | 0.991635 | 6.404254 | 137 | | 5 | 2 | mase | 1.013596 | 11.365040 | 138 | ``` 139 | 140 | --- 141 | -------------------------------------------------------------------------------- /tests/test_plotting.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from itertools import product 3 | from pathlib import Path 4 | 5 | import matplotlib.pyplot as plt 6 | import pytest 7 | from conftest import assert_raises_with_message 8 | 9 | from utilsforecast.compat import POLARS_INSTALLED 10 | from utilsforecast.data import generate_series 11 | from utilsforecast.plotting import plot_series 12 | 13 | if POLARS_INSTALLED: 14 | import polars as pl 15 | 16 | try: 17 | import plotly 18 | PLOTLY_INSTALLED = True 19 | except ImportError: 20 | PLOTLY_INSTALLED = False 21 | 22 | try: 23 | import plotly_resampler 24 | PLOTLY_RESAMPLER_INSTALLED = True 25 | except ImportError: 26 | PLOTLY_RESAMPLER_INSTALLED = False 27 | 28 | 29 | @pytest.fixture 30 | def set_paths(): 31 | ROOT_DIR = Path(__file__).resolve().parent.parent 32 | IMG_PATH = ROOT_DIR / "nbs" / "imgs" 33 | IMG_PATH.mkdir(parents=True, exist_ok=True) 34 | return IMG_PATH 35 | 36 | 37 | @pytest.fixture 38 | def set_series(): 39 | level = [80, 95] 40 | series = generate_series( 41 | 4, freq="D", equal_ends=True, with_trend=True, n_models=2, level=level 42 | ) 43 | test_pd = series.groupby("unique_id", observed=True).tail(10).copy() 44 | train_pd = series.drop(test_pd.index) 45 | return series, test_pd, train_pd, level 46 | 47 | 48 | def test_plot_series(set_series, set_paths): 49 | _, test_pd, train_pd, level = set_series 50 | plt.style.use("ggplot") 51 | fig = plot_series( 52 | train_pd, 53 | forecasts_df=test_pd, 54 | ids=[0, 3], 55 | plot_random=False, 56 | level=level, 57 | max_insample_length=50, 58 | engine="matplotlib", 59 | plot_anomalies=True, 60 | ) 61 | fig.savefig(set_paths / "plotting.png", bbox_inches="tight") 62 | 63 | 64 | # Prepare combinations 65 | bools = [True, False] 66 | polars_opts = bools if POLARS_INSTALLED else [False] 67 | anomalies = bools 68 | randoms = bools 69 | forecasts = bools 70 | ids_list = [[0], [3, 1], None] 71 | levels = [[80], None] 72 | max_insample_lengths = [None, 50] 73 | 74 | engines = ["matplotlib"] 75 | if PLOTLY_INSTALLED: 76 | engines.append("plotly") 77 | if PLOTLY_RESAMPLER_INSTALLED: 78 | engines.append("plotly-resampler") 79 | 80 | @pytest.mark.parametrize("as_polars,ids,plot_anomalies,level,max_insample_length,engine,plot_random,with_forecasts", product( 81 | polars_opts, ids_list, anomalies, levels, max_insample_lengths, engines, randoms, forecasts 82 | )) 83 | def test_plotting_combinations( 84 | as_polars, 85 | ids, 86 | plot_anomalies, 87 | level, 88 | max_insample_length, 89 | engine, 90 | plot_random, 91 | with_forecasts, 92 | set_series, 93 | ): 94 | _, test_pd, train_pd, _ = set_series 95 | 96 | if POLARS_INSTALLED and as_polars: 97 | train = pl.DataFrame(train_pd.to_records(index=False)) 98 | test = pl.DataFrame(test_pd.to_records(index=False)) if with_forecasts else None 99 | else: 100 | train = train_pd 101 | test = test_pd if with_forecasts else None 102 | 103 | fn = lambda: plot_series( 104 | train, 105 | forecasts_df=test, 106 | ids=ids, 107 | plot_random=plot_random, 108 | plot_anomalies=plot_anomalies, 109 | level=level, 110 | max_insample_length=max_insample_length, 111 | engine=engine, 112 | ) 113 | 114 | if level is None and plot_anomalies: 115 | assert_raises_with_message(fn, "specify the `level` argument") 116 | elif level is not None and plot_anomalies and not with_forecasts: 117 | assert_raises_with_message(fn, "provide a `forecasts_df` with prediction") 118 | else: 119 | with warnings.catch_warnings(): 120 | warnings.filterwarnings( 121 | "ignore", 122 | message="The behavior of DatetimeProperties.to_pydatetime is deprecated", 123 | category=FutureWarning, 124 | ) 125 | fn() 126 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | ## Did you find a bug? 4 | 5 | - Ensure the bug was not already reported by searching on GitHub under Issues. 6 | - If you're unable to find an open issue addressing the problem, open a new one. 7 | Be sure to include a title and clear description, as much relevant information 8 | as possible, and a code sample or an executable test case demonstrating the 9 | expected behavior that is not occurring. 10 | - Be sure to add the complete error messages. 11 | 12 | ## Do you have a feature request? 13 | 14 | - Ensure that it hasn't been yet implemented in the `main` branch of the 15 | repository and that there's not an Issue requesting it yet. 16 | - Open a new issue and make sure to describe it clearly, mention how it improves 17 | the project and why its useful. 18 | 19 | ## Do you want to fix a bug or implement a feature? 20 | 21 | Bug fixes and features are added through pull requests (PRs). 22 | 23 | ## PR submission guidelines 24 | 25 | - Keep each PR focused. While it's more convenient, do not combine several 26 | unrelated fixes together. Create as many branches as needing to keep each PR 27 | focused. 28 | - Ensure that your PR includes a test that fails without your patch, and passes 29 | with it. 30 | - Ensure the PR description clearly describes the problem and solution. Include 31 | the relevant issue number if applicable. 32 | - Do not mix style changes/fixes with "functional" changes. It's very difficult 33 | to review such PRs and it most likely get rejected. 34 | - Do not add/remove vertical whitespace. Preserve the original style of the file 35 | you edit as much as you can. 36 | - Do not turn an already submitted PR into your development playground. If after 37 | you submitted PR, you discovered that more work is needed - close the PR, do 38 | the required work and then submit a new PR. Otherwise each of your commits 39 | requires attention from maintainers of the project. 40 | - If, however, you submitted a PR and received a request for changes, you should 41 | proceed with commits inside that PR, so that the maintainer can see the 42 | incremental fixes and won't need to review the whole PR again. In the 43 | exception case where you realize it'll take many many commits to complete the 44 | requests, then it's probably best to close the PR, do the work and then submit 45 | it again. Use common sense where you'd choose one way over another. 46 | 47 | ### Local setup for working on a PR 48 | 49 | #### Clone the repository 50 | 51 | - HTTPS: `git clone https://github.com/Nixtla/utilsforecast.git` 52 | - SSH: `git clone git@github.com:Nixtla/utilsforecast.git` 53 | - GitHub CLI: `gh repo clone Nixtla/utilsforecast` 54 | 55 | ## 🛠️ Create the Development Environment 56 | 57 | ```bash 58 | pip install uv 59 | uv venv --python 3.10 60 | source .venv/bin/activate 61 | 62 | # Install the library in editable mode for development 63 | uv pip install -e ".[dev]" -U 64 | ``` 65 | 66 | ## 🔧 Install Pre-commit Hooks 67 | 68 | Pre-commit hooks help maintain code quality by running checks before commits. 🛡️ 69 | 70 | ```bash 71 | pre-commit install 72 | pre-commit run --all-files 73 | ``` 74 | 75 | ## Viewing documentation locally 76 | 77 | The new documentation pipeline relies on `mintlify` and `lazydocs`. 78 | 79 | ### install mintlify 80 | 81 | > [!NOTE] 82 | > Please install Node.js before proceeding. 83 | 84 | ```sh 85 | npm i -g mint 86 | ``` 87 | 88 | For additional instructions, you can read about it [here](https://mintlify.com/docs/installation). 89 | 90 | ```sh 91 | make all_docs 92 | ``` 93 | 94 | Finally to view the documentation 95 | 96 | ```sh 97 | make preview_docs 98 | ``` 99 | 100 | ## Running tests 101 | 102 | If you're working on the local interface you can just use 103 | 104 | ```sh 105 | uv run pytest 106 | ``` 107 | 108 | ## Do you want to contribute to the documentation? 109 | 110 | - The docs are automatically generated from the docstrings in the utilsforecast folder. 111 | - To contribute, ensure your docstrings follow the Google style format. 112 | - Once your docstring is correctly written, the documentation framework will scrape it and regenerate the corresponding `.mdx` files and your changes will then appear in the updated docs. 113 | - Make an appropriate entry in the `docs/mintlify/mint.json` file. 114 | -------------------------------------------------------------------------------- /utilsforecast/grouped_array.py: -------------------------------------------------------------------------------- 1 | __all__ = ['GroupedArray'] 2 | 3 | 4 | from typing import Sequence, Tuple, Union 5 | 6 | import numpy as np 7 | 8 | from .compat import DataFrame 9 | from .processing import counts_by_id, value_cols_to_numpy 10 | 11 | 12 | def _append_one( 13 | data: np.ndarray, indptr: np.ndarray, new: np.ndarray 14 | ) -> Tuple[np.ndarray, np.ndarray]: 15 | """Append each value of new to each group in data formed by indptr.""" 16 | n_groups = len(indptr) - 1 17 | n_rows = data.shape[0] + new.shape[0] 18 | if data.ndim == 2: 19 | new_data = np.empty_like(data, shape=(n_rows, data.shape[1])) 20 | else: 21 | new_data = np.empty_like(data, shape=n_rows) 22 | new_indptr = indptr.copy() 23 | new_indptr[1:] += np.arange(1, n_groups + 1) 24 | for i in range(n_groups): 25 | prev_slice = slice(indptr[i], indptr[i + 1]) 26 | new_slice = slice(new_indptr[i], new_indptr[i + 1] - 1) 27 | new_data[new_slice] = data[prev_slice] 28 | new_data[new_indptr[i + 1] - 1] = new[i] 29 | return new_data, new_indptr 30 | 31 | 32 | def _append_several( 33 | data: np.ndarray, 34 | indptr: np.ndarray, 35 | new_sizes: np.ndarray, 36 | new_values: np.ndarray, 37 | new_groups: np.ndarray, 38 | ) -> Tuple[np.ndarray, np.ndarray]: 39 | n_rows = data.shape[0] + new_values.shape[0] 40 | if data.ndim == 2: 41 | new_data = np.empty_like(data, shape=(n_rows, data.shape[1])) 42 | else: 43 | new_data = np.empty_like(data, shape=n_rows) 44 | new_indptr = np.empty_like(indptr, shape=new_sizes.size + 1) 45 | new_indptr[0] = 0 46 | old_indptr_idx = 0 47 | new_vals_idx = 0 48 | for i, is_new in enumerate(new_groups): 49 | new_size = new_sizes[i] 50 | if is_new: 51 | old_size = 0 52 | else: 53 | prev_slice = slice(indptr[old_indptr_idx], indptr[old_indptr_idx + 1]) 54 | old_indptr_idx += 1 55 | old_size = prev_slice.stop - prev_slice.start 56 | new_size += old_size 57 | new_data[new_indptr[i] : new_indptr[i] + old_size] = data[prev_slice] 58 | new_indptr[i + 1] = new_indptr[i] + new_size 59 | new_data[new_indptr[i] + old_size : new_indptr[i + 1]] = new_values[ 60 | new_vals_idx : new_vals_idx + new_sizes[i] 61 | ] 62 | new_vals_idx += new_sizes[i] 63 | return new_data, new_indptr 64 | 65 | 66 | class GroupedArray: 67 | def __init__(self, data: np.ndarray, indptr: np.ndarray): 68 | self.data = data 69 | self.indptr = indptr 70 | self.n_groups = len(indptr) - 1 71 | 72 | def __len__(self): 73 | return self.n_groups 74 | 75 | def __getitem__(self, idx: int) -> np.ndarray: 76 | if idx < 0: 77 | idx = self.n_groups + idx 78 | return self.data[self.indptr[idx] : self.indptr[idx + 1]] 79 | 80 | @classmethod 81 | def from_sorted_df( 82 | cls, df: DataFrame, id_col: str, time_col: str, target_col: str 83 | ) -> "GroupedArray": 84 | id_counts = counts_by_id(df, id_col) 85 | sizes = id_counts["counts"].to_numpy() 86 | indptr = np.append(0, sizes.cumsum()) 87 | data = value_cols_to_numpy(df, id_col, time_col, target_col) 88 | if data.dtype not in (np.float32, np.float64): 89 | data = data.astype(np.float32) 90 | return cls(data, indptr) 91 | 92 | def _take_from_ranges(self, ranges: Sequence) -> Tuple[np.ndarray, np.ndarray]: 93 | items = [self.data[r] for r in ranges] 94 | sizes = np.array([item.shape[0] for item in items]) 95 | if self.data.ndim == 2: 96 | data = np.vstack(items) 97 | else: 98 | data = np.hstack(items) 99 | indptr = np.append(0, sizes.cumsum()) 100 | return data, indptr 101 | 102 | def take(self, idxs: Sequence[int]) -> Tuple[np.ndarray, np.ndarray]: 103 | """Subset specific groups by their indices.""" 104 | ranges = [range(self.indptr[i], self.indptr[i + 1]) for i in idxs] 105 | return self._take_from_ranges(ranges) 106 | 107 | def take_from_groups(self, idx: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: 108 | """Select a subset from each group.""" 109 | if isinstance(idx, int): 110 | # this preserves the 2d structure of data when indexing with the range 111 | idx = slice(idx, idx + 1) 112 | ranges = [ 113 | range(self.indptr[i], self.indptr[i + 1])[idx] for i in range(self.n_groups) 114 | ] 115 | return self._take_from_ranges(ranges) 116 | 117 | def append(self, new: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 118 | """Appends each element of `new` to each existing group. Returns a copy.""" 119 | if new.shape[0] != self.n_groups: 120 | raise ValueError(f"new must have {self.n_groups} rows.") 121 | return _append_one(self.data, self.indptr, new) 122 | 123 | def append_several( 124 | self, new_sizes: np.ndarray, new_values: np.ndarray, new_groups: np.ndarray 125 | ) -> Tuple[np.ndarray, np.ndarray]: 126 | return _append_several( 127 | self.data, self.indptr, new_sizes, new_values, new_groups 128 | ) 129 | 130 | def __repr__(self): 131 | return f"{self.__class__.__name__}(n_rows={self.data.shape[0]:,}, n_groups={self.n_groups:,})" 132 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | ops@nixtla.io. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /utilsforecast/validation.py: -------------------------------------------------------------------------------- 1 | """Utilities to validate input data""" 2 | 3 | __all__ = ['ensure_shallow_copy', 'ensure_time_dtype', 'validate_format', 'validate_freq'] 4 | 5 | 6 | import re 7 | from typing import Optional, Union 8 | 9 | import pandas as pd 10 | 11 | from .compat import DataFrame, DFType, Series, pl, pl_DataFrame, pl_Series 12 | 13 | 14 | def _is_int_dtype(s: Union[pd.Index, Series]) -> bool: 15 | if isinstance(s, (pd.Index, pd.Series)): 16 | out = pd.api.types.is_integer_dtype(s.dtype) 17 | else: 18 | try: 19 | out = s.dtype.is_integer() 20 | except AttributeError: 21 | out = s.is_integer() 22 | return out 23 | 24 | 25 | def _is_dt_dtype(s: Union[pd.Index, Series]) -> bool: 26 | if isinstance(s, (pd.Index, pd.Series)): 27 | out = pd.api.types.is_datetime64_any_dtype(s.dtype) 28 | else: 29 | out = s.dtype in (pl.Date, pl.Datetime) 30 | return out 31 | 32 | 33 | def _is_dt_or_int(s: Series) -> bool: 34 | return _is_dt_dtype(s) or _is_int_dtype(s) 35 | 36 | 37 | def ensure_shallow_copy(df: pd.DataFrame) -> pd.DataFrame: 38 | from packaging.version import Version 39 | 40 | if Version(pd.__version__) < Version("1.4"): 41 | # https://github.com/pandas-dev/pandas/pull/43406 42 | df = df.copy() 43 | return df 44 | 45 | 46 | def ensure_time_dtype(df: DFType, time_col: str = "ds") -> DFType: 47 | """Make sure that `time_col` contains timestamps or integers. 48 | If it contains strings, try to cast them as timestamps.""" 49 | times = df[time_col] 50 | if _is_dt_or_int(times): 51 | return df 52 | parse_err_msg = ( 53 | f"Failed to parse '{time_col}' from string to datetime. " 54 | "Please make sure that it contains valid timestamps or integers." 55 | ) 56 | if isinstance(times, pd.Series) and pd.api.types.is_object_dtype(times): 57 | try: 58 | times = pd.to_datetime(times) 59 | except ValueError: 60 | raise ValueError(parse_err_msg) 61 | df = ensure_shallow_copy(df.copy(deep=False)) 62 | df[time_col] = times 63 | elif isinstance(times, pl_Series) and times.dtype == pl.Utf8: 64 | try: 65 | times = times.str.to_datetime() 66 | except pl.exceptions.ComputeError: 67 | raise ValueError(parse_err_msg) 68 | df = df.with_columns(times) 69 | else: 70 | raise ValueError(f"'{time_col}' should have valid timestamps or integers.") 71 | return df 72 | 73 | 74 | def validate_format( 75 | df: DataFrame, 76 | id_col: str = "unique_id", 77 | time_col: str = "ds", 78 | target_col: Optional[str] = "y", 79 | ) -> None: 80 | """Ensure DataFrame has expected format. 81 | 82 | Args: 83 | df (pandas or polars DataFrame): DataFrame with time series in long format. 84 | id_col (str, optional): Column that identifies each serie. Defaults to 'unique_id'. 85 | time_col (str, optional): Column that identifies each timestamp. Defaults to 'ds'. 86 | target_col (str, optional): Column that contains the target. Defaults to 'y'. 87 | 88 | Returns: 89 | None 90 | """ 91 | if not isinstance(df, (pd.DataFrame, pl_DataFrame)): 92 | raise ValueError( 93 | f"`df` must be either pandas or polars dataframe, got {type(df)}" 94 | ) 95 | 96 | # required columns 97 | expected_cols = {id_col, time_col} 98 | if target_col is not None: 99 | expected_cols.add(target_col) 100 | missing_cols = sorted(expected_cols - set(df.columns)) 101 | if missing_cols: 102 | raise ValueError(f"The following columns are missing: {missing_cols}") 103 | 104 | # time col 105 | if not _is_dt_or_int(df[time_col]): 106 | times_dtype = df[time_col].dtype 107 | raise ValueError( 108 | f"The time column ('{time_col}') should have either timestamps or integers, got '{times_dtype}'." 109 | ) 110 | 111 | # target col 112 | if target_col is None: 113 | return None 114 | target = df[target_col] 115 | if isinstance(target, pd.Series): 116 | is_numeric = pd.api.types.is_numeric_dtype(target.dtype) 117 | else: 118 | try: 119 | is_numeric = target.dtype.is_numeric() 120 | except AttributeError: 121 | is_numeric = target.is_numeric() 122 | if not is_numeric: 123 | raise ValueError( 124 | f"The target column ('{target_col}') should have a numeric data type, got '{target.dtype}')" 125 | ) 126 | 127 | 128 | def validate_freq( 129 | times: Series, 130 | freq: Union[str, int], 131 | ) -> None: 132 | if _is_int_dtype(times) and not isinstance(freq, int): 133 | raise ValueError( 134 | "Time column contains integers but the specified frequency is not an integer. " 135 | "Please provide a valid integer, e.g. `freq=1`" 136 | ) 137 | if _is_dt_dtype(times) and isinstance(freq, int): 138 | raise ValueError( 139 | "Time column contains timestamps but the specified frequency is an integer. " 140 | "Please provide a valid pandas or polars offset, e.g. `freq='D'` or `freq='1d'`." 141 | ) 142 | # try to catch pandas frequency in polars dataframe 143 | if isinstance(times, pl_Series) and isinstance(freq, str): 144 | missing_n = re.search(r"\d+", freq) is None 145 | uppercase = re.sub(r"\d+", "", freq).isupper() 146 | if missing_n or uppercase: 147 | raise ValueError( 148 | "You must specify a valid polars offset when using polars dataframes. " 149 | "You can find the available offsets in " 150 | "https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.offset_by.html" 151 | ) 152 | -------------------------------------------------------------------------------- /tests/test_grouped_array.py: -------------------------------------------------------------------------------- 1 | # test _append_one 2 | import numpy as np 3 | from conftest import assert_raises_with_message 4 | 5 | from utilsforecast.data import generate_series 6 | from utilsforecast.grouped_array import GroupedArray, _append_one, _append_several 7 | 8 | 9 | def test_append_one(): 10 | data = np.arange(5) 11 | indptr = np.array([0, 2, 5]) 12 | new = np.array([7, 8]) 13 | new_data, new_indptr = _append_one(data, indptr, new) 14 | np.testing.assert_equal(new_data, np.array([0, 1, 7, 2, 3, 4, 8])) 15 | np.testing.assert_equal( 16 | new_indptr, 17 | np.array([0, 3, 7]), 18 | ) 19 | 20 | # 2d 21 | def test_append_one_2d(): 22 | data = np.arange(5).reshape(-1, 1) 23 | indptr = np.array([0, 2, 5]) 24 | new = np.array([7, 8]) 25 | new_data, new_indptr = _append_one(data, indptr, new) 26 | np.testing.assert_equal(new_data, np.array([0, 1, 7, 2, 3, 4, 8]).reshape(-1, 1)) 27 | np.testing.assert_equal( 28 | new_indptr, 29 | np.array([0, 3, 7]), 30 | ) 31 | 32 | # test append several 33 | def test_append_several(): 34 | data = np.arange(5) 35 | indptr = np.array([0, 2, 5]) 36 | new_sizes = np.array([0, 2, 1]) 37 | new_values = np.array([6, 7, 5]) 38 | new_groups = np.array([False, True, False]) 39 | new_data, new_indptr = _append_several(data, indptr, new_sizes, new_values, new_groups) 40 | np.testing.assert_equal(new_data, np.array([0, 1, 6, 7, 2, 3, 4, 5])) 41 | np.testing.assert_equal( 42 | new_indptr, 43 | np.array([0, 2, 4, 8]), 44 | ) 45 | 46 | # 2d 47 | def test_append_several_2d(): 48 | data = np.arange(5).reshape(-1, 1) 49 | indptr = np.array([0, 2, 5]) 50 | new_sizes = np.array([0, 2, 1]) 51 | new_values = np.array([6, 7, 5]).reshape(-1, 1) 52 | new_groups = np.array([False, True, False]) 53 | new_data, new_indptr = _append_several(data, indptr, new_sizes, new_values, new_groups) 54 | np.testing.assert_equal(new_data, np.array([0, 1, 6, 7, 2, 3, 4, 5]).reshape(-1, 1)) 55 | np.testing.assert_equal( 56 | new_indptr, 57 | np.array([0, 2, 4, 8]), 58 | ) 59 | 60 | 61 | # The `GroupedArray` is used internally for storing the series values and performing transformations. 62 | def test_grouped_array(): 63 | data = np.arange(20, dtype=np.float32).reshape(-1, 2) 64 | indptr = np.array([0, 2, 10]) # group 1: [0, 1], group 2: [2..9] 65 | ga = GroupedArray(data, indptr) 66 | assert len(ga) == 2 67 | # Iterate through the groups 68 | ga_iter = iter(ga) 69 | np.testing.assert_equal(next(ga_iter), np.arange(4).reshape(-1, 2)) 70 | np.testing.assert_equal(next(ga_iter), np.arange(4, 20).reshape(-1, 2)) 71 | # Take the last two observations from each group 72 | last2_data, last2_indptr = ga.take_from_groups(slice(-2, None)) 73 | np.testing.assert_equal( 74 | last2_data, 75 | np.vstack( 76 | [ 77 | np.arange(4).reshape(-1, 2), 78 | np.arange(16, 20).reshape(-1, 2), 79 | ] 80 | ), 81 | ) 82 | np.testing.assert_equal(last2_indptr, np.array([0, 2, 4])) 83 | 84 | # 1d 85 | def test_grouped_array_1d(): 86 | data = np.arange(20, dtype=np.float32).reshape(-1, 2) 87 | indptr = np.array([0, 2, 10]) 88 | ga = GroupedArray(data, indptr) 89 | ga1d = GroupedArray(np.arange(10), indptr) 90 | last2_data1d, last2_indptr1d = ga1d.take_from_groups(slice(-2, None)) 91 | np.testing.assert_equal(last2_data1d, np.array([0, 1, 8, 9])) 92 | np.testing.assert_equal(last2_indptr1d, np.array([0, 2, 4])) 93 | # Take the second observation from each group 94 | second_data, second_indptr = ga.take_from_groups(1) 95 | np.testing.assert_equal(second_data, np.array([[2, 3], [6, 7]])) 96 | np.testing.assert_equal(second_indptr, np.array([0, 1, 2])) 97 | 98 | # 1d 99 | second_data1d, second_indptr1d = ga1d.take_from_groups(1) 100 | np.testing.assert_equal(second_data1d, np.array([1, 3])) 101 | np.testing.assert_equal(second_indptr1d, np.array([0, 1, 2])) 102 | # Take the last four observations from every group. Note that since group 1 only has two elements, only these are returned. 103 | last4_data, last4_indptr = ga.take_from_groups(slice(-4, None)) 104 | np.testing.assert_equal( 105 | last4_data, 106 | np.vstack( 107 | [ 108 | np.arange(4).reshape(-1, 2), 109 | np.arange(12, 20).reshape(-1, 2), 110 | ] 111 | ), 112 | ) 113 | np.testing.assert_equal(last4_indptr, np.array([0, 2, 6])) 114 | 115 | # 1d 116 | last4_data1d, last4_indptr1d = ga1d.take_from_groups(slice(-4, None)) 117 | np.testing.assert_equal(last4_data1d, np.array([0, 1, 6, 7, 8, 9])) 118 | np.testing.assert_equal(last4_indptr1d, np.array([0, 2, 6])) 119 | # Select a specific subset of groups 120 | indptr = np.array([0, 2, 4, 7, 10]) 121 | ga2 = GroupedArray(data, indptr) 122 | subset = GroupedArray(*ga2.take([0, 2])) 123 | np.testing.assert_allclose(subset[0].data, ga2[0].data) 124 | np.testing.assert_allclose(subset[1].data, ga2[2].data) 125 | 126 | # 1d 127 | ga2_1d = GroupedArray(np.arange(10), indptr) 128 | subset1d = GroupedArray(*ga2_1d.take([0, 2])) 129 | np.testing.assert_allclose(subset1d[0].data, ga2_1d[0].data) 130 | np.testing.assert_allclose(subset1d[1].data, ga2_1d[2].data) 131 | # try to append new values that don't match the number of groups 132 | assert_raises_with_message(lambda: ga.append(np.array([1.0, 2.0, 3.0])), "new must have 2 rows") 133 | # build from df 134 | series_pd = generate_series(10, static_as_categorical=False, engine="pandas") 135 | ga_pd = GroupedArray.from_sorted_df(series_pd, "unique_id", "ds", "y") 136 | series_pl = generate_series(10, static_as_categorical=False, engine="polars") 137 | ga_pl = GroupedArray.from_sorted_df(series_pl, "unique_id", "ds", "y") 138 | np.testing.assert_allclose(ga_pd.data, ga_pl.data) 139 | np.testing.assert_equal(ga_pd.indptr, ga_pl.indptr) 140 | -------------------------------------------------------------------------------- /tests/test_feature_engineering.py: -------------------------------------------------------------------------------- 1 | from functools import partial, reduce 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import polars as pl 6 | import polars.testing 7 | import pytest 8 | 9 | from utilsforecast.data import generate_series 10 | from utilsforecast.feature_engineering import ( 11 | fourier, 12 | future_exog_to_historic, 13 | pipeline, 14 | time_features, 15 | trend, 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def setup_series(): 21 | series = generate_series(5, equal_ends=True) 22 | series_pl = generate_series(5, equal_ends=True, engine="polars") 23 | return series, series_pl 24 | 25 | 26 | def test_fourier_transform(setup_series): 27 | series, series_pl = setup_series 28 | 29 | transformed_df, future_df = fourier(series, freq="D", season_length=7, k=2, h=1) 30 | transformed_df2, future_df2 = fourier( 31 | series.sample(frac=1.0), freq="D", season_length=7, k=2, h=1 32 | ) 33 | pd.testing.assert_frame_equal( 34 | transformed_df, 35 | transformed_df2.sort_values(["unique_id", "ds"]).reset_index(drop=True), 36 | ) 37 | pd.testing.assert_frame_equal(future_df, future_df2) 38 | 39 | transformed_pl, future_pl = fourier(series_pl, freq="1d", season_length=7, k=2, h=1) 40 | transformed_pl2, future_pl2 = fourier( 41 | series_pl.sample(fraction=1.0), freq="1d", season_length=7, k=2, h=1 42 | ) 43 | pl.testing.assert_frame_equal(transformed_pl, transformed_pl2) 44 | pl.testing.assert_frame_equal(future_pl, future_pl2) 45 | pd.testing.assert_frame_equal( 46 | transformed_df.drop(columns=["unique_id"]), 47 | transformed_pl.drop("unique_id").to_pandas(), 48 | ) 49 | pd.testing.assert_frame_equal( 50 | future_df.drop(columns=["unique_id"]), future_pl.drop("unique_id").to_pandas() 51 | ) 52 | series = generate_series(5, equal_ends=True) 53 | transformed_df, future_df = trend(series, freq="D", h=1) 54 | transformed_df 55 | future_df 56 | transformed_df, future_df = time_features( 57 | series, freq="D", features=["month", "day", "week"], h=1 58 | ) 59 | series_with_prices = series.assign(price=np.random.rand(len(series))).sample( 60 | frac=1.0 61 | ) 62 | series_with_prices 63 | transformed_df, future_df = future_exog_to_historic( 64 | df=series_with_prices, 65 | freq="D", 66 | features=["price"], 67 | h=2, 68 | ) 69 | pd.testing.assert_frame_equal( 70 | ( 71 | series_with_prices.sort_values(["unique_id", "ds"]) 72 | .groupby("unique_id", observed=True) 73 | .tail(2)[["unique_id", "price"]] 74 | .reset_index(drop=True) 75 | ), 76 | future_df[["unique_id", "price"]], 77 | ) 78 | series_with_prices_pl = pl.from_pandas( 79 | series_with_prices.astype({"unique_id": "int64"}) 80 | ) 81 | transformed_pl, future_pl = future_exog_to_historic( 82 | df=series_with_prices_pl, 83 | freq="1d", 84 | features=["price"], 85 | h=2, 86 | ) 87 | pd.testing.assert_frame_equal( 88 | future_pl.to_pandas(), future_df.astype({"unique_id": "int64"}) 89 | ) 90 | 91 | 92 | def is_weekend(times): 93 | if isinstance(times, pd.Index): 94 | dow = times.weekday + 1 # monday=0 in pandas and 1 in polars 95 | else: 96 | dow = times.dt.weekday() 97 | return dow >= 6 98 | 99 | 100 | def even_days_and_months(times): 101 | if isinstance(times, pd.Index): 102 | out = pd.DataFrame( 103 | { 104 | "even_day": (times.weekday + 1) % 2 == 0, 105 | "even_month": times.month % 2 == 0, 106 | } 107 | ) 108 | else: 109 | # for polars you can return a list of expressions 110 | out = [ 111 | (times.dt.weekday() % 2 == 0).alias("even_day"), 112 | (times.dt.month() % 2 == 0).alias("even_month"), 113 | ] 114 | return out 115 | 116 | 117 | @pytest.fixture 118 | def setup_features(): 119 | features = [ 120 | trend, 121 | partial(fourier, season_length=7, k=1), 122 | partial(fourier, season_length=28, k=1), 123 | partial(time_features, features=["day", is_weekend, even_days_and_months]), 124 | ] 125 | return features 126 | 127 | 128 | @pytest.fixture 129 | def setup_pipeline(setup_series, setup_features): 130 | def _inner(freq): 131 | series, _ = setup_series 132 | transformed_df, future_df = pipeline( 133 | series, 134 | features=setup_features, 135 | freq=freq, 136 | h=1, 137 | ) 138 | return transformed_df, future_df 139 | 140 | return _inner 141 | 142 | 143 | def reduce_join(dfs, on): 144 | return reduce( 145 | lambda left, right: left.merge(right, on=on, how="left"), 146 | dfs, 147 | ) 148 | 149 | 150 | @pytest.mark.parametrize("freq", ["D", "1d"]) 151 | def test_pipeline(setup_series, freq, setup_pipeline, setup_features): 152 | series, series_pl = setup_series 153 | transformed_df, future_df = setup_pipeline(freq) 154 | 155 | individual_results = [f(series, freq=freq, h=1) for f in setup_features] 156 | expected_transformed = reduce_join( 157 | [r[0] for r in individual_results], on=["unique_id", "ds", "y"] 158 | ) 159 | expected_future = reduce_join( 160 | [r[1] for r in individual_results], on=["unique_id", "ds"] 161 | ) 162 | 163 | pd.testing.assert_frame_equal(transformed_df, expected_transformed) 164 | pd.testing.assert_frame_equal(future_df, expected_future) 165 | 166 | transformed_pl, future_pl = pipeline( 167 | series_pl, 168 | features=setup_features, 169 | freq="1d", 170 | h=1, 171 | ) 172 | pd.testing.assert_frame_equal( 173 | transformed_pl.drop("unique_id").to_pandas(), 174 | transformed_df.drop(columns="unique_id"), 175 | check_dtype=False, 176 | ) 177 | pd.testing.assert_frame_equal( 178 | future_pl.drop("unique_id").to_pandas(), 179 | future_df.drop(columns="unique_id"), 180 | check_dtype=False, 181 | ) 182 | -------------------------------------------------------------------------------- /utilsforecast/data.py: -------------------------------------------------------------------------------- 1 | """Utilies for generating time series datasets""" 2 | 3 | __all__ = ['generate_series'] 4 | 5 | 6 | from typing import List, Literal, Optional, overload 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from .compat import DataFrame, pl, pl_DataFrame 12 | 13 | 14 | @overload 15 | def generate_series( 16 | n_series: int, 17 | freq: str = "D", 18 | min_length: int = 50, 19 | max_length: int = 500, 20 | n_static_features: int = 0, 21 | equal_ends: bool = False, 22 | with_trend: bool = False, 23 | static_as_categorical: bool = True, 24 | n_models: int = 0, 25 | level: Optional[List[float]] = None, 26 | engine: Literal["pandas"] = "pandas", 27 | ) -> pd.DataFrame: ... 28 | 29 | 30 | @overload 31 | def generate_series( 32 | n_series: int, 33 | freq: str = "D", 34 | min_length: int = 50, 35 | max_length: int = 500, 36 | n_static_features: int = 0, 37 | equal_ends: bool = False, 38 | with_trend: bool = False, 39 | static_as_categorical: bool = True, 40 | n_models: int = 0, 41 | level: Optional[List[float]] = None, 42 | engine: Literal["polars"] = "polars", 43 | ) -> pl_DataFrame: ... 44 | 45 | 46 | def generate_series( 47 | n_series: int, 48 | freq: str = "D", 49 | min_length: int = 50, 50 | max_length: int = 500, 51 | n_static_features: int = 0, 52 | equal_ends: bool = False, 53 | with_trend: bool = False, 54 | static_as_categorical: bool = True, 55 | n_models: int = 0, 56 | level: Optional[List[float]] = None, 57 | engine: Literal["pandas", "polars"] = "pandas", 58 | seed: int = 0, 59 | ) -> DataFrame: 60 | """Generate Synthetic Panel Series. 61 | 62 | Args: 63 | n_series (int): Number of series for synthetic panel. 64 | freq (str, optional): Frequency of the data (pandas alias). 65 | Seasonalities are implemented for hourly, daily and monthly. 66 | Defaults to 'D'. 67 | min_length (int, optional): Minimum length of synthetic panel's series. 68 | Defaults to 50. 69 | max_length (int, optional): Maximum length of synthetic panel's series. 70 | Defaults to 500. 71 | n_static_features (int, optional): Number of static exogenous variables 72 | for synthetic panel's series. Defaults to 0. 73 | equal_ends (bool, optional): Series should end in the same timestamp. 74 | Defaults to False. 75 | with_trend (bool, optional): Series should have a (positive) trend. 76 | Defaults to False. 77 | static_as_categorical (bool, optional): Static features should have a 78 | categorical data type. Defaults to True. 79 | n_models (int, optional): Number of models predictions to simulate. 80 | Defaults to 0. 81 | level (list of float, optional): Confidence level for intervals to 82 | simulate for each model. Defaults to None. 83 | engine (str, optional): Output Dataframe type. Defaults to 'pandas'. 84 | seed (int, optional): Random seed used for generating the data. 85 | Defaults to 0. 86 | 87 | Returns: 88 | pandas or polars DataFrame: Synthetic panel with columns [`unique_id`, 89 | `ds`, `y`] and exogenous features. 90 | """ 91 | available_engines = ["pandas", "polars"] 92 | engine = engine.lower() # type: ignore 93 | if engine not in available_engines: 94 | raise ValueError( 95 | f"{engine} is not a correct engine; available options: {available_engines}" 96 | ) 97 | seasonalities = { 98 | pd.offsets.Hour(): 24, 99 | pd.offsets.Day(): 7, 100 | pd.offsets.MonthBegin(): 12, 101 | pd.offsets.MonthEnd(): 12, 102 | } 103 | freq = pd.tseries.frequencies.to_offset(freq) 104 | season = seasonalities.get(freq, 1) 105 | 106 | rng = np.random.RandomState(seed) 107 | series_lengths = rng.randint(min_length, max_length + 1, n_series) 108 | total_length = series_lengths.sum() 109 | 110 | vals_dict = {"unique_id": np.repeat(np.arange(n_series), series_lengths)} 111 | 112 | dates = pd.date_range("2000-01-01", periods=max_length, freq=freq).values 113 | if equal_ends: 114 | series_dates = [dates[-length:] for length in series_lengths] 115 | else: 116 | series_dates = [dates[:length] for length in series_lengths] 117 | vals_dict["ds"] = np.concatenate(series_dates) 118 | 119 | vals_dict["y"] = np.arange(total_length) % season + rng.rand(total_length) * 0.5 120 | 121 | for i in range(n_static_features): 122 | static_values = np.repeat(rng.randint(0, 100, n_series), series_lengths) 123 | vals_dict[f"static_{i}"] = static_values 124 | if i == 0: 125 | vals_dict["y"] = vals_dict["y"] * (1 + vals_dict[f"static_{i}"]) 126 | 127 | if with_trend: 128 | coefs = np.repeat(rng.rand(n_series), series_lengths) 129 | trends = np.concatenate([np.arange(length) for length in series_lengths]) 130 | vals_dict["y"] += coefs * trends 131 | 132 | for i in range(n_models): 133 | rands = rng.rand(total_length) 134 | vals_dict[f"model{i}"] = vals_dict["y"] * (0.2 * rands + 0.9) 135 | level = level or [] 136 | for lv in level: 137 | lv_rands = 0.5 * rands * lv / 100 138 | vals_dict[f"model{i}-lo-{lv}"] = vals_dict[f"model{i}"] * (1 - lv_rands) 139 | vals_dict[f"model{i}-hi-{lv}"] = vals_dict[f"model{i}"] * (1 + lv_rands) 140 | 141 | cat_cols = [col for col in vals_dict.keys() if "static" in col] 142 | cat_cols.append("unique_id") 143 | if engine == "pandas": 144 | df = pd.DataFrame(vals_dict) 145 | if static_as_categorical: 146 | df[cat_cols] = df[cat_cols].astype("category") 147 | df["unique_id"] = df["unique_id"].cat.as_ordered() 148 | else: 149 | df = pl.DataFrame(vals_dict) 150 | df = df.with_columns(pl.col("unique_id").sort()) 151 | if static_as_categorical: 152 | df = df.with_columns( 153 | *[pl.col(col).cast(str).cast(pl.Categorical) for col in cat_cols] 154 | ) 155 | return df 156 | -------------------------------------------------------------------------------- /docs/convert_to_mkdocstrings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import re 4 | from typing import Any, Dict, Optional 5 | 6 | import griffe 7 | import yaml 8 | from griffe2md import ConfigDict, render_object_docs 9 | 10 | logging.getLogger("griffe").setLevel(logging.ERROR) 11 | 12 | 13 | class MkDocstringsParser: 14 | def __init__(self): 15 | pass 16 | 17 | def parse_docstring_block( 18 | self, block_content: str 19 | ) -> tuple[str, str, Dict[str, Any]]: 20 | """Parse a ::: block to extract module path, handler, and options""" 21 | lines = block_content.strip().split("\n") 22 | 23 | # First line contains the module path 24 | module_path = lines[0].replace(":::", "").strip() 25 | 26 | # Parse YAML configuration 27 | yaml_content = "\n".join(lines[1:]) if len(lines) > 1 else "" 28 | 29 | try: 30 | config = yaml.safe_load(yaml_content) or {} 31 | except yaml.YAMLError: 32 | config = {} 33 | 34 | handler_type = config.get("handler", "python") 35 | options = config.get("options", {}) 36 | 37 | return module_path, handler_type, options 38 | 39 | def generate_documentation(self, module_path: str, options: Dict[str, Any]) -> str: 40 | """Generate documentation for a given module using griffe and griffe2md""" 41 | try: 42 | if "." in module_path: 43 | parts = module_path.split(".") 44 | package_name = parts[0] 45 | object_path = ".".join(parts[1:]) 46 | else: 47 | package_name = module_path 48 | object_path = "" 49 | 50 | package = griffe.load(package_name) 51 | to_replace = ".".join((package_name + "." + object_path).split(".")[:-1]) 52 | if object_path: 53 | obj = package[object_path] 54 | else: 55 | obj = package 56 | 57 | # Ensure the docstring is properly parsed with Google parser 58 | # For functions, we might need to get the actual runtime docstring 59 | if hasattr(obj, "kind") and obj.kind.value == "function": 60 | try: 61 | # Try to get the actual function object to access runtime docstring 62 | import importlib 63 | 64 | module_parts = module_path.split(".") 65 | module_name = ".".join(module_parts[:-1]) 66 | func_name = module_parts[-1] 67 | 68 | actual_module = importlib.import_module(module_name) 69 | actual_func = getattr(actual_module, func_name) 70 | 71 | # If the actual function has a docstring but griffe obj doesn't, use it 72 | if actual_func.__doc__ and ( 73 | not obj.docstring or not obj.docstring.value 74 | ): 75 | from griffe import Docstring 76 | 77 | obj.docstring = Docstring(actual_func.__doc__, lineno=1) 78 | except: 79 | pass # Fall back to griffe's detection 80 | 81 | if obj.docstring: 82 | # Force parsing with Google parser to get structured sections 83 | obj.docstring.parsed = griffe.parse_google(obj.docstring) 84 | 85 | if hasattr(obj, "members"): 86 | for member_name, member in obj.members.items(): 87 | if member.docstring: 88 | member.docstring.parsed = griffe.parse_google(member.docstring) 89 | 90 | # Create ConfigDict with the options 91 | # Adjust default options based on object type 92 | if hasattr(obj, "kind") and obj.kind.value == "function": 93 | # Configuration for functions 94 | default_options = { 95 | "docstring_section_style": "table", 96 | "heading_level": 3, 97 | "show_root_heading": True, 98 | "show_source": True, 99 | "show_signature": True, 100 | } 101 | else: 102 | # Configuration for classes and modules 103 | default_options = { 104 | "docstring_section_style": "table", 105 | "heading_level": 3, 106 | "show_root_heading": True, 107 | "show_source": True, 108 | "summary": {"functions": False}, 109 | } 110 | 111 | default_options.update(options) 112 | config = ConfigDict(**default_options) 113 | 114 | # Generate the documentation using griffe2md 115 | # Type ignore since griffe2md can handle various object types 116 | markdown_docs = render_object_docs(obj, config) # type: ignore 117 | 118 | markdown_docs = markdown_docs.replace(f"### `{to_replace}.", "### `") 119 | 120 | return markdown_docs 121 | 122 | except Exception as e: 123 | return f"" 124 | 125 | def process_markdown(self, content: str) -> str: 126 | """Process markdown content, replacing ::: blocks with generated documentation""" 127 | 128 | # Pattern to match ::: blocks (including multi-line YAML config) 129 | pattern = r":::\s*([^\n]+)(?:\n((?:\s{4}.*\n?)*))?" 130 | 131 | def replace_block(match): 132 | module_line = match.group(1).strip() 133 | yaml_block = match.group(2) or "" 134 | 135 | # Reconstruct the full block 136 | full_block = f":::{module_line}\n{yaml_block}".rstrip() 137 | 138 | try: 139 | module_path, handler_type, options = self.parse_docstring_block( 140 | full_block 141 | ) 142 | generated_docs = self.generate_documentation(module_path, options) 143 | return generated_docs 144 | except Exception as e: 145 | return f"" 146 | 147 | return re.sub(pattern, replace_block, content, flags=re.MULTILINE) 148 | 149 | def process_file(self, input_file: str, output_file: Optional[str] = None) -> str: 150 | """Process a markdown file and return the result""" 151 | with open(input_file, "r", encoding="utf-8") as f: 152 | content = f.read() 153 | 154 | processed_content = self.process_markdown(content) 155 | 156 | if output_file: 157 | with open(output_file, "w", encoding="utf-8") as f: 158 | f.write(processed_content) 159 | 160 | return processed_content 161 | 162 | def get_args(self): 163 | parser = argparse.ArgumentParser( 164 | description="Convert ::: blocks to mkdocstrings" 165 | ) 166 | parser.add_argument("input_file", type=str, help="Input markdown file") 167 | parser.add_argument("output_file", type=str, help="Output markdown file") 168 | return parser.parse_args() 169 | 170 | # Usage example 171 | if __name__ == "__main__": 172 | parser = MkDocstringsParser() 173 | 174 | test_class = """::: coreforecast.lag_transforms.Lag 175 | handler: python 176 | options: 177 | docstring_style: google 178 | members: 179 | - stack 180 | - take 181 | - transform 182 | - update 183 | heading_level: 3 184 | show_root_heading: true 185 | show_source: true 186 | """ 187 | 188 | test_function = """::: coreforecast.differences.diff 189 | handler: python 190 | options: 191 | docstring_style: google 192 | heading_level: 3 193 | show_root_heading: true 194 | show_source: true 195 | show_signature: true 196 | """ 197 | 198 | print("Class documentation:") 199 | print(parser.process_markdown(test_class)) 200 | print("\n" + "=" * 50 + "\n") 201 | print("Function documentation:") 202 | fn = parser.process_markdown(test_function) 203 | print(fn) 204 | -------------------------------------------------------------------------------- /utilsforecast/preprocessing.py: -------------------------------------------------------------------------------- 1 | """Utilities for processing data before training/analysis""" 2 | 3 | __all__ = ['id_time_grid', 'fill_gaps'] 4 | 5 | 6 | import warnings 7 | from datetime import date, datetime 8 | from functools import partial 9 | from typing import Union 10 | 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from .compat import DFType, pl, pl_DataFrame, pl_Series 15 | from .processing import group_by, repeat 16 | from .validation import _is_int_dtype, validate_format, validate_freq 17 | 18 | 19 | def _determine_bound(bound, freq, times_by_id, agg) -> np.ndarray: 20 | if bound == "per_serie": 21 | out = times_by_id[agg].to_numpy() 22 | else: 23 | # the following return a scalar 24 | if bound == "global": 25 | val = getattr(times_by_id[agg].to_numpy(), agg)() 26 | if isinstance(freq, str): 27 | val = np.datetime64(val) 28 | else: 29 | if isinstance(freq, str): 30 | # this raises a nice error message if it isn't a valid datetime 31 | if isinstance(bound, pd.Timestamp) and bound.tz is not None: 32 | bound = bound.tz_convert("UTC").tz_localize(None) 33 | val = np.datetime64(bound) 34 | else: 35 | val = bound 36 | out = np.full(times_by_id.shape[0], val) 37 | if isinstance(freq, str): 38 | out = out.astype(f"datetime64[{freq}]") 39 | return out 40 | 41 | 42 | def _determine_bound_pl( 43 | bound: Union[str, int, date, datetime], 44 | times_by_id: pl_DataFrame, 45 | agg: str, 46 | ) -> pl_Series: 47 | if bound == "per_serie": 48 | out = times_by_id[agg] 49 | else: 50 | if bound == "global": 51 | val = getattr(times_by_id[agg], agg)() 52 | else: 53 | val = bound 54 | out = repeat(pl_Series([val]), times_by_id.shape[0]) 55 | return out 56 | 57 | 58 | def id_time_grid( 59 | df: DFType, 60 | freq: Union[str, int], 61 | start: Union[str, int, date, datetime] = "per_serie", 62 | end: Union[str, int, date, datetime] = "global", 63 | id_col: str = "unique_id", 64 | time_col: str = "ds", 65 | ) -> DFType: 66 | """Generate all expected combiations of ids and times. 67 | 68 | Args: 69 | df (pandas or polars DataFrame): Input data 70 | freq (str or int): Series' frequency 71 | start (str, int, date or datetime, optional): Initial timestamp for the series. 72 | * 'per_serie' uses each serie's first timestamp 73 | * 'global' uses the first timestamp seen in the data 74 | * Can also be a specific timestamp or integer, e.g. '2000-01-01', 2000 or datetime(2000, 1, 1) 75 | Defaults to "per_serie". 76 | end (str, int, date or datetime, optional): Initial timestamp for the series. 77 | * 'per_serie' uses each serie's last timestamp 78 | * 'global' uses the last timestamp seen in the data 79 | * Can also be a specific timestamp or integer, e.g. '2000-01-01', 2000 or datetime(2000, 1, 1) 80 | Defaults to "global". 81 | id_col (str, optional): Column that identifies each serie. Defaults to 'unique_id'. 82 | time_col (str, optional): Column that identifies each timestamp. Defaults to 'ds'. 83 | 84 | Returns: 85 | pandas or polars DataFrame: Dataframe with expected ids and times. 86 | """ 87 | if isinstance(df, pl_DataFrame): 88 | times_by_id = ( 89 | group_by(df, id_col) 90 | .agg( 91 | pl.col(time_col).min().alias("min"), 92 | pl.col(time_col).max().alias("max"), 93 | ) 94 | .sort(id_col) 95 | ) 96 | starts = _determine_bound_pl(start, times_by_id, "min") 97 | ends = _determine_bound_pl(end, times_by_id, "max") 98 | grid = pl_DataFrame({id_col: times_by_id[id_col]}) 99 | if _is_int_dtype(starts): 100 | grid = grid.with_columns( 101 | pl.int_ranges(starts, ends + freq, step=freq, eager=True).alias( 102 | time_col 103 | ) 104 | ) 105 | else: 106 | if starts.dtype == pl.Date: 107 | ranges_fn = pl.date_ranges 108 | else: 109 | ranges_fn = partial( 110 | pl.datetime_ranges, 111 | time_unit=df.schema[time_col].time_unit, 112 | ) 113 | grid = grid.with_columns( 114 | ranges_fn( 115 | starts, 116 | ends, 117 | interval=freq, 118 | eager=True, 119 | ).alias(time_col) 120 | ) 121 | return grid.explode(time_col) 122 | if isinstance(freq, str): 123 | offset = pd.tseries.frequencies.to_offset(freq) 124 | n = offset.n 125 | if isinstance(offset.base, pd.offsets.Minute): 126 | # minutes are represented as 'm' in numpy 127 | freq = "m" 128 | elif isinstance(offset.base, pd.offsets.BusinessDay): 129 | if n != 1: 130 | raise NotImplementedError("Multiple of a business day") 131 | freq = "D" 132 | elif isinstance(offset.base, pd.offsets.Hour): 133 | # hours are represented as 'h' in numpy 134 | freq = "h" 135 | elif isinstance(offset.base, (pd.offsets.QuarterBegin, pd.offsets.QuarterEnd)): 136 | n = 3 137 | freq = "M" 138 | elif isinstance(offset.base, (pd.offsets.YearBegin, pd.offsets.YearEnd)): 139 | freq = "Y" 140 | elif isinstance(offset.base, pd.offsets.Second): 141 | freq = "s" 142 | elif isinstance(offset.base, pd.offsets.Milli): 143 | freq = "ms" 144 | elif isinstance(offset.base, pd.offsets.Micro): 145 | freq = "us" 146 | elif isinstance(offset.base, pd.offsets.Nano): 147 | freq = "ns" 148 | elif isinstance(offset.base, (pd.offsets.MonthBegin, pd.offsets.MonthEnd)): 149 | freq = "M" 150 | elif isinstance(offset.base, pd.offsets.Week): 151 | freq = "W" 152 | if n > 1: 153 | freq = freq.replace(str(n), "") 154 | try: 155 | pd.Timedelta(offset) 156 | except ValueError: 157 | # irregular freq, try using first letter of abbreviation 158 | # such as MS = 'Month Start' -> 'M', YS = 'Year Start' -> 'Y' 159 | freq = freq[0] 160 | delta: Union[np.timedelta64, int] = np.timedelta64(n, freq) 161 | if df[time_col].dt.tz is not None: 162 | df = df.copy(deep=False) 163 | df[time_col] = df[time_col].dt.tz_convert("UTC").dt.tz_localize(None) 164 | else: 165 | delta = freq 166 | times_by_id = df.groupby(id_col, observed=True)[time_col].agg(["min", "max"]) 167 | starts = _determine_bound(start, freq, times_by_id, "min") 168 | ends = _determine_bound(end, freq, times_by_id, "max") + delta 169 | sizes = ((ends - starts) / delta).astype(np.int64) 170 | times = np.hstack( 171 | [np.arange(start, end, delta) for start, end in zip(starts, ends)] 172 | ) 173 | uids = np.repeat(times_by_id.index, sizes) 174 | if isinstance(freq, str): 175 | if isinstance(offset.base, pd.offsets.BusinessDay): 176 | # data was generated daily, we need to keep only business days 177 | bdays = np.is_busday(times) 178 | uids = uids[bdays] 179 | times = times[bdays] 180 | times = pd.Index(times.astype("datetime64[ns]", copy=False)) 181 | first_time = np.datetime64(df.iloc[0][time_col]) 182 | was_truncated = first_time != first_time.astype(f"datetime64[{freq}]") 183 | if was_truncated: 184 | times += offset.base 185 | return pd.DataFrame( 186 | { 187 | id_col: uids, 188 | time_col: times, 189 | } 190 | ) 191 | 192 | 193 | def fill_gaps( 194 | df: DFType, 195 | freq: Union[str, int], 196 | start: Union[str, int, date, datetime] = "per_serie", 197 | end: Union[str, int, date, datetime] = "global", 198 | id_col: str = "unique_id", 199 | time_col: str = "ds", 200 | ) -> DFType: 201 | """Enforce start and end datetimes for dataframe. 202 | 203 | Args: 204 | df (pandas or polars DataFrame): Input data 205 | freq (str or int): Series' frequency 206 | start (str, int, date or datetime, optional): Initial timestamp for the series. 207 | * 'per_serie' uses each serie's first timestamp 208 | * 'global' uses the first timestamp seen in the data 209 | * Can also be a specific timestamp or integer, e.g. '2000-01-01', 2000 or datetime(2000, 1, 1) 210 | Defaults to "per_serie". 211 | end (str, int, date or datetime, optional): Initial timestamp for the series. 212 | * 'per_serie' uses each serie's last timestamp 213 | * 'global' uses the last timestamp seen in the data 214 | * Can also be a specific timestamp or integer, e.g. '2000-01-01', 2000 or datetime(2000, 1, 1) 215 | Defaults to "global". 216 | id_col (str, optional): Column that identifies each serie. Defaults to 'unique_id'. 217 | time_col (str, optional): Column that identifies each timestamp. Defaults to 'ds'. 218 | 219 | Returns: 220 | pandas or polars DataFrame: Dataframe with gaps filled. 221 | """ 222 | validate_format(df, id_col=id_col, time_col=time_col, target_col=None) 223 | validate_freq(df[time_col], freq=freq) 224 | 225 | grid = id_time_grid( 226 | df=df, 227 | freq=freq, 228 | start=start, 229 | end=end, 230 | id_col=id_col, 231 | time_col=time_col, 232 | ) 233 | if isinstance(df, pl_DataFrame): 234 | return grid.join(df, on=[id_col, time_col], how="left") 235 | idx = pd.MultiIndex.from_frame(grid) 236 | if isinstance(freq, str): 237 | tz = df[time_col].dt.tz 238 | if tz is not None: 239 | df = df.copy(deep=False) 240 | df[time_col] = df[time_col].dt.tz_convert("UTC").dt.tz_localize(None) 241 | res = df.set_index([id_col, time_col]).reindex(idx).reset_index() 242 | if isinstance(freq, str): 243 | if tz is not None: 244 | res[time_col] = res[time_col].dt.tz_localize("UTC").dt.tz_convert(tz) 245 | extra_cols = df.columns.drop([id_col, time_col]).tolist() 246 | if extra_cols: 247 | check_col = extra_cols[0] 248 | if res[check_col].count() < df[check_col].count(): 249 | warnings.warn( 250 | "Some values were lost during filling, " 251 | "please make sure that all your times meet the specified frequency.\n" 252 | "For example if you have 'W-TUE' as your frequency, " 253 | "make sure that all your times are actually Tuesdays." 254 | ) 255 | return res 256 | -------------------------------------------------------------------------------- /tests/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from datetime import date, datetime 3 | from itertools import product 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import polars as pl 8 | 9 | from utilsforecast.data import generate_series 10 | from utilsforecast.preprocessing import fill_gaps 11 | 12 | df = pd.DataFrame( 13 | { 14 | "unique_id": [0, 0, 0, 1, 1], 15 | "ds": pd.to_datetime(["2020", "2021", "2023", "2021", "2022"]), 16 | "y": np.arange(5), 17 | } 18 | ) 19 | fill_gaps( 20 | df, 21 | freq="YS", 22 | ) 23 | fill_gaps( 24 | df, 25 | freq="YS", 26 | end="per_serie", 27 | ) 28 | fill_gaps( 29 | df, 30 | freq="YS", 31 | end="2024", 32 | ) 33 | fill_gaps(df, freq="YS", start="global") 34 | fill_gaps( 35 | df, 36 | freq="YS", 37 | start="2019", 38 | ) 39 | df = pd.DataFrame( 40 | { 41 | "unique_id": [0, 0, 0, 1, 1], 42 | "ds": [2020, 2021, 2023, 2021, 2022], 43 | "y": np.arange(5), 44 | } 45 | ) 46 | 47 | fill_gaps( 48 | df, 49 | freq=1, 50 | start=2019, 51 | end=2024, 52 | ) 53 | df = pl.DataFrame( 54 | { 55 | "unique_id": [0, 0, 0, 1, 1], 56 | "ds": [ 57 | datetime(2020, 1, 1), 58 | datetime(2022, 1, 1), 59 | datetime(2023, 1, 1), 60 | datetime(2021, 1, 1), 61 | datetime(2022, 1, 1), 62 | ], 63 | "y": np.arange(5), 64 | } 65 | ) 66 | df 67 | polars_ms = fill_gaps( 68 | df.with_columns(pl.col("ds").cast(pl.Datetime(time_unit="ms"))), 69 | freq="1y", 70 | start=datetime(2019, 1, 1), 71 | end=datetime(2024, 1, 1), 72 | ) 73 | 74 | 75 | def test_fill_gaps_polars(): 76 | assert polars_ms.schema["ds"].time_unit == "ms" 77 | 78 | 79 | df = pl.DataFrame( 80 | { 81 | "unique_id": [0, 0, 0, 1, 1], 82 | "ds": [ 83 | date(2020, 1, 1), 84 | date(2022, 1, 1), 85 | date(2023, 1, 1), 86 | date(2021, 1, 1), 87 | date(2022, 1, 1), 88 | ], 89 | "y": np.arange(5), 90 | } 91 | ) 92 | 93 | fill_gaps( 94 | df, 95 | freq="1y", 96 | start=date(2020, 1, 1), 97 | end=date(2024, 1, 1), 98 | ) 99 | df = pl.DataFrame( 100 | { 101 | "unique_id": [0, 0, 0, 1, 1], 102 | "ds": [2020, 2021, 2023, 2021, 2022], 103 | "y": np.arange(5), 104 | } 105 | ) 106 | 107 | fill_gaps( 108 | df, 109 | freq=1, 110 | start=2019, 111 | end=2024, 112 | ) 113 | 114 | 115 | def check_fill(dates, freq, start, end, include_start, include_end): 116 | base_idxs = [] 117 | if include_start: 118 | base_idxs.append(0) 119 | if include_end: 120 | base_idxs.append(dates.size - 1) 121 | base_idxs = np.array(base_idxs, dtype=np.int64) 122 | date_idxs = np.hstack( 123 | [ 124 | np.append( 125 | base_idxs, 126 | np.random.choice( 127 | np.arange(1, dates.size - 1), 128 | size=n_periods // 2 - len(base_idxs), 129 | replace=False, 130 | ), 131 | ) 132 | for _ in range(2) 133 | ], 134 | ) 135 | data = pd.DataFrame( 136 | { 137 | "unique_id": np.repeat([1, 2], n_periods // 2), 138 | "ds": dates[date_idxs], 139 | "y": np.arange(n_periods, dtype=np.float64), 140 | } 141 | ) 142 | filled = fill_gaps(data, freq, start=start, end=end) 143 | data_starts_ends = data.groupby("unique_id", observed=True)["ds"].agg( 144 | ["min", "max"] 145 | ) 146 | global_start = data_starts_ends["min"].min() 147 | global_end = data_starts_ends["max"].max() 148 | filled_starts_ends = filled.groupby("unique_id", observed=True)["ds"].agg( 149 | ["min", "max"] 150 | ) 151 | 152 | # inferred frequency is the expected 153 | first_serie = filled[filled["unique_id"] == 1] 154 | if isinstance(freq, str): 155 | if first_serie["ds"].dt.tz is not None: 156 | first_serie = first_serie.copy() 157 | first_serie["ds"] = first_serie["ds"].dt.tz_convert("UTC") 158 | inferred_freq = pd.infer_freq(first_serie["ds"].dt.tz_localize(None)) 159 | assert inferred_freq == pd.tseries.frequencies.to_offset(freq) 160 | else: 161 | assert all(first_serie["ds"].diff().value_counts().index == [freq]) 162 | 163 | # fill keeps original data 164 | assert filled["y"].count() == n_periods 165 | # check starts 166 | if start == "per_serie": 167 | pd.testing.assert_series_equal( 168 | data_starts_ends["min"], 169 | filled_starts_ends["min"], 170 | ) 171 | else: # global or specific 172 | min_dates = filled_starts_ends["min"].unique() 173 | assert min_dates.size == 1 174 | expected_start = global_start if start == "global" else start 175 | assert min_dates[0] == expected_start 176 | 177 | # check ends 178 | if end == "per_serie": 179 | pd.testing.assert_series_equal( 180 | data_starts_ends["max"], 181 | filled_starts_ends["max"], 182 | ) 183 | else: # global or specific 184 | max_dates = filled_starts_ends["max"].unique() 185 | assert max_dates.size == 1 186 | expected_end = global_end if end == "global" else end 187 | assert max_dates[0] == expected_end 188 | 189 | 190 | n_periods = 100 191 | freqs = [ 192 | "YE", 193 | "YS", 194 | "ME", 195 | "MS", 196 | "W", 197 | "W-TUE", 198 | "D", 199 | "s", 200 | "ms", 201 | 1, 202 | 2, 203 | "20D", 204 | "30s", 205 | "2YE", 206 | "3YS", 207 | "30min", 208 | "B", 209 | "1h", 210 | "QS-NOV", 211 | "QE", 212 | ] 213 | try: 214 | pd.tseries.frequencies.to_offset("YE") 215 | except ValueError: 216 | freqs = [ 217 | f.replace("YE", "Y").replace("ME", "M").replace("h", "H").replace("QE", "Q") 218 | for f in freqs 219 | if isinstance(f, str) 220 | ] 221 | for freq in freqs: 222 | if isinstance(freq, (pd.offsets.BaseOffset, str)): 223 | offset = pd.tseries.frequencies.to_offset(freq) 224 | if isinstance(freq, str): 225 | try: 226 | delta = pd.Timedelta(freq) 227 | if delta.days > 0: 228 | tz = None 229 | else: 230 | tz = "Europe/Berlin" 231 | except ValueError: 232 | tz = None 233 | dates = pd.date_range("1950-01-01", periods=n_periods, freq=freq, tz=tz) 234 | else: 235 | dates = np.arange(0, freq * n_periods, freq, dtype=np.int64) 236 | offset = freq 237 | global_start = dates[0] 238 | global_end = dates[-1] 239 | starts = ["global", "per_serie", global_start - offset] 240 | ends = ["global", "per_serie", global_end + offset] 241 | include_starts = [True, False] 242 | include_ends = [True, False] 243 | iterable = product(starts, ends, include_starts, include_ends) 244 | for start, end, include_start, include_end in iterable: 245 | check_fill(dates, freq, start, end, include_start, include_end) 246 | # last value doesn't meet frequency (year start) 247 | dfx = pd.DataFrame( 248 | { 249 | "unique_id": [0, 0, 0, 1, 1], 250 | "ds": pd.to_datetime(["2020-01", "2021-01", "2023-01", "2021-01", "2022-02"]), 251 | "y": np.arange(5), 252 | } 253 | ) 254 | with warnings.catch_warnings(record=True) as w: 255 | fill_gaps(dfx, "YS") 256 | assert "values were lost" in str(w[0].message) 257 | 258 | 259 | # frequency and time column are not compatible 260 | def error_freq(dates, freq, start, end, include_start, include_end, lib): 261 | base_idxs = [] 262 | if include_start: 263 | base_idxs.append(0) 264 | if include_end: 265 | base_idxs.append(np.size(dates) - 1) 266 | base_idxs = np.array(base_idxs, dtype=np.int64) 267 | date_idxs = np.hstack( 268 | [ 269 | np.append( 270 | base_idxs, 271 | np.random.choice( 272 | np.arange(1, np.size(dates) - 1), 273 | size=n_periods // 2 - len(base_idxs), 274 | replace=False, 275 | ), 276 | ) 277 | for _ in range(2) 278 | ], 279 | ) 280 | if lib == "pandas": 281 | data = pd.DataFrame( 282 | { 283 | "unique_id": np.repeat([1, 2], n_periods // 2), 284 | "ds": dates[date_idxs], 285 | "y": np.arange(n_periods, dtype=np.float64), 286 | } 287 | ) 288 | 289 | if lib == "polars": 290 | data = pl.DataFrame( 291 | { 292 | "unique_id": np.repeat([1, 2], n_periods // 2), 293 | "ds": dates[date_idxs], 294 | "y": np.arange(n_periods, dtype=np.float64), 295 | } 296 | ) 297 | 298 | try: 299 | filled = fill_gaps(data, freq, start=start, end=end) 300 | except Exception as e: 301 | assert isinstance(e, ValueError) 302 | 303 | 304 | # https://pandas.pydata.org/docs/user_guide/timeseries.html#timeseries-offset-aliases 305 | freqs_pd = [ 306 | "YE", 307 | "YS", 308 | "ME", 309 | "MS", 310 | "W", 311 | "W-TUE", 312 | "D", 313 | "s", 314 | "ms", 315 | "20D", 316 | "30s", 317 | "2YE", 318 | "3YS", 319 | "30min", 320 | "B", 321 | "1h", 322 | "QS-NOV", 323 | "QE", 324 | ] 325 | 326 | # https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.dt.offset_by.html 327 | freqs_pl = ["1d", "1w", "1mo", "1q", "1y"] 328 | 329 | # integer freqs 330 | freqs_int = list(range(1, 10 + 1)) 331 | 332 | n_periods = 100 333 | for lib in ["pandas", "polars"]: 334 | freqs_list = freqs_pd if lib == "pandas" else freqs_pl 335 | for freq_int, freq_str in product(freqs_int, freqs_list): 336 | dates_int = np.arange(1, (n_periods * freq_int) + 1, freq_int) 337 | 338 | if lib == "pandas": 339 | dates_str = pd.date_range("1950-01-01", periods=n_periods, freq=freq_str) 340 | offset = pd.tseries.frequencies.to_offset(freq) 341 | first_date = dates_str[0] - offset 342 | last_date = dates_str[-1] + offset 343 | 344 | if lib == "polars": 345 | pl_dt = pl.date(1950, 1, 1) 346 | dates_str = pl.date_range( 347 | pl_dt, 348 | pl_dt.dt.offset_by(f"{n_periods}{freq_str[1:]}"), 349 | interval=freq_str, 350 | eager=True, 351 | ) 352 | first_date = dates_str.dt.offset_by(f"-{freq_str}")[0] 353 | last_date = dates_str.dt.offset_by(freq_str)[-1] 354 | 355 | starts = ["global", "per_serie", first_date] 356 | ends = ["global", "per_serie", last_date] 357 | include_starts = [True, False] 358 | include_ends = [True, False] 359 | iterable = product(starts, ends, include_starts, include_ends) 360 | for start, end, include_start, include_end in iterable: 361 | error_freq(dates_str, freq_int, start, end, include_start, include_end, lib) 362 | error_freq(dates_int, freq_str, start, end, include_start, include_end, lib) 363 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022, fastai 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /docs/losses.html.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Losses 3 | description: Loss functions for model evaluation. 4 | --- 5 | 6 | 7 | The most important train signal is the forecast error, which is the 8 | difference between the observed value $y_{\tau}$ and the prediction 9 | $\hat{y}_{\tau}$, at time $y_{\tau}$: 10 | 11 | $$ 12 | 13 | e_{\tau} = y_{\tau}-\hat{y}_{\tau} \qquad \qquad \tau \in \{t+1,\dots,t+H \} 14 | 15 | $$ 16 | 17 | The train loss summarizes the forecast errors in different evaluation 18 | metrics. 19 | 20 | 21 | ## 1. Scale-dependent Errors 22 | 23 | ### Mean Absolute Error 24 | 25 | $$ 26 | 27 | \mathrm{MAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} |y_{\tau} - \hat{y}_{\tau}| 28 | 29 | $$ 30 | ![](./imgs/losses/mae_loss.png) 31 | 32 | ::: utilsforecast.losses.mae 33 | handler: python 34 | options: 35 | docstring_style: google 36 | heading_level: 4 37 | show_root_heading: true 38 | show_source: true 39 | 40 | ### Mean Squared Error 41 | 42 | $$ 43 | 44 | \mathrm{MSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} (y_{\tau} - \hat{y}_{\tau})^{2} 45 | 46 | $$ 47 | 48 | ![](./imgs/losses/mse_loss.png) 49 | 50 | ::: utilsforecast.losses.mse 51 | handler: python 52 | options: 53 | docstring_style: google 54 | heading_level: 4 55 | show_root_heading: true 56 | show_source: true 57 | 58 | ### Root Mean Squared Error 59 | 60 | $$ 61 | 62 | \mathrm{RMSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \sqrt{\frac{1}{H} \sum^{t+H}_{\tau=t+1} (y_{\tau} - \hat{y}_{\tau})^{2}} 63 | 64 | $$ 65 | 66 | ![](./imgs/losses/rmse_loss.png) 67 | ::: utilsforecast.losses.rmse 68 | handler: python 69 | options: 70 | docstring_style: google 71 | heading_level: 4 72 | show_root_heading: true 73 | show_source: true 74 | 75 | ### Bias 76 | 77 | $$ 78 | 79 | \mathrm{Bias}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} (\hat{y}_{\tau} - \mathbf{y}_{\tau}) 80 | 81 | $$ 82 | 83 | ::: utilsforecast.losses.bias 84 | handler: python 85 | options: 86 | docstring_style: google 87 | heading_level: 4 88 | show_root_heading: true 89 | show_source: true 90 | 91 | ### Cumulative Forecast Error 92 | 93 | $$ 94 | 95 | \mathrm{CFE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \sum^{t+H}_{\tau=t+1} (\hat{y}_{\tau} - \mathbf{y}_{\tau}) 96 | 97 | $$ 98 | 99 | ::: utilsforecast.losses.cfe 100 | handler: python 101 | options: 102 | docstring_style: google 103 | heading_level: 4 104 | show_root_heading: true 105 | show_source: true 106 | 107 | ### Absolute Periods In Stock 108 | 109 | $$ 110 | 111 | \mathrm{PIS}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \sum^{t+H}_{\tau=t+1} |y_{\tau} - \hat{y}_{\tau}| 112 | 113 | $$ 114 | 115 | ::: utilsforecast.losses.pis 116 | handler: python 117 | options: 118 | docstring_style: google 119 | heading_level: 4 120 | show_root_heading: true 121 | show_source: true 122 | 123 | ### Linex 124 | 125 | $$ 126 | \mathrm{Linex}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} (e^{a(y_{\tau} - \hat{y}_{\tau})} - a(y_{\tau} - \hat{y}_{\tau}) - 1) 127 | $$ 128 | 129 | where must be $a\neq0$. 130 | 131 | ::: utilsforecast.losses.linex 132 | handler: python 133 | options: 134 | docstring_style: google 135 | heading_level: 4 136 | show_root_heading: true 137 | show_source: true 138 | 139 | ## 2. Percentage Errors 140 | 141 | ### Mean Absolute Percentage Error 142 | 143 | $$ 144 | 145 | \mathrm{MAPE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau}-\hat{y}_{\tau}|}{|y_{\tau}|} 146 | 147 | $$ 148 | 149 | ![](./imgs/losses/mape_loss.png) 150 | ::: utilsforecast.losses.mape 151 | handler: python 152 | options: 153 | docstring_style: google 154 | heading_level: 4 155 | show_root_heading: true 156 | show_source: true 157 | 158 | ### Symmetric Mean Absolute Percentage Error 159 | 160 | $$ 161 | 162 | \mathrm{SMAPE}_{2}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau}-\hat{y}_{\tau}|}{|y_{\tau}|+|\hat{y}_{\tau}|} 163 | 164 | $$ 165 | 166 | ::: utilsforecast.losses.smape 167 | handler: python 168 | options: 169 | docstring_style: google 170 | heading_level: 4 171 | show_root_heading: true 172 | show_source: true 173 | 174 | ## 3. Scale-independent Errors 175 | 176 | ### Mean Absolute Scaled Error 177 | 178 | $$ 179 | 180 | \mathrm{MASE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau}) = 181 | \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau}-\hat{y}_{\tau}|}{\mathrm{MAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau})} 182 | 183 | $$ 184 | 185 | ![](./imgs/losses/mase_loss.png) 186 | 187 | 188 | ::: utilsforecast.losses.mase 189 | handler: python 190 | options: 191 | docstring_style: google 192 | heading_level: 4 193 | show_root_heading: true 194 | show_source: true 195 | 196 | ### Relative Mean Absolute Error 197 | 198 | $$ 199 | 200 | \mathrm{RMAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}, \mathbf{\hat{y}}^{base}_{\tau}) = \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau}-\hat{y}_{\tau}|}{\mathrm{MAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{base}_{\tau})} 201 | 202 | $$ 203 | 204 | ![](./imgs/losses/rmae_loss.png) 205 | 206 | ::: utilsforecast.losses.rmae 207 | handler: python 208 | options: 209 | docstring_style: google 210 | heading_level: 4 211 | show_root_heading: true 212 | show_source: true 213 | 214 | ### Normalized Deviation 215 | 216 | $$ 217 | 218 | \mathrm{ND}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \frac{\sum^{t+H}_{\tau=t+1} |y_{\tau} - \hat{y}_{\tau}|}{\sum^{t+H}_{\tau=t+1} | y_{\tau} |} 219 | 220 | $$ 221 | 222 | ::: utilsforecast.losses.nd 223 | handler: python 224 | options: 225 | docstring_style: google 226 | heading_level: 4 227 | show_root_heading: true 228 | show_source: true 229 | 230 | ### Mean Squared Scaled Error 231 | 232 | $$ 233 | 234 | \mathrm{MSSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau}) = 235 | \frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{(y_{\tau}-\hat{y}_{\tau})^2}{\mathrm{MSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau})} 236 | 237 | $$ 238 | 239 | ::: utilsforecast.losses.msse 240 | handler: python 241 | options: 242 | docstring_style: google 243 | heading_level: 3 244 | show_root_heading: true 245 | show_source: true 246 | 247 | ### Root Mean Squared Scaled Error 248 | 249 | $$ 250 | 251 | \mathrm{RMSSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau}) = 252 | \sqrt{\frac{1}{H} \sum^{t+H}_{\tau=t+1} \frac{(y_{\tau}-\hat{y}_{\tau})^2}{\mathrm{MSE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau})}} 253 | 254 | $$ 255 | 256 | 257 | ::: utilsforecast.losses.rmsse 258 | handler: python 259 | options: 260 | docstring_style: google 261 | heading_level: 3 262 | show_root_heading: true 263 | show_source: true 264 | 265 | ### Scaled Absolute Periods In Stock 266 | 267 | $$ 268 | 269 | \mathrm{PIS}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}_{\tau}) = \sum^{t+H}_{\tau=t+1} \frac{|y_{\tau} - \hat{y}_{\tau}|}{\bar{y}} 270 | 271 | $$ 272 | 273 | where $\bar{y}=\frac{1}{H}\sum^{t+H}_{\tau=t+1} y_{\tau}$. 274 | 275 | ::: utilsforecast.losses.spis 276 | handler: python 277 | options: 278 | docstring_style: google 279 | heading_level: 4 280 | show_root_heading: true 281 | show_source: true 282 | 283 | ## 4. Probabilistic Errors 284 | 285 | 286 | ### Quantile Loss 287 | 288 | $$ 289 | 290 | \mathrm{QL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q)}_{\tau}) = 291 | \frac{1}{H} \sum^{t+H}_{\tau=t+1} 292 | \Big( (1-q)\,( \hat{y}^{(q)}_{\tau} - y_{\tau} )_{+} 293 | + q\,( y_{\tau} - \hat{y}^{(q)}_{\tau} )_{+} \Big) 294 | 295 | $$ 296 | 297 | ![](./imgs/losses/q_loss.png) 298 | 299 | ::: utilsforecast.losses.quantile_loss 300 | handler: python 301 | options: 302 | docstring_style: google 303 | heading_level: 4 304 | show_root_heading: true 305 | show_source: true 306 | 307 | ### Scaled Quantile Loss 308 | 309 | $$ 310 | 311 | \mathrm{SQL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q)}_{\tau}) = 312 | \frac{1}{H} \sum^{t+H}_{\tau=t+1} 313 | \frac{(1-q)\,( \hat{y}^{(q)}_{\tau} - y_{\tau} )_{+} 314 | + q\,( y_{\tau} - \hat{y}^{(q)}_{\tau} )_{+}}{\mathrm{MAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau})} 315 | 316 | $$ 317 | 318 | ::: utilsforecast.losses.scaled_quantile_loss 319 | handler: python 320 | options: 321 | docstring_style: google 322 | heading_level: 4 323 | show_root_heading: tr 324 | 325 | ### Multi-Quantile Loss 326 | 327 | $$ 328 | 329 | \mathrm{MQL}(\mathbf{y}_{\tau}, 330 | [\mathbf{\hat{y}}^{(q_{1})}_{\tau}, ... ,\hat{y}^{(q_{n})}_{\tau}]) = 331 | \frac{1}{n} \sum_{q_{i}} \mathrm{QL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q_{i})}_{\tau}) 332 | 333 | $$ 334 | 335 | ![](./imgs/losses/mq_loss.png) 336 | 337 | ::: utilsforecast.losses.mqloss 338 | handler: python 339 | options: 340 | docstring_style: google 341 | heading_level: 4 342 | show_root_heading: true 343 | show_source: true 344 | 345 | ### Scaled Multi-Quantile Loss 346 | 347 | $$ 348 | 349 | \mathrm{MQL}(\mathbf{y}_{\tau}, 350 | [\mathbf{\hat{y}}^{(q_{1})}_{\tau}, ... ,\hat{y}^{(q_{n})}_{\tau}]) = 351 | \frac{1}{n} \sum_{q_{i}} \frac{\mathrm{QL}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{(q_{i})}_{\tau})}{\mathrm{MAE}(\mathbf{y}_{\tau}, \mathbf{\hat{y}}^{season}_{\tau})} 352 | 353 | $$ 354 | 355 | ::: utilsforecast.losses.scaled_mqloss 356 | handler: python 357 | options: 358 | docstring_style: google 359 | heading_level: 4 360 | show_root_heading: true 361 | show_source: true 362 | 363 | ### Coverage 364 | 365 | ::: utilsforecast.losses.coverage 366 | handler: python 367 | options: 368 | docstring_style: google 369 | heading_level: 4 370 | show_root_heading: true 371 | show_source: true 372 | 373 | ### Calibration 374 | 375 | ::: utilsforecast.losses.calibration 376 | handler: python 377 | options: 378 | docstring_style: google 379 | heading_level: 4 380 | show_root_heading: true 381 | show_source: true 382 | 383 | ### CRPS 384 | 385 | $$ 386 | 387 | \mathrm{sCRPS}(\hat{F}_{\tau}, \mathbf{y}_{\tau}) = \frac{2}{N} \sum_{i} 388 | \int^{1}_{0} \frac{\mathrm{QL}(\hat{F}_{i,\tau}, y_{i,\tau})_{q}}{\sum_{i} | y_{i,\tau} |} dq 389 | 390 | $$ 391 | 392 | Where $\hat{F}_{\tau}$ is the an estimated multivariate distribution, 393 | and $y_{i,\tau}$ are its realizations. 394 | 395 | ::: utilsforecast.losses.scaled_crps 396 | handler: python 397 | options: 398 | docstring_style: google 399 | heading_level: 4 400 | show_root_heading: true 401 | show_source: true 402 | 403 | ### Tweedie Deviance 404 | 405 | For a set of forecasts $\{\mu_i\}_{i=1}^N$ and observations 406 | $\{y_i\}_{i=1}^N$, the mean Tweedie deviance with power $p$ is 407 | 408 | $$ 409 | 410 | \mathrm{TD}_{p}(\boldsymbol{\mu}, \mathbf{y}) 411 | = \frac{1}{N} \sum_{i=1}^{N} d_{p}(y_i, \mu_i) 412 | 413 | $$ 414 | 415 | where the unit-scaled deviance for each pair $(y,\mu)$ is 416 | 417 | $$ 418 | 419 | d_{p}(y,\mu) 420 | = 421 | 2 422 | \begin{cases} 423 | \displaystyle 424 | \frac{y^{2-p}}{(1-p)(2-p)} 425 | \;-\; 426 | \frac{y\,\mu^{1-p}}{1-p} 427 | \;+\; 428 | \frac{\mu^{2-p}}{2-p}, 429 | & p \notin\{1,2\},\\[1em] 430 | \displaystyle 431 | y\,\ln\!\frac{y}{\mu}\;-\;(y-\mu), 432 | & p = 1\quad(\text{Poisson deviance}),\\[0.5em] 433 | \displaystyle 434 | -2\Bigl[\ln\!\frac{y}{\mu}\;-\;\frac{y-\mu}{\mu}\Bigr], 435 | & p = 2\quad(\text{Gamma deviance}). 436 | \end{cases} 437 | 438 | $$ 439 | 440 | - $y_i$ are the true values, $\mu_i$ the predicted means. 441 | - $p$ controls the variance relationship 442 | $\mathrm{Var}(Y)\propto\mu^{p}$. 443 | - When $1 np.ndarray: 30 | """Returns quantiles associated to `level` and the sorte columns of `model_name`""" 31 | level = sorted(level) 32 | alphas = [100 - lv for lv in level] 33 | quantiles = [alpha / 200 for alpha in reversed(alphas)] 34 | quantiles.extend([1 - alpha / 200 for alpha in alphas]) 35 | return np.array(quantiles) 36 | 37 | 38 | def _models_from_levels(model_name: str, level: List[int]) -> List[str]: 39 | level = sorted(level) 40 | cols = [f"{model_name}-lo-{lv}" for lv in reversed(level)] 41 | cols.extend([f"{model_name}-hi-{lv}" for lv in level]) 42 | return cols 43 | 44 | 45 | def _get_model_cols( 46 | cols: List[str], 47 | id_col: str, 48 | time_col: str, 49 | target_col: str, 50 | cutoff_col: str, 51 | ) -> List[str]: 52 | return [ 53 | c 54 | for c in cols 55 | if c not in [id_col, time_col, target_col, cutoff_col] 56 | and not re.search(r"-(?:lo|hi)-\d+", c) 57 | ] 58 | 59 | 60 | def _evaluate_wrapper( 61 | df: pd.DataFrame, 62 | metrics: List[Callable], 63 | models: Optional[List[str]], 64 | level: Optional[List[int]], 65 | id_col: str, 66 | time_col: str, 67 | target_col: str, 68 | cutoff_col: str, 69 | agg_fn: Optional[str], 70 | ) -> pd.DataFrame: 71 | group_cols = _get_group_cols(df, id_col, cutoff_col) 72 | if "_in_sample" in df: 73 | in_sample_mask = df["_in_sample"] 74 | train_df = df.loc[in_sample_mask, [*group_cols, time_col, target_col]] 75 | df = df.loc[~in_sample_mask].drop(columns="_in_sample") 76 | else: 77 | train_df = None 78 | return evaluate( 79 | df=df, 80 | metrics=metrics, 81 | models=models, 82 | train_df=train_df, 83 | level=level, 84 | id_col=id_col, 85 | time_col=time_col, 86 | target_col=target_col, 87 | cutoff_col=cutoff_col, 88 | agg_fn=agg_fn, 89 | ) 90 | 91 | 92 | def _distributed_evaluate( 93 | df: DistributedDFType, 94 | metrics: List[Callable], 95 | models: Optional[List[str]], 96 | train_df: Optional[DFType], 97 | level: Optional[List[int]], 98 | id_col: str, 99 | time_col: str, 100 | target_col: str, 101 | cutoff_col: str, 102 | agg_fn: Optional[str], 103 | ) -> DistributedDFType: 104 | import fugue.api as fa 105 | 106 | if agg_fn is not None: 107 | raise ValueError("`agg_fn` is not supported in distributed") 108 | df_cols = fa.get_column_names(df) 109 | group_cols: list[str] = _get_group_cols(df, id_col, cutoff_col) 110 | if train_df is not None: 111 | # align columns in order to vstack them 112 | def assign_cols(df: pd.DataFrame, cols) -> pd.DataFrame: 113 | return df.assign(**cols) 114 | train_cols = [*group_cols, time_col, target_col] 115 | extra_cols = [c for c in df_cols if c not in train_cols] 116 | train_df = fa.select_columns(train_df, train_cols) 117 | train_df = fa.transform( 118 | train_df, 119 | using=assign_cols, 120 | schema=( 121 | "*," + str(fa.get_schema(df).extract(extra_cols)) + ",_in_sample:bool" 122 | ), 123 | params={ 124 | "cols": { 125 | **{c: float("nan") for c in extra_cols}, 126 | "_in_sample": True, 127 | }, 128 | }, 129 | ) 130 | df = fa.transform( 131 | df, 132 | using=assign_cols, 133 | schema="*,_in_sample:bool", 134 | params={"cols": {"_in_sample": False}}, 135 | ) 136 | df = fa.union(train_df, df) 137 | 138 | if models is None: 139 | model_cols = _get_model_cols(df_cols, id_col, time_col, target_col, cutoff_col) 140 | else: 141 | model_cols = models 142 | models_schema = ",".join(f"{m}:double" for m in model_cols) 143 | result_schema = fa.get_schema(df).extract(*group_cols) + "metric:str" + models_schema 144 | return fa.transform( 145 | df, 146 | using=_evaluate_wrapper, 147 | schema=result_schema, 148 | params=dict( 149 | metrics=metrics, 150 | models=models, 151 | level=level, 152 | id_col=id_col, 153 | time_col=time_col, 154 | target_col=target_col, 155 | cutoff_col=cutoff_col, 156 | agg_fn=agg_fn, 157 | ), 158 | partition={"by": group_cols, "algo": "coarse"}, 159 | ) 160 | 161 | 162 | def evaluate( 163 | df: AnyDFType, 164 | metrics: List[Callable], 165 | models: Optional[List[str]] = None, 166 | train_df: Optional[AnyDFType] = None, 167 | level: Optional[List[int]] = None, 168 | id_col: str = "unique_id", 169 | time_col: str = "ds", 170 | target_col: str = "y", 171 | cutoff_col: str = "cutoff", 172 | agg_fn: Optional[str] = None, 173 | ) -> AnyDFType: 174 | """Evaluate forecast using different metrics. 175 | 176 | Args: 177 | df (pandas, polars, dask or spark DataFrame): Forecasts to evaluate. 178 | Must have `id_col`, `time_col`, `target_col` and models' predictions. 179 | metrics (list of callable): Functions with arguments `df`, `models`, 180 | `id_col`, `target_col` and optionally `train_df`. 181 | models (list of str, optional): Names of the models to evaluate. 182 | If `None` will use every column in the dataframe after removing 183 | id, time and target. Defaults to None. 184 | train_df (pandas, polars, dask or spark DataFrame, optional): Training set. 185 | Used to evaluate metrics such as `mase`. Defaults to None. 186 | level (list of int, optional): Prediction interval levels. Used to compute 187 | losses that rely on quantiles. Defaults to None. 188 | id_col (str, optional): Column that identifies each serie. 189 | Defaults to 'unique_id'. 190 | time_col (str, optional): Column that identifies each timestep, its values 191 | can be timestamps or integers. Defaults to 'ds'. 192 | target_col (str, optional): Column that contains the target. 193 | Defaults to 'y'. 194 | cutoff_col (str, optional): Column that identifies the cutoff point for 195 | each forecast cross-validation fold. Defaults to 'cutoff'. 196 | agg_fn (str, optional): Statistic to compute on the scores by id to reduce 197 | them to a single number. Defaults to None. 198 | 199 | Returns: 200 | pandas, polars, dask or spark DataFrame: Metrics with one row per 201 | (id, metric) combination and one column per model. If `agg_fn` is 202 | not `None`, there is only one row per metric. 203 | """ 204 | if not isinstance(df, (pd.DataFrame, pl_DataFrame)): 205 | return _distributed_evaluate( 206 | df=df, 207 | metrics=metrics, 208 | models=models, 209 | train_df=train_df, 210 | level=level, 211 | id_col=id_col, 212 | time_col=time_col, 213 | target_col=target_col, 214 | cutoff_col=cutoff_col, 215 | agg_fn=agg_fn, 216 | ) 217 | if models is None: 218 | model_cols = _get_model_cols(df.columns, id_col, time_col, target_col, cutoff_col) 219 | else: 220 | model_cols = models 221 | 222 | # interval cols 223 | if level is not None: 224 | expected_cols = { 225 | f"{m}-{side}-{lvl}" 226 | for m in model_cols 227 | for side in ("lo", "hi") 228 | for lvl in level 229 | } 230 | missing = expected_cols - set(df.columns) 231 | if missing: 232 | raise ValueError( 233 | f"The following columns are required for level={level} " 234 | f"and are missing: {missing}" 235 | ) 236 | else: 237 | requires_level = [ 238 | m 239 | for m in metrics 240 | if get_origin(inspect.signature(m).parameters["models"].annotation) is dict 241 | ] 242 | if requires_level: 243 | raise ValueError( 244 | f"The following metrics require setting `level`: {requires_level}" 245 | ) 246 | 247 | # y_train 248 | metric_requires_y_train = { 249 | _function_name(m): "train_df" in inspect.signature(m).parameters 250 | for m in metrics 251 | } 252 | y_train_metrics = [ 253 | m for m, requires_yt in metric_requires_y_train.items() if requires_yt 254 | ] 255 | if y_train_metrics: 256 | if train_df is None: 257 | raise ValueError( 258 | f"The following metrics require y_train: {y_train_metrics}. " 259 | "Please provide `train_df`." 260 | ) 261 | train_df = ufp.sort(train_df, by=[id_col, time_col]) 262 | missing_series = set(df[id_col].unique()) - set(train_df[id_col].unique()) 263 | if missing_series: 264 | raise ValueError( 265 | f"The following series are missing from the train_df: {reprlib.repr(missing_series)}" 266 | ) 267 | 268 | results_per_metric = [] 269 | for metric in metrics: 270 | metric_name = _function_name(metric) 271 | kwargs = dict(df=df, models=model_cols, id_col=id_col, target_col=target_col) 272 | if metric_requires_y_train[metric_name]: 273 | kwargs["train_df"] = train_df 274 | kwargs["cutoff_col"] = cutoff_col 275 | kwargs["time_col"] = time_col 276 | metric_params = inspect.signature(metric).parameters 277 | if "baseline" in metric_params: 278 | metric_name = f"{metric_name}_{metric_params['baseline'].default}" 279 | if "q" in metric_params or metric_params["models"].annotation is Dict[str, str]: 280 | assert level is not None # we've already made sure of this above 281 | for lvl in level: 282 | quantiles = _quantiles_from_levels([lvl]) 283 | for q, side in zip(quantiles, ["lo", "hi"]): 284 | kwargs["models"] = { 285 | model: f"{model}-{side}-{lvl}" for model in model_cols 286 | } 287 | if "q" in metric_params: 288 | # this is for calibration, since it uses the predictions for q 289 | # but doesn't use it 290 | kwargs["q"] = q 291 | result = metric(**kwargs) 292 | result = ufp.assign_columns(result, "metric", f"{metric_name}_q{q}") 293 | results_per_metric.append(result) 294 | elif "quantiles" in metric_params: 295 | assert level is not None # we've already made sure of this above 296 | quantiles = _quantiles_from_levels(level) 297 | kwargs["quantiles"] = quantiles 298 | kwargs["models"] = { 299 | model: _models_from_levels(model, level) for model in model_cols 300 | } 301 | result = metric(**kwargs) 302 | result = ufp.assign_columns(result, "metric", metric_name) 303 | results_per_metric.append(result) 304 | elif "level" in metric_params: 305 | assert level is not None # we've already made sure of this above 306 | for lvl in level: 307 | kwargs["level"] = lvl 308 | result = metric(**kwargs) 309 | result = ufp.assign_columns( 310 | result, "metric", f"{metric_name}_level{lvl}" 311 | ) 312 | results_per_metric.append(result) 313 | else: 314 | result = metric(**kwargs) 315 | result = ufp.assign_columns(result, "metric", metric_name) 316 | results_per_metric.append(result) 317 | if isinstance(df, pd.DataFrame): 318 | df = pd.concat(results_per_metric).reset_index(drop=True) 319 | else: 320 | df = pl.concat(results_per_metric, how="diagonal") 321 | 322 | if cutoff_col in df.columns: 323 | id_cols = [id_col, cutoff_col, "metric"] 324 | else: 325 | id_cols = [id_col, "metric"] 326 | 327 | model_cols = [c for c in df.columns if c not in id_cols] 328 | df = df[id_cols + model_cols] 329 | if agg_fn is not None: 330 | group_cols = id_cols[1:] # exclude id_col 331 | df = ufp.group_by_agg( 332 | df, 333 | by=group_cols, 334 | aggs={m: agg_fn for m in model_cols}, 335 | maintain_order=True, 336 | ) 337 | return df 338 | -------------------------------------------------------------------------------- /utilsforecast/feature_engineering.py: -------------------------------------------------------------------------------- 1 | """Create exogenous regressors for your models""" 2 | 3 | __all__ = ['fourier', 'trend', 'time_features', 'future_exog_to_historic', 'pipeline'] 4 | 5 | 6 | from functools import partial 7 | from typing import Callable, List, Optional, Tuple, Union 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import utilsforecast.processing as ufp 13 | 14 | from .compat import DataFrame, DFType, pl, pl_DataFrame, pl_Expr 15 | from .validation import validate_format, validate_freq 16 | 17 | _Features = Tuple[List[str], np.ndarray, np.ndarray] 18 | 19 | 20 | def _add_features( 21 | df: DFType, 22 | freq: Union[str, int], 23 | h: int, 24 | id_col: str, 25 | time_col: str, 26 | f: Callable[[np.ndarray, int], _Features], 27 | ) -> Tuple[DFType, DFType]: 28 | # validations 29 | if not isinstance(h, int) or h < 0: 30 | raise ValueError("`h` must be a non-negative integer") 31 | validate_format(df, id_col, time_col, None) 32 | validate_freq(df[time_col], freq) 33 | 34 | # decompose series 35 | id_counts = ufp.counts_by_id(df, id_col) 36 | uids = id_counts[id_col] 37 | sizes = id_counts["counts"].to_numpy() 38 | 39 | # compute values 40 | cols, vals, future_vals = f(sizes=sizes, h=h) # type: ignore 41 | 42 | # assign back to df 43 | sort_idxs = ufp.maybe_compute_sort_indices(df, id_col, time_col) 44 | times = df[time_col] 45 | if sort_idxs is not None: 46 | restore_idxs = np.empty_like(sort_idxs) 47 | restore_idxs[sort_idxs] = np.arange(sort_idxs.size) 48 | vals = vals[restore_idxs] 49 | times = ufp.take_rows(times, sort_idxs) 50 | last_times = ufp.take_rows(times, sizes.cumsum() - 1) 51 | df = ufp.copy_if_pandas(df, deep=False) 52 | transformed = ufp.assign_columns(df, cols, vals) 53 | 54 | if h == 0: 55 | return transformed, type(df)({}) 56 | 57 | # future vals 58 | future_df = ufp.make_future_dataframe( 59 | uids=uids, 60 | last_times=last_times, 61 | freq=freq, 62 | h=h, 63 | id_col=id_col, 64 | time_col=time_col, 65 | ) 66 | future_df = ufp.assign_columns(future_df, cols, future_vals) 67 | return transformed, future_df 68 | 69 | 70 | def _assign_slices( 71 | sizes: np.ndarray, 72 | feats: np.ndarray, 73 | h: int, 74 | ) -> Tuple[np.ndarray, np.ndarray]: 75 | max_samples, n_feats = feats.shape 76 | vals = np.empty((sizes.sum(), n_feats), dtype=np.float32) 77 | future_vals = np.empty((h * sizes.size, n_feats)) 78 | start = 0 79 | for i, size in enumerate(sizes): 80 | vals[start : start + size, :] = feats[max_samples - size - h : max_samples - h] 81 | future_vals[i * h : (i + 1) * h] = feats[max_samples - h :] 82 | start += size 83 | return vals, future_vals 84 | 85 | 86 | def _fourier( 87 | sizes: np.ndarray, 88 | h: int, 89 | season_length: int, 90 | k: int, 91 | ) -> _Features: 92 | # taken from: https://github.com/tblume1992/TSUtilities/blob/main/TSUtilities/TSFeatures/fourier_seasonality.py 93 | x = 2 * np.pi * np.arange(1, k + 1) / season_length 94 | x = x.astype(np.float32) 95 | t = np.arange(1, sizes.max() + 1 + h, dtype=np.float32) 96 | x = x * t[:, None] 97 | terms = np.hstack([np.sin(x), np.cos(x)]) 98 | cols = [f"{op}{i+1}_{season_length}" for op in ("sin", "cos") for i in range(k)] 99 | vals, future_vals = _assign_slices(sizes=sizes, feats=terms, h=h) 100 | return cols, vals, future_vals 101 | 102 | 103 | def _trend(sizes: np.ndarray, h: int) -> _Features: 104 | t = np.arange(1, sizes.max() + 1 + h, dtype=np.float32).reshape(-1, 1) 105 | cols = ["trend"] 106 | vals, future_vals = _assign_slices(sizes=sizes, feats=t, h=h) 107 | return cols, vals, future_vals 108 | 109 | 110 | def fourier( 111 | df: DFType, 112 | freq: Union[str, int], 113 | season_length: int, 114 | k: int, 115 | h: int = 0, 116 | id_col: str = "unique_id", 117 | time_col: str = "ds", 118 | ) -> Tuple[DFType, DFType]: 119 | """Compute fourier seasonal terms for training and forecasting 120 | 121 | Args: 122 | df (pandas or polars DataFrame): Dataframe with ids, times and values 123 | for the exogenous regressors. 124 | freq (str or int): Frequency of the data. Must be a valid pandas or 125 | polars offset alias, or an integer. 126 | season_length (int): Number of observations per unit of time. 127 | Ex: 24 Hourly data. 128 | k (int): Maximum order of the fourier terms 129 | h (int, optional): Forecast horizon. Defaults to 0. 130 | id_col (str, optional): Column that identifies each serie. 131 | Defaults to 'unique_id'. 132 | time_col (str, optional): Column that identifies each timestep, its 133 | values can be timestamps or integers. Defaults to 'ds'. 134 | 135 | Returns: 136 | tuple[pandas or polars DataFrame, pandas or polars DataFrame]: A tuple 137 | containing the original DataFrame with the computed features and 138 | DataFrame with future values. 139 | """ 140 | f = partial(_fourier, season_length=season_length, k=k) 141 | return _add_features( 142 | df=df, 143 | freq=freq, 144 | h=h, 145 | id_col=id_col, 146 | time_col=time_col, 147 | f=f, 148 | ) 149 | 150 | 151 | def trend( 152 | df: DFType, 153 | freq: Union[str, int], 154 | h: int = 0, 155 | id_col: str = "unique_id", 156 | time_col: str = "ds", 157 | ) -> Tuple[DFType, DFType]: 158 | """Add a trend column with consecutive integers for training and forecasting 159 | 160 | Args: 161 | df (pandas or polars DataFrame): Dataframe with ids, times and values 162 | for the exogenous regressors. 163 | freq (str or int): Frequency of the data. Must be a valid pandas or 164 | polars offset alias, or an integer. 165 | h (int, optional): Forecast horizon. Defaults to 0. 166 | id_col (str, optional): Column that identifies each serie. 167 | Defaults to 'unique_id'. 168 | time_col (str, optional): Column that identifies each timestep, its 169 | values can be timestamps or integers. Defaults to 'ds'. 170 | 171 | Returns: 172 | tuple[pandas or polars DataFrame, pandas or polars DataFrame]: A tuple 173 | containing the original DataFrame with the computed features and 174 | DataFrame with future values. 175 | """ 176 | return _add_features( 177 | df=df, 178 | freq=freq, 179 | h=h, 180 | id_col=id_col, 181 | time_col=time_col, 182 | f=_trend, 183 | ) 184 | 185 | 186 | def _compute_time_feature( 187 | times: Union[pd.Index, pl_Expr], 188 | feature: Union[str, Callable], 189 | ) -> Tuple[ 190 | Union[str, List[str]], 191 | Union[pd.DataFrame, pl_Expr, List[pl_Expr], pd.Index, np.ndarray], 192 | ]: 193 | if callable(feature): 194 | feat_vals = feature(times) 195 | if isinstance(feat_vals, pd.DataFrame): 196 | feat_name = feat_vals.columns.tolist() 197 | feat_vals = feat_vals.to_numpy() 198 | else: 199 | feat_name = feature.__name__ 200 | else: 201 | feat_name = feature 202 | if isinstance(times, pd.DatetimeIndex): 203 | if feature in ("week", "weekofyear"): 204 | times = times.isocalendar() 205 | feat_vals = getattr(times, feature).to_numpy() 206 | else: 207 | feat_vals = getattr(times.dt, feature)() 208 | return feat_name, feat_vals 209 | 210 | 211 | def _add_time_features( 212 | df: DFType, 213 | features: List[Union[str, Callable]], 214 | time_col: str = "ds", 215 | ) -> DFType: 216 | df = ufp.copy_if_pandas(df, deep=False) 217 | unique_times = df[time_col].unique() 218 | if isinstance(df, pd.DataFrame): 219 | times = pd.Index(unique_times) 220 | time2pos = {time: i for i, time in enumerate(times)} 221 | restore_idxs = df[time_col].map(time2pos).to_numpy() 222 | for feature in features: 223 | name, vals = _compute_time_feature(times, feature) 224 | df[name] = vals[restore_idxs] 225 | elif isinstance(df, pl_DataFrame): 226 | exprs = [] 227 | for feature in features: 228 | name, vals = _compute_time_feature(pl.col(time_col), feature) 229 | if isinstance(vals, list): 230 | exprs.extend(vals) 231 | else: 232 | assert isinstance(vals, pl_Expr) 233 | exprs.append(vals.alias(name)) 234 | feats = unique_times.to_frame().with_columns(*exprs) 235 | df = df.join(feats, on=time_col, how="left") 236 | return df 237 | 238 | 239 | def time_features( 240 | df: DFType, 241 | freq: Union[str, int], 242 | features: List[Union[str, Callable]], 243 | h: int = 0, 244 | id_col: str = "unique_id", 245 | time_col: str = "ds", 246 | ) -> Tuple[DFType, DFType]: 247 | """Compute timestamp-based features for training and forecasting 248 | 249 | Args: 250 | df (pandas or polars DataFrame): Dataframe with ids, times and values 251 | for the exogenous regressors. 252 | freq (str or int): Frequency of the data. Must be a valid pandas or 253 | polars offset alias, or an integer. 254 | features (list of str or callable): Features to compute. Can be string 255 | aliases of timestamp attributes or functions to apply to the times. 256 | h (int, optional): Forecast horizon. Defaults to 0. 257 | id_col (str, optional): Column that identifies each serie. 258 | Defaults to 'unique_id'. 259 | time_col (str, optional): Column that identifies each timestep, its 260 | values can be timestamps or integers. Defaults to 'ds'. 261 | 262 | Returns: 263 | tuple[pandas or polars DataFrame, pandas or polars DataFrame]: A tuple 264 | containing the original DataFrame with the computed features and 265 | DataFrame with future values. 266 | """ 267 | transformed = _add_time_features(df=df, features=features, time_col=time_col) 268 | if h == 0: 269 | return transformed, type(df)({}) 270 | times_by_id = ufp.group_by_agg(df, id_col, {time_col: "max"}, maintain_order=True) 271 | times_by_id = ufp.sort(times_by_id, id_col) 272 | future = ufp.make_future_dataframe( 273 | uids=times_by_id[id_col], 274 | last_times=times_by_id[time_col], 275 | freq=freq, 276 | h=h, 277 | id_col=id_col, 278 | time_col=time_col, 279 | ) 280 | future = _add_time_features(df=future, features=features, time_col=time_col) 281 | return transformed, future 282 | 283 | 284 | def future_exog_to_historic( 285 | df: DFType, 286 | freq: Union[str, int], 287 | features: List[str], 288 | h: int = 0, 289 | id_col: str = "unique_id", 290 | time_col: str = "ds", 291 | ) -> Tuple[DFType, DFType]: 292 | """Turn future exogenous features into historic by shifting them `h` steps. 293 | 294 | Args: 295 | df (pandas or polars DataFrame): Dataframe with ids, times and values 296 | for the exogenous regressors. 297 | freq (str or int): Frequency of the data. Must be a valid pandas or 298 | polars offset alias, or an integer. 299 | features (list of str): Features to be converted into historic. 300 | h (int, optional): Forecast horizon. Defaults to 0. 301 | id_col (str, optional): Column that identifies each serie. 302 | Defaults to 'unique_id'. 303 | time_col (str, optional): Column that identifies each timestep, its 304 | values can be timestamps or integers. Defaults to 'ds'. 305 | 306 | Returns: 307 | tuple[pandas or polars DataFrame, pandas or polars DataFrame]: A tuple 308 | containing the original DataFrame with the computed features and 309 | DataFrame with future values. 310 | """ 311 | if h == 0: 312 | return df, type(df)({}) 313 | new_feats = ufp.copy_if_pandas(df[[id_col, time_col, *features]]) 314 | new_feats = ufp.assign_columns( 315 | new_feats, 316 | time_col, 317 | ufp.offset_times(new_feats[time_col], freq=freq, n=h), 318 | ) 319 | df = ufp.drop_columns(df, features) 320 | df = ufp.join(df, new_feats, on=[id_col, time_col], how="left") 321 | times_by_id = ufp.group_by_agg(df, id_col, {time_col: "max"}, maintain_order=True) 322 | times_by_id = ufp.sort(times_by_id, id_col) 323 | future = ufp.make_future_dataframe( 324 | uids=times_by_id[id_col], 325 | last_times=times_by_id[time_col], 326 | freq=freq, 327 | h=h, 328 | id_col=id_col, 329 | time_col=time_col, 330 | ) 331 | future = ufp.join(future, new_feats, on=[id_col, time_col], how="left") 332 | return df, future 333 | 334 | 335 | def pipeline( 336 | df: DFType, 337 | features: List[Callable], 338 | freq: Union[str, int], 339 | h: int = 0, 340 | id_col: str = "unique_id", 341 | time_col: str = "ds", 342 | ) -> Tuple[DFType, DFType]: 343 | """Compute several features for training and forecasting 344 | 345 | Args: 346 | df (pandas or polars DataFrame): Dataframe with ids, times and values 347 | for the exogenous regressors. 348 | features (list of callable): List of features to compute. Must take only 349 | df, freq, h, id_col and time_col (other arguments must be fixed). 350 | freq (str or int): Frequency of the data. Must be a valid pandas or 351 | polars offset alias, or an integer. 352 | h (int, optional): Forecast horizon. Defaults to 0. 353 | id_col (str, optional): Column that identifies each serie. 354 | Defaults to 'unique_id'. 355 | time_col (str, optional): Column that identifies each timestep, its 356 | values can be timestamps or integers. Defaults to 'ds'. 357 | 358 | Returns: 359 | tuple[pandas or polars DataFrame, pandas or polars DataFrame]: A tuple 360 | containing the original DataFrame with the computed features and 361 | DataFrame with future values. 362 | """ 363 | transformed: Optional[DataFrame] = None 364 | future: Optional[DataFrame] = None 365 | for f in features: 366 | f_transformed, f_future = f( 367 | df=df, freq=freq, h=h, id_col=id_col, time_col=time_col 368 | ) 369 | if transformed is None: 370 | transformed = f_transformed 371 | future = f_future 372 | else: 373 | feat_cols = [c for c in f_future.columns if c not in (id_col, time_col)] 374 | transformed = ufp.horizontal_concat([transformed, f_transformed[feat_cols]]) 375 | future = ufp.horizontal_concat([future, f_future[feat_cols]]) 376 | return transformed, future 377 | -------------------------------------------------------------------------------- /utilsforecast/plotting.py: -------------------------------------------------------------------------------- 1 | """Time series visualizations""" 2 | 3 | __all__ = ['plot_series'] 4 | 5 | 6 | import re 7 | from typing import TYPE_CHECKING, Dict, List, Optional, Union 8 | 9 | try: 10 | import matplotlib as mpl 11 | import matplotlib.colors as cm 12 | import matplotlib.pyplot as plt 13 | except ImportError: 14 | raise ImportError( 15 | "matplotlib is not installed. Please install it and try again.\n" 16 | "You can find detailed instructions at https://matplotlib.org/stable/users/installing/index.html" 17 | ) 18 | import numpy as np 19 | import pandas as pd 20 | 21 | if TYPE_CHECKING: 22 | import plotly 23 | from packaging.version import Version 24 | from packaging.version import parse as parse_version 25 | 26 | import utilsforecast.processing as ufp 27 | 28 | from .compat import DFType, pl, pl_Series 29 | from .validation import validate_format 30 | 31 | 32 | def _filter_series(df, id_col, time_col, uids, models=None, max_insample_length=None): 33 | out_cols = [id_col, time_col] 34 | if models is not None: 35 | models_pat = r"|".join(models).replace(r"(", r"\(").replace(r")", r"\)") 36 | interval_cols = [ 37 | c for c in df.columns if re.search(rf"^({models_pat})-(?:lo|hi)-\d+", c) 38 | ] 39 | out_cols.extend(models + interval_cols) 40 | mask = ufp.is_in(df[id_col], uids) 41 | df = ufp.filter_with_mask(df, mask) 42 | df = df[out_cols] 43 | df = ufp.sort(df, time_col) 44 | if max_insample_length is not None: 45 | df = ufp.group_by(df, id_col, maintain_order=True).tail(max_insample_length) 46 | return df 47 | 48 | 49 | def plot_series( 50 | df: Optional[DFType] = None, 51 | forecasts_df: Optional[DFType] = None, 52 | ids: Optional[List[str]] = None, 53 | plot_random: bool = True, 54 | max_ids: int = 8, 55 | models: Optional[List[str]] = None, 56 | level: Optional[List[float]] = None, 57 | max_insample_length: Optional[int] = None, 58 | plot_anomalies: bool = False, 59 | engine: str = "matplotlib", 60 | palette: Optional[str] = None, 61 | id_col: str = "unique_id", 62 | time_col: str = "ds", 63 | target_col: str = "y", 64 | seed: int = 0, 65 | resampler_kwargs: Optional[Dict] = None, 66 | ax: Optional[Union[plt.Axes, np.ndarray, "plotly.graph_objects.Figure"]] = None, 67 | ): 68 | """Plot forecasts and insample values. 69 | 70 | Args: 71 | df (pandas or polars DataFrame, optional): DataFrame with columns 72 | [`id_col`, `time_col`, `target_col`]. Defaults to None. 73 | forecasts_df (pandas or polars DataFrame, optional): DataFrame with 74 | columns [`id_col`, `time_col`] and models. Defaults to None. 75 | ids (list of str, optional): Time Series to plot. If None, time series 76 | are selected randomly. Defaults to None. 77 | plot_random (bool, optional): Select time series to plot randomly. 78 | Defaults to True. 79 | max_ids (int, optional): Maximum number of ids to plot. Defaults to 8. 80 | models (list of str, optional): Models to plot. Defaults to None. 81 | level (list of float, optional): Prediction intervals to plot. 82 | Defaults to None. 83 | max_insample_length (int, optional): Maximum number of train/insample 84 | observations to be plotted. Defaults to None. 85 | plot_anomalies (bool, optional): Plot anomalies for each prediction 86 | interval. Defaults to False. 87 | engine (str, optional): Library used to plot. 'plotly', 'plotly-resampler' 88 | or 'matplotlib'. Defaults to 'matplotlib'. 89 | palette (str, optional): Name of the matplotlib colormap to use for the 90 | plots. If None, uses the current style. Defaults to None. 91 | id_col (str, optional): Column that identifies each serie. 92 | Defaults to 'unique_id'. 93 | time_col (str, optional): Column that identifies each timestep, its 94 | values can be timestamps or integers. Defaults to 'ds'. 95 | target_col (str, optional): Column that contains the target. 96 | Defaults to 'y'. 97 | seed (int, optional): Seed used for the random number generator. Only 98 | used if plot_random is True. Defaults to 0. 99 | resampler_kwargs (dict, optional): Keyword arguments to be passed to 100 | plotly-resampler constructor. For further custumization ("show_dash") 101 | call the method, store the plotting object and add the extra arguments 102 | to its `show_dash` method. Defaults to None. 103 | ax (matplotlib axes, array of matplotlib axes or plotly Figure, optional): 104 | Object where plots will be added. Defaults to None. 105 | 106 | Returns: 107 | matplotlib or plotly figure: Plot's figure 108 | """ 109 | # checks 110 | supported_engines = ["matplotlib", "plotly", "plotly-resampler"] 111 | if engine not in supported_engines: 112 | raise ValueError(f"engine must be one of {supported_engines}, got '{engine}'.") 113 | if engine.startswith("plotly"): 114 | try: 115 | import plotly.graph_objects as go 116 | from plotly.subplots import make_subplots 117 | except ImportError: 118 | raise ImportError( 119 | "plotly is not installed. Please install it and try again.\n" 120 | "You can find detailed instructions at https://github.com/plotly/plotly.py#installation" 121 | ) 122 | if plot_anomalies: 123 | if level is None: 124 | raise ValueError( 125 | "In order to plot anomalies you have to specify the `level` argument" 126 | ) 127 | elif forecasts_df is None or not any("lo" in c for c in forecasts_df.columns): 128 | raise ValueError( 129 | "In order to plot anomalies you have to provide a `forecasts_df` with prediction intervals." 130 | ) 131 | if level is not None and not isinstance(level, list): 132 | raise ValueError( 133 | "Please use a list for the `level` argument " 134 | "If you only have one level, use `level=[your_level]`" 135 | ) 136 | elif level is None: 137 | level = [] 138 | if df is None and forecasts_df is None: 139 | raise ValueError("At least one of `df` and `forecasts_df` must be provided.") 140 | elif df is not None: 141 | validate_format(df, id_col, time_col, target_col) 142 | elif forecasts_df is not None: 143 | validate_format(forecasts_df, id_col, time_col, None) 144 | 145 | # models to plot 146 | if models is None: 147 | if forecasts_df is None: 148 | models = [] 149 | else: 150 | models = [ 151 | c 152 | for c in forecasts_df.columns 153 | if c not in [id_col, time_col, target_col] 154 | and not re.search(r"-(?:lo|hi)-\d+", c) 155 | ] 156 | 157 | # ids 158 | if ids is None: 159 | if df is not None: 160 | uids: Union[np.ndarray, pl_Series, List] = df[id_col].unique() 161 | else: 162 | assert forecasts_df is not None 163 | uids = forecasts_df[id_col].unique() 164 | else: 165 | uids = ids 166 | if ax is not None: 167 | if isinstance(ax, plt.Axes): 168 | ax = np.array([ax]) 169 | if isinstance(ax, np.ndarray) and isinstance(ax.flat[0], plt.Axes): 170 | gs = ax.flat[0].get_gridspec() 171 | n_rows, n_cols = gs.nrows, gs.ncols 172 | ax = ax.reshape(n_rows, n_cols) 173 | elif engine.startswith("plotly") and isinstance(ax, go.Figure): 174 | rows, cols = ax._get_subplot_rows_columns() 175 | # rows and cols are ranges 176 | n_rows = len(rows) 177 | n_cols = len(cols) 178 | else: 179 | raise ValueError(f"Cannot process `ax` of type: {type(ax).__name__}.") 180 | max_ids = n_rows * n_cols 181 | if len(uids) > max_ids and plot_random: 182 | rng = np.random.RandomState(seed) 183 | uids = rng.choice(uids, size=max_ids, replace=False) 184 | else: 185 | uids = uids[:max_ids] 186 | n_series = len(uids) 187 | if ax is None: 188 | if n_series == 1: 189 | n_cols = 1 190 | else: 191 | n_cols = 2 192 | quot, resid = divmod(n_series, n_cols) 193 | n_rows = quot + resid 194 | 195 | # filtering 196 | if df is not None: 197 | df = _filter_series( 198 | df=df, 199 | id_col=id_col, 200 | time_col=time_col, 201 | uids=uids, 202 | models=[target_col], 203 | max_insample_length=max_insample_length, 204 | ) 205 | if forecasts_df is not None: 206 | forecasts_df = _filter_series( 207 | df=forecasts_df, 208 | id_col=id_col, 209 | time_col=time_col, 210 | uids=uids, 211 | models=[target_col] + models if target_col in forecasts_df else models, 212 | max_insample_length=None, 213 | ) 214 | if df is None: 215 | df = forecasts_df 216 | else: 217 | if isinstance(df, pd.DataFrame): 218 | df = pd.concat([df, forecasts_df]) 219 | else: 220 | df = pl.concat([df, forecasts_df], how="align") 221 | 222 | xlabel = f"Time [{time_col}]" 223 | ylabel = f"Target [{target_col}]" 224 | if palette is not None: 225 | if parse_version(mpl.__version__) < Version("3.6"): 226 | cmap = plt.cm.get_cmap(palette, len(models) + 1) 227 | else: 228 | cmap = mpl.colormaps[palette].resampled(len(models) + 1) 229 | colors = [cm.to_hex(color) for color in cmap.colors] 230 | else: 231 | colors_stylesheet = plt.rcParams["axes.prop_cycle"].by_key()["color"] 232 | cmap = cm.LinearSegmentedColormap.from_list( 233 | "mymap", colors_stylesheet 234 | ).resampled(len(models) + 1) 235 | rgb_colors = cmap(np.linspace(0, 1, len(models) + 1)) 236 | colors = [cm.to_hex(color) for color in rgb_colors] 237 | 238 | # define plot grid 239 | if ax is None: 240 | postprocess = True 241 | if engine.startswith("plotly"): 242 | fig = make_subplots( 243 | rows=n_rows, 244 | cols=n_cols, 245 | vertical_spacing=0.15, 246 | horizontal_spacing=0.07, 247 | x_title=xlabel, 248 | y_title=ylabel, 249 | subplot_titles=[f"{id_col}={uid}" for uid in uids], 250 | ) 251 | if engine == "plotly-resampler": 252 | try: 253 | from plotly_resampler import FigureResampler 254 | except ImportError: 255 | raise ImportError( 256 | "The 'plotly-resampler' package is required " 257 | "when `engine='plotly-resampler'`." 258 | ) 259 | resampler_kwargs = {} if resampler_kwargs is None else resampler_kwargs 260 | fig = FigureResampler(fig, **resampler_kwargs) 261 | else: 262 | fig, ax = plt.subplots( 263 | nrows=n_rows, 264 | ncols=n_cols, 265 | figsize=(16, 3.5 * n_rows), 266 | squeeze=False, 267 | constrained_layout=True, 268 | ) 269 | else: 270 | postprocess = False 271 | if engine.startswith("plotly"): 272 | fig = ax 273 | else: 274 | fig = plt.gcf() 275 | 276 | def _add_mpl_plot(axi, df, y_col, levels): 277 | axi.plot(df[time_col], df[y_col], label=y_col, color=color) 278 | if y_col == target_col: 279 | return 280 | times = df[time_col] 281 | for level in levels: 282 | lo = df[f"{y_col}-lo-{level}"] 283 | hi = df[f"{y_col}-hi-{level}"] 284 | min_alpha = 0.1 # fix alpha to avoid transparency issues 285 | max_alpha = 0.9 286 | alpha = max_alpha - (float(level) / 100) * (max_alpha - min_alpha) 287 | axi.fill_between( 288 | times, 289 | lo, 290 | hi, 291 | alpha=alpha, 292 | color=color, 293 | label=f"{y_col}_level_{level}", 294 | ) 295 | if plot_anomalies: 296 | anomalies = df[target_col].lt(lo) | df[target_col].gt(hi) 297 | anomalies = anomalies.to_numpy().astype("bool") 298 | if not anomalies.any(): 299 | continue 300 | axi.scatter( 301 | x=times.to_numpy()[anomalies], 302 | y=df[target_col].to_numpy()[anomalies], 303 | color=color, 304 | s=30, 305 | alpha=float(level) / 100, 306 | label=f"{y_col}_anomalies_level_{level}", 307 | linewidths=0.5, 308 | edgecolors="red", 309 | ) 310 | 311 | def _add_plotly_plot(fig, df, y_col, levels): 312 | show_legend = row == 0 and col == 0 313 | fig.add_trace( 314 | go.Scatter( 315 | x=df[time_col], 316 | y=df[y_col], 317 | mode="lines", 318 | name=y_col, 319 | legendgroup=y_col, 320 | line=dict(color=color, width=1), 321 | showlegend=show_legend, 322 | ), 323 | row=row + 1, 324 | col=col + 1, 325 | ) 326 | if y_col == target_col: 327 | return 328 | x = np.concatenate([df[time_col], df[time_col][::-1]]) 329 | for level in levels: 330 | name = f"{y_col}_level_{level}" 331 | lo = df[f"{y_col}-lo-{level}"] 332 | hi = df[f"{y_col}-hi-{level}"] 333 | min_alpha = 0.1 334 | max_alpha = 0.9 335 | alpha = max_alpha - (float(level) / 100) * (max_alpha - min_alpha) 336 | y = np.concatenate([hi, lo[::-1]]) 337 | fig.add_trace( 338 | go.Scatter( 339 | x=x, 340 | y=y, 341 | fill="toself", 342 | mode="lines", 343 | fillcolor=color, 344 | opacity=alpha, 345 | name=name, 346 | legendgroup=name, 347 | line=dict(color=color, width=1), 348 | showlegend=show_legend, 349 | ), 350 | row=row + 1, 351 | col=col + 1, 352 | ) 353 | if plot_anomalies: 354 | anomalies = df[target_col].lt(lo) | df[target_col].gt(hi) 355 | anomalies = anomalies.to_numpy().astype("bool") 356 | if not anomalies.any(): 357 | continue 358 | name = f"{y_col}_anomalies_level_{level}" 359 | fig.add_trace( 360 | go.Scatter( 361 | x=df[time_col].to_numpy()[anomalies], 362 | y=df[target_col].to_numpy()[anomalies], 363 | fillcolor=color, 364 | mode="markers", 365 | opacity=float(level) / 100, 366 | name=name, 367 | legendgroup=name, 368 | line=dict(color=color, width=0.7), 369 | marker=dict(size=4, line=dict(color="red", width=0.5)), 370 | showlegend=show_legend, 371 | ), 372 | row=row + 1, 373 | col=col + 1, 374 | ) 375 | 376 | for i, uid in enumerate(uids): 377 | mask = df[id_col].eq(uid) 378 | uid_df = ufp.filter_with_mask(df, mask) 379 | row, col = divmod(i, n_cols) 380 | for y_col, color in zip([target_col] + models, colors): 381 | if isinstance(ax, np.ndarray): 382 | _add_mpl_plot(ax[row, col], uid_df, y_col, level) 383 | else: 384 | _add_plotly_plot(fig, uid_df, y_col, level) 385 | title = f"{id_col}={uid}" 386 | if isinstance(ax, np.ndarray): 387 | ax[row, col].set_title(title) 388 | if col == 0: 389 | ax[row, col].set_ylabel(ylabel) 390 | if row == n_rows - 1: 391 | ax[row, col].set_xlabel(xlabel) 392 | ax[row, col].tick_params(axis="x", labelrotation=30) 393 | else: 394 | fig.update_annotations(selector={"text": str(i)}, text=title) 395 | 396 | if isinstance(ax, np.ndarray): 397 | handles, labels = ax[0, 0].get_legend_handles_labels() 398 | fig.legend( 399 | handles, 400 | labels, 401 | loc="upper left", 402 | bbox_to_anchor=(1.01, 0.97), 403 | ) 404 | plt.close(fig) 405 | if len(ax.flat) > n_series: 406 | for axi in ax.flat[n_series:]: 407 | axi.set_axis_off() 408 | else: 409 | fig.update_xaxes(matches=None, showticklabels=True, visible=True) 410 | fig.update_annotations(font_size=10) 411 | if postprocess: 412 | fig.update_layout(margin=dict(l=60, r=10, t=20, b=50)) 413 | fig.update_layout(template="plotly_white", font=dict(size=10)) 414 | fig.update_layout(autosize=True, height=200 * n_rows) 415 | return fig 416 | -------------------------------------------------------------------------------- /scripts/cli.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d9bbcfbf", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#|default_exp cli" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "b622bc9b", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "name": "stderr", 21 | "output_type": "stream", 22 | "text": [ 23 | "/Users/deven367/miniforge3/envs/nixtla/lib/python3.11/site-packages/nbdev/doclinks.py:20: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", 24 | " import pkg_resources,importlib\n" 25 | ] 26 | } 27 | ], 28 | "source": [ 29 | "#|export\n", 30 | "from execnb.nbio import read_nb\n", 31 | "from nbdev.processors import NBProcessor\n", 32 | "from nbdev.export import ExportModuleProc, nb_export\n", 33 | "from nbdev.maker import ModuleMaker\n", 34 | "from fastcore.xtras import globtastic, Path\n", 35 | "from functools import partial\n", 36 | "from fastcore.script import call_parse\n", 37 | "from nbdev import nbdev_export" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "a98e24fc", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "nb_path = \"../nbs/evaluation.ipynb\"\n", 48 | "nb = read_nb(nb_path)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "6391ab8e", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "#|export\n", 59 | "tst_flags = 'datasets distributed matplotlib polars pyarrow scipy'.split()\n", 60 | "to_skip = [\n", 61 | " 'showdoc',\n", 62 | " 'load_ext',\n", 63 | " 'from nbdev'\n", 64 | "]\n", 65 | "\n", 66 | "\n", 67 | "def print_execs(cell):\n", 68 | " if 'exec' in cell.source: print(cell.source)\n", 69 | "\n", 70 | "def print_hide(cell):\n", 71 | " if 'hide' in cell.directives_: print(cell.source)\n", 72 | "\n", 73 | "def other_tests(cell):\n", 74 | " if len(cell.directives_) == 0:\n", 75 | " print(cell.source)\n", 76 | "\n", 77 | "def get_markdown(cell):\n", 78 | " if cell.cell_type == \"markdown\":\n", 79 | " print(cell.source)\n", 80 | "\n", 81 | "def extract_dir(cell, dir):\n", 82 | " if dir in cell.directives_:\n", 83 | " print(cell.source)\n", 84 | "\n", 85 | "def no_dir_and_dir(cell, dir):\n", 86 | " if len(cell.directives_) == 0:\n", 87 | " print(cell.source)\n", 88 | "\n", 89 | " if dir in cell.directives_:\n", 90 | " print(cell.source)\n", 91 | "\n", 92 | "def get_all_tests2(cell):\n", 93 | " if cell.cell_type == \"code\":\n", 94 | "\n", 95 | " if len(cell.directives_) == 0:\n", 96 | " print(cell.source)\n", 97 | "\n", 98 | "\n", 99 | " elif any(x in tst_flags + ['hide'] for x in cell.directives_):\n", 100 | " if not (x in cell.source for x in to_skip):\n", 101 | " print(cell.source)\n", 102 | "\n", 103 | "def get_all_tests(cell):\n", 104 | " if len(cell.directives_) == 0:\n", 105 | " print(cell.source)\n", 106 | "\n", 107 | " if any(x in tst_flags + [\"hide\"] for x in cell.directives_):\n", 108 | " print(cell.source)\n", 109 | "\n" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "50e5bdf7", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "tst_cell = nb.cells[0]" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "dc8942a7", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "data": { 130 | "text/markdown": [ 131 | "```json\n", 132 | "{ 'cell_type': 'code',\n", 133 | " 'execution_count': None,\n", 134 | " 'id': '5aace6e9-4c24-4e66-b786-f468e32227a6',\n", 135 | " 'idx_': 0,\n", 136 | " 'metadata': {},\n", 137 | " 'outputs': [],\n", 138 | " 'source': '#| hide\\n%load_ext autoreload\\n%autoreload 2'}\n", 139 | "```" 140 | ], 141 | "text/plain": [ 142 | "{'cell_type': 'code',\n", 143 | " 'execution_count': None,\n", 144 | " 'id': '5aace6e9-4c24-4e66-b786-f468e32227a6',\n", 145 | " 'metadata': {},\n", 146 | " 'outputs': [],\n", 147 | " 'source': '#| hide\\n%load_ext autoreload\\n%autoreload 2',\n", 148 | " 'idx_': 0}" 149 | ] 150 | }, 151 | "execution_count": null, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "tst_cell" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "id": "4ebddd94", 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "['datasets', 'distributed', 'matplotlib', 'polars', 'pyarrow', 'scipy', 'hide']" 170 | ] 171 | }, 172 | "execution_count": null, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | } 176 | ], 177 | "source": [ 178 | "tst_flags + ['hide']" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": null, 184 | "id": "502fa72c", 185 | "metadata": {}, 186 | "outputs": [], 187 | "source": [ 188 | "#|export\n", 189 | "mapper = {\n", 190 | " 'print_execs': print_execs,\n", 191 | " 'print_hide': print_hide,\n", 192 | " 'other_tests': other_tests,\n", 193 | " 'get_markdown': get_markdown,\n", 194 | " 'extract_dir': extract_dir,\n", 195 | " 'no_dir_and_dir': no_dir_and_dir,\n", 196 | " 'get_all_tests':get_all_tests\n", 197 | "}" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "1c189eef", 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "#|export\n", 208 | "@call_parse\n", 209 | "def print_dir_in_nb(nb_path:str,\n", 210 | " dir:str=None,\n", 211 | " dir_name:str=None,\n", 212 | " ):\n", 213 | " if dir_name not in mapper.keys():\n", 214 | " raise ValueError(f'Choose processor from the the following: {mapper.keys()}')\n", 215 | "\n", 216 | " if dir_name == 'extract_dir':\n", 217 | " processor = NBProcessor(nb_path, partial(extract_dir, dir=dir))\n", 218 | " processor.process()\n", 219 | " return\n", 220 | " elif dir_name == 'no_dir_and_dir':\n", 221 | " processor = NBProcessor(nb_path, partial(no_dir_and_dir, dir=dir))\n", 222 | " processor.process()\n", 223 | " return\n", 224 | "\n", 225 | " processor = NBProcessor(nb_path, mapper[dir_name])\n", 226 | " processor.process()\n" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "id": "fbb108ba", 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "%load_ext autoreload\n", 240 | "%autoreload 2\n", 241 | "from nbdev import show_doc\n", 242 | "show_doc(evaluate)\n", 243 | "from functools import partial\n", 244 | "\n", 245 | "import numpy as np\n", 246 | "import pandas as pd\n", 247 | "\n", 248 | "from utilsforecast.losses import *\n", 249 | "from utilsforecast.data import generate_series\n", 250 | "series = generate_series(10, n_models=2, level=[80, 95])\n", 251 | "series['unique_id'] = series['unique_id'].astype('int')\n", 252 | "models = ['model0', 'model1']\n", 253 | "metrics = [\n", 254 | " mae,\n", 255 | " mse,\n", 256 | " rmse,\n", 257 | " mape,\n", 258 | " smape,\n", 259 | " partial(mase, seasonality=7),\n", 260 | " quantile_loss,\n", 261 | " mqloss,\n", 262 | " coverage,\n", 263 | " calibration,\n", 264 | " scaled_crps,\n", 265 | "]\n", 266 | "evaluation = evaluate(\n", 267 | " series,\n", 268 | " metrics=metrics,\n", 269 | " models=models,\n", 270 | " train_df=series,\n", 271 | " level=[80, 95],\n", 272 | ")\n", 273 | "evaluation\n", 274 | "summary = evaluation.drop(columns='unique_id').groupby('metric').mean().reset_index()\n", 275 | "summary\n", 276 | "import polars.testing\n", 277 | "series_pl = generate_series(10, n_models=2, level=[80, 95], engine='polars')\n", 278 | "pl_evaluation = (\n", 279 | " evaluate(\n", 280 | " series_pl,\n", 281 | " metrics=metrics,\n", 282 | " train_df=series_pl,\n", 283 | " level=[80, 95],\n", 284 | " ).drop('unique_id')\n", 285 | ")\n", 286 | "pl_summary = ufp.group_by(pl_evaluation, 'metric').mean()\n", 287 | "pd.testing.assert_frame_equal(\n", 288 | " summary.sort_values('metric'),\n", 289 | " pl_summary.sort('metric').to_pandas(),\n", 290 | ")\n", 291 | "pl.testing.assert_frame_equal(\n", 292 | " evaluate(\n", 293 | " series_pl, metrics=metrics, train_df=series_pl, level=[80, 95], agg_fn='mean'\n", 294 | " ).sort('metric'),\n", 295 | " pl_summary.sort('metric'),\n", 296 | ")\n", 297 | "from datasetsforecast.evaluation import accuracy as ds_evaluate\n", 298 | "import datasetsforecast.losses as ds_losses\n", 299 | "def daily_mase(y, y_hat, y_train):\n", 300 | " return ds_losses.mase(y, y_hat, y_train, seasonality=7)\n", 301 | "\n", 302 | "level = [80, 95]\n", 303 | "for agg_fn in [None, 'mean']:\n", 304 | " uf_res = evaluate(\n", 305 | " series,\n", 306 | " metrics=metrics,\n", 307 | " models=models,\n", 308 | " train_df=series,\n", 309 | " level=level,\n", 310 | " agg_fn=agg_fn,\n", 311 | " )\n", 312 | " agg_by = None if agg_fn == 'mean' else ['unique_id']\n", 313 | " ds_res = ds_evaluate(\n", 314 | " series,\n", 315 | " metrics=[\n", 316 | " ds_losses.mae,\n", 317 | " ds_losses.mse,\n", 318 | " ds_losses.rmse,\n", 319 | " ds_losses.mape,\n", 320 | " daily_mase,\n", 321 | " ds_losses.smape,\n", 322 | " ds_losses.quantile_loss, \n", 323 | " ds_losses.mqloss,\n", 324 | " ds_losses.coverage, \n", 325 | " ds_losses.calibration,\n", 326 | " ds_losses.scaled_crps,\n", 327 | " ],\n", 328 | " level=level,\n", 329 | " Y_df=series,\n", 330 | " agg_by=agg_by,\n", 331 | " )\n", 332 | " ds_res['metric'] = ds_res['metric'].str.replace('-', '_')\n", 333 | " ds_res['metric'] = ds_res['metric'].str.replace('q_', 'q')\n", 334 | " ds_res['metric'] = ds_res['metric'].str.replace('lv_', 'level')\n", 335 | " ds_res['metric'] = ds_res['metric'].str.replace('daily_mase', 'mase')\n", 336 | " # utils doesn't multiply pct metrics by 100\n", 337 | " ds_res.loc[ds_res['metric'].str.startswith('coverage'), ['model0', 'model1']] /= 100\n", 338 | " ds_res.loc[ds_res['metric'].eq('mape'), ['model0', 'model1']] /= 100\n", 339 | " # we report smape between 0 and 1 instead of 0-200\n", 340 | " ds_res.loc[ds_res['metric'].eq('smape'), ['model0', 'model1']] /= 200\n", 341 | "\n", 342 | " ds_res = ds_res[uf_res.columns]\n", 343 | " if agg_fn is None:\n", 344 | " ds_res = ds_res.sort_values(['unique_id', 'metric'])\n", 345 | " uf_res = uf_res.sort_values(['unique_id', 'metric'])\n", 346 | " else:\n", 347 | " ds_res = ds_res.sort_values('metric')\n", 348 | " uf_res = uf_res.sort_values('metric')\n", 349 | " \n", 350 | " pd.testing.assert_frame_equal(\n", 351 | " uf_res.reset_index(drop=True),\n", 352 | " ds_res.reset_index(drop=True),\n", 353 | " )\n", 354 | "import sys\n", 355 | "from itertools import product\n", 356 | "\n", 357 | "import dask.dataframe as dd\n", 358 | "import fugue.api as fa\n", 359 | "from pyspark.sql import SparkSession\n", 360 | "if sys.version_info >= (3, 9):\n", 361 | " spark = SparkSession.builder.getOrCreate()\n", 362 | " spark.sparkContext.setLogLevel('FATAL')\n", 363 | " dask_df = dd.from_pandas(series, npartitions=2)\n", 364 | " spark_df = spark.createDataFrame(series).repartition(2)\n", 365 | " for distributed_df, use_train in product([dask_df, spark_df], [True, False]):\n", 366 | " distr_metrics = [rmse, mae]\n", 367 | " if use_train:\n", 368 | " distr_metrics.append(partial(mase, seasonality=7))\n", 369 | " local_train = series\n", 370 | " distr_train = distributed_df\n", 371 | " else:\n", 372 | " local_train = None\n", 373 | " distr_train = None\n", 374 | " local_res = evaluate(series, metrics=distr_metrics, level=level, train_df=local_train)\n", 375 | " distr_res = fa.as_fugue_df(\n", 376 | " evaluate(distributed_df, metrics=distr_metrics, level=level, train_df=distr_train)\n", 377 | " ).as_pandas()\n", 378 | " pd.testing.assert_frame_equal(\n", 379 | " local_res.sort_values(['unique_id', 'metric']).reset_index(drop=True),\n", 380 | " distr_res.sort_values(['unique_id', 'metric']).reset_index(drop=True),\n", 381 | " check_dtype=False,\n", 382 | " )\n" 383 | ] 384 | } 385 | ], 386 | "source": [ 387 | "NBProcessor(nb_path, procs=get_all_tests).process()" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "id": "e3be25c9", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "processor = NBProcessor(nb_path, partial(extract_dir, dir='distributed'))" 398 | ] 399 | }, 400 | { 401 | "cell_type": "code", 402 | "execution_count": null, 403 | "id": "df36376d", 404 | "metadata": {}, 405 | "outputs": [ 406 | { 407 | "name": "stdout", 408 | "output_type": "stream", 409 | "text": [ 410 | "import sys\n", 411 | "from itertools import product\n", 412 | "\n", 413 | "import dask.dataframe as dd\n", 414 | "import fugue.api as fa\n", 415 | "from pyspark.sql import SparkSession\n", 416 | "if sys.version_info >= (3, 9):\n", 417 | " spark = SparkSession.builder.getOrCreate()\n", 418 | " spark.sparkContext.setLogLevel('FATAL')\n", 419 | " dask_df = dd.from_pandas(series, npartitions=2)\n", 420 | " spark_df = spark.createDataFrame(series).repartition(2)\n", 421 | " for distributed_df, use_train in product([dask_df, spark_df], [True, False]):\n", 422 | " distr_metrics = [rmse, mae]\n", 423 | " if use_train:\n", 424 | " distr_metrics.append(partial(mase, seasonality=7))\n", 425 | " local_train = series\n", 426 | " distr_train = distributed_df\n", 427 | " else:\n", 428 | " local_train = None\n", 429 | " distr_train = None\n", 430 | " local_res = evaluate(series, metrics=distr_metrics, level=level, train_df=local_train)\n", 431 | " distr_res = fa.as_fugue_df(\n", 432 | " evaluate(distributed_df, metrics=distr_metrics, level=level, train_df=distr_train)\n", 433 | " ).as_pandas()\n", 434 | " pd.testing.assert_frame_equal(\n", 435 | " local_res.sort_values(['unique_id', 'metric']).reset_index(drop=True),\n", 436 | " distr_res.sort_values(['unique_id', 'metric']).reset_index(drop=True),\n", 437 | " check_dtype=False,\n", 438 | " )\n" 439 | ] 440 | } 441 | ], 442 | "source": [ 443 | "processor.process()" 444 | ] 445 | }, 446 | { 447 | "cell_type": "code", 448 | "execution_count": null, 449 | "id": "f61392f9", 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "nb_export(nb_path, '../tests', partial(extract_dir, dir='distributed'), 'foo2')" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "id": "82c974d8", 460 | "metadata": {}, 461 | "outputs": [], 462 | "source": [] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "id": "31426130", 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "id": "92247f47", 476 | "metadata": {}, 477 | "outputs": [], 478 | "source": [ 479 | "nbs = globtastic('../nbs', file_glob='*.ipynb', recursive=False).map(Path).sorted()\n", 480 | "tst_flags = 'datasets distributed matplotlib polars pyarrow scipy'.split()" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "id": "242be5c9", 487 | "metadata": {}, 488 | "outputs": [ 489 | { 490 | "data": { 491 | "text/plain": [ 492 | "(#12) [Path('../nbs/compat.ipynb'),Path('../nbs/data.ipynb'),Path('../nbs/evaluation.ipynb'),Path('../nbs/feature_engineering.ipynb'),Path('../nbs/grouped_array.ipynb'),Path('../nbs/index.ipynb'),Path('../nbs/losses.ipynb'),Path('../nbs/plotting.ipynb'),Path('../nbs/preprocessing.ipynb'),Path('../nbs/processing.ipynb'),Path('../nbs/read.ipynb'),Path('../nbs/validation.ipynb')]" 493 | ] 494 | }, 495 | "execution_count": null, 496 | "metadata": {}, 497 | "output_type": "execute_result" 498 | } 499 | ], 500 | "source": [ 501 | "nbs" 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "id": "fc44bc54", 508 | "metadata": {}, 509 | "outputs": [], 510 | "source": [ 511 | "TEST_PATH = Path('../tests').resolve()" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": null, 517 | "id": "c01da20b", 518 | "metadata": {}, 519 | "outputs": [ 520 | { 521 | "data": { 522 | "text/plain": [ 523 | "Path('/Users/deven367/projects/public/utilsforecast/tests')" 524 | ] 525 | }, 526 | "execution_count": null, 527 | "metadata": {}, 528 | "output_type": "execute_result" 529 | } 530 | ], 531 | "source": [ 532 | "TEST_PATH" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "id": "b725156d", 539 | "metadata": {}, 540 | "outputs": [], 541 | "source": [ 542 | "if not TEST_PATH.exists():\n", 543 | " TEST_PATH.mkdir()" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "id": "522bd94f", 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "data": { 554 | "text/plain": [ 555 | "Path('../nbs/grouped_array.ipynb')" 556 | ] 557 | }, 558 | "execution_count": null, 559 | "metadata": {}, 560 | "output_type": "execute_result" 561 | } 562 | ], 563 | "source": [ 564 | "nbs[4]" 565 | ] 566 | }, 567 | { 568 | "cell_type": "code", 569 | "execution_count": null, 570 | "id": "c889adde", 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "# for tst in tst_flags:\n", 575 | "# for nb in nbs:\n", 576 | "# nb_name = nb.stem\n", 577 | "# nb_export(nb, lib_path='../tests', procs=partial(extract_dir, dir=tst), name=f'{tst}_{nb_name}')" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "id": "5c848413", 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "# for tst in tst_flags:\n", 588 | "# for nb in nbs:\n", 589 | "# nb_name = nb.stem\n", 590 | "# code = NBProcessor(nb, partial(extract_dir, dir=tst))\n", 591 | "# with open(TEST_PATH / f'{tst}.py', '+a') as f:\n", 592 | "# if code.process() is not None:\n", 593 | "# f.write(code.process())" 594 | ] 595 | } 596 | ], 597 | "metadata": { 598 | "kernelspec": { 599 | "display_name": "python3", 600 | "language": "python", 601 | "name": "python3" 602 | } 603 | }, 604 | "nbformat": 4, 605 | "nbformat_minor": 5 606 | } 607 | --------------------------------------------------------------------------------