├── .github ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── autoapprove.yml.disabled │ ├── automerge.yml.disabled │ ├── pypi_release.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CHANGELOG.md ├── CODEOWNERS ├── LICENSE ├── README.md ├── build_tools ├── changelog.py └── run_examples.sh ├── codecov.yml ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ ├── custom.css │ ├── favicon.png │ ├── favicon.svg │ └── logo.svg │ ├── _templates │ ├── custom-base-template.rst │ ├── custom-class-template.rst │ └── custom-module-template.rst │ ├── api.rst │ ├── conf.py │ ├── data.rst │ ├── faq.rst │ ├── getting-started.rst │ ├── index.rst │ ├── installation.rst │ ├── metrics.rst │ ├── models.rst │ ├── tutorials.rst │ └── tutorials │ ├── ar.ipynb │ ├── building.ipynb │ ├── deepar.ipynb │ ├── nhits.ipynb │ └── stallion.ipynb ├── examples ├── ar.py ├── data │ └── stallion.parquet ├── nbeats.py └── stallion.py ├── pyproject.toml ├── pytorch_forecasting ├── __init__.py ├── _registry │ ├── __init__.py │ └── _lookup.py ├── data │ ├── __init__.py │ ├── data_module.py │ ├── encoders.py │ ├── examples.py │ ├── samplers.py │ └── timeseries │ │ ├── __init__.py │ │ ├── _timeseries.py │ │ └── _timeseries_v2.py ├── metrics │ ├── __init__.py │ ├── _mqf2_utils.py │ ├── base_metrics.py │ ├── distributions.py │ ├── point.py │ └── quantile.py ├── models │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── _base_model.py │ │ ├── _base_model_v2.py │ │ └── _base_object.py │ ├── base_model.py │ ├── baseline.py │ ├── deepar │ │ ├── __init__.py │ │ ├── _deepar.py │ │ └── _deepar_metadata.py │ ├── mlp │ │ ├── __init__.py │ │ ├── _decodermlp.py │ │ └── submodules.py │ ├── nbeats │ │ ├── __init__.py │ │ ├── _nbeats.py │ │ ├── _nbeats_metadata.py │ │ └── sub_modules.py │ ├── nhits │ │ ├── __init__.py │ │ ├── _nhits.py │ │ └── sub_modules.py │ ├── nn │ │ ├── __init__.py │ │ ├── embeddings.py │ │ └── rnn.py │ ├── rnn │ │ ├── __init__.py │ │ └── _rnn.py │ ├── temporal_fusion_transformer │ │ ├── __init__.py │ │ ├── _tft.py │ │ ├── _tft_v2.py │ │ ├── sub_modules.py │ │ └── tuning.py │ └── tide │ │ ├── __init__.py │ │ ├── _tide.py │ │ ├── _tide_metadata.py │ │ └── sub_modules.py ├── tests │ ├── __init__.py │ ├── _config.py │ ├── _conftest.py │ ├── _data_scenarios.py │ └── test_all_estimators.py └── utils │ ├── __init__.py │ ├── _coerce.py │ ├── _dependencies │ ├── __init__.py │ ├── _dependencies.py │ ├── _safe_import.py │ └── tests │ │ ├── __init__.py │ │ └── test_safe_import.py │ ├── _maint │ ├── __init__.py │ └── _show_versions.py │ └── _utils.py └── tests ├── conftest.py ├── test_data ├── test_d1.py ├── test_data_module.py ├── test_encoders.py ├── test_samplers.py └── test_timeseries.py ├── test_metrics.py ├── test_models ├── _test_tft_v2.py ├── conftest.py ├── test_baseline.py ├── test_deepar.py ├── test_mlp.py ├── test_nbeats.py ├── test_nhits.py ├── test_nn │ ├── test_embeddings.py │ └── test_rnn.py ├── test_rnn_model.py ├── test_temporal_fusion_transformer.py └── test_tide.py └── test_utils ├── test_autocorrelation.py ├── test_safe_import.py └── test_show_versions.py /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | - PyTorch-Forecasting version: 2 | - PyTorch version: 3 | - Python version: 4 | - Operating System: 5 | 6 | ### Expected behavior 7 | 8 | I executed code ... in order to ... and expected to get result ... 9 | 10 | ### Actual behavior 11 | 12 | However, result was .... I think it has to do with ... because of ... 13 | 14 | ### Code to reproduce the problem 15 | 16 | ``` 17 | 18 | ``` 19 | 20 | Paste the command(s) you ran and the output. Including a link to a colab notebook will speed up issue resolution. 21 | If there was a crash, please include the traceback here. 22 | The code used to initialize the TimeSeriesDataSet and model should be also included. 23 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Description 2 | 3 | This PR ... 4 | 5 | ### Checklist 6 | 7 | - [ ] Linked issues (if existing) 8 | - [ ] Amended changelog for large changes (and added myself there as contributor) 9 | - [ ] Added/modified tests 10 | - [ ] Used pre-commit hooks when committing to ensure that code is compliant with hooks. Install hooks with `pre-commit install`. 11 | To run hooks independent of commit, execute `pre-commit run --all-files` 12 | 13 | Make sure to have fun coding! 14 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | commit-message: 8 | prefix: "[MNT] [Dependabot]" 9 | include: "scope" 10 | labels: 11 | - "maintenance" 12 | - package-ecosystem: "github-actions" 13 | directory: "/" 14 | schedule: 15 | interval: "daily" 16 | commit-message: 17 | prefix: "[MNT] [Dependabot]" 18 | include: "scope" 19 | labels: 20 | - "maintenance" 21 | -------------------------------------------------------------------------------- /.github/workflows/autoapprove.yml.disabled: -------------------------------------------------------------------------------- 1 | name: Dependabot auto-approve 2 | on: pull_request 3 | 4 | permissions: 5 | pull-requests: write 6 | 7 | jobs: 8 | dependabot: 9 | runs-on: ubuntu-latest 10 | if: ${{ github.actor == 'dependabot[bot]' }} 11 | steps: 12 | - name: Dependabot metadata 13 | id: metadata 14 | uses: dependabot/fetch-metadata@v1.1.1 15 | with: 16 | github-token: "${{ secrets.GITHUB_TOKEN }}" 17 | - name: Approve a PR 18 | run: gh pr review --approve "$PR_URL" 19 | env: 20 | PR_URL: ${{github.event.pull_request.html_url}} 21 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} 22 | -------------------------------------------------------------------------------- /.github/workflows/automerge.yml.disabled: -------------------------------------------------------------------------------- 1 | name: Dependabot auto-merge 2 | on: pull_request 3 | 4 | permissions: 5 | pull-requests: write 6 | contents: write 7 | 8 | jobs: 9 | dependabot: 10 | runs-on: ubuntu-latest 11 | if: ${{ github.actor == 'dependabot[bot]' }} 12 | steps: 13 | - name: Dependabot metadata 14 | id: metadata 15 | uses: dependabot/fetch-metadata@v1.1.1 16 | with: 17 | github-token: "${{ secrets.GITHUB_TOKEN }}" 18 | - name: Enable auto-merge for Dependabot PRs 19 | run: gh pr merge --auto --merge "$PR_URL" 20 | env: 21 | PR_URL: ${{github.event.pull_request.html_url}} 22 | GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} 23 | -------------------------------------------------------------------------------- /.github/workflows/pypi_release.yml: -------------------------------------------------------------------------------- 1 | name: PyPI Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | build_wheels: 9 | name: Build wheels 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - uses: actions/setup-python@v5 16 | with: 17 | python-version: '3.11' 18 | 19 | - name: Build wheel 20 | run: | 21 | python -m pip install build 22 | python -m build --wheel --sdist --outdir wheelhouse 23 | 24 | - name: Store wheels 25 | uses: actions/upload-artifact@v4 26 | with: 27 | name: wheels 28 | path: wheelhouse/* 29 | 30 | pytest-nosoftdeps: 31 | name: no-softdeps 32 | runs-on: ${{ matrix.os }} 33 | needs: [build_wheels] 34 | strategy: 35 | fail-fast: false 36 | matrix: 37 | os: [ubuntu-latest, macos-latest, windows-latest] 38 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 39 | 40 | steps: 41 | - uses: actions/checkout@v4 42 | 43 | - name: Set up Python ${{ matrix.python-version }} 44 | uses: actions/setup-python@v5 45 | with: 46 | python-version: ${{ matrix.python-version }} 47 | 48 | - name: Setup macOS 49 | if: runner.os == 'macOS' 50 | run: | 51 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030 52 | 53 | - name: Get full Python version 54 | id: full-python-version 55 | shell: bash 56 | run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT 57 | 58 | - name: Install dependencies 59 | shell: bash 60 | run: | 61 | pip install ".[dev,github-actions]" 62 | 63 | - name: Show dependencies 64 | run: python -m pip list 65 | 66 | - name: Run pytest 67 | shell: bash 68 | run: python -m pytest tests 69 | 70 | upload_wheels: 71 | name: Upload wheels to PyPI 72 | runs-on: ubuntu-latest 73 | needs: [pytest-nosoftdeps] 74 | 75 | permissions: 76 | id-token: write 77 | 78 | steps: 79 | - uses: actions/download-artifact@v4 80 | with: 81 | name: wheels 82 | path: wheelhouse 83 | 84 | - name: Publish package to PyPI 85 | uses: pypa/gh-action-pypi-publish@release/v1 86 | with: 87 | packages-dir: wheelhouse/ 88 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Test 5 | 6 | on: 7 | push: 8 | branches: [main] 9 | pull_request: 10 | branches: [main] 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | code-quality: 18 | name: code-quality 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: repository checkout step 22 | uses: actions/checkout@v4 23 | 24 | - name: python environment step 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.11" 28 | 29 | - name: install pre-commit 30 | run: python3 -m pip install pre-commit 31 | 32 | - name: Checkout code 33 | uses: actions/checkout@v4 34 | with: 35 | fetch-depth: 0 36 | 37 | - name: Get changed files 38 | id: changed-files 39 | run: | 40 | CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.sha }} | tr '\n' ' ') 41 | echo "CHANGED_FILES=${CHANGED_FILES}" >> $GITHUB_ENV 42 | 43 | - name: Print changed files 44 | run: | 45 | echo "Changed files:" && echo "$CHANGED_FILES" | tr ' ' '\n' 46 | 47 | - name: Run pre-commit on changed files 48 | run: | 49 | if [ -n "$CHANGED_FILES" ]; then 50 | pre-commit run --color always --files $CHANGED_FILES --show-diff-on-failure 51 | else 52 | echo "No changed files to check." 53 | fi 54 | 55 | run-notebook-tutorials: 56 | name: Run notebook tutorials 57 | needs: code-quality 58 | runs-on: ubuntu-latest 59 | steps: 60 | - uses: actions/checkout@v4 61 | - name: Set up Python 62 | uses: actions/setup-python@v5 63 | with: 64 | python-version: 3.9 65 | - name: Install dependencies 66 | run: | 67 | python -m pip install --upgrade pip 68 | python -m pip install ".[dev,all_extras,github-actions]" 69 | 70 | - name: Show dependencies 71 | run: python -m pip list 72 | 73 | - name: Run example notebooks 74 | run: build_tools/run_examples.sh 75 | shell: bash 76 | 77 | pytest-nosoftdeps: 78 | name: no-softdeps 79 | needs: code-quality 80 | runs-on: ${{ matrix.os }} 81 | strategy: 82 | fail-fast: false 83 | matrix: 84 | os: [ubuntu-latest, macos-latest, windows-latest] 85 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 86 | 87 | steps: 88 | - uses: actions/checkout@v4 89 | 90 | - name: Set up Python ${{ matrix.python-version }} 91 | uses: actions/setup-python@v5 92 | with: 93 | python-version: ${{ matrix.python-version }} 94 | 95 | - name: Setup macOS 96 | if: runner.os == 'macOS' 97 | run: | 98 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030 99 | 100 | - name: Get full Python version 101 | id: full-python-version 102 | shell: bash 103 | run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT 104 | 105 | - name: Install dependencies 106 | shell: bash 107 | run: | 108 | pip install ".[dev,github-actions]" 109 | 110 | - name: Show dependencies 111 | run: python -m pip list 112 | 113 | - name: Run pytest 114 | shell: bash 115 | run: python -m pytest 116 | 117 | pytest: 118 | name: Run pytest 119 | needs: pytest-nosoftdeps 120 | runs-on: ${{ matrix.os }} 121 | strategy: 122 | fail-fast: false 123 | matrix: 124 | os: [ubuntu-latest, macos-latest, windows-latest] 125 | python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] 126 | 127 | steps: 128 | - uses: actions/checkout@v4 129 | 130 | - name: Set up Python ${{ matrix.python-version }} 131 | uses: actions/setup-python@v5 132 | with: 133 | python-version: ${{ matrix.python-version }} 134 | 135 | - name: Setup macOS 136 | if: runner.os == 'macOS' 137 | run: | 138 | brew install libomp # https://github.com/pytorch/pytorch/issues/20030 139 | 140 | - name: Get full Python version 141 | id: full-python-version 142 | shell: bash 143 | run: echo version=$(python -c "import sys; print('-'.join(str(v) for v in sys.version_info))") >> $GITHUB_OUTPUT 144 | 145 | - name: Install dependencies 146 | shell: bash 147 | run: | 148 | pip install ".[dev,all_extras,github-actions]" 149 | 150 | - name: Show dependencies 151 | run: python -m pip list 152 | 153 | - name: Run pytest 154 | shell: bash 155 | run: python -m pytest 156 | 157 | - name: Statistics 158 | run: | 159 | pip install coverage 160 | coverage report 161 | coverage xml 162 | 163 | - name: Upload coverage to Codecov 164 | uses: codecov/codecov-action@v5 165 | if: always() 166 | continue-on-error: true 167 | with: 168 | token: ${{ secrets.CODECOV_TOKEN }} 169 | file: coverage.xml 170 | flags: cpu,pytest 171 | name: CPU-coverage 172 | fail_ci_if_error: false 173 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/source/api/ 74 | docs/source/CHANGELOG.md 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # pycharm 134 | .idea 135 | 136 | # vscode 137 | .vscode 138 | 139 | # logs 140 | lightning_logs 141 | .history 142 | 143 | # checkpoints 144 | *.ckpt 145 | *.pkl 146 | .DS_Store 147 | 148 | # data 149 | pytorch_forecasting/data/*.parquet 150 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.6.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-ast 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: v0.6.9 13 | hooks: 14 | - id: ruff 15 | args: [--fix] 16 | - id: ruff-format 17 | - repo: https://github.com/nbQA-dev/nbQA 18 | rev: 1.8.7 19 | hooks: 20 | - id: nbqa-black 21 | - id: nbqa-ruff 22 | - id: nbqa-check-ast 23 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | # reference: https://docs.readthedocs.io/en/stable/config-file/v2.html#sphinx 10 | sphinx: 11 | configuration: docs/source/conf.py 12 | # fail_on_warning: true 13 | 14 | # Build documentation with MkDocs 15 | #mkdocs: 16 | # configuration: mkdocs.yml 17 | 18 | # Optionally build your docs in additional formats such as PDF and ePub 19 | formats: 20 | - htmlzip 21 | 22 | build: 23 | os: ubuntu-22.04 24 | tools: 25 | python: "3.12" 26 | 27 | python: 28 | install: 29 | - method: pip 30 | path: . 31 | extra_requirements: 32 | - docs 33 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # The file specifies framework level core developers for automated review requests 2 | 3 | * @benheid @fkiraly @fnhirwa @geetu040 @jdb78 @pranavvp16 @XinyuWuu @yarnabrina 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | THE MIT License 2 | 3 | Copyright (c) 2020 - present, the pytorch-forecasting developers 4 | Copyright (c) 2020 Jan Beitner 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. 23 | -------------------------------------------------------------------------------- /build_tools/changelog.py: -------------------------------------------------------------------------------- 1 | """RestructuredText changelog generator.""" 2 | 3 | from collections import defaultdict 4 | import os 5 | 6 | HEADERS = { 7 | "Accept": "application/vnd.github.v3+json", 8 | } 9 | 10 | if os.getenv("GITHUB_TOKEN") is not None: 11 | HEADERS["Authorization"] = f"token {os.getenv('GITHUB_TOKEN')}" 12 | 13 | OWNER = "jdb78" 14 | REPO = "pytorch-forecasting" 15 | GITHUB_REPOS = "https://api.github.com/repos" 16 | 17 | 18 | def fetch_merged_pull_requests(page: int = 1) -> list[dict]: 19 | """Fetch a page of merged pull requests. 20 | 21 | Parameters 22 | ---------- 23 | page : int, optional 24 | Page number to fetch, by default 1. 25 | Returns all merged pull request from the ``page``-th page of closed PRs, 26 | where pages are in descending order of last update. 27 | 28 | Returns 29 | ------- 30 | list 31 | List of merged pull requests from the ``page``-th page of closed PRs. 32 | Elements of list are dictionaries with PR details, as obtained 33 | from the GitHub API via ``httpx.get``, from the ``pulls`` endpoint. 34 | """ 35 | import httpx 36 | 37 | params = { 38 | "base": "main", 39 | "state": "closed", 40 | "page": page, 41 | "per_page": 50, 42 | "sort": "updated", 43 | "direction": "desc", 44 | } 45 | r = httpx.get( 46 | f"{GITHUB_REPOS}/{OWNER}/{REPO}/pulls", 47 | headers=HEADERS, 48 | params=params, 49 | ) 50 | return [pr for pr in r.json() if pr["merged_at"]] 51 | 52 | 53 | def fetch_latest_release(): # noqa: D103 54 | """Fetch the latest release from the GitHub API. 55 | 56 | Returns 57 | ------- 58 | dict 59 | Dictionary with details of the latest release. 60 | Dictionary is as obtained from the GitHub API via ``httpx.get``, 61 | for ``releases/latest`` endpoint. 62 | """ 63 | import httpx 64 | 65 | response = httpx.get( 66 | f"{GITHUB_REPOS}/{OWNER}/{REPO}/releases/latest", headers=HEADERS 67 | ) 68 | 69 | if response.status_code == 200: 70 | return response.json() 71 | else: 72 | raise ValueError(response.text, response.status_code) 73 | 74 | 75 | def fetch_pull_requests_since_last_release() -> list[dict]: 76 | """Fetch all pull requests merged since last release. 77 | 78 | Returns 79 | ------- 80 | list 81 | List of pull requests merged since the latest release. 82 | Elements of list are dictionaries with PR details, as obtained 83 | from the GitHub API via ``httpx.get``, through ``fetch_merged_pull_requests``. 84 | """ 85 | from dateutil import parser 86 | 87 | release = fetch_latest_release() 88 | published_at = parser.parse(release["published_at"]) 89 | print(f"Latest release {release['tag_name']} was published at {published_at}") 90 | 91 | is_exhausted = False 92 | page = 1 93 | all_pulls = [] 94 | while not is_exhausted: 95 | pulls = fetch_merged_pull_requests(page=page) 96 | all_pulls.extend( 97 | [p for p in pulls if parser.parse(p["merged_at"]) > published_at] 98 | ) 99 | is_exhausted = any(parser.parse(p["updated_at"]) < published_at for p in pulls) 100 | page += 1 101 | return all_pulls 102 | 103 | 104 | def github_compare_tags(tag_left: str, tag_right: str = "HEAD"): 105 | """Compare commit between two tags.""" 106 | import httpx 107 | 108 | response = httpx.get( 109 | f"{GITHUB_REPOS}/{OWNER}/{REPO}/compare/{tag_left}...{tag_right}" 110 | ) 111 | if response.status_code == 200: 112 | return response.json() 113 | else: 114 | raise ValueError(response.text, response.status_code) 115 | 116 | 117 | def render_contributors(prs: list, fmt: str = "rst"): 118 | """Find unique authors and print a list in given format.""" 119 | authors = sorted({pr["user"]["login"] for pr in prs}, key=lambda x: x.lower()) 120 | 121 | header = "Contributors" 122 | if fmt == "github": 123 | print(f"### {header}") 124 | print(", ".join(f"@{user}" for user in authors)) 125 | elif fmt == "rst": 126 | print(header) 127 | print("~" * len(header), end="\n\n") 128 | print(",\n".join(f":user:`{user}`" for user in authors)) 129 | 130 | 131 | def assign_prs(prs, categs: list[dict[str, list[str]]]): 132 | """Assign PR to categories based on labels.""" 133 | assigned = defaultdict(list) 134 | 135 | for i, pr in enumerate(prs): 136 | for cat in categs: 137 | pr_labels = [label["name"] for label in pr["labels"]] 138 | if not set(cat["labels"]).isdisjoint(set(pr_labels)): 139 | assigned[cat["title"]].append(i) 140 | 141 | # if any(l.startswith("module") for l in pr_labels): 142 | # print(i, pr_labels) 143 | 144 | assigned["Other"] = list( 145 | set(range(len(prs))) - {i for _, j in assigned.items() for i in j} 146 | ) 147 | 148 | return assigned 149 | 150 | 151 | def render_row(pr): 152 | """Render a single row with PR in restructuredText format.""" 153 | print( 154 | "*", 155 | pr["title"].replace("`", "``"), 156 | f"(:pr:`{pr['number']}`)", 157 | f":user:`{pr['user']['login']}`", 158 | ) 159 | 160 | 161 | def render_changelog(prs, assigned): 162 | # sourcery skip: use-named-expression 163 | """Render changelog.""" 164 | from dateutil import parser 165 | 166 | for title, _ in assigned.items(): 167 | pr_group = [prs[i] for i in assigned[title]] 168 | if pr_group: 169 | print(f"\n{title}") 170 | print("~" * len(title), end="\n\n") 171 | 172 | for pr in sorted(pr_group, key=lambda x: parser.parse(x["merged_at"])): 173 | render_row(pr) 174 | 175 | 176 | if __name__ == "__main__": 177 | categories = [ 178 | {"title": "Enhancements", "labels": ["feature", "enhancement"]}, 179 | {"title": "Fixes", "labels": ["bug", "fix", "bugfix"]}, 180 | {"title": "Maintenance", "labels": ["maintenance", "chore"]}, 181 | {"title": "Refactored", "labels": ["refactor"]}, 182 | {"title": "Documentation", "labels": ["documentation"]}, 183 | ] 184 | 185 | pulls = fetch_pull_requests_since_last_release() 186 | print(f"Found {len(pulls)} merged PRs since last release") 187 | assigned = assign_prs(pulls, categories) 188 | render_changelog(pulls, assigned) 189 | print() 190 | render_contributors(pulls) 191 | 192 | release = fetch_latest_release() 193 | diff = github_compare_tags(release["tag_name"]) 194 | if diff["total_commits"] != len(pulls): 195 | raise ValueError( 196 | "Something went wrong and not all PR were fetched. " 197 | f'There are {len(pulls)} PRs but {diff["total_commits"]} in the diff. ' 198 | "Please verify that all PRs are included in the changelog." 199 | ) 200 | -------------------------------------------------------------------------------- /build_tools/run_examples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script to run all example notebooks. 4 | # copy-paste from sktime's run_examples.sh 5 | set -euxo pipefail 6 | 7 | CMD="jupyter nbconvert --to notebook --inplace --execute --ExecutePreprocessor.timeout=1200" 8 | 9 | for notebook in docs/source/tutorials/*.ipynb; do 10 | echo "Running: $notebook" 11 | $CMD "$notebook" 12 | done 13 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | round: down 4 | range: "70...100" 5 | status: 6 | project: 7 | default: 8 | threshold: 0.2% 9 | patch: 10 | default: 11 | informational: true 12 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx >3.2 2 | nbsphinx 3 | pandoc 4 | docutils 5 | pydata-sphinx-theme 6 | lightning >=2.0.0 7 | cloudpickle 8 | torch >=2.0,!=2.0.1 9 | optuna >=3.1.0 10 | optuna-integration 11 | scipy 12 | pandas >=1.3 13 | scikit-learn >1.2 14 | matplotlib 15 | statsmodels 16 | ipython 17 | nbconvert >=6.3.0 18 | recommonmark >=0.7.1 19 | pytorch-optimizer >=2.5.1 20 | fastapi >0.80 21 | cpflows 22 | -------------------------------------------------------------------------------- /docs/source/_static/custom.css: -------------------------------------------------------------------------------- 1 | .container-xl { 2 | max-width: 4000px; 3 | } 4 | 5 | .bd-content { 6 | flex-grow: 1; 7 | max-width: 100%; 8 | } 9 | .bd-page-width { 10 | max-width: 100rem; 11 | } 12 | .bd-main .bd-content .bd-article-container { 13 | max-width: 100%; 14 | } 15 | 16 | a.reference.internal.nav-link { 17 | color: #727991 !important; 18 | } 19 | 20 | html[data-theme="light"] { 21 | --pst-color-primary: #ee4c2c; 22 | } 23 | a.nav-link 24 | { 25 | color: #647db6 !important; 26 | } 27 | 28 | a.nav-link[href="https://github.com/sktime/pytorch-forecasting"] 29 | { 30 | color: #ee4c2c !important; 31 | } 32 | 33 | code { 34 | color: #d14 !important; 35 | } 36 | 37 | pre code, 38 | a > code, 39 | .reference { 40 | color: #647db6 !important; 41 | } 42 | 43 | dt:target, 44 | span.highlighted { 45 | background-color: #fff !important; 46 | } 47 | 48 | /* code highlighting */ 49 | .highlight > pre > .mi { 50 | color: #d14 !important; 51 | } 52 | .highlight > pre > .mf { 53 | color: #d14 !important; 54 | } 55 | 56 | .highlight > pre > .s2 { 57 | color: #647db6 !important; 58 | } 59 | -------------------------------------------------------------------------------- /docs/source/_static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sktime/pytorch-forecasting/4384140418bfe8e8c64cc6ce5fc19372508bfccd/docs/source/_static/favicon.png -------------------------------------------------------------------------------- /docs/source/_static/favicon.svg: -------------------------------------------------------------------------------- 1 | 2 | 20 | 22 | 50 | 52 | 53 | 55 | image/svg+xml 56 | 58 | 59 | 60 | 61 | 62 | 67 | 70 | 76 | 82 | 88 | 94 | 99 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /docs/source/_static/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 17 | 19 | 45 | 47 | 48 | 50 | image/svg+xml 51 | 53 | 54 | 55 | 56 | 57 | 62 | PyTorch 73 | Forecasting 85 | 88 | 94 | 100 | 106 | 112 | 117 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /docs/source/_templates/custom-base-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname.split(".")[-1] | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/source/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname.split(".")[-1] | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | .. autoclass:: {{ objname }} 7 | :members: 8 | :show-inheritance: 9 | :exclude-members: __init__ 10 | {% set allow_inherited = "zero_grad" not in inherited_members %} {# no inheritance for torch.nn.Modules #} 11 | {%if allow_inherited %} 12 | :inherited-members: 13 | {% endif %} 14 | 15 | {% block methods %} 16 | {% set allowed_methods = [] %} 17 | {% for item in methods %}{% if not item.startswith("_") and (item not in inherited_members or allow_inherited) %} 18 | {% set a=allowed_methods.append(item) %} 19 | {% endif %}{%- endfor %} 20 | {% if allowed_methods %} 21 | .. rubric:: {{ _('Methods') }} 22 | 23 | .. autosummary:: 24 | {% for item in allowed_methods %} 25 | ~{{ name }}.{{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block attributes %} 31 | {% set dynamic_attributes = [] %} {# dynamic attributes are not documented #} 32 | {% set allowed_attributes = [] %} 33 | {% for item in attributes %}{% if not item.startswith("_") and (item not in inherited_members or allow_inherited) and (item not in dynamic_attributes) and allow_inherited %} 34 | {% set a=allowed_attributes.append(item) %} 35 | {% endif %}{%- endfor %} 36 | {% if allowed_attributes %} 37 | .. rubric:: {{ _('Attributes') }} 38 | 39 | .. autosummary:: 40 | {% for item in allowed_attributes %} 41 | ~{{ name }}.{{ item }} 42 | {%- endfor %} 43 | {% endif %} 44 | {% endblock %} 45 | -------------------------------------------------------------------------------- /docs/source/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname.split(".")[-1] | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: Module Attributes 8 | 9 | .. autosummary:: 10 | :toctree: 11 | {% for item in attributes %} 12 | {{ item }} 13 | {%- endfor %} 14 | {% endif %} 15 | {% endblock %} 16 | 17 | {% block functions %} 18 | {% if functions %} 19 | .. rubric:: {{ _('Functions') }} 20 | 21 | .. autosummary:: 22 | :toctree: 23 | :template: custom-base-template.rst 24 | {% for item in functions %} 25 | {{ item }} 26 | {%- endfor %} 27 | {% endif %} 28 | {% endblock %} 29 | 30 | {% block classes %} 31 | {% if classes %} 32 | .. rubric:: {{ _('Classes') }} 33 | 34 | .. autosummary:: 35 | :toctree: 36 | :template: custom-class-template.rst 37 | {% for item in classes %} 38 | {{ item }} 39 | {%- endfor %} 40 | {% endif %} 41 | {% endblock %} 42 | 43 | {% block exceptions %} 44 | {% if exceptions %} 45 | .. rubric:: {{ _('Exceptions') }} 46 | 47 | .. autosummary:: 48 | :toctree: 49 | {% for item in exceptions %} 50 | {{ item }} 51 | {%- endfor %} 52 | {% endif %} 53 | {% endblock %} 54 | 55 | {% block modules %} 56 | {% if all_modules %} 57 | .. rubric:: Modules 58 | 59 | .. autosummary:: 60 | :toctree: 61 | :template: custom-module-template.rst 62 | :recursive: 63 | {% for item in all_modules %} 64 | {{ item }} 65 | {%- endfor %} 66 | {% endif %} 67 | {% endblock %} 68 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API 2 | ==== 3 | 4 | .. currentmodule:: pytorch_forecasting 5 | 6 | .. autosummary:: 7 | :toctree: api 8 | :template: custom-module-template.rst 9 | :recursive: 10 | 11 | data 12 | models 13 | metrics 14 | utils 15 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/main/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | from pathlib import Path 15 | import shutil 16 | import sys 17 | 18 | from sphinx.application import Sphinx 19 | from sphinx.ext.autosummary import Autosummary 20 | from sphinx.pycode import ModuleAnalyzer 21 | 22 | SOURCE_PATH = Path(os.path.dirname(__file__)) # noqa # docs source 23 | PROJECT_PATH = SOURCE_PATH.joinpath("../..") # noqa # project root 24 | 25 | sys.path.insert(0, str(PROJECT_PATH)) # noqa 26 | 27 | import pytorch_forecasting # isort:skip 28 | 29 | # -- Project information ----------------------------------------------------- 30 | 31 | project = "pytorch-forecasting" 32 | copyright = "2020, Jan Beitner" 33 | author = "Jan Beitner" 34 | 35 | 36 | # -- General configuration --------------------------------------------------- 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | "nbsphinx", 43 | "recommonmark", 44 | "sphinx.ext.autodoc", 45 | "sphinx.ext.autosummary", 46 | "sphinx.ext.doctest", 47 | "sphinx.ext.intersphinx", 48 | "sphinx.ext.mathjax", 49 | "sphinx.ext.viewcode", 50 | "sphinx.ext.githubpages", 51 | "sphinx.ext.napoleon", 52 | ] 53 | 54 | # Add any paths that contain templates here, relative to this directory. 55 | templates_path = ["_templates"] 56 | 57 | # List of patterns, relative to source directory, that match files and 58 | # directories to ignore when looking for source files. 59 | # This pattern also affects html_static_path and html_extra_path. 60 | exclude_patterns = ["**/.ipynb_checkpoints"] 61 | 62 | 63 | # -- Options for HTML output ------------------------------------------------- 64 | 65 | # The theme to use for HTML and HTML Help pages. See the documentation for 66 | # a list of builtin themes. 67 | # 68 | html_theme = "pydata_sphinx_theme" 69 | html_logo = "_static/logo.svg" 70 | html_favicon = "_static/favicon.png" 71 | 72 | # Add any paths that contain custom static files (such as style sheets) here, 73 | # relative to this directory. They are copied after the builtin static files, 74 | # so a file named "default.css" will overwrite the builtin "default.css". 75 | html_static_path = ["_static"] 76 | 77 | 78 | # setup configuration 79 | def skip(app, what, name, obj, skip, options): 80 | """ 81 | Document __init__ methods 82 | """ 83 | if name == "__init__": 84 | return True 85 | return skip 86 | 87 | 88 | apidoc_output_folder = SOURCE_PATH.joinpath("api") 89 | 90 | PACKAGES = [pytorch_forecasting.__name__] 91 | 92 | 93 | def get_by_name(string: str): 94 | """ 95 | Import by name and return imported module/function/class 96 | 97 | Parameters 98 | ---------- 99 | string (str): 100 | module/function/class to import, e.g. 'pandas.read_csv' 101 | will return read_csv function as defined by pandas 102 | 103 | Returns 104 | ------- 105 | imported object 106 | """ 107 | class_name = string.split(".")[-1] 108 | module_name = ".".join(string.split(".")[:-1]) 109 | 110 | if module_name == "": 111 | return getattr(sys.modules[__name__], class_name) 112 | 113 | mod = __import__(module_name, fromlist=[class_name]) 114 | return getattr(mod, class_name) 115 | 116 | 117 | class ModuleAutoSummary(Autosummary): 118 | def get_items(self, names): 119 | new_names = [] 120 | for name in names: 121 | mod = sys.modules[name] 122 | mod_items = getattr(mod, "__all__", mod.__dict__) 123 | for t in mod_items: 124 | if "." not in t and not t.startswith("_"): 125 | obj = get_by_name(f"{name}.{t}") 126 | if hasattr(obj, "__module__"): 127 | mod_name = obj.__module__ 128 | t = f"{mod_name}.{t}" 129 | if t.startswith("pytorch_forecasting"): 130 | new_names.append(t) 131 | new_items = super().get_items(sorted(new_names)) 132 | return new_items 133 | 134 | 135 | def setup(app: Sphinx): 136 | app.add_css_file("custom.css") 137 | app.connect("autodoc-skip-member", skip) 138 | app.add_directive("moduleautosummary", ModuleAutoSummary) 139 | app.add_js_file("https://buttons.github.io/buttons.js", **{"async": "async"}) 140 | 141 | 142 | # extension configuration 143 | mathjax_path = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/MathJax.js?config=TeX-MML-AM_CHTML" 144 | 145 | # theme options 146 | html_theme_options = { 147 | "github_url": "https://github.com/sktime/pytorch-forecasting", 148 | "navbar_end": ["navbar-icon-links.html", "search-field.html"], 149 | "show_nav_level": 2, 150 | "header_links_before_dropdown": 10, 151 | "external_links": [ 152 | {"name": "GitHub", "url": "https://github.com/sktime/pytorch-forecasting"} 153 | ], 154 | } 155 | 156 | html_sidebars = { 157 | "index": [], 158 | # "getting-started": [], 159 | # "data": [], 160 | # "models": [], 161 | # "metrics": [], 162 | "faq": [], 163 | "contribute": [], 164 | "CHANGELOG": [], 165 | } 166 | 167 | 168 | autodoc_member_order = "groupwise" 169 | autoclass_content = "both" 170 | 171 | # autosummary 172 | autosummary_generate = True 173 | shutil.rmtree(SOURCE_PATH.joinpath("api"), ignore_errors=True) 174 | 175 | # copy changelog 176 | shutil.copy( 177 | "../../CHANGELOG.md", 178 | "CHANGELOG.md", 179 | ) 180 | 181 | intersphinx_mapping = { 182 | "sklearn": ("https://scikit-learn.org/stable/", None), 183 | } 184 | 185 | suppress_warnings = [ 186 | "autosummary.import_cycle", 187 | ] 188 | 189 | # -----------nbsphinx extension ---------- 190 | nbsphinx_execute = "never" # always 191 | nbsphinx_allow_errors = False # False 192 | nbsphinx_timeout = 600 # seconds 193 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | ===== 3 | 4 | .. currentmodule:: pytorch_forecasting.data 5 | 6 | Loading data for timeseries forecasting is not trivial - in particular if covariates are included and values are missing. 7 | PyTorch Forecasting provides the :py:class:`~timeseries.TimeSeriesDataSet` which comes with a :py:meth:`~timeseries.TimeSeriesDataSet.to_dataloader` 8 | method to convert it to a dataloader and a :py:meth:`~timeseries.TimeSeriesDataSet.from_dataset` method to create, e.g. a validation 9 | or test dataset from a training dataset using the same label encoders and data normalization. 10 | 11 | Further, timeseries have to be (almost always) normalized for a neural network to learn efficiently. PyTorch Forecasting 12 | provides multiple such target normalizers (some of which can also be used for normalizing covariates). 13 | 14 | 15 | Time series data set 16 | --------------------- 17 | 18 | The time series dataset is the central data-holding object in PyTorch Forecasting. It primarily takes 19 | a pandas DataFrame along with some metadata. See the :ref:`tutorial on passing data to models ` to learn more it is coupled to models. 20 | 21 | .. autoclass:: pytorch_forecasting.data.timeseries.TimeSeriesDataSet 22 | :noindex: 23 | :members: __init__ 24 | 25 | Details 26 | -------- 27 | 28 | See the API documentation for further details on available data encoders and the :py:class:`~timeseries.TimeSeriesDataSet`: 29 | 30 | .. currentmodule:: pytorch_forecasting 31 | 32 | .. moduleautosummary:: 33 | :toctree: api/ 34 | :template: custom-module-template.rst 35 | :recursive: 36 | 37 | pytorch_forecasting.data 38 | -------------------------------------------------------------------------------- /docs/source/faq.rst: -------------------------------------------------------------------------------- 1 | FAQ 2 | ==== 3 | 4 | .. currentmodule:: pytorch_forecasting 5 | 6 | Common issues and answers. Other places to seek help from: 7 | 8 | * :ref:`Tutorials ` 9 | * `PyTorch Lightning documentation `_ and issues 10 | * `PyTorch documentation `_ and issues 11 | * `Stack Overflow `_ 12 | 13 | 14 | Creating datasets 15 | ----------------- 16 | 17 | * **How do I create a dataset for new samples?** 18 | 19 | Use the :py:class:`~data.timeseries.TimeSeriesDataSet` method of your training dataset to 20 | create datasets on which you can run inference. 21 | 22 | * **How long should the encoder and decoder/prediction length be?** 23 | 24 | .. _faq_encoder_decoder_length: 25 | 26 | Choose something reasonably long, but not much longer than 500 for the encoder length and 27 | 200 for the decoder length. Consider that longer lengths increase the time it takes 28 | for your model to train. 29 | 30 | The ratio of decoder and encoder length depends on the used alogrithm. 31 | Look at :ref:`documentation ` to get clues. 32 | 33 | * **It takes very long to create the dataset. Why is that?** 34 | 35 | If you set ``allow_missing_timesteps=True`` in your dataset, the creation of an index 36 | might take far more time as all missing values in the timeseries have to be identified. 37 | The algorithm might be possible to speed up but currently, it might be faster for you to 38 | not allow missing values and fill them yourself. 39 | 40 | 41 | * **How are missing values treated?** 42 | 43 | #. Missing values between time points are either filled up with a fill 44 | forward or a constant fill-in strategy 45 | #. Missing values indicated by NaNs are a problem and 46 | should be filled in up-front, e.g. with the median value and another missing indicator categorical variable. 47 | #. Missing values in the future (out of range) are not filled in and 48 | simply not predicted. You have to provide values into the future. 49 | If those values are amongst the unknown future values, they will simply be ignored. 50 | 51 | 52 | Training models 53 | --------------- 54 | 55 | * **My training seems to freeze - nothing seem to be happening although my CPU/GPU is working at 100%. 56 | How to fix this issue?** 57 | 58 | Probably, your model is too big (check the number of parameters with ``model.size()`` or 59 | the dataset encoder and decoder length are unrealistically large. See 60 | :ref:`How long should the encoder and decoder/prediction length be? ` 61 | 62 | * **Why does the learning rate finder not finish?** 63 | 64 | First, ensure that the trainer does not have the keword ``fast_dev_run=True`` and 65 | ``limit_train_batches=...`` set. Second, use a target normalizer in your training dataset. 66 | Third, increase the ``early_stop_threshold`` argument 67 | of the ``lr_find`` method to a large number. 68 | 69 | * **Why do I get lots of matplotlib warnings when running the learning rate finder?** 70 | 71 | This is because you keep on creating plots for logging but without a logger. 72 | Set ``log_interval=-1`` in your model to avoid this behaviour. 73 | 74 | * **How do I choose hyperparameters?** 75 | 76 | Consult the :ref:`model documentation ` to understand which parameters 77 | are important and which ranges are reasonable. Choose the learning rate with 78 | the learning rate finder. To tune hyperparameters, the `optuna package `_ 79 | is a great place to start with. 80 | 81 | 82 | Interpreting models 83 | ------------------- 84 | 85 | * **What interpretation is built into PyTorch Forecasting?** 86 | 87 | Look up the :ref:`model documentation ` for the model you use for model-specific interpretation. 88 | Further, all models come with some basic methods inherited from :py:class:`~models.base_model.BaseModel`. 89 | -------------------------------------------------------------------------------- /docs/source/getting-started.rst: -------------------------------------------------------------------------------- 1 | Getting started 2 | =============== 3 | 4 | .. _getting-started: 5 | 6 | 7 | Installation 8 | -------------- 9 | 10 | .. _install: 11 | 12 | If you are working Windows, you need to first install PyTorch with 13 | 14 | .. code-block:: bash 15 | 16 | pip install torch -f https://download.pytorch.org/whl/torch_stable.html 17 | 18 | Otherwise, you can proceed with 19 | 20 | .. code-block:: bash 21 | 22 | pip install pytorch-forecasting 23 | 24 | 25 | Alternatively, to install the package via ``conda``: 26 | .. code-block:: bash 27 | 28 | conda install pytorch-forecasting pytorch>=1.7 -c pytorch -c conda-forge 29 | 30 | PyTorch Forecasting is now installed from the conda-forge channel while PyTorch is installed from the pytorch channel. 31 | 32 | To use the MQF2 loss (multivariate quantile loss), also install 33 | 34 | .. code-block:: bash 35 | 36 | pip install pytorch-forecasting[mqf2] 37 | 38 | 39 | Usage 40 | ------------- 41 | 42 | .. currentmodule:: pytorch_forecasting 43 | 44 | The library builds strongly upon `PyTorch Lightning `_ which allows to train models with ease, 45 | spot bugs quickly and train on multiple GPUs out-of-the-box. 46 | 47 | Further, we rely on `Tensorboard `_ for logging training progress. 48 | 49 | The general setup for training and testing a model is 50 | 51 | #. Create training dataset using :py:class:`~data.timeseries.TimeSeriesDataSet`. 52 | #. Using the training dataset, create a validation dataset with :py:meth:`~data.timeseries.TimeSeriesDataSet.from_dataset`. 53 | Similarly, a test dataset or later a dataset for inference can be created. You can store the dataset parameters 54 | directly if you do not wish to load the entire training dataset at inference time. 55 | 56 | #. Instantiate a model using the ``.from_dataset()`` method. 57 | #. Create a ``lightning.Trainer()`` object. 58 | #. Find the optimal learning rate with its ``.tuner.lr_find()`` method. 59 | #. Train the model with early stopping on the training dataset and use the tensorboard logs 60 | to understand if it has converged with acceptable accuracy. 61 | #. Tune the hyperparameters of the model with your 62 | `favourite package `_. 63 | #. Train the model with the same learning rate schedule on the entire dataset. 64 | #. Load the model from the model checkpoint and apply it to new data. 65 | 66 | 67 | The :ref:`Tutorials ` section provides detailed guidance and examples on how to use models and implement new ones. 68 | 69 | 70 | Example 71 | -------- 72 | 73 | 74 | .. code-block:: python 75 | 76 | import lightning.pytorch as pl 77 | from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor 78 | from lightning.pytorch.tuner import Tuner 79 | from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer 80 | 81 | # load data 82 | data = ... 83 | 84 | # define dataset 85 | max_encoder_length = 36 86 | max_prediction_length = 6 87 | training_cutoff = "YYYY-MM-DD" # day for cutoff 88 | 89 | training = TimeSeriesDataSet( 90 | data[lambda x: x.date < training_cutoff], 91 | time_idx= ..., 92 | target= ..., 93 | # weight="weight", 94 | group_ids=[ ... ], 95 | max_encoder_length=max_encoder_length, 96 | max_prediction_length=max_prediction_length, 97 | static_categoricals=[ ... ], 98 | static_reals=[ ... ], 99 | time_varying_known_categoricals=[ ... ], 100 | time_varying_known_reals=[ ... ], 101 | time_varying_unknown_categoricals=[ ... ], 102 | time_varying_unknown_reals=[ ... ], 103 | ) 104 | 105 | # create validation and training dataset 106 | validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True) 107 | batch_size = 128 108 | train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2) 109 | val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2) 110 | 111 | # define trainer with early stopping 112 | early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min") 113 | lr_logger = LearningRateMonitor() 114 | trainer = pl.Trainer( 115 | max_epochs=100, 116 | accelerator="auto", 117 | gradient_clip_val=0.1, 118 | limit_train_batches=30, 119 | callbacks=[lr_logger, early_stop_callback], 120 | ) 121 | 122 | # create the model 123 | tft = TemporalFusionTransformer.from_dataset( 124 | training, 125 | learning_rate=0.03, 126 | hidden_size=32, 127 | attention_head_size=1, 128 | dropout=0.1, 129 | hidden_continuous_size=16, 130 | output_size=7, 131 | loss=QuantileLoss(), 132 | log_interval=2, 133 | reduce_on_plateau_patience=4 134 | ) 135 | print(f"Number of parameters in network: {tft.size()/1e3:.1f}k") 136 | 137 | # find optimal learning rate (set limit_train_batches to 1.0 and log_interval = -1) 138 | res = Tuner(trainer).lr_find( 139 | tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3, 140 | ) 141 | 142 | print(f"suggested learning rate: {res.suggestion()}") 143 | fig = res.plot(show=True, suggest=True) 144 | fig.show() 145 | 146 | # fit the model 147 | trainer.fit( 148 | tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, 149 | ) 150 | 151 | Main API 152 | --------- 153 | 154 | .. currentmodule:: pytorch_forecasting 155 | 156 | .. moduleautosummary:: 157 | :toctree: api 158 | :template: custom-module-template.rst 159 | :recursive: 160 | 161 | pytorch_forecasting 162 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. pytorch-forecasting documentation master file, created by 2 | sphinx-quickstart on Sun Aug 16 22:17:24 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | PyTorch Forecasting Documentation 7 | ================================== 8 | 9 | .. raw:: html 10 | 11 | GitHub 12 | 13 | 14 | Our article on `Towards Data Science `_ 15 | introduces the package and provides background information. 16 | 17 | PyTorch Forecasting aims to ease state-of-the-art 18 | timeseries forecasting with neural networks for both real-world cases and 19 | research alike. The goal is to provide a high-level API with maximum flexibility for 20 | professionals and reasonable defaults for beginners. 21 | Specifically, the package provides 22 | 23 | * A timeseries dataset class which abstracts handling variable transformations, missing values, 24 | randomized subsampling, multiple history lengths, etc. 25 | * A base model class which provides basic training of timeseries models along with logging in tensorboard 26 | and generic visualizations such actual vs predictions and dependency plots 27 | * Multiple neural network architectures for timeseries forecasting that have been enhanced 28 | for real-world deployment and come with in-built interpretation capabilities 29 | * Multi-horizon timeseries metrics 30 | * Hyperparameter tuning with `optuna `_ 31 | 32 | The package is built on `PyTorch Lightning `_ to allow 33 | training on CPUs, single and multiple GPUs out-of-the-box. 34 | 35 | If you do not have pytorch already installed, follow the :ref:`detailed installation instructions`. 36 | 37 | Otherwise, proceed to install the package by executing 38 | 39 | .. code-block:: 40 | 41 | pip install pytorch-forecasting 42 | 43 | or to install via conda 44 | 45 | .. code-block:: 46 | 47 | conda install pytorch-forecasting pytorch>=1.7 -c pytorch -c conda-forge 48 | 49 | To use the MQF2 loss (multivariate quantile loss), also execute 50 | 51 | .. code-block:: 52 | 53 | pip install pytorch-forecasting[mqf2] 54 | 55 | Vist :ref:`Getting started ` to learn more about the package and detailled installation instruction. 56 | The :ref:`Tutorials ` section provides guidance on how to use models and implement new ones. 57 | 58 | .. toctree:: 59 | :titlesonly: 60 | :hidden: 61 | :maxdepth: 6 62 | 63 | getting-started 64 | tutorials 65 | data 66 | models 67 | metrics 68 | faq 69 | installation 70 | api 71 | CHANGELOG 72 | 73 | 74 | Indices and tables 75 | ================== 76 | 77 | * :ref:`genindex` 78 | * :ref:`modindex` 79 | * :ref:`search` 80 | -------------------------------------------------------------------------------- /docs/source/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ========== 3 | 4 | Multiple metrics have been implemented to ease adaptation. 5 | 6 | In particular, these metrics can be applied to the multi-horizon forecasting problem, i.e. 7 | can take tensors that are not only of shape ``n_samples`` but also ``n_samples x prediction_horizon`` 8 | or even ``n_samples x prediction_horizon x n_outputs``, where ``n_outputs`` could be the number 9 | of forecasted quantiles. 10 | 11 | Metrics can be easily combined by addition, e.g. 12 | 13 | .. code-block:: python 14 | 15 | from pytorch_forecasting.metrics import SMAPE, MAE 16 | 17 | composite_metric = SMAPE() + 1e-4 * MAE() 18 | 19 | Such composite metrics are useful when training because they can reduce outliers in other metrics. 20 | In the example, SMAPE is mostly optimized, while large outliers in MAE are avoided. 21 | 22 | Further, one can modify a loss metric to reduce a mean prediction bias, i.e. ensure that 23 | predictions add up. For example: 24 | 25 | .. code-block:: python 26 | 27 | from pytorch_forecasting.metrics import MAE, AggregationMetric 28 | 29 | composite_metric = MAE() + AggregationMetric(metric=MAE()) 30 | 31 | Here we add to MAE an additional loss. This additional loss is the MAE calculated on the mean predictions 32 | and actuals. We can also use other metrics such as SMAPE to ensure aggregated results are unbiased in that metric. 33 | One important point to keep in mind is that this metric is calculated accross samples, i.e. it will vary depending 34 | on the batch size. In particular, errors tend to average out with increased batch sizes. 35 | 36 | 37 | Details 38 | -------- 39 | 40 | See the API documentation for further details on available metrics: 41 | 42 | .. currentmodule:: pytorch_forecasting 43 | 44 | .. moduleautosummary:: 45 | :toctree: api 46 | :template: custom-module-template.rst 47 | :recursive: 48 | 49 | pytorch_forecasting.metrics 50 | -------------------------------------------------------------------------------- /docs/source/tutorials.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========== 3 | 4 | .. _tutorials: 5 | 6 | The following tutorials can be also found as `notebooks on GitHub `_. 7 | 8 | .. toctree:: 9 | :titlesonly: 10 | :maxdepth: 2 11 | 12 | tutorials/stallion 13 | tutorials/ar 14 | tutorials/building 15 | tutorials/deepar 16 | tutorials/nhits 17 | -------------------------------------------------------------------------------- /examples/ar.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import lightning.pytorch as pl 4 | from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor 5 | import pandas as pd 6 | from pandas.core.common import SettingWithCopyWarning 7 | import torch 8 | 9 | from pytorch_forecasting import GroupNormalizer, TimeSeriesDataSet 10 | from pytorch_forecasting.data import NaNLabelEncoder 11 | from pytorch_forecasting.data.examples import generate_ar_data 12 | from pytorch_forecasting.metrics import NormalDistributionLoss 13 | from pytorch_forecasting.models.deepar import DeepAR 14 | 15 | warnings.simplefilter("error", category=SettingWithCopyWarning) 16 | 17 | 18 | data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) 19 | data["static"] = "2" 20 | data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") 21 | validation = data.series.sample(20) 22 | 23 | max_encoder_length = 60 24 | max_prediction_length = 20 25 | 26 | training_cutoff = data["time_idx"].max() - max_prediction_length 27 | 28 | training = TimeSeriesDataSet( 29 | data[lambda x: ~x.series.isin(validation)], 30 | time_idx="time_idx", 31 | target="value", 32 | categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, 33 | group_ids=["series"], 34 | static_categoricals=["static"], 35 | min_encoder_length=max_encoder_length, 36 | max_encoder_length=max_encoder_length, 37 | min_prediction_length=max_prediction_length, 38 | max_prediction_length=max_prediction_length, 39 | time_varying_unknown_reals=["value"], 40 | time_varying_known_reals=["time_idx"], 41 | target_normalizer=GroupNormalizer(groups=["series"]), 42 | add_relative_time_idx=False, 43 | add_target_scales=True, 44 | randomize_length=None, 45 | ) 46 | 47 | validation = TimeSeriesDataSet.from_dataset( 48 | training, 49 | data[lambda x: x.series.isin(validation)], 50 | # predict=True, 51 | stop_randomization=True, 52 | ) 53 | batch_size = 64 54 | train_dataloader = training.to_dataloader( 55 | train=True, batch_size=batch_size, num_workers=0 56 | ) 57 | val_dataloader = validation.to_dataloader( 58 | train=False, batch_size=batch_size, num_workers=0 59 | ) 60 | 61 | # save datasets 62 | training.save("training.pkl") 63 | validation.save("validation.pkl") 64 | 65 | early_stop_callback = EarlyStopping( 66 | monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min" 67 | ) 68 | lr_logger = LearningRateMonitor() 69 | 70 | trainer = pl.Trainer( 71 | max_epochs=10, 72 | accelerator="gpu", 73 | devices="auto", 74 | gradient_clip_val=0.1, 75 | limit_train_batches=30, 76 | limit_val_batches=3, 77 | # fast_dev_run=True, 78 | # logger=logger, 79 | # profiler=True, 80 | callbacks=[lr_logger, early_stop_callback], 81 | ) 82 | 83 | 84 | deepar = DeepAR.from_dataset( 85 | training, 86 | learning_rate=0.1, 87 | hidden_size=32, 88 | dropout=0.1, 89 | loss=NormalDistributionLoss(), 90 | log_interval=10, 91 | log_val_interval=3, 92 | # reduce_on_plateau_patience=3, 93 | ) 94 | print(f"Number of parameters in network: {deepar.size() / 1e3:.1f}k") 95 | 96 | # # find optimal learning rate 97 | # deepar.hparams.log_interval = -1 98 | # deepar.hparams.log_val_interval = -1 99 | # trainer.limit_train_batches = 1.0 100 | # res = Tuner(trainer).lr_find( 101 | # deepar, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501 102 | # ) 103 | 104 | # print(f"suggested learning rate: {res.suggestion()}") 105 | # fig = res.plot(show=True, suggest=True) 106 | # fig.show() 107 | # deepar.hparams.learning_rate = res.suggestion() 108 | 109 | torch.set_num_threads(10) 110 | trainer.fit( 111 | deepar, 112 | train_dataloaders=train_dataloader, 113 | val_dataloaders=val_dataloader, 114 | ) 115 | 116 | # calcualte mean absolute error on validation set 117 | actuals = torch.cat([y for x, (y, weight) in iter(val_dataloader)]) 118 | predictions = deepar.predict(val_dataloader) 119 | print(f"Mean absolute error of model: {(actuals - predictions).abs().mean()}") 120 | 121 | # # plot actual vs. predictions 122 | # raw_predictions, x = deepar.predict(val_dataloader, mode="raw", return_x=True) 123 | # for idx in range(10): # plot 10 examples 124 | # deepar.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True) 125 | -------------------------------------------------------------------------------- /examples/data/stallion.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sktime/pytorch-forecasting/4384140418bfe8e8c64cc6ce5fc19372508bfccd/examples/data/stallion.parquet -------------------------------------------------------------------------------- /examples/nbeats.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import lightning.pytorch as pl 4 | from lightning.pytorch.callbacks import EarlyStopping 5 | import pandas as pd 6 | 7 | from pytorch_forecasting import NBeats, TimeSeriesDataSet 8 | from pytorch_forecasting.data import NaNLabelEncoder 9 | from pytorch_forecasting.data.examples import generate_ar_data 10 | 11 | sys.path.append("..") 12 | 13 | 14 | print("load data") 15 | data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100) 16 | data["static"] = 2 17 | data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") 18 | validation = data.series.sample(20) 19 | 20 | 21 | max_encoder_length = 150 22 | max_prediction_length = 20 23 | 24 | training_cutoff = data["time_idx"].max() - max_prediction_length 25 | 26 | context_length = max_encoder_length 27 | prediction_length = max_prediction_length 28 | 29 | training = TimeSeriesDataSet( 30 | data[lambda x: x.time_idx < training_cutoff], 31 | time_idx="time_idx", 32 | target="value", 33 | categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, 34 | group_ids=["series"], 35 | min_encoder_length=context_length, 36 | max_encoder_length=context_length, 37 | max_prediction_length=prediction_length, 38 | min_prediction_length=prediction_length, 39 | time_varying_unknown_reals=["value"], 40 | randomize_length=None, 41 | add_relative_time_idx=False, 42 | add_target_scales=False, 43 | ) 44 | 45 | validation = TimeSeriesDataSet.from_dataset( 46 | training, data, min_prediction_idx=training_cutoff 47 | ) 48 | batch_size = 128 49 | train_dataloader = training.to_dataloader( 50 | train=True, batch_size=batch_size, num_workers=2 51 | ) 52 | val_dataloader = validation.to_dataloader( 53 | train=False, batch_size=batch_size, num_workers=2 54 | ) 55 | 56 | 57 | early_stop_callback = EarlyStopping( 58 | monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min" 59 | ) 60 | trainer = pl.Trainer( 61 | max_epochs=100, 62 | accelerator="auto", 63 | gradient_clip_val=0.1, 64 | callbacks=[early_stop_callback], 65 | limit_train_batches=15, 66 | # limit_val_batches=1, 67 | # fast_dev_run=True, 68 | # logger=logger, 69 | # profiler=True, 70 | ) 71 | 72 | 73 | net = NBeats.from_dataset( 74 | training, 75 | learning_rate=3e-2, 76 | log_interval=10, 77 | log_val_interval=1, 78 | log_gradient_flow=False, 79 | weight_decay=1e-2, 80 | ) 81 | print(f"Number of parameters in network: {net.size() / 1e3:.1f}k") 82 | 83 | # # find optimal learning rate 84 | # # remove logging and artificial epoch size 85 | # net.hparams.log_interval = -1 86 | # net.hparams.log_val_interval = -1 87 | # trainer.limit_train_batches = 1.0 88 | # # run learning rate finder 89 | # res = Tuner(trainer).lr_find( 90 | # net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501 91 | # ) 92 | # print(f"suggested learning rate: {res.suggestion()}") 93 | # fig = res.plot(show=True, suggest=True) 94 | # fig.show() 95 | # net.hparams.learning_rate = res.suggestion() 96 | 97 | trainer.fit( 98 | net, 99 | train_dataloaders=train_dataloader, 100 | val_dataloaders=val_dataloader, 101 | ) 102 | -------------------------------------------------------------------------------- /examples/stallion.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import warnings 3 | 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor 6 | from lightning.pytorch.loggers import TensorBoardLogger 7 | import numpy as np 8 | from pandas.core.common import SettingWithCopyWarning 9 | 10 | from pytorch_forecasting import ( 11 | GroupNormalizer, 12 | TemporalFusionTransformer, 13 | TimeSeriesDataSet, 14 | ) 15 | from pytorch_forecasting.data.examples import get_stallion_data 16 | from pytorch_forecasting.metrics import QuantileLoss 17 | from pytorch_forecasting.models.temporal_fusion_transformer.tuning import ( 18 | optimize_hyperparameters, 19 | ) 20 | 21 | warnings.simplefilter("error", category=SettingWithCopyWarning) 22 | 23 | 24 | data = get_stallion_data() 25 | 26 | data["month"] = data.date.dt.month.astype("str").astype("category") 27 | data["log_volume"] = np.log(data.volume + 1e-8) 28 | 29 | data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month 30 | data["time_idx"] -= data["time_idx"].min() 31 | data["avg_volume_by_sku"] = data.groupby( 32 | ["time_idx", "sku"], observed=True 33 | ).volume.transform("mean") 34 | data["avg_volume_by_agency"] = data.groupby( 35 | ["time_idx", "agency"], observed=True 36 | ).volume.transform("mean") 37 | # data = data[lambda x: (x.sku == data.iloc[0]["sku"]) & (x.agency == data.iloc[0]["agency"])] # noqa: E501 38 | special_days = [ 39 | "easter_day", 40 | "good_friday", 41 | "new_year", 42 | "christmas", 43 | "labor_day", 44 | "independence_day", 45 | "revolution_day_memorial", 46 | "regional_games", 47 | "fifa_u_17_world_cup", 48 | "football_gold_cup", 49 | "beer_capital", 50 | "music_fest", 51 | ] 52 | data[special_days] = ( 53 | data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") 54 | ) 55 | 56 | training_cutoff = data["time_idx"].max() - 6 57 | max_encoder_length = 36 58 | max_prediction_length = 6 59 | 60 | training = TimeSeriesDataSet( 61 | data[lambda x: x.time_idx <= training_cutoff], 62 | time_idx="time_idx", 63 | target="volume", 64 | group_ids=["agency", "sku"], 65 | min_encoder_length=max_encoder_length 66 | // 2, # allow encoder lengths from 0 to max_prediction_length 67 | max_encoder_length=max_encoder_length, 68 | min_prediction_length=1, 69 | max_prediction_length=max_prediction_length, 70 | static_categoricals=["agency", "sku"], 71 | static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], 72 | time_varying_known_categoricals=["special_days", "month"], 73 | variable_groups={ 74 | "special_days": special_days 75 | }, # group of categorical variables can be treated as one variable 76 | time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"], 77 | time_varying_unknown_categoricals=[], 78 | time_varying_unknown_reals=[ 79 | "volume", 80 | "log_volume", 81 | "industry_volume", 82 | "soda_volume", 83 | "avg_max_temp", 84 | "avg_volume_by_agency", 85 | "avg_volume_by_sku", 86 | ], 87 | target_normalizer=GroupNormalizer( 88 | groups=["agency", "sku"], transformation="softplus", center=False 89 | ), # use softplus with beta=1.0 and normalize by group 90 | add_relative_time_idx=True, 91 | add_target_scales=True, 92 | add_encoder_length=True, 93 | ) 94 | 95 | 96 | validation = TimeSeriesDataSet.from_dataset( 97 | training, data, predict=True, stop_randomization=True 98 | ) 99 | batch_size = 64 100 | train_dataloader = training.to_dataloader( 101 | train=True, batch_size=batch_size, num_workers=0 102 | ) 103 | val_dataloader = validation.to_dataloader( 104 | train=False, batch_size=batch_size, num_workers=0 105 | ) 106 | 107 | 108 | # save datasets 109 | training.save("t raining.pkl") 110 | validation.save("validation.pkl") 111 | 112 | early_stop_callback = EarlyStopping( 113 | monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min" 114 | ) 115 | lr_logger = LearningRateMonitor() 116 | logger = TensorBoardLogger(log_graph=True) 117 | 118 | trainer = pl.Trainer( 119 | max_epochs=100, 120 | accelerator="auto", 121 | gradient_clip_val=0.1, 122 | limit_train_batches=30, 123 | # val_check_interval=20, 124 | # limit_val_batches=1, 125 | # fast_dev_run=True, 126 | logger=logger, 127 | # profiler=True, 128 | callbacks=[lr_logger, early_stop_callback], 129 | ) 130 | 131 | 132 | tft = TemporalFusionTransformer.from_dataset( 133 | training, 134 | learning_rate=0.03, 135 | hidden_size=16, 136 | attention_head_size=1, 137 | dropout=0.1, 138 | hidden_continuous_size=8, 139 | output_size=7, 140 | loss=QuantileLoss(), 141 | log_interval=10, 142 | log_val_interval=1, 143 | reduce_on_plateau_patience=3, 144 | ) 145 | print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k") 146 | 147 | # # find optimal learning rate 148 | # # remove logging and artificial epoch size 149 | # tft.hparams.log_interval = -1 150 | # tft.hparams.log_val_interval = -1 151 | # trainer.limit_train_batches = 1.0 152 | # # run learning rate finder 153 | # res = Tuner(trainer).lr_find( 154 | # tft, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501 155 | # ) 156 | # print(f"suggested learning rate: {res.suggestion()}") 157 | # fig = res.plot(show=True, suggest=True) 158 | # fig.show() 159 | # tft.hparams.learning_rate = res.suggestion() 160 | 161 | # trainer.fit( 162 | # tft, 163 | # train_dataloaders=train_dataloader, 164 | # val_dataloaders=val_dataloader, 165 | # ) 166 | 167 | # # make a prediction on entire validation set 168 | # preds, index = tft.predict(val_dataloader, return_index=True, fast_dev_run=True) 169 | 170 | 171 | # tune 172 | study = optimize_hyperparameters( 173 | train_dataloader, 174 | val_dataloader, 175 | model_path="optuna_test", 176 | n_trials=200, 177 | max_epochs=50, 178 | gradient_clip_val_range=(0.01, 1.0), 179 | hidden_size_range=(8, 128), 180 | hidden_continuous_size_range=(8, 128), 181 | attention_head_size_range=(1, 4), 182 | learning_rate_range=(0.001, 0.1), 183 | dropout_range=(0.1, 0.3), 184 | trainer_kwargs=dict(limit_train_batches=30), 185 | reduce_on_plateau_patience=4, 186 | use_learning_rate_finder=False, 187 | ) 188 | with open("test_study.pkl", "wb") as fout: 189 | pickle.dump(study, fout) 190 | 191 | 192 | # profile speed 193 | # profile( 194 | # trainer.fit, 195 | # profile_fname="profile.prof", 196 | # model=tft, 197 | # period=0.001, 198 | # filter="pytorch_forecasting", 199 | # train_dataloaders=train_dataloader, 200 | # val_dataloaders=val_dataloader, 201 | # ) 202 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pytorch-forecasting" 3 | readme = "README.md" # Markdown files are supported 4 | version = "1.3.0" # is being replaced automatically 5 | 6 | authors = [ 7 | {name = "Jan Beitner"}, 8 | ] 9 | requires-python = ">=3.9,<3.14" 10 | classifiers = [ 11 | "Intended Audience :: Developers", 12 | "Intended Audience :: Science/Research", 13 | "Programming Language :: Python :: 3", 14 | "Programming Language :: Python :: 3.9", 15 | "Programming Language :: Python :: 3.10", 16 | "Programming Language :: Python :: 3.11", 17 | "Programming Language :: Python :: 3.12", 18 | "Programming Language :: Python :: 3.13", 19 | "Topic :: Scientific/Engineering", 20 | "Topic :: Scientific/Engineering :: Mathematics", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | "Topic :: Software Development", 23 | "Topic :: Software Development :: Libraries", 24 | "Topic :: Software Development :: Libraries :: Python Modules", 25 | "License :: OSI Approved :: MIT License", 26 | ] 27 | description = "Forecasting timeseries with PyTorch - dataloaders, normalizers, metrics and models" 28 | 29 | dependencies = [ 30 | "numpy<=3.0.0", 31 | "torch >=2.0.0,!=2.0.1,<3.0.0", 32 | "lightning >=2.0.0,<3.0.0", 33 | "scipy >=1.8,<2.0", 34 | "pandas >=1.3.0,<3.0.0", 35 | "scikit-learn >=1.2,<2.0", 36 | ] 37 | 38 | [project.optional-dependencies] 39 | # there are the following dependency sets: 40 | # - all_extras - all soft dependencies 41 | # - granular dependency sets: 42 | # - tuning - dependencies for tuning hyperparameters via optuna 43 | # - mqf2 - dependencies for multivariate quantile loss 44 | # - graph - dependencies for graph based forecasting 45 | # - dev - the developer dependency set, for contributors to pytorch-forecasting 46 | # - CI related: e.g., dev, github-actions. Not for users. 47 | # 48 | # soft dependencies are not required for the core functionality of pytorch-forecasting 49 | # but are required by popular estimators, e.g., prophet, tbats, etc. 50 | 51 | # all soft dependencies 52 | # 53 | # users can install via "pip install pytorch-forecasting[all_extras]" 54 | # 55 | all_extras = [ 56 | "cpflows", 57 | "matplotlib", 58 | "optuna >=3.1.0,<5.0.0", 59 | "optuna-integration", 60 | "pytorch_optimizer >=2.5.1,<4.0.0", 61 | "statsmodels", 62 | ] 63 | 64 | tuning = [ 65 | "optuna >=3.1.0,<5.0.0", 66 | "optuna-integration", 67 | "statsmodels", 68 | ] 69 | 70 | mqf2 = ["cpflows"] 71 | 72 | # the graph set is not currently used within pytorch-forecasting 73 | # but is kept for future development, as it has already been released 74 | graph = ["networkx"] 75 | 76 | # dev - the developer dependency set, for contributors to pytorch-forecasting 77 | dev = [ 78 | "pydocstyle >=6.1.1,<7.0.0", 79 | # checks and make tools 80 | "pre-commit >=3.2.0,<5.0.0", 81 | "invoke", 82 | "mypy", 83 | "pylint", 84 | "ruff", 85 | # pytest 86 | "pytest", 87 | "pytest-xdist", 88 | "pytest-cov", 89 | "pytest-sugar", 90 | "coverage", 91 | "pyarrow", 92 | # jupyter notebook 93 | "ipykernel", 94 | "nbconvert", 95 | "black[jupyter]", 96 | # documentatation 97 | "sphinx", 98 | "pydata-sphinx-theme", 99 | "nbsphinx", 100 | "recommonmark", 101 | "ipywidgets>=8.0.1,<9.0.0", 102 | "pytest-dotenv>=0.5.2,<1.0.0", 103 | "tensorboard>=2.12.1,<3.0.0", 104 | "pandoc>=2.3,<3.0.0", 105 | "scikit-base", 106 | ] 107 | 108 | # docs - dependencies for building the documentation 109 | docs = [ 110 | "sphinx>3.2,<8.2.4", 111 | "pydata-sphinx-theme", 112 | "nbsphinx", 113 | "pandoc", 114 | "nbconvert", 115 | "recommonmark", 116 | "docutils", 117 | ] 118 | 119 | github-actions = ["pytest-github-actions-annotate-failures"] 120 | 121 | [tool.setuptools.packages.find] 122 | exclude = ["build_tools"] 123 | 124 | [build-system] 125 | build-backend = "setuptools.build_meta" 126 | requires = [ 127 | "setuptools>=70.0.0", 128 | ] 129 | 130 | [tool.ruff] 131 | line-length = 88 132 | exclude = [ 133 | "docs/build/", 134 | "node_modules/", 135 | ".eggs/", 136 | "versioneer.py", 137 | "venv/", 138 | ".venv/", 139 | ".git/", 140 | ".history/", 141 | "docs/source/tutorials/", 142 | ] 143 | target-version = "py39" 144 | 145 | [tool.ruff.format] 146 | # Enable formatting 147 | quote-style = "double" 148 | indent-style = "space" 149 | skip-magic-trailing-comma = false 150 | line-ending = "auto" 151 | 152 | [tool.ruff.lint] 153 | select = ["E", "F", "W", "C4", "S"] 154 | extend-select = [ 155 | "I", # isort 156 | "UP", # pyupgrade 157 | "C4", # https://pypi.org/project/flake8-comprehensions 158 | ] 159 | extend-ignore = [ 160 | "E203", # space before : (needed for how black formats slicing) 161 | "E402", # module level import not at top of file 162 | "E731", # do not assign a lambda expression, use a def 163 | "E741", # ignore not easy to read variables like i l I etc. 164 | "C406", # Unnecessary list literal - rewrite as a dict literal. 165 | "C408", # Unnecessary dict call - rewrite as a literal. 166 | "C409", # Unnecessary list passed to tuple() - rewrite as a tuple literal. 167 | "F401", # unused imports 168 | "S101", # use of assert 169 | ] 170 | 171 | [tool.ruff.lint.isort] 172 | known-first-party = ["pytorch_forecasting"] 173 | combine-as-imports = true 174 | force-sort-within-sections = true 175 | 176 | [tool.ruff.lint.per-file-ignores] 177 | "pytorch_forecasting/data/timeseries.py" = [ 178 | "E501", # Line too long being fixed in #1746 To be removed after merging 179 | ] 180 | 181 | [tool.nbqa.mutate] 182 | ruff = 1 183 | black = 1 184 | 185 | [tool.nbqa.exclude] 186 | ruff = "docs/source/tutorials/" # ToDo: Remove this when fixing notebooks 187 | 188 | [tool.coverage.report] 189 | ignore_errors = false 190 | show_missing = true 191 | 192 | [tool.mypy] 193 | ignore_missing_imports = true 194 | no_implicit_optional = true 195 | check_untyped_defs = true 196 | cache_dir = ".cache/mypy/" 197 | 198 | [tool.pytest.ini_options] 199 | addopts = [ 200 | "-rsxX", 201 | "-vv", 202 | "--cov-config=.coveragerc", 203 | "--cov=pytorch_forecasting", 204 | "--cov-report=html", 205 | "--cov-report=term-missing:skip-covered", 206 | "--no-cov-on-fail" 207 | ] 208 | markers = [] 209 | testpaths = ["tests/", "pytorch_forecasting/tests/"] 210 | log_cli_level = "ERROR" 211 | log_format = "%(asctime)s %(levelname)s %(message)s" 212 | log_date_format = "%Y-%m-%d %H:%M:%S" 213 | cache_dir = ".cache" 214 | filterwarnings = [ 215 | "ignore:Found \\d+ unknown classes which were set to NaN:UserWarning", 216 | "ignore:Less than \\d+ samples available for \\d+ prediction times. Use ba:UserWarning", 217 | "ignore:scale is below 1e-7 - consider not centering the data or using data with:UserWarning", 218 | "ignore:You defined a `validation_step` but have no `val_dataloader`:UserWarning", 219 | "ignore:ReduceLROnPlateau conditioned on metric:RuntimeWarning", 220 | "ignore:The number of training samples \\(\\d+\\) is smaller than the logging interval Trainer\\(:UserWarning", 221 | "ignore:The dataloader, [\\_\\s]+ \\d+, does not have many workers which may be a bottleneck.:UserWarning", 222 | "ignore:Consider increasing the value of the `num_workers` argument`:UserWarning", 223 | "ignore::UserWarning" 224 | ] 225 | -------------------------------------------------------------------------------- /pytorch_forecasting/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch Forecasting package for timeseries forecasting with PyTorch. 3 | """ 4 | 5 | __version__ = "1.3.0" 6 | 7 | from pytorch_forecasting.data import ( 8 | EncoderNormalizer, 9 | GroupNormalizer, 10 | MultiNormalizer, 11 | NaNLabelEncoder, 12 | TimeSeriesDataSet, 13 | ) 14 | from pytorch_forecasting.metrics import ( 15 | MAE, 16 | MAPE, 17 | MASE, 18 | RMSE, 19 | SMAPE, 20 | BetaDistributionLoss, 21 | CrossEntropy, 22 | DistributionLoss, 23 | ImplicitQuantileNetworkDistributionLoss, 24 | LogNormalDistributionLoss, 25 | MQF2DistributionLoss, 26 | MultiHorizonMetric, 27 | MultiLoss, 28 | MultivariateNormalDistributionLoss, 29 | NegativeBinomialDistributionLoss, 30 | NormalDistributionLoss, 31 | PoissonLoss, 32 | QuantileLoss, 33 | ) 34 | from pytorch_forecasting.models import ( 35 | GRU, 36 | LSTM, 37 | AutoRegressiveBaseModel, 38 | AutoRegressiveBaseModelWithCovariates, 39 | Baseline, 40 | BaseModel, 41 | BaseModelWithCovariates, 42 | DecoderMLP, 43 | DeepAR, 44 | MultiEmbedding, 45 | NBeats, 46 | NHiTS, 47 | RecurrentNetwork, 48 | TemporalFusionTransformer, 49 | TiDEModel, 50 | get_rnn, 51 | ) 52 | from pytorch_forecasting.utils import ( 53 | apply_to_list, 54 | autocorrelation, 55 | create_mask, 56 | detach, 57 | get_embedding_size, 58 | groupby_apply, 59 | integer_histogram, 60 | move_to_device, 61 | profile, 62 | to_list, 63 | unpack_sequence, 64 | ) 65 | from pytorch_forecasting.utils._maint._show_versions import show_versions 66 | 67 | __all__ = [ 68 | "TimeSeriesDataSet", 69 | "GroupNormalizer", 70 | "EncoderNormalizer", 71 | "NaNLabelEncoder", 72 | "MultiNormalizer", 73 | "TemporalFusionTransformer", 74 | "TiDEModel", 75 | "NBeats", 76 | "NHiTS", 77 | "Baseline", 78 | "DeepAR", 79 | "BaseModel", 80 | "BaseModelWithCovariates", 81 | "AutoRegressiveBaseModel", 82 | "AutoRegressiveBaseModelWithCovariates", 83 | "MultiHorizonMetric", 84 | "MultiLoss", 85 | "MAE", 86 | "MAPE", 87 | "MASE", 88 | "SMAPE", 89 | "DistributionLoss", 90 | "BetaDistributionLoss", 91 | "LogNormalDistributionLoss", 92 | "NegativeBinomialDistributionLoss", 93 | "NormalDistributionLoss", 94 | "ImplicitQuantileNetworkDistributionLoss", 95 | "MultivariateNormalDistributionLoss", 96 | "MQF2DistributionLoss", 97 | "CrossEntropy", 98 | "PoissonLoss", 99 | "QuantileLoss", 100 | "RMSE", 101 | "get_rnn", 102 | "LSTM", 103 | "GRU", 104 | "MultiEmbedding", 105 | "apply_to_list", 106 | "autocorrelation", 107 | "get_embedding_size", 108 | "create_mask", 109 | "to_list", 110 | "RecurrentNetwork", 111 | "DecoderMLP", 112 | "detach", 113 | "move_to_device", 114 | "integer_histogram", 115 | "groupby_apply", 116 | "profile", 117 | "show_versions", 118 | "unpack_sequence", 119 | ] 120 | -------------------------------------------------------------------------------- /pytorch_forecasting/_registry/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Forecasting registry.""" 2 | 3 | from pytorch_forecasting._registry._lookup import all_objects 4 | 5 | __all__ = ["all_objects"] 6 | -------------------------------------------------------------------------------- /pytorch_forecasting/data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datasets, etc. for timeseries data. 3 | 4 | Handling timeseries data is not trivial. It requires special treatment. 5 | This sub-package provides the necessary tools to abstracts the necessary work. 6 | """ 7 | 8 | from pytorch_forecasting.data.encoders import ( 9 | EncoderNormalizer, 10 | GroupNormalizer, 11 | MultiNormalizer, 12 | NaNLabelEncoder, 13 | TorchNormalizer, 14 | ) 15 | from pytorch_forecasting.data.samplers import TimeSynchronizedBatchSampler 16 | from pytorch_forecasting.data.timeseries import TimeSeries, TimeSeriesDataSet 17 | 18 | __all__ = [ 19 | "TimeSeriesDataSet", 20 | "TimeSeries", 21 | "NaNLabelEncoder", 22 | "GroupNormalizer", 23 | "TorchNormalizer", 24 | "EncoderNormalizer", 25 | "TimeSynchronizedBatchSampler", 26 | "MultiNormalizer", 27 | ] 28 | -------------------------------------------------------------------------------- /pytorch_forecasting/data/examples.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example datasets for tutorials and testing. 3 | """ 4 | 5 | from pathlib import Path 6 | from urllib.request import urlretrieve 7 | 8 | import numpy as np 9 | import pandas as pd 10 | 11 | BASE_URL = "https://github.com/sktime/pytorch-forecasting/raw/main/examples/data/" 12 | 13 | DATA_PATH = Path(__file__).parent 14 | 15 | 16 | def _get_data_by_filename(fname: str) -> Path: 17 | """ 18 | Download file or used cached version. 19 | 20 | Args: 21 | fname (str): name of file to download 22 | 23 | Returns: 24 | Path: path at which file lives 25 | """ 26 | full_fname = DATA_PATH.joinpath(fname) 27 | 28 | # check if file exists - download if necessary 29 | if not full_fname.exists(): 30 | url = BASE_URL + fname 31 | urlretrieve(url, full_fname) # noqa: S310 32 | 33 | return full_fname 34 | 35 | 36 | def get_stallion_data() -> pd.DataFrame: 37 | """ 38 | Demand data with covariates. 39 | 40 | ~20k samples of 350 timeseries. Important columns 41 | 42 | * Timeseries can be identified by ``agency`` and ``sku``. 43 | * ``volume`` is the demand 44 | * ``date`` is the month of the demand. 45 | 46 | Returns: 47 | pd.DataFrame: data 48 | """ 49 | fname = _get_data_by_filename("stallion.parquet") 50 | return pd.read_parquet(fname) 51 | 52 | 53 | def generate_ar_data( 54 | n_series: int = 10, 55 | timesteps: int = 400, 56 | seasonality: float = 3.0, 57 | trend: float = 3.0, 58 | noise: float = 0.1, 59 | level: float = 1.0, 60 | exp: bool = False, 61 | seed: int = 213, 62 | ) -> pd.DataFrame: 63 | """ 64 | Generate multivariate data without covariates. 65 | 66 | Eeach timeseries is generated from seasonality and trend. Important columns: 67 | 68 | * ``series``: series ID 69 | * ``time_idx``: time index 70 | * ``value``: target value 71 | 72 | Args: 73 | n_series (int, optional): Number of series. Defaults to 10. 74 | timesteps (int, optional): Number of timesteps. Defaults to 400. 75 | seasonality (float, optional): Normalized frequency, i.e. frequency is ``seasonality / timesteps``. 76 | Defaults to 3.0. 77 | trend (float, optional): Trend multiplier (seasonality is multiplied with 1.0). Defaults to 3.0. 78 | noise (float, optional): Level of gaussian noise. Defaults to 0.1. 79 | level (float, optional): Level multiplier (level is a constant to be aded to timeseries). Defaults to 1.0. 80 | exp (bool, optional): If to return exponential of timeseries values. Defaults to False. 81 | seed (int, optional): Random seed. Defaults to 213. 82 | 83 | Returns: 84 | pd.DataFrame: data 85 | """ # noqa: E501 86 | # sample parameters 87 | np.random.seed(seed) 88 | linear_trends = np.random.normal(size=n_series)[:, None] / timesteps 89 | quadratic_trends = np.random.normal(size=n_series)[:, None] / timesteps**2 90 | seasonalities = np.random.normal(size=n_series)[:, None] 91 | levels = level * np.random.normal(size=n_series)[:, None] 92 | 93 | # generate series 94 | x = np.arange(timesteps)[None, :] 95 | series = ( 96 | x * linear_trends + x**2 * quadratic_trends 97 | ) * trend + seasonalities * np.sin(2 * np.pi * seasonality * x / timesteps) 98 | # add noise 99 | series = levels * series * (1 + noise * np.random.normal(size=series.shape)) 100 | if exp: 101 | series = np.exp(series) 102 | 103 | # insert into dataframe 104 | data = ( 105 | pd.DataFrame(series) 106 | .stack() 107 | .reset_index() 108 | .rename(columns={"level_0": "series", "level_1": "time_idx", 0: "value"}) 109 | ) 110 | 111 | return data 112 | -------------------------------------------------------------------------------- /pytorch_forecasting/data/samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Samplers for sampling time series from the :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet` 3 | """ # noqa: E501 4 | 5 | import warnings 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.utils import shuffle 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | class GroupedSampler(Sampler): 14 | """ 15 | Samples mini-batches randomly but in a grouped manner. 16 | 17 | This means that the items from the different groups are always sampled together. 18 | This is an abstract class. Implement the :py:meth:`~get_groups` method which creates groups to be sampled from. 19 | """ # noqa: E501 20 | 21 | def __init__( 22 | self, 23 | sampler: Sampler, 24 | batch_size: int = 64, 25 | shuffle: bool = False, 26 | drop_last: bool = False, 27 | ): 28 | """ 29 | Initialize. 30 | 31 | Args: 32 | sampler (Sampler or Iterable): Base sampler. Can be any iterable object 33 | drop_last (bool): if to drop last mini-batch from a group if it is smaller than batch_size. 34 | Defaults to False. 35 | shuffle (bool): if to shuffle dataset. Defaults to False. 36 | batch_size (int, optional): Number of samples in a mini-batch. This is rather the maximum number 37 | of samples. Because mini-batches are grouped by prediction time, chances are that there 38 | are multiple where batch size will be smaller than the maximum. Defaults to 64. 39 | """ # noqa: E501 40 | # Since collections.abc.Iterable does not check for `__getitem__`, which 41 | # is one way for an object to be an iterable, we don't do an `isinstance` 42 | # check here. 43 | if ( 44 | not isinstance(batch_size, int) 45 | or isinstance(batch_size, bool) 46 | or batch_size <= 0 47 | ): 48 | raise ValueError( 49 | "batch_size should be a positive integer value, " 50 | f"but got batch_size={batch_size}" 51 | ) 52 | if not isinstance(drop_last, bool): 53 | raise ValueError( 54 | f"drop_last should be a boolean value, but got drop_last={drop_last}" 55 | ) 56 | self.sampler = sampler 57 | self.batch_size = batch_size 58 | self.drop_last = drop_last 59 | self.shuffle = shuffle 60 | # make groups and construct new index to sample from 61 | groups = self.get_groups(self.sampler) 62 | self.construct_batch_groups(groups) 63 | 64 | def get_groups(self, sampler: Sampler): 65 | """ 66 | Create the groups which can be sampled. 67 | 68 | Args: 69 | sampler (Sampler): will have attribute data_source which is of type TimeSeriesDataSet. 70 | 71 | Returns: 72 | dict-like: dictionary-like object with data_source.index as values and group names as keys 73 | """ # noqa: E501 74 | raise NotImplementedError() 75 | 76 | def construct_batch_groups(self, groups): 77 | """ 78 | Construct index of batches from which can be sampled 79 | """ 80 | self._groups = groups 81 | # calculate sizes of groups 82 | self._group_sizes = {} 83 | warns = [] 84 | for name, group in self._groups.items(): # iterate over groups 85 | if self.drop_last: 86 | self._group_sizes[name] = len(group) // self.batch_size 87 | else: 88 | self._group_sizes[name] = ( 89 | len(group) + self.batch_size - 1 90 | ) // self.batch_size 91 | if self._group_sizes[name] == 0: 92 | self._group_sizes[name] = 1 93 | warns.append(name) 94 | if len(warns) > 0: 95 | warnings.warn( 96 | f"Less than {self.batch_size} samples available for " 97 | f"{len(warns)} prediction times. " 98 | f"Use batch size smaller than {self.batch_size}. " 99 | f"First 10 prediction times with small batch sizes: {warns[:10]}" 100 | ) 101 | # create index from which can be sampled: index is equal to number of batches 102 | # associate index with prediction time 103 | self._group_index = np.repeat( 104 | list(self._group_sizes.keys()), list(self._group_sizes.values()) 105 | ) 106 | # associate index with batch within prediction time group 107 | self._sub_group_index = np.concatenate( 108 | [np.arange(size) for size in self._group_sizes.values()] 109 | ) 110 | 111 | def __iter__(self): 112 | if self.shuffle: # shuffle samples 113 | groups = {name: shuffle(group) for name, group in self._groups.items()} 114 | batch_samples = np.random.permutation(len(self)) 115 | else: 116 | groups = self._groups 117 | batch_samples = np.arange(len(self)) 118 | 119 | for idx in batch_samples: 120 | name = self._group_index[idx] 121 | sub_group = self._sub_group_index[idx] 122 | sub_group_start = sub_group * self.batch_size 123 | sub_group_end = sub_group_start + self.batch_size 124 | batch = groups[name][sub_group_start:sub_group_end] 125 | yield batch 126 | 127 | def __len__(self): 128 | return len(self._group_index) 129 | 130 | 131 | class TimeSynchronizedBatchSampler(GroupedSampler): 132 | """ 133 | Samples mini-batches randomly but in a time-synchronised manner. 134 | 135 | Time-synchornisation means that the time index of the first decoder samples are aligned across the batch. 136 | This sampler does not support missing values in the dataset. 137 | """ # noqa: E501 138 | 139 | def get_groups(self, sampler: Sampler): 140 | data_source = sampler.data_source 141 | index = data_source.index 142 | # get groups, i.e. group all samples by first predict time 143 | last_time = data_source.data["time"][index["index_end"].to_numpy()].numpy() 144 | decoder_lengths = data_source.calculate_decoder_length( 145 | last_time, index.sequence_length 146 | ) 147 | first_prediction_time = index.time + index.sequence_length - decoder_lengths + 1 148 | groups = pd.RangeIndex(0, len(index.index)).groupby(first_prediction_time) 149 | return groups 150 | -------------------------------------------------------------------------------- /pytorch_forecasting/data/timeseries/__init__.py: -------------------------------------------------------------------------------- 1 | """Data loaders for time series data.""" 2 | 3 | from pytorch_forecasting.data.timeseries._timeseries import ( 4 | TimeSeriesDataSet, 5 | _find_end_indices, 6 | check_for_nonfinite, 7 | ) 8 | from pytorch_forecasting.data.timeseries._timeseries_v2 import TimeSeries 9 | 10 | __all__ = [ 11 | "_find_end_indices", 12 | "check_for_nonfinite", 13 | "TimeSeriesDataSet", 14 | "TimeSeries", 15 | ] 16 | -------------------------------------------------------------------------------- /pytorch_forecasting/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics for (mulit-horizon) timeseries forecasting. 3 | """ 4 | 5 | from pytorch_forecasting.metrics.base_metrics import ( 6 | DistributionLoss, 7 | Metric, 8 | MultiHorizonMetric, 9 | MultiLoss, 10 | MultivariateDistributionLoss, 11 | convert_torchmetric_to_pytorch_forecasting_metric, 12 | ) 13 | from pytorch_forecasting.metrics.distributions import ( 14 | BetaDistributionLoss, 15 | ImplicitQuantileNetworkDistributionLoss, 16 | LogNormalDistributionLoss, 17 | MQF2DistributionLoss, 18 | MultivariateNormalDistributionLoss, 19 | NegativeBinomialDistributionLoss, 20 | NormalDistributionLoss, 21 | ) 22 | from pytorch_forecasting.metrics.point import ( 23 | MAE, 24 | MAPE, 25 | MASE, 26 | RMSE, 27 | SMAPE, 28 | CrossEntropy, 29 | PoissonLoss, 30 | TweedieLoss, 31 | ) 32 | from pytorch_forecasting.metrics.quantile import QuantileLoss 33 | 34 | __all__ = [ 35 | "MultiHorizonMetric", 36 | "DistributionLoss", 37 | "MultivariateDistributionLoss", 38 | "MultiLoss", 39 | "Metric", 40 | "convert_torchmetric_to_pytorch_forecasting_metric", 41 | "MAE", 42 | "MAPE", 43 | "MASE", 44 | "PoissonLoss", 45 | "TweedieLoss", 46 | "CrossEntropy", 47 | "SMAPE", 48 | "RMSE", 49 | "BetaDistributionLoss", 50 | "NegativeBinomialDistributionLoss", 51 | "NormalDistributionLoss", 52 | "LogNormalDistributionLoss", 53 | "MultivariateNormalDistributionLoss", 54 | "ImplicitQuantileNetworkDistributionLoss", 55 | "QuantileLoss", 56 | "MQF2DistributionLoss", 57 | ] 58 | -------------------------------------------------------------------------------- /pytorch_forecasting/metrics/quantile.py: -------------------------------------------------------------------------------- 1 | """Quantile metrics for forecasting multiple quantiles per time step.""" 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | from pytorch_forecasting.metrics.base_metrics import MultiHorizonMetric 8 | 9 | 10 | class QuantileLoss(MultiHorizonMetric): 11 | """ 12 | Quantile loss, i.e. a quantile of ``q=0.5`` will give half of the mean absolute error as it is calculated as 13 | 14 | Defined as ``max(q * (y-y_pred), (1-q) * (y_pred-y))`` 15 | """ # noqa: E501 16 | 17 | def __init__( 18 | self, 19 | quantiles: Optional[list[float]] = None, 20 | **kwargs, 21 | ): 22 | """ 23 | Quantile loss 24 | 25 | Args: 26 | quantiles: quantiles for metric 27 | """ 28 | if quantiles is None: 29 | quantiles = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98] 30 | super().__init__(quantiles=quantiles, **kwargs) 31 | 32 | def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 33 | # calculate quantile loss 34 | losses = [] 35 | for i, q in enumerate(self.quantiles): 36 | errors = target - y_pred[..., i] 37 | losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1)) 38 | losses = 2 * torch.cat(losses, dim=2) 39 | 40 | return losses 41 | 42 | def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Convert network prediction into a point prediction. 45 | 46 | Args: 47 | y_pred: prediction output of network 48 | 49 | Returns: 50 | torch.Tensor: point prediction 51 | """ 52 | if y_pred.ndim == 3: 53 | idx = self.quantiles.index(0.5) 54 | y_pred = y_pred[..., idx] 55 | return y_pred 56 | 57 | def to_quantiles(self, y_pred: torch.Tensor) -> torch.Tensor: 58 | """ 59 | Convert network prediction into a quantile prediction. 60 | 61 | Args: 62 | y_pred: prediction output of network 63 | 64 | Returns: 65 | torch.Tensor: prediction quantiles 66 | """ 67 | return y_pred 68 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Models for timeseries forecasting. 3 | """ 4 | 5 | from pytorch_forecasting.models.base import ( 6 | AutoRegressiveBaseModel, 7 | AutoRegressiveBaseModelWithCovariates, 8 | BaseModel, 9 | BaseModelWithCovariates, 10 | ) 11 | from pytorch_forecasting.models.baseline import Baseline 12 | from pytorch_forecasting.models.deepar import DeepAR 13 | from pytorch_forecasting.models.mlp import DecoderMLP 14 | from pytorch_forecasting.models.nbeats import NBeats 15 | from pytorch_forecasting.models.nhits import NHiTS 16 | from pytorch_forecasting.models.nn import GRU, LSTM, MultiEmbedding, get_rnn 17 | from pytorch_forecasting.models.rnn import RecurrentNetwork 18 | from pytorch_forecasting.models.temporal_fusion_transformer import ( 19 | TemporalFusionTransformer, 20 | ) 21 | from pytorch_forecasting.models.tide import TiDEModel 22 | 23 | __all__ = [ 24 | "NBeats", 25 | "NHiTS", 26 | "TemporalFusionTransformer", 27 | "RecurrentNetwork", 28 | "DeepAR", 29 | "BaseModel", 30 | "Baseline", 31 | "BaseModelWithCovariates", 32 | "AutoRegressiveBaseModel", 33 | "AutoRegressiveBaseModelWithCovariates", 34 | "get_rnn", 35 | "LSTM", 36 | "GRU", 37 | "MultiEmbedding", 38 | "DecoderMLP", 39 | "TiDEModel", 40 | ] 41 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/base/__init__.py: -------------------------------------------------------------------------------- 1 | """Base classes for pytorch-foercasting models.""" 2 | 3 | from pytorch_forecasting.models.base._base_model import ( 4 | AutoRegressiveBaseModel, 5 | AutoRegressiveBaseModelWithCovariates, 6 | BaseModel, 7 | BaseModelWithCovariates, 8 | Prediction, 9 | ) 10 | from pytorch_forecasting.models.base._base_object import ( 11 | _BaseObject, 12 | _BasePtForecaster, 13 | ) 14 | 15 | __all__ = [ 16 | "_BaseObject", 17 | "_BasePtForecaster", 18 | "AutoRegressiveBaseModel", 19 | "AutoRegressiveBaseModelWithCovariates", 20 | "BaseModel", 21 | "BaseModelWithCovariates", 22 | "Prediction", 23 | ] 24 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/base/_base_object.py: -------------------------------------------------------------------------------- 1 | """Base Classes for pytorch-forecasting models, skbase compatible for indexing.""" 2 | 3 | import inspect 4 | 5 | from pytorch_forecasting.utils._dependencies import _safe_import 6 | 7 | _SkbaseBaseObject = _safe_import("skbase.base.BaseObject", pkg_name="scikit-base") 8 | 9 | 10 | class _BaseObject(_SkbaseBaseObject): 11 | pass 12 | 13 | 14 | class _BasePtForecaster(_BaseObject): 15 | """Base class for all PyTorch Forecasting forecaster metadata. 16 | 17 | This class points to model objects and contains metadata as tags. 18 | """ 19 | 20 | _tags = { 21 | "object_type": "forecaster_pytorch", 22 | } 23 | 24 | @classmethod 25 | def get_model_cls(cls): 26 | """Get model class.""" 27 | raise NotImplementedError 28 | 29 | @classmethod 30 | def name(cls): 31 | """Get model name.""" 32 | name = cls.get_class_tags().get("info:name", None) 33 | if name is None: 34 | name = cls.get_model_cls().__name__ 35 | return name 36 | 37 | @classmethod 38 | def create_test_instance(cls, parameter_set="default"): 39 | """Construct an instance of the class, using first test parameter set. 40 | 41 | Parameters 42 | ---------- 43 | parameter_set : str, default="default" 44 | Name of the set of test parameters to return, for use in tests. If no 45 | special parameters are defined for a value, will return `"default"` set. 46 | 47 | Returns 48 | ------- 49 | instance : instance of the class with default parameters 50 | 51 | """ 52 | if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: 53 | params = cls.get_test_params(parameter_set=parameter_set) 54 | else: 55 | params = cls.get_test_params() 56 | 57 | if isinstance(params, list) and isinstance(params[0], dict): 58 | params = params[0] 59 | elif isinstance(params, dict): 60 | pass 61 | else: 62 | raise TypeError( 63 | "get_test_params should either return a dict or list of dict." 64 | ) 65 | 66 | return cls.get_model_cls()(**params) 67 | 68 | @classmethod 69 | def create_test_instances_and_names(cls, parameter_set="default"): 70 | """Create list of all test instances and a list of names for them. 71 | 72 | Parameters 73 | ---------- 74 | parameter_set : str, default="default" 75 | Name of the set of test parameters to return, for use in tests. If no 76 | special parameters are defined for a value, will return `"default"` set. 77 | 78 | Returns 79 | ------- 80 | objs : list of instances of cls 81 | i-th instance is ``cls(**cls.get_test_params()[i])`` 82 | names : list of str, same length as objs 83 | i-th element is name of i-th instance of obj in tests. 84 | The naming convention is ``{cls.__name__}-{i}`` if more than one instance, 85 | otherwise ``{cls.__name__}`` 86 | """ 87 | if "parameter_set" in inspect.getfullargspec(cls.get_test_params).args: 88 | param_list = cls.get_test_params(parameter_set=parameter_set) 89 | else: 90 | param_list = cls.get_test_params() 91 | 92 | objs = [] 93 | if not isinstance(param_list, (dict, list)): 94 | raise RuntimeError( 95 | f"Error in {cls.__name__}.get_test_params, " 96 | "return must be param dict for class, or list thereof" 97 | ) 98 | if isinstance(param_list, dict): 99 | param_list = [param_list] 100 | for params in param_list: 101 | if not isinstance(params, dict): 102 | raise RuntimeError( 103 | f"Error in {cls.__name__}.get_test_params, " 104 | "return must be param dict for class, or list thereof" 105 | ) 106 | objs += [cls.get_model_cls()(**params)] 107 | 108 | num_instances = len(param_list) 109 | if num_instances > 1: 110 | names = [cls.__name__ + "-" + str(i) for i in range(num_instances)] 111 | else: 112 | names = [cls.__name__] 113 | 114 | return objs, names 115 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/base_model.py: -------------------------------------------------------------------------------- 1 | """Base classes for pytorch-foercasting models.""" 2 | 3 | from pytorch_forecasting.models.base import ( 4 | AutoRegressiveBaseModel, 5 | AutoRegressiveBaseModelWithCovariates, 6 | BaseModel, 7 | BaseModelWithCovariates, 8 | Prediction, 9 | ) 10 | 11 | __all__ = [ 12 | "AutoRegressiveBaseModel", 13 | "AutoRegressiveBaseModelWithCovariates", 14 | "BaseModel", 15 | "BaseModelWithCovariates", 16 | "Prediction", 17 | ] 18 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/baseline.py: -------------------------------------------------------------------------------- 1 | """ 2 | Baseline model. 3 | """ 4 | 5 | from typing import Any 6 | 7 | import torch 8 | 9 | from pytorch_forecasting.models import BaseModel 10 | 11 | 12 | class Baseline(BaseModel): 13 | """ 14 | Baseline model that uses last known target value to make prediction. 15 | 16 | Example: 17 | 18 | .. code-block:: python 19 | 20 | from pytorch_forecasting import BaseModel, MAE 21 | 22 | # generating predictions 23 | predictions = Baseline().predict(dataloader) 24 | 25 | # calculate baseline performance in terms of mean absolute error (MAE) 26 | metric = MAE() 27 | model = Baseline() 28 | for x, y in dataloader: 29 | metric.update(model(x), y) 30 | 31 | metric.compute() 32 | """ 33 | 34 | def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 35 | """ 36 | Network forward pass. 37 | 38 | Args: 39 | x (Dict[str, torch.Tensor]): network input 40 | 41 | Returns: 42 | Dict[str, torch.Tensor]: netowrk outputs 43 | """ 44 | if isinstance(x["encoder_target"], (list, tuple)): # multiple targets 45 | prediction = [ 46 | self.forward_one_target( 47 | encoder_lengths=x["encoder_lengths"], 48 | decoder_lengths=x["decoder_lengths"], 49 | encoder_target=encoder_target, 50 | ) 51 | for encoder_target in x["encoder_target"] 52 | ] 53 | else: # one target 54 | prediction = self.forward_one_target( 55 | encoder_lengths=x["encoder_lengths"], 56 | decoder_lengths=x["decoder_lengths"], 57 | encoder_target=x["encoder_target"], 58 | ) 59 | return self.to_network_output(prediction=prediction) 60 | 61 | def forward_one_target( 62 | self, 63 | encoder_lengths: torch.Tensor, 64 | decoder_lengths: torch.Tensor, 65 | encoder_target: torch.Tensor, 66 | ): 67 | max_prediction_length = decoder_lengths.max() 68 | assert ( 69 | encoder_lengths.min() > 0 70 | ), "Encoder lengths of at least 1 required to obtain last value" 71 | last_values = encoder_target[ 72 | torch.arange(encoder_target.size(0)), encoder_lengths - 1 73 | ] 74 | prediction = last_values[:, None].expand(-1, max_prediction_length) 75 | return prediction 76 | 77 | def to_prediction(self, out: dict[str, Any], use_metric: bool = True, **kwargs): 78 | return out.prediction 79 | 80 | def to_quantiles(self, out: dict[str, Any], use_metric: bool = True, **kwargs): 81 | return out.prediction[..., None] 82 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/deepar/__init__.py: -------------------------------------------------------------------------------- 1 | """DeepAR: Probabilistic forecasting with autoregressive recurrent networks.""" 2 | 3 | from pytorch_forecasting.models.deepar._deepar import DeepAR 4 | from pytorch_forecasting.models.deepar._deepar_metadata import DeepARMetadata 5 | 6 | __all__ = ["DeepAR", "DeepARMetadata"] 7 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/deepar/_deepar_metadata.py: -------------------------------------------------------------------------------- 1 | """DeepAR metadata container.""" 2 | 3 | from pytorch_forecasting.models.base._base_object import _BasePtForecaster 4 | 5 | 6 | class DeepARMetadata(_BasePtForecaster): 7 | """DeepAR metadata container.""" 8 | 9 | _tags = { 10 | "info:name": "DeepAR", 11 | "info:compute": 3, 12 | "authors": ["jdb78"], 13 | "capability:exogenous": True, 14 | "capability:multivariate": True, 15 | "capability:pred_int": True, 16 | "capability:flexible_history_length": True, 17 | "capability:cold_start": False, 18 | } 19 | 20 | @classmethod 21 | def get_model_cls(cls): 22 | """Get model class.""" 23 | from pytorch_forecasting.models import DeepAR 24 | 25 | return DeepAR 26 | 27 | @classmethod 28 | def get_test_train_params(cls): 29 | """Return testing parameter settings for the trainer. 30 | 31 | Returns 32 | ------- 33 | params : dict or list of dict, default = {} 34 | Parameters to create testing instances of the class 35 | Each dict are parameters to construct an "interesting" test instance, i.e., 36 | `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. 37 | `create_test_instance` uses the first (or only) dictionary in `params` 38 | """ 39 | from pytorch_forecasting.data.encoders import GroupNormalizer 40 | from pytorch_forecasting.metrics import ( 41 | BetaDistributionLoss, 42 | ImplicitQuantileNetworkDistributionLoss, 43 | LogNormalDistributionLoss, 44 | MultivariateNormalDistributionLoss, 45 | NegativeBinomialDistributionLoss, 46 | ) 47 | 48 | params = [ 49 | {}, 50 | {"cell_type": "GRU"}, 51 | dict( 52 | loss=LogNormalDistributionLoss(), 53 | clip_target=True, 54 | data_loader_kwargs=dict( 55 | target_normalizer=GroupNormalizer( 56 | groups=["agency", "sku"], transformation="log" 57 | ) 58 | ), 59 | ), 60 | dict( 61 | loss=NegativeBinomialDistributionLoss(), 62 | clip_target=False, 63 | data_loader_kwargs=dict( 64 | target_normalizer=GroupNormalizer( 65 | groups=["agency", "sku"], center=False 66 | ) 67 | ), 68 | ), 69 | dict( 70 | loss=BetaDistributionLoss(), 71 | clip_target=True, 72 | data_loader_kwargs=dict( 73 | target_normalizer=GroupNormalizer( 74 | groups=["agency", "sku"], transformation="logit" 75 | ) 76 | ), 77 | ), 78 | dict( 79 | data_loader_kwargs=dict( 80 | lags={"volume": [2, 5]}, 81 | target="volume", 82 | time_varying_unknown_reals=["volume"], 83 | min_encoder_length=2, 84 | ), 85 | ), 86 | dict( 87 | data_loader_kwargs=dict( 88 | time_varying_unknown_reals=["volume", "discount"], 89 | target=["volume", "discount"], 90 | lags={"volume": [2], "discount": [2]}, 91 | ), 92 | ), 93 | dict( 94 | loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8), 95 | ), 96 | dict( 97 | loss=MultivariateNormalDistributionLoss(), 98 | trainer_kwargs=dict(accelerator="cpu"), 99 | ), 100 | dict( 101 | loss=MultivariateNormalDistributionLoss(), 102 | data_loader_kwargs=dict( 103 | target_normalizer=GroupNormalizer( 104 | groups=["agency", "sku"], transformation="log1p" 105 | ) 106 | ), 107 | trainer_kwargs=dict(accelerator="cpu"), 108 | ), 109 | ] 110 | defaults = { 111 | "hidden_size": 5, 112 | "cell_type": "LSTM", 113 | "n_plotting_samples": 100, 114 | } 115 | for param in params: 116 | param.update(defaults) 117 | return params 118 | 119 | @classmethod 120 | def _get_test_dataloaders_from(cls, params): 121 | """Get dataloaders from parameters. 122 | 123 | Parameters 124 | ---------- 125 | params : dict 126 | Parameters to create dataloaders. 127 | One of the elements in the list returned by ``get_test_train_params``. 128 | 129 | Returns 130 | ------- 131 | dataloaders : dict with keys "train", "val", "test", values torch DataLoader 132 | Dict of dataloaders created from the parameters. 133 | Train, validation, and test dataloaders. 134 | """ 135 | loss = params.get("loss", None) 136 | clip_target = params.get("clip_target", False) 137 | data_loader_kwargs = params.get("data_loader_kwargs", {}) 138 | 139 | from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss 140 | from pytorch_forecasting.tests._conftest import make_dataloaders 141 | from pytorch_forecasting.tests._data_scenarios import data_with_covariates 142 | 143 | dwc = data_with_covariates() 144 | 145 | if isinstance(loss, NegativeBinomialDistributionLoss): 146 | dwc = dwc.assign(volume=lambda x: x.volume.round()) 147 | 148 | dwc = dwc.copy() 149 | if clip_target: 150 | dwc["target"] = dwc["volume"].clip(1e-3, 1.0) 151 | else: 152 | dwc["target"] = dwc["volume"] 153 | data_loader_default_kwargs = dict( 154 | target="target", 155 | time_varying_known_reals=["price_actual"], 156 | time_varying_unknown_reals=["target"], 157 | static_categoricals=["agency"], 158 | add_relative_time_idx=True, 159 | ) 160 | data_loader_default_kwargs.update(data_loader_kwargs) 161 | dataloaders_w_covariates = make_dataloaders(dwc, **data_loader_default_kwargs) 162 | return dataloaders_w_covariates 163 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/mlp/__init__.py: -------------------------------------------------------------------------------- 1 | """Simple models based on fully connected networks.""" 2 | 3 | from pytorch_forecasting.models.mlp._decodermlp import DecoderMLP 4 | from pytorch_forecasting.models.mlp.submodules import FullyConnectedModule 5 | 6 | __all__ = ["DecoderMLP", "FullyConnectedModule"] 7 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/mlp/submodules.py: -------------------------------------------------------------------------------- 1 | """ 2 | MLP implementation 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class FullyConnectedModule(nn.Module): 10 | def __init__( 11 | self, 12 | input_size: int, 13 | output_size: int, 14 | hidden_size: int, 15 | n_hidden_layers: int, 16 | activation_class: nn.ReLU, 17 | dropout: float = None, 18 | norm: bool = True, 19 | ): 20 | super().__init__() 21 | self.input_size = input_size 22 | self.output_size = output_size 23 | self.hidden_size = hidden_size 24 | self.n_hidden_layers = n_hidden_layers 25 | self.activation_class = activation_class 26 | self.dropout = dropout 27 | self.norm = norm 28 | 29 | # input layer 30 | module_list = [nn.Linear(input_size, hidden_size), activation_class()] 31 | if dropout is not None: 32 | module_list.append(nn.Dropout(dropout)) 33 | if norm: 34 | module_list.append(nn.LayerNorm(hidden_size)) 35 | # hidden layers 36 | for _ in range(n_hidden_layers): 37 | module_list.extend( 38 | [nn.Linear(hidden_size, hidden_size), activation_class()] 39 | ) 40 | if dropout is not None: 41 | module_list.append(nn.Dropout(dropout)) 42 | if norm: 43 | module_list.append(nn.LayerNorm(hidden_size)) 44 | # output layer 45 | module_list.append(nn.Linear(hidden_size, output_size)) 46 | 47 | self.sequential = nn.Sequential(*module_list) 48 | 49 | def forward(self, x: torch.Tensor) -> torch.Tensor: 50 | # x of shape: batch_size x n_timesteps_in 51 | # output of shape batch_size x n_timesteps_out 52 | return self.sequential(x) 53 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/nbeats/__init__.py: -------------------------------------------------------------------------------- 1 | """N-Beats model for timeseries forecasting without covariates.""" 2 | 3 | from pytorch_forecasting.models.nbeats._nbeats import NBeats 4 | from pytorch_forecasting.models.nbeats._nbeats_metadata import NBeatsMetadata 5 | from pytorch_forecasting.models.nbeats.sub_modules import ( 6 | NBEATSGenericBlock, 7 | NBEATSSeasonalBlock, 8 | NBEATSTrendBlock, 9 | ) 10 | 11 | __all__ = [ 12 | "NBeats", 13 | "NBEATSGenericBlock", 14 | "NBeatsMetadata", 15 | "NBEATSSeasonalBlock", 16 | "NBEATSTrendBlock", 17 | ] 18 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/nbeats/_nbeats_metadata.py: -------------------------------------------------------------------------------- 1 | """NBeats metadata container.""" 2 | 3 | from pytorch_forecasting.models.base._base_object import _BasePtForecaster 4 | 5 | 6 | class NBeatsMetadata(_BasePtForecaster): 7 | """NBeats metadata container.""" 8 | 9 | _tags = { 10 | "info:name": "NBeats", 11 | "info:compute": 1, 12 | "authors": ["jdb78"], 13 | "capability:exogenous": False, 14 | "capability:multivariate": False, 15 | "capability:pred_int": False, 16 | "capability:flexible_history_length": False, 17 | "capability:cold_start": False, 18 | } 19 | 20 | @classmethod 21 | def get_model_cls(cls): 22 | """Get model class.""" 23 | from pytorch_forecasting.models import NBeats 24 | 25 | return NBeats 26 | 27 | @classmethod 28 | def get_test_train_params(cls): 29 | """Return testing parameter settings for the trainer. 30 | 31 | Returns 32 | ------- 33 | params : dict or list of dict, default = {} 34 | Parameters to create testing instances of the class 35 | Each dict are parameters to construct an "interesting" test instance, i.e., 36 | `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. 37 | `create_test_instance` uses the first (or only) dictionary in `params` 38 | """ 39 | return [{"backcast_loss_ratio": 1.0}] 40 | 41 | @classmethod 42 | def _get_test_dataloaders_from(cls, params): 43 | """Get dataloaders from parameters. 44 | 45 | Parameters 46 | ---------- 47 | params : dict 48 | Parameters to create dataloaders. 49 | One of the elements in the list returned by ``get_test_train_params``. 50 | 51 | Returns 52 | ------- 53 | dataloaders : dict with keys "train", "val", "test", values torch DataLoader 54 | Dict of dataloaders created from the parameters. 55 | Train, validation, and test dataloaders, in this order. 56 | """ 57 | from pytorch_forecasting.tests._conftest import ( 58 | _dataloaders_fixed_window_without_covariates, 59 | ) 60 | 61 | return _dataloaders_fixed_window_without_covariates() 62 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/nbeats/sub_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of ``nn.Modules`` for N-Beats model. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def linear(input_size, output_size, bias=True, dropout: int = None): 12 | lin = nn.Linear(input_size, output_size, bias=bias) 13 | if dropout is not None: 14 | return nn.Sequential(nn.Dropout(dropout), lin) 15 | else: 16 | return lin 17 | 18 | 19 | def linspace( 20 | backcast_length: int, forecast_length: int, centered: bool = False 21 | ) -> tuple[np.ndarray, np.ndarray]: 22 | if centered: 23 | norm = max(backcast_length, forecast_length) 24 | start = -backcast_length 25 | stop = forecast_length - 1 26 | else: 27 | norm = backcast_length + forecast_length 28 | start = 0 29 | stop = backcast_length + forecast_length - 1 30 | lin_space = np.linspace( 31 | start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32 32 | ) 33 | b_ls = lin_space[:backcast_length] 34 | f_ls = lin_space[backcast_length:] 35 | return b_ls, f_ls 36 | 37 | 38 | class NBEATSBlock(nn.Module): 39 | def __init__( 40 | self, 41 | units, 42 | thetas_dim, 43 | num_block_layers=4, 44 | backcast_length=10, 45 | forecast_length=5, 46 | share_thetas=False, 47 | dropout=0.1, 48 | ): 49 | super().__init__() 50 | self.units = units 51 | self.thetas_dim = thetas_dim 52 | self.backcast_length = backcast_length 53 | self.forecast_length = forecast_length 54 | self.share_thetas = share_thetas 55 | 56 | fc_stack = [ 57 | nn.Linear(backcast_length, units), 58 | nn.ReLU(), 59 | ] 60 | for _ in range(num_block_layers - 1): 61 | fc_stack.extend([linear(units, units, dropout=dropout), nn.ReLU()]) 62 | self.fc = nn.Sequential(*fc_stack) 63 | 64 | if share_thetas: 65 | self.theta_f_fc = self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) 66 | else: 67 | self.theta_b_fc = nn.Linear(units, thetas_dim, bias=False) 68 | self.theta_f_fc = nn.Linear(units, thetas_dim, bias=False) 69 | 70 | def forward(self, x): 71 | return self.fc(x) 72 | 73 | 74 | class NBEATSSeasonalBlock(NBEATSBlock): 75 | def __init__( 76 | self, 77 | units, 78 | thetas_dim=None, 79 | num_block_layers=4, 80 | backcast_length=10, 81 | forecast_length=5, 82 | nb_harmonics=None, 83 | min_period=1, 84 | dropout=0.1, 85 | ): 86 | if nb_harmonics: 87 | thetas_dim = nb_harmonics 88 | else: 89 | thetas_dim = forecast_length 90 | self.min_period = min_period 91 | 92 | super().__init__( 93 | units=units, 94 | thetas_dim=thetas_dim, 95 | num_block_layers=num_block_layers, 96 | backcast_length=backcast_length, 97 | forecast_length=forecast_length, 98 | share_thetas=True, 99 | dropout=dropout, 100 | ) 101 | 102 | backcast_linspace, forecast_linspace = linspace( 103 | backcast_length, forecast_length, centered=False 104 | ) 105 | 106 | p1, p2 = ( 107 | (thetas_dim // 2, thetas_dim // 2) 108 | if thetas_dim % 2 == 0 109 | else (thetas_dim // 2, thetas_dim // 2 + 1) 110 | ) 111 | s1_b = torch.tensor( 112 | np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * backcast_linspace), 113 | dtype=torch.float32, 114 | ) # H/2-1 115 | s2_b = torch.tensor( 116 | np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * backcast_linspace), 117 | dtype=torch.float32, 118 | ) 119 | self.register_buffer("S_backcast", torch.cat([s1_b, s2_b])) 120 | 121 | s1_f = torch.tensor( 122 | np.cos(2 * np.pi * self.get_frequencies(p1)[:, None] * forecast_linspace), 123 | dtype=torch.float32, 124 | ) # H/2-1 125 | s2_f = torch.tensor( 126 | np.sin(2 * np.pi * self.get_frequencies(p2)[:, None] * forecast_linspace), 127 | dtype=torch.float32, 128 | ) 129 | self.register_buffer("S_forecast", torch.cat([s1_f, s2_f])) 130 | 131 | def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: 132 | x = super().forward(x) 133 | amplitudes_backward = self.theta_b_fc(x) 134 | backcast = amplitudes_backward.mm(self.S_backcast) 135 | amplitudes_forward = self.theta_f_fc(x) 136 | forecast = amplitudes_forward.mm(self.S_forecast) 137 | 138 | return backcast, forecast 139 | 140 | def get_frequencies(self, n): 141 | return np.linspace( 142 | 0, (self.backcast_length + self.forecast_length) / self.min_period, n 143 | ) 144 | 145 | 146 | class NBEATSTrendBlock(NBEATSBlock): 147 | def __init__( 148 | self, 149 | units, 150 | thetas_dim, 151 | num_block_layers=4, 152 | backcast_length=10, 153 | forecast_length=5, 154 | dropout=0.1, 155 | ): 156 | super().__init__( 157 | units=units, 158 | thetas_dim=thetas_dim, 159 | num_block_layers=num_block_layers, 160 | backcast_length=backcast_length, 161 | forecast_length=forecast_length, 162 | share_thetas=True, 163 | dropout=dropout, 164 | ) 165 | 166 | backcast_linspace, forecast_linspace = linspace( 167 | backcast_length, forecast_length, centered=True 168 | ) 169 | norm = np.sqrt( 170 | forecast_length / thetas_dim 171 | ) # ensure range of predictions is comparable to input 172 | thetas_dims_range = np.array(range(thetas_dim)) 173 | coefficients = torch.tensor( 174 | backcast_linspace ** thetas_dims_range[:, None], 175 | dtype=torch.float32, 176 | ) 177 | self.register_buffer("T_backcast", coefficients * norm) 178 | coefficients = torch.tensor( 179 | forecast_linspace ** thetas_dims_range[:, None], 180 | dtype=torch.float32, 181 | ) 182 | self.register_buffer("T_forecast", coefficients * norm) 183 | 184 | def forward(self, x) -> tuple[torch.Tensor, torch.Tensor]: 185 | x = super().forward(x) 186 | backcast = self.theta_b_fc(x).mm(self.T_backcast) 187 | forecast = self.theta_f_fc(x).mm(self.T_forecast) 188 | return backcast, forecast 189 | 190 | 191 | class NBEATSGenericBlock(NBEATSBlock): 192 | def __init__( 193 | self, 194 | units, 195 | thetas_dim, 196 | num_block_layers=4, 197 | backcast_length=10, 198 | forecast_length=5, 199 | dropout=0.1, 200 | ): 201 | super().__init__( 202 | units=units, 203 | thetas_dim=thetas_dim, 204 | num_block_layers=num_block_layers, 205 | backcast_length=backcast_length, 206 | forecast_length=forecast_length, 207 | dropout=dropout, 208 | ) 209 | 210 | self.backcast_fc = nn.Linear(thetas_dim, backcast_length) 211 | self.forecast_fc = nn.Linear(thetas_dim, forecast_length) 212 | 213 | def forward(self, x): 214 | x = super().forward(x) 215 | 216 | theta_b = F.relu(self.theta_b_fc(x)) 217 | theta_f = F.relu(self.theta_f_fc(x)) 218 | 219 | return self.backcast_fc(theta_b), self.forecast_fc(theta_f) 220 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/nhits/__init__.py: -------------------------------------------------------------------------------- 1 | """N-HiTS model for timeseries forecasting with covariates.""" 2 | 3 | from pytorch_forecasting.models.nhits._nhits import NHiTS 4 | from pytorch_forecasting.models.nhits.sub_modules import NHiTS as NHiTSModule 5 | 6 | __all__ = ["NHits", "NHiTSModule"] 7 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_forecasting.models.nn.embeddings import MultiEmbedding 2 | from pytorch_forecasting.models.nn.rnn import GRU, LSTM, HiddenState, get_rnn 3 | from pytorch_forecasting.utils import TupleOutputMixIn 4 | 5 | __all__ = [ 6 | "MultiEmbedding", 7 | "get_rnn", 8 | "LSTM", 9 | "GRU", 10 | "HiddenState", 11 | "TupleOutputMixIn", 12 | ] 13 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/rnn/__init__.py: -------------------------------------------------------------------------------- 1 | """Simple recurrent model - either with LSTM or GRU cells.""" 2 | 3 | from pytorch_forecasting.models.rnn._rnn import RecurrentNetwork 4 | 5 | __all__ = ["RecurrentNetwork"] 6 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/temporal_fusion_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | """Temporal fusion transformer for forecasting timeseries.""" 2 | 3 | from pytorch_forecasting.models.temporal_fusion_transformer._tft import ( 4 | TemporalFusionTransformer, 5 | ) 6 | from pytorch_forecasting.models.temporal_fusion_transformer.sub_modules import ( 7 | AddNorm, 8 | GateAddNorm, 9 | GatedLinearUnit, 10 | GatedResidualNetwork, 11 | InterpretableMultiHeadAttention, 12 | VariableSelectionNetwork, 13 | ) 14 | 15 | __all__ = [ 16 | "TemporalFusionTransformer", 17 | "AddNorm", 18 | "GateAddNorm", 19 | "GatedLinearUnit", 20 | "GatedResidualNetwork", 21 | "InterpretableMultiHeadAttention", 22 | "VariableSelectionNetwork", 23 | ] 24 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/tide/__init__.py: -------------------------------------------------------------------------------- 1 | """Tide model.""" 2 | 3 | from pytorch_forecasting.models.tide._tide import TiDEModel 4 | from pytorch_forecasting.models.tide._tide_metadata import TiDEModelMetadata 5 | from pytorch_forecasting.models.tide.sub_modules import _TideModule 6 | 7 | __all__ = [ 8 | "_TideModule", 9 | "TiDEModel", 10 | "TiDEModelMetadata", 11 | ] 12 | -------------------------------------------------------------------------------- /pytorch_forecasting/models/tide/_tide_metadata.py: -------------------------------------------------------------------------------- 1 | """TiDE metadata container.""" 2 | 3 | from pytorch_forecasting.models.base._base_object import _BasePtForecaster 4 | 5 | 6 | class TiDEModelMetadata(_BasePtForecaster): 7 | """Metadata container for TiDE Model.""" 8 | 9 | _tags = { 10 | "info:name": "TiDEModel", 11 | "info:compute": 3, 12 | "authors": ["Sohaib-Ahmed21"], 13 | "capability:exogenous": True, 14 | "capability:multivariate": True, 15 | "capability:pred_int": True, 16 | "capability:flexible_history_length": True, 17 | "capability:cold_start": False, 18 | } 19 | 20 | @classmethod 21 | def get_model_cls(cls): 22 | """Get model class.""" 23 | from pytorch_forecasting.models.tide import TiDEModel 24 | 25 | return TiDEModel 26 | 27 | @classmethod 28 | def get_test_train_params(cls): 29 | """Return testing parameter settings for the trainer. 30 | 31 | Returns 32 | ------- 33 | params : dict or list of dict, default = {} 34 | Parameters to create testing instances of the class. 35 | """ 36 | 37 | from pytorch_forecasting.data.encoders import GroupNormalizer 38 | from pytorch_forecasting.metrics import SMAPE 39 | 40 | params = [ 41 | { 42 | "data_loader_kwargs": dict( 43 | add_relative_time_idx=False, 44 | # must include this everytime since the data_loader_default_kwargs 45 | # include this to be True. 46 | ) 47 | }, 48 | { 49 | "temporal_decoder_hidden": 16, 50 | "data_loader_kwargs": dict(add_relative_time_idx=False), 51 | }, 52 | { 53 | "dropout": 0.2, 54 | "use_layer_norm": True, 55 | "loss": SMAPE(), 56 | "data_loader_kwargs": dict( 57 | target_normalizer=GroupNormalizer( 58 | groups=["agency", "sku"], transformation="softplus" 59 | ), 60 | add_relative_time_idx=False, 61 | ), 62 | }, 63 | ] 64 | defaults = {"hidden_size": 5} 65 | for param in params: 66 | param.update(defaults) 67 | return params 68 | 69 | @classmethod 70 | def _get_test_dataloaders_from(cls, params): 71 | """Get dataloaders from parameters. 72 | 73 | Parameters 74 | ---------- 75 | params : dict 76 | Parameters to create dataloaders. 77 | One of the elements in the list returned by ``get_test_train_params``. 78 | 79 | Returns 80 | ------- 81 | dataloaders : dict with keys "train", "val", "test", values torch DataLoader 82 | Dict of dataloaders created from the parameters. 83 | Train, validation, and test dataloaders. 84 | """ 85 | trainer_kwargs = params.get("trainer_kwargs", {}) 86 | clip_target = params.get("clip_target", False) 87 | data_loader_kwargs = params.get("data_loader_kwargs", {}) 88 | 89 | from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss 90 | from pytorch_forecasting.tests._conftest import make_dataloaders 91 | from pytorch_forecasting.tests._data_scenarios import data_with_covariates 92 | 93 | dwc = data_with_covariates() 94 | 95 | if "loss" in trainer_kwargs and isinstance( 96 | trainer_kwargs["loss"], NegativeBinomialDistributionLoss 97 | ): 98 | dwc = dwc.assign(volume=lambda x: x.volume.round()) 99 | 100 | dwc = dwc.copy() 101 | if clip_target: 102 | dwc["target"] = dwc["volume"].clip(1e-3, 1.0) 103 | else: 104 | dwc["target"] = dwc["volume"] 105 | data_loader_default_kwargs = dict( 106 | target="target", 107 | time_varying_known_reals=["price_actual"], 108 | time_varying_unknown_reals=["target"], 109 | static_categoricals=["agency"], 110 | add_relative_time_idx=True, 111 | ) 112 | data_loader_default_kwargs.update(data_loader_kwargs) 113 | dataloaders_w_covariates = make_dataloaders(dwc, **data_loader_default_kwargs) 114 | return dataloaders_w_covariates 115 | -------------------------------------------------------------------------------- /pytorch_forecasting/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """PyTorch Forecasting test suite.""" 2 | -------------------------------------------------------------------------------- /pytorch_forecasting/tests/_config.py: -------------------------------------------------------------------------------- 1 | """Test configs.""" 2 | 3 | # list of str, names of estimators to exclude from testing 4 | # WARNING: tests for these estimators will be skipped 5 | EXCLUDE_ESTIMATORS = [ 6 | "DummySkipped", 7 | "ClassName", # exclude classes from extension templates 8 | ] 9 | 10 | # dictionary of lists of str, names of tests to exclude from testing 11 | # keys are class names of estimators, values are lists of test names to exclude 12 | # WARNING: tests with these names will be skipped 13 | EXCLUDED_TESTS = {} 14 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch Forecasting package for timeseries forecasting with PyTorch. 3 | """ 4 | 5 | from pytorch_forecasting.utils._utils import ( 6 | InitialParameterRepresenterMixIn, 7 | OutputMixIn, 8 | TupleOutputMixIn, 9 | apply_to_list, 10 | autocorrelation, 11 | concat_sequences, 12 | create_mask, 13 | detach, 14 | get_embedding_size, 15 | groupby_apply, 16 | integer_histogram, 17 | masked_op, 18 | move_to_device, 19 | padded_stack, 20 | profile, 21 | redirect_stdout, 22 | repr_class, 23 | to_list, 24 | unpack_sequence, 25 | unsqueeze_like, 26 | ) 27 | 28 | __all__ = [ 29 | "InitialParameterRepresenterMixIn", 30 | "OutputMixIn", 31 | "TupleOutputMixIn", 32 | "apply_to_list", 33 | "autocorrelation", 34 | "get_embedding_size", 35 | "concat_sequences", 36 | "create_mask", 37 | "to_list", 38 | "RecurrentNetwork", 39 | "DecoderMLP", 40 | "detach", 41 | "masked_op", 42 | "move_to_device", 43 | "integer_histogram", 44 | "groupby_apply", 45 | "padded_stack", 46 | "profile", 47 | "redirect_stdout", 48 | "repr_class", 49 | "unpack_sequence", 50 | "unsqueeze_like", 51 | ] 52 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_coerce.py: -------------------------------------------------------------------------------- 1 | """Coercion functions for various data types.""" 2 | 3 | from copy import deepcopy 4 | 5 | 6 | def _coerce_to_list(obj): 7 | """Coerce object to list. 8 | 9 | None is coerced to empty list, otherwise list constructor is used. 10 | """ 11 | if obj is None: 12 | return [] 13 | if isinstance(obj, str): 14 | return [obj] 15 | return list(obj) 16 | 17 | 18 | def _coerce_to_dict(obj): 19 | """Coerce object to dict. 20 | 21 | None is coerce to empty dict, otherwise deepcopy is used. 22 | """ 23 | if obj is None: 24 | return {} 25 | return deepcopy(obj) 26 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_dependencies/__init__.py: -------------------------------------------------------------------------------- 1 | """Utilities for managing dependencies.""" 2 | 3 | from pytorch_forecasting.utils._dependencies._dependencies import ( 4 | _check_matplotlib, 5 | _get_installed_packages, 6 | ) 7 | from pytorch_forecasting.utils._dependencies._safe_import import _safe_import 8 | 9 | __all__ = [ 10 | "_get_installed_packages", 11 | "_check_matplotlib", 12 | "_safe_import", 13 | ] 14 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_dependencies/_dependencies.py: -------------------------------------------------------------------------------- 1 | """Utilities for managing dependencies. 2 | 3 | Copied from sktime/skbase. 4 | """ 5 | 6 | from functools import lru_cache 7 | 8 | 9 | @lru_cache 10 | def _get_installed_packages_private(): 11 | """Get a dictionary of installed packages and their versions. 12 | 13 | Same as _get_installed_packages, but internal to avoid mutating the lru_cache 14 | by accident. 15 | """ 16 | from importlib.metadata import distributions, version 17 | 18 | dists = distributions() 19 | package_names = {dist.metadata["Name"] for dist in dists} 20 | package_versions = {pkg_name: version(pkg_name) for pkg_name in package_names} 21 | # developer note: 22 | # we cannot just use distributions naively, 23 | # because the same top level package name may appear *twice*, 24 | # e.g., in a situation where a virtual env overrides a base env, 25 | # such as in deployment environments like databricks. 26 | # the "version" contract ensures we always get the version that corresponds 27 | # to the importable distribution, i.e., the top one in the sys.path. 28 | return package_versions 29 | 30 | 31 | def _get_installed_packages(): 32 | """Get a dictionary of installed packages and their versions. 33 | 34 | Returns 35 | ------- 36 | dict : dictionary of installed packages and their versions 37 | keys are PEP 440 compatible package names, values are package versions 38 | MAJOR.MINOR.PATCH version format is used for versions, e.g., "1.2.3" 39 | """ 40 | return _get_installed_packages_private().copy() 41 | 42 | 43 | def _check_matplotlib(ref="This feature", raise_error=True): 44 | """Check if matplotlib is installed. 45 | 46 | Parameters 47 | ---------- 48 | ref : str, optional (default="This feature") 49 | reference to the feature that requires matplotlib, used in error message 50 | raise_error : bool, optional (default=True) 51 | whether to raise an error if matplotlib is not installed 52 | 53 | Returns 54 | ------- 55 | bool : whether matplotlib is installed 56 | """ 57 | pkgs = _get_installed_packages() 58 | 59 | if raise_error and "matplotlib" not in pkgs: 60 | raise ImportError( 61 | f"{ref} requires matplotlib." 62 | " Please install matplotlib with `pip install matplotlib`." 63 | ) 64 | 65 | return "matplotlib" in pkgs 66 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_dependencies/_safe_import.py: -------------------------------------------------------------------------------- 1 | """Import a module/class, return a Mock object if import fails. 2 | 3 | Copied from sktime/skbase. 4 | 5 | Should be refactored and moved to a common location in skbase. 6 | """ 7 | 8 | import importlib 9 | from unittest.mock import MagicMock 10 | 11 | from pytorch_forecasting.utils._dependencies import _get_installed_packages 12 | 13 | 14 | def _safe_import(import_path, pkg_name=None): 15 | """Import a module/class, return a Mock object if import fails. 16 | 17 | Idiomatic usage is ``obj = _safe_import("a.b.c.obj")``. 18 | The function supports importing both top-level modules and nested attributes: 19 | 20 | - Top-level module: ``"torch"`` -> same as ``import torch`` 21 | - Nested module: ``"torch.nn"`` -> same as``from torch import nn`` 22 | - Class/function: ``"torch.nn.Linear"`` -> same as ``from torch.nn import Linear`` 23 | 24 | Parameters 25 | ---------- 26 | import_path : str 27 | The path to the module/class to import. Can be: 28 | 29 | - Single module: ``"torch"`` 30 | - Nested module: ``"torch.nn"`` 31 | - Class/attribute: ``"torch.nn.ReLU"`` 32 | 33 | Note: The dots in the path determine the import behavior: 34 | 35 | - No dots: Imports as a single module 36 | - One dot: Imports as a submodule 37 | - Multiple dots: Last part is treated as an attribute to import 38 | 39 | pkg_name : str, default=None 40 | The name of the package to check for installation. This is useful when 41 | the import name differs from the package name, for example: 42 | 43 | - import: ``"sklearn"`` -> ``pkg_name="scikit-learn"`` 44 | - import: ``"cv2"`` -> ``pkg_name="opencv-python"`` 45 | 46 | If ``None``, uses the first part of ``import_path`` before the dot. 47 | 48 | Returns 49 | ------- 50 | object 51 | If the import path and ``pkg_name`` is present, one of the following: 52 | 53 | - The imported module if ``import_path`` has no dots 54 | - The imported submodule if ``import_path`` has one dot 55 | - The imported class/function if ``import_path`` has multiple dots 56 | 57 | If the package or import path are not found: 58 | a unique ``MagicMock`` object per unique import path. 59 | 60 | Examples 61 | -------- 62 | >>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import 63 | 64 | >>> # Import a top-level module 65 | >>> torch = _safe_import("torch") 66 | 67 | >>> # Import a submodule 68 | >>> nn = _safe_import("torch.nn") 69 | 70 | >>> # Import a specific class 71 | >>> Linear = _safe_import("torch.nn.Linear") 72 | 73 | >>> # Import with different package name 74 | >>> cv2 = _safe_import("cv2", pkg_name="opencv-python") 75 | """ 76 | path_list = import_path.split(".") 77 | 78 | if pkg_name is None: 79 | pkg_name = path_list[0] 80 | obj_name = path_list[-1] 81 | 82 | if pkg_name in _get_installed_packages(): 83 | try: 84 | if len(path_list) == 1: 85 | return importlib.import_module(pkg_name) 86 | module_name, attr_name = import_path.rsplit(".", 1) 87 | module = importlib.import_module(module_name) 88 | return getattr(module, attr_name) 89 | except (ImportError, AttributeError): 90 | pass 91 | 92 | mock_obj = _create_mock_class(obj_name) 93 | return mock_obj 94 | 95 | 96 | class CommonMagicMeta(type): 97 | def __getattr__(cls, name): 98 | return MagicMock() 99 | 100 | def __setattr__(cls, name, value): 101 | pass # Ignore attribute writes 102 | 103 | 104 | class MagicAttribute(metaclass=CommonMagicMeta): 105 | def __getattr__(self, name): 106 | return MagicMock() 107 | 108 | def __setattr__(self, name, value): 109 | pass # Ignore attribute writes 110 | 111 | def __call__(self, *args, **kwargs): 112 | return self # Ensures instantiation returns the same object 113 | 114 | 115 | def _create_mock_class(name: str, bases=()): 116 | """Create new dynamic mock class similar to MagicMock. 117 | 118 | Parameters 119 | ---------- 120 | name : str 121 | The name of the new class. 122 | bases : tuple, default=() 123 | The base classes of the new class. 124 | 125 | Returns 126 | ------- 127 | a new class that behaves like MagicMock, with name ``name``. 128 | Forwards all attribute access to a MagicMock object stored in the instance. 129 | """ 130 | return type(name, (MagicAttribute,), {"__metaclass__": CommonMagicMeta}) 131 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_dependencies/tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Tests for dependency utilities.""" 2 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_dependencies/tests/test_safe_import.py: -------------------------------------------------------------------------------- 1 | __author__ = ["jgyasu", "fkiraly"] 2 | 3 | from pytorch_forecasting.utils._dependencies import ( 4 | _get_installed_packages, 5 | _safe_import, 6 | ) 7 | 8 | 9 | def test_import_present_module(): 10 | """Test importing a dependency that is installed.""" 11 | result = _safe_import("pandas") 12 | assert result is not None 13 | assert "pandas" in _get_installed_packages() 14 | 15 | 16 | def test_import_missing_module(): 17 | """Test importing a dependency that is not installed.""" 18 | result = _safe_import("nonexistent_module") 19 | assert hasattr(result, "__name__") 20 | assert result.__name__ == "nonexistent_module" 21 | 22 | 23 | def test_import_without_pkg_name(): 24 | """Test importing a dependency with the same name as package name.""" 25 | result = _safe_import("torch", pkg_name="torch") 26 | assert result is not None 27 | 28 | 29 | def test_import_with_different_pkg_name_1(): 30 | """Test importing a dependency with a different package name.""" 31 | result = _safe_import("skbase", pkg_name="scikit-base") 32 | assert result is not None 33 | 34 | 35 | def test_import_with_different_pkg_name_2(): 36 | """Test importing another dependency with a different package name.""" 37 | result = _safe_import("cv2", pkg_name="opencv-python") 38 | assert result is not None 39 | 40 | 41 | def test_import_submodule(): 42 | """Test importing a submodule.""" 43 | result = _safe_import("torch.nn") 44 | assert result is not None 45 | 46 | 47 | def test_import_class(): 48 | """Test importing a class.""" 49 | result = _safe_import("torch.nn.Linear") 50 | assert result is not None 51 | 52 | 53 | def test_import_existing_object(): 54 | """Test importing an existing object.""" 55 | result = _safe_import("pandas.DataFrame") 56 | assert result is not None 57 | assert result.__name__ == "DataFrame" 58 | from pandas import DataFrame 59 | 60 | assert result is DataFrame 61 | 62 | 63 | def test_multiple_inheritance_from_mock(): 64 | """Test multiple inheritance from dynamic MagicMock.""" 65 | Class1 = _safe_import("foobar.foo.FooBar") 66 | Class2 = _safe_import("barfoobar.BarFooBar") 67 | 68 | class NewClass(Class1, Class2): 69 | """This should not trigger an error. 70 | 71 | The class definition would trigger an error if multiple inheritance 72 | from Class1 and Class2 does not work, e.g., if it is simply 73 | identical to MagicMock. 74 | """ 75 | 76 | pass 77 | -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_maint/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sktime/pytorch-forecasting/4384140418bfe8e8c64cc6ce5fc19372508bfccd/pytorch_forecasting/utils/_maint/__init__.py -------------------------------------------------------------------------------- /pytorch_forecasting/utils/_maint/_show_versions.py: -------------------------------------------------------------------------------- 1 | # License: BSD 3 clause 2 | """Utility methods to print system info for debugging. 3 | 4 | adapted from 5 | :func: `sklearn.show_versions` and `sktime.show_versions` 6 | """ 7 | 8 | __all__ = ["show_versions"] 9 | 10 | import importlib 11 | import platform 12 | import sys 13 | 14 | 15 | def _get_sys_info(): 16 | """System information. 17 | 18 | Return 19 | ------ 20 | sys_info : dict 21 | system and Python version information 22 | """ 23 | python = sys.version.replace("\n", " ") 24 | 25 | blob = [ 26 | ("python", python), 27 | ("executable", sys.executable), 28 | ("machine", platform.platform()), 29 | ] 30 | 31 | return dict(blob) 32 | 33 | 34 | # dependencies to print versions of, by default 35 | DEFAULT_DEPS_TO_SHOW = [ 36 | "pip", 37 | "pytorch-forecasting", 38 | "torch", 39 | "lightning", 40 | "numpy", 41 | "scipy", 42 | "pandas", 43 | "cpflows", 44 | "matplotlib", 45 | "optuna", 46 | "optuna-integration", 47 | "pytorch_optimizer", 48 | "scikit-learn", 49 | "scikit-base", 50 | "statsmodels", 51 | ] 52 | 53 | 54 | def _get_deps_info(deps=None, source="distributions"): 55 | """Overview of the installed version of main dependencies. 56 | 57 | Parameters 58 | ---------- 59 | deps : optional, list of strings with package names 60 | if None, behaves as deps = ["pytorch-forecasting"]. 61 | 62 | source : str, optional one of "distributions" (default) or "import" 63 | source of version information 64 | 65 | * "distributions" - uses importlib.distributions. In this case, 66 | strings in deps are assumed to be PEP 440 package strings, 67 | e.g., scikit-learn, not sklearn. 68 | * "import" - uses the __version__ attribute of the module. 69 | In this case, strings in deps are assumed to be import names, 70 | e.g., sklearn, not scikit-learn. 71 | 72 | Returns 73 | ------- 74 | deps_info: dict 75 | version information on libraries in `deps` 76 | keys are package names, import names if source is "import", 77 | and PEP 440 package strings if source is "distributions"; 78 | values are PEP 440 version strings 79 | of the import as present in the current python environment 80 | """ 81 | if deps is None: 82 | deps = ["pytorch-forecasting"] 83 | 84 | if source == "distributions": 85 | from pytorch_forecasting.utils._dependencies import _get_installed_packages 86 | 87 | KEY_ALIAS = {"sklearn": "scikit-learn", "skbase": "scikit-base"} 88 | 89 | pkgs = _get_installed_packages() 90 | 91 | deps_info = {} 92 | for modname in deps: 93 | pkg_name = KEY_ALIAS.get(modname, modname) 94 | deps_info[modname] = pkgs.get(pkg_name, None) 95 | 96 | return deps_info 97 | 98 | def get_version(module): 99 | return getattr(module, "__version__", None) 100 | 101 | deps_info = {} 102 | 103 | for modname in deps: 104 | try: 105 | if modname in sys.modules: 106 | mod = sys.modules[modname] 107 | else: 108 | mod = importlib.import_module(modname) 109 | except ImportError: 110 | deps_info[modname] = None 111 | else: 112 | ver = get_version(mod) 113 | deps_info[modname] = ver 114 | 115 | return deps_info 116 | 117 | 118 | def show_versions(): 119 | """Print python version, OS version, sktime version, selected dependency versions. 120 | 121 | Pretty prints: 122 | 123 | * python version of environment 124 | * python executable location 125 | * OS version 126 | * list of import name and version number for selected python dependencies 127 | 128 | Developer note: 129 | Python version/executable and OS version are from `_get_sys_info` 130 | Package versions are retrieved by `_get_deps_info` 131 | Selected dependencies are as in the DEFAULT_DEPS_TO_SHOW variable 132 | """ 133 | sys_info = _get_sys_info() 134 | deps_info = _get_deps_info(deps=DEFAULT_DEPS_TO_SHOW) 135 | 136 | print("\nSystem:") # noqa: T001, T201 137 | for k, stat in sys_info.items(): 138 | print(f"{k:>10}: {stat}") # noqa: T001, T201 139 | 140 | print("\nPython dependencies:") # noqa: T001, T201 141 | for k, stat in deps_info.items(): 142 | print(f"{k:>13}: {stat}") # noqa: T001, T201 143 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) # isort:skip 8 | 9 | 10 | from pytorch_forecasting import TimeSeriesDataSet # isort:skip 11 | from pytorch_forecasting.data.examples import get_stallion_data # isort:skip 12 | 13 | 14 | # for vscode debugging: https://stackoverflow.com/a/62563106/14121677 15 | if os.getenv("_PYTEST_RAISE", "0") != "0": 16 | 17 | @pytest.hookimpl(tryfirst=True) 18 | def pytest_exception_interact(call): 19 | raise call.excinfo.value 20 | 21 | @pytest.hookimpl(tryfirst=True) 22 | def pytest_internalerror(excinfo): 23 | raise excinfo.value 24 | 25 | 26 | @pytest.fixture(scope="session") 27 | def test_data(): 28 | data = get_stallion_data() 29 | data["month"] = data.date.dt.month.astype(str) 30 | data["log_volume"] = np.log1p(data.volume) 31 | data["weight"] = 1 + np.sqrt(data.volume) 32 | 33 | data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month 34 | data["time_idx"] -= data["time_idx"].min() 35 | 36 | special_days = [ 37 | "easter_day", 38 | "good_friday", 39 | "new_year", 40 | "christmas", 41 | "labor_day", 42 | "independence_day", 43 | "revolution_day_memorial", 44 | "regional_games", 45 | "fifa_u_17_world_cup", 46 | "football_gold_cup", 47 | "beer_capital", 48 | "music_fest", 49 | ] 50 | data[special_days] = ( 51 | data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category") 52 | ) 53 | 54 | data = data[lambda x: x.time_idx < 10] # downsample 55 | return data 56 | 57 | 58 | @pytest.fixture(scope="session") 59 | def test_dataset(test_data): 60 | training = TimeSeriesDataSet( 61 | test_data.copy(), 62 | time_idx="time_idx", 63 | target="volume", 64 | time_varying_known_reals=["price_regular", "time_idx"], 65 | group_ids=["agency", "sku"], 66 | static_categoricals=["agency"], 67 | max_encoder_length=5, 68 | max_prediction_length=2, 69 | min_prediction_length=1, 70 | min_encoder_length=0, 71 | randomize_length=None, 72 | ) 73 | return training 74 | 75 | 76 | @pytest.fixture(autouse=True) 77 | def disable_mps(monkeypatch): 78 | """Disable MPS for all tests""" 79 | monkeypatch.setattr("torch._C._mps_is_available", lambda: False) 80 | -------------------------------------------------------------------------------- /tests/test_data/test_encoders.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import itertools 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | from sklearn.utils.validation import NotFittedError, check_is_fitted 8 | import torch 9 | 10 | from pytorch_forecasting.data import ( 11 | EncoderNormalizer, 12 | GroupNormalizer, 13 | MultiNormalizer, 14 | NaNLabelEncoder, 15 | TorchNormalizer, 16 | ) 17 | 18 | 19 | @pytest.mark.parametrize( 20 | "data,allow_nan", 21 | itertools.product( 22 | [ 23 | (np.array([2, 3, 4]), np.array([1, 2, 3, 5, np.nan])), 24 | (np.array(["a", "b", "c"]), np.array(["q", "a", "nan"])), 25 | ], 26 | [True, False], 27 | ), 28 | ) 29 | def test_NaNLabelEncoder(data, allow_nan): 30 | fit_data, transform_data = data 31 | encoder = NaNLabelEncoder(warn=False, add_nan=allow_nan) 32 | encoder.fit(fit_data) 33 | assert np.array_equal( 34 | encoder.inverse_transform(encoder.transform(fit_data)), fit_data 35 | ), "Inverse transform should reverse transform" 36 | if not allow_nan: 37 | with pytest.raises(KeyError): 38 | encoder.transform(transform_data) 39 | else: 40 | assert ( 41 | encoder.transform(transform_data)[0] == 0 42 | ), "First value should be translated to 0 if nan" 43 | assert ( 44 | encoder.transform(transform_data)[-1] == 0 45 | ), "Last value should be translated to 0 if nan" 46 | assert ( 47 | encoder.transform(fit_data)[0] > 0 48 | ), "First value should not be 0 if not nan" 49 | 50 | 51 | def test_NaNLabelEncoder_add(): 52 | encoder = NaNLabelEncoder(add_nan=False) 53 | encoder.fit(np.array(["a", "b", "c"])) 54 | encoder2 = deepcopy(encoder) 55 | encoder2.fit(np.array(["d"])) 56 | assert encoder2.transform(np.array(["a"]))[0] == 0, "a must be encoded as 0" 57 | assert encoder2.transform(np.array(["d"]))[0] == 3, "d must be encoded as 3" 58 | 59 | 60 | @pytest.mark.parametrize( 61 | "kwargs", 62 | [ 63 | dict(method="robust"), 64 | dict(method="robust", method_kwargs=dict(upper=1.0, lower=0.0)), 65 | dict(method="robust", data=np.random.randn(100)), 66 | dict(data=np.random.randn(100)), 67 | dict(transformation="log"), 68 | dict(transformation="softplus"), 69 | dict(transformation="log1p"), 70 | dict(transformation="relu"), 71 | dict(method="identity"), 72 | dict(method="identity", data=np.random.randn(100)), 73 | dict(center=False), 74 | dict(max_length=5), 75 | dict(data=pd.Series(np.random.randn(100))), 76 | dict(max_length=[1, 2]), 77 | ], 78 | ) 79 | def test_EncoderNormalizer(kwargs): 80 | kwargs.setdefault("method", "standard") 81 | kwargs.setdefault("center", True) 82 | kwargs.setdefault("data", torch.rand(100)) 83 | data = kwargs.pop("data") 84 | 85 | normalizer = EncoderNormalizer(**kwargs) 86 | 87 | if kwargs.get("transformation") in ["relu", "softplus", "log1p"]: 88 | assert ( 89 | normalizer.inverse_transform( 90 | torch.as_tensor(normalizer.fit_transform(data)) 91 | ) 92 | >= 0 93 | ).all(), "Inverse transform should yield only positive values" 94 | else: 95 | assert torch.isclose( 96 | normalizer.inverse_transform( 97 | torch.as_tensor(normalizer.fit_transform(data)) 98 | ), 99 | torch.as_tensor(data), 100 | atol=1e-5, 101 | ).all(), "Inverse transform should reverse transform" 102 | 103 | 104 | @pytest.mark.parametrize( 105 | "kwargs,groups", 106 | itertools.product( 107 | [ 108 | dict(method="robust"), 109 | dict(transformation="log"), 110 | dict(transformation="relu"), 111 | dict(center=False), 112 | dict(transformation="log1p"), 113 | dict(transformation="softplus"), 114 | dict(scale_by_group=True), 115 | ], 116 | [[], ["a"]], 117 | ), 118 | ) 119 | def test_GroupNormalizer(kwargs, groups): 120 | data = pd.DataFrame(dict(a=[1, 1, 2, 2, 3], b=[1.1, 1.1, 1.0, 0.0, 1.1])) 121 | defaults = dict( 122 | method="standard", transformation=None, center=True, scale_by_group=False 123 | ) 124 | defaults.update(kwargs) 125 | kwargs = defaults 126 | kwargs["groups"] = groups 127 | kwargs["scale_by_group"] = kwargs["scale_by_group"] and len(kwargs["groups"]) > 0 128 | 129 | normalizer = GroupNormalizer(**kwargs) 130 | encoded = normalizer.fit_transform(data["b"], data) 131 | 132 | test_data = dict( 133 | prediction=torch.tensor([encoded[0]]), 134 | target_scale=torch.tensor(normalizer.get_parameters([1])).unsqueeze(0), 135 | ) 136 | 137 | if kwargs.get("transformation") in ["relu", "softplus", "log1p", "log"]: 138 | assert ( 139 | normalizer(test_data) >= 0 140 | ).all(), "Inverse transform should yield only positive values" 141 | else: 142 | assert torch.isclose( 143 | normalizer(test_data), torch.tensor(data.b.iloc[0]), atol=1e-5 144 | ).all(), "Inverse transform should reverse transform" 145 | 146 | 147 | def test_EncoderNormalizer_with_limited_history(): 148 | data = torch.rand(100) 149 | normalizer = EncoderNormalizer(max_length=[1, 2]).fit(data) 150 | assert normalizer.center_ == data[-1] 151 | 152 | 153 | def test_MultiNormalizer_fitted(): 154 | data = pd.DataFrame( 155 | dict( 156 | a=[1, 1, 2, 2, 3], b=[1.1, 1.1, 1.0, 5.0, 1.1], c=[1.1, 1.1, 1.0, 5.0, 1.1] 157 | ) 158 | ) 159 | 160 | normalizer = MultiNormalizer([GroupNormalizer(groups=["a"]), TorchNormalizer()]) 161 | 162 | with pytest.raises(NotFittedError): 163 | check_is_fitted(normalizer) 164 | 165 | normalizer.fit(data, data) 166 | 167 | try: 168 | check_is_fitted(normalizer.normalizers[0]) 169 | check_is_fitted(normalizer.normalizers[1]) 170 | check_is_fitted(normalizer) 171 | except NotFittedError: 172 | pytest.fail(f"{NotFittedError}") 173 | 174 | 175 | def test_TorchNormalizer_dtype_consistency(): 176 | """ 177 | - Ensures that even for float64 `target_scale`, the transformation will not change the prediction dtype. 178 | - Ensure that target_scale will be of type float32 if method is 'identity' 179 | """ # noqa: E501 180 | parameters = torch.tensor([[[366.4587]]]) 181 | target_scale = torch.tensor([[427875.7500, 80367.4766]], dtype=torch.float64) 182 | assert ( 183 | TorchNormalizer()(dict(prediction=parameters, target_scale=target_scale)).dtype 184 | == torch.float32 185 | ) 186 | assert ( 187 | TorchNormalizer().transform(parameters, target_scale=target_scale).dtype 188 | == torch.float32 189 | ) 190 | 191 | y = np.array([1, 2, 3], dtype=np.float32) 192 | assert ( 193 | TorchNormalizer(method="identity").fit(y).get_parameters().dtype 194 | == torch.float32 195 | ) 196 | -------------------------------------------------------------------------------- /tests/test_data/test_samplers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.utils.data.sampler import SequentialSampler 4 | 5 | from pytorch_forecasting.data import TimeSynchronizedBatchSampler 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "drop_last,shuffle,as_string,batch_size", 10 | [ 11 | (True, True, True, 64), 12 | (False, False, False, 64), 13 | (True, False, False, 1000), 14 | ], 15 | ) 16 | def test_TimeSynchronizedBatchSampler( 17 | test_dataset, shuffle, drop_last, as_string, batch_size 18 | ): 19 | if as_string: 20 | dataloader = test_dataset.to_dataloader( 21 | batch_sampler="synchronized", 22 | shuffle=shuffle, 23 | drop_last=drop_last, 24 | batch_size=batch_size, 25 | ) 26 | else: 27 | sampler = TimeSynchronizedBatchSampler( 28 | SequentialSampler(test_dataset), 29 | shuffle=shuffle, 30 | drop_last=drop_last, 31 | batch_size=batch_size, 32 | ) 33 | dataloader = test_dataset.to_dataloader(batch_sampler=sampler) 34 | 35 | time_idx_pos = test_dataset.reals.index("time_idx") 36 | for x, _ in iter(dataloader): # check all samples 37 | time_idx_of_first_prediction = x["decoder_cont"][:, 0, time_idx_pos] 38 | assert torch.isclose( 39 | time_idx_of_first_prediction, time_idx_of_first_prediction[0] 40 | ).all(), "Time index should be the same for the first prediction" 41 | -------------------------------------------------------------------------------- /tests/test_models/test_baseline.py: -------------------------------------------------------------------------------- 1 | from pytorch_forecasting import Baseline 2 | 3 | 4 | def test_integration(multiple_dataloaders_with_covariates): 5 | dataloader = multiple_dataloaders_with_covariates["val"] 6 | Baseline().predict(dataloader, fast_dev_run=True) 7 | repr(Baseline()) 8 | -------------------------------------------------------------------------------- /tests/test_models/test_deepar.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import shutil 3 | 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.callbacks import EarlyStopping 6 | from lightning.pytorch.loggers import TensorBoardLogger 7 | import pytest 8 | from test_models.conftest import make_dataloaders 9 | from torch import nn 10 | 11 | from pytorch_forecasting.data.encoders import GroupNormalizer 12 | from pytorch_forecasting.metrics import ( 13 | BetaDistributionLoss, 14 | ImplicitQuantileNetworkDistributionLoss, 15 | LogNormalDistributionLoss, 16 | MultivariateNormalDistributionLoss, 17 | NegativeBinomialDistributionLoss, 18 | NormalDistributionLoss, 19 | ) 20 | from pytorch_forecasting.models import DeepAR 21 | 22 | 23 | def _integration( 24 | data_with_covariates, 25 | tmp_path, 26 | cell_type="LSTM", 27 | data_loader_kwargs={}, 28 | clip_target: bool = False, 29 | trainer_kwargs=None, 30 | **kwargs, 31 | ): 32 | data_with_covariates = data_with_covariates.copy() 33 | if clip_target: 34 | data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) 35 | else: 36 | data_with_covariates["target"] = data_with_covariates["volume"] 37 | data_loader_default_kwargs = dict( 38 | target="target", 39 | time_varying_known_reals=["price_actual"], 40 | time_varying_unknown_reals=["target"], 41 | static_categoricals=["agency"], 42 | add_relative_time_idx=True, 43 | ) 44 | data_loader_default_kwargs.update(data_loader_kwargs) 45 | dataloaders_with_covariates = make_dataloaders( 46 | data_with_covariates, **data_loader_default_kwargs 47 | ) 48 | 49 | train_dataloader = dataloaders_with_covariates["train"] 50 | val_dataloader = dataloaders_with_covariates["val"] 51 | test_dataloader = dataloaders_with_covariates["test"] 52 | 53 | early_stop_callback = EarlyStopping( 54 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" 55 | ) 56 | 57 | logger = TensorBoardLogger(tmp_path) 58 | if trainer_kwargs is None: 59 | trainer_kwargs = {} 60 | trainer = pl.Trainer( 61 | max_epochs=3, 62 | gradient_clip_val=0.1, 63 | callbacks=[early_stop_callback], 64 | enable_checkpointing=True, 65 | default_root_dir=tmp_path, 66 | limit_train_batches=2, 67 | limit_val_batches=2, 68 | limit_test_batches=2, 69 | logger=logger, 70 | **trainer_kwargs, 71 | ) 72 | 73 | net = DeepAR.from_dataset( 74 | train_dataloader.dataset, 75 | hidden_size=5, 76 | cell_type=cell_type, 77 | learning_rate=0.01, 78 | log_gradient_flow=True, 79 | log_interval=1000, 80 | n_plotting_samples=100, 81 | **kwargs, 82 | ) 83 | net.size() 84 | try: 85 | trainer.fit( 86 | net, 87 | train_dataloaders=train_dataloader, 88 | val_dataloaders=val_dataloader, 89 | ) 90 | test_outputs = trainer.test(net, dataloaders=test_dataloader) 91 | assert len(test_outputs) > 0 92 | # check loading 93 | net = DeepAR.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) 94 | 95 | # check prediction 96 | net.predict( 97 | val_dataloader, 98 | fast_dev_run=True, 99 | return_index=True, 100 | return_decoder_lengths=True, 101 | trainer_kwargs=trainer_kwargs, 102 | ) 103 | finally: 104 | shutil.rmtree(tmp_path, ignore_errors=True) 105 | 106 | net.predict( 107 | val_dataloader, 108 | fast_dev_run=True, 109 | return_index=True, 110 | return_decoder_lengths=True, 111 | trainer_kwargs=trainer_kwargs, 112 | ) 113 | 114 | 115 | @pytest.mark.parametrize( 116 | "kwargs", 117 | [ 118 | {}, 119 | {"cell_type": "GRU"}, 120 | dict( 121 | loss=LogNormalDistributionLoss(), 122 | clip_target=True, 123 | data_loader_kwargs=dict( 124 | target_normalizer=GroupNormalizer( 125 | groups=["agency", "sku"], transformation="log" 126 | ) 127 | ), 128 | ), 129 | dict( 130 | loss=NegativeBinomialDistributionLoss(), 131 | clip_target=False, 132 | data_loader_kwargs=dict( 133 | target_normalizer=GroupNormalizer( 134 | groups=["agency", "sku"], center=False 135 | ) 136 | ), 137 | ), 138 | dict( 139 | loss=BetaDistributionLoss(), 140 | clip_target=True, 141 | data_loader_kwargs=dict( 142 | target_normalizer=GroupNormalizer( 143 | groups=["agency", "sku"], transformation="logit" 144 | ) 145 | ), 146 | ), 147 | dict( 148 | data_loader_kwargs=dict( 149 | lags={"volume": [2, 5]}, 150 | target="volume", 151 | time_varying_unknown_reals=["volume"], 152 | min_encoder_length=2, 153 | ) 154 | ), 155 | dict( 156 | data_loader_kwargs=dict( 157 | time_varying_unknown_reals=["volume", "discount"], 158 | target=["volume", "discount"], 159 | lags={"volume": [2], "discount": [2]}, 160 | ) 161 | ), 162 | dict( 163 | loss=ImplicitQuantileNetworkDistributionLoss(hidden_size=8), 164 | ), 165 | dict( 166 | loss=MultivariateNormalDistributionLoss(), 167 | trainer_kwargs=dict(accelerator="cpu"), 168 | ), 169 | dict( 170 | loss=MultivariateNormalDistributionLoss(), 171 | data_loader_kwargs=dict( 172 | target_normalizer=GroupNormalizer( 173 | groups=["agency", "sku"], transformation="log1p" 174 | ) 175 | ), 176 | trainer_kwargs=dict(accelerator="cpu"), 177 | ), 178 | ], 179 | ) 180 | def test_integration(data_with_covariates, tmp_path, kwargs): 181 | if "loss" in kwargs and isinstance( 182 | kwargs["loss"], NegativeBinomialDistributionLoss 183 | ): 184 | data_with_covariates = data_with_covariates.assign( 185 | volume=lambda x: x.volume.round() 186 | ) 187 | _integration(data_with_covariates, tmp_path, **kwargs) 188 | 189 | 190 | @pytest.fixture 191 | def model(dataloaders_with_covariates): 192 | dataset = dataloaders_with_covariates["train"].dataset 193 | net = DeepAR.from_dataset( 194 | dataset, 195 | hidden_size=5, 196 | learning_rate=0.15, 197 | log_gradient_flow=True, 198 | log_interval=1000, 199 | ) 200 | return net 201 | 202 | 203 | def test_predict_average(model, dataloaders_with_covariates): 204 | prediction = model.predict( 205 | dataloaders_with_covariates["val"], 206 | fast_dev_run=True, 207 | mode="prediction", 208 | n_samples=100, 209 | ) 210 | assert prediction.ndim == 2, "expected averaging of samples" 211 | 212 | 213 | def test_predict_samples(model, dataloaders_with_covariates): 214 | prediction = model.predict( 215 | dataloaders_with_covariates["val"], 216 | fast_dev_run=True, 217 | mode="samples", 218 | n_samples=100, 219 | ) 220 | assert prediction.size()[-1] == 100, "expected raw samples" 221 | 222 | 223 | @pytest.mark.parametrize( 224 | "loss", [NormalDistributionLoss(), MultivariateNormalDistributionLoss()] 225 | ) 226 | def test_pickle(dataloaders_with_covariates, loss): 227 | dataset = dataloaders_with_covariates["train"].dataset 228 | model = DeepAR.from_dataset( 229 | dataset, 230 | hidden_size=5, 231 | learning_rate=0.15, 232 | log_gradient_flow=True, 233 | log_interval=1000, 234 | loss=loss, 235 | ) 236 | pkl = pickle.dumps(model) 237 | pickle.loads(pkl) # noqa: S301 238 | -------------------------------------------------------------------------------- /tests/test_models/test_mlp.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import shutil 3 | 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.callbacks import EarlyStopping 6 | from lightning.pytorch.loggers import TensorBoardLogger 7 | import pytest 8 | from test_models.conftest import make_dataloaders 9 | from torchmetrics import MeanSquaredError 10 | 11 | from pytorch_forecasting.metrics import MAE, CrossEntropy, MultiLoss, QuantileLoss 12 | from pytorch_forecasting.models import DecoderMLP 13 | 14 | 15 | def _integration( 16 | data_with_covariates, tmp_path, data_loader_kwargs={}, train_only=False, **kwargs 17 | ): 18 | data_loader_default_kwargs = dict( 19 | target="target", 20 | time_varying_known_reals=["price_actual"], 21 | time_varying_unknown_reals=["target"], 22 | static_categoricals=["agency"], 23 | add_relative_time_idx=True, 24 | ) 25 | data_loader_default_kwargs.update(data_loader_kwargs) 26 | dataloaders_with_covariates = make_dataloaders( 27 | data_with_covariates, **data_loader_default_kwargs 28 | ) 29 | train_dataloader = dataloaders_with_covariates["train"] 30 | val_dataloader = dataloaders_with_covariates["val"] 31 | test_dataloader = dataloaders_with_covariates["test"] 32 | early_stop_callback = EarlyStopping( 33 | monitor="val_loss", 34 | min_delta=1e-4, 35 | patience=1, 36 | verbose=False, 37 | mode="min", 38 | strict=False, 39 | ) 40 | 41 | logger = TensorBoardLogger(tmp_path) 42 | trainer = pl.Trainer( 43 | max_epochs=3, 44 | gradient_clip_val=0.1, 45 | callbacks=[early_stop_callback], 46 | enable_checkpointing=True, 47 | default_root_dir=tmp_path, 48 | limit_train_batches=2, 49 | limit_val_batches=2, 50 | limit_test_batches=2, 51 | logger=logger, 52 | ) 53 | 54 | net = DecoderMLP.from_dataset( 55 | train_dataloader.dataset, 56 | learning_rate=0.015, 57 | log_gradient_flow=True, 58 | log_interval=1000, 59 | hidden_size=10, 60 | **kwargs, 61 | ) 62 | net.size() 63 | try: 64 | if train_only: 65 | trainer.fit(net, train_dataloaders=train_dataloader) 66 | else: 67 | trainer.fit( 68 | net, 69 | train_dataloaders=train_dataloader, 70 | val_dataloaders=val_dataloader, 71 | ) 72 | # check loading 73 | net = DecoderMLP.load_from_checkpoint( 74 | trainer.checkpoint_callback.best_model_path 75 | ) 76 | 77 | # check prediction 78 | net.predict( 79 | val_dataloader, 80 | fast_dev_run=True, 81 | return_index=True, 82 | return_decoder_lengths=True, 83 | ) 84 | # check test dataloader 85 | test_outputs = trainer.test(net, dataloaders=test_dataloader) 86 | assert len(test_outputs) > 0 87 | finally: 88 | shutil.rmtree(tmp_path, ignore_errors=True) 89 | 90 | net.predict( 91 | val_dataloader, 92 | fast_dev_run=True, 93 | return_index=True, 94 | return_decoder_lengths=True, 95 | ) 96 | 97 | 98 | @pytest.mark.parametrize( 99 | "kwargs", 100 | [ 101 | {}, 102 | dict(train_only=True), 103 | dict( 104 | loss=MultiLoss([QuantileLoss(), MAE()]), 105 | data_loader_kwargs=dict( 106 | time_varying_unknown_reals=["volume", "discount"], 107 | target=["volume", "discount"], 108 | ), 109 | ), 110 | dict( 111 | loss=CrossEntropy(), 112 | data_loader_kwargs=dict( 113 | target="agency", 114 | ), 115 | ), 116 | dict(loss=MeanSquaredError()), 117 | dict( 118 | loss=MeanSquaredError(), 119 | data_loader_kwargs=dict(min_prediction_length=1, min_encoder_length=1), 120 | ), 121 | ], 122 | ) 123 | def test_integration(data_with_covariates, tmp_path, kwargs): 124 | _integration( 125 | data_with_covariates.assign(target=lambda x: x.volume), tmp_path, **kwargs 126 | ) 127 | 128 | 129 | @pytest.fixture 130 | def model(dataloaders_with_covariates): 131 | dataset = dataloaders_with_covariates["train"].dataset 132 | net = DecoderMLP.from_dataset( 133 | dataset, 134 | learning_rate=0.15, 135 | log_gradient_flow=True, 136 | log_interval=1000, 137 | hidden_size=10, 138 | ) 139 | return net 140 | 141 | 142 | def test_pickle(model): 143 | pkl = pickle.dumps(model) 144 | pickle.loads(pkl) # noqa: S301 145 | -------------------------------------------------------------------------------- /tests/test_models/test_nbeats.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import shutil 3 | 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.callbacks import EarlyStopping 6 | from lightning.pytorch.loggers import TensorBoardLogger 7 | import pytest 8 | 9 | from pytorch_forecasting.models import NBeats 10 | from pytorch_forecasting.utils._dependencies import _get_installed_packages 11 | 12 | 13 | def test_integration(dataloaders_fixed_window_without_covariates, tmp_path): 14 | train_dataloader = dataloaders_fixed_window_without_covariates["train"] 15 | val_dataloader = dataloaders_fixed_window_without_covariates["val"] 16 | test_dataloader = dataloaders_fixed_window_without_covariates["test"] 17 | 18 | early_stop_callback = EarlyStopping( 19 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" 20 | ) 21 | 22 | logger = TensorBoardLogger(tmp_path) 23 | trainer = pl.Trainer( 24 | max_epochs=2, 25 | gradient_clip_val=0.1, 26 | callbacks=[early_stop_callback], 27 | enable_checkpointing=True, 28 | default_root_dir=tmp_path, 29 | limit_train_batches=2, 30 | limit_val_batches=2, 31 | limit_test_batches=2, 32 | logger=logger, 33 | ) 34 | 35 | net = NBeats.from_dataset( 36 | train_dataloader.dataset, 37 | learning_rate=0.15, 38 | log_gradient_flow=True, 39 | widths=[4, 4, 4], 40 | log_interval=1000, 41 | backcast_loss_ratio=1.0, 42 | ) 43 | net.size() 44 | try: 45 | trainer.fit( 46 | net, 47 | train_dataloaders=train_dataloader, 48 | val_dataloaders=val_dataloader, 49 | ) 50 | test_outputs = trainer.test(net, dataloaders=test_dataloader) 51 | assert len(test_outputs) > 0 52 | # check loading 53 | net = NBeats.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) 54 | 55 | # check prediction 56 | net.predict( 57 | val_dataloader, 58 | fast_dev_run=True, 59 | return_index=True, 60 | return_decoder_lengths=True, 61 | ) 62 | finally: 63 | shutil.rmtree(tmp_path, ignore_errors=True) 64 | 65 | net.predict( 66 | val_dataloader, 67 | fast_dev_run=True, 68 | return_index=True, 69 | return_decoder_lengths=True, 70 | ) 71 | 72 | 73 | @pytest.fixture(scope="session") 74 | def model(dataloaders_fixed_window_without_covariates): 75 | dataset = dataloaders_fixed_window_without_covariates["train"].dataset 76 | net = NBeats.from_dataset( 77 | dataset, 78 | learning_rate=0.15, 79 | log_gradient_flow=True, 80 | widths=[4, 4, 4], 81 | log_interval=1000, 82 | backcast_loss_ratio=1.0, 83 | ) 84 | return net 85 | 86 | 87 | def test_pickle(model): 88 | pkl = pickle.dumps(model) 89 | pickle.loads(pkl) # noqa: S301 90 | 91 | 92 | @pytest.mark.skipif( 93 | "matplotlib" not in _get_installed_packages(), 94 | reason="skip test if required package matplotlib not installed", 95 | ) 96 | def test_interpretation(model, dataloaders_fixed_window_without_covariates): 97 | raw_predictions = model.predict( 98 | dataloaders_fixed_window_without_covariates["val"], 99 | mode="raw", 100 | return_x=True, 101 | fast_dev_run=True, 102 | ) 103 | model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0) 104 | -------------------------------------------------------------------------------- /tests/test_models/test_nhits.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import shutil 3 | 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.callbacks import EarlyStopping 6 | from lightning.pytorch.loggers import TensorBoardLogger 7 | import numpy as np 8 | import pandas as pd 9 | import pytest 10 | 11 | from pytorch_forecasting.data.timeseries import TimeSeriesDataSet 12 | from pytorch_forecasting.metrics import MQF2DistributionLoss, QuantileLoss 13 | from pytorch_forecasting.metrics.distributions import ( 14 | ImplicitQuantileNetworkDistributionLoss, 15 | ) 16 | from pytorch_forecasting.models import NHiTS 17 | from pytorch_forecasting.utils._dependencies import _get_installed_packages 18 | 19 | 20 | def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs): 21 | train_dataloader = dataloader["train"] 22 | val_dataloader = dataloader["val"] 23 | test_dataloader = dataloader["test"] 24 | 25 | early_stop_callback = EarlyStopping( 26 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" 27 | ) 28 | 29 | logger = TensorBoardLogger(tmp_path) 30 | if trainer_kwargs is None: 31 | trainer_kwargs = {} 32 | trainer = pl.Trainer( 33 | max_epochs=2, 34 | gradient_clip_val=0.1, 35 | callbacks=[early_stop_callback], 36 | enable_checkpointing=True, 37 | default_root_dir=tmp_path, 38 | limit_train_batches=2, 39 | limit_val_batches=2, 40 | limit_test_batches=2, 41 | logger=logger, 42 | **trainer_kwargs, 43 | ) 44 | 45 | kwargs.setdefault("learning_rate", 0.15) 46 | kwargs.setdefault("weight_decay", 1e-2) 47 | 48 | net = NHiTS.from_dataset( 49 | train_dataloader.dataset, 50 | log_gradient_flow=True, 51 | log_interval=1000, 52 | hidden_size=8, 53 | **kwargs, 54 | ) 55 | net.size() 56 | try: 57 | trainer.fit( 58 | net, 59 | train_dataloaders=train_dataloader, 60 | val_dataloaders=val_dataloader, 61 | ) 62 | # todo: testing somehow disables grad computation even though 63 | # it is explicitly turned on 64 | # loss is calculated as "grad" for MQF2 65 | if not isinstance(net.loss, MQF2DistributionLoss): 66 | test_outputs = trainer.test(net, dataloaders=test_dataloader) 67 | assert len(test_outputs) > 0 68 | # check loading 69 | net = NHiTS.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) 70 | 71 | # check prediction 72 | net.predict( 73 | val_dataloader, 74 | fast_dev_run=True, 75 | return_index=True, 76 | return_decoder_lengths=True, 77 | trainer_kwargs=trainer_kwargs, 78 | ) 79 | finally: 80 | shutil.rmtree(tmp_path, ignore_errors=True) 81 | 82 | net.predict( 83 | val_dataloader, 84 | fast_dev_run=True, 85 | return_index=True, 86 | return_decoder_lengths=True, 87 | ) 88 | 89 | 90 | LOADERS = [ 91 | "with_covariates", 92 | "different_encoder_decoder_size", 93 | "fixed_window_without_covariates", 94 | "multi_target", 95 | "quantiles", 96 | "implicit-quantiles", 97 | ] 98 | 99 | if "cpflows" in _get_installed_packages(): 100 | LOADERS += ["multivariate-quantiles"] 101 | 102 | 103 | @pytest.mark.parametrize("dataloader", LOADERS) 104 | def test_integration( 105 | dataloaders_with_covariates, 106 | dataloaders_with_different_encoder_decoder_length, 107 | dataloaders_fixed_window_without_covariates, 108 | dataloaders_multi_target, 109 | tmp_path, 110 | dataloader, 111 | ): 112 | kwargs = {} 113 | if dataloader == "with_covariates": 114 | dataloader = dataloaders_with_covariates 115 | kwargs["backcast_loss_ratio"] = 0.5 116 | elif dataloader == "different_encoder_decoder_size": 117 | dataloader = dataloaders_with_different_encoder_decoder_length 118 | elif dataloader == "fixed_window_without_covariates": 119 | dataloader = dataloaders_fixed_window_without_covariates 120 | elif dataloader == "multi_target": 121 | dataloader = dataloaders_multi_target 122 | kwargs["loss"] = QuantileLoss() 123 | elif dataloader == "quantiles": 124 | dataloader = dataloaders_with_covariates 125 | kwargs["loss"] = QuantileLoss() 126 | elif dataloader == "implicit-quantiles": 127 | dataloader = dataloaders_with_covariates 128 | kwargs["loss"] = ImplicitQuantileNetworkDistributionLoss() 129 | elif dataloader == "multivariate-quantiles": 130 | dataloader = dataloaders_with_covariates 131 | kwargs["loss"] = MQF2DistributionLoss( 132 | prediction_length=dataloader["train"].dataset.max_prediction_length 133 | ) 134 | kwargs["learning_rate"] = 1e-9 135 | kwargs["trainer_kwargs"] = dict(accelerator="cpu") 136 | else: 137 | raise ValueError(f"dataloader {dataloader} unknown") 138 | _integration(dataloader, tmp_path=tmp_path, **kwargs) 139 | 140 | 141 | @pytest.fixture(scope="session") 142 | def model(dataloaders_with_covariates): 143 | dataset = dataloaders_with_covariates["train"].dataset 144 | net = NHiTS.from_dataset( 145 | dataset, 146 | learning_rate=0.15, 147 | hidden_size=8, 148 | log_gradient_flow=True, 149 | log_interval=1000, 150 | backcast_loss_ratio=1.0, 151 | ) 152 | return net 153 | 154 | 155 | def test_pickle(model): 156 | pkl = pickle.dumps(model) 157 | pickle.loads(pkl) # noqa : S301 158 | 159 | 160 | @pytest.mark.skipif( 161 | "matplotlib" not in _get_installed_packages(), 162 | reason="skip test if required package matplotlib not installed", 163 | ) 164 | def test_interpretation(model, dataloaders_with_covariates): 165 | raw_predictions = model.predict( 166 | dataloaders_with_covariates["val"], mode="raw", return_x=True, fast_dev_run=True 167 | ) 168 | model.plot_prediction( 169 | raw_predictions.x, raw_predictions.output, idx=0, add_loss_to_title=True 170 | ) 171 | model.plot_interpretation(raw_predictions.x, raw_predictions.output, idx=0) 172 | 173 | 174 | # Bug when max_prediction_length=1 #1571 175 | @pytest.mark.parametrize("max_prediction_length", [1, 5]) 176 | def test_prediction_length(max_prediction_length: int): 177 | n_timeseries = 10 178 | time_points = 10 179 | data = pd.DataFrame( 180 | data={ 181 | "target": np.random.rand(time_points * n_timeseries), 182 | "time_varying_known_real_1": np.random.rand(time_points * n_timeseries), 183 | "time_idx": np.tile(np.arange(time_points), n_timeseries), 184 | "group_id": np.repeat(np.arange(n_timeseries), time_points), 185 | } 186 | ) 187 | training_dataset = TimeSeriesDataSet( 188 | data=data, 189 | time_idx="time_idx", 190 | target="target", 191 | group_ids=["group_id"], 192 | time_varying_unknown_reals=["target"], 193 | time_varying_known_reals=(["time_varying_known_real_1"]), 194 | max_prediction_length=max_prediction_length, 195 | max_encoder_length=3, 196 | ) 197 | training_data_loader = training_dataset.to_dataloader(train=True) 198 | forecaster = NHiTS.from_dataset(training_dataset, log_val_interval=1) 199 | trainer = pl.Trainer( 200 | accelerator="cpu", 201 | max_epochs=3, 202 | min_epochs=2, 203 | limit_train_batches=10, 204 | ) 205 | trainer.fit( 206 | forecaster, 207 | train_dataloaders=training_data_loader, 208 | ) 209 | validation_dataset = TimeSeriesDataSet.from_dataset( 210 | training_dataset, data, stop_randomization=True, predict=True 211 | ) 212 | validation_data_loader = validation_dataset.to_dataloader(train=False) 213 | forecaster.predict( 214 | validation_data_loader, 215 | fast_dev_run=True, 216 | return_index=True, 217 | return_decoder_lengths=True, 218 | ) 219 | -------------------------------------------------------------------------------- /tests/test_models/test_nn/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from pytorch_forecasting import MultiEmbedding 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "kwargs", 9 | [ 10 | dict(embedding_sizes=(10, 10, 10)), 11 | dict(embedding_sizes=((10, 3), (10, 2), (10, 1))), 12 | dict(x_categoricals=["x1", "x2", "x3"], embedding_sizes=dict(x1=(10, 10))), 13 | dict( 14 | x_categoricals=["x1", "x2", "x3"], 15 | embedding_sizes=dict(x1=(10, 2), xg1=(10, 3)), 16 | categorical_groups=dict(xg1=["x2", "x3"]), 17 | ), 18 | ], 19 | ) 20 | def test_MultiEmbedding(kwargs): 21 | x = torch.randint(0, 10, size=(4, 3)) 22 | embedding = MultiEmbedding(**kwargs) 23 | assert embedding.input_size == x.size( 24 | 1 25 | ), "Input size should be equal to number of features" 26 | out = embedding(x) 27 | if isinstance(out, dict): 28 | assert isinstance(kwargs["embedding_sizes"], dict) 29 | for name, o in out.items(): 30 | assert ( 31 | o.size(1) == embedding.output_size[name] 32 | ), "Output size should be equal to number of embedding dimensions" 33 | elif isinstance(out, torch.Tensor): 34 | assert isinstance(kwargs["embedding_sizes"], (tuple, list)) 35 | assert ( 36 | out.size(1) == embedding.output_size 37 | ), "Output size should be equal to number of summed embedding dimensions" 38 | else: 39 | raise ValueError(f"Unknown output type {type(out)}") 40 | -------------------------------------------------------------------------------- /tests/test_models/test_nn/test_rnn.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytest 4 | import torch 5 | from torch import nn 6 | 7 | from pytorch_forecasting.models.nn.rnn import GRU, LSTM, get_rnn 8 | 9 | 10 | def test_get_lstm_cell(): 11 | cell = get_rnn("LSTM")(10, 10) 12 | assert isinstance(cell, LSTM) 13 | assert isinstance(cell, nn.LSTM) 14 | 15 | 16 | def test_get_gru_cell(): 17 | cell = get_rnn("GRU")(10, 10) 18 | assert isinstance(cell, GRU) 19 | assert isinstance(cell, nn.GRU) 20 | 21 | 22 | def test_get_cell_raises_value_error(): 23 | pytest.raises(ValueError, lambda: get_rnn("ABCDEF")) 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "klass,rnn_kwargs", 28 | itertools.product( 29 | [LSTM, GRU], 30 | [ 31 | dict(batch_first=True, num_layers=1), 32 | dict(batch_first=False, num_layers=2), 33 | ], 34 | ), 35 | ) 36 | def test_zero_length_sequence(klass, rnn_kwargs): 37 | rnn = klass(input_size=2, hidden_size=5, **rnn_kwargs) 38 | x = torch.rand(100, 3, 2) 39 | lengths = torch.randint(0, 3, size=([3, 100][rnn_kwargs["batch_first"]],)) 40 | _, hidden_state = rnn(x, lengths=lengths, enforce_sorted=False) 41 | init_hidden_state = rnn.init_hidden_state(x) 42 | 43 | if isinstance(hidden_state, torch.Tensor): 44 | hidden_state = [hidden_state] 45 | init_hidden_state = [init_hidden_state] 46 | 47 | for idx in range(len(hidden_state)): 48 | assert ( 49 | hidden_state[idx].size() == init_hidden_state[idx].size() 50 | ), "Hidden state sizes should be equal" 51 | assert (hidden_state[idx][:, lengths == 0] == 0).all() and ( 52 | hidden_state[idx][:, lengths > 0] != 0 53 | ).all(), "Hidden state should be zero for zero-length sequences" 54 | -------------------------------------------------------------------------------- /tests/test_models/test_rnn_model.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import shutil 3 | 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.callbacks import EarlyStopping 6 | from lightning.pytorch.loggers import TensorBoardLogger 7 | import pytest 8 | from test_models.conftest import make_dataloaders 9 | 10 | from pytorch_forecasting.data.encoders import GroupNormalizer 11 | from pytorch_forecasting.models import RecurrentNetwork 12 | 13 | 14 | def _integration( 15 | data_with_covariates, 16 | tmp_path, 17 | cell_type="LSTM", 18 | data_loader_kwargs={}, 19 | clip_target: bool = False, 20 | **kwargs, 21 | ): 22 | data_with_covariates = data_with_covariates.copy() 23 | if clip_target: 24 | data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) 25 | else: 26 | data_with_covariates["target"] = data_with_covariates["volume"] 27 | data_loader_default_kwargs = dict( 28 | target="target", 29 | time_varying_known_reals=["price_actual"], 30 | time_varying_unknown_reals=["target"], 31 | static_categoricals=["agency"], 32 | add_relative_time_idx=True, 33 | ) 34 | data_loader_default_kwargs.update(data_loader_kwargs) 35 | dataloaders_with_covariates = make_dataloaders( 36 | data_with_covariates, **data_loader_default_kwargs 37 | ) 38 | train_dataloader = dataloaders_with_covariates["train"] 39 | val_dataloader = dataloaders_with_covariates["val"] 40 | test_dataloader = dataloaders_with_covariates["test"] 41 | 42 | early_stop_callback = EarlyStopping( 43 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" 44 | ) 45 | 46 | logger = TensorBoardLogger(tmp_path) 47 | trainer = pl.Trainer( 48 | max_epochs=3, 49 | gradient_clip_val=0.1, 50 | callbacks=[early_stop_callback], 51 | enable_checkpointing=True, 52 | default_root_dir=tmp_path, 53 | limit_train_batches=2, 54 | limit_val_batches=2, 55 | limit_test_batches=2, 56 | logger=logger, 57 | ) 58 | 59 | net = RecurrentNetwork.from_dataset( 60 | train_dataloader.dataset, 61 | cell_type=cell_type, 62 | learning_rate=0.15, 63 | log_gradient_flow=True, 64 | log_interval=1000, 65 | hidden_size=5, 66 | **kwargs, 67 | ) 68 | net.size() 69 | try: 70 | trainer.fit( 71 | net, 72 | train_dataloaders=train_dataloader, 73 | val_dataloaders=val_dataloader, 74 | ) 75 | test_outputs = trainer.test(net, dataloaders=test_dataloader) 76 | assert len(test_outputs) > 0 77 | # check loading 78 | net = RecurrentNetwork.load_from_checkpoint( 79 | trainer.checkpoint_callback.best_model_path 80 | ) 81 | 82 | # check prediction 83 | net.predict( 84 | val_dataloader, 85 | fast_dev_run=True, 86 | return_index=True, 87 | return_decoder_lengths=True, 88 | ) 89 | finally: 90 | shutil.rmtree(tmp_path, ignore_errors=True) 91 | 92 | net.predict( 93 | val_dataloader, 94 | fast_dev_run=True, 95 | return_index=True, 96 | return_decoder_lengths=True, 97 | ) 98 | 99 | 100 | @pytest.mark.parametrize( 101 | "kwargs", 102 | [ 103 | {}, 104 | {"cell_type": "GRU"}, 105 | dict( 106 | data_loader_kwargs=dict( 107 | target_normalizer=GroupNormalizer( 108 | groups=["agency", "sku"], center=False 109 | ) 110 | ), 111 | ), 112 | dict( 113 | data_loader_kwargs=dict( 114 | lags={"volume": [2, 5]}, 115 | target="volume", 116 | time_varying_unknown_reals=["volume"], 117 | min_encoder_length=2, 118 | ) 119 | ), 120 | dict( 121 | data_loader_kwargs=dict( 122 | time_varying_unknown_reals=["volume", "discount"], 123 | target=["volume", "discount"], 124 | lags={"volume": [2], "discount": [2]}, 125 | ) 126 | ), 127 | ], 128 | ) 129 | def test_integration(data_with_covariates, tmp_path, kwargs): 130 | _integration(data_with_covariates, tmp_path, **kwargs) 131 | 132 | 133 | @pytest.fixture(scope="session") 134 | def model(dataloaders_with_covariates): 135 | dataset = dataloaders_with_covariates["train"].dataset 136 | net = RecurrentNetwork.from_dataset( 137 | dataset, 138 | learning_rate=0.15, 139 | log_gradient_flow=True, 140 | log_interval=1000, 141 | hidden_size=5, 142 | ) 143 | return net 144 | 145 | 146 | def test_pickle(model): 147 | pkl = pickle.dumps(model) 148 | pickle.loads(pkl) # noqa: S301 149 | -------------------------------------------------------------------------------- /tests/test_models/test_tide.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import shutil 3 | 4 | import lightning.pytorch as pl 5 | from lightning.pytorch.callbacks import EarlyStopping 6 | from lightning.pytorch.loggers import TensorBoardLogger 7 | import numpy as np 8 | import pandas as pd 9 | import pytest 10 | 11 | from pytorch_forecasting.data.timeseries import TimeSeriesDataSet 12 | from pytorch_forecasting.metrics import SMAPE 13 | from pytorch_forecasting.models import TiDEModel 14 | from pytorch_forecasting.tests._conftest import make_dataloaders 15 | from pytorch_forecasting.utils._dependencies import _get_installed_packages 16 | 17 | 18 | def _integration( 19 | estimator_cls, 20 | data_with_covariates, 21 | tmp_path, 22 | data_loader_kwargs={}, 23 | clip_target: bool = False, 24 | trainer_kwargs=None, 25 | **kwargs, 26 | ): 27 | data_with_covariates = data_with_covariates.copy() 28 | if clip_target: 29 | data_with_covariates["target"] = data_with_covariates["volume"].clip(1e-3, 1.0) 30 | else: 31 | data_with_covariates["target"] = data_with_covariates["volume"] 32 | data_loader_default_kwargs = dict( 33 | target="target", 34 | time_varying_known_reals=["price_actual"], 35 | time_varying_unknown_reals=["target"], 36 | static_categoricals=["agency"], 37 | add_relative_time_idx=True, 38 | ) 39 | data_loader_default_kwargs.update(data_loader_kwargs) 40 | dataloaders_with_covariates = make_dataloaders( 41 | data_with_covariates, **data_loader_default_kwargs 42 | ) 43 | 44 | train_dataloader = dataloaders_with_covariates["train"] 45 | val_dataloader = dataloaders_with_covariates["val"] 46 | test_dataloader = dataloaders_with_covariates["test"] 47 | 48 | early_stop_callback = EarlyStopping( 49 | monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min" 50 | ) 51 | 52 | logger = TensorBoardLogger(tmp_path) 53 | if trainer_kwargs is None: 54 | trainer_kwargs = {} 55 | trainer = pl.Trainer( 56 | max_epochs=3, 57 | gradient_clip_val=0.1, 58 | callbacks=[early_stop_callback], 59 | enable_checkpointing=True, 60 | default_root_dir=tmp_path, 61 | limit_train_batches=2, 62 | limit_val_batches=2, 63 | limit_test_batches=2, 64 | logger=logger, 65 | **trainer_kwargs, 66 | ) 67 | 68 | net = estimator_cls.from_dataset( 69 | train_dataloader.dataset, 70 | hidden_size=5, 71 | learning_rate=0.01, 72 | log_gradient_flow=True, 73 | log_interval=1000, 74 | **kwargs, 75 | ) 76 | net.size() 77 | try: 78 | trainer.fit( 79 | net, 80 | train_dataloaders=train_dataloader, 81 | val_dataloaders=val_dataloader, 82 | ) 83 | test_outputs = trainer.test(net, dataloaders=test_dataloader) 84 | assert len(test_outputs) > 0 85 | # check loading 86 | net = estimator_cls.load_from_checkpoint( 87 | trainer.checkpoint_callback.best_model_path 88 | ) 89 | 90 | # check prediction 91 | net.predict( 92 | val_dataloader, 93 | fast_dev_run=True, 94 | return_index=True, 95 | return_decoder_lengths=True, 96 | trainer_kwargs=trainer_kwargs, 97 | ) 98 | finally: 99 | shutil.rmtree(tmp_path, ignore_errors=True) 100 | 101 | net.predict( 102 | val_dataloader, 103 | fast_dev_run=True, 104 | return_index=True, 105 | return_decoder_lengths=True, 106 | trainer_kwargs=trainer_kwargs, 107 | ) 108 | 109 | 110 | def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs): 111 | """TiDE specific wrapper around the common integration test function. 112 | 113 | Args: 114 | dataloaders: Dictionary of dataloaders for train, val, and test. 115 | tmp_path: Temporary path for saving the model. 116 | trainer_kwargs: Additional arguments for the Trainer. 117 | **kwargs: Additional arguments for the TiDEModel. 118 | 119 | Returns: 120 | Predictions from the trained model. 121 | """ 122 | from pytorch_forecasting.tests._data_scenarios import data_with_covariates 123 | 124 | df = data_with_covariates() 125 | 126 | tide_kwargs = { 127 | "temporal_decoder_hidden": 8, 128 | "temporal_width_future": 4, 129 | "dropout": 0.1, 130 | } 131 | 132 | tide_kwargs.update(kwargs) 133 | train_dataset = dataloaders["train"].dataset 134 | 135 | data_loader_kwargs = { 136 | "target": train_dataset.target, 137 | "group_ids": train_dataset.group_ids, 138 | "time_varying_known_reals": train_dataset.time_varying_known_reals, 139 | "time_varying_unknown_reals": train_dataset.time_varying_unknown_reals, 140 | "static_categoricals": train_dataset.static_categoricals, 141 | "static_reals": train_dataset.static_reals, 142 | "add_relative_time_idx": train_dataset.add_relative_time_idx, 143 | } 144 | return _integration( 145 | TiDEModel, 146 | df, 147 | tmp_path, 148 | data_loader_kwargs=data_loader_kwargs, 149 | trainer_kwargs=trainer_kwargs, 150 | **tide_kwargs, 151 | ) 152 | 153 | 154 | @pytest.mark.parametrize( 155 | "kwargs", 156 | [ 157 | {}, 158 | {"loss": SMAPE()}, 159 | {"temporal_decoder_hidden": 16}, 160 | {"dropout": 0.2, "use_layer_norm": True}, 161 | ], 162 | ) 163 | def test_integration(dataloaders_with_covariates, tmp_path, kwargs): 164 | _tide_integration(dataloaders_with_covariates, tmp_path, **kwargs) 165 | 166 | 167 | @pytest.mark.parametrize( 168 | "kwargs", 169 | [ 170 | {}, 171 | ], 172 | ) 173 | def test_multi_target_integration(dataloaders_multi_target, tmp_path, kwargs): 174 | _tide_integration(dataloaders_multi_target, tmp_path, **kwargs) 175 | 176 | 177 | @pytest.fixture 178 | def model(dataloaders_with_covariates): 179 | dataset = dataloaders_with_covariates["train"].dataset 180 | net = TiDEModel.from_dataset( 181 | dataset, 182 | hidden_size=16, 183 | dropout=0.1, 184 | temporal_width_future=4, 185 | ) 186 | return net 187 | 188 | 189 | def test_pickle(model): 190 | pkl = pickle.dumps(model) 191 | pickle.loads(pkl) # noqa: S301 192 | 193 | 194 | @pytest.mark.skipif( 195 | "matplotlib" not in _get_installed_packages(), 196 | reason="skip test if required package matplotlib not installed", 197 | ) 198 | def test_prediction_visualization(model, dataloaders_with_covariates): 199 | raw_predictions = model.predict( 200 | dataloaders_with_covariates["val"], 201 | mode="raw", 202 | return_x=True, 203 | fast_dev_run=True, 204 | ) 205 | model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0) 206 | 207 | 208 | def test_prediction_with_kwargs(model, dataloaders_with_covariates): 209 | # Tests prediction works with different keyword arguments 210 | model.predict( 211 | dataloaders_with_covariates["val"], return_index=True, fast_dev_run=True 212 | ) 213 | model.predict( 214 | dataloaders_with_covariates["val"], 215 | return_x=True, 216 | return_y=True, 217 | fast_dev_run=True, 218 | ) 219 | 220 | 221 | def test_no_exogenous_variable(): 222 | data = pd.DataFrame( 223 | { 224 | "target": np.ones(1600), 225 | "group_id": np.repeat(np.arange(16), 100), 226 | "time_idx": np.tile(np.arange(100), 16), 227 | } 228 | ) 229 | training_dataset = TimeSeriesDataSet( 230 | data=data, 231 | time_idx="time_idx", 232 | target="target", 233 | group_ids=["group_id"], 234 | max_encoder_length=10, 235 | max_prediction_length=5, 236 | time_varying_unknown_reals=["target"], 237 | time_varying_known_reals=[], 238 | ) 239 | validation_dataset = TimeSeriesDataSet.from_dataset( 240 | training_dataset, data, stop_randomization=True, predict=True 241 | ) 242 | training_data_loader = training_dataset.to_dataloader( 243 | train=True, batch_size=8, num_workers=0 244 | ) 245 | validation_data_loader = validation_dataset.to_dataloader( 246 | train=False, batch_size=8, num_workers=0 247 | ) 248 | forecaster = TiDEModel.from_dataset( 249 | training_dataset, 250 | ) 251 | from lightning.pytorch import Trainer 252 | 253 | trainer = Trainer( 254 | max_epochs=2, 255 | limit_train_batches=8, 256 | limit_val_batches=8, 257 | ) 258 | trainer.fit( 259 | forecaster, 260 | train_dataloaders=training_data_loader, 261 | val_dataloaders=validation_data_loader, 262 | ) 263 | best_model_path = trainer.checkpoint_callback.best_model_path 264 | best_model = TiDEModel.load_from_checkpoint(best_model_path) 265 | best_model.predict( 266 | validation_data_loader, 267 | fast_dev_run=True, 268 | return_x=True, 269 | return_y=True, 270 | return_index=True, 271 | ) 272 | -------------------------------------------------------------------------------- /tests/test_utils/test_autocorrelation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | from pytorch_forecasting.utils import autocorrelation 6 | 7 | 8 | def test_autocorrelation(): 9 | x = torch.sin(torch.linspace(0, 2 * 2 * math.pi, 201)) 10 | corr = autocorrelation(x, dim=-1) 11 | assert corr[0] == 1, "Autocorrelation of first element should be 1." 12 | assert corr[101] > 0.99, "Autocorrelation should be near 1 for sin(2*pi)" 13 | assert corr[50] < -0.99, "Autocorrelation should be near -1 for sin(pi)" 14 | -------------------------------------------------------------------------------- /tests/test_utils/test_safe_import.py: -------------------------------------------------------------------------------- 1 | from pytorch_forecasting.utils._dependencies import _safe_import 2 | 3 | 4 | def test_present_module(): 5 | """Test importing a dependency that is installed.""" 6 | module = _safe_import("torch") 7 | assert module is not None 8 | 9 | 10 | def test_import_missing_module(): 11 | """Test importing a dependency that is not installed.""" 12 | result = _safe_import("nonexistent_module") 13 | assert hasattr(result, "__name__") 14 | assert result.__name__ == "nonexistent_module" 15 | 16 | 17 | def test_import_without_pkg_name(): 18 | """Test importing a dependency with the same name as package name.""" 19 | result = _safe_import("torch", pkg_name="torch") 20 | assert result is not None 21 | 22 | 23 | def test_import_with_different_pkg_name_1(): 24 | """Test importing a dependency with a different package name.""" 25 | result = _safe_import("skbase", pkg_name="scikit-base") 26 | assert result is not None 27 | 28 | 29 | def test_import_with_different_pkg_name_2(): 30 | """Test importing another dependency with a different package name.""" 31 | result = _safe_import("cv2", pkg_name="opencv-python") 32 | assert result is not None 33 | 34 | 35 | def test_import_submodule(): 36 | """Test importing a submodule.""" 37 | result = _safe_import("torch.nn") 38 | assert result is not None 39 | 40 | 41 | def test_import_class(): 42 | """Test importing a class.""" 43 | result = _safe_import("torch.nn.Linear") 44 | assert result is not None 45 | 46 | 47 | def test_import_existing_object(): 48 | """Test importing an existing object.""" 49 | result = _safe_import("pandas.DataFrame") 50 | assert result is not None 51 | assert result.__name__ == "DataFrame" 52 | from pandas import DataFrame 53 | 54 | assert result is DataFrame 55 | 56 | 57 | def test_multiple_inheritance_from_mock(): 58 | """Test multiple inheritance from dynamic MagicMock.""" 59 | Class1 = _safe_import("foobar.foo.FooBar") 60 | Class2 = _safe_import("barfoobar.BarFooBar") 61 | 62 | class NewClass(Class1, Class2): 63 | """This should not trigger an error. 64 | 65 | The class definition would trigger an error if multiple inheritance 66 | from Class1 and Class2 does not work, e.g., if it is simply 67 | identical to MagicMock. 68 | """ 69 | 70 | pass 71 | -------------------------------------------------------------------------------- /tests/test_utils/test_show_versions.py: -------------------------------------------------------------------------------- 1 | """Tests for the show_versions utility.""" 2 | 3 | import pathlib 4 | import uuid 5 | 6 | from pytorch_forecasting.utils._maint._show_versions import ( 7 | DEFAULT_DEPS_TO_SHOW, 8 | _get_deps_info, 9 | show_versions, 10 | ) 11 | 12 | 13 | def test_show_versions_runs(): 14 | """Test that show_versions runs without exceptions.""" 15 | # only prints, should return None 16 | assert show_versions() is None 17 | 18 | 19 | def test_show_versions_import_loc(): 20 | """Test that show_version can be imported from root.""" 21 | from pytorch_forecasting import show_versions as show_versions_imported 22 | 23 | assert show_versions == show_versions_imported 24 | 25 | 26 | def test_deps_info(): 27 | """Test that _get_deps_info returns package/version dict as per contract.""" 28 | deps_info = _get_deps_info() 29 | assert isinstance(deps_info, dict) 30 | assert set(deps_info.keys()) == {"pytorch-forecasting"} 31 | 32 | deps_info_default = _get_deps_info(DEFAULT_DEPS_TO_SHOW) 33 | assert isinstance(deps_info_default, dict) 34 | assert set(deps_info_default.keys()) == set(DEFAULT_DEPS_TO_SHOW) 35 | 36 | 37 | def test_deps_info_deps_missing_package_present_directory(): 38 | """Test that _get_deps_info does not fail if a dependency is missing.""" 39 | dummy_package_name = uuid.uuid4().hex 40 | 41 | dummy_folder_path = pathlib.Path(dummy_package_name) 42 | dummy_folder_path.mkdir() 43 | 44 | assert _get_deps_info([dummy_package_name]) == {dummy_package_name: None} 45 | 46 | dummy_folder_path.rmdir() 47 | --------------------------------------------------------------------------------