├── .github
├── ISSUE_TEMPLATE.md
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
└── workflows
│ ├── autoapprove.yml.disabled
│ ├── automerge.yml.disabled
│ ├── pypi_release.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yml
├── CHANGELOG.md
├── CODEOWNERS
├── LICENSE
├── README.md
├── build_tools
├── changelog.py
└── run_examples.sh
├── codecov.yml
├── docs
├── Makefile
├── make.bat
├── requirements.txt
└── source
│ ├── _static
│ ├── custom.css
│ ├── favicon.png
│ ├── favicon.svg
│ └── logo.svg
│ ├── _templates
│ ├── custom-base-template.rst
│ ├── custom-class-template.rst
│ └── custom-module-template.rst
│ ├── api.rst
│ ├── conf.py
│ ├── data.rst
│ ├── faq.rst
│ ├── getting-started.rst
│ ├── index.rst
│ ├── installation.rst
│ ├── metrics.rst
│ ├── models.rst
│ ├── tutorials.rst
│ └── tutorials
│ ├── ar.ipynb
│ ├── building.ipynb
│ ├── deepar.ipynb
│ ├── nhits.ipynb
│ └── stallion.ipynb
├── examples
├── ar.py
├── data
│ └── stallion.parquet
├── nbeats.py
└── stallion.py
├── pyproject.toml
├── pytorch_forecasting
├── __init__.py
├── _registry
│ ├── __init__.py
│ └── _lookup.py
├── data
│ ├── __init__.py
│ ├── data_module.py
│ ├── encoders.py
│ ├── examples.py
│ ├── samplers.py
│ └── timeseries
│ │ ├── __init__.py
│ │ ├── _timeseries.py
│ │ └── _timeseries_v2.py
├── metrics
│ ├── __init__.py
│ ├── _mqf2_utils.py
│ ├── base_metrics.py
│ ├── distributions.py
│ ├── point.py
│ └── quantile.py
├── models
│ ├── __init__.py
│ ├── base
│ │ ├── __init__.py
│ │ ├── _base_model.py
│ │ ├── _base_model_v2.py
│ │ └── _base_object.py
│ ├── base_model.py
│ ├── baseline.py
│ ├── deepar
│ │ ├── __init__.py
│ │ ├── _deepar.py
│ │ └── _deepar_metadata.py
│ ├── mlp
│ │ ├── __init__.py
│ │ ├── _decodermlp.py
│ │ └── submodules.py
│ ├── nbeats
│ │ ├── __init__.py
│ │ ├── _nbeats.py
│ │ ├── _nbeats_metadata.py
│ │ └── sub_modules.py
│ ├── nhits
│ │ ├── __init__.py
│ │ ├── _nhits.py
│ │ └── sub_modules.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── embeddings.py
│ │ └── rnn.py
│ ├── rnn
│ │ ├── __init__.py
│ │ └── _rnn.py
│ ├── temporal_fusion_transformer
│ │ ├── __init__.py
│ │ ├── _tft.py
│ │ ├── _tft_v2.py
│ │ ├── sub_modules.py
│ │ └── tuning.py
│ └── tide
│ │ ├── __init__.py
│ │ ├── _tide.py
│ │ ├── _tide_metadata.py
│ │ └── sub_modules.py
├── tests
│ ├── __init__.py
│ ├── _config.py
│ ├── _conftest.py
│ ├── _data_scenarios.py
│ └── test_all_estimators.py
└── utils
│ ├── __init__.py
│ ├── _coerce.py
│ ├── _dependencies
│ ├── __init__.py
│ ├── _dependencies.py
│ ├── _safe_import.py
│ └── tests
│ │ ├── __init__.py
│ │ └── test_safe_import.py
│ ├── _maint
│ ├── __init__.py
│ └── _show_versions.py
│ └── _utils.py
└── tests
├── conftest.py
├── test_data
├── test_d1.py
├── test_data_module.py
├── test_encoders.py
├── test_samplers.py
└── test_timeseries.py
├── test_metrics.py
├── test_models
├── _test_tft_v2.py
├── conftest.py
├── test_baseline.py
├── test_deepar.py
├── test_mlp.py
├── test_nbeats.py
├── test_nhits.py
├── test_nn
│ ├── test_embeddings.py
│ └── test_rnn.py
├── test_rnn_model.py
├── test_temporal_fusion_transformer.py
└── test_tide.py
└── test_utils
├── test_autocorrelation.py
├── test_safe_import.py
└── test_show_versions.py
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | - PyTorch-Forecasting version:
2 | - PyTorch version:
3 | - Python version:
4 | - Operating System:
5 |
6 | ### Expected behavior
7 |
8 | I executed code ... in order to ... and expected to get result ...
9 |
10 | ### Actual behavior
11 |
12 | However, result was .... I think it has to do with ... because of ...
13 |
14 | ### Code to reproduce the problem
15 |
16 | ```
17 |
18 | ```
19 |
20 | Paste the command(s) you ran and the output. Including a link to a colab notebook will speed up issue resolution.
21 | If there was a crash, please include the traceback here.
22 | The code used to initialize the TimeSeriesDataSet and model should be also included.
23 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ### Description
2 |
3 | This PR ...
4 |
5 | ### Checklist
6 |
7 | - [ ] Linked issues (if existing)
8 | - [ ] Amended changelog for large changes (and added myself there as contributor)
9 | - [ ] Added/modified tests
10 | - [ ] Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with `pre-commit install`.
11 | To run hooks independent of commit, execute `pre-commit run --all-files`
12 |
13 | Make sure to have fun coding!
14 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 | commit-message:
8 | prefix: "[MNT] [Dependabot]"
9 | include: "scope"
10 | labels:
11 | - "maintenance"
12 | - package-ecosystem: "github-actions"
13 | directory: "/"
14 | schedule:
15 | interval: "daily"
16 | commit-message:
17 | prefix: "[MNT] [Dependabot]"
18 | include: "scope"
19 | labels:
20 | - "maintenance"
21 |
--------------------------------------------------------------------------------
/.github/workflows/autoapprove.yml.disabled:
--------------------------------------------------------------------------------
1 | name: Dependabot auto-approve
2 | on: pull_request
3 |
4 | permissions:
5 | pull-requests: write
6 |
7 | jobs:
8 | dependabot:
9 | runs-on: ubuntu-latest
10 | if: ${{ github.actor == 'dependabot[bot]' }}
11 | steps:
12 | - name: Dependabot metadata
13 | id: metadata
14 | uses: dependabot/fetch-metadata@v1.1.1
15 | with:
16 | github-token: "${{ secrets.GITHUB_TOKEN }}"
17 | - name: Approve a PR
18 | run: gh pr review --approve "$PR_URL"
19 | env:
20 | PR_URL: ${{github.event.pull_request.html_url}}
21 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
22 |
--------------------------------------------------------------------------------
/.github/workflows/automerge.yml.disabled:
--------------------------------------------------------------------------------
1 | name: Dependabot auto-merge
2 | on: pull_request
3 |
4 | permissions:
5 | pull-requests: write
6 | contents: write
7 |
8 | jobs:
9 | dependabot:
10 | runs-on: ubuntu-latest
11 | if: ${{ github.actor == 'dependabot[bot]' }}
12 | steps:
13 | - name: Dependabot metadata
14 | id: metadata
15 | uses: dependabot/fetch-metadata@v1.1.1
16 | with:
17 | github-token: "${{ secrets.GITHUB_TOKEN }}"
18 | - name: Enable auto-merge for Dependabot PRs
19 | run: gh pr merge --auto --merge "$PR_URL"
20 | env:
21 | PR_URL: ${{github.event.pull_request.html_url}}
22 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}}
23 |
--------------------------------------------------------------------------------
/.github/workflows/pypi_release.yml:
--------------------------------------------------------------------------------
1 | name: PyPI Release
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | build_wheels:
9 | name: Build wheels
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 |
15 | - uses: actions/setup-python@v5
16 | with:
17 | python-version: '3.11'
18 |
19 | - name: Build wheel
20 | run: |
21 | python -m pip install build
22 | python -m build --wheel --sdist --outdir wheelhouse
23 |
24 | - name: Store wheels
25 | uses: actions/upload-artifact@v4
26 | with:
27 | name: wheels
28 | path: wheelhouse/*
29 |
30 | pytest-nosoftdeps:
31 | name: no-softdeps
32 | runs-on: ${{ matrix.os }}
33 | needs: [build_wheels]
34 | strategy:
35 | fail-fast: false
36 | matrix:
37 | os: [ubuntu-latest, macos-latest, windows-latest]
38 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
39 |
40 | steps:
41 | - uses: actions/checkout@v4
42 |
43 | - name: Set up Python ${{ matrix.python-version }}
44 | uses: actions/setup-python@v5
45 | with:
46 | python-version: ${{ matrix.python-version }}
47 |
48 | - name: Setup macOS
49 | if: runner.os == 'macOS'
50 | run: |
51 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030
52 |
53 | - name: Get full Python version
54 | id: full-python-version
55 | shell: bash
56 | run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT
57 |
58 | - name: Install dependencies
59 | shell: bash
60 | run: |
61 | pip install ".[dev,github-actions]"
62 |
63 | - name: Show dependencies
64 | run: python -m pip list
65 |
66 | - name: Run pytest
67 | shell: bash
68 | run: python -m pytest tests
69 |
70 | upload_wheels:
71 | name: Upload wheels to PyPI
72 | runs-on: ubuntu-latest
73 | needs: [pytest-nosoftdeps]
74 |
75 | permissions:
76 | id-token: write
77 |
78 | steps:
79 | - uses: actions/download-artifact@v4
80 | with:
81 | name: wheels
82 | path: wheelhouse
83 |
84 | - name: Publish package to PyPI
85 | uses: pypa/gh-action-pypi-publish@release/v1
86 | with:
87 | packages-dir: wheelhouse/
88 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Test
5 |
6 | on:
7 | push:
8 | branches: [main]
9 | pull_request:
10 | branches: [main]
11 |
12 | concurrency:
13 | group: ${{ github.workflow }}-${{ github.ref }}
14 | cancel-in-progress: true
15 |
16 | jobs:
17 | code-quality:
18 | name: code-quality
19 | runs-on: ubuntu-latest
20 | steps:
21 | - name: repository checkout step
22 | uses: actions/checkout@v4
23 |
24 | - name: python environment step
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: "3.11"
28 |
29 | - name: install pre-commit
30 | run: python3 -m pip install pre-commit
31 |
32 | - name: Checkout code
33 | uses: actions/checkout@v4
34 | with:
35 | fetch-depth: 0
36 |
37 | - name: Get changed files
38 | id: changed-files
39 | run: |
40 | CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | tr '\n' ' ')
41 | echo "CHANGED_FILES=${CHANGED_FILES}" >> $GITHUB_ENV
42 |
43 | - name: Print changed files
44 | run: |
45 | echo "Changed files:" && echo "$CHANGED_FILES" | tr ' ' '\n'
46 |
47 | - name: Run pre-commit on changed files
48 | run: |
49 | if [ -n "$CHANGED_FILES" ]; then
50 | pre-commit run --color always --files $CHANGED_FILES --show-diff-on-failure
51 | else
52 | echo "No changed files to check."
53 | fi
54 |
55 | run-notebook-tutorials:
56 | name: Run notebook tutorials
57 | needs: code-quality
58 | runs-on: ubuntu-latest
59 | steps:
60 | - uses: actions/checkout@v4
61 | - name: Set up Python
62 | uses: actions/setup-python@v5
63 | with:
64 | python-version: 3.9
65 | - name: Install dependencies
66 | run: |
67 | python -m pip install --upgrade pip
68 | python -m pip install ".[dev,all_extras,github-actions]"
69 |
70 | - name: Show dependencies
71 | run: python -m pip list
72 |
73 | - name: Run example notebooks
74 | run: build_tools/run_examples.sh
75 | shell: bash
76 |
77 | pytest-nosoftdeps:
78 | name: no-softdeps
79 | needs: code-quality
80 | runs-on: ${{ matrix.os }}
81 | strategy:
82 | fail-fast: false
83 | matrix:
84 | os: [ubuntu-latest, macos-latest, windows-latest]
85 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
86 |
87 | steps:
88 | - uses: actions/checkout@v4
89 |
90 | - name: Set up Python ${{ matrix.python-version }}
91 | uses: actions/setup-python@v5
92 | with:
93 | python-version: ${{ matrix.python-version }}
94 |
95 | - name: Setup macOS
96 | if: runner.os == 'macOS'
97 | run: |
98 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030
99 |
100 | - name: Get full Python version
101 | id: full-python-version
102 | shell: bash
103 | run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT
104 |
105 | - name: Install dependencies
106 | shell: bash
107 | run: |
108 | pip install ".[dev,github-actions]"
109 |
110 | - name: Show dependencies
111 | run: python -m pip list
112 |
113 | - name: Run pytest
114 | shell: bash
115 | run: python -m pytest
116 |
117 | pytest:
118 | name: Run pytest
119 | needs: pytest-nosoftdeps
120 | runs-on: ${{ matrix.os }}
121 | strategy:
122 | fail-fast: false
123 | matrix:
124 | os: [ubuntu-latest, macos-latest, windows-latest]
125 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
126 |
127 | steps:
128 | - uses: actions/checkout@v4
129 |
130 | - name: Set up Python ${{ matrix.python-version }}
131 | uses: actions/setup-python@v5
132 | with:
133 | python-version: ${{ matrix.python-version }}
134 |
135 | - name: Setup macOS
136 | if: runner.os == 'macOS'
137 | run: |
138 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030
139 |
140 | - name: Get full Python version
141 | id: full-python-version
142 | shell: bash
143 | run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT
144 |
145 | - name: Install dependencies
146 | shell: bash
147 | run: |
148 | pip install ".[dev,all_extras,github-actions]"
149 |
150 | - name: Show dependencies
151 | run: python -m pip list
152 |
153 | - name: Run pytest
154 | shell: bash
155 | run: python -m pytest
156 |
157 | - name: Statistics
158 | run: |
159 | pip install coverage
160 | coverage report
161 | coverage xml
162 |
163 | - name: Upload coverage to Codecov
164 | uses: codecov/codecov-action@v5
165 | if: always()
166 | continue-on-error: true
167 | with:
168 | token: ${{ secrets.CODECOV_TOKEN }}
169 | file: coverage.xml
170 | flags: cpu,pytest
171 | name: CPU-coverage
172 | fail_ci_if_error: false
173 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 | docs/source/api/
74 | docs/source/CHANGELOG.md
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # pycharm
134 | .idea
135 |
136 | # vscode
137 | .vscode
138 |
139 | # logs
140 | lightning_logs
141 | .history
142 |
143 | # checkpoints
144 | *.ckpt
145 | *.pkl
146 | .DS_Store
147 |
148 | # data
149 | pytorch_forecasting/data/*.parquet
150 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.6.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-ast
11 | - repo: https://github.com/astral-sh/ruff-pre-commit
12 | rev: v0.6.9
13 | hooks:
14 | - id: ruff
15 | args: [--fix]
16 | - id: ruff-format
17 | - repo: https://github.com/nbQA-dev/nbQA
18 | rev: 1.8.7
19 | hooks:
20 | - id: nbqa-black
21 | - id: nbqa-ruff
22 | - id: nbqa-check-ast
23 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | # Build documentation in the docs/ directory with Sphinx
9 | # reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx
10 | sphinx:
11 | configuration: docs/source/conf.py
12 | # fail_on_warning: true
13 |
14 | # Build documentation with MkDocs
15 | #mkdocs:
16 | # configuration: mkdocs.yml
17 |
18 | # Optionally build your docs in additional formats such as PDF and ePub
19 | formats:
20 | - htmlzip
21 |
22 | build:
23 | os: ubuntu-22.04
24 | tools:
25 | python: "3.12"
26 |
27 | python:
28 | install:
29 | - method: pip
30 | path: .
31 | extra_requirements:
32 | - docs
33 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # The file specifies framework level core developers for automated review requests
2 |
3 | * @benheid @fkiraly @fnhirwa @geetu040 @jdb78 @pranavvp16 @XinyuWuu @yarnabrina
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | THE MIT License
2 |
3 | Copyright (c) 2020 - present, the pytorch-forecasting developers
4 | Copyright (c) 2020 Jan Beitner
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
23 |
--------------------------------------------------------------------------------
/build_tools/changelog.py:
--------------------------------------------------------------------------------
1 | """RestructuredText changelog generator."""
2 |
3 | from collections import defaultdict
4 | import os
5 |
6 | HEADERS = {
7 | "Accept": "application/vnd.github.v3+json",
8 | }
9 |
10 | if os.getenv("GITHUB_TOKEN") is not None:
11 | HEADERS["Authorization"] = f"token {os.getenv('GITHUB_TOKEN')}"
12 |
13 | OWNER = "jdb78"
14 | REPO = "pytorch-forecasting"
15 | GITHUB_REPOS = "https://api.github.com/repos"
16 |
17 |
18 | def fetch_merged_pull_requests(page: int = 1) -> list[dict]:
19 | """Fetch a page of merged pull requests.
20 |
21 | Parameters
22 | ----------
23 | page : int, optional
24 | Page number to fetch, by default 1.
25 | Returns all merged pull request from the ``page``-th page of closed PRs,
26 | where pages are in descending order of last update.
27 |
28 | Returns
29 | -------
30 | list
31 | List of merged pull requests from the ``page``-th page of closed PRs.
32 | Elements of list are dictionaries with PR details, as obtained
33 | from the GitHub API via ``httpx.get``, from the ``pulls`` endpoint.
34 | """
35 | import httpx
36 |
37 | params = {
38 | "base": "main",
39 | "state": "closed",
40 | "page": page,
41 | "per_page": 50,
42 | "sort": "updated",
43 | "direction": "desc",
44 | }
45 | r = httpx.get(
46 | f"{GITHUB_REPOS}/{OWNER}/{REPO}/pulls",
47 | headers=HEADERS,
48 | params=params,
49 | )
50 | return [pr for pr in r.json() if pr["merged_at"]]
51 |
52 |
53 | def fetch_latest_release(): # noqa: D103
54 | """Fetch the latest release from the GitHub API.
55 |
56 | Returns
57 | -------
58 | dict
59 | Dictionary with details of the latest release.
60 | Dictionary is as obtained from the GitHub API via ``httpx.get``,
61 | for ``releases/latest`` endpoint.
62 | """
63 | import httpx
64 |
65 | response = httpx.get(
66 | f"{GITHUB_REPOS}/{OWNER}/{REPO}/releases/latest", headers=HEADERS
67 | )
68 |
69 | if response.status_code == 200:
70 | return response.json()
71 | else:
72 | raise ValueError(response.text, response.status_code)
73 |
74 |
75 | def fetch_pull_requests_since_last_release() -> list[dict]:
76 | """Fetch all pull requests merged since last release.
77 |
78 | Returns
79 | -------
80 | list
81 | List of pull requests merged since the latest release.
82 | Elements of list are dictionaries with PR details, as obtained
83 | from the GitHub API via ``httpx.get``, through ``fetch_merged_pull_requests``.
84 | """
85 | from dateutil import parser
86 |
87 | release = fetch_latest_release()
88 | published_at = parser.parse(release["published_at"])
89 | print(f"Latest release {release['tag_name']} was published at {published_at}")
90 |
91 | is_exhausted = False
92 | page = 1
93 | all_pulls = []
94 | while not is_exhausted:
95 | pulls = fetch_merged_pull_requests(page=page)
96 | all_pulls.extend(
97 | [p for p in pulls if parser.parse(p["merged_at"]) > published_at]
98 | )
99 | is_exhausted = any(parser.parse(p["updated_at"]) < published_at for p in pulls)
100 | page += 1
101 | return all_pulls
102 |
103 |
104 | def github_compare_tags(tag_left: str, tag_right: str = "HEAD"):
105 | """Compare commit between two tags."""
106 | import httpx
107 |
108 | response = httpx.get(
109 | f"{GITHUB_REPOS}/{OWNER}/{REPO}/compare/{tag_left}...{tag_right}"
110 | )
111 | if response.status_code == 200:
112 | return response.json()
113 | else:
114 | raise ValueError(response.text, response.status_code)
115 |
116 |
117 | def render_contributors(prs: list, fmt: str = "rst"):
118 | """Find unique authors and print a list in given format."""
119 | authors = sorted({pr["user"]["login"] for pr in prs}, key=lambda x: x.lower())
120 |
121 | header = "Contributors"
122 | if fmt == "github":
123 | print(f"### {header}")
124 | print(", ".join(f"@{user}" for user in authors))
125 | elif fmt == "rst":
126 | print(header)
127 | print("~" * len(header), end="\n\n")
128 | print(",\n".join(f":user:`{user}`" for user in authors))
129 |
130 |
131 | def assign_prs(prs, categs: list[dict[str, list[str]]]):
132 | """Assign PR to categories based on labels."""
133 | assigned = defaultdict(list)
134 |
135 | for i, pr in enumerate(prs):
136 | for cat in categs:
137 | pr_labels = [label["name"] for label in pr["labels"]]
138 | if not set(cat["labels"]).isdisjoint(set(pr_labels)):
139 | assigned[cat["title"]].append(i)
140 |
141 | # if any(l.startswith("module") for l in pr_labels):
142 | # print(i, pr_labels)
143 |
144 | assigned["Other"] = list(
145 | set(range(len(prs))) - {i for _, j in assigned.items() for i in j}
146 | )
147 |
148 | return assigned
149 |
150 |
151 | def render_row(pr):
152 | """Render a single row with PR in restructuredText format."""
153 | print(
154 | "*",
155 | pr["title"].replace("`", "``"),
156 | f"(:pr:`{pr['number']}`)",
157 | f":user:`{pr['user']['login']}`",
158 | )
159 |
160 |
161 | def render_changelog(prs, assigned):
162 | # sourcery skip: use-named-expression
163 | """Render changelog."""
164 | from dateutil import parser
165 |
166 | for title, _ in assigned.items():
167 | pr_group = [prs[i] for i in assigned[title]]
168 | if pr_group:
169 | print(f"\n{title}")
170 | print("~" * len(title), end="\n\n")
171 |
172 | for pr in sorted(pr_group, key=lambda x: parser.parse(x["merged_at"])):
173 | render_row(pr)
174 |
175 |
176 | if __name__ == "__main__":
177 | categories = [
178 | {"title": "Enhancements", "labels": ["feature", "enhancement"]},
179 | {"title": "Fixes", "labels": ["bug", "fix", "bugfix"]},
180 | {"title": "Maintenance", "labels": ["maintenance", "chore"]},
181 | {"title": "Refactored", "labels": ["refactor"]},
182 | {"title": "Documentation", "labels": ["documentation"]},
183 | ]
184 |
185 | pulls = fetch_pull_requests_since_last_release()
186 | print(f"Found {len(pulls)} merged PRs since last release")
187 | assigned = assign_prs(pulls, categories)
188 | render_changelog(pulls, assigned)
189 | print()
190 | render_contributors(pulls)
191 |
192 | release = fetch_latest_release()
193 | diff = github_compare_tags(release["tag_name"])
194 | if diff["total_commits"] != len(pulls):
195 | raise ValueError(
196 | "Something went wrong and not all PR were fetched. "
197 | f'There are {len(pulls)} PRs but {diff["total_commits"]} in the diff. '
198 | "Please verify that all PRs are included in the changelog."
199 | )
200 |
--------------------------------------------------------------------------------
/build_tools/run_examples.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Script to run all example notebooks.
4 | # copy-paste from sktime's run_examples.sh
5 | set -euxo pipefail
6 |
7 | CMD="jupyter nbconvert --to notebook --inplace --execute --ExecutePreprocessor.timeout=1200"
8 |
9 | for notebook in docs/source/tutorials/*.ipynb; do
10 | echo "Running: $notebook"
11 | $CMD "$notebook"
12 | done
13 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | precision: 2
3 | round: down
4 | range: "70...100"
5 | status:
6 | project:
7 | default:
8 | threshold: 0.2%
9 | patch:
10 | default:
11 | informational: true
12 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx >3.2
2 | nbsphinx
3 | pandoc
4 | docutils
5 | pydata-sphinx-theme
6 | lightning >=2.0.0
7 | cloudpickle
8 | torch >=2.0,!=2.0.1
9 | optuna >=3.1.0
10 | optuna-integration
11 | scipy
12 | pandas >=1.3
13 | scikit-learn >1.2
14 | matplotlib
15 | statsmodels
16 | ipython
17 | nbconvert >=6.3.0
18 | recommonmark >=0.7.1
19 | pytorch-optimizer >=2.5.1
20 | fastapi >0.80
21 | cpflows
22 |
--------------------------------------------------------------------------------
/docs/source/_static/custom.css:
--------------------------------------------------------------------------------
1 | .container-xl {
2 | max-width: 4000px;
3 | }
4 |
5 | .bd-content {
6 | flex-grow: 1;
7 | max-width: 100%;
8 | }
9 | .bd-page-width {
10 | max-width: 100rem;
11 | }
12 | .bd-main .bd-content .bd-article-container {
13 | max-width: 100%;
14 | }
15 |
16 | a.reference.internal.nav-link {
17 | color: #727991 !important;
18 | }
19 |
20 | html[data-theme="light"] {
21 | --pst-color-primary: #ee4c2c;
22 | }
23 | a.nav-link
24 | {
25 | color: #647db6 !important;
26 | }
27 |
28 | a.nav-link[href="https://github.com/sktime/pytorch-forecasting"]
29 | {
30 | color: #ee4c2c !important;
31 | }
32 |
33 | code {
34 | color: #d14 !important;
35 | }
36 |
37 | pre code,
38 | a > code,
39 | .reference {
40 | color: #647db6 !important;
41 | }
42 |
43 | dt:target,
44 | span.highlighted {
45 | background-color: #fff !important;
46 | }
47 |
48 | /* code highlighting */
49 | .highlight > pre > .mi {
50 | color: #d14 !important;
51 | }
52 | .highlight > pre > .mf {
53 | color: #d14 !important;
54 | }
55 |
56 | .highlight > pre > .s2 {
57 | color: #647db6 !important;
58 | }
59 |
--------------------------------------------------------------------------------
/docs/source/_static/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sktime/pytorch-forecasting/4384140418bfe8e8c64cc6ce5fc19372508bfccd/docs/source/_static/favicon.png
--------------------------------------------------------------------------------
/docs/source/_static/favicon.svg:
--------------------------------------------------------------------------------
1 |
2 |
107 |
--------------------------------------------------------------------------------
/docs/source/_static/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
125 |
--------------------------------------------------------------------------------
/docs/source/_templates/custom-base-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname.split(".")[-1] | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. auto{{ objtype }}:: {{ objname }}
6 |
--------------------------------------------------------------------------------
/docs/source/_templates/custom-class-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname.split(".")[-1] | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 |
6 | .. autoclass:: {{ objname }}
7 | :members:
8 | :show-inheritance:
9 | :exclude-members: __init__
10 | {% set allow_inherited = "zero_grad" not in inherited_members %} {# no inheritance for torch.nn.Modules #}
11 | {%if allow_inherited %}
12 | :inherited-members:
13 | {% endif %}
14 |
15 | {% block methods %}
16 | {% set allowed_methods = [] %}
17 | {% for item in methods %}{% if not item.startswith("_") and (item not in inherited_members or allow_inherited) %}
18 | {% set a=allowed_methods.append(item) %}
19 | {% endif %}{%- endfor %}
20 | {% if allowed_methods %}
21 | .. rubric:: {{ _('Methods') }}
22 |
23 | .. autosummary::
24 | {% for item in allowed_methods %}
25 | ~{{ name }}.{{ item }}
26 | {%- endfor %}
27 | {% endif %}
28 | {% endblock %}
29 |
30 | {% block attributes %}
31 | {% set dynamic_attributes = [] %} {# dynamic attributes are not documented #}
32 | {% set allowed_attributes = [] %}
33 | {% for item in attributes %}{% if not item.startswith("_") and (item not in inherited_members or allow_inherited) and (item not in dynamic_attributes) and allow_inherited %}
34 | {% set a=allowed_attributes.append(item) %}
35 | {% endif %}{%- endfor %}
36 | {% if allowed_attributes %}
37 | .. rubric:: {{ _('Attributes') }}
38 |
39 | .. autosummary::
40 | {% for item in allowed_attributes %}
41 | ~{{ name }}.{{ item }}
42 | {%- endfor %}
43 | {% endif %}
44 | {% endblock %}
45 |
--------------------------------------------------------------------------------
/docs/source/_templates/custom-module-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname.split(".")[-1] | escape | underline}}
2 |
3 | .. automodule:: {{ fullname }}
4 |
5 | {% block attributes %}
6 | {% if attributes %}
7 | .. rubric:: Module Attributes
8 |
9 | .. autosummary::
10 | :toctree:
11 | {% for item in attributes %}
12 | {{ item }}
13 | {%- endfor %}
14 | {% endif %}
15 | {% endblock %}
16 |
17 | {% block functions %}
18 | {% if functions %}
19 | .. rubric:: {{ _('Functions') }}
20 |
21 | .. autosummary::
22 | :toctree:
23 | :template: custom-base-template.rst
24 | {% for item in functions %}
25 | {{ item }}
26 | {%- endfor %}
27 | {% endif %}
28 | {% endblock %}
29 |
30 | {% block classes %}
31 | {% if classes %}
32 | .. rubric:: {{ _('Classes') }}
33 |
34 | .. autosummary::
35 | :toctree:
36 | :template: custom-class-template.rst
37 | {% for item in classes %}
38 | {{ item }}
39 | {%- endfor %}
40 | {% endif %}
41 | {% endblock %}
42 |
43 | {% block exceptions %}
44 | {% if exceptions %}
45 | .. rubric:: {{ _('Exceptions') }}
46 |
47 | .. autosummary::
48 | :toctree:
49 | {% for item in exceptions %}
50 | {{ item }}
51 | {%- endfor %}
52 | {% endif %}
53 | {% endblock %}
54 |
55 | {% block modules %}
56 | {% if all_modules %}
57 | .. rubric:: Modules
58 |
59 | .. autosummary::
60 | :toctree:
61 | :template: custom-module-template.rst
62 | :recursive:
63 | {% for item in all_modules %}
64 | {{ item }}
65 | {%- endfor %}
66 | {% endif %}
67 | {% endblock %}
68 |
--------------------------------------------------------------------------------
/docs/source/api.rst:
--------------------------------------------------------------------------------
1 | API
2 | ====
3 |
4 | .. currentmodule:: pytorch_forecasting
5 |
6 | .. autosummary::
7 | :toctree: api
8 | :template: custom-module-template.rst
9 | :recursive:
10 |
11 | data
12 | models
13 | metrics
14 | utils
15 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/main/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | from pathlib import Path
15 | import shutil
16 | import sys
17 |
18 | from sphinx.application import Sphinx
19 | from sphinx.ext.autosummary import Autosummary
20 | from sphinx.pycode import ModuleAnalyzer
21 |
22 | SOURCE_PATH = Path(os.path.dirname(__file__)) # noqa # docs source
23 | PROJECT_PATH = SOURCE_PATH.joinpath("../..") # noqa # project root
24 |
25 | sys.path.insert(0, str(PROJECT_PATH)) # noqa
26 |
27 | import pytorch_forecasting # isort:skip
28 |
29 | # -- Project information -----------------------------------------------------
30 |
31 | project = "pytorch-forecasting"
32 | copyright = "2020, Jan Beitner"
33 | author = "Jan Beitner"
34 |
35 |
36 | # -- General configuration ---------------------------------------------------
37 |
38 | # Add any Sphinx extension module names here, as strings. They can be
39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
40 | # ones.
41 | extensions = [
42 | "nbsphinx",
43 | "recommonmark",
44 | "sphinx.ext.autodoc",
45 | "sphinx.ext.autosummary",
46 | "sphinx.ext.doctest",
47 | "sphinx.ext.intersphinx",
48 | "sphinx.ext.mathjax",
49 | "sphinx.ext.viewcode",
50 | "sphinx.ext.githubpages",
51 | "sphinx.ext.napoleon",
52 | ]
53 |
54 | # Add any paths that contain templates here, relative to this directory.
55 | templates_path = ["_templates"]
56 |
57 | # List of patterns, relative to source directory, that match files and
58 | # directories to ignore when looking for source files.
59 | # This pattern also affects html_static_path and html_extra_path.
60 | exclude_patterns = ["**/.ipynb_checkpoints"]
61 |
62 |
63 | # -- Options for HTML output -------------------------------------------------
64 |
65 | # The theme to use for HTML and HTML Help pages. See the documentation for
66 | # a list of builtin themes.
67 | #
68 | html_theme = "pydata_sphinx_theme"
69 | html_logo = "_static/logo.svg"
70 | html_favicon = "_static/favicon.png"
71 |
72 | # Add any paths that contain custom static files (such as style sheets) here,
73 | # relative to this directory. They are copied after the builtin static files,
74 | # so a file named "default.css" will overwrite the builtin "default.css".
75 | html_static_path = ["_static"]
76 |
77 |
78 | # setup configuration
79 | def skip(app, what, name, obj, skip, options):
80 | """
81 | Document __init__ methods
82 | """
83 | if name == "__init__":
84 | return True
85 | return skip
86 |
87 |
88 | apidoc_output_folder = SOURCE_PATH.joinpath("api")
89 |
90 | PACKAGES = [pytorch_forecasting.__name__]
91 |
92 |
93 | def get_by_name(string: str):
94 | """
95 | Import by name and return imported module/function/class
96 |
97 | Parameters
98 | ----------
99 | string (str):
100 | module/function/class to import, e.g. 'pandas.read_csv'
101 | will return read_csv function as defined by pandas
102 |
103 | Returns
104 | -------
105 | imported object
106 | """
107 | class_name = string.split(".")[-1]
108 | module_name = ".".join(string.split(".")[:-1])
109 |
110 | if module_name == "":
111 | return getattr(sys.modules[__name__], class_name)
112 |
113 | mod = __import__(module_name, fromlist=[class_name])
114 | return getattr(mod, class_name)
115 |
116 |
117 | class ModuleAutoSummary(Autosummary):
118 | def get_items(self, names):
119 | new_names = []
120 | for name in names:
121 | mod = sys.modules[name]
122 | mod_items = getattr(mod, "__all__", mod.__dict__)
123 | for t in mod_items:
124 | if "." not in t and not t.startswith("_"):
125 | obj = get_by_name(f"{name}.{t}")
126 | if hasattr(obj, "__module__"):
127 | mod_name = obj.__module__
128 | t = f"{mod_name}.{t}"
129 | if t.startswith("pytorch_forecasting"):
130 | new_names.append(t)
131 | new_items = super().get_items(sorted(new_names))
132 | return new_items
133 |
134 |
135 | def setup(app: Sphinx):
136 | app.add_css_file("custom.css")
137 | app.connect("autodoc-skip-member", skip)
138 | app.add_directive("moduleautosummary", ModuleAutoSummary)
139 | app.add_js_file("https://buttons.github.io/buttons.js", **{"async": "async"})
140 |
141 |
142 | # extension configuration
143 | mathjax_path = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML"
144 |
145 | # theme options
146 | html_theme_options = {
147 | "github_url": "https://github.com/sktime/pytorch-forecasting",
148 | "navbar_end": ["navbar-icon-links.html", "search-field.html"],
149 | "show_nav_level": 2,
150 | "header_links_before_dropdown": 10,
151 | "external_links": [
152 | {"name": "GitHub", "url": "https://github.com/sktime/pytorch-forecasting"}
153 | ],
154 | }
155 |
156 | html_sidebars = {
157 | "index": [],
158 | # "getting-started": [],
159 | # "data": [],
160 | # "models": [],
161 | # "metrics": [],
162 | "faq": [],
163 | "contribute": [],
164 | "CHANGELOG": [],
165 | }
166 |
167 |
168 | autodoc_member_order = "groupwise"
169 | autoclass_content = "both"
170 |
171 | # autosummary
172 | autosummary_generate = True
173 | shutil.rmtree(SOURCE_PATH.joinpath("api"), ignore_errors=True)
174 |
175 | # copy changelog
176 | shutil.copy(
177 | "../../CHANGELOG.md",
178 | "CHANGELOG.md",
179 | )
180 |
181 | intersphinx_mapping = {
182 | "sklearn": ("https://scikit-learn.org/stable/", None),
183 | }
184 |
185 | suppress_warnings = [
186 | "autosummary.import_cycle",
187 | ]
188 |
189 | # -----------nbsphinx extension ----------
190 | nbsphinx_execute = "never" # always
191 | nbsphinx_allow_errors = False # False
192 | nbsphinx_timeout = 600 # seconds
193 |
--------------------------------------------------------------------------------
/docs/source/data.rst:
--------------------------------------------------------------------------------
1 | Data
2 | =====
3 |
4 | .. currentmodule:: pytorch_forecasting.data
5 |
6 | Loading data for timeseries forecasting is not trivial - in particular if covariates are included and values are missing.
7 | PyTorch Forecasting provides the :py:class:`~timeseries.TimeSeriesDataSet` which comes with a :py:meth:`~timeseries.TimeSeriesDataSet.to_dataloader`
8 | method to convert it to a dataloader and a :py:meth:`~timeseries.TimeSeriesDataSet.from_dataset` method to create, e.g. a validation
9 | or test dataset from a training dataset using the same label encoders and data normalization.
10 |
11 | Further, timeseries have to be (almost always) normalized for a neural network to learn efficiently. PyTorch Forecasting
12 | provides multiple such target normalizers (some of which can also be used for normalizing covariates).
13 |
14 |
15 | Time series data set
16 | ---------------------
17 |
18 | The time series dataset is the central data-holding object in PyTorch Forecasting. It primarily takes
19 | a pandas DataFrame along with some metadata. See the :ref:`tutorial on passing data to models ` to learn more it is coupled to models.
20 |
21 | .. autoclass:: pytorch_forecasting.data.timeseries.TimeSeriesDataSet
22 | :noindex:
23 | :members: __init__
24 |
25 | Details
26 | --------
27 |
28 | See the API documentation for further details on available data encoders and the :py:class:`~timeseries.TimeSeriesDataSet`:
29 |
30 | .. currentmodule:: pytorch_forecasting
31 |
32 | .. moduleautosummary::
33 | :toctree: api/
34 | :template: custom-module-template.rst
35 | :recursive:
36 |
37 | pytorch_forecasting.data
38 |
--------------------------------------------------------------------------------
/docs/source/faq.rst:
--------------------------------------------------------------------------------
1 | FAQ
2 | ====
3 |
4 | .. currentmodule:: pytorch_forecasting
5 |
6 | Common issues and answers. Other places to seek help from:
7 |
8 | * :ref:`Tutorials `
9 | * `PyTorch Lightning documentation `_ and issues
10 | * `PyTorch documentation `_ and issues
11 | * `Stack Overflow `_
12 |
13 |
14 | Creating datasets
15 | -----------------
16 |
17 | * **How do I create a dataset for new samples?**
18 |
19 | Use the :py:class:`~data.timeseries.TimeSeriesDataSet` method of your training dataset to
20 | create datasets on which you can run inference.
21 |
22 | * **How long should the encoder and decoder/prediction length be?**
23 |
24 | .. _faq_encoder_decoder_length:
25 |
26 | Choose something reasonably long, but not much longer than 500 for the encoder length and
27 | 200 for the decoder length. Consider that longer lengths increase the time it takes
28 | for your model to train.
29 |
30 | The ratio of decoder and encoder length depends on the used alogrithm.
31 | Look at :ref:`documentation ` to get clues.
32 |
33 | * **It takes very long to create the dataset. Why is that?**
34 |
35 | If you set ``allow_missing_timesteps=True`` in your dataset, the creation of an index
36 | might take far more time as all missing values in the timeseries have to be identified.
37 | The algorithm might be possible to speed up but currently, it might be faster for you to
38 | not allow missing values and fill them yourself.
39 |
40 |
41 | * **How are missing values treated?**
42 |
43 | #. Missing values between time points are either filled up with a fill
44 | forward or a constant fill-in strategy
45 | #. Missing values indicated by NaNs are a problem and
46 | should be filled in up-front, e.g. with the median value and another missing indicator categorical variable.
47 | #. Missing values in the future (out of range) are not filled in and
48 | simply not predicted. You have to provide values into the future.
49 | If those values are amongst the unknown future values, they will simply be ignored.
50 |
51 |
52 | Training models
53 | ---------------
54 |
55 | * **My training seems to freeze - nothing seem to be happening although my CPU/GPU is working at 100%.
56 | How to fix this issue?**
57 |
58 | Probably, your model is too big (check the number of parameters with ``model.size()`` or
59 | the dataset encoder and decoder length are unrealistically large. See
60 | :ref:`How long should the encoder and decoder/prediction length be? `
61 |
62 | * **Why does the learning rate finder not finish?**
63 |
64 | First, ensure that the trainer does not have the keword ``fast_dev_run=True`` and
65 | ``limit_train_batches=...`` set. Second, use a target normalizer in your training dataset.
66 | Third, increase the ``early_stop_threshold`` argument
67 | of the ``lr_find`` method to a large number.
68 |
69 | * **Why do I get lots of matplotlib warnings when running the learning rate finder?**
70 |
71 | This is because you keep on creating plots for logging but without a logger.
72 | Set ``log_interval=-1`` in your model to avoid this behaviour.
73 |
74 | * **How do I choose hyperparameters?**
75 |
76 | Consult the :ref:`model documentation ` to understand which parameters
77 | are important and which ranges are reasonable. Choose the learning rate with
78 | the learning rate finder. To tune hyperparameters, the `optuna package `_
79 | is a great place to start with.
80 |
81 |
82 | Interpreting models
83 | -------------------
84 |
85 | * **What interpretation is built into PyTorch Forecasting?**
86 |
87 | Look up the :ref:`model documentation ` for the model you use for model-specific interpretation.
88 | Further, all models come with some basic methods inherited from :py:class:`~models.base_model.BaseModel`.
89 |
--------------------------------------------------------------------------------
/docs/source/getting-started.rst:
--------------------------------------------------------------------------------
1 | Getting started
2 | ===============
3 |
4 | .. _getting-started:
5 |
6 |
7 | Installation
8 | --------------
9 |
10 | .. _install:
11 |
12 | If you are working Windows, you need to first install PyTorch with
13 |
14 | .. code-block:: bash
15 |
16 | pip install torch -f https://download.pytorch.org/whl/torch_stable.html
17 |
18 | Otherwise, you can proceed with
19 |
20 | .. code-block:: bash
21 |
22 | pip install pytorch-forecasting
23 |
24 |
25 | Alternatively, to install the package via ``conda``:
26 | .. code-block:: bash
27 |
28 | conda install pytorch-forecasting pytorch>=1.7 -c pytorch -c conda-forge
29 |
30 | PyTorch Forecasting is now installed from the conda-forge channel while PyTorch is installed from the pytorch channel.
31 |
32 | To use the MQF2 loss (multivariate quantile loss), also install
33 |
34 | .. code-block:: bash
35 |
36 | pip install pytorch-forecasting[mqf2]
37 |
38 |
39 | Usage
40 | -------------
41 |
42 | .. currentmodule:: pytorch_forecasting
43 |
44 | The library builds strongly upon `PyTorch Lightning `_ which allows to train models with ease,
45 | spot bugs quickly and train on multiple GPUs out-of-the-box.
46 |
47 | Further, we rely on `Tensorboard `_ for logging training progress.
48 |
49 | The general setup for training and testing a model is
50 |
51 | #. Create training dataset using :py:class:`~data.timeseries.TimeSeriesDataSet`.
52 | #. Using the training dataset, create a validation dataset with :py:meth:`~data.timeseries.TimeSeriesDataSet.from_dataset`.
53 | Similarly, a test dataset or later a dataset for inference can be created. You can store the dataset parameters
54 | directly if you do not wish to load the entire training dataset at inference time.
55 |
56 | #. Instantiate a model using the ``.from_dataset()`` method.
57 | #. Create a ``lightning.Trainer()`` object.
58 | #. Find the optimal learning rate with its ``.tuner.lr_find()`` method.
59 | #. Train the model with early stopping on the training dataset and use the tensorboard logs
60 | to understand if it has converged with acceptable accuracy.
61 | #. Tune the hyperparameters of the model with your
62 | `favourite package `_.
63 | #. Train the model with the same learning rate schedule on the entire dataset.
64 | #. Load the model from the model checkpoint and apply it to new data.
65 |
66 |
67 | The :ref:`Tutorials ` section provides detailed guidance and examples on how to use models and implement new ones.
68 |
69 |
70 | Example
71 | --------
72 |
73 |
74 | .. code-block:: python
75 |
76 | import lightning.pytorch as pl
77 | from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
78 | from lightning.pytorch.tuner import Tuner
79 | from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
80 |
81 | # load data
82 | data = ...
83 |
84 | # define dataset
85 | max_encoder_length = 36
86 | max_prediction_length = 6
87 | training_cutoff = "YYYY-MM-DD" # day for cutoff
88 |
89 | training = TimeSeriesDataSet(
90 | data[lambda x: x.date < training_cutoff],
91 | time_idx= ...,
92 | target= ...,
93 | # weight="weight",
94 | group_ids=[ ... ],
95 | max_encoder_length=max_encoder_length,
96 | max_prediction_length=max_prediction_length,
97 | static_categoricals=[ ... ],
98 | static_reals=[ ... ],
99 | time_varying_known_categoricals=[ ... ],
100 | time_varying_known_reals=[ ... ],
101 | time_varying_unknown_categoricals=[ ... ],
102 | time_varying_unknown_reals=[ ... ],
103 | )
104 |
105 | # create validation and training dataset
106 | validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
107 | batch_size = 128
108 | train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
109 | val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)
110 |
111 | # define trainer with early stopping
112 | early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
113 | lr_logger = LearningRateMonitor()
114 | trainer = pl.Trainer(
115 | max_epochs=100,
116 | accelerator="auto",
117 | gradient_clip_val=0.1,
118 | limit_train_batches=30,
119 | callbacks=[lr_logger, early_stop_callback],
120 | )
121 |
122 | # create the model
123 | tft = TemporalFusionTransformer.from_dataset(
124 | training,
125 | learning_rate=0.03,
126 | hidden_size=32,
127 | attention_head_size=1,
128 | dropout=0.1,
129 | hidden_continuous_size=16,
130 | output_size=7,
131 | loss=QuantileLoss(),
132 | log_interval=2,
133 | reduce_on_plateau_patience=4
134 | )
135 | print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")
136 |
137 | # find optimal learning rate (set limit_train_batches to 1.0 and log_interval = -1)
138 | res = Tuner(trainer).lr_find(
139 | tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
140 | )
141 |
142 | print(f"suggested learning rate: {res.suggestion()}")
143 | fig = res.plot(show=True, suggest=True)
144 | fig.show()
145 |
146 | # fit the model
147 | trainer.fit(
148 | tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader,
149 | )
150 |
151 | Main API
152 | ---------
153 |
154 | .. currentmodule:: pytorch_forecasting
155 |
156 | .. moduleautosummary::
157 | :toctree: api
158 | :template: custom-module-template.rst
159 | :recursive:
160 |
161 | pytorch_forecasting
162 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. pytorch-forecasting documentation master file, created by
2 | sphinx-quickstart on Sun Aug 16 22:17:24 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | PyTorch Forecasting Documentation
7 | ==================================
8 |
9 | .. raw:: html
10 |
11 | GitHub
12 |
13 |
14 | Our article on `Towards Data Science `_
15 | introduces the package and provides background information.
16 |
17 | PyTorch Forecasting aims to ease state-of-the-art
18 | timeseries forecasting with neural networks for both real-world cases and
19 | research alike. The goal is to provide a high-level API with maximum flexibility for
20 | professionals and reasonable defaults for beginners.
21 | Specifically, the package provides
22 |
23 | * A timeseries dataset class which abstracts handling variable transformations, missing values,
24 | randomized subsampling, multiple history lengths, etc.
25 | * A base model class which provides basic training of timeseries models along with logging in tensorboard
26 | and generic visualizations such actual vs predictions and dependency plots
27 | * Multiple neural network architectures for timeseries forecasting that have been enhanced
28 | for real-world deployment and come with in-built interpretation capabilities
29 | * Multi-horizon timeseries metrics
30 | * Hyperparameter tuning with `optuna `_
31 |
32 | The package is built on `PyTorch Lightning `_ to allow
33 | training on CPUs, single and multiple GPUs out-of-the-box.
34 |
35 | If you do not have pytorch already installed, follow the :ref:`detailed installation instructions`.
36 |
37 | Otherwise, proceed to install the package by executing
38 |
39 | .. code-block::
40 |
41 | pip install pytorch-forecasting
42 |
43 | or to install via conda
44 |
45 | .. code-block::
46 |
47 | conda install pytorch-forecasting pytorch>=1.7 -c pytorch -c conda-forge
48 |
49 | To use the MQF2 loss (multivariate quantile loss), also execute
50 |
51 | .. code-block::
52 |
53 | pip install pytorch-forecasting[mqf2]
54 |
55 | Vist :ref:`Getting started ` to learn more about the package and detailled installation instruction.
56 | The :ref:`Tutorials ` section provides guidance on how to use models and implement new ones.
57 |
58 | .. toctree::
59 | :titlesonly:
60 | :hidden:
61 | :maxdepth: 6
62 |
63 | getting-started
64 | tutorials
65 | data
66 | models
67 | metrics
68 | faq
69 | installation
70 | api
71 | CHANGELOG
72 |
73 |
74 | Indices and tables
75 | ==================
76 |
77 | * :ref:`genindex`
78 | * :ref:`modindex`
79 | * :ref:`search`
80 |
--------------------------------------------------------------------------------
/docs/source/metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics
2 | ==========
3 |
4 | Multiple metrics have been implemented to ease adaptation.
5 |
6 | In particular, these metrics can be applied to the multi-horizon forecasting problem, i.e.
7 | can take tensors that are not only of shape ``n_samples`` but also ``n_samples x prediction_horizon``
8 | or even ``n_samples x prediction_horizon x n_outputs``, where ``n_outputs`` could be the number
9 | of forecasted quantiles.
10 |
11 | Metrics can be easily combined by addition, e.g.
12 |
13 | .. code-block:: python
14 |
15 | from pytorch_forecasting.metrics import SMAPE, MAE
16 |
17 | composite_metric = SMAPE() + 1e-4 * MAE()
18 |
19 | Such composite metrics are useful when training because they can reduce outliers in other metrics.
20 | In the example, SMAPE is mostly optimized, while large outliers in MAE are avoided.
21 |
22 | Further, one can modify a loss metric to reduce a mean prediction bias, i.e. ensure that
23 | predictions add up. For example:
24 |
25 | .. code-block:: python
26 |
27 | from pytorch_forecasting.metrics import MAE, AggregationMetric
28 |
29 | composite_metric = MAE() + AggregationMetric(metric=MAE())
30 |
31 | Here we add to MAE an additional loss. This additional loss is the MAE calculated on the mean predictions
32 | and actuals. We can also use other metrics such as SMAPE to ensure aggregated results are unbiased in that metric.
33 | One important point to keep in mind is that this metric is calculated accross samples, i.e. it will vary depending
34 | on the batch size. In particular, errors tend to average out with increased batch sizes.
35 |
36 |
37 | Details
38 | --------
39 |
40 | See the API documentation for further details on available metrics:
41 |
42 | .. currentmodule:: pytorch_forecasting
43 |
44 | .. moduleautosummary::
45 | :toctree: api
46 | :template: custom-module-template.rst
47 | :recursive:
48 |
49 | pytorch_forecasting.metrics
50 |
--------------------------------------------------------------------------------
/docs/source/tutorials.rst:
--------------------------------------------------------------------------------
1 | Tutorials
2 | ==========
3 |
4 | .. _tutorials:
5 |
6 | The following tutorials can be also found as `notebooks on GitHub `_.
7 |
8 | .. toctree::
9 | :titlesonly:
10 | :maxdepth: 2
11 |
12 | tutorials/stallion
13 | tutorials/ar
14 | tutorials/building
15 | tutorials/deepar
16 | tutorials/nhits
17 |
--------------------------------------------------------------------------------
/examples/ar.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import lightning.pytorch as pl
4 | from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
5 | import pandas as pd
6 | from pandas.core.common import SettingWithCopyWarning
7 | import torch
8 |
9 | from pytorch_forecasting import GroupNormalizer, TimeSeriesDataSet
10 | from pytorch_forecasting.data import NaNLabelEncoder
11 | from pytorch_forecasting.data.examples import generate_ar_data
12 | from pytorch_forecasting.metrics import NormalDistributionLoss
13 | from pytorch_forecasting.models.deepar import DeepAR
14 |
15 | warnings.simplefilter("error", category=SettingWithCopyWarning)
16 |
17 |
18 | data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
19 | data["static"] = "2"
20 | data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
21 | validation = data.series.sample(20)
22 |
23 | max_encoder_length = 60
24 | max_prediction_length = 20
25 |
26 | training_cutoff = data["time_idx"].max() - max_prediction_length
27 |
28 | training = TimeSeriesDataSet(
29 | data[lambda x: ~x.series.isin(validation)],
30 | time_idx="time_idx",
31 | target="value",
32 | categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
33 | group_ids=["series"],
34 | static_categoricals=["static"],
35 | min_encoder_length=max_encoder_length,
36 | max_encoder_length=max_encoder_length,
37 | min_prediction_length=max_prediction_length,
38 | max_prediction_length=max_prediction_length,
39 | time_varying_unknown_reals=["value"],
40 | time_varying_known_reals=["time_idx"],
41 | target_normalizer=GroupNormalizer(groups=["series"]),
42 | add_relative_time_idx=False,
43 | add_target_scales=True,
44 | randomize_length=None,
45 | )
46 |
47 | validation = TimeSeriesDataSet.from_dataset(
48 | training,
49 | data[lambda x: x.series.isin(validation)],
50 | # predict=True,
51 | stop_randomization=True,
52 | )
53 | batch_size = 64
54 | train_dataloader = training.to_dataloader(
55 | train=True, batch_size=batch_size, num_workers=0
56 | )
57 | val_dataloader = validation.to_dataloader(
58 | train=False, batch_size=batch_size, num_workers=0
59 | )
60 |
61 | # save datasets
62 | training.save("training.pkl")
63 | validation.save("validation.pkl")
64 |
65 | early_stop_callback = EarlyStopping(
66 | monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min"
67 | )
68 | lr_logger = LearningRateMonitor()
69 |
70 | trainer = pl.Trainer(
71 | max_epochs=10,
72 | accelerator="gpu",
73 | devices="auto",
74 | gradient_clip_val=0.1,
75 | limit_train_batches=30,
76 | limit_val_batches=3,
77 | # fast_dev_run=True,
78 | # logger=logger,
79 | # profiler=True,
80 | callbacks=[lr_logger, early_stop_callback],
81 | )
82 |
83 |
84 | deepar = DeepAR.from_dataset(
85 | training,
86 | learning_rate=0.1,
87 | hidden_size=32,
88 | dropout=0.1,
89 | loss=NormalDistributionLoss(),
90 | log_interval=10,
91 | log_val_interval=3,
92 | # reduce_on_plateau_patience=3,
93 | )
94 | print(f"Number of parameters in network: {deepar.size() / 1e3:.1f}k")
95 |
96 | # # find optimal learning rate
97 | # deepar.hparams.log_interval = -1
98 | # deepar.hparams.log_val_interval = -1
99 | # trainer.limit_train_batches = 1.0
100 | # res = Tuner(trainer).lr_find(
101 | # deepar, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501
102 | # )
103 |
104 | # print(f"suggested learning rate: {res.suggestion()}")
105 | # fig = res.plot(show=True, suggest=True)
106 | # fig.show()
107 | # deepar.hparams.learning_rate = res.suggestion()
108 |
109 | torch.set_num_threads(10)
110 | trainer.fit(
111 | deepar,
112 | train_dataloaders=train_dataloader,
113 | val_dataloaders=val_dataloader,
114 | )
115 |
116 | # calcualte mean absolute error on validation set
117 | actuals = torch.cat([y for x, (y, weight) in iter(val_dataloader)])
118 | predictions = deepar.predict(val_dataloader)
119 | print(f"Mean absolute error of model: {(actuals - predictions).abs().mean()}")
120 |
121 | # # plot actual vs. predictions
122 | # raw_predictions, x = deepar.predict(val_dataloader, mode="raw", return_x=True)
123 | # for idx in range(10): # plot 10 examples
124 | # deepar.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True)
125 |
--------------------------------------------------------------------------------
/examples/data/stallion.parquet:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sktime/pytorch-forecasting/4384140418bfe8e8c64cc6ce5fc19372508bfccd/examples/data/stallion.parquet
--------------------------------------------------------------------------------
/examples/nbeats.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import lightning.pytorch as pl
4 | from lightning.pytorch.callbacks import EarlyStopping
5 | import pandas as pd
6 |
7 | from pytorch_forecasting import NBeats, TimeSeriesDataSet
8 | from pytorch_forecasting.data import NaNLabelEncoder
9 | from pytorch_forecasting.data.examples import generate_ar_data
10 |
11 | sys.path.append("..")
12 |
13 |
14 | print("load data")
15 | data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
16 | data["static"] = 2
17 | data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
18 | validation = data.series.sample(20)
19 |
20 |
21 | max_encoder_length = 150
22 | max_prediction_length = 20
23 |
24 | training_cutoff = data["time_idx"].max() - max_prediction_length
25 |
26 | context_length = max_encoder_length
27 | prediction_length = max_prediction_length
28 |
29 | training = TimeSeriesDataSet(
30 | data[lambda x: x.time_idx < training_cutoff],
31 | time_idx="time_idx",
32 | target="value",
33 | categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
34 | group_ids=["series"],
35 | min_encoder_length=context_length,
36 | max_encoder_length=context_length,
37 | max_prediction_length=prediction_length,
38 | min_prediction_length=prediction_length,
39 | time_varying_unknown_reals=["value"],
40 | randomize_length=None,
41 | add_relative_time_idx=False,
42 | add_target_scales=False,
43 | )
44 |
45 | validation = TimeSeriesDataSet.from_dataset(
46 | training, data, min_prediction_idx=training_cutoff
47 | )
48 | batch_size = 128
49 | train_dataloader = training.to_dataloader(
50 | train=True, batch_size=batch_size, num_workers=2
51 | )
52 | val_dataloader = validation.to_dataloader(
53 | train=False, batch_size=batch_size, num_workers=2
54 | )
55 |
56 |
57 | early_stop_callback = EarlyStopping(
58 | monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
59 | )
60 | trainer = pl.Trainer(
61 | max_epochs=100,
62 | accelerator="auto",
63 | gradient_clip_val=0.1,
64 | callbacks=[early_stop_callback],
65 | limit_train_batches=15,
66 | # limit_val_batches=1,
67 | # fast_dev_run=True,
68 | # logger=logger,
69 | # profiler=True,
70 | )
71 |
72 |
73 | net = NBeats.from_dataset(
74 | training,
75 | learning_rate=3e-2,
76 | log_interval=10,
77 | log_val_interval=1,
78 | log_gradient_flow=False,
79 | weight_decay=1e-2,
80 | )
81 | print(f"Number of parameters in network: {net.size() / 1e3:.1f}k")
82 |
83 | # # find optimal learning rate
84 | # # remove logging and artificial epoch size
85 | # net.hparams.log_interval = -1
86 | # net.hparams.log_val_interval = -1
87 | # trainer.limit_train_batches = 1.0
88 | # # run learning rate finder
89 | # res = Tuner(trainer).lr_find(
90 | # net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501
91 | # )
92 | # print(f"suggested learning rate: {res.suggestion()}")
93 | # fig = res.plot(show=True, suggest=True)
94 | # fig.show()
95 | # net.hparams.learning_rate = res.suggestion()
96 |
97 | trainer.fit(
98 | net,
99 | train_dataloaders=train_dataloader,
100 | val_dataloaders=val_dataloader,
101 | )
102 |
--------------------------------------------------------------------------------
/examples/stallion.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import warnings
3 |
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
6 | from lightning.pytorch.loggers import TensorBoardLogger
7 | import numpy as np
8 | from pandas.core.common import SettingWithCopyWarning
9 |
10 | from pytorch_forecasting import (
11 | GroupNormalizer,
12 | TemporalFusionTransformer,
13 | TimeSeriesDataSet,
14 | )
15 | from pytorch_forecasting.data.examples import get_stallion_data
16 | from pytorch_forecasting.metrics import QuantileLoss
17 | from pytorch_forecasting.models.temporal_fusion_transformer.tuning import (
18 | optimize_hyperparameters,
19 | )
20 |
21 | warnings.simplefilter("error", category=SettingWithCopyWarning)
22 |
23 |
24 | data = get_stallion_data()
25 |
26 | data["month"] = data.date.dt.month.astype("str").astype("category")
27 | data["log_volume"] = np.log(data.volume + 1e-8)
28 |
29 | data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
30 | data["time_idx"] -= data["time_idx"].min()
31 | data["avg_volume_by_sku"] = data.groupby(
32 | ["time_idx", "sku"], observed=True
33 | ).volume.transform("mean")
34 | data["avg_volume_by_agency"] = data.groupby(
35 | ["time_idx", "agency"], observed=True
36 | ).volume.transform("mean")
37 | # data = data[lambda x: (x.sku == data.iloc[0]["sku"]) & (x.agency == data.iloc[0]["agency"])] # noqa: E501
38 | special_days = [
39 | "easter_day",
40 | "good_friday",
41 | "new_year",
42 | "christmas",
43 | "labor_day",
44 | "independence_day",
45 | "revolution_day_memorial",
46 | "regional_games",
47 | "fifa_u_17_world_cup",
48 | "football_gold_cup",
49 | "beer_capital",
50 | "music_fest",
51 | ]
52 | data[special_days] = (
53 | data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
54 | )
55 |
56 | training_cutoff = data["time_idx"].max() - 6
57 | max_encoder_length = 36
58 | max_prediction_length = 6
59 |
60 | training = TimeSeriesDataSet(
61 | data[lambda x: x.time_idx <= training_cutoff],
62 | time_idx="time_idx",
63 | target="volume",
64 | group_ids=["agency", "sku"],
65 | min_encoder_length=max_encoder_length
66 | // 2, # allow encoder lengths from 0 to max_prediction_length
67 | max_encoder_length=max_encoder_length,
68 | min_prediction_length=1,
69 | max_prediction_length=max_prediction_length,
70 | static_categoricals=["agency", "sku"],
71 | static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
72 | time_varying_known_categoricals=["special_days", "month"],
73 | variable_groups={
74 | "special_days": special_days
75 | }, # group of categorical variables can be treated as one variable
76 | time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
77 | time_varying_unknown_categoricals=[],
78 | time_varying_unknown_reals=[
79 | "volume",
80 | "log_volume",
81 | "industry_volume",
82 | "soda_volume",
83 | "avg_max_temp",
84 | "avg_volume_by_agency",
85 | "avg_volume_by_sku",
86 | ],
87 | target_normalizer=GroupNormalizer(
88 | groups=["agency", "sku"], transformation="softplus", center=False
89 | ), # use softplus with beta=1.0 and normalize by group
90 | add_relative_time_idx=True,
91 | add_target_scales=True,
92 | add_encoder_length=True,
93 | )
94 |
95 |
96 | validation = TimeSeriesDataSet.from_dataset(
97 | training, data, predict=True, stop_randomization=True
98 | )
99 | batch_size = 64
100 | train_dataloader = training.to_dataloader(
101 | train=True, batch_size=batch_size, num_workers=0
102 | )
103 | val_dataloader = validation.to_dataloader(
104 | train=False, batch_size=batch_size, num_workers=0
105 | )
106 |
107 |
108 | # save datasets
109 | training.save("t raining.pkl")
110 | validation.save("validation.pkl")
111 |
112 | early_stop_callback = EarlyStopping(
113 | monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
114 | )
115 | lr_logger = LearningRateMonitor()
116 | logger = TensorBoardLogger(log_graph=True)
117 |
118 | trainer = pl.Trainer(
119 | max_epochs=100,
120 | accelerator="auto",
121 | gradient_clip_val=0.1,
122 | limit_train_batches=30,
123 | # val_check_interval=20,
124 | # limit_val_batches=1,
125 | # fast_dev_run=True,
126 | logger=logger,
127 | # profiler=True,
128 | callbacks=[lr_logger, early_stop_callback],
129 | )
130 |
131 |
132 | tft = TemporalFusionTransformer.from_dataset(
133 | training,
134 | learning_rate=0.03,
135 | hidden_size=16,
136 | attention_head_size=1,
137 | dropout=0.1,
138 | hidden_continuous_size=8,
139 | output_size=7,
140 | loss=QuantileLoss(),
141 | log_interval=10,
142 | log_val_interval=1,
143 | reduce_on_plateau_patience=3,
144 | )
145 | print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")
146 |
147 | # # find optimal learning rate
148 | # # remove logging and artificial epoch size
149 | # tft.hparams.log_interval = -1
150 | # tft.hparams.log_val_interval = -1
151 | # trainer.limit_train_batches = 1.0
152 | # # run learning rate finder
153 | # res = Tuner(trainer).lr_find(
154 | # tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501
155 | # )
156 | # print(f"suggested learning rate: {res.suggestion()}")
157 | # fig = res.plot(show=True, suggest=True)
158 | # fig.show()
159 | # tft.hparams.learning_rate = res.suggestion()
160 |
161 | # trainer.fit(
162 | # tft,
163 | # train_dataloaders=train_dataloader,
164 | # val_dataloaders=val_dataloader,
165 | # )
166 |
167 | # # make a prediction on entire validation set
168 | # preds, index = tft.predict(val_dataloader, return_index=True, fast_dev_run=True)
169 |
170 |
171 | # tune
172 | study = optimize_hyperparameters(
173 | train_dataloader,
174 | val_dataloader,
175 | model_path="optuna_test",
176 | n_trials=200,
177 | max_epochs=50,
178 | gradient_clip_val_range=(0.01, 1.0),
179 | hidden_size_range=(8, 128),
180 | hidden_continuous_size_range=(8, 128),
181 | attention_head_size_range=(1, 4),
182 | learning_rate_range=(0.001, 0.1),
183 | dropout_range=(0.1, 0.3),
184 | trainer_kwargs=dict(limit_train_batches=30),
185 | reduce_on_plateau_patience=4,
186 | use_learning_rate_finder=False,
187 | )
188 | with open("test_study.pkl", "wb") as fout:
189 | pickle.dump(study, fout)
190 |
191 |
192 | # profile speed
193 | # profile(
194 | # trainer.fit,
195 | # profile_fname="profile.prof",
196 | # model=tft,
197 | # period=0.001,
198 | # filter="pytorch_forecasting",
199 | # train_dataloaders=train_dataloader,
200 | # val_dataloaders=val_dataloader,
201 | # )
202 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "pytorch-forecasting"
3 | readme = "README.md" # Markdown files are supported
4 | version = "1.3.0" # is being replaced automatically
5 |
6 | authors = [
7 | {name = "Jan Beitner"},
8 | ]
9 | requires-python = ">=3.9,<3.14"
10 | classifiers = [
11 | "Intended Audience :: Developers",
12 | "Intended Audience :: Science/Research",
13 | "Programming Language :: Python :: 3",
14 | "Programming Language :: Python :: 3.9",
15 | "Programming Language :: Python :: 3.10",
16 | "Programming Language :: Python :: 3.11",
17 | "Programming Language :: Python :: 3.12",
18 | "Programming Language :: Python :: 3.13",
19 | "Topic :: Scientific/Engineering",
20 | "Topic :: Scientific/Engineering :: Mathematics",
21 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
22 | "Topic :: Software Development",
23 | "Topic :: Software Development :: Libraries",
24 | "Topic :: Software Development :: Libraries :: Python Modules",
25 | "License :: OSI Approved :: MIT License",
26 | ]
27 | description = "Forecasting timeseries with PyTorch - dataloaders, normalizers, metrics and models"
28 |
29 | dependencies = [
30 | "numpy<=3.0.0",
31 | "torch >=2.0.0,!=2.0.1,<3.0.0",
32 | "lightning >=2.0.0,<3.0.0",
33 | "scipy >=1.8,<2.0",
34 | "pandas >=1.3.0,<3.0.0",
35 | "scikit-learn >=1.2,<2.0",
36 | ]
37 |
38 | [project.optional-dependencies]
39 | # there are the following dependency sets:
40 | # - all_extras - all soft dependencies
41 | # - granular dependency sets:
42 | # - tuning - dependencies for tuning hyperparameters via optuna
43 | # - mqf2 - dependencies for multivariate quantile loss
44 | # - graph - dependencies for graph based forecasting
45 | # - dev - the developer dependency set, for contributors to pytorch-forecasting
46 | # - CI related: e.g., dev, github-actions. Not for users.
47 | #
48 | # soft dependencies are not required for the core functionality of pytorch-forecasting
49 | # but are required by popular estimators, e.g., prophet, tbats, etc.
50 |
51 | # all soft dependencies
52 | #
53 | # users can install via "pip install pytorch-forecasting[all_extras]"
54 | #
55 | all_extras = [
56 | "cpflows",
57 | "matplotlib",
58 | "optuna >=3.1.0,<5.0.0",
59 | "optuna-integration",
60 | "pytorch_optimizer >=2.5.1,<4.0.0",
61 | "statsmodels",
62 | ]
63 |
64 | tuning = [
65 | "optuna >=3.1.0,<5.0.0",
66 | "optuna-integration",
67 | "statsmodels",
68 | ]
69 |
70 | mqf2 = ["cpflows"]
71 |
72 | # the graph set is not currently used within pytorch-forecasting
73 | # but is kept for future development, as it has already been released
74 | graph = ["networkx"]
75 |
76 | # dev - the developer dependency set, for contributors to pytorch-forecasting
77 | dev = [
78 | "pydocstyle >=6.1.1,<7.0.0",
79 | # checks and make tools
80 | "pre-commit >=3.2.0,<5.0.0",
81 | "invoke",
82 | "mypy",
83 | "pylint",
84 | "ruff",
85 | # pytest
86 | "pytest",
87 | "pytest-xdist",
88 | "pytest-cov",
89 | "pytest-sugar",
90 | "coverage",
91 | "pyarrow",
92 | # jupyter notebook
93 | "ipykernel",
94 | "nbconvert",
95 | "black[jupyter]",
96 | # documentatation
97 | "sphinx",
98 | "pydata-sphinx-theme",
99 | "nbsphinx",
100 | "recommonmark",
101 | "ipywidgets>=8.0.1,<9.0.0",
102 | "pytest-dotenv>=0.5.2,<1.0.0",
103 | "tensorboard>=2.12.1,<3.0.0",
104 | "pandoc>=2.3,<3.0.0",
105 | "scikit-base",
106 | ]
107 |
108 | # docs - dependencies for building the documentation
109 | docs = [
110 | "sphinx>3.2,<8.2.4",
111 | "pydata-sphinx-theme",
112 | "nbsphinx",
113 | "pandoc",
114 | "nbconvert",
115 | "recommonmark",
116 | "docutils",
117 | ]
118 |
119 | github-actions = ["pytest-github-actions-annotate-failures"]
120 |
121 | [tool.setuptools.packages.find]
122 | exclude = ["build_tools"]
123 |
124 | [build-system]
125 | build-backend = "setuptools.build_meta"
126 | requires = [
127 | "setuptools>=70.0.0",
128 | ]
129 |
130 | [tool.ruff]
131 | line-length = 88
132 | exclude = [
133 | "docs/build/",
134 | "node_modules/",
135 | ".eggs/",
136 | "versioneer.py",
137 | "venv/",
138 | ".venv/",
139 | ".git/",
140 | ".history/",
141 | "docs/source/tutorials/",
142 | ]
143 | target-version = "py39"
144 |
145 | [tool.ruff.format]
146 | # Enable formatting
147 | quote-style = "double"
148 | indent-style = "space"
149 | skip-magic-trailing-comma = false
150 | line-ending = "auto"
151 |
152 | [tool.ruff.lint]
153 | select = ["E", "F", "W", "C4", "S"]
154 | extend-select = [
155 | "I", # isort
156 | "UP", # pyupgrade
157 | "C4", # https://pypi.org/project/flake8-comprehensions
158 | ]
159 | extend-ignore = [
160 | "E203", # space before : (needed for how black formats slicing)
161 | "E402", # module level import not at top of file
162 | "E731", # do not assign a lambda expression, use a def
163 | "E741", # ignore not easy to read variables like i l I etc.
164 | "C406", # Unnecessary list literal - rewrite as a dict literal.
165 | "C408", # Unnecessary dict call - rewrite as a literal.
166 | "C409", # Unnecessary list passed to tuple() - rewrite as a tuple literal.
167 | "F401", # unused imports
168 | "S101", # use of assert
169 | ]
170 |
171 | [tool.ruff.lint.isort]
172 | known-first-party = ["pytorch_forecasting"]
173 | combine-as-imports = true
174 | force-sort-within-sections = true
175 |
176 | [tool.ruff.lint.per-file-ignores]
177 | "pytorch_forecasting/data/timeseries.py" = [
178 | "E501", # Line too long being fixed in #1746 To be removed after merging
179 | ]
180 |
181 | [tool.nbqa.mutate]
182 | ruff = 1
183 | black = 1
184 |
185 | [tool.nbqa.exclude]
186 | ruff = "docs/source/tutorials/" # ToDo: Remove this when fixing notebooks
187 |
188 | [tool.coverage.report]
189 | ignore_errors = false
190 | show_missing = true
191 |
192 | [tool.mypy]
193 | ignore_missing_imports = true
194 | no_implicit_optional = true
195 | check_untyped_defs = true
196 | cache_dir = ".cache/mypy/"
197 |
198 | [tool.pytest.ini_options]
199 | addopts = [
200 | "-rsxX",
201 | "-vv",
202 | "--cov-config=.coveragerc",
203 | "--cov=pytorch_forecasting",
204 | "--cov-report=html",
205 | "--cov-report=term-missing:skip-covered",
206 | "--no-cov-on-fail"
207 | ]
208 | markers = []
209 | testpaths = ["tests/", "pytorch_forecasting/tests/"]
210 | log_cli_level = "ERROR"
211 | log_format = "%(asctime)s %(levelname)s %(message)s"
212 | log_date_format = "%Y-%m-%d %H:%M:%S"
213 | cache_dir = ".cache"
214 | filterwarnings = [
215 | "ignore:Found \\d+ unknown classes which were set to NaN:UserWarning",
216 | "ignore:Less than \\d+ samples available for \\d+ prediction times. Use ba:UserWarning",
217 | "ignore:scale is below 1e-7 - consider not centering the data or using data with:UserWarning",
218 | "ignore:You defined a `validation_step` but have no `val_dataloader`:UserWarning",
219 | "ignore:ReduceLROnPlateau conditioned on metric:RuntimeWarning",
220 | "ignore:The number of training samples \\(\\d+\\) is smaller than the logging interval Trainer\\(:UserWarning",
221 | "ignore:The dataloader, [\\_\\s]+ \\d+, does not have many workers which may be a bottleneck.:UserWarning",
222 | "ignore:Consider increasing the value of the `num_workers` argument`:UserWarning",
223 | "ignore::UserWarning"
224 | ]
225 |
--------------------------------------------------------------------------------
/pytorch_forecasting/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | PyTorch Forecasting package for timeseries forecasting with PyTorch.
3 | """
4 |
5 | __version__ = "1.3.0"
6 |
7 | from pytorch_forecasting.data import (
8 | EncoderNormalizer,
9 | GroupNormalizer,
10 | MultiNormalizer,
11 | NaNLabelEncoder,
12 | TimeSeriesDataSet,
13 | )
14 | from pytorch_forecasting.metrics import (
15 | MAE,
16 | MAPE,
17 | MASE,
18 | RMSE,
19 | SMAPE,
20 | BetaDistributionLoss,
21 | CrossEntropy,
22 | DistributionLoss,
23 | ImplicitQuantileNetworkDistributionLoss,
24 | LogNormalDistributionLoss,
25 | MQF2DistributionLoss,
26 | MultiHorizonMetric,
27 | MultiLoss,
28 | MultivariateNormalDistributionLoss,
29 | NegativeBinomialDistributionLoss,
30 | NormalDistributionLoss,
31 | PoissonLoss,
32 | QuantileLoss,
33 | )
34 | from pytorch_forecasting.models import (
35 | GRU,
36 | LSTM,
37 | AutoRegressiveBaseModel,
38 | AutoRegressiveBaseModelWithCovariates,
39 | Baseline,
40 | BaseModel,
41 | BaseModelWithCovariates,
42 | DecoderMLP,
43 | DeepAR,
44 | MultiEmbedding,
45 | NBeats,
46 | NHiTS,
47 | RecurrentNetwork,
48 | TemporalFusionTransformer,
49 | TiDEModel,
50 | get_rnn,
51 | )
52 | from pytorch_forecasting.utils import (
53 | apply_to_list,
54 | autocorrelation,
55 | create_mask,
56 | detach,
57 | get_embedding_size,
58 | groupby_apply,
59 | integer_histogram,
60 | move_to_device,
61 | profile,
62 | to_list,
63 | unpack_sequence,
64 | )
65 | from pytorch_forecasting.utils._maint._show_versions import show_versions
66 |
67 | __all__ = [
68 | "TimeSeriesDataSet",
69 | "GroupNormalizer",
70 | "EncoderNormalizer",
71 | "NaNLabelEncoder",
72 | "MultiNormalizer",
73 | "TemporalFusionTransformer",
74 | "TiDEModel",
75 | "NBeats",
76 | "NHiTS",
77 | "Baseline",
78 | "DeepAR",
79 | "BaseModel",
80 | "BaseModelWithCovariates",
81 | "AutoRegressiveBaseModel",
82 | "AutoRegressiveBaseModelWithCovariates",
83 | "MultiHorizonMetric",
84 | "MultiLoss",
85 | "MAE",
86 | "MAPE",
87 | "MASE",
88 | "SMAPE",
89 | "DistributionLoss",
90 | "BetaDistributionLoss",
91 | "LogNormalDistributionLoss",
92 | "NegativeBinomialDistributionLoss",
93 | "NormalDistributionLoss",
94 | "ImplicitQuantileNetworkDistributionLoss",
95 | "MultivariateNormalDistributionLoss",
96 | "MQF2DistributionLoss",
97 | "CrossEntropy",
98 | "PoissonLoss",
99 | "QuantileLoss",
100 | "RMSE",
101 | "get_rnn",
102 | "LSTM",
103 | "GRU",
104 | "MultiEmbedding",
105 | "apply_to_list",
106 | "autocorrelation",
107 | "get_embedding_size",
108 | "create_mask",
109 | "to_list",
110 | "RecurrentNetwork",
111 | "DecoderMLP",
112 | "detach",
113 | "move_to_device",
114 | "integer_histogram",
115 | "groupby_apply",
116 | "profile",
117 | "show_versions",
118 | "unpack_sequence",
119 | ]
120 |
--------------------------------------------------------------------------------
/pytorch_forecasting/_registry/__init__.py:
--------------------------------------------------------------------------------
1 | """PyTorch Forecasting registry."""
2 |
3 | from pytorch_forecasting._registry._lookup import all_objects
4 |
5 | __all__ = ["all_objects"]
6 |
--------------------------------------------------------------------------------
/pytorch_forecasting/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Datasets, etc. for timeseries data.
3 |
4 | Handling timeseries data is not trivial. It requires special treatment.
5 | This sub-package provides the necessary tools to abstracts the necessary work.
6 | """
7 |
8 | from pytorch_forecasting.data.encoders import (
9 | EncoderNormalizer,
10 | GroupNormalizer,
11 | MultiNormalizer,
12 | NaNLabelEncoder,
13 | TorchNormalizer,
14 | )
15 | from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler
16 | from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet
17 |
18 | __all__ = [
19 | "TimeSeriesDataSet",
20 | "TimeSeries",
21 | "NaNLabelEncoder",
22 | "GroupNormalizer",
23 | "TorchNormalizer",
24 | "EncoderNormalizer",
25 | "TimeSynchronizedBatchSampler",
26 | "MultiNormalizer",
27 | ]
28 |
--------------------------------------------------------------------------------
/pytorch_forecasting/data/examples.py:
--------------------------------------------------------------------------------
1 | """
2 | Example datasets for tutorials and testing.
3 | """
4 |
5 | from pathlib import Path
6 | from urllib.request import urlretrieve
7 |
8 | import numpy as np
9 | import pandas as pd
10 |
11 | BASE_URL = "https://github.com/sktime/pytorch-forecasting/raw/main/examples/data/"
12 |
13 | DATA_PATH = Path(__file__).parent
14 |
15 |
16 | def _get_data_by_filename(fname: str) -> Path:
17 | """
18 | Download file or used cached version.
19 |
20 | Args:
21 | fname (str): name of file to download
22 |
23 | Returns:
24 | Path: path at which file lives
25 | """
26 | full_fname = DATA_PATH.joinpath(fname)
27 |
28 | # check if file exists - download if necessary
29 | if not full_fname.exists():
30 | url = BASE_URL + fname
31 | urlretrieve(url, full_fname) # noqa: S310
32 |
33 | return full_fname
34 |
35 |
36 | def get_stallion_data() -> pd.DataFrame:
37 | """
38 | Demand data with covariates.
39 |
40 | ~20k samples of 350 timeseries. Important columns
41 |
42 | * Timeseries can be identified by ``agency`` and ``sku``.
43 | * ``volume`` is the demand
44 | * ``date`` is the month of the demand.
45 |
46 | Returns:
47 | pd.DataFrame: data
48 | """
49 | fname = _get_data_by_filename("stallion.parquet")
50 | return pd.read_parquet(fname)
51 |
52 |
53 | def generate_ar_data(
54 | n_series: int = 10,
55 | timesteps: int = 400,
56 | seasonality: float = 3.0,
57 | trend: float = 3.0,
58 | noise: float = 0.1,
59 | level: float = 1.0,
60 | exp: bool = False,
61 | seed: int = 213,
62 | ) -> pd.DataFrame:
63 | """
64 | Generate multivariate data without covariates.
65 |
66 | Eeach timeseries is generated from seasonality and trend. Important columns:
67 |
68 | * ``series``: series ID
69 | * ``time_idx``: time index
70 | * ``value``: target value
71 |
72 | Args:
73 | n_series (int, optional): Number of series. Defaults to 10.
74 | timesteps (int, optional): Number of timesteps. Defaults to 400.
75 | seasonality (float, optional): Normalized frequency, i.e. frequency is ``seasonality / timesteps``.
76 | Defaults to 3.0.
77 | trend (float, optional): Trend multiplier (seasonality is multiplied with 1.0). Defaults to 3.0.
78 | noise (float, optional): Level of gaussian noise. Defaults to 0.1.
79 | level (float, optional): Level multiplier (level is a constant to be aded to timeseries). Defaults to 1.0.
80 | exp (bool, optional): If to return exponential of timeseries values. Defaults to False.
81 | seed (int, optional): Random seed. Defaults to 213.
82 |
83 | Returns:
84 | pd.DataFrame: data
85 | """ # noqa: E501
86 | # sample parameters
87 | np.random.seed(seed)
88 | linear_trends = np.random.normal(size=n_series)[:, None] / timesteps
89 | quadratic_trends = np.random.normal(size=n_series)[:, None] / timesteps**2
90 | seasonalities = np.random.normal(size=n_series)[:, None]
91 | levels = level * np.random.normal(size=n_series)[:, None]
92 |
93 | # generate series
94 | x = np.arange(timesteps)[None, :]
95 | series = (
96 | x * linear_trends + x**2 * quadratic_trends
97 | ) * trend + seasonalities * np.sin(2 * np.pi * seasonality * x / timesteps)
98 | # add noise
99 | series = levels * series * (1 + noise * np.random.normal(size=series.shape))
100 | if exp:
101 | series = np.exp(series)
102 |
103 | # insert into dataframe
104 | data = (
105 | pd.DataFrame(series)
106 | .stack()
107 | .reset_index()
108 | .rename(columns={"level_0": "series", "level_1": "time_idx", 0: "value"})
109 | )
110 |
111 | return data
112 |
--------------------------------------------------------------------------------
/pytorch_forecasting/data/samplers.py:
--------------------------------------------------------------------------------
1 | """
2 | Samplers for sampling time series from the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`
3 | """ # noqa: E501
4 |
5 | import warnings
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from sklearn.utils import shuffle
10 | from torch.utils.data.sampler import Sampler
11 |
12 |
13 | class GroupedSampler(Sampler):
14 | """
15 | Samples mini-batches randomly but in a grouped manner.
16 |
17 | This means that the items from the different groups are always sampled together.
18 | This is an abstract class. Implement the :py:meth:`~get_groups` method which creates groups to be sampled from.
19 | """ # noqa: E501
20 |
21 | def __init__(
22 | self,
23 | sampler: Sampler,
24 | batch_size: int = 64,
25 | shuffle: bool = False,
26 | drop_last: bool = False,
27 | ):
28 | """
29 | Initialize.
30 |
31 | Args:
32 | sampler (Sampler or Iterable): Base sampler. Can be any iterable object
33 | drop_last (bool): if to drop last mini-batch from a group if it is smaller than batch_size.
34 | Defaults to False.
35 | shuffle (bool): if to shuffle dataset. Defaults to False.
36 | batch_size (int, optional): Number of samples in a mini-batch. This is rather the maximum number
37 | of samples. Because mini-batches are grouped by prediction time, chances are that there
38 | are multiple where batch size will be smaller than the maximum. Defaults to 64.
39 | """ # noqa: E501
40 | # Since collections.abc.Iterable does not check for `__getitem__`, which
41 | # is one way for an object to be an iterable, we don't do an `isinstance`
42 | # check here.
43 | if (
44 | not isinstance(batch_size, int)
45 | or isinstance(batch_size, bool)
46 | or batch_size <= 0
47 | ):
48 | raise ValueError(
49 | "batch_size should be a positive integer value, "
50 | f"but got batch_size={batch_size}"
51 | )
52 | if not isinstance(drop_last, bool):
53 | raise ValueError(
54 | f"drop_last should be a boolean value, but got drop_last={drop_last}"
55 | )
56 | self.sampler = sampler
57 | self.batch_size = batch_size
58 | self.drop_last = drop_last
59 | self.shuffle = shuffle
60 | # make groups and construct new index to sample from
61 | groups = self.get_groups(self.sampler)
62 | self.construct_batch_groups(groups)
63 |
64 | def get_groups(self, sampler: Sampler):
65 | """
66 | Create the groups which can be sampled.
67 |
68 | Args:
69 | sampler (Sampler): will have attribute data_source which is of type TimeSeriesDataSet.
70 |
71 | Returns:
72 | dict-like: dictionary-like object with data_source.index as values and group names as keys
73 | """ # noqa: E501
74 | raise NotImplementedError()
75 |
76 | def construct_batch_groups(self, groups):
77 | """
78 | Construct index of batches from which can be sampled
79 | """
80 | self._groups = groups
81 | # calculate sizes of groups
82 | self._group_sizes = {}
83 | warns = []
84 | for name, group in self._groups.items(): # iterate over groups
85 | if self.drop_last:
86 | self._group_sizes[name] = len(group) // self.batch_size
87 | else:
88 | self._group_sizes[name] = (
89 | len(group) + self.batch_size - 1
90 | ) // self.batch_size
91 | if self._group_sizes[name] == 0:
92 | self._group_sizes[name] = 1
93 | warns.append(name)
94 | if len(warns) > 0:
95 | warnings.warn(
96 | f"Less than {self.batch_size} samples available for "
97 | f"{len(warns)} prediction times. "
98 | f"Use batch size smaller than {self.batch_size}. "
99 | f"First 10 prediction times with small batch sizes: {warns[:10]}"
100 | )
101 | # create index from which can be sampled: index is equal to number of batches
102 | # associate index with prediction time
103 | self._group_index = np.repeat(
104 | list(self._group_sizes.keys()), list(self._group_sizes.values())
105 | )
106 | # associate index with batch within prediction time group
107 | self._sub_group_index = np.concatenate(
108 | [np.arange(size) for size in self._group_sizes.values()]
109 | )
110 |
111 | def __iter__(self):
112 | if self.shuffle: # shuffle samples
113 | groups = {name: shuffle(group) for name, group in self._groups.items()}
114 | batch_samples = np.random.permutation(len(self))
115 | else:
116 | groups = self._groups
117 | batch_samples = np.arange(len(self))
118 |
119 | for idx in batch_samples:
120 | name = self._group_index[idx]
121 | sub_group = self._sub_group_index[idx]
122 | sub_group_start = sub_group * self.batch_size
123 | sub_group_end = sub_group_start + self.batch_size
124 | batch = groups[name][sub_group_start:sub_group_end]
125 | yield batch
126 |
127 | def __len__(self):
128 | return len(self._group_index)
129 |
130 |
131 | class TimeSynchronizedBatchSampler(GroupedSampler):
132 | """
133 | Samples mini-batches randomly but in a time-synchronised manner.
134 |
135 | Time-synchornisation means that the time index of the first decoder samples are aligned across the batch.
136 | This sampler does not support missing values in the dataset.
137 | """ # noqa: E501
138 |
139 | def get_groups(self, sampler: Sampler):
140 | data_source = sampler.data_source
141 | index = data_source.index
142 | # get groups, i.e. group all samples by first predict time
143 | last_time = data_source.data["time"][index["index_end"].to_numpy()].numpy()
144 | decoder_lengths = data_source.calculate_decoder_length(
145 | last_time, index.sequence_length
146 | )
147 | first_prediction_time = index.time + index.sequence_length - decoder_lengths + 1
148 | groups = pd.RangeIndex(0, len(index.index)).groupby(first_prediction_time)
149 | return groups
150 |
--------------------------------------------------------------------------------
/pytorch_forecasting/data/timeseries/__init__.py:
--------------------------------------------------------------------------------
1 | """Data loaders for time series data."""
2 |
3 | from pytorch_forecasting.data.timeseries._timeseries import (
4 | TimeSeriesDataSet,
5 | _find_end_indices,
6 | check_for_nonfinite,
7 | )
8 | from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries
9 |
10 | __all__ = [
11 | "_find_end_indices",
12 | "check_for_nonfinite",
13 | "TimeSeriesDataSet",
14 | "TimeSeries",
15 | ]
16 |
--------------------------------------------------------------------------------
/pytorch_forecasting/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Metrics for (mulit-horizon) timeseries forecasting.
3 | """
4 |
5 | from pytorch_forecasting.metrics.base_metrics import (
6 | DistributionLoss,
7 | Metric,
8 | MultiHorizonMetric,
9 | MultiLoss,
10 | MultivariateDistributionLoss,
11 | convert_torchmetric_to_pytorch_forecasting_metric,
12 | )
13 | from pytorch_forecasting.metrics.distributions import (
14 | BetaDistributionLoss,
15 | ImplicitQuantileNetworkDistributionLoss,
16 | LogNormalDistributionLoss,
17 | MQF2DistributionLoss,
18 | MultivariateNormalDistributionLoss,
19 | NegativeBinomialDistributionLoss,
20 | NormalDistributionLoss,
21 | )
22 | from pytorch_forecasting.metrics.point import (
23 | MAE,
24 | MAPE,
25 | MASE,
26 | RMSE,
27 | SMAPE,
28 | CrossEntropy,
29 | PoissonLoss,
30 | TweedieLoss,
31 | )
32 | from pytorch_forecasting.metrics.quantile import QuantileLoss
33 |
34 | __all__ = [
35 | "MultiHorizonMetric",
36 | "DistributionLoss",
37 | "MultivariateDistributionLoss",
38 | "MultiLoss",
39 | "Metric",
40 | "convert_torchmetric_to_pytorch_forecasting_metric",
41 | "MAE",
42 | "MAPE",
43 | "MASE",
44 | "PoissonLoss",
45 | "TweedieLoss",
46 | "CrossEntropy",
47 | "SMAPE",
48 | "RMSE",
49 | "BetaDistributionLoss",
50 | "NegativeBinomialDistributionLoss",
51 | "NormalDistributionLoss",
52 | "LogNormalDistributionLoss",
53 | "MultivariateNormalDistributionLoss",
54 | "ImplicitQuantileNetworkDistributionLoss",
55 | "QuantileLoss",
56 | "MQF2DistributionLoss",
57 | ]
58 |
--------------------------------------------------------------------------------
/pytorch_forecasting/metrics/quantile.py:
--------------------------------------------------------------------------------
1 | """Quantile metrics for forecasting multiple quantiles per time step."""
2 |
3 | from typing import Optional
4 |
5 | import torch
6 |
7 | from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric
8 |
9 |
10 | class QuantileLoss(MultiHorizonMetric):
11 | """
12 | Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calculated as
13 |
14 | Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))``
15 | """ # noqa: E501
16 |
17 | def __init__(
18 | self,
19 | quantiles: Optional[list[float]] = None,
20 | **kwargs,
21 | ):
22 | """
23 | Quantile loss
24 |
25 | Args:
26 | quantiles: quantiles for metric
27 | """
28 | if quantiles is None:
29 | quantiles = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
30 | super().__init__(quantiles=quantiles, **kwargs)
31 |
32 | def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
33 | # calculate quantile loss
34 | losses = []
35 | for i, q in enumerate(self.quantiles):
36 | errors = target - y_pred[..., i]
37 | losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))
38 | losses = 2 * torch.cat(losses, dim=2)
39 |
40 | return losses
41 |
42 | def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
43 | """
44 | Convert network prediction into a point prediction.
45 |
46 | Args:
47 | y_pred: prediction output of network
48 |
49 | Returns:
50 | torch.Tensor: point prediction
51 | """
52 | if y_pred.ndim == 3:
53 | idx = self.quantiles.index(0.5)
54 | y_pred = y_pred[..., idx]
55 | return y_pred
56 |
57 | def to_quantiles(self, y_pred: torch.Tensor) -> torch.Tensor:
58 | """
59 | Convert network prediction into a quantile prediction.
60 |
61 | Args:
62 | y_pred: prediction output of network
63 |
64 | Returns:
65 | torch.Tensor: prediction quantiles
66 | """
67 | return y_pred
68 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Models for timeseries forecasting.
3 | """
4 |
5 | from pytorch_forecasting.models.base import (
6 | AutoRegressiveBaseModel,
7 | AutoRegressiveBaseModelWithCovariates,
8 | BaseModel,
9 | BaseModelWithCovariates,
10 | )
11 | from pytorch_forecasting.models.baseline import Baseline
12 | from pytorch_forecasting.models.deepar import DeepAR
13 | from pytorch_forecasting.models.mlp import DecoderMLP
14 | from pytorch_forecasting.models.nbeats import NBeats
15 | from pytorch_forecasting.models.nhits import NHiTS
16 | from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn
17 | from pytorch_forecasting.models.rnn import RecurrentNetwork
18 | from pytorch_forecasting.models.temporal_fusion_transformer import (
19 | TemporalFusionTransformer,
20 | )
21 | from pytorch_forecasting.models.tide import TiDEModel
22 |
23 | __all__ = [
24 | "NBeats",
25 | "NHiTS",
26 | "TemporalFusionTransformer",
27 | "RecurrentNetwork",
28 | "DeepAR",
29 | "BaseModel",
30 | "Baseline",
31 | "BaseModelWithCovariates",
32 | "AutoRegressiveBaseModel",
33 | "AutoRegressiveBaseModelWithCovariates",
34 | "get_rnn",
35 | "LSTM",
36 | "GRU",
37 | "MultiEmbedding",
38 | "DecoderMLP",
39 | "TiDEModel",
40 | ]
41 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/base/__init__.py:
--------------------------------------------------------------------------------
1 | """Base classes for pytorch-foercasting models."""
2 |
3 | from pytorch_forecasting.models.base._base_model import (
4 | AutoRegressiveBaseModel,
5 | AutoRegressiveBaseModelWithCovariates,
6 | BaseModel,
7 | BaseModelWithCovariates,
8 | Prediction,
9 | )
10 | from pytorch_forecasting.models.base._base_object import (
11 | _BaseObject,
12 | _BasePtForecaster,
13 | )
14 |
15 | __all__ = [
16 | "_BaseObject",
17 | "_BasePtForecaster",
18 | "AutoRegressiveBaseModel",
19 | "AutoRegressiveBaseModelWithCovariates",
20 | "BaseModel",
21 | "BaseModelWithCovariates",
22 | "Prediction",
23 | ]
24 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/base/_base_object.py:
--------------------------------------------------------------------------------
1 | """Base Classes for pytorch-forecasting models, skbase compatible for indexing."""
2 |
3 | import inspect
4 |
5 | from pytorch_forecasting.utils._dependencies import _safe_import
6 |
7 | _SkbaseBaseObject = _safe_import("skbase.base.BaseObject", pkg_name="scikit-base")
8 |
9 |
10 | class _BaseObject(_SkbaseBaseObject):
11 | pass
12 |
13 |
14 | class _BasePtForecaster(_BaseObject):
15 | """Base class for all PyTorch Forecasting forecaster metadata.
16 |
17 | This class points to model objects and contains metadata as tags.
18 | """
19 |
20 | _tags = {
21 | "object_type": "forecaster_pytorch",
22 | }
23 |
24 | @classmethod
25 | def get_model_cls(cls):
26 | """Get model class."""
27 | raise NotImplementedError
28 |
29 | @classmethod
30 | def name(cls):
31 | """Get model name."""
32 | name = cls.get_class_tags().get("info:name", None)
33 | if name is None:
34 | name = cls.get_model_cls().__name__
35 | return name
36 |
37 | @classmethod
38 | def create_test_instance(cls, parameter_set="default"):
39 | """Construct an instance of the class, using first test parameter set.
40 |
41 | Parameters
42 | ----------
43 | parameter_set : str, default="default"
44 | Name of the set of test parameters to return, for use in tests. If no
45 | special parameters are defined for a value, will return `"default"` set.
46 |
47 | Returns
48 | -------
49 | instance : instance of the class with default parameters
50 |
51 | """
52 | if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
53 | params = cls.get_test_params(parameter_set=parameter_set)
54 | else:
55 | params = cls.get_test_params()
56 |
57 | if isinstance(params, list) and isinstance(params[0], dict):
58 | params = params[0]
59 | elif isinstance(params, dict):
60 | pass
61 | else:
62 | raise TypeError(
63 | "get_test_params should either return a dict or list of dict."
64 | )
65 |
66 | return cls.get_model_cls()(**params)
67 |
68 | @classmethod
69 | def create_test_instances_and_names(cls, parameter_set="default"):
70 | """Create list of all test instances and a list of names for them.
71 |
72 | Parameters
73 | ----------
74 | parameter_set : str, default="default"
75 | Name of the set of test parameters to return, for use in tests. If no
76 | special parameters are defined for a value, will return `"default"` set.
77 |
78 | Returns
79 | -------
80 | objs : list of instances of cls
81 | i-th instance is ``cls(**cls.get_test_params()[i])``
82 | names : list of str, same length as objs
83 | i-th element is name of i-th instance of obj in tests.
84 | The naming convention is ``{cls.__name__}-{i}`` if more than one instance,
85 | otherwise ``{cls.__name__}``
86 | """
87 | if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args:
88 | param_list = cls.get_test_params(parameter_set=parameter_set)
89 | else:
90 | param_list = cls.get_test_params()
91 |
92 | objs = []
93 | if not isinstance(param_list, (dict, list)):
94 | raise RuntimeError(
95 | f"Error in {cls.__name__}.get_test_params, "
96 | "return must be param dict for class, or list thereof"
97 | )
98 | if isinstance(param_list, dict):
99 | param_list = [param_list]
100 | for params in param_list:
101 | if not isinstance(params, dict):
102 | raise RuntimeError(
103 | f"Error in {cls.__name__}.get_test_params, "
104 | "return must be param dict for class, or list thereof"
105 | )
106 | objs += [cls.get_model_cls()(**params)]
107 |
108 | num_instances = len(param_list)
109 | if num_instances > 1:
110 | names = [cls.__name__ + "-" + str(i) for i in range(num_instances)]
111 | else:
112 | names = [cls.__name__]
113 |
114 | return objs, names
115 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/base_model.py:
--------------------------------------------------------------------------------
1 | """Base classes for pytorch-foercasting models."""
2 |
3 | from pytorch_forecasting.models.base import (
4 | AutoRegressiveBaseModel,
5 | AutoRegressiveBaseModelWithCovariates,
6 | BaseModel,
7 | BaseModelWithCovariates,
8 | Prediction,
9 | )
10 |
11 | __all__ = [
12 | "AutoRegressiveBaseModel",
13 | "AutoRegressiveBaseModelWithCovariates",
14 | "BaseModel",
15 | "BaseModelWithCovariates",
16 | "Prediction",
17 | ]
18 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/baseline.py:
--------------------------------------------------------------------------------
1 | """
2 | Baseline model.
3 | """
4 |
5 | from typing import Any
6 |
7 | import torch
8 |
9 | from pytorch_forecasting.models import BaseModel
10 |
11 |
12 | class Baseline(BaseModel):
13 | """
14 | Baseline model that uses last known target value to make prediction.
15 |
16 | Example:
17 |
18 | .. code-block:: python
19 |
20 | from pytorch_forecasting import BaseModel, MAE
21 |
22 | # generating predictions
23 | predictions = Baseline().predict(dataloader)
24 |
25 | # calculate baseline performance in terms of mean absolute error (MAE)
26 | metric = MAE()
27 | model = Baseline()
28 | for x, y in dataloader:
29 | metric.update(model(x), y)
30 |
31 | metric.compute()
32 | """
33 |
34 | def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
35 | """
36 | Network forward pass.
37 |
38 | Args:
39 | x (Dict[str, torch.Tensor]): network input
40 |
41 | Returns:
42 | Dict[str, torch.Tensor]: netowrk outputs
43 | """
44 | if isinstance(x["encoder_target"], (list, tuple)): # multiple targets
45 | prediction = [
46 | self.forward_one_target(
47 | encoder_lengths=x["encoder_lengths"],
48 | decoder_lengths=x["decoder_lengths"],
49 | encoder_target=encoder_target,
50 | )
51 | for encoder_target in x["encoder_target"]
52 | ]
53 | else: # one target
54 | prediction = self.forward_one_target(
55 | encoder_lengths=x["encoder_lengths"],
56 | decoder_lengths=x["decoder_lengths"],
57 | encoder_target=x["encoder_target"],
58 | )
59 | return self.to_network_output(prediction=prediction)
60 |
61 | def forward_one_target(
62 | self,
63 | encoder_lengths: torch.Tensor,
64 | decoder_lengths: torch.Tensor,
65 | encoder_target: torch.Tensor,
66 | ):
67 | max_prediction_length = decoder_lengths.max()
68 | assert (
69 | encoder_lengths.min() > 0
70 | ), "Encoder lengths of at least 1 required to obtain last value"
71 | last_values = encoder_target[
72 | torch.arange(encoder_target.size(0)), encoder_lengths - 1
73 | ]
74 | prediction = last_values[:, None].expand(-1, max_prediction_length)
75 | return prediction
76 |
77 | def to_prediction(self, out: dict[str, Any], use_metric: bool = True, **kwargs):
78 | return out.prediction
79 |
80 | def to_quantiles(self, out: dict[str, Any], use_metric: bool = True, **kwargs):
81 | return out.prediction[..., None]
82 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/deepar/__init__.py:
--------------------------------------------------------------------------------
1 | """DeepAR: Probabilistic forecasting with autoregressive recurrent networks."""
2 |
3 | from pytorch_forecasting.models.deepar._deepar import DeepAR
4 | from pytorch_forecasting.models.deepar._deepar_metadata import DeepARMetadata
5 |
6 | __all__ = ["DeepAR", "DeepARMetadata"]
7 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/deepar/_deepar_metadata.py:
--------------------------------------------------------------------------------
1 | """DeepAR metadata container."""
2 |
3 | from pytorch_forecasting.models.base._base_object import _BasePtForecaster
4 |
5 |
6 | class DeepARMetadata(_BasePtForecaster):
7 | """DeepAR metadata container."""
8 |
9 | _tags = {
10 | "info:name": "DeepAR",
11 | "info:compute": 3,
12 | "authors": ["jdb78"],
13 | "capability:exogenous": True,
14 | "capability:multivariate": True,
15 | "capability:pred_int": True,
16 | "capability:flexible_history_length": True,
17 | "capability:cold_start": False,
18 | }
19 |
20 | @classmethod
21 | def get_model_cls(cls):
22 | """Get model class."""
23 | from pytorch_forecasting.models import DeepAR
24 |
25 | return DeepAR
26 |
27 | @classmethod
28 | def get_test_train_params(cls):
29 | """Return testing parameter settings for the trainer.
30 |
31 | Returns
32 | -------
33 | params : dict or list of dict, default = {}
34 | Parameters to create testing instances of the class
35 | Each dict are parameters to construct an "interesting" test instance, i.e.,
36 | `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
37 | `create_test_instance` uses the first (or only) dictionary in `params`
38 | """
39 | from pytorch_forecasting.data.encoders import GroupNormalizer
40 | from pytorch_forecasting.metrics import (
41 | BetaDistributionLoss,
42 | ImplicitQuantileNetworkDistributionLoss,
43 | LogNormalDistributionLoss,
44 | MultivariateNormalDistributionLoss,
45 | NegativeBinomialDistributionLoss,
46 | )
47 |
48 | params = [
49 | {},
50 | {"cell_type": "GRU"},
51 | dict(
52 | loss=LogNormalDistributionLoss(),
53 | clip_target=True,
54 | data_loader_kwargs=dict(
55 | target_normalizer=GroupNormalizer(
56 | groups=["agency", "sku"], transformation="log"
57 | )
58 | ),
59 | ),
60 | dict(
61 | loss=NegativeBinomialDistributionLoss(),
62 | clip_target=False,
63 | data_loader_kwargs=dict(
64 | target_normalizer=GroupNormalizer(
65 | groups=["agency", "sku"], center=False
66 | )
67 | ),
68 | ),
69 | dict(
70 | loss=BetaDistributionLoss(),
71 | clip_target=True,
72 | data_loader_kwargs=dict(
73 | target_normalizer=GroupNormalizer(
74 | groups=["agency", "sku"], transformation="logit"
75 | )
76 | ),
77 | ),
78 | dict(
79 | data_loader_kwargs=dict(
80 | lags={"volume": [2, 5]},
81 | target="volume",
82 | time_varying_unknown_reals=["volume"],
83 | min_encoder_length=2,
84 | ),
85 | ),
86 | dict(
87 | data_loader_kwargs=dict(
88 | time_varying_unknown_reals=["volume", "discount"],
89 | target=["volume", "discount"],
90 | lags={"volume": [2], "discount": [2]},
91 | ),
92 | ),
93 | dict(
94 | loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8),
95 | ),
96 | dict(
97 | loss=MultivariateNormalDistributionLoss(),
98 | trainer_kwargs=dict(accelerator="cpu"),
99 | ),
100 | dict(
101 | loss=MultivariateNormalDistributionLoss(),
102 | data_loader_kwargs=dict(
103 | target_normalizer=GroupNormalizer(
104 | groups=["agency", "sku"], transformation="log1p"
105 | )
106 | ),
107 | trainer_kwargs=dict(accelerator="cpu"),
108 | ),
109 | ]
110 | defaults = {
111 | "hidden_size": 5,
112 | "cell_type": "LSTM",
113 | "n_plotting_samples": 100,
114 | }
115 | for param in params:
116 | param.update(defaults)
117 | return params
118 |
119 | @classmethod
120 | def _get_test_dataloaders_from(cls, params):
121 | """Get dataloaders from parameters.
122 |
123 | Parameters
124 | ----------
125 | params : dict
126 | Parameters to create dataloaders.
127 | One of the elements in the list returned by ``get_test_train_params``.
128 |
129 | Returns
130 | -------
131 | dataloaders : dict with keys "train", "val", "test", values torch DataLoader
132 | Dict of dataloaders created from the parameters.
133 | Train, validation, and test dataloaders.
134 | """
135 | loss = params.get("loss", None)
136 | clip_target = params.get("clip_target", False)
137 | data_loader_kwargs = params.get("data_loader_kwargs", {})
138 |
139 | from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss
140 | from pytorch_forecasting.tests._conftest import make_dataloaders
141 | from pytorch_forecasting.tests._data_scenarios import data_with_covariates
142 |
143 | dwc = data_with_covariates()
144 |
145 | if isinstance(loss, NegativeBinomialDistributionLoss):
146 | dwc = dwc.assign(volume=lambda x: x.volume.round())
147 |
148 | dwc = dwc.copy()
149 | if clip_target:
150 | dwc["target"] = dwc["volume"].clip(1e-3, 1.0)
151 | else:
152 | dwc["target"] = dwc["volume"]
153 | data_loader_default_kwargs = dict(
154 | target="target",
155 | time_varying_known_reals=["price_actual"],
156 | time_varying_unknown_reals=["target"],
157 | static_categoricals=["agency"],
158 | add_relative_time_idx=True,
159 | )
160 | data_loader_default_kwargs.update(data_loader_kwargs)
161 | dataloaders_w_covariates = make_dataloaders(dwc, **data_loader_default_kwargs)
162 | return dataloaders_w_covariates
163 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/mlp/__init__.py:
--------------------------------------------------------------------------------
1 | """Simple models based on fully connected networks."""
2 |
3 | from pytorch_forecasting.models.mlp._decodermlp import DecoderMLP
4 | from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule
5 |
6 | __all__ = ["DecoderMLP", "FullyConnectedModule"]
7 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/mlp/submodules.py:
--------------------------------------------------------------------------------
1 | """
2 | MLP implementation
3 | """
4 |
5 | import torch
6 | from torch import nn
7 |
8 |
9 | class FullyConnectedModule(nn.Module):
10 | def __init__(
11 | self,
12 | input_size: int,
13 | output_size: int,
14 | hidden_size: int,
15 | n_hidden_layers: int,
16 | activation_class: nn.ReLU,
17 | dropout: float = None,
18 | norm: bool = True,
19 | ):
20 | super().__init__()
21 | self.input_size = input_size
22 | self.output_size = output_size
23 | self.hidden_size = hidden_size
24 | self.n_hidden_layers = n_hidden_layers
25 | self.activation_class = activation_class
26 | self.dropout = dropout
27 | self.norm = norm
28 |
29 | # input layer
30 | module_list = [nn.Linear(input_size, hidden_size), activation_class()]
31 | if dropout is not None:
32 | module_list.append(nn.Dropout(dropout))
33 | if norm:
34 | module_list.append(nn.LayerNorm(hidden_size))
35 | # hidden layers
36 | for _ in range(n_hidden_layers):
37 | module_list.extend(
38 | [nn.Linear(hidden_size, hidden_size), activation_class()]
39 | )
40 | if dropout is not None:
41 | module_list.append(nn.Dropout(dropout))
42 | if norm:
43 | module_list.append(nn.LayerNorm(hidden_size))
44 | # output layer
45 | module_list.append(nn.Linear(hidden_size, output_size))
46 |
47 | self.sequential = nn.Sequential(*module_list)
48 |
49 | def forward(self, x: torch.Tensor) -> torch.Tensor:
50 | # x of shape: batch_size x n_timesteps_in
51 | # output of shape batch_size x n_timesteps_out
52 | return self.sequential(x)
53 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/nbeats/__init__.py:
--------------------------------------------------------------------------------
1 | """N-Beats model for timeseries forecasting without covariates."""
2 |
3 | from pytorch_forecasting.models.nbeats._nbeats import NBeats
4 | from pytorch_forecasting.models.nbeats._nbeats_metadata import NBeatsMetadata
5 | from pytorch_forecasting.models.nbeats.sub_modules import (
6 | NBEATSGenericBlock,
7 | NBEATSSeasonalBlock,
8 | NBEATSTrendBlock,
9 | )
10 |
11 | __all__ = [
12 | "NBeats",
13 | "NBEATSGenericBlock",
14 | "NBeatsMetadata",
15 | "NBEATSSeasonalBlock",
16 | "NBEATSTrendBlock",
17 | ]
18 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/nbeats/_nbeats_metadata.py:
--------------------------------------------------------------------------------
1 | """NBeats metadata container."""
2 |
3 | from pytorch_forecasting.models.base._base_object import _BasePtForecaster
4 |
5 |
6 | class NBeatsMetadata(_BasePtForecaster):
7 | """NBeats metadata container."""
8 |
9 | _tags = {
10 | "info:name": "NBeats",
11 | "info:compute": 1,
12 | "authors": ["jdb78"],
13 | "capability:exogenous": False,
14 | "capability:multivariate": False,
15 | "capability:pred_int": False,
16 | "capability:flexible_history_length": False,
17 | "capability:cold_start": False,
18 | }
19 |
20 | @classmethod
21 | def get_model_cls(cls):
22 | """Get model class."""
23 | from pytorch_forecasting.models import NBeats
24 |
25 | return NBeats
26 |
27 | @classmethod
28 | def get_test_train_params(cls):
29 | """Return testing parameter settings for the trainer.
30 |
31 | Returns
32 | -------
33 | params : dict or list of dict, default = {}
34 | Parameters to create testing instances of the class
35 | Each dict are parameters to construct an "interesting" test instance, i.e.,
36 | `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
37 | `create_test_instance` uses the first (or only) dictionary in `params`
38 | """
39 | return [{"backcast_loss_ratio": 1.0}]
40 |
41 | @classmethod
42 | def _get_test_dataloaders_from(cls, params):
43 | """Get dataloaders from parameters.
44 |
45 | Parameters
46 | ----------
47 | params : dict
48 | Parameters to create dataloaders.
49 | One of the elements in the list returned by ``get_test_train_params``.
50 |
51 | Returns
52 | -------
53 | dataloaders : dict with keys "train", "val", "test", values torch DataLoader
54 | Dict of dataloaders created from the parameters.
55 | Train, validation, and test dataloaders, in this order.
56 | """
57 | from pytorch_forecasting.tests._conftest import (
58 | _dataloaders_fixed_window_without_covariates,
59 | )
60 |
61 | return _dataloaders_fixed_window_without_covariates()
62 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/nbeats/sub_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of ``nn.Modules`` for N-Beats model.
3 | """
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def linear(input_size, output_size, bias=True, dropout: int = None):
12 | lin = nn.Linear(input_size, output_size, bias=bias)
13 | if dropout is not None:
14 | return nn.Sequential(nn.Dropout(dropout), lin)
15 | else:
16 | return lin
17 |
18 |
19 | def linspace(
20 | backcast_length: int, forecast_length: int, centered: bool = False
21 | ) -> tuple[np.ndarray, np.ndarray]:
22 | if centered:
23 | norm = max(backcast_length, forecast_length)
24 | start = -backcast_length
25 | stop = forecast_length - 1
26 | else:
27 | norm = backcast_length + forecast_length
28 | start = 0
29 | stop = backcast_length + forecast_length - 1
30 | lin_space = np.linspace(
31 | start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32
32 | )
33 | b_ls = lin_space[:backcast_length]
34 | f_ls = lin_space[backcast_length:]
35 | return b_ls, f_ls
36 |
37 |
38 | class NBEATSBlock(nn.Module):
39 | def __init__(
40 | self,
41 | units,
42 | thetas_dim,
43 | num_block_layers=4,
44 | backcast_length=10,
45 | forecast_length=5,
46 | share_thetas=False,
47 | dropout=0.1,
48 | ):
49 | super().__init__()
50 | self.units = units
51 | self.thetas_dim = thetas_dim
52 | self.backcast_length = backcast_length
53 | self.forecast_length = forecast_length
54 | self.share_thetas = share_thetas
55 |
56 | fc_stack = [
57 | nn.Linear(backcast_length, units),
58 | nn.ReLU(),
59 | ]
60 | for _ in range(num_block_layers - 1):
61 | fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()])
62 | self.fc = nn.Sequential(*fc_stack)
63 |
64 | if share_thetas:
65 | self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)
66 | else:
67 | self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False)
68 | self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False)
69 |
70 | def forward(self, x):
71 | return self.fc(x)
72 |
73 |
74 | class NBEATSSeasonalBlock(NBEATSBlock):
75 | def __init__(
76 | self,
77 | units,
78 | thetas_dim=None,
79 | num_block_layers=4,
80 | backcast_length=10,
81 | forecast_length=5,
82 | nb_harmonics=None,
83 | min_period=1,
84 | dropout=0.1,
85 | ):
86 | if nb_harmonics:
87 | thetas_dim = nb_harmonics
88 | else:
89 | thetas_dim = forecast_length
90 | self.min_period = min_period
91 |
92 | super().__init__(
93 | units=units,
94 | thetas_dim=thetas_dim,
95 | num_block_layers=num_block_layers,
96 | backcast_length=backcast_length,
97 | forecast_length=forecast_length,
98 | share_thetas=True,
99 | dropout=dropout,
100 | )
101 |
102 | backcast_linspace, forecast_linspace = linspace(
103 | backcast_length, forecast_length, centered=False
104 | )
105 |
106 | p1, p2 = (
107 | (thetas_dim // 2, thetas_dim // 2)
108 | if thetas_dim % 2 == 0
109 | else (thetas_dim // 2, thetas_dim // 2 + 1)
110 | )
111 | s1_b = torch.tensor(
112 | np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * backcast_linspace),
113 | dtype=torch.float32,
114 | ) # H/2-1
115 | s2_b = torch.tensor(
116 | np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * backcast_linspace),
117 | dtype=torch.float32,
118 | )
119 | self.register_buffer("S_backcast", torch.cat([s1_b, s2_b]))
120 |
121 | s1_f = torch.tensor(
122 | np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * forecast_linspace),
123 | dtype=torch.float32,
124 | ) # H/2-1
125 | s2_f = torch.tensor(
126 | np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * forecast_linspace),
127 | dtype=torch.float32,
128 | )
129 | self.register_buffer("S_forecast", torch.cat([s1_f, s2_f]))
130 |
131 | def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]:
132 | x = super().forward(x)
133 | amplitudes_backward = self.theta_b_fc(x)
134 | backcast = amplitudes_backward.mm(self.S_backcast)
135 | amplitudes_forward = self.theta_f_fc(x)
136 | forecast = amplitudes_forward.mm(self.S_forecast)
137 |
138 | return backcast, forecast
139 |
140 | def get_frequencies(self, n):
141 | return np.linspace(
142 | 0, (self.backcast_length + self.forecast_length) / self.min_period, n
143 | )
144 |
145 |
146 | class NBEATSTrendBlock(NBEATSBlock):
147 | def __init__(
148 | self,
149 | units,
150 | thetas_dim,
151 | num_block_layers=4,
152 | backcast_length=10,
153 | forecast_length=5,
154 | dropout=0.1,
155 | ):
156 | super().__init__(
157 | units=units,
158 | thetas_dim=thetas_dim,
159 | num_block_layers=num_block_layers,
160 | backcast_length=backcast_length,
161 | forecast_length=forecast_length,
162 | share_thetas=True,
163 | dropout=dropout,
164 | )
165 |
166 | backcast_linspace, forecast_linspace = linspace(
167 | backcast_length, forecast_length, centered=True
168 | )
169 | norm = np.sqrt(
170 | forecast_length / thetas_dim
171 | ) # ensure range of predictions is comparable to input
172 | thetas_dims_range = np.array(range(thetas_dim))
173 | coefficients = torch.tensor(
174 | backcast_linspace ** thetas_dims_range[:, None],
175 | dtype=torch.float32,
176 | )
177 | self.register_buffer("T_backcast", coefficients * norm)
178 | coefficients = torch.tensor(
179 | forecast_linspace ** thetas_dims_range[:, None],
180 | dtype=torch.float32,
181 | )
182 | self.register_buffer("T_forecast", coefficients * norm)
183 |
184 | def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]:
185 | x = super().forward(x)
186 | backcast = self.theta_b_fc(x).mm(self.T_backcast)
187 | forecast = self.theta_f_fc(x).mm(self.T_forecast)
188 | return backcast, forecast
189 |
190 |
191 | class NBEATSGenericBlock(NBEATSBlock):
192 | def __init__(
193 | self,
194 | units,
195 | thetas_dim,
196 | num_block_layers=4,
197 | backcast_length=10,
198 | forecast_length=5,
199 | dropout=0.1,
200 | ):
201 | super().__init__(
202 | units=units,
203 | thetas_dim=thetas_dim,
204 | num_block_layers=num_block_layers,
205 | backcast_length=backcast_length,
206 | forecast_length=forecast_length,
207 | dropout=dropout,
208 | )
209 |
210 | self.backcast_fc = nn.Linear(thetas_dim, backcast_length)
211 | self.forecast_fc = nn.Linear(thetas_dim, forecast_length)
212 |
213 | def forward(self, x):
214 | x = super().forward(x)
215 |
216 | theta_b = F.relu(self.theta_b_fc(x))
217 | theta_f = F.relu(self.theta_f_fc(x))
218 |
219 | return self.backcast_fc(theta_b), self.forecast_fc(theta_f)
220 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/nhits/__init__.py:
--------------------------------------------------------------------------------
1 | """N-HiTS model for timeseries forecasting with covariates."""
2 |
3 | from pytorch_forecasting.models.nhits._nhits import NHiTS
4 | from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule
5 |
6 | __all__ = ["NHits", "NHiTSModule"]
7 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from pytorch_forecasting.models.nn.embeddings import MultiEmbedding
2 | from pytorch_forecasting.models.nn.rnn import GRU, LSTM, HiddenState, get_rnn
3 | from pytorch_forecasting.utils import TupleOutputMixIn
4 |
5 | __all__ = [
6 | "MultiEmbedding",
7 | "get_rnn",
8 | "LSTM",
9 | "GRU",
10 | "HiddenState",
11 | "TupleOutputMixIn",
12 | ]
13 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/rnn/__init__.py:
--------------------------------------------------------------------------------
1 | """Simple recurrent model - either with LSTM or GRU cells."""
2 |
3 | from pytorch_forecasting.models.rnn._rnn import RecurrentNetwork
4 |
5 | __all__ = ["RecurrentNetwork"]
6 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py:
--------------------------------------------------------------------------------
1 | """Temporal fusion transformer for forecasting timeseries."""
2 |
3 | from pytorch_forecasting.models.temporal_fusion_transformer._tft import (
4 | TemporalFusionTransformer,
5 | )
6 | from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import (
7 | AddNorm,
8 | GateAddNorm,
9 | GatedLinearUnit,
10 | GatedResidualNetwork,
11 | InterpretableMultiHeadAttention,
12 | VariableSelectionNetwork,
13 | )
14 |
15 | __all__ = [
16 | "TemporalFusionTransformer",
17 | "AddNorm",
18 | "GateAddNorm",
19 | "GatedLinearUnit",
20 | "GatedResidualNetwork",
21 | "InterpretableMultiHeadAttention",
22 | "VariableSelectionNetwork",
23 | ]
24 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/tide/__init__.py:
--------------------------------------------------------------------------------
1 | """Tide model."""
2 |
3 | from pytorch_forecasting.models.tide._tide import TiDEModel
4 | from pytorch_forecasting.models.tide._tide_metadata import TiDEModelMetadata
5 | from pytorch_forecasting.models.tide.sub_modules import _TideModule
6 |
7 | __all__ = [
8 | "_TideModule",
9 | "TiDEModel",
10 | "TiDEModelMetadata",
11 | ]
12 |
--------------------------------------------------------------------------------
/pytorch_forecasting/models/tide/_tide_metadata.py:
--------------------------------------------------------------------------------
1 | """TiDE metadata container."""
2 |
3 | from pytorch_forecasting.models.base._base_object import _BasePtForecaster
4 |
5 |
6 | class TiDEModelMetadata(_BasePtForecaster):
7 | """Metadata container for TiDE Model."""
8 |
9 | _tags = {
10 | "info:name": "TiDEModel",
11 | "info:compute": 3,
12 | "authors": ["Sohaib-Ahmed21"],
13 | "capability:exogenous": True,
14 | "capability:multivariate": True,
15 | "capability:pred_int": True,
16 | "capability:flexible_history_length": True,
17 | "capability:cold_start": False,
18 | }
19 |
20 | @classmethod
21 | def get_model_cls(cls):
22 | """Get model class."""
23 | from pytorch_forecasting.models.tide import TiDEModel
24 |
25 | return TiDEModel
26 |
27 | @classmethod
28 | def get_test_train_params(cls):
29 | """Return testing parameter settings for the trainer.
30 |
31 | Returns
32 | -------
33 | params : dict or list of dict, default = {}
34 | Parameters to create testing instances of the class.
35 | """
36 |
37 | from pytorch_forecasting.data.encoders import GroupNormalizer
38 | from pytorch_forecasting.metrics import SMAPE
39 |
40 | params = [
41 | {
42 | "data_loader_kwargs": dict(
43 | add_relative_time_idx=False,
44 | # must include this everytime since the data_loader_default_kwargs
45 | # include this to be True.
46 | )
47 | },
48 | {
49 | "temporal_decoder_hidden": 16,
50 | "data_loader_kwargs": dict(add_relative_time_idx=False),
51 | },
52 | {
53 | "dropout": 0.2,
54 | "use_layer_norm": True,
55 | "loss": SMAPE(),
56 | "data_loader_kwargs": dict(
57 | target_normalizer=GroupNormalizer(
58 | groups=["agency", "sku"], transformation="softplus"
59 | ),
60 | add_relative_time_idx=False,
61 | ),
62 | },
63 | ]
64 | defaults = {"hidden_size": 5}
65 | for param in params:
66 | param.update(defaults)
67 | return params
68 |
69 | @classmethod
70 | def _get_test_dataloaders_from(cls, params):
71 | """Get dataloaders from parameters.
72 |
73 | Parameters
74 | ----------
75 | params : dict
76 | Parameters to create dataloaders.
77 | One of the elements in the list returned by ``get_test_train_params``.
78 |
79 | Returns
80 | -------
81 | dataloaders : dict with keys "train", "val", "test", values torch DataLoader
82 | Dict of dataloaders created from the parameters.
83 | Train, validation, and test dataloaders.
84 | """
85 | trainer_kwargs = params.get("trainer_kwargs", {})
86 | clip_target = params.get("clip_target", False)
87 | data_loader_kwargs = params.get("data_loader_kwargs", {})
88 |
89 | from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss
90 | from pytorch_forecasting.tests._conftest import make_dataloaders
91 | from pytorch_forecasting.tests._data_scenarios import data_with_covariates
92 |
93 | dwc = data_with_covariates()
94 |
95 | if "loss" in trainer_kwargs and isinstance(
96 | trainer_kwargs["loss"], NegativeBinomialDistributionLoss
97 | ):
98 | dwc = dwc.assign(volume=lambda x: x.volume.round())
99 |
100 | dwc = dwc.copy()
101 | if clip_target:
102 | dwc["target"] = dwc["volume"].clip(1e-3, 1.0)
103 | else:
104 | dwc["target"] = dwc["volume"]
105 | data_loader_default_kwargs = dict(
106 | target="target",
107 | time_varying_known_reals=["price_actual"],
108 | time_varying_unknown_reals=["target"],
109 | static_categoricals=["agency"],
110 | add_relative_time_idx=True,
111 | )
112 | data_loader_default_kwargs.update(data_loader_kwargs)
113 | dataloaders_w_covariates = make_dataloaders(dwc, **data_loader_default_kwargs)
114 | return dataloaders_w_covariates
115 |
--------------------------------------------------------------------------------
/pytorch_forecasting/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """PyTorch Forecasting test suite."""
2 |
--------------------------------------------------------------------------------
/pytorch_forecasting/tests/_config.py:
--------------------------------------------------------------------------------
1 | """Test configs."""
2 |
3 | # list of str, names of estimators to exclude from testing
4 | # WARNING: tests for these estimators will be skipped
5 | EXCLUDE_ESTIMATORS = [
6 | "DummySkipped",
7 | "ClassName", # exclude classes from extension templates
8 | ]
9 |
10 | # dictionary of lists of str, names of tests to exclude from testing
11 | # keys are class names of estimators, values are lists of test names to exclude
12 | # WARNING: tests with these names will be skipped
13 | EXCLUDED_TESTS = {}
14 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | PyTorch Forecasting package for timeseries forecasting with PyTorch.
3 | """
4 |
5 | from pytorch_forecasting.utils._utils import (
6 | InitialParameterRepresenterMixIn,
7 | OutputMixIn,
8 | TupleOutputMixIn,
9 | apply_to_list,
10 | autocorrelation,
11 | concat_sequences,
12 | create_mask,
13 | detach,
14 | get_embedding_size,
15 | groupby_apply,
16 | integer_histogram,
17 | masked_op,
18 | move_to_device,
19 | padded_stack,
20 | profile,
21 | redirect_stdout,
22 | repr_class,
23 | to_list,
24 | unpack_sequence,
25 | unsqueeze_like,
26 | )
27 |
28 | __all__ = [
29 | "InitialParameterRepresenterMixIn",
30 | "OutputMixIn",
31 | "TupleOutputMixIn",
32 | "apply_to_list",
33 | "autocorrelation",
34 | "get_embedding_size",
35 | "concat_sequences",
36 | "create_mask",
37 | "to_list",
38 | "RecurrentNetwork",
39 | "DecoderMLP",
40 | "detach",
41 | "masked_op",
42 | "move_to_device",
43 | "integer_histogram",
44 | "groupby_apply",
45 | "padded_stack",
46 | "profile",
47 | "redirect_stdout",
48 | "repr_class",
49 | "unpack_sequence",
50 | "unsqueeze_like",
51 | ]
52 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_coerce.py:
--------------------------------------------------------------------------------
1 | """Coercion functions for various data types."""
2 |
3 | from copy import deepcopy
4 |
5 |
6 | def _coerce_to_list(obj):
7 | """Coerce object to list.
8 |
9 | None is coerced to empty list, otherwise list constructor is used.
10 | """
11 | if obj is None:
12 | return []
13 | if isinstance(obj, str):
14 | return [obj]
15 | return list(obj)
16 |
17 |
18 | def _coerce_to_dict(obj):
19 | """Coerce object to dict.
20 |
21 | None is coerce to empty dict, otherwise deepcopy is used.
22 | """
23 | if obj is None:
24 | return {}
25 | return deepcopy(obj)
26 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_dependencies/__init__.py:
--------------------------------------------------------------------------------
1 | """Utilities for managing dependencies."""
2 |
3 | from pytorch_forecasting.utils._dependencies._dependencies import (
4 | _check_matplotlib,
5 | _get_installed_packages,
6 | )
7 | from pytorch_forecasting.utils._dependencies._safe_import import _safe_import
8 |
9 | __all__ = [
10 | "_get_installed_packages",
11 | "_check_matplotlib",
12 | "_safe_import",
13 | ]
14 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_dependencies/_dependencies.py:
--------------------------------------------------------------------------------
1 | """Utilities for managing dependencies.
2 |
3 | Copied from sktime/skbase.
4 | """
5 |
6 | from functools import lru_cache
7 |
8 |
9 | @lru_cache
10 | def _get_installed_packages_private():
11 | """Get a dictionary of installed packages and their versions.
12 |
13 | Same as _get_installed_packages, but internal to avoid mutating the lru_cache
14 | by accident.
15 | """
16 | from importlib.metadata import distributions, version
17 |
18 | dists = distributions()
19 | package_names = {dist.metadata["Name"] for dist in dists}
20 | package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names}
21 | # developer note:
22 | # we cannot just use distributions naively,
23 | # because the same top level package name may appear *twice*,
24 | # e.g., in a situation where a virtual env overrides a base env,
25 | # such as in deployment environments like databricks.
26 | # the "version" contract ensures we always get the version that corresponds
27 | # to the importable distribution, i.e., the top one in the sys.path.
28 | return package_versions
29 |
30 |
31 | def _get_installed_packages():
32 | """Get a dictionary of installed packages and their versions.
33 |
34 | Returns
35 | -------
36 | dict : dictionary of installed packages and their versions
37 | keys are PEP 440 compatible package names, values are package versions
38 | MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3"
39 | """
40 | return _get_installed_packages_private().copy()
41 |
42 |
43 | def _check_matplotlib(ref="This feature", raise_error=True):
44 | """Check if matplotlib is installed.
45 |
46 | Parameters
47 | ----------
48 | ref : str, optional (default="This feature")
49 | reference to the feature that requires matplotlib, used in error message
50 | raise_error : bool, optional (default=True)
51 | whether to raise an error if matplotlib is not installed
52 |
53 | Returns
54 | -------
55 | bool : whether matplotlib is installed
56 | """
57 | pkgs = _get_installed_packages()
58 |
59 | if raise_error and "matplotlib" not in pkgs:
60 | raise ImportError(
61 | f"{ref} requires matplotlib."
62 | " Please install matplotlib with `pip install matplotlib`."
63 | )
64 |
65 | return "matplotlib" in pkgs
66 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_dependencies/_safe_import.py:
--------------------------------------------------------------------------------
1 | """Import a module/class, return a Mock object if import fails.
2 |
3 | Copied from sktime/skbase.
4 |
5 | Should be refactored and moved to a common location in skbase.
6 | """
7 |
8 | import importlib
9 | from unittest.mock import MagicMock
10 |
11 | from pytorch_forecasting.utils._dependencies import _get_installed_packages
12 |
13 |
14 | def _safe_import(import_path, pkg_name=None):
15 | """Import a module/class, return a Mock object if import fails.
16 |
17 | Idiomatic usage is ``obj = _safe_import("a.b.c.obj")``.
18 | The function supports importing both top-level modules and nested attributes:
19 |
20 | - Top-level module: ``"torch"`` -> same as ``import torch``
21 | - Nested module: ``"torch.nn"`` -> same as``from torch import nn``
22 | - Class/function: ``"torch.nn.Linear"`` -> same as ``from torch.nn import Linear``
23 |
24 | Parameters
25 | ----------
26 | import_path : str
27 | The path to the module/class to import. Can be:
28 |
29 | - Single module: ``"torch"``
30 | - Nested module: ``"torch.nn"``
31 | - Class/attribute: ``"torch.nn.ReLU"``
32 |
33 | Note: The dots in the path determine the import behavior:
34 |
35 | - No dots: Imports as a single module
36 | - One dot: Imports as a submodule
37 | - Multiple dots: Last part is treated as an attribute to import
38 |
39 | pkg_name : str, default=None
40 | The name of the package to check for installation. This is useful when
41 | the import name differs from the package name, for example:
42 |
43 | - import: ``"sklearn"`` -> ``pkg_name="scikit-learn"``
44 | - import: ``"cv2"`` -> ``pkg_name="opencv-python"``
45 |
46 | If ``None``, uses the first part of ``import_path`` before the dot.
47 |
48 | Returns
49 | -------
50 | object
51 | If the import path and ``pkg_name`` is present, one of the following:
52 |
53 | - The imported module if ``import_path`` has no dots
54 | - The imported submodule if ``import_path`` has one dot
55 | - The imported class/function if ``import_path`` has multiple dots
56 |
57 | If the package or import path are not found:
58 | a unique ``MagicMock`` object per unique import path.
59 |
60 | Examples
61 | --------
62 | >>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import
63 |
64 | >>> # Import a top-level module
65 | >>> torch = _safe_import("torch")
66 |
67 | >>> # Import a submodule
68 | >>> nn = _safe_import("torch.nn")
69 |
70 | >>> # Import a specific class
71 | >>> Linear = _safe_import("torch.nn.Linear")
72 |
73 | >>> # Import with different package name
74 | >>> cv2 = _safe_import("cv2", pkg_name="opencv-python")
75 | """
76 | path_list = import_path.split(".")
77 |
78 | if pkg_name is None:
79 | pkg_name = path_list[0]
80 | obj_name = path_list[-1]
81 |
82 | if pkg_name in _get_installed_packages():
83 | try:
84 | if len(path_list) == 1:
85 | return importlib.import_module(pkg_name)
86 | module_name, attr_name = import_path.rsplit(".", 1)
87 | module = importlib.import_module(module_name)
88 | return getattr(module, attr_name)
89 | except (ImportError, AttributeError):
90 | pass
91 |
92 | mock_obj = _create_mock_class(obj_name)
93 | return mock_obj
94 |
95 |
96 | class CommonMagicMeta(type):
97 | def __getattr__(cls, name):
98 | return MagicMock()
99 |
100 | def __setattr__(cls, name, value):
101 | pass # Ignore attribute writes
102 |
103 |
104 | class MagicAttribute(metaclass=CommonMagicMeta):
105 | def __getattr__(self, name):
106 | return MagicMock()
107 |
108 | def __setattr__(self, name, value):
109 | pass # Ignore attribute writes
110 |
111 | def __call__(self, *args, **kwargs):
112 | return self # Ensures instantiation returns the same object
113 |
114 |
115 | def _create_mock_class(name: str, bases=()):
116 | """Create new dynamic mock class similar to MagicMock.
117 |
118 | Parameters
119 | ----------
120 | name : str
121 | The name of the new class.
122 | bases : tuple, default=()
123 | The base classes of the new class.
124 |
125 | Returns
126 | -------
127 | a new class that behaves like MagicMock, with name ``name``.
128 | Forwards all attribute access to a MagicMock object stored in the instance.
129 | """
130 | return type(name, (MagicAttribute,), {"__metaclass__": CommonMagicMeta})
131 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_dependencies/tests/__init__.py:
--------------------------------------------------------------------------------
1 | """Tests for dependency utilities."""
2 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py:
--------------------------------------------------------------------------------
1 | __author__ = ["jgyasu", "fkiraly"]
2 |
3 | from pytorch_forecasting.utils._dependencies import (
4 | _get_installed_packages,
5 | _safe_import,
6 | )
7 |
8 |
9 | def test_import_present_module():
10 | """Test importing a dependency that is installed."""
11 | result = _safe_import("pandas")
12 | assert result is not None
13 | assert "pandas" in _get_installed_packages()
14 |
15 |
16 | def test_import_missing_module():
17 | """Test importing a dependency that is not installed."""
18 | result = _safe_import("nonexistent_module")
19 | assert hasattr(result, "__name__")
20 | assert result.__name__ == "nonexistent_module"
21 |
22 |
23 | def test_import_without_pkg_name():
24 | """Test importing a dependency with the same name as package name."""
25 | result = _safe_import("torch", pkg_name="torch")
26 | assert result is not None
27 |
28 |
29 | def test_import_with_different_pkg_name_1():
30 | """Test importing a dependency with a different package name."""
31 | result = _safe_import("skbase", pkg_name="scikit-base")
32 | assert result is not None
33 |
34 |
35 | def test_import_with_different_pkg_name_2():
36 | """Test importing another dependency with a different package name."""
37 | result = _safe_import("cv2", pkg_name="opencv-python")
38 | assert result is not None
39 |
40 |
41 | def test_import_submodule():
42 | """Test importing a submodule."""
43 | result = _safe_import("torch.nn")
44 | assert result is not None
45 |
46 |
47 | def test_import_class():
48 | """Test importing a class."""
49 | result = _safe_import("torch.nn.Linear")
50 | assert result is not None
51 |
52 |
53 | def test_import_existing_object():
54 | """Test importing an existing object."""
55 | result = _safe_import("pandas.DataFrame")
56 | assert result is not None
57 | assert result.__name__ == "DataFrame"
58 | from pandas import DataFrame
59 |
60 | assert result is DataFrame
61 |
62 |
63 | def test_multiple_inheritance_from_mock():
64 | """Test multiple inheritance from dynamic MagicMock."""
65 | Class1 = _safe_import("foobar.foo.FooBar")
66 | Class2 = _safe_import("barfoobar.BarFooBar")
67 |
68 | class NewClass(Class1, Class2):
69 | """This should not trigger an error.
70 |
71 | The class definition would trigger an error if multiple inheritance
72 | from Class1 and Class2 does not work, e.g., if it is simply
73 | identical to MagicMock.
74 | """
75 |
76 | pass
77 |
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_maint/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sktime/pytorch-forecasting/4384140418bfe8e8c64cc6ce5fc19372508bfccd/pytorch_forecasting/utils/_maint/__init__.py
--------------------------------------------------------------------------------
/pytorch_forecasting/utils/_maint/_show_versions.py:
--------------------------------------------------------------------------------
1 | # License: BSD 3 clause
2 | """Utility methods to print system info for debugging.
3 |
4 | adapted from
5 | :func: `sklearn.show_versions` and `sktime.show_versions`
6 | """
7 |
8 | __all__ = ["show_versions"]
9 |
10 | import importlib
11 | import platform
12 | import sys
13 |
14 |
15 | def _get_sys_info():
16 | """System information.
17 |
18 | Return
19 | ------
20 | sys_info : dict
21 | system and Python version information
22 | """
23 | python = sys.version.replace("\n", " ")
24 |
25 | blob = [
26 | ("python", python),
27 | ("executable", sys.executable),
28 | ("machine", platform.platform()),
29 | ]
30 |
31 | return dict(blob)
32 |
33 |
34 | # dependencies to print versions of, by default
35 | DEFAULT_DEPS_TO_SHOW = [
36 | "pip",
37 | "pytorch-forecasting",
38 | "torch",
39 | "lightning",
40 | "numpy",
41 | "scipy",
42 | "pandas",
43 | "cpflows",
44 | "matplotlib",
45 | "optuna",
46 | "optuna-integration",
47 | "pytorch_optimizer",
48 | "scikit-learn",
49 | "scikit-base",
50 | "statsmodels",
51 | ]
52 |
53 |
54 | def _get_deps_info(deps=None, source="distributions"):
55 | """Overview of the installed version of main dependencies.
56 |
57 | Parameters
58 | ----------
59 | deps : optional, list of strings with package names
60 | if None, behaves as deps = ["pytorch-forecasting"].
61 |
62 | source : str, optional one of "distributions" (default) or "import"
63 | source of version information
64 |
65 | * "distributions" - uses importlib.distributions. In this case,
66 | strings in deps are assumed to be PEP 440 package strings,
67 | e.g., scikit-learn, not sklearn.
68 | * "import" - uses the __version__ attribute of the module.
69 | In this case, strings in deps are assumed to be import names,
70 | e.g., sklearn, not scikit-learn.
71 |
72 | Returns
73 | -------
74 | deps_info: dict
75 | version information on libraries in `deps`
76 | keys are package names, import names if source is "import",
77 | and PEP 440 package strings if source is "distributions";
78 | values are PEP 440 version strings
79 | of the import as present in the current python environment
80 | """
81 | if deps is None:
82 | deps = ["pytorch-forecasting"]
83 |
84 | if source == "distributions":
85 | from pytorch_forecasting.utils._dependencies import _get_installed_packages
86 |
87 | KEY_ALIAS = {"sklearn": "scikit-learn", "skbase": "scikit-base"}
88 |
89 | pkgs = _get_installed_packages()
90 |
91 | deps_info = {}
92 | for modname in deps:
93 | pkg_name = KEY_ALIAS.get(modname, modname)
94 | deps_info[modname] = pkgs.get(pkg_name, None)
95 |
96 | return deps_info
97 |
98 | def get_version(module):
99 | return getattr(module, "__version__", None)
100 |
101 | deps_info = {}
102 |
103 | for modname in deps:
104 | try:
105 | if modname in sys.modules:
106 | mod = sys.modules[modname]
107 | else:
108 | mod = importlib.import_module(modname)
109 | except ImportError:
110 | deps_info[modname] = None
111 | else:
112 | ver = get_version(mod)
113 | deps_info[modname] = ver
114 |
115 | return deps_info
116 |
117 |
118 | def show_versions():
119 | """Print python version, OS version, sktime version, selected dependency versions.
120 |
121 | Pretty prints:
122 |
123 | * python version of environment
124 | * python executable location
125 | * OS version
126 | * list of import name and version number for selected python dependencies
127 |
128 | Developer note:
129 | Python version/executable and OS version are from `_get_sys_info`
130 | Package versions are retrieved by `_get_deps_info`
131 | Selected dependencies are as in the DEFAULT_DEPS_TO_SHOW variable
132 | """
133 | sys_info = _get_sys_info()
134 | deps_info = _get_deps_info(deps=DEFAULT_DEPS_TO_SHOW)
135 |
136 | print("\nSystem:") # noqa: T001, T201
137 | for k, stat in sys_info.items():
138 | print(f"{k:>10}: {stat}") # noqa: T001, T201
139 |
140 | print("\nPython dependencies:") # noqa: T001, T201
141 | for k, stat in deps_info.items():
142 | print(f"{k:>13}: {stat}") # noqa: T001, T201
143 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import pytest
6 |
7 | sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) # isort:skip
8 |
9 |
10 | from pytorch_forecasting import TimeSeriesDataSet # isort:skip
11 | from pytorch_forecasting.data.examples import get_stallion_data # isort:skip
12 |
13 |
14 | # for vscode debugging: https://stackoverflow.com/a/62563106/14121677
15 | if os.getenv("_PYTEST_RAISE", "0") != "0":
16 |
17 | @pytest.hookimpl(tryfirst=True)
18 | def pytest_exception_interact(call):
19 | raise call.excinfo.value
20 |
21 | @pytest.hookimpl(tryfirst=True)
22 | def pytest_internalerror(excinfo):
23 | raise excinfo.value
24 |
25 |
26 | @pytest.fixture(scope="session")
27 | def test_data():
28 | data = get_stallion_data()
29 | data["month"] = data.date.dt.month.astype(str)
30 | data["log_volume"] = np.log1p(data.volume)
31 | data["weight"] = 1 + np.sqrt(data.volume)
32 |
33 | data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
34 | data["time_idx"] -= data["time_idx"].min()
35 |
36 | special_days = [
37 | "easter_day",
38 | "good_friday",
39 | "new_year",
40 | "christmas",
41 | "labor_day",
42 | "independence_day",
43 | "revolution_day_memorial",
44 | "regional_games",
45 | "fifa_u_17_world_cup",
46 | "football_gold_cup",
47 | "beer_capital",
48 | "music_fest",
49 | ]
50 | data[special_days] = (
51 | data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
52 | )
53 |
54 | data = data[lambda x: x.time_idx < 10] # downsample
55 | return data
56 |
57 |
58 | @pytest.fixture(scope="session")
59 | def test_dataset(test_data):
60 | training = TimeSeriesDataSet(
61 | test_data.copy(),
62 | time_idx="time_idx",
63 | target="volume",
64 | time_varying_known_reals=["price_regular", "time_idx"],
65 | group_ids=["agency", "sku"],
66 | static_categoricals=["agency"],
67 | max_encoder_length=5,
68 | max_prediction_length=2,
69 | min_prediction_length=1,
70 | min_encoder_length=0,
71 | randomize_length=None,
72 | )
73 | return training
74 |
75 |
76 | @pytest.fixture(autouse=True)
77 | def disable_mps(monkeypatch):
78 | """Disable MPS for all tests"""
79 | monkeypatch.setattr("torch._C._mps_is_available", lambda: False)
80 |
--------------------------------------------------------------------------------
/tests/test_data/test_encoders.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import itertools
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import pytest
7 | from sklearn.utils.validation import NotFittedError, check_is_fitted
8 | import torch
9 |
10 | from pytorch_forecasting.data import (
11 | EncoderNormalizer,
12 | GroupNormalizer,
13 | MultiNormalizer,
14 | NaNLabelEncoder,
15 | TorchNormalizer,
16 | )
17 |
18 |
19 | @pytest.mark.parametrize(
20 | "data,allow_nan",
21 | itertools.product(
22 | [
23 | (np.array([2, 3, 4]), np.array([1, 2, 3, 5, np.nan])),
24 | (np.array(["a", "b", "c"]), np.array(["q", "a", "nan"])),
25 | ],
26 | [True, False],
27 | ),
28 | )
29 | def test_NaNLabelEncoder(data, allow_nan):
30 | fit_data, transform_data = data
31 | encoder = NaNLabelEncoder(warn=False, add_nan=allow_nan)
32 | encoder.fit(fit_data)
33 | assert np.array_equal(
34 | encoder.inverse_transform(encoder.transform(fit_data)), fit_data
35 | ), "Inverse transform should reverse transform"
36 | if not allow_nan:
37 | with pytest.raises(KeyError):
38 | encoder.transform(transform_data)
39 | else:
40 | assert (
41 | encoder.transform(transform_data)[0] == 0
42 | ), "First value should be translated to 0 if nan"
43 | assert (
44 | encoder.transform(transform_data)[-1] == 0
45 | ), "Last value should be translated to 0 if nan"
46 | assert (
47 | encoder.transform(fit_data)[0] > 0
48 | ), "First value should not be 0 if not nan"
49 |
50 |
51 | def test_NaNLabelEncoder_add():
52 | encoder = NaNLabelEncoder(add_nan=False)
53 | encoder.fit(np.array(["a", "b", "c"]))
54 | encoder2 = deepcopy(encoder)
55 | encoder2.fit(np.array(["d"]))
56 | assert encoder2.transform(np.array(["a"]))[0] == 0, "a must be encoded as 0"
57 | assert encoder2.transform(np.array(["d"]))[0] == 3, "d must be encoded as 3"
58 |
59 |
60 | @pytest.mark.parametrize(
61 | "kwargs",
62 | [
63 | dict(method="robust"),
64 | dict(method="robust", method_kwargs=dict(upper=1.0, lower=0.0)),
65 | dict(method="robust", data=np.random.randn(100)),
66 | dict(data=np.random.randn(100)),
67 | dict(transformation="log"),
68 | dict(transformation="softplus"),
69 | dict(transformation="log1p"),
70 | dict(transformation="relu"),
71 | dict(method="identity"),
72 | dict(method="identity", data=np.random.randn(100)),
73 | dict(center=False),
74 | dict(max_length=5),
75 | dict(data=pd.Series(np.random.randn(100))),
76 | dict(max_length=[1, 2]),
77 | ],
78 | )
79 | def test_EncoderNormalizer(kwargs):
80 | kwargs.setdefault("method", "standard")
81 | kwargs.setdefault("center", True)
82 | kwargs.setdefault("data", torch.rand(100))
83 | data = kwargs.pop("data")
84 |
85 | normalizer = EncoderNormalizer(**kwargs)
86 |
87 | if kwargs.get("transformation") in ["relu", "softplus", "log1p"]:
88 | assert (
89 | normalizer.inverse_transform(
90 | torch.as_tensor(normalizer.fit_transform(data))
91 | )
92 | >= 0
93 | ).all(), "Inverse transform should yield only positive values"
94 | else:
95 | assert torch.isclose(
96 | normalizer.inverse_transform(
97 | torch.as_tensor(normalizer.fit_transform(data))
98 | ),
99 | torch.as_tensor(data),
100 | atol=1e-5,
101 | ).all(), "Inverse transform should reverse transform"
102 |
103 |
104 | @pytest.mark.parametrize(
105 | "kwargs,groups",
106 | itertools.product(
107 | [
108 | dict(method="robust"),
109 | dict(transformation="log"),
110 | dict(transformation="relu"),
111 | dict(center=False),
112 | dict(transformation="log1p"),
113 | dict(transformation="softplus"),
114 | dict(scale_by_group=True),
115 | ],
116 | [[], ["a"]],
117 | ),
118 | )
119 | def test_GroupNormalizer(kwargs, groups):
120 | data = pd.DataFrame(dict(a=[1, 1, 2, 2, 3], b=[1.1, 1.1, 1.0, 0.0, 1.1]))
121 | defaults = dict(
122 | method="standard", transformation=None, center=True, scale_by_group=False
123 | )
124 | defaults.update(kwargs)
125 | kwargs = defaults
126 | kwargs["groups"] = groups
127 | kwargs["scale_by_group"] = kwargs["scale_by_group"] and len(kwargs["groups"]) > 0
128 |
129 | normalizer = GroupNormalizer(**kwargs)
130 | encoded = normalizer.fit_transform(data["b"], data)
131 |
132 | test_data = dict(
133 | prediction=torch.tensor([encoded[0]]),
134 | target_scale=torch.tensor(normalizer.get_parameters([1])).unsqueeze(0),
135 | )
136 |
137 | if kwargs.get("transformation") in ["relu", "softplus", "log1p", "log"]:
138 | assert (
139 | normalizer(test_data) >= 0
140 | ).all(), "Inverse transform should yield only positive values"
141 | else:
142 | assert torch.isclose(
143 | normalizer(test_data), torch.tensor(data.b.iloc[0]), atol=1e-5
144 | ).all(), "Inverse transform should reverse transform"
145 |
146 |
147 | def test_EncoderNormalizer_with_limited_history():
148 | data = torch.rand(100)
149 | normalizer = EncoderNormalizer(max_length=[1, 2]).fit(data)
150 | assert normalizer.center_ == data[-1]
151 |
152 |
153 | def test_MultiNormalizer_fitted():
154 | data = pd.DataFrame(
155 | dict(
156 | a=[1, 1, 2, 2, 3], b=[1.1, 1.1, 1.0, 5.0, 1.1], c=[1.1, 1.1, 1.0, 5.0, 1.1]
157 | )
158 | )
159 |
160 | normalizer = MultiNormalizer([GroupNormalizer(groups=["a"]), TorchNormalizer()])
161 |
162 | with pytest.raises(NotFittedError):
163 | check_is_fitted(normalizer)
164 |
165 | normalizer.fit(data, data)
166 |
167 | try:
168 | check_is_fitted(normalizer.normalizers[0])
169 | check_is_fitted(normalizer.normalizers[1])
170 | check_is_fitted(normalizer)
171 | except NotFittedError:
172 | pytest.fail(f"{NotFittedError}")
173 |
174 |
175 | def test_TorchNormalizer_dtype_consistency():
176 | """
177 | - Ensures that even for float64 `target_scale`, the transformation will not change the prediction dtype.
178 | - Ensure that target_scale will be of type float32 if method is 'identity'
179 | """ # noqa: E501
180 | parameters = torch.tensor([[[366.4587]]])
181 | target_scale = torch.tensor([[427875.7500, 80367.4766]], dtype=torch.float64)
182 | assert (
183 | TorchNormalizer()(dict(prediction=parameters, target_scale=target_scale)).dtype
184 | == torch.float32
185 | )
186 | assert (
187 | TorchNormalizer().transform(parameters, target_scale=target_scale).dtype
188 | == torch.float32
189 | )
190 |
191 | y = np.array([1, 2, 3], dtype=np.float32)
192 | assert (
193 | TorchNormalizer(method="identity").fit(y).get_parameters().dtype
194 | == torch.float32
195 | )
196 |
--------------------------------------------------------------------------------
/tests/test_data/test_samplers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.utils.data.sampler import SequentialSampler
4 |
5 | from pytorch_forecasting.data import TimeSynchronizedBatchSampler
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "drop_last,shuffle,as_string,batch_size",
10 | [
11 | (True, True, True, 64),
12 | (False, False, False, 64),
13 | (True, False, False, 1000),
14 | ],
15 | )
16 | def test_TimeSynchronizedBatchSampler(
17 | test_dataset, shuffle, drop_last, as_string, batch_size
18 | ):
19 | if as_string:
20 | dataloader = test_dataset.to_dataloader(
21 | batch_sampler="synchronized",
22 | shuffle=shuffle,
23 | drop_last=drop_last,
24 | batch_size=batch_size,
25 | )
26 | else:
27 | sampler = TimeSynchronizedBatchSampler(
28 | SequentialSampler(test_dataset),
29 | shuffle=shuffle,
30 | drop_last=drop_last,
31 | batch_size=batch_size,
32 | )
33 | dataloader = test_dataset.to_dataloader(batch_sampler=sampler)
34 |
35 | time_idx_pos = test_dataset.reals.index("time_idx")
36 | for x, _ in iter(dataloader): # check all samples
37 | time_idx_of_first_prediction = x["decoder_cont"][:, 0, time_idx_pos]
38 | assert torch.isclose(
39 | time_idx_of_first_prediction, time_idx_of_first_prediction[0]
40 | ).all(), "Time index should be the same for the first prediction"
41 |
--------------------------------------------------------------------------------
/tests/test_models/test_baseline.py:
--------------------------------------------------------------------------------
1 | from pytorch_forecasting import Baseline
2 |
3 |
4 | def test_integration(multiple_dataloaders_with_covariates):
5 | dataloader = multiple_dataloaders_with_covariates["val"]
6 | Baseline().predict(dataloader, fast_dev_run=True)
7 | repr(Baseline())
8 |
--------------------------------------------------------------------------------
/tests/test_models/test_deepar.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import shutil
3 |
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.callbacks import EarlyStopping
6 | from lightning.pytorch.loggers import TensorBoardLogger
7 | import pytest
8 | from test_models.conftest import make_dataloaders
9 | from torch import nn
10 |
11 | from pytorch_forecasting.data.encoders import GroupNormalizer
12 | from pytorch_forecasting.metrics import (
13 | BetaDistributionLoss,
14 | ImplicitQuantileNetworkDistributionLoss,
15 | LogNormalDistributionLoss,
16 | MultivariateNormalDistributionLoss,
17 | NegativeBinomialDistributionLoss,
18 | NormalDistributionLoss,
19 | )
20 | from pytorch_forecasting.models import DeepAR
21 |
22 |
23 | def _integration(
24 | data_with_covariates,
25 | tmp_path,
26 | cell_type="LSTM",
27 | data_loader_kwargs={},
28 | clip_target: bool = False,
29 | trainer_kwargs=None,
30 | **kwargs,
31 | ):
32 | data_with_covariates = data_with_covariates.copy()
33 | if clip_target:
34 | data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0)
35 | else:
36 | data_with_covariates["target"] = data_with_covariates["volume"]
37 | data_loader_default_kwargs = dict(
38 | target="target",
39 | time_varying_known_reals=["price_actual"],
40 | time_varying_unknown_reals=["target"],
41 | static_categoricals=["agency"],
42 | add_relative_time_idx=True,
43 | )
44 | data_loader_default_kwargs.update(data_loader_kwargs)
45 | dataloaders_with_covariates = make_dataloaders(
46 | data_with_covariates, **data_loader_default_kwargs
47 | )
48 |
49 | train_dataloader = dataloaders_with_covariates["train"]
50 | val_dataloader = dataloaders_with_covariates["val"]
51 | test_dataloader = dataloaders_with_covariates["test"]
52 |
53 | early_stop_callback = EarlyStopping(
54 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min"
55 | )
56 |
57 | logger = TensorBoardLogger(tmp_path)
58 | if trainer_kwargs is None:
59 | trainer_kwargs = {}
60 | trainer = pl.Trainer(
61 | max_epochs=3,
62 | gradient_clip_val=0.1,
63 | callbacks=[early_stop_callback],
64 | enable_checkpointing=True,
65 | default_root_dir=tmp_path,
66 | limit_train_batches=2,
67 | limit_val_batches=2,
68 | limit_test_batches=2,
69 | logger=logger,
70 | **trainer_kwargs,
71 | )
72 |
73 | net = DeepAR.from_dataset(
74 | train_dataloader.dataset,
75 | hidden_size=5,
76 | cell_type=cell_type,
77 | learning_rate=0.01,
78 | log_gradient_flow=True,
79 | log_interval=1000,
80 | n_plotting_samples=100,
81 | **kwargs,
82 | )
83 | net.size()
84 | try:
85 | trainer.fit(
86 | net,
87 | train_dataloaders=train_dataloader,
88 | val_dataloaders=val_dataloader,
89 | )
90 | test_outputs = trainer.test(net, dataloaders=test_dataloader)
91 | assert len(test_outputs) > 0
92 | # check loading
93 | net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
94 |
95 | # check prediction
96 | net.predict(
97 | val_dataloader,
98 | fast_dev_run=True,
99 | return_index=True,
100 | return_decoder_lengths=True,
101 | trainer_kwargs=trainer_kwargs,
102 | )
103 | finally:
104 | shutil.rmtree(tmp_path, ignore_errors=True)
105 |
106 | net.predict(
107 | val_dataloader,
108 | fast_dev_run=True,
109 | return_index=True,
110 | return_decoder_lengths=True,
111 | trainer_kwargs=trainer_kwargs,
112 | )
113 |
114 |
115 | @pytest.mark.parametrize(
116 | "kwargs",
117 | [
118 | {},
119 | {"cell_type": "GRU"},
120 | dict(
121 | loss=LogNormalDistributionLoss(),
122 | clip_target=True,
123 | data_loader_kwargs=dict(
124 | target_normalizer=GroupNormalizer(
125 | groups=["agency", "sku"], transformation="log"
126 | )
127 | ),
128 | ),
129 | dict(
130 | loss=NegativeBinomialDistributionLoss(),
131 | clip_target=False,
132 | data_loader_kwargs=dict(
133 | target_normalizer=GroupNormalizer(
134 | groups=["agency", "sku"], center=False
135 | )
136 | ),
137 | ),
138 | dict(
139 | loss=BetaDistributionLoss(),
140 | clip_target=True,
141 | data_loader_kwargs=dict(
142 | target_normalizer=GroupNormalizer(
143 | groups=["agency", "sku"], transformation="logit"
144 | )
145 | ),
146 | ),
147 | dict(
148 | data_loader_kwargs=dict(
149 | lags={"volume": [2, 5]},
150 | target="volume",
151 | time_varying_unknown_reals=["volume"],
152 | min_encoder_length=2,
153 | )
154 | ),
155 | dict(
156 | data_loader_kwargs=dict(
157 | time_varying_unknown_reals=["volume", "discount"],
158 | target=["volume", "discount"],
159 | lags={"volume": [2], "discount": [2]},
160 | )
161 | ),
162 | dict(
163 | loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8),
164 | ),
165 | dict(
166 | loss=MultivariateNormalDistributionLoss(),
167 | trainer_kwargs=dict(accelerator="cpu"),
168 | ),
169 | dict(
170 | loss=MultivariateNormalDistributionLoss(),
171 | data_loader_kwargs=dict(
172 | target_normalizer=GroupNormalizer(
173 | groups=["agency", "sku"], transformation="log1p"
174 | )
175 | ),
176 | trainer_kwargs=dict(accelerator="cpu"),
177 | ),
178 | ],
179 | )
180 | def test_integration(data_with_covariates, tmp_path, kwargs):
181 | if "loss" in kwargs and isinstance(
182 | kwargs["loss"], NegativeBinomialDistributionLoss
183 | ):
184 | data_with_covariates = data_with_covariates.assign(
185 | volume=lambda x: x.volume.round()
186 | )
187 | _integration(data_with_covariates, tmp_path, **kwargs)
188 |
189 |
190 | @pytest.fixture
191 | def model(dataloaders_with_covariates):
192 | dataset = dataloaders_with_covariates["train"].dataset
193 | net = DeepAR.from_dataset(
194 | dataset,
195 | hidden_size=5,
196 | learning_rate=0.15,
197 | log_gradient_flow=True,
198 | log_interval=1000,
199 | )
200 | return net
201 |
202 |
203 | def test_predict_average(model, dataloaders_with_covariates):
204 | prediction = model.predict(
205 | dataloaders_with_covariates["val"],
206 | fast_dev_run=True,
207 | mode="prediction",
208 | n_samples=100,
209 | )
210 | assert prediction.ndim == 2, "expected averaging of samples"
211 |
212 |
213 | def test_predict_samples(model, dataloaders_with_covariates):
214 | prediction = model.predict(
215 | dataloaders_with_covariates["val"],
216 | fast_dev_run=True,
217 | mode="samples",
218 | n_samples=100,
219 | )
220 | assert prediction.size()[-1] == 100, "expected raw samples"
221 |
222 |
223 | @pytest.mark.parametrize(
224 | "loss", [NormalDistributionLoss(), MultivariateNormalDistributionLoss()]
225 | )
226 | def test_pickle(dataloaders_with_covariates, loss):
227 | dataset = dataloaders_with_covariates["train"].dataset
228 | model = DeepAR.from_dataset(
229 | dataset,
230 | hidden_size=5,
231 | learning_rate=0.15,
232 | log_gradient_flow=True,
233 | log_interval=1000,
234 | loss=loss,
235 | )
236 | pkl = pickle.dumps(model)
237 | pickle.loads(pkl) # noqa: S301
238 |
--------------------------------------------------------------------------------
/tests/test_models/test_mlp.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import shutil
3 |
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.callbacks import EarlyStopping
6 | from lightning.pytorch.loggers import TensorBoardLogger
7 | import pytest
8 | from test_models.conftest import make_dataloaders
9 | from torchmetrics import MeanSquaredError
10 |
11 | from pytorch_forecasting.metrics import MAE, CrossEntropy, MultiLoss, QuantileLoss
12 | from pytorch_forecasting.models import DecoderMLP
13 |
14 |
15 | def _integration(
16 | data_with_covariates, tmp_path, data_loader_kwargs={}, train_only=False, **kwargs
17 | ):
18 | data_loader_default_kwargs = dict(
19 | target="target",
20 | time_varying_known_reals=["price_actual"],
21 | time_varying_unknown_reals=["target"],
22 | static_categoricals=["agency"],
23 | add_relative_time_idx=True,
24 | )
25 | data_loader_default_kwargs.update(data_loader_kwargs)
26 | dataloaders_with_covariates = make_dataloaders(
27 | data_with_covariates, **data_loader_default_kwargs
28 | )
29 | train_dataloader = dataloaders_with_covariates["train"]
30 | val_dataloader = dataloaders_with_covariates["val"]
31 | test_dataloader = dataloaders_with_covariates["test"]
32 | early_stop_callback = EarlyStopping(
33 | monitor="val_loss",
34 | min_delta=1e-4,
35 | patience=1,
36 | verbose=False,
37 | mode="min",
38 | strict=False,
39 | )
40 |
41 | logger = TensorBoardLogger(tmp_path)
42 | trainer = pl.Trainer(
43 | max_epochs=3,
44 | gradient_clip_val=0.1,
45 | callbacks=[early_stop_callback],
46 | enable_checkpointing=True,
47 | default_root_dir=tmp_path,
48 | limit_train_batches=2,
49 | limit_val_batches=2,
50 | limit_test_batches=2,
51 | logger=logger,
52 | )
53 |
54 | net = DecoderMLP.from_dataset(
55 | train_dataloader.dataset,
56 | learning_rate=0.015,
57 | log_gradient_flow=True,
58 | log_interval=1000,
59 | hidden_size=10,
60 | **kwargs,
61 | )
62 | net.size()
63 | try:
64 | if train_only:
65 | trainer.fit(net, train_dataloaders=train_dataloader)
66 | else:
67 | trainer.fit(
68 | net,
69 | train_dataloaders=train_dataloader,
70 | val_dataloaders=val_dataloader,
71 | )
72 | # check loading
73 | net = DecoderMLP.load_from_checkpoint(
74 | trainer.checkpoint_callback.best_model_path
75 | )
76 |
77 | # check prediction
78 | net.predict(
79 | val_dataloader,
80 | fast_dev_run=True,
81 | return_index=True,
82 | return_decoder_lengths=True,
83 | )
84 | # check test dataloader
85 | test_outputs = trainer.test(net, dataloaders=test_dataloader)
86 | assert len(test_outputs) > 0
87 | finally:
88 | shutil.rmtree(tmp_path, ignore_errors=True)
89 |
90 | net.predict(
91 | val_dataloader,
92 | fast_dev_run=True,
93 | return_index=True,
94 | return_decoder_lengths=True,
95 | )
96 |
97 |
98 | @pytest.mark.parametrize(
99 | "kwargs",
100 | [
101 | {},
102 | dict(train_only=True),
103 | dict(
104 | loss=MultiLoss([QuantileLoss(), MAE()]),
105 | data_loader_kwargs=dict(
106 | time_varying_unknown_reals=["volume", "discount"],
107 | target=["volume", "discount"],
108 | ),
109 | ),
110 | dict(
111 | loss=CrossEntropy(),
112 | data_loader_kwargs=dict(
113 | target="agency",
114 | ),
115 | ),
116 | dict(loss=MeanSquaredError()),
117 | dict(
118 | loss=MeanSquaredError(),
119 | data_loader_kwargs=dict(min_prediction_length=1, min_encoder_length=1),
120 | ),
121 | ],
122 | )
123 | def test_integration(data_with_covariates, tmp_path, kwargs):
124 | _integration(
125 | data_with_covariates.assign(target=lambda x: x.volume), tmp_path, **kwargs
126 | )
127 |
128 |
129 | @pytest.fixture
130 | def model(dataloaders_with_covariates):
131 | dataset = dataloaders_with_covariates["train"].dataset
132 | net = DecoderMLP.from_dataset(
133 | dataset,
134 | learning_rate=0.15,
135 | log_gradient_flow=True,
136 | log_interval=1000,
137 | hidden_size=10,
138 | )
139 | return net
140 |
141 |
142 | def test_pickle(model):
143 | pkl = pickle.dumps(model)
144 | pickle.loads(pkl) # noqa: S301
145 |
--------------------------------------------------------------------------------
/tests/test_models/test_nbeats.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import shutil
3 |
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.callbacks import EarlyStopping
6 | from lightning.pytorch.loggers import TensorBoardLogger
7 | import pytest
8 |
9 | from pytorch_forecasting.models import NBeats
10 | from pytorch_forecasting.utils._dependencies import _get_installed_packages
11 |
12 |
13 | def test_integration(dataloaders_fixed_window_without_covariates, tmp_path):
14 | train_dataloader = dataloaders_fixed_window_without_covariates["train"]
15 | val_dataloader = dataloaders_fixed_window_without_covariates["val"]
16 | test_dataloader = dataloaders_fixed_window_without_covariates["test"]
17 |
18 | early_stop_callback = EarlyStopping(
19 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min"
20 | )
21 |
22 | logger = TensorBoardLogger(tmp_path)
23 | trainer = pl.Trainer(
24 | max_epochs=2,
25 | gradient_clip_val=0.1,
26 | callbacks=[early_stop_callback],
27 | enable_checkpointing=True,
28 | default_root_dir=tmp_path,
29 | limit_train_batches=2,
30 | limit_val_batches=2,
31 | limit_test_batches=2,
32 | logger=logger,
33 | )
34 |
35 | net = NBeats.from_dataset(
36 | train_dataloader.dataset,
37 | learning_rate=0.15,
38 | log_gradient_flow=True,
39 | widths=[4, 4, 4],
40 | log_interval=1000,
41 | backcast_loss_ratio=1.0,
42 | )
43 | net.size()
44 | try:
45 | trainer.fit(
46 | net,
47 | train_dataloaders=train_dataloader,
48 | val_dataloaders=val_dataloader,
49 | )
50 | test_outputs = trainer.test(net, dataloaders=test_dataloader)
51 | assert len(test_outputs) > 0
52 | # check loading
53 | net = NBeats.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
54 |
55 | # check prediction
56 | net.predict(
57 | val_dataloader,
58 | fast_dev_run=True,
59 | return_index=True,
60 | return_decoder_lengths=True,
61 | )
62 | finally:
63 | shutil.rmtree(tmp_path, ignore_errors=True)
64 |
65 | net.predict(
66 | val_dataloader,
67 | fast_dev_run=True,
68 | return_index=True,
69 | return_decoder_lengths=True,
70 | )
71 |
72 |
73 | @pytest.fixture(scope="session")
74 | def model(dataloaders_fixed_window_without_covariates):
75 | dataset = dataloaders_fixed_window_without_covariates["train"].dataset
76 | net = NBeats.from_dataset(
77 | dataset,
78 | learning_rate=0.15,
79 | log_gradient_flow=True,
80 | widths=[4, 4, 4],
81 | log_interval=1000,
82 | backcast_loss_ratio=1.0,
83 | )
84 | return net
85 |
86 |
87 | def test_pickle(model):
88 | pkl = pickle.dumps(model)
89 | pickle.loads(pkl) # noqa: S301
90 |
91 |
92 | @pytest.mark.skipif(
93 | "matplotlib" not in _get_installed_packages(),
94 | reason="skip test if required package matplotlib not installed",
95 | )
96 | def test_interpretation(model, dataloaders_fixed_window_without_covariates):
97 | raw_predictions = model.predict(
98 | dataloaders_fixed_window_without_covariates["val"],
99 | mode="raw",
100 | return_x=True,
101 | fast_dev_run=True,
102 | )
103 | model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0)
104 |
--------------------------------------------------------------------------------
/tests/test_models/test_nhits.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import shutil
3 |
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.callbacks import EarlyStopping
6 | from lightning.pytorch.loggers import TensorBoardLogger
7 | import numpy as np
8 | import pandas as pd
9 | import pytest
10 |
11 | from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
12 | from pytorch_forecasting.metrics import MQF2DistributionLoss, QuantileLoss
13 | from pytorch_forecasting.metrics.distributions import (
14 | ImplicitQuantileNetworkDistributionLoss,
15 | )
16 | from pytorch_forecasting.models import NHiTS
17 | from pytorch_forecasting.utils._dependencies import _get_installed_packages
18 |
19 |
20 | def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
21 | train_dataloader = dataloader["train"]
22 | val_dataloader = dataloader["val"]
23 | test_dataloader = dataloader["test"]
24 |
25 | early_stop_callback = EarlyStopping(
26 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min"
27 | )
28 |
29 | logger = TensorBoardLogger(tmp_path)
30 | if trainer_kwargs is None:
31 | trainer_kwargs = {}
32 | trainer = pl.Trainer(
33 | max_epochs=2,
34 | gradient_clip_val=0.1,
35 | callbacks=[early_stop_callback],
36 | enable_checkpointing=True,
37 | default_root_dir=tmp_path,
38 | limit_train_batches=2,
39 | limit_val_batches=2,
40 | limit_test_batches=2,
41 | logger=logger,
42 | **trainer_kwargs,
43 | )
44 |
45 | kwargs.setdefault("learning_rate", 0.15)
46 | kwargs.setdefault("weight_decay", 1e-2)
47 |
48 | net = NHiTS.from_dataset(
49 | train_dataloader.dataset,
50 | log_gradient_flow=True,
51 | log_interval=1000,
52 | hidden_size=8,
53 | **kwargs,
54 | )
55 | net.size()
56 | try:
57 | trainer.fit(
58 | net,
59 | train_dataloaders=train_dataloader,
60 | val_dataloaders=val_dataloader,
61 | )
62 | # todo: testing somehow disables grad computation even though
63 | # it is explicitly turned on
64 | # loss is calculated as "grad" for MQF2
65 | if not isinstance(net.loss, MQF2DistributionLoss):
66 | test_outputs = trainer.test(net, dataloaders=test_dataloader)
67 | assert len(test_outputs) > 0
68 | # check loading
69 | net = NHiTS.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
70 |
71 | # check prediction
72 | net.predict(
73 | val_dataloader,
74 | fast_dev_run=True,
75 | return_index=True,
76 | return_decoder_lengths=True,
77 | trainer_kwargs=trainer_kwargs,
78 | )
79 | finally:
80 | shutil.rmtree(tmp_path, ignore_errors=True)
81 |
82 | net.predict(
83 | val_dataloader,
84 | fast_dev_run=True,
85 | return_index=True,
86 | return_decoder_lengths=True,
87 | )
88 |
89 |
90 | LOADERS = [
91 | "with_covariates",
92 | "different_encoder_decoder_size",
93 | "fixed_window_without_covariates",
94 | "multi_target",
95 | "quantiles",
96 | "implicit-quantiles",
97 | ]
98 |
99 | if "cpflows" in _get_installed_packages():
100 | LOADERS += ["multivariate-quantiles"]
101 |
102 |
103 | @pytest.mark.parametrize("dataloader", LOADERS)
104 | def test_integration(
105 | dataloaders_with_covariates,
106 | dataloaders_with_different_encoder_decoder_length,
107 | dataloaders_fixed_window_without_covariates,
108 | dataloaders_multi_target,
109 | tmp_path,
110 | dataloader,
111 | ):
112 | kwargs = {}
113 | if dataloader == "with_covariates":
114 | dataloader = dataloaders_with_covariates
115 | kwargs["backcast_loss_ratio"] = 0.5
116 | elif dataloader == "different_encoder_decoder_size":
117 | dataloader = dataloaders_with_different_encoder_decoder_length
118 | elif dataloader == "fixed_window_without_covariates":
119 | dataloader = dataloaders_fixed_window_without_covariates
120 | elif dataloader == "multi_target":
121 | dataloader = dataloaders_multi_target
122 | kwargs["loss"] = QuantileLoss()
123 | elif dataloader == "quantiles":
124 | dataloader = dataloaders_with_covariates
125 | kwargs["loss"] = QuantileLoss()
126 | elif dataloader == "implicit-quantiles":
127 | dataloader = dataloaders_with_covariates
128 | kwargs["loss"] = ImplicitQuantileNetworkDistributionLoss()
129 | elif dataloader == "multivariate-quantiles":
130 | dataloader = dataloaders_with_covariates
131 | kwargs["loss"] = MQF2DistributionLoss(
132 | prediction_length=dataloader["train"].dataset.max_prediction_length
133 | )
134 | kwargs["learning_rate"] = 1e-9
135 | kwargs["trainer_kwargs"] = dict(accelerator="cpu")
136 | else:
137 | raise ValueError(f"dataloader {dataloader} unknown")
138 | _integration(dataloader, tmp_path=tmp_path, **kwargs)
139 |
140 |
141 | @pytest.fixture(scope="session")
142 | def model(dataloaders_with_covariates):
143 | dataset = dataloaders_with_covariates["train"].dataset
144 | net = NHiTS.from_dataset(
145 | dataset,
146 | learning_rate=0.15,
147 | hidden_size=8,
148 | log_gradient_flow=True,
149 | log_interval=1000,
150 | backcast_loss_ratio=1.0,
151 | )
152 | return net
153 |
154 |
155 | def test_pickle(model):
156 | pkl = pickle.dumps(model)
157 | pickle.loads(pkl) # noqa : S301
158 |
159 |
160 | @pytest.mark.skipif(
161 | "matplotlib" not in _get_installed_packages(),
162 | reason="skip test if required package matplotlib not installed",
163 | )
164 | def test_interpretation(model, dataloaders_with_covariates):
165 | raw_predictions = model.predict(
166 | dataloaders_with_covariates["val"], mode="raw", return_x=True, fast_dev_run=True
167 | )
168 | model.plot_prediction(
169 | raw_predictions.x, raw_predictions.output, idx=0, add_loss_to_title=True
170 | )
171 | model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0)
172 |
173 |
174 | # Bug when max_prediction_length=1 #1571
175 | @pytest.mark.parametrize("max_prediction_length", [1, 5])
176 | def test_prediction_length(max_prediction_length: int):
177 | n_timeseries = 10
178 | time_points = 10
179 | data = pd.DataFrame(
180 | data={
181 | "target": np.random.rand(time_points * n_timeseries),
182 | "time_varying_known_real_1": np.random.rand(time_points * n_timeseries),
183 | "time_idx": np.tile(np.arange(time_points), n_timeseries),
184 | "group_id": np.repeat(np.arange(n_timeseries), time_points),
185 | }
186 | )
187 | training_dataset = TimeSeriesDataSet(
188 | data=data,
189 | time_idx="time_idx",
190 | target="target",
191 | group_ids=["group_id"],
192 | time_varying_unknown_reals=["target"],
193 | time_varying_known_reals=(["time_varying_known_real_1"]),
194 | max_prediction_length=max_prediction_length,
195 | max_encoder_length=3,
196 | )
197 | training_data_loader = training_dataset.to_dataloader(train=True)
198 | forecaster = NHiTS.from_dataset(training_dataset, log_val_interval=1)
199 | trainer = pl.Trainer(
200 | accelerator="cpu",
201 | max_epochs=3,
202 | min_epochs=2,
203 | limit_train_batches=10,
204 | )
205 | trainer.fit(
206 | forecaster,
207 | train_dataloaders=training_data_loader,
208 | )
209 | validation_dataset = TimeSeriesDataSet.from_dataset(
210 | training_dataset, data, stop_randomization=True, predict=True
211 | )
212 | validation_data_loader = validation_dataset.to_dataloader(train=False)
213 | forecaster.predict(
214 | validation_data_loader,
215 | fast_dev_run=True,
216 | return_index=True,
217 | return_decoder_lengths=True,
218 | )
219 |
--------------------------------------------------------------------------------
/tests/test_models/test_nn/test_embeddings.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from pytorch_forecasting import MultiEmbedding
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "kwargs",
9 | [
10 | dict(embedding_sizes=(10, 10, 10)),
11 | dict(embedding_sizes=((10, 3), (10, 2), (10, 1))),
12 | dict(x_categoricals=["x1", "x2", "x3"], embedding_sizes=dict(x1=(10, 10))),
13 | dict(
14 | x_categoricals=["x1", "x2", "x3"],
15 | embedding_sizes=dict(x1=(10, 2), xg1=(10, 3)),
16 | categorical_groups=dict(xg1=["x2", "x3"]),
17 | ),
18 | ],
19 | )
20 | def test_MultiEmbedding(kwargs):
21 | x = torch.randint(0, 10, size=(4, 3))
22 | embedding = MultiEmbedding(**kwargs)
23 | assert embedding.input_size == x.size(
24 | 1
25 | ), "Input size should be equal to number of features"
26 | out = embedding(x)
27 | if isinstance(out, dict):
28 | assert isinstance(kwargs["embedding_sizes"], dict)
29 | for name, o in out.items():
30 | assert (
31 | o.size(1) == embedding.output_size[name]
32 | ), "Output size should be equal to number of embedding dimensions"
33 | elif isinstance(out, torch.Tensor):
34 | assert isinstance(kwargs["embedding_sizes"], (tuple, list))
35 | assert (
36 | out.size(1) == embedding.output_size
37 | ), "Output size should be equal to number of summed embedding dimensions"
38 | else:
39 | raise ValueError(f"Unknown output type {type(out)}")
40 |
--------------------------------------------------------------------------------
/tests/test_models/test_nn/test_rnn.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import pytest
4 | import torch
5 | from torch import nn
6 |
7 | from pytorch_forecasting.models.nn.rnn import GRU, LSTM, get_rnn
8 |
9 |
10 | def test_get_lstm_cell():
11 | cell = get_rnn("LSTM")(10, 10)
12 | assert isinstance(cell, LSTM)
13 | assert isinstance(cell, nn.LSTM)
14 |
15 |
16 | def test_get_gru_cell():
17 | cell = get_rnn("GRU")(10, 10)
18 | assert isinstance(cell, GRU)
19 | assert isinstance(cell, nn.GRU)
20 |
21 |
22 | def test_get_cell_raises_value_error():
23 | pytest.raises(ValueError, lambda: get_rnn("ABCDEF"))
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "klass,rnn_kwargs",
28 | itertools.product(
29 | [LSTM, GRU],
30 | [
31 | dict(batch_first=True, num_layers=1),
32 | dict(batch_first=False, num_layers=2),
33 | ],
34 | ),
35 | )
36 | def test_zero_length_sequence(klass, rnn_kwargs):
37 | rnn = klass(input_size=2, hidden_size=5, **rnn_kwargs)
38 | x = torch.rand(100, 3, 2)
39 | lengths = torch.randint(0, 3, size=([3, 100][rnn_kwargs["batch_first"]],))
40 | _, hidden_state = rnn(x, lengths=lengths, enforce_sorted=False)
41 | init_hidden_state = rnn.init_hidden_state(x)
42 |
43 | if isinstance(hidden_state, torch.Tensor):
44 | hidden_state = [hidden_state]
45 | init_hidden_state = [init_hidden_state]
46 |
47 | for idx in range(len(hidden_state)):
48 | assert (
49 | hidden_state[idx].size() == init_hidden_state[idx].size()
50 | ), "Hidden state sizes should be equal"
51 | assert (hidden_state[idx][:, lengths == 0] == 0).all() and (
52 | hidden_state[idx][:, lengths > 0] != 0
53 | ).all(), "Hidden state should be zero for zero-length sequences"
54 |
--------------------------------------------------------------------------------
/tests/test_models/test_rnn_model.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import shutil
3 |
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.callbacks import EarlyStopping
6 | from lightning.pytorch.loggers import TensorBoardLogger
7 | import pytest
8 | from test_models.conftest import make_dataloaders
9 |
10 | from pytorch_forecasting.data.encoders import GroupNormalizer
11 | from pytorch_forecasting.models import RecurrentNetwork
12 |
13 |
14 | def _integration(
15 | data_with_covariates,
16 | tmp_path,
17 | cell_type="LSTM",
18 | data_loader_kwargs={},
19 | clip_target: bool = False,
20 | **kwargs,
21 | ):
22 | data_with_covariates = data_with_covariates.copy()
23 | if clip_target:
24 | data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0)
25 | else:
26 | data_with_covariates["target"] = data_with_covariates["volume"]
27 | data_loader_default_kwargs = dict(
28 | target="target",
29 | time_varying_known_reals=["price_actual"],
30 | time_varying_unknown_reals=["target"],
31 | static_categoricals=["agency"],
32 | add_relative_time_idx=True,
33 | )
34 | data_loader_default_kwargs.update(data_loader_kwargs)
35 | dataloaders_with_covariates = make_dataloaders(
36 | data_with_covariates, **data_loader_default_kwargs
37 | )
38 | train_dataloader = dataloaders_with_covariates["train"]
39 | val_dataloader = dataloaders_with_covariates["val"]
40 | test_dataloader = dataloaders_with_covariates["test"]
41 |
42 | early_stop_callback = EarlyStopping(
43 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min"
44 | )
45 |
46 | logger = TensorBoardLogger(tmp_path)
47 | trainer = pl.Trainer(
48 | max_epochs=3,
49 | gradient_clip_val=0.1,
50 | callbacks=[early_stop_callback],
51 | enable_checkpointing=True,
52 | default_root_dir=tmp_path,
53 | limit_train_batches=2,
54 | limit_val_batches=2,
55 | limit_test_batches=2,
56 | logger=logger,
57 | )
58 |
59 | net = RecurrentNetwork.from_dataset(
60 | train_dataloader.dataset,
61 | cell_type=cell_type,
62 | learning_rate=0.15,
63 | log_gradient_flow=True,
64 | log_interval=1000,
65 | hidden_size=5,
66 | **kwargs,
67 | )
68 | net.size()
69 | try:
70 | trainer.fit(
71 | net,
72 | train_dataloaders=train_dataloader,
73 | val_dataloaders=val_dataloader,
74 | )
75 | test_outputs = trainer.test(net, dataloaders=test_dataloader)
76 | assert len(test_outputs) > 0
77 | # check loading
78 | net = RecurrentNetwork.load_from_checkpoint(
79 | trainer.checkpoint_callback.best_model_path
80 | )
81 |
82 | # check prediction
83 | net.predict(
84 | val_dataloader,
85 | fast_dev_run=True,
86 | return_index=True,
87 | return_decoder_lengths=True,
88 | )
89 | finally:
90 | shutil.rmtree(tmp_path, ignore_errors=True)
91 |
92 | net.predict(
93 | val_dataloader,
94 | fast_dev_run=True,
95 | return_index=True,
96 | return_decoder_lengths=True,
97 | )
98 |
99 |
100 | @pytest.mark.parametrize(
101 | "kwargs",
102 | [
103 | {},
104 | {"cell_type": "GRU"},
105 | dict(
106 | data_loader_kwargs=dict(
107 | target_normalizer=GroupNormalizer(
108 | groups=["agency", "sku"], center=False
109 | )
110 | ),
111 | ),
112 | dict(
113 | data_loader_kwargs=dict(
114 | lags={"volume": [2, 5]},
115 | target="volume",
116 | time_varying_unknown_reals=["volume"],
117 | min_encoder_length=2,
118 | )
119 | ),
120 | dict(
121 | data_loader_kwargs=dict(
122 | time_varying_unknown_reals=["volume", "discount"],
123 | target=["volume", "discount"],
124 | lags={"volume": [2], "discount": [2]},
125 | )
126 | ),
127 | ],
128 | )
129 | def test_integration(data_with_covariates, tmp_path, kwargs):
130 | _integration(data_with_covariates, tmp_path, **kwargs)
131 |
132 |
133 | @pytest.fixture(scope="session")
134 | def model(dataloaders_with_covariates):
135 | dataset = dataloaders_with_covariates["train"].dataset
136 | net = RecurrentNetwork.from_dataset(
137 | dataset,
138 | learning_rate=0.15,
139 | log_gradient_flow=True,
140 | log_interval=1000,
141 | hidden_size=5,
142 | )
143 | return net
144 |
145 |
146 | def test_pickle(model):
147 | pkl = pickle.dumps(model)
148 | pickle.loads(pkl) # noqa: S301
149 |
--------------------------------------------------------------------------------
/tests/test_models/test_tide.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import shutil
3 |
4 | import lightning.pytorch as pl
5 | from lightning.pytorch.callbacks import EarlyStopping
6 | from lightning.pytorch.loggers import TensorBoardLogger
7 | import numpy as np
8 | import pandas as pd
9 | import pytest
10 |
11 | from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
12 | from pytorch_forecasting.metrics import SMAPE
13 | from pytorch_forecasting.models import TiDEModel
14 | from pytorch_forecasting.tests._conftest import make_dataloaders
15 | from pytorch_forecasting.utils._dependencies import _get_installed_packages
16 |
17 |
18 | def _integration(
19 | estimator_cls,
20 | data_with_covariates,
21 | tmp_path,
22 | data_loader_kwargs={},
23 | clip_target: bool = False,
24 | trainer_kwargs=None,
25 | **kwargs,
26 | ):
27 | data_with_covariates = data_with_covariates.copy()
28 | if clip_target:
29 | data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0)
30 | else:
31 | data_with_covariates["target"] = data_with_covariates["volume"]
32 | data_loader_default_kwargs = dict(
33 | target="target",
34 | time_varying_known_reals=["price_actual"],
35 | time_varying_unknown_reals=["target"],
36 | static_categoricals=["agency"],
37 | add_relative_time_idx=True,
38 | )
39 | data_loader_default_kwargs.update(data_loader_kwargs)
40 | dataloaders_with_covariates = make_dataloaders(
41 | data_with_covariates, **data_loader_default_kwargs
42 | )
43 |
44 | train_dataloader = dataloaders_with_covariates["train"]
45 | val_dataloader = dataloaders_with_covariates["val"]
46 | test_dataloader = dataloaders_with_covariates["test"]
47 |
48 | early_stop_callback = EarlyStopping(
49 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min"
50 | )
51 |
52 | logger = TensorBoardLogger(tmp_path)
53 | if trainer_kwargs is None:
54 | trainer_kwargs = {}
55 | trainer = pl.Trainer(
56 | max_epochs=3,
57 | gradient_clip_val=0.1,
58 | callbacks=[early_stop_callback],
59 | enable_checkpointing=True,
60 | default_root_dir=tmp_path,
61 | limit_train_batches=2,
62 | limit_val_batches=2,
63 | limit_test_batches=2,
64 | logger=logger,
65 | **trainer_kwargs,
66 | )
67 |
68 | net = estimator_cls.from_dataset(
69 | train_dataloader.dataset,
70 | hidden_size=5,
71 | learning_rate=0.01,
72 | log_gradient_flow=True,
73 | log_interval=1000,
74 | **kwargs,
75 | )
76 | net.size()
77 | try:
78 | trainer.fit(
79 | net,
80 | train_dataloaders=train_dataloader,
81 | val_dataloaders=val_dataloader,
82 | )
83 | test_outputs = trainer.test(net, dataloaders=test_dataloader)
84 | assert len(test_outputs) > 0
85 | # check loading
86 | net = estimator_cls.load_from_checkpoint(
87 | trainer.checkpoint_callback.best_model_path
88 | )
89 |
90 | # check prediction
91 | net.predict(
92 | val_dataloader,
93 | fast_dev_run=True,
94 | return_index=True,
95 | return_decoder_lengths=True,
96 | trainer_kwargs=trainer_kwargs,
97 | )
98 | finally:
99 | shutil.rmtree(tmp_path, ignore_errors=True)
100 |
101 | net.predict(
102 | val_dataloader,
103 | fast_dev_run=True,
104 | return_index=True,
105 | return_decoder_lengths=True,
106 | trainer_kwargs=trainer_kwargs,
107 | )
108 |
109 |
110 | def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs):
111 | """TiDE specific wrapper around the common integration test function.
112 |
113 | Args:
114 | dataloaders: Dictionary of dataloaders for train, val, and test.
115 | tmp_path: Temporary path for saving the model.
116 | trainer_kwargs: Additional arguments for the Trainer.
117 | **kwargs: Additional arguments for the TiDEModel.
118 |
119 | Returns:
120 | Predictions from the trained model.
121 | """
122 | from pytorch_forecasting.tests._data_scenarios import data_with_covariates
123 |
124 | df = data_with_covariates()
125 |
126 | tide_kwargs = {
127 | "temporal_decoder_hidden": 8,
128 | "temporal_width_future": 4,
129 | "dropout": 0.1,
130 | }
131 |
132 | tide_kwargs.update(kwargs)
133 | train_dataset = dataloaders["train"].dataset
134 |
135 | data_loader_kwargs = {
136 | "target": train_dataset.target,
137 | "group_ids": train_dataset.group_ids,
138 | "time_varying_known_reals": train_dataset.time_varying_known_reals,
139 | "time_varying_unknown_reals": train_dataset.time_varying_unknown_reals,
140 | "static_categoricals": train_dataset.static_categoricals,
141 | "static_reals": train_dataset.static_reals,
142 | "add_relative_time_idx": train_dataset.add_relative_time_idx,
143 | }
144 | return _integration(
145 | TiDEModel,
146 | df,
147 | tmp_path,
148 | data_loader_kwargs=data_loader_kwargs,
149 | trainer_kwargs=trainer_kwargs,
150 | **tide_kwargs,
151 | )
152 |
153 |
154 | @pytest.mark.parametrize(
155 | "kwargs",
156 | [
157 | {},
158 | {"loss": SMAPE()},
159 | {"temporal_decoder_hidden": 16},
160 | {"dropout": 0.2, "use_layer_norm": True},
161 | ],
162 | )
163 | def test_integration(dataloaders_with_covariates, tmp_path, kwargs):
164 | _tide_integration(dataloaders_with_covariates, tmp_path, **kwargs)
165 |
166 |
167 | @pytest.mark.parametrize(
168 | "kwargs",
169 | [
170 | {},
171 | ],
172 | )
173 | def test_multi_target_integration(dataloaders_multi_target, tmp_path, kwargs):
174 | _tide_integration(dataloaders_multi_target, tmp_path, **kwargs)
175 |
176 |
177 | @pytest.fixture
178 | def model(dataloaders_with_covariates):
179 | dataset = dataloaders_with_covariates["train"].dataset
180 | net = TiDEModel.from_dataset(
181 | dataset,
182 | hidden_size=16,
183 | dropout=0.1,
184 | temporal_width_future=4,
185 | )
186 | return net
187 |
188 |
189 | def test_pickle(model):
190 | pkl = pickle.dumps(model)
191 | pickle.loads(pkl) # noqa: S301
192 |
193 |
194 | @pytest.mark.skipif(
195 | "matplotlib" not in _get_installed_packages(),
196 | reason="skip test if required package matplotlib not installed",
197 | )
198 | def test_prediction_visualization(model, dataloaders_with_covariates):
199 | raw_predictions = model.predict(
200 | dataloaders_with_covariates["val"],
201 | mode="raw",
202 | return_x=True,
203 | fast_dev_run=True,
204 | )
205 | model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0)
206 |
207 |
208 | def test_prediction_with_kwargs(model, dataloaders_with_covariates):
209 | # Tests prediction works with different keyword arguments
210 | model.predict(
211 | dataloaders_with_covariates["val"], return_index=True, fast_dev_run=True
212 | )
213 | model.predict(
214 | dataloaders_with_covariates["val"],
215 | return_x=True,
216 | return_y=True,
217 | fast_dev_run=True,
218 | )
219 |
220 |
221 | def test_no_exogenous_variable():
222 | data = pd.DataFrame(
223 | {
224 | "target": np.ones(1600),
225 | "group_id": np.repeat(np.arange(16), 100),
226 | "time_idx": np.tile(np.arange(100), 16),
227 | }
228 | )
229 | training_dataset = TimeSeriesDataSet(
230 | data=data,
231 | time_idx="time_idx",
232 | target="target",
233 | group_ids=["group_id"],
234 | max_encoder_length=10,
235 | max_prediction_length=5,
236 | time_varying_unknown_reals=["target"],
237 | time_varying_known_reals=[],
238 | )
239 | validation_dataset = TimeSeriesDataSet.from_dataset(
240 | training_dataset, data, stop_randomization=True, predict=True
241 | )
242 | training_data_loader = training_dataset.to_dataloader(
243 | train=True, batch_size=8, num_workers=0
244 | )
245 | validation_data_loader = validation_dataset.to_dataloader(
246 | train=False, batch_size=8, num_workers=0
247 | )
248 | forecaster = TiDEModel.from_dataset(
249 | training_dataset,
250 | )
251 | from lightning.pytorch import Trainer
252 |
253 | trainer = Trainer(
254 | max_epochs=2,
255 | limit_train_batches=8,
256 | limit_val_batches=8,
257 | )
258 | trainer.fit(
259 | forecaster,
260 | train_dataloaders=training_data_loader,
261 | val_dataloaders=validation_data_loader,
262 | )
263 | best_model_path = trainer.checkpoint_callback.best_model_path
264 | best_model = TiDEModel.load_from_checkpoint(best_model_path)
265 | best_model.predict(
266 | validation_data_loader,
267 | fast_dev_run=True,
268 | return_x=True,
269 | return_y=True,
270 | return_index=True,
271 | )
272 |
--------------------------------------------------------------------------------
/tests/test_utils/test_autocorrelation.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 | from pytorch_forecasting.utils import autocorrelation
6 |
7 |
8 | def test_autocorrelation():
9 | x = torch.sin(torch.linspace(0, 2 * 2 * math.pi, 201))
10 | corr = autocorrelation(x, dim=-1)
11 | assert corr[0] == 1, "Autocorrelation of first element should be 1."
12 | assert corr[101] > 0.99, "Autocorrelation should be near 1 for sin(2*pi)"
13 | assert corr[50] < -0.99, "Autocorrelation should be near -1 for sin(pi)"
14 |
--------------------------------------------------------------------------------
/tests/test_utils/test_safe_import.py:
--------------------------------------------------------------------------------
1 | from pytorch_forecasting.utils._dependencies import _safe_import
2 |
3 |
4 | def test_present_module():
5 | """Test importing a dependency that is installed."""
6 | module = _safe_import("torch")
7 | assert module is not None
8 |
9 |
10 | def test_import_missing_module():
11 | """Test importing a dependency that is not installed."""
12 | result = _safe_import("nonexistent_module")
13 | assert hasattr(result, "__name__")
14 | assert result.__name__ == "nonexistent_module"
15 |
16 |
17 | def test_import_without_pkg_name():
18 | """Test importing a dependency with the same name as package name."""
19 | result = _safe_import("torch", pkg_name="torch")
20 | assert result is not None
21 |
22 |
23 | def test_import_with_different_pkg_name_1():
24 | """Test importing a dependency with a different package name."""
25 | result = _safe_import("skbase", pkg_name="scikit-base")
26 | assert result is not None
27 |
28 |
29 | def test_import_with_different_pkg_name_2():
30 | """Test importing another dependency with a different package name."""
31 | result = _safe_import("cv2", pkg_name="opencv-python")
32 | assert result is not None
33 |
34 |
35 | def test_import_submodule():
36 | """Test importing a submodule."""
37 | result = _safe_import("torch.nn")
38 | assert result is not None
39 |
40 |
41 | def test_import_class():
42 | """Test importing a class."""
43 | result = _safe_import("torch.nn.Linear")
44 | assert result is not None
45 |
46 |
47 | def test_import_existing_object():
48 | """Test importing an existing object."""
49 | result = _safe_import("pandas.DataFrame")
50 | assert result is not None
51 | assert result.__name__ == "DataFrame"
52 | from pandas import DataFrame
53 |
54 | assert result is DataFrame
55 |
56 |
57 | def test_multiple_inheritance_from_mock():
58 | """Test multiple inheritance from dynamic MagicMock."""
59 | Class1 = _safe_import("foobar.foo.FooBar")
60 | Class2 = _safe_import("barfoobar.BarFooBar")
61 |
62 | class NewClass(Class1, Class2):
63 | """This should not trigger an error.
64 |
65 | The class definition would trigger an error if multiple inheritance
66 | from Class1 and Class2 does not work, e.g., if it is simply
67 | identical to MagicMock.
68 | """
69 |
70 | pass
71 |
--------------------------------------------------------------------------------
/tests/test_utils/test_show_versions.py:
--------------------------------------------------------------------------------
1 | """Tests for the show_versions utility."""
2 |
3 | import pathlib
4 | import uuid
5 |
6 | from pytorch_forecasting.utils._maint._show_versions import (
7 | DEFAULT_DEPS_TO_SHOW,
8 | _get_deps_info,
9 | show_versions,
10 | )
11 |
12 |
13 | def test_show_versions_runs():
14 | """Test that show_versions runs without exceptions."""
15 | # only prints, should return None
16 | assert show_versions() is None
17 |
18 |
19 | def test_show_versions_import_loc():
20 | """Test that show_version can be imported from root."""
21 | from pytorch_forecasting import show_versions as show_versions_imported
22 |
23 | assert show_versions == show_versions_imported
24 |
25 |
26 | def test_deps_info():
27 | """Test that _get_deps_info returns package/version dict as per contract."""
28 | deps_info = _get_deps_info()
29 | assert isinstance(deps_info, dict)
30 | assert set(deps_info.keys()) == {"pytorch-forecasting"}
31 |
32 | deps_info_default = _get_deps_info(DEFAULT_DEPS_TO_SHOW)
33 | assert isinstance(deps_info_default, dict)
34 | assert set(deps_info_default.keys()) == set(DEFAULT_DEPS_TO_SHOW)
35 |
36 |
37 | def test_deps_info_deps_missing_package_present_directory():
38 | """Test that _get_deps_info does not fail if a dependency is missing."""
39 | dummy_package_name = uuid.uuid4().hex
40 |
41 | dummy_folder_path = pathlib.Path(dummy_package_name)
42 | dummy_folder_path.mkdir()
43 |
44 | assert _get_deps_info([dummy_package_name]) == {dummy_package_name: None}
45 |
46 | dummy_folder_path.rmdir()
47 |
--------------------------------------------------------------------------------